diff --git a/src/Juvix/Compiler/Core/Extra/Base.hs b/src/Juvix/Compiler/Core/Extra/Base.hs index 106e4bd6ad..a3a2540a8a 100644 --- a/src/Juvix/Compiler/Core/Extra/Base.hs +++ b/src/Juvix/Compiler/Core/Extra/Base.hs @@ -1,9 +1,9 @@ module Juvix.Compiler.Core.Extra.Base where import Data.Functor.Identity -import Data.List qualified as List import Juvix.Compiler.Core.Info qualified as Info import Juvix.Compiler.Core.Language +import Polysemy.Input {------------------------------------------------------------------------} {- Node constructors -} @@ -245,8 +245,8 @@ unfoldLambdas' = first length . unfoldLambdas {------------------------------------------------------------------------} {- functions on Pattern -} -getBinderPatternInfos :: Pattern -> [Binder] -getBinderPatternInfos = go [] +getPatternBinders :: Pattern -> [Binder] +getPatternBinders = reverse . go [] where go :: [Binder] -> Pattern -> [Binder] go acc = \case @@ -255,7 +255,7 @@ getBinderPatternInfos = go [] PatWildcard {} -> acc getPatternInfos :: Pattern -> [Info] -getPatternInfos = go [] +getPatternInfos = reverse . go [] where go :: [Info] -> Pattern -> [Info] go acc = \case @@ -342,6 +342,12 @@ twoChildren f i _ ch = case ch of [l, r] -> f i l r _ -> impossible +{-# INLINE threeChildren #-} +threeChildren :: (Info -> NodeChild -> NodeChild -> NodeChild -> Node) -> Reassemble +threeChildren f i _ ch = case ch of + [a, b, c] -> f i a b c + _ -> impossible + {-# INLINE manyChildren #-} manyChildren :: (Info -> [NodeChild] -> Node) -> Reassemble manyChildren f i _ = f i @@ -360,6 +366,10 @@ twoManyChildrenI f i is = \case (x : y : xs) -> f i is x y xs _ -> impossible +{-# INLINE input' #-} +input' :: Members '[Input (Maybe a)] r => Sem r a +input' = fmap fromJust input + -- | Destruct a node into NodeDetails. This is an ugly internal function used to -- implement more high-level accessors and recursors. destruct :: Node -> NodeDetails @@ -410,16 +420,21 @@ destruct = \case NodeDetails { _nodeInfo = i, _nodeSubinfos = [], - _nodeChildren = [oneBinder bi b], - _nodeReassemble = oneChild $ \i' ch' -> mkLambda i' (hd (ch' ^. childBinders)) (ch' ^. childNode) + _nodeChildren = [noBinders (bi ^. binderType), oneBinder bi b], + _nodeReassemble = twoChildren $ \i' ty' b' -> + let binder' :: Binder + binder' = set binderType (ty' ^. childNode) bi + in mkLambda i' binder' (b' ^. childNode) } NLet (Let i (LetItem bi v) b) -> NodeDetails { _nodeInfo = i, _nodeSubinfos = [], - _nodeChildren = [noBinders v, oneBinder bi b], - _nodeReassemble = twoChildren $ \i' v' b' -> - mkLet i' (hd (b' ^. childBinders)) (v' ^. childNode) (b' ^. childNode) + _nodeChildren = [noBinders (bi ^. binderType), noBinders v, oneBinder bi b], + _nodeReassemble = threeChildren $ \i' ty' v' b' -> + let binder' :: Binder + binder' = set binderType (ty' ^. childNode) bi + in mkLet i' binder' (v' ^. childNode) (b' ^. childNode) } NRec (LetRec i vs b) -> NodeDetails @@ -429,89 +444,142 @@ destruct = \case let binders :: [Binder] values :: [Node] (binders, values) = unzip [(it ^. letItemBinder, it ^. letItemValue) | it <- toList vs] - in map (manyBinders binders) (b : values), - _nodeReassemble = someChildren $ \i' (b' :| values') -> - let items' = + binderTypes :: [Type] + binderTypes = map (^. binderType) binders + in map (manyBinders binders) (b : values) ++ map noBinders binderTypes, + _nodeReassemble = someChildren $ \i' (b' :| valuesTys') -> + let numItems :: Int + numItems = length vs + tys' :: [Type] + values' :: [NodeChild] + (values', tys') = second (map (^. childNode)) (splitAtExact numItems valuesTys') + items' = nonEmpty' - [ LetItem (item ^. letItemBinder) (v' ^. childNode) | (v', item) <- zipExact values' (toList vs) + [ LetItem (Binder name ty') (v' ^. childNode) + | (v', ty', name) <- + zip3Exact + values' + tys' + (map (^. letItemBinder . binderName) (toList vs)) ] in mkLetRec i' items' (b' ^. childNode) } NCase (Case i v brs mdef) -> - let branchChildren :: [NodeChild] + let branchChildren :: [([Binder], NodeChild)] branchChildren = - [ manyBinders (br ^. caseBranchBinders) (br ^. caseBranchBody) - | br <- brs + [ (binders, manyBinders binders (br ^. caseBranchBody)) + | br <- brs, + let binders = br ^. caseBranchBinders ] + -- in this list we have the bodies and the binder types interleaved + allNodes :: [NodeChild] + allNodes = + concat + [ b : map (noBinders . (^. binderType)) bi + | (bi, b) <- branchChildren + ] + mkBranch :: Info -> CaseBranch -> Sem '[Input (Maybe NodeChild)] CaseBranch + mkBranch nfo' br = do + b' <- input' + let nBinders = br ^. caseBranchBindersNum + tys' <- map (^. childNode) <$> replicateM nBinders input' + return + br + { _caseBranchInfo = nfo', + _caseBranchBinders = zipWithExact (set binderType) tys' (b' ^. childBinders), + _caseBranchBody = b' ^. childNode + } + mkBranches :: [Info] -> [NodeChild] -> [CaseBranch] + mkBranches is' allNodes' = + run $ + runInputList allNodes' $ + sequence + [ mkBranch ci' br + | (ci', br) <- zipExact is' brs + ] in case mdef of Nothing -> NodeDetails { _nodeInfo = i, _nodeSubinfos = map (^. caseBranchInfo) brs, - _nodeChildren = noBinders v : branchChildren, - _nodeReassemble = someChildrenI $ \i' is' (v' :| bodies') -> - let branches :: [CaseBranch] - branches = - [ br - { _caseBranchInfo = ib', - _caseBranchBinders = body' ^. childBinders, - _caseBranchBody = body' ^. childNode - } - | (body', ib', br) <- zip3Exact bodies' is' brs - ] - in mkCase i' (v' ^. childNode) branches Nothing + _nodeChildren = noBinders v : allNodes, + _nodeReassemble = someChildrenI $ \i' is' (v' :| allNodes') -> + mkCase i' (v' ^. childNode) (mkBranches is' allNodes') Nothing } Just def -> NodeDetails { _nodeInfo = i, _nodeSubinfos = map (^. caseBranchInfo) brs, - _nodeChildren = noBinders v : noBinders def : branchChildren, - _nodeReassemble = twoManyChildrenI $ \i' is' v' def' bodies' -> - let branches :: [CaseBranch] - branches = - [ br - { _caseBranchInfo = ib', - _caseBranchBinders = body' ^. childBinders, - _caseBranchBody = body' ^. childNode - } - | (body', ib', br) <- zip3Exact bodies' is' brs - ] - in mkCase i' (v' ^. childNode) branches (Just (def' ^. childNode)) + _nodeChildren = noBinders v : noBinders def : allNodes, + _nodeReassemble = twoManyChildrenI $ \i' is' v' def' allNodes' -> + mkCase i' (v' ^. childNode) (mkBranches is' allNodes') (Just (def' ^. childNode)) } NMatch (Match i vs branches) -> - let branchChildren :: [NodeChild] - branchChildren = - [ manyBinders binders (br ^. matchBranchBody) - | br <- branches, - let binders = concatMap getBinderPatternInfos (reverse (toList (br ^. matchBranchPatterns))) - ] - branchPatternInfos :: [Info] - branchPatternInfos = - concatMap - ( \br -> - concatMap - (reverse . getPatternInfos) - (br ^. matchBranchPatterns) - ) - branches - n = length vs + let allNodes :: [NodeChild] + allNodes = + concat + [ b + : map (noBinders . (^. binderType)) bis + | (bis, b) <- branchChildren + ] + where + branchChildren :: [([Binder], NodeChild)] + branchChildren = + [ (binders, manyBinders binders (br ^. matchBranchBody)) + | br <- branches, + let binders = concatMap getPatternBinders (toList (br ^. matchBranchPatterns)) + ] + branchInfos :: [Info] + branchInfos = + concat + [ br + ^. matchBranchInfo + : concatMap getPatternInfos (br ^. matchBranchPatterns) + | br <- branches + ] + setPatternsInfos :: forall r. Members '[Input (Maybe Info), Input (Maybe NodeChild)] r => NonEmpty Pattern -> Sem r (NonEmpty Pattern) + setPatternsInfos = mapM goPattern + where + goPattern :: Pattern -> Sem r Pattern + goPattern = \case + PatWildcard x -> do + i' <- input' + return (PatWildcard (set patternWildcardInfo i' x)) + PatBinder x -> do + ty <- (^. childNode) <$> input' + let _patternBinder = set binderType ty (x ^. patternBinder) + _patternBinderPattern <- goPattern (x ^. patternBinderPattern) + return (PatBinder PatternBinder {..}) + PatConstr x -> do + i' <- input' + args' <- mapM goPattern (x ^. patternConstrArgs) + return (PatConstr (set patternConstrInfo i' (set patternConstrArgs args' x))) in NodeDetails { _nodeInfo = i, - _nodeSubinfos = branchPatternInfos, - _nodeChildren = map noBinders (toList vs) ++ branchChildren, + _nodeSubinfos = branchInfos, + _nodeChildren = map noBinders (toList vs) ++ allNodes, _nodeReassemble = someChildrenI $ \i' is' chs' -> - let values' :: NonEmpty NodeChild - bodies' :: [NodeChild] - (values', bodies') = first nonEmpty' (splitAtExact n (toList chs')) + let mkBranch :: MatchBranch -> Sem '[Input (Maybe NodeChild), Input (Maybe Info)] MatchBranch + mkBranch br = do + bi' <- input' + b' <- input' + pats' <- setPatternsInfos (br ^. matchBranchPatterns) + return + br + { _matchBranchInfo = bi', + _matchBranchPatterns = pats', + _matchBranchBody = b' ^. childNode + } + numVals = length vs + values' :: NonEmpty NodeChild + branchesChilds' :: [NodeChild] + (values', branchesChilds') = first nonEmpty' (splitAtExact numVals (toList chs')) branches' :: [MatchBranch] branches' = - [ br - { _matchBranchPatterns = nonEmpty' $ setPatternsInfos binders' is' (toList (br ^. matchBranchPatterns)), - _matchBranchBody = body' ^. childNode - } - | (body', br) <- zipExact bodies' branches, - let binders' = body' ^. childBinders - ] + run $ + runInputList is' $ + runInputList branchesChilds' $ + mapM mkBranch branches in mkMatch i' (fmap (^. childNode) values') branches' } NPi (Pi i bi b) -> @@ -520,10 +588,9 @@ destruct = \case _nodeSubinfos = [], _nodeChildren = [noBinders (bi ^. binderType), oneBinder bi b], _nodeReassemble = twoChildren $ \i' bi' b' -> - -- NOTE the binder type here is treated as a node - let binder :: Binder - binder = set binderType (bi' ^. childNode) (hd (b' ^. childBinders)) - in mkPi i' binder (b' ^. childNode) + let binder' :: Binder + binder' = set binderType (bi' ^. childNode) bi + in mkPi i' binder' (b' ^. childNode) } NUniv (Univ i l) -> NodeDetails @@ -561,32 +628,6 @@ destruct = \case _nodeReassemble = someChildren $ \i' (b' :| env') -> Closure (map (^. childNode) env') (Lambda i' bi (b' ^. childNode)) } - where - setPatternsInfos :: [Binder] -> [Info] -> [Pattern] -> [Pattern] - setPatternsInfos binders infos = snd . setPatternsInfos' binders infos - where - setPatternsInfos' :: [Binder] -> [Info] -> [Pattern] -> (([Binder], [Info]), [Pattern]) - setPatternsInfos' bs is [] = ((bs, is), []) - setPatternsInfos' bs is (p : ps) = - let ((bs', is'), p') = setPatInfos bs is p - (bis'', ps') = setPatternsInfos' bs' is' ps - in (bis'', p' : ps') - - setPatInfos :: [Binder] -> [Info] -> Pattern -> (([Binder], [Info]), Pattern) - setPatInfos bs is = \case - PatWildcard x -> - ((bs, tl is), PatWildcard (x {_patternWildcardInfo = hd is})) - PatBinder x -> - ((tl bs, is), PatBinder (x {_patternBinder = hd bs})) - PatConstr x -> - let (bis', ps) = setPatternsInfos' bs (tl is) (x ^. patternConstrArgs) - in (bis', PatConstr (x {_patternConstrInfo = hd is, _patternConstrArgs = ps})) - - hd :: [a] -> a - hd = List.head - - tl :: [a] -> [a] - tl = List.tail reassembleDetails :: NodeDetails -> [Node] -> Node reassembleDetails d ns = (d ^. nodeReassemble) (d ^. nodeInfo) (d ^. nodeSubinfos) children' diff --git a/src/Juvix/Compiler/Core/Translation/FromSource.hs b/src/Juvix/Compiler/Core/Translation/FromSource.hs index 6c20aa8153..bda13130d0 100644 --- a/src/Juvix/Compiler/Core/Translation/FromSource.hs +++ b/src/Juvix/Compiler/Core/Translation/FromSource.hs @@ -891,7 +891,7 @@ matchBranch patsNum varsNum vars = do unless (length pats == patsNum) $ parseFailure off "wrong number of patterns" let pis :: [Binder] - pis = concatMap (reverse . getBinderPatternInfos) pats + pis = concatMap getPatternBinders pats (vars', varsNum') = foldl' ( \(vs, k) name ->