module Lift where import Data.Functor.Foldable import Data.List import Types converge f a = let a' = f a in if a' == a then a else converge f a' factorial :: AExp factorial = ( Lam ["n"] ( Let "m" (Call "factorial" [(ASub (Ident "n") (Number 1))]) ( If (ALt (Ident "n") (Number 2)) (FC (Atom (Number 1))) (FC (Atom (AMul (Ident "m") (Ident "n")))) ) ) ) three :: CExp three = ( Let "n" (Atom (Number 1)) ( Let "n" (Atom (AAdd (Ident "n") (Number 1))) ( Let "n" (Atom (AAdd (Ident "n") (Number 1))) (FC (Atom (Ident "n"))) ) ) ) liftArgs :: AExp -> AExp liftArgs lam@(Lam args body) = Lam (args ++ cata findVarsAExp lam) body liftArgs rest = rest -- | F-Algebras to find free variables findVars :: CExpF [String] -> [String] findVars (LetF ident fc rest) = converge (\\ [ident]) $ rest ++ (findVarsFC fc) findVars (IfF cond t e) = t ++ e ++ (cata findVarsAExp cond) findVars (FCF fc) = findVarsFC fc findVarsFC :: Funcall -> [String] findVarsFC (Atom aexp) = cata findVarsAExp aexp findVarsFC (Call id args) = id : (args >>= cata findVarsAExp) findVarsAExp :: AExpF [String] -> [String] findVarsAExp (IdentF ns) = [ns] findVarsAExp (LamF args cexp) = converge (\\ args) (fold findVars cexp) findVarsAExp def = foldMap id def -- replacing free variables with a new one if it matches the argument compareNames n m = if n == m then n ++ "_" else n -- | replace bound variables in AExps -- >>> hoist (replaceVarsAExp "n") factorial replaceVarsAExp :: String -> AExpF a -> AExpF a replaceVarsAExp n (IdentF m) = IdentF $ compareNames m n replaceVarsAExp n (LamF args body) = LamF (fmap (\x -> if x == n then x ++ "_" else x) args) (cata replaceVarsCExp body n) replaceVarsAExp _ rest = rest -- | replace bound variables in CExps replaceVarsCExp :: CExpF (String -> CExp) -> String -> CExp replaceVarsCExp (LetF name fc restf) = do env <- id rest <- restf let newName = compareNames name env return $ Let newName (replaceVarsFC env fc) rest replaceVarsCExp (IfF cond thenF elseF) = do thenPart <- thenF elsePart <- elseF env <- id return $ If (hoist (replaceVarsAExp env) cond) thenPart elsePart replaceVarsCExp (FCF fc) = do env <- id return $ FC (replaceVarsFC env fc) -- | replace bound variables in Function calls replaceVarsFC :: String -> Funcall -> Funcall replaceVarsFC n (Atom aexp) = Atom $ hoist (replaceVarsAExp n) aexp replaceVarsFC n (Call name args) = Call (compareNames name n) $ fmap (hoist (replaceVarsAExp n)) args subVarsAExp :: ([String], AExpF a) -> AExpF ([String], a) subVarsAExp (env, LamF args body) = let toReplace = intersect env args newArgs = fmap (\x -> if x `elem` toReplace then x ++ "_" else x) args newBody = foldl (cata replaceVarsCExp) body toReplace in LamF newArgs (cotransverse subVarsCExp (newArgs ++ env, [], newBody)) subVarsAExp (env, rest) = fmap (env,) rest {- | cotransverse of a complex expression. The first string list represents | bound variables, the second one a queue of variables to be replaced -} subVarsCExp :: ([String], [String], CExpF a) -> CExpF ([String], [String], a) subVarsCExp (env, queue, LetF name fc rest) = let (newName, oldNames) = foldr (\m (n, ns) -> if n == m then (n ++ "_", n : ns) else (n, ns)) -- repeatedly replace variables in the function call (name, []) queue -- make sure to do oldest first! newFC = foldr replaceVarsFC fc $ oldNames \\ [newName] in LetF newName newFC (if name `elem` env then (env, newName : queue, rest) else (name : env, name : queue, rest)) subVarsCExp (env, queue, IfF cond thenPart elsePart) = IfF (foldr (\x c -> hoist (replaceVarsAExp x) c) cond queue) (env, queue, thenPart) (env, queue, elsePart) subVarsCExp (env, queue, FCF fc) = FCF $ foldr replaceVarsFC fc queue unsafeCotransverse :: (Corecursive t, Recursive a, Functor f) => (f (Base a a) -> Base t (f a)) -> f a -> t unsafeCotransverse n = ana (n . fmap project) subVarsCExp' :: ([String], CExpF CExp) -> CExpF ([String], CExp) subVarsCExp' (env, LetF name fc body) | name `elem` env = let newName = name ++ "_" newFC = replaceVarsFC newName fc newBody = cata replaceVarsCExp body name in LetF newName newFC (newName : env, newBody) | otherwise = LetF name fc (name : env, body) subVarsCExp' (env, rest) = fmap (env,) rest substitute :: CExp -> CExp substitute = unsafeCotransverse subVarsCExp' . ([],)