Skip to content

Commit

Permalink
Automatically detect and split mutually recursive blocks in let expre…
Browse files Browse the repository at this point in the history
…ssions (#1894)

- Closes #1677
  • Loading branch information
janmasrovira authored Mar 17, 2023
1 parent da44ad6 commit 934a273
Show file tree
Hide file tree
Showing 24 changed files with 294 additions and 126 deletions.
61 changes: 44 additions & 17 deletions src/Juvix/Compiler/Abstract/Extra/DependencyBuilder.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module Juvix.Compiler.Abstract.Extra.DependencyBuilder (buildDependencyInfo, ExportsTable) where
module Juvix.Compiler.Abstract.Extra.DependencyBuilder (buildDependencyInfo, buildDependencyInfoExpr, ExportsTable) where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
Expand All @@ -18,7 +18,23 @@ type ExportsTable = HashSet NameId

buildDependencyInfo :: NonEmpty TopModule -> ExportsTable -> NameDependencyInfo
buildDependencyInfo ms tab =
createDependencyInfo graph startNodes
buildDependencyInfoHelper tab (mapM_ goModule ms)

buildDependencyInfoExpr :: Expression -> NameDependencyInfo
buildDependencyInfoExpr = buildDependencyInfoHelper mempty . goExpression Nothing

buildDependencyInfoHelper ::
ExportsTable ->
( Sem
'[ Reader ExportsTable,
State DependencyGraph,
State StartNodes,
State VisitedModules
]
()
) ->
NameDependencyInfo
buildDependencyInfoHelper tbl m = createDependencyInfo graph startNodes
where
startNodes :: StartNodes
graph :: DependencyGraph
Expand All @@ -27,12 +43,14 @@ buildDependencyInfo ms tab =
evalState (HashSet.empty :: VisitedModules) $
runState HashSet.empty $
execState HashMap.empty $
runReader tab $
mapM_ goModule ms
runReader tbl m

addStartNode :: (Member (State StartNodes) r) => Name -> Sem r ()
addStartNode n = modify (HashSet.insert n)

addEdgeMay :: (Member (State DependencyGraph) r) => Maybe Name -> Name -> Sem r ()
addEdgeMay mn1 n2 = whenJust mn1 $ \n1 -> addEdge n1 n2

addEdge :: (Member (State DependencyGraph) r) => Name -> Name -> Sem r ()
addEdge n1 n2 =
modify
Expand Down Expand Up @@ -87,16 +105,16 @@ goStatement modName = \case
StatementAxiom ax -> do
checkStartNode (ax ^. axiomName)
addEdge (ax ^. axiomName) modName
goExpression (ax ^. axiomName) (ax ^. axiomType)
goExpression (Just (ax ^. axiomName)) (ax ^. axiomType)
StatementFunction f -> goTopFunctionDef modName f
StatementImport m -> guardNotVisited (m ^. moduleName) (goModule m)
StatementLocalModule m -> goLocalModule modName m
StatementInductive i -> do
checkStartNode (i ^. inductiveName)
checkBuiltinInductiveStartNode i
addEdge (i ^. inductiveName) modName
mapM_ (goFunctionParameter (i ^. inductiveName)) (i ^. inductiveParameters)
goExpression (i ^. inductiveName) (i ^. inductiveType)
mapM_ (goFunctionParameter (Just (i ^. inductiveName))) (i ^. inductiveParameters)
goExpression (Just (i ^. inductiveName)) (i ^. inductiveType)
mapM_ (goConstructorDef (i ^. inductiveName)) (i ^. inductiveConstructors)

goTopFunctionDef :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> FunctionDef -> Sem r ()
Expand All @@ -110,22 +128,22 @@ goFunctionDefHelper ::
Sem r ()
goFunctionDefHelper f = do
checkStartNode (f ^. funDefName)
goExpression (f ^. funDefName) (f ^. funDefTypeSig)
goExpression (Just (f ^. funDefName)) (f ^. funDefTypeSig)
mapM_ (goFunctionClause (f ^. funDefName)) (f ^. funDefClauses)

-- constructors of an inductive type depend on the inductive type, not the other
-- way round; an inductive type depends on the types of its constructors
goConstructorDef :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> InductiveConstructorDef -> Sem r ()
goConstructorDef indName c = do
addEdge (c ^. constructorName) indName
goExpression indName (c ^. constructorType)
goExpression (Just indName) (c ^. constructorType)

goFunctionClause :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> FunctionClause -> Sem r ()
goFunctionClause p c = do
mapM_ (goPattern p) (c ^. clausePatterns)
goExpression p (c ^. clauseBody)
mapM_ (goPattern (Just p)) (c ^. clausePatterns)
goExpression (Just p) (c ^. clauseBody)

