diff --git a/src/Juvix/Compiler/Core/Extra/Recursors/RMap.hs b/src/Juvix/Compiler/Core/Extra/Recursors/RMap.hs index 13ec8f174b..6c223c3019 100644 --- a/src/Juvix/Compiler/Core/Extra/Recursors/RMap.hs +++ b/src/Juvix/Compiler/Core/Extra/Recursors/RMap.hs @@ -66,37 +66,39 @@ rmapG coll f = go mempty 0 (coll ^. cEmpty) go binders bl c n = f recur c n where recur :: c -> [BinderChange] -> Node -> m Node - recur c' changes = \case - NVar v -> - return $ - maybe - (NVar v {_varIndex = getBinderIndex bl lvl}) - (shift (bl - lvl)) - mnode - where - (lvl, mnode) = BL.lookup (v ^. varIndex) binders - n' -> - let ni = destruct n' - in reassembleDetails ni <$> mapM goChild (ni ^. nodeChildren) - where - goChild :: NodeChild -> m Node - goChild ch = - let (bl', rbs, rbs') = - foldl' - ( \(l, bs, acc) chg -> case chg of - BCAdd k -> (l + k, bs, acc) - BCKeep b -> (l + 1, b : bs, (l, Nothing) : acc) - BCRemove (BinderRemove b node) -> (l, b : bs, (l, Just node) : acc) - ) - (bl, [], []) - changes - cbs = map (\l -> (l, Nothing)) [bl' .. bl' + ch ^. childBindersNum - 1] - binders' = BL.prependRev cbs (BL.prepend rbs' binders) - in go - binders' - (bl' + ch ^. childBindersNum) - ((coll ^. cCollect) (length rbs + ch ^. childBindersNum, reverse rbs ++ ch ^. childBinders) c') - (ch ^. childNode) + recur c' changes = + let (bl', rbs, rbs') = + foldl' + ( \(l, bs, acc) chg -> case chg of + BCAdd k -> (l + k, bs, acc) + BCKeep b -> (l + 1, b : bs, (l, Nothing) : acc) + BCRemove (BinderRemove b node) -> (l, b : bs, (l, Just node) : acc) + ) + (bl, [], []) + changes + binders' = BL.prepend rbs' binders + in \case + NVar v -> + return $ + maybe + (NVar v {_varIndex = getBinderIndex bl' lvl}) + (shift (bl' - lvl)) + mnode + where + (lvl, mnode) = BL.lookup (v ^. varIndex) binders' + n' -> + let ni = destruct n' + in reassembleDetails ni <$> mapM goChild (ni ^. nodeChildren) + where + goChild :: NodeChild -> m Node + goChild ch = + let cbs = map (\l -> (l, Nothing)) [bl' .. bl' + ch ^. childBindersNum - 1] + binders'' = BL.prependRev cbs binders' + in go + binders'' + (bl' + ch ^. childBindersNum) + ((coll ^. cCollect) (length rbs + ch ^. childBindersNum, reverse rbs ++ ch ^. childBinders) c') + (ch ^. childNode) rmapEmbedIden :: ((([BinderChange] -> Node -> Node) -> Node -> Node)) -> (([BinderChange] -> Node -> Identity Node) -> Node -> Identity Node) rmapEmbedIden f recur = return . f (\bcs -> runIdentity . recur bcs) diff --git a/test/Core/Recursor/RMap.hs b/test/Core/Recursor/RMap.hs index cf11eb5296..318ffc6452 100644 --- a/test/Core/Recursor/RMap.hs +++ b/test/Core/Recursor/RMap.hs @@ -28,6 +28,9 @@ tests = addLambdas [ ( mkLambdas' [mkTypeInteger', mkTypeInteger'] (mkBuiltinApp' OpIntAdd [mkVar' 1, mkVar' 0]), mkLambdas' [mkDynamic', mkTypeInteger', mkDynamic', mkTypeInteger'] (mkBuiltinApp' OpIntAdd [mkVar' 2, mkVar' 0]) + ), + ( mkLambdas' [mkTypeInteger', mkTypeInteger'] (mkVar' 1), + mkLambdas' [mkDynamic', mkTypeInteger', mkDynamic', mkTypeInteger'] (mkVar' 2) ) ], UnitTest @@ -38,6 +41,9 @@ tests = ), ( mkLambda' mkTypeInteger' $ mkLet' mkTypeInteger' (mkVar' 0) $ mkLambda' mkTypeInteger' $ mkBuiltinApp' OpIntAdd [mkVar' 2, mkVar' 0], mkLets' [(mkTypeInteger', mkConstant' (ConstInteger 0)), (mkTypeInteger', mkVar' 0), (mkTypeInteger', mkConstant' (ConstInteger 2))] (mkBuiltinApp' OpIntAdd [mkVar' 2, mkVar' 0]) + ), + ( mkLambda' mkTypeInteger' $ mkVar' 0, + mkLet' mkTypeInteger' (mkConstant' (ConstInteger 0)) (mkVar' 0) ) ], UnitTest @@ -45,6 +51,15 @@ tests = removeLambdas [ ( mkLambdas' [mkTypeInteger', mkTypeInteger', mkTypeInteger', mkTypeInteger'] (mkBuiltinApp' OpIntAdd [mkBuiltinApp' OpIntAdd [mkVar' 3, mkVar' 1], mkVar' 2]), mkLambdas' [mkTypeInteger', mkTypeInteger'] (mkBuiltinApp' OpIntAdd [mkBuiltinApp' OpIntAdd [mkVar' 1, mkVar' 0], mkVar' 1]) + ), + ( mkLambdas' [mkTypeInteger', mkTypeInteger'] (mkVar' 0), + mkLambda' mkTypeInteger' (mkVar' 0) + ), + ( mkLambdas' [mkTypeInteger', mkTypeInteger', mkTypeInteger', mkTypeInteger'] (mkVar' 2), + mkLambdas' [mkTypeInteger', mkTypeInteger'] (mkVar' 1) + ), + ( mkLambdas' [mkTypeInteger', mkTypeInteger', mkTypeInteger'] (mkVar' 1), + mkLambdas' [mkTypeInteger', mkTypeInteger'] (mkVar' 1) ) ] ]