Skip to content

Commit

Permalink
Juvix core recursors should descend into nodes stored in infos (#1600)
Browse files Browse the repository at this point in the history
  • Loading branch information
janmasrovira authored Nov 2, 2022
1 parent 59e6712 commit 23c2b9e
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 99 deletions.
237 changes: 139 additions & 98 deletions src/Juvix/Compiler/Core/Extra/Base.hs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module Juvix.Compiler.Core.Extra.Base where

import Data.Functor.Identity
import Data.List qualified as List
import Juvix.Compiler.Core.Info qualified as Info
import Juvix.Compiler.Core.Language
import Polysemy.Input

{------------------------------------------------------------------------}
{- Node constructors -}
Expand Down Expand Up @@ -245,8 +245,8 @@ unfoldLambdas' = first length . unfoldLambdas
{------------------------------------------------------------------------}
{- functions on Pattern -}

getBinderPatternInfos :: Pattern -> [Binder]
getBinderPatternInfos = go []
getPatternBinders :: Pattern -> [Binder]
getPatternBinders = reverse . go []
where
go :: [Binder] -> Pattern -> [Binder]
go acc = \case
Expand All @@ -255,7 +255,7 @@ getBinderPatternInfos = go []
PatWildcard {} -> acc

getPatternInfos :: Pattern -> [Info]
getPatternInfos = go []
getPatternInfos = reverse . go []
where
go :: [Info] -> Pattern -> [Info]
go acc = \case
Expand Down Expand Up @@ -342,6 +342,12 @@ twoChildren f i _ ch = case ch of
[l, r] -> f i l r
_ -> impossible

{-# INLINE threeChildren #-}
threeChildren :: (Info -> NodeChild -> NodeChild -> NodeChild -> Node) -> Reassemble
threeChildren f i _ ch = case ch of
[a, b, c] -> f i a b c
_ -> impossible

{-# INLINE manyChildren #-}
manyChildren :: (Info -> [NodeChild] -> Node) -> Reassemble
manyChildren f i _ = f i
Expand All @@ -360,6 +366,10 @@ twoManyChildrenI f i is = \case
(x : y : xs) -> f i is x y xs
_ -> impossible

{-# INLINE input' #-}
input' :: Members '[Input (Maybe a)] r => Sem r a
input' = fmap fromJust input

-- | Destruct a node into NodeDetails. This is an ugly internal function used to
-- implement more high-level accessors and recursors.
destruct :: Node -> NodeDetails
Expand Down Expand Up @@ -410,16 +420,21 @@ destruct = \case
NodeDetails
{ _nodeInfo = i,
_nodeSubinfos = [],
_nodeChildren = [oneBinder bi b],
_nodeReassemble = oneChild $ \i' ch' -> mkLambda i' (hd (ch' ^. childBinders)) (ch' ^. childNode)
_nodeChildren = [noBinders (bi ^. binderType), oneBinder bi b],
_nodeReassemble = twoChildren $ \i' ty' b' ->
let binder' :: Binder
binder' = set binderType (ty' ^. childNode) bi
in mkLambda i' binder' (b' ^. childNode)
}
NLet (Let i (LetItem bi v) b) ->
NodeDetails
{ _nodeInfo = i,
_nodeSubinfos = [],
_nodeChildren = [noBinders v, oneBinder bi b],
_nodeReassemble = twoChildren $ \i' v' b' ->
mkLet i' (hd (b' ^. childBinders)) (v' ^. childNode) (b' ^. childNode)
_nodeChildren = [noBinders (bi ^. binderType), noBinders v, oneBinder bi b],
_nodeReassemble = threeChildren $ \i' ty' v' b' ->
let binder' :: Binder
binder' = set binderType (ty' ^. childNode) bi
in mkLet i' binder' (v' ^. childNode) (b' ^. childNode)
}
NRec (LetRec i vs b) ->
NodeDetails
Expand All @@ -429,89 +444,142 @@ destruct = \case
let binders :: [Binder]
values :: [Node]
(binders, values) = unzip [(it ^. letItemBinder, it ^. letItemValue) | it <- toList vs]
in map (manyBinders binders) (b : values),
_nodeReassemble = someChildren $ \i' (b' :| values') ->
let items' =
binderTypes :: [Type]
binderTypes = map (^. binderType) binders
in map (manyBinders binders) (b : values) ++ map noBinders binderTypes,
_nodeReassemble = someChildren $ \i' (b' :| valuesTys') ->
let numItems :: Int
numItems = length vs
tys' :: [Type]
values' :: [NodeChild]
(values', tys') = second (map (^. childNode)) (splitAtExact numItems valuesTys')
items' =
nonEmpty'
[ LetItem (item ^. letItemBinder) (v' ^. childNode) | (v', item) <- zipExact values' (toList vs)
[ LetItem (Binder name ty') (v' ^. childNode)
| (v', ty', name) <-
zip3Exact
values'
tys'
(map (^. letItemBinder . binderName) (toList vs))
]
in mkLetRec i' items' (b' ^. childNode)
}
NCase (Case i v brs mdef) ->
let branchChildren :: [NodeChild]
let branchChildren :: [([Binder], NodeChild)]
branchChildren =
[ manyBinders (br ^. caseBranchBinders) (br ^. caseBranchBody)
| br <- brs
[ (binders, manyBinders binders (br ^. caseBranchBody))
| br <- brs,
let binders = br ^. caseBranchBinders
]
-- in this list we have the bodies and the binder types interleaved
allNodes :: [NodeChild]
allNodes =
concat
[ b : map (noBinders . (^. binderType)) bi
| (bi, b) <- branchChildren
]
mkBranch :: Info -> CaseBranch -> Sem '[Input (Maybe NodeChild)] CaseBranch
mkBranch nfo' br = do
b' <- input'
let nBinders = br ^. caseBranchBindersNum
tys' <- map (^. childNode) <$> replicateM nBinders input'
return
br
{ _caseBranchInfo = nfo',
_caseBranchBinders = zipWithExact (set binderType) tys' (b' ^. childBinders),
_caseBranchBody = b' ^. childNode
}
mkBranches :: [Info] -> [NodeChild] -> [CaseBranch]
mkBranches is' allNodes' =
run $
runInputList allNodes' $
sequence
[ mkBranch ci' br
| (ci', br) <- zipExact is' brs
]
in case mdef of
Nothing ->
NodeDetails
{ _nodeInfo = i,
_nodeSubinfos = map (^. caseBranchInfo) brs,
_nodeChildren = noBinders v : branchChildren,
_nodeReassemble = someChildrenI $ \i' is' (v' :| bodies') ->
let branches :: [CaseBranch]
branches =
[ br
{ _caseBranchInfo = ib',
_caseBranchBinders = body' ^. childBinders,
_caseBranchBody = body' ^. childNode
}
| (body', ib', br) <- zip3Exact bodies' is' brs
]
in mkCase i' (v' ^. childNode) branches Nothing
_nodeChildren = noBinders v : allNodes,
_nodeReassemble = someChildrenI $ \i' is' (v' :| allNodes') ->
mkCase i' (v' ^. childNode) (mkBranches is' allNodes') Nothing
}
Just def ->
NodeDetails
{ _nodeInfo = i,
_nodeSubinfos = map (^. caseBranchInfo) brs,
_nodeChildren = noBinders v : noBinders def : branchChildren,
_nodeReassemble = twoManyChildrenI $ \i' is' v' def' bodies' ->
let branches :: [CaseBranch]
branches =
[ br
{ _caseBranchInfo = ib',
_caseBranchBinders = body' ^. childBinders,
_caseBranchBody = body' ^. childNode
}
| (body', ib', br) <- zip3Exact bodies' is' brs
]
in mkCase i' (v' ^. childNode) branches (Just (def' ^. childNode))
_nodeChildren = noBinders v : noBinders def : allNodes,
_nodeReassemble = twoManyChildrenI $ \i' is' v' def' allNodes' ->
mkCase i' (v' ^. childNode) (mkBranches is' allNodes') (Just (def' ^. childNode))
}
NMatch (Match i vs branches) ->
let branchChildren :: [NodeChild]
branchChildren =
[ manyBinders binders (br ^. matchBranchBody)
| br <- branches,
let binders = concatMap getBinderPatternInfos (reverse (toList (br ^. matchBranchPatterns)))
]
branchPatternInfos :: [Info]
branchPatternInfos =
concatMap
( \br ->
concatMap
(reverse . getPatternInfos)
(br ^. matchBranchPatterns)
)
branches
n = length vs
let allNodes :: [NodeChild]
allNodes =
concat
[ b
: map (noBinders . (^. binderType)) bis
| (bis, b) <- branchChildren
]
where
branchChildren :: [([Binder], NodeChild)]
branchChildren =
[ (binders, manyBinders binders (br ^. matchBranchBody))
| br <- branches,
let binders = concatMap getPatternBinders (toList (br ^. matchBranchPatterns))
]
branchInfos :: [Info]
branchInfos =
concat
[ br
^. matchBranchInfo
: concatMap getPatternInfos (br ^. matchBranchPatterns)
| br <- branches
]
setPatternsInfos :: forall r. Members '[Input (Maybe Info), Input (Maybe NodeChild)] r => NonEmpty Pattern -> Sem r (NonEmpty Pattern)
setPatternsInfos = mapM goPattern
where
goPattern :: Pattern -> Sem r Pattern
goPattern = \case
PatWildcard x -> do
i' <- input'
return (PatWildcard (set patternWildcardInfo i' x))
PatBinder x -> do
ty <- (^. childNode) <$> input'
let _patternBinder = set binderType ty (x ^. patternBinder)
_patternBinderPattern <- goPattern (x ^. patternBinderPattern)
return (PatBinder PatternBinder {..})
PatConstr x -> do
i' <- input'
args' <- mapM goPattern (x ^. patternConstrArgs)
return (PatConstr (set patternConstrInfo i' (set patternConstrArgs args' x)))
in NodeDetails
{ _nodeInfo = i,
_nodeSubinfos = branchPatternInfos,
_nodeChildren = map noBinders (toList vs) ++ branchChildren,
_nodeSubinfos = branchInfos,
_nodeChildren = map noBinders (toList vs) ++ allNodes,
_nodeReassemble = someChildrenI $ \i' is' chs' ->
let values' :: NonEmpty NodeChild
bodies' :: [NodeChild]
(values', bodies') = first nonEmpty' (splitAtExact n (toList chs'))
let mkBranch :: MatchBranch -> Sem '[Input (Maybe NodeChild), Input (Maybe Info)] MatchBranch
mkBranch br = do
bi' <- input'
b' <- input'
pats' <- setPatternsInfos (br ^. matchBranchPatterns)
return
br
{ _matchBranchInfo = bi',
_matchBranchPatterns = pats',
_matchBranchBody = b' ^. childNode
}
numVals = length vs
values' :: NonEmpty NodeChild
branchesChilds' :: [NodeChild]
(values', branchesChilds') = first nonEmpty' (splitAtExact numVals (toList chs'))
branches' :: [MatchBranch]
branches' =
[ br
{ _matchBranchPatterns = nonEmpty' $ setPatternsInfos binders' is' (toList (br ^. matchBranchPatterns)),
_matchBranchBody = body' ^. childNode
}
| (body', br) <- zipExact bodies' branches,
let binders' = body' ^. childBinders
]
run $
runInputList is' $
runInputList branchesChilds' $
mapM mkBranch branches
in mkMatch i' (fmap (^. childNode) values') branches'
}
NPi (Pi i bi b) ->
Expand All @@ -520,10 +588,9 @@ destruct = \case
_nodeSubinfos = [],
_nodeChildren = [noBinders (bi ^. binderType), oneBinder bi b],
_nodeReassemble = twoChildren $ \i' bi' b' ->
-- NOTE the binder type here is treated as a node
let binder :: Binder
binder = set binderType (bi' ^. childNode) (hd (b' ^. childBinders))
in mkPi i' binder (b' ^. childNode)
let binder' :: Binder
binder' = set binderType (bi' ^. childNode) bi
in mkPi i' binder' (b' ^. childNode)
}
NUniv (Univ i l) ->
NodeDetails
Expand Down Expand Up @@ -561,32 +628,6 @@ destruct = \case
_nodeReassemble = someChildren $ \i' (b' :| env') ->
Closure (map (^. childNode) env') (Lambda i' bi (b' ^. childNode))
}
where
setPatternsInfos :: [Binder] -> [Info] -> [Pattern] -> [Pattern]
setPatternsInfos binders infos = snd . setPatternsInfos' binders infos
where
setPatternsInfos' :: [Binder] -> [Info] -> [Pattern] -> (([Binder], [Info]), [Pattern])
setPatternsInfos' bs is [] = ((bs, is), [])
setPatternsInfos' bs is (p : ps) =
let ((bs', is'), p') = setPatInfos bs is p
(bis'', ps') = setPatternsInfos' bs' is' ps
in (bis'', p' : ps')

setPatInfos :: [Binder] -> [Info] -> Pattern -> (([Binder], [Info]), Pattern)
setPatInfos bs is = \case
PatWildcard x ->
((bs, tl is), PatWildcard (x {_patternWildcardInfo = hd is}))
PatBinder x ->
((tl bs, is), PatBinder (x {_patternBinder = hd bs}))
PatConstr x ->
let (bis', ps) = setPatternsInfos' bs (tl is) (x ^. patternConstrArgs)
in (bis', PatConstr (x {_patternConstrInfo = hd is, _patternConstrArgs = ps}))

hd :: [a] -> a
hd = List.head

tl :: [a] -> [a]
tl = List.tail

reassembleDetails :: NodeDetails -> [Node] -> Node
reassembleDetails d ns = (d ^. nodeReassemble) (d ^. nodeInfo) (d ^. nodeSubinfos) children'
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Translation/FromSource.hs
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ matchBranch patsNum varsNum vars = do
unless (length pats == patsNum) $
parseFailure off "wrong number of patterns"
let pis :: [Binder]
pis = concatMap (reverse . getBinderPatternInfos) pats
pis = concatMap getPatternBinders pats
(vars', varsNum') =
foldl'
( \(vs, k) name ->
Expand Down

0 comments on commit 23c2b9e

Please sign in to comment.