goPattern :: forall r. (Member (State DependencyGraph) r) => Name -> PatternArg -> Sem r ()
goPattern :: forall r. (Member (State DependencyGraph) r) => Maybe Name -> PatternArg -> Sem r ()
goPattern n p = case p ^. patternArgPattern of
PatternVariable {} -> return ()
PatternWildcard {} -> return ()
Expand All @@ -134,12 +152,17 @@ goPattern n p = case p ^. patternArgPattern of
where
goApp :: ConstructorApp -> Sem r ()
goApp (ConstructorApp ctr ps) = do
addEdge n (ctr ^. constructorRefName)
addEdgeMay n (ctr ^. constructorRefName)
mapM_ (goPattern n) ps

goExpression :: forall r. (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> Expression -> Sem r ()
goExpression ::
forall r.
(Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) =>
Maybe Name ->
Expression ->
Sem r ()
goExpression p e = case e of
ExpressionIden i -> addEdge p (idenName i)
ExpressionIden i -> addEdgeMay p (idenName i)
ExpressionUniverse {} -> return ()
ExpressionFunction f -> do
goFunctionParameter p (f ^. funParameter)
Expand Down Expand Up @@ -177,8 +200,12 @@ goExpression p e = case e of
goLetClause :: LetClause -> Sem r ()
goLetClause = \case
LetFunDef f -> do
addEdge p (f ^. funDefName)
addEdgeMay p (f ^. funDefName)
goFunctionDefHelper f

goFunctionParameter :: (Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) => Name -> FunctionParameter -> Sem r ()
goFunctionParameter ::
(Members '[State DependencyGraph, State StartNodes, Reader ExportsTable] r) =>
Maybe Name ->
FunctionParameter ->
Sem r ()
goFunctionParameter p param = goExpression p (param ^. paramType)
61 changes: 30 additions & 31 deletions src/Juvix/Compiler/Core/Translation/FromInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,9 @@ goFunctionDef ::
Sem r ()
goFunctionDef ((f, sym), ty) = do
mbody <- case f ^. Internal.funDefBuiltin of
Just b | isIgnoredBuiltin b -> return Nothing
Just _ -> Just <$> runReader initIndexTable (mkFunBody ty f)
Just b
| isIgnoredBuiltin b -> return Nothing
| otherwise -> Just <$> runReader initIndexTable (mkFunBody ty f)
Nothing -> Just <$> runReader initIndexTable (mkFunBody ty f)
forM_ mbody (registerIdentNode sym)
forM_ mbody setIdentArgsInfo'
Expand Down Expand Up @@ -461,35 +462,33 @@ goLet ::
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable] r) =>
Internal.Let ->
Sem r Node
goLet l = do
vars <- asks (^. indexTableVars)
varsNum <- asks (^. indexTableVarsNum)
let bs :: [Name]
bs = map (\(Internal.LetFunDef Internal.FunctionDef {..}) -> _funDefName) (toList $ l ^. Internal.letClauses)
(vars', varsNum') =
foldl'
( \(vs, k) name ->
(HashMap.insert (name ^. nameId) k vs, k + 1)
)
(vars, varsNum)
bs
(defs, value) <- do
values <-
mapM
( \(Internal.LetFunDef f) -> do
funTy <- goType (f ^. Internal.funDefType)

funBody <- local (set indexTableVars vars' . set indexTableVarsNum varsNum') (mkFunBody funTy f)
return (funTy, funBody)
)
(l ^. Internal.letClauses)

