Skip to content

Commit

Permalink
Letrec lifting (#1579)
Browse files Browse the repository at this point in the history
  • Loading branch information
janmasrovira authored Oct 21, 2022
1 parent 9e7a8a9 commit b02f2f8
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 59 deletions.
36 changes: 24 additions & 12 deletions app/Commands/Dev/Core/Eval.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,36 @@ doEval noIO loc tab node
| noIO = embed $ Core.catchEvalError loc (Core.eval (tab ^. Core.identContext) [] node)
| otherwise = embed $ Core.catchEvalErrorIO loc (Core.evalIO (tab ^. Core.identContext) [] node)

evalAndPrint ::
forall r.
Members '[Embed IO, App] r =>
CoreEvalOptions ->
Core.InfoTable ->
Core.Node ->
Sem r ()
evalAndPrint opts tab node = do
r <- doEval (opts ^. coreEvalNoIO) defaultLoc tab node
case r of
Left err -> exitJuvixError (JuvixError err)
Right node'
| Info.member Info.kNoDisplayInfo (Core.getInfo node') ->
return ()
Right node' -> do
renderStdOut (Core.ppOut opts node')
embed (putStrLn "")
where
defaultLoc :: Interval
defaultLoc = singletonInterval (mkLoc 0 (M.initialPos f))
f :: FilePath
f = opts ^. coreEvalInputFile . pathPath

runCommand :: forall r. Members '[Embed IO, App] r => CoreEvalOptions -> Sem r ()
runCommand opts = do
s <- embed (readFile f)
case Core.runParser f Core.emptyInfoTable s of
Left err -> exitJuvixError (JuvixError err)
Right (tab, Just node) -> do
r <- doEval (opts ^. coreEvalNoIO) defaultLoc tab node
case r of
Left err -> exitJuvixError (JuvixError err)
Right node'
| Info.member Info.kNoDisplayInfo (Core.getInfo node') ->
return ()
Right node' -> do
renderStdOut (Core.ppOut opts node')
embed (putStrLn "")
Right (tab, Just node) -> do evalAndPrint opts tab node
Right (_, Nothing) -> return ()
where
defaultLoc :: Interval
defaultLoc = singletonInterval (mkLoc 0 (M.initialPos f))
f :: FilePath
f = opts ^. coreEvalInputFile . pathPath
14 changes: 12 additions & 2 deletions app/Commands/Dev/Core/Read.hs
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
module Commands.Dev.Core.Read where

import Commands.Base
import Commands.Dev.Core.Eval qualified as Eval
import Commands.Dev.Core.Read.Options
import Juvix.Compiler.Core.Pretty qualified as Core
import Juvix.Compiler.Core.Scoper qualified as Scoper
import Juvix.Compiler.Core.Transformation qualified as Core
import Juvix.Compiler.Core.Translation.FromSource qualified as Core

runCommand :: forall r. Members '[Embed IO, App] r => CoreReadOptions -> Sem r ()
runCommand opts = do
s' <- embed (readFile f)
tab <- getRight (fst <$> mapLeft JuvixError (Core.runParser f Core.emptyInfoTable s'))
(tab, mnode) <- getRight (mapLeft JuvixError (Core.runParser f Core.emptyInfoTable s'))
let tab' = Core.applyTransformations (opts ^. coreReadTransformations) tab
renderStdOut (Core.ppOut opts tab')
embed (Scoper.scopeTrace tab')
unless (opts ^. coreReadNoPrint) (renderStdOut (Core.ppOut opts tab'))
whenJust mnode $ doEval tab'
where
doEval :: Core.InfoTable -> Core.Node -> Sem r ()
doEval tab' node = when (opts ^. coreReadEval) $ do
embed (putStrLn "--------------------------------")
embed (putStrLn "| Eval |")
embed (putStrLn "--------------------------------")
Eval.evalAndPrint (project opts) tab' node
f :: FilePath
f = opts ^. coreReadInputFile . pathPath
21 changes: 21 additions & 0 deletions app/Commands/Dev/Core/Read/Options.hs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
module Commands.Dev.Core.Read.Options where

import Commands.Dev.Core.Eval.Options qualified as Eval
import CommonOptions
import Juvix.Compiler.Core.Data.TransformationId.Parser
import Juvix.Compiler.Core.Pretty.Options qualified as Core

data CoreReadOptions = CoreReadOptions
{ _coreReadTransformations :: [TransformationId],
_coreReadShowDeBruijn :: Bool,
_coreReadEval :: Bool,
_coreReadNoPrint :: Bool,
_coreReadInputFile :: Path
}
deriving stock (Data)
Expand All @@ -19,9 +22,27 @@ instance CanonicalProjection CoreReadOptions Core.Options where
{ Core._optShowDeBruijnIndices = c ^. coreReadShowDeBruijn
}

instance CanonicalProjection CoreReadOptions Eval.CoreEvalOptions where
project c =
Eval.CoreEvalOptions
{ _coreEvalNoIO = False,
_coreEvalInputFile = c ^. coreReadInputFile,
_coreEvalShowDeBruijn = c ^. coreReadShowDeBruijn
}

parseCoreReadOptions :: Parser CoreReadOptions
parseCoreReadOptions = do
_coreReadShowDeBruijn <- optDeBruijn
_coreReadNoPrint <-
switch
( long "no-print"
<> help "do not print the transformed code"
)
_coreReadEval <-
switch
( long "eval"
<> help "evaluate after the transformation"
)
_coreReadTransformations <-
option
(eitherReader parseTransf)
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Control.Exception qualified as Exception
import Data.HashMap.Strict qualified as HashMap
import Debug.Trace qualified as Debug
import GHC.Conc qualified as GHC
import GHC.Show as S
import GHC.Show qualified as S
import Juvix.Compiler.Core.Data.InfoTable
import Juvix.Compiler.Core.Error (CoreError (..))
import Juvix.Compiler.Core.Extra
Expand Down
25 changes: 16 additions & 9 deletions src/Juvix/Compiler/Core/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import Juvix.Compiler.Core.Language
isClosed :: Node -> Bool
isClosed = not . has freeVars

getFreeVars :: Node -> HashSet Var
getFreeVars n = HashSet.fromList (n ^.. freeVars)
freeVarsSet :: Node -> HashSet Var
freeVarsSet n = HashSet.fromList (n ^.. freeVars)

freeVars :: SimpleFold Node Var
freeVars f = ufoldNA reassemble go
Expand Down Expand Up @@ -98,16 +98,22 @@ captureFreeVars fv
| Just v <- s ^. at (u - k) -> NVar (Var i (v + k))
m -> m

-- | subst for multiple bindings
substs :: [Node] -> Node -> Node
substs t = umapN go
where
len = length t
go k n = case n of
NVar (Var i idx)
| idx >= k, idx - k < len -> shift k (t !! (idx - k))
| idx > k -> mkVar i (idx - len)
_ -> n

-- | substitute a term t for the free variable with de Bruijn index 0, avoiding
-- variable capture; shifts all free variabes with de Bruijn index > 0 by -1 (as
-- if the topmost binder was removed)
subst :: Node -> Node -> Node
subst t = umapN go
where
go k n = case n of
NVar (Var _ idx) | idx == k -> shift k t
NVar (Var i idx) | idx > k -> mkVar i (idx - 1)
_ -> n
subst t = substs [t]

-- | reduce all beta redexes present in a term and the ones created immediately
-- downwards (i.e., a "beta-development")
Expand All @@ -130,7 +136,8 @@ substEnv env
| otherwise = umapN go
where
go k n = case n of
NVar (Var _ idx) | idx >= k -> env !! (idx - k)
NVar (Var _ idx)
| idx >= k -> env !! (idx - k)
_ -> n

convertClosures :: Node -> Node
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Core/Language/Nodes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ data Let' i a = Let
-- | Represents a block of mutually recursive local definitions. Both in the
-- body and in the values `length _letRecValues` implicit binders are introduced
-- which hold the functions/values being defined.
-- the last item _letRecValues will have have index $0 in the body.
data LetRec' i a = LetRec
{ _letRecInfo :: i,
_letRecValues :: !(NonEmpty a),
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ppTrace' :: (CanonicalProjection a Options, PrettyCode c) => a -> c -> Text
ppTrace' opts = Ansi.renderStrict . reAnnotateS stylize . layoutPretty defaultLayoutOptions . doc (project opts)

ppTrace :: PrettyCode c => c -> Text
ppTrace = ppTrace' defaultOptions
ppTrace = ppTrace' traceOptions

ppPrint :: PrettyCode c => c -> Text
ppPrint = show . ppOutDefault
5 changes: 4 additions & 1 deletion src/Juvix/Compiler/Core/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,10 @@ instance PrettyCode InfoTable where
where
ppDef :: Symbol -> Node -> Sem r (Doc Ann)
ppDef s n = do
sym' <- maybe (return (pretty s)) ppCode (tbl ^? infoIdentifiers . at s . _Just . identifierName . _Just)
let mname :: Maybe Name
mname = tbl ^? infoIdentifiers . at s . _Just . identifierName . _Just
mname' = over (_Just . namePretty) (\nm -> nm <> "!" <> prettyText s) mname
sym' <- maybe (return (pretty s)) ppCode mname'
body' <- ppCode n
return (kwDef <+> sym' <+> kwAssign <+> nest 2 body')

Expand Down
31 changes: 31 additions & 0 deletions src/Juvix/Compiler/Core/Scoper.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
module Juvix.Compiler.Core.Scoper where

import Juvix.Compiler.Core.Data.InfoTable
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Language
import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation.Base

type ScopeError = Text

scopeCheck :: InfoTable -> Maybe ScopeError
scopeCheck = either Just (const Nothing) . run . runError . walkT goTopNode

goTopNode :: Members '[Error ScopeError] r => Symbol -> Node -> Sem r ()
goTopNode sym = runReader sym . walkN check

check :: Members '[Reader Symbol, Error ScopeError] r => Index -> Node -> Sem r ()
check k = \case
NVar v
| v ^. varIndex < k -> return ()
| otherwise -> scopeErr ("variable " <> ppTrace (NVar v) <> " is out of scope")
_ -> return ()

scopeErr :: Members '[Reader Symbol, Error ScopeError] r => Text -> Sem r a
scopeErr msg = do
sym <- ask @Symbol
throw @ScopeError ("Scope error in the definition of " <> show sym <> "\n" <> msg)

-- | prints the scope error without exiting
scopeTrace :: InfoTable -> IO ()
scopeTrace i = whenJust (scopeCheck i) putStrLn
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ mapT' f tab =
mapM_
(\(k, v) -> f v >>= registerIdentNode k)
(HashMap.toList (tab ^. identContext))

walkT :: Applicative f => (Symbol -> Node -> f ()) -> InfoTable -> f ()
walkT f tab = for_ (HashMap.toList (tab ^. identContext)) (uncurry f)
88 changes: 86 additions & 2 deletions src/Juvix/Compiler/Core/Transformation/LambdaLifting.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import Juvix.Compiler.Core.Data.BinderList (BinderList)
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Data.InfoTableBuilder
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.BinderInfo qualified as Info
import Juvix.Compiler.Core.Info.NameInfo
import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation.Base

Expand All @@ -26,12 +28,15 @@ lambdaLiftNode aboveBl top =
go :: BinderList Info -> Node -> Sem r Recur
go bl = \case
NLam l -> goLambda l
NRec l -> goLetRec l
m -> return (Recur m)
where
goLambda :: Lambda -> Sem r Recur
goLambda lm = do
l' <- lambdaLiftNode bl (NLam lm)
let freevars = toList (getFreeVars l')
let lambdaBinder :: Info
lambdaBinder = Info.getInfoBinder (lm ^. lambdaInfo)
l' <- lambdaLiftNode (BL.extend lambdaBinder bl) (NLam lm)
let freevars = toList (freeVarsSet l')
freevarsAssocs :: [(Index, Info)]
freevarsAssocs = [(i, BL.lookup i bl) | i <- map (^. varIndex) freevars]
fBody' = captureFreeVars freevarsAssocs l'
Expand All @@ -51,6 +56,85 @@ lambdaLiftNode aboveBl top =
let fApp = mkApps' (mkIdent mempty f) (map NVar freevars)
return (End fApp)

goLetRec :: LetRec -> Sem r Recur
goLetRec letr = do
let defs :: [Node]
defs = toList (letr ^. letRecValues)
ndefs :: Int
ndefs = length defs
letRecBinders :: [Info]
letRecBinders = Info.getInfoBinders ndefs (letr ^. letRecInfo)
bl' :: BinderList Info
bl' = BL.prepend letRecBinders bl
topSyms :: [Symbol] <- forM defs (const freshSymbol)
let recItemsFreeVars :: [(Var, Info)]
recItemsFreeVars = mapMaybe helper (toList (mconcatMap freeVarsSet defs))
where
-- free vars in each let
-- throw away variables bound in the letrec and shift others
helper :: Var -> Maybe (Var, Info)
helper v
| v ^. varIndex < ndefs = Nothing
| otherwise = Just (set varIndex idx' v, BL.lookup idx' bl)
where
idx' = (v ^. varIndex) - ndefs

subsCalls :: Node -> Node
subsCalls =
substs
( reverse
[ mkApps' (mkIdent' sym) (map (NVar . fst) recItemsFreeVars)
| sym <- topSyms
]
)
-- NOTE that we are first substituting the calls and then performing
-- lambda lifting. This is a tradeoff. We have slower compilation but
-- slightly faster execution time, since it minimizes the number of
-- free variables that need to be passed around.
liftedDefs <- mapM (lambdaLiftNode bl . subsCalls) defs
body' <- lambdaLiftNode bl' (letr ^. letRecBody)
let declareTopSyms :: Sem r ()
declareTopSyms =
sequence_
[ do
let topBody = captureFreeVars (map (first (^. varIndex)) recItemsFreeVars) b
argsInfo :: [ArgumentInfo]
argsInfo = map (argumentInfoFromInfo . snd) recItemsFreeVars
registerIdentNode sym topBody
registerIdent
IdentifierInfo
{ _identifierSymbol = sym,
_identifierName = getInfoName itemInfo,
_identifierType = typeFromArgs argsInfo,
_identifierArgsNum = length recItemsFreeVars,
_identifierArgsInfo = argsInfo,
_identifierIsExported = False
}
| (sym, (itemInfo, b)) <- zipExact topSyms (zipExact letRecBinders liftedDefs)
]
letItems :: [Node]
letItems =
let fv = recItemsFreeVars
in [ mkApps' (mkIdent' s) (map (NVar . fst) fv)
| s <- topSyms
]
declareTopSyms

let -- TODO it can probably be simplified
shiftHelper :: Node -> NonEmpty Node -> Node
shiftHelper b = goShift 0
where
goShift :: Int -> NonEmpty Node -> Node
goShift k = \case
x :| yys -> case yys of
[]
| k == ndefs - 1 -> mkLet' (shift k x) b
| otherwise -> impossible
(y : ys) -> mkLet' (shift k x) (goShift (k + 1) (y :| ys))
let res :: Node
res = shiftHelper body' (nonEmpty' letItems)
return (Recur res)

lambdaLifting :: InfoTable -> InfoTable
lambdaLifting = run . mapT' (lambdaLiftNode mempty)

Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Prelude/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ commonPrefix a b = reverse (go [] a b)
nonEmptyUnsnoc :: NonEmpty a -> (Maybe (NonEmpty a), a)
nonEmptyUnsnoc e = (NonEmpty.nonEmpty (NonEmpty.init e), NonEmpty.last e)

nonEmpty' :: HasCallStack => [a] -> NonEmpty a
nonEmpty' = fromJust . nonEmpty

_nonEmpty :: Lens' [a] (Maybe (NonEmpty a))
_nonEmpty f x = maybe [] toList <$> f (nonEmpty x)

Expand Down
14 changes: 0 additions & 14 deletions test/Core/Transformation/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ import Base
import Core.Eval.Base
import Core.Eval.Positive qualified as Eval
import Juvix.Compiler.Core.Data.InfoTable
import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation
import Prettyprinter.Render.Text qualified as Text

data Test = Test
{ _testTransformations :: [TransformationId],
Expand All @@ -29,15 +27,3 @@ toTestDescr Test {..} =
_testRoot = tRoot,
_testAssertion = Steps $ coreEvalAssertion _file _expectedFile _testTransformations _testAssertion
}

assertExpectedOutput :: FilePath -> InfoTable -> Assertion
assertExpectedOutput testExpectedFile r = do
expected <- readFile testExpectedFile
let actualOutput = Text.renderStrict (toTextStream (ppOut opts r))
assertEqDiff ("Check: output = " <> testExpectedFile) actualOutput expected
where
opts :: Options
opts =
defaultOptions
{ _optShowDeBruijnIndices = True
}
Loading

0 comments on commit b02f2f8

Please sign in to comment.