Skip to content

Commit

Permalink
implement only letrec
Browse files Browse the repository at this point in the history
  • Loading branch information
janmasrovira committed Feb 1, 2023
1 parent a0c1a7e commit fa3bf3c
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 36 deletions.
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Core/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ shift m = umapN go
| v ^. varIndex >= k -> NVar (shiftVar m v)
n -> n

-- | Prism for NRec
_NRec :: SimpleFold Node LetRec
_NRec f = \case
NRec l -> NRec <$> f l
n -> pure n

-- | Prism for NLam
_NLam :: SimpleFold Node Lambda
_NLam f = \case
Expand Down
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ applyTransformations ts tbl = foldl' (flip appTrans) tbl ts
where
appTrans :: TransformationId -> InfoTable -> InfoTable
appTrans = \case
LambdaLetRecLifting -> lambdaLifting
LetRecLifting -> letrecLifting
LambdaLetRecLifting -> lambdaLetRecLifting
LetRecLifting -> letRecLifting
Identity -> identity
TopEtaExpand -> topEtaExpand
RemoveTypeArgs -> removeTypeArgs
Expand Down
88 changes: 55 additions & 33 deletions src/Juvix/Compiler/Core/Transformation/LambdaLetRecLifting.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ import Juvix.Compiler.Core.Info.NameInfo
import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation.Base

lambdaLiftBinder :: (Member InfoTableBuilder r) => BinderList Binder -> Binder -> Sem r Binder
lambdaLiftBinder :: Members '[Reader OnlyLetRec, InfoTableBuilder] r => BinderList Binder -> Binder -> Sem r Binder
lambdaLiftBinder bl = traverseOf binderType (lambdaLiftNode bl)

lambdaLiftNode :: forall r. (Member InfoTableBuilder r) => BinderList Binder -> Node -> Sem r Node
type OnlyLetRec = Bool

lambdaLiftNode :: forall r. Members '[Reader OnlyLetRec, InfoTableBuilder] r => BinderList Binder -> Node -> Sem r Node
lambdaLiftNode aboveBl top =
let topArgs :: [LambdaLhs]
(topArgs, body) = unfoldLambdas top
Expand All @@ -40,30 +42,37 @@ lambdaLiftNode aboveBl top =
m -> return (Recur m)
where
goLambda :: Lambda -> Sem r Recur
goLambda lm = do
l' <- lambdaLiftNode bl (NLam lm)
let (freevarsAssocs, fBody') = captureFreeVarsCtx bl l'
allfreevars :: [Var]
allfreevars = map fst freevarsAssocs
argsInfo :: [ArgumentInfo]
argsInfo = map (argumentInfoFromBinder . (^. lambdaLhsBinder)) (fst (unfoldLambdas fBody'))
f <- freshSymbol
let name = uniqueName "lambda" f
registerIdent
name
IdentifierInfo
{ _identifierSymbol = f,
_identifierName = name,
_identifierLocation = Nothing,
_identifierType = typeFromArgs argsInfo,
_identifierArgsNum = length argsInfo,
_identifierArgsInfo = argsInfo,
_identifierIsExported = False,
_identifierBuiltin = Nothing
}
registerIdentNode f fBody'
let fApp = mkApps' (mkIdent (setInfoName name mempty) f) (map NVar allfreevars)
return (End fApp)
goLambda l = do
onlyLetRec <- ask @OnlyLetRec
if
| onlyLetRec -> return (Recur (NLam l))
| otherwise -> goLambdaGo l
where
goLambdaGo :: Lambda -> Sem r Recur
goLambdaGo lm = do
l' <- lambdaLiftNode bl (NLam lm)
let (freevarsAssocs, fBody') = captureFreeVarsCtx bl l'
allfreevars :: [Var]
allfreevars = map fst freevarsAssocs
argsInfo :: [ArgumentInfo]
argsInfo = map (argumentInfoFromBinder . (^. lambdaLhsBinder)) (fst (unfoldLambdas fBody'))
f <- freshSymbol
let name = uniqueName "lambda" f
registerIdent
name
IdentifierInfo
{ _identifierSymbol = f,
_identifierName = name,
_identifierLocation = Nothing,
_identifierType = typeFromArgs argsInfo,
_identifierArgsNum = length argsInfo,
_identifierArgsInfo = argsInfo,
_identifierIsExported = False,
_identifierBuiltin = Nothing
}
registerIdentNode f fBody'
let fApp = mkApps' (mkIdent (setInfoName name mempty) f) (map NVar allfreevars)
return (End fApp)

goLetRec :: LetRec -> Sem r Recur
goLetRec letr = do
Expand Down Expand Up @@ -148,18 +157,31 @@ lambdaLiftNode aboveBl top =
res = shiftHelper body' (nonEmpty' (zipExact letItems letRecBinders'))
return (Recur res)

lambdaLifting :: InfoTable -> InfoTable
lambdaLifting = run . mapT' (const (lambdaLiftNode mempty))
lifting :: Bool -> InfoTable -> InfoTable
lifting onlyLetRec = run . runReader onlyLetRec . mapT' (const (lambdaLiftNode mempty))

letrecLifting :: InfoTable -> InfoTable
letrecLifting = run . mapT' (const (lambdaLiftNode mempty))
lambdaLetRecLifting :: InfoTable -> InfoTable
lambdaLetRecLifting = lifting False

letRecLifting :: InfoTable -> InfoTable
letRecLifting = lifting True

-- | True if lambdas are only found at the top level
nodeIsLifted :: Node -> Bool
nodeIsLifted = not . hasNestedLambdas
nodeIsLifted = nodeIsLambdaLifted .&&. nodeIsLetRecLifted

-- | True if lambdas are only found at the top level
nodeIsLambdaLifted :: Node -> Bool
nodeIsLambdaLifted = not . hasNestedLambdas
where
hasNestedLambdas :: Node -> Bool
hasNestedLambdas = has (cosmos . _NLam) . snd . unfoldLambdas'

-- | True if there are no letrec nodes
nodeIsLetRecLifted :: Node -> Bool
nodeIsLetRecLifted = not . hasLetRecs
where
hasLetRecs :: Node -> Bool
hasLetRecs = has (cosmos . _NRec)

isLifted :: InfoTable -> Bool
isLifted = all nodeIsLifted . toList . (^. identContext)
isLifted = all nodeIsLifted . (^. identContext)
2 changes: 1 addition & 1 deletion test/Core/Transformation/Lifting.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Core.Transformation.Base
import Juvix.Compiler.Core.Transformation

allTests :: TestTree
allTests = testGroup "Lambda lifting" (map liftTest Eval.tests)
allTests = testGroup "Lambda and LetRec lifting" (map liftTest Eval.tests)

pipe :: [TransformationId]
pipe = [LambdaLetRecLifting]
Expand Down

0 comments on commit fa3bf3c

Please sign in to comment.