lbody <-
local
(set indexTableVars vars' . set indexTableVarsNum varsNum')
(goExpression (l ^. Internal.letExpression))
return (values, lbody)
return $ mkLetRec' defs value
goLet l = goClauses (toList (l ^. Internal.letClauses))
where
goClauses :: [Internal.LetClause] -> Sem r Node
goClauses = \case
[] -> goExpression (l ^. Internal.letExpression)
c : cs -> case c of
Internal.LetFunDef f -> goNonRecFun f
Internal.LetMutualBlock m -> goMutual m
where
goNonRecFun :: Internal.FunctionDef -> Sem r Node
goNonRecFun f =
do
funTy <- goType (f ^. Internal.funDefType)
funBody <- mkFunBody funTy f
rest <- localAddName (f ^. Internal.funDefName) (goClauses cs)
return $ mkLet' funTy funBody rest
goMutual :: Internal.MutualBlock -> Sem r Node
goMutual (Internal.MutualBlock funs) = do
let lfuns = toList funs
names = map (^. Internal.funDefName) lfuns
tys = map (^. Internal.funDefType) lfuns
tys' <- mapM goType tys
localAddNames names $ do
vals' <- sequence [mkFunBody ty f | (ty, f) <- zipExact tys' lfuns]
let items = nonEmpty' (zip tys' vals')
rest <- goClauses cs
return (mkLetRec' items rest)

goAxiomInductive ::
forall r.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,24 @@ makeLenses ''IndexTable
initIndexTable :: IndexTable
initIndexTable = IndexTable 0 mempty

localAddName :: forall r a. (Member (Reader IndexTable) r) => Name -> Sem r a -> Sem r a
localAddName n s = do
localAddName :: Member (Reader IndexTable) r => Name -> Sem r a -> Sem r a
localAddName n = localAddNames [n]

localAddNames :: forall r a. (Member (Reader IndexTable) r) => [Name] -> Sem r a -> Sem r a
localAddNames names s = do
updateFn <- update
local updateFn s
where
len :: Int = length names
insertMany :: [(NameId, Index)] -> HashMap NameId Index -> HashMap NameId Index
insertMany l t = foldl' (\m (k, v) -> HashMap.insert k v m) t l
update :: Sem r (IndexTable -> IndexTable)
update = do
idx <- asks (^. indexTableVarsNum)
let newElems = zip (map (^. nameId) names) [idx ..]
return
( over indexTableVars (HashMap.insert (n ^. nameId) idx)
. over indexTableVarsNum (+ 1)
( over indexTableVars (insertMany newElems)
. over indexTableVarsNum (+ len)
)

underBinders :: Members '[Reader IndexTable] r => Int -> Sem r a -> Sem r a
Expand Down
20 changes: 18 additions & 2 deletions src/Juvix/Compiler/Internal/Data/InfoTable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,25 @@ extendWithReplExpression e =
over
infoFunctions
( HashMap.union
(HashMap.fromList [(f ^. funDefName, FunctionInfo f) | LetFunDef f <- universeBi e])
( HashMap.fromList
[ (f ^. funDefName, FunctionInfo f)
| f <- letFunctionDefs e
]
)
)

letFunctionDefs :: Data from => from -> [FunctionDef]
letFunctionDefs e =
concat
[ concatMap (toList . flattenClause) _letClauses
| Let {..} <- universeBi e
]
where
flattenClause :: LetClause -> NonEmpty FunctionDef
flattenClause = \case
LetFunDef f -> pure f
LetMutualBlock (MutualBlock fs) -> fs

-- | moduleName ↦ infoTable
type Cache = HashMap Name InfoTable

Expand Down Expand Up @@ -117,7 +133,7 @@ buildTable1' m = do
]
<> [ (f ^. funDefName, FunctionInfo f)
| s <- filter (not . isInclude) ss,
LetFunDef f <- universeBi s
f <- letFunctionDefs s
]
where
isInclude :: Statement -> Bool
Expand Down
5 changes: 5 additions & 0 deletions src/Juvix/Compiler/Internal/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,14 @@ instance HasExpressions Case where
where
_caseParens = l ^. caseParens

instance HasExpressions MutualBlock where
leafExpressions f (MutualBlock defs) =
MutualBlock <$> traverse (leafExpressions f) defs

instance HasExpressions LetClause where
leafExpressions f = \case
LetFunDef d -> LetFunDef <$> leafExpressions f d
LetMutualBlock b -> LetMutualBlock <$> leafExpressions f b

instance HasExpressions Let where
leafExpressions f l = do
Expand Down
14 changes: 11 additions & 3 deletions src/Juvix/Compiler/Internal/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ data Statement
newtype MutualBlock = MutualBlock
{ _mutualFunctions :: NonEmpty FunctionDef
}
deriving stock (Data)
deriving stock (Eq, Generic, Data)

instance Hashable MutualBlock

data AxiomDef = AxiomDef
{ _axiomName :: AxiomName,
Expand Down Expand Up @@ -98,8 +100,10 @@ data TypedExpression = TypedExpression
_typedExpression :: Expression
}

newtype LetClause
= LetFunDef FunctionDef
data LetClause
= -- | Non-recursive let definition
LetFunDef FunctionDef
| LetMutualBlock MutualBlock
deriving stock (Eq, Generic, Data)

instance Hashable LetClause
Expand Down Expand Up @@ -367,9 +371,13 @@ instance HasLoc FunctionClause where
instance HasLoc FunctionDef where
getLoc f = getLoc (f ^. funDefName) <> getLocSpan (f ^. funDefClauses)

instance HasLoc MutualBlock where
getLoc (MutualBlock defs) = getLocSpan defs

instance HasLoc LetClause where
getLoc = \case
LetFunDef f -> getLoc f
LetMutualBlock f -> getLoc f

instance HasLoc Let where
getLoc l = getLocSpan (l ^. letClauses) <> getLoc (l ^. letExpression)
Expand Down
9 changes: 9 additions & 0 deletions src/Juvix/Compiler/Internal/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,17 @@ instance PrettyCode Let where
return $ kwLet <+> letClauses' <+> kwIn <+> letExpression'

instance PrettyCode LetClause where
ppCode :: forall r. Member (Reader Options) r => LetClause -> Sem r (Doc Ann)
ppCode = \case
LetFunDef f -> ppCode f
LetMutualBlock b -> ppMutual b
where
ppMutual :: MutualBlock -> Sem r (Doc Ann)
ppMutual m@(MutualBlock b)
| [_] <- toList b = ppCode b
| otherwise = do
b' <- ppCode m
return (kwMutual <+> braces (line <> indent' b' <> line))

ppPipeBlock :: (PrettyCode a, Members '[Reader Options] r, Traversable t) => t a -> Sem r (Doc Ann)
ppPipeBlock items = vsep <$> mapM (fmap (kwPipe <+>) . ppCode) items
Expand Down
Loading

0 comments on commit 934a273

Please sign in to comment.