Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve inlining #2377

Merged
merged 5 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/Juvix/Compiler/Core/Data/BinderList.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,14 @@ lookupsSortedRev bl = go [] 0 bl
head' :: BinderList a -> a
head' = lookup 0

lookupMay :: Index -> BinderList a -> Maybe a
lookupMay idx bl
| idx < bl ^. blLength = Just $ (bl ^. blMap) !! idx
| otherwise = Nothing

-- | lookup de Bruijn Index
lookup :: Index -> BinderList a -> a
lookup idx bl
| idx < bl ^. blLength = (bl ^. blMap) !! idx
| otherwise = err
lookup idx bl = fromMaybe err (lookupMay idx bl)
where
err :: a
err =
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Extra/Recursors/Collector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ binderInfoCollector' ini = Collector ini collect
collect :: (Int, [Binder]) -> BinderList Binder -> BinderList Binder
collect (k, bi) c
| k == 0 = c
| otherwise = BL.prepend (reverse bi) c
| otherwise = BL.prependRev bi c

binderInfoCollector :: Collector (Int, [Binder]) (BinderList Binder)
binderInfoCollector = binderInfoCollector' mempty
Expand Down
32 changes: 23 additions & 9 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ filterOutTypeSynonyms tab = pruneInfoTable tab'
tab' = tab {_infoIdentifiers = idents'}
idents' = HashMap.filter (\ii -> not (isTypeConstr tab (ii ^. identifierType))) (tab ^. infoIdentifiers)

isType :: Node -> Bool
isType = \case
isType' :: Node -> Bool
isType' = \case
NPi {} -> True
NUniv {} -> True
NPrim {} -> True
Expand All @@ -83,6 +83,16 @@ isType = \case
NMatch {} -> False
Closure {} -> False

isType :: InfoTable -> BinderList Binder -> Node -> Bool
isType tab bl node = case node of
NVar Var {..}
| Just Binder {..} <- BL.lookupMay _varIndex bl ->
isTypeConstr tab _binderType
NIdt Ident {..}
| Just ii <- lookupIdentifierInfo' tab _identSymbol ->
isTypeConstr tab (ii ^. identifierType)
_ -> isType' node

-- | True for nodes whose evaluation immediately returns a value, i.e.,
-- no reduction or memory allocation in the runtime is required.
isImmediate :: InfoTable -> Node -> Bool
Expand All @@ -97,8 +107,8 @@ isImmediate tab = \case
| Just ii <- lookupIdentifierInfo' tab _identSymbol ->
let paramsNum = length (takeWhile (isTypeConstr tab) (typeArgs (ii ^. identifierType)))
in length args <= paramsNum
_ -> all isType args
node -> isType node
_ -> all (isType tab mempty) args
node -> isType tab mempty node

isImmediate' :: Node -> Bool
isImmediate' = isImmediate emptyInfoTable
Expand Down Expand Up @@ -350,13 +360,17 @@ translateCase translateIf dflt Case {..} = case _caseBranches of
branchFailure :: Node
branchFailure = mkBuiltinApp' OpFail [mkConstant' (ConstString "illegal `if` branch")]

checkDepth :: Int -> Node -> Bool
checkDepth 0 _ = False
checkDepth d node = case node of
checkDepth :: InfoTable -> BinderList Binder -> Int -> Node -> Bool
checkDepth tab bl 0 node = isType tab bl node
checkDepth tab bl d node = case node of
NApp App {..} ->
checkDepth d _appLeft && checkDepth (d - 1) _appRight
checkDepth tab bl d _appLeft && checkDepth tab bl (d - 1) _appRight
_ ->
all (checkDepth (d - 1)) (childrenNodes node)
all go (children node)
where
go :: NodeChild -> Bool
go NodeChild {..} =
checkDepth tab (BL.prependRev _childBinders bl) (d - 1) _childNode

isCaseBoolean :: [CaseBranch] -> Bool
isCaseBoolean = \case
Expand Down
17 changes: 10 additions & 7 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
module Juvix.Compiler.Core.Transformation.Optimize.Inlining where

import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base

isInlineableLambda :: Int -> Node -> Bool
isInlineableLambda inlineDepth node = case node of
isInlineableLambda :: Int -> InfoTable -> BinderList Binder -> Node -> Bool
isInlineableLambda inlineDepth tab bl node = case node of
NLam {} ->
checkDepth inlineDepth (snd (unfoldLambdas node))
let (lams, body) = unfoldLambdas node
binders = map (^. lambdaLhsBinder) lams
in checkDepth tab (BL.prependRev binders bl) inlineDepth body
_ ->
False

convertNode :: Int -> HashSet Symbol -> InfoTable -> Node -> Node
convertNode inlineDepth recSyms tab = dmap go
convertNode inlineDepth recSyms tab = dmapL go
where
go :: Node -> Node
go node = case node of
go :: BinderList Binder -> Node -> Node
go bl node = case node of
NApp {} ->
let (h, args) = unfoldApps node
in case h of
Expand All @@ -33,7 +36,7 @@ convertNode inlineDepth recSyms tab = dmap go
node
_
| not (HashSet.member _identSymbol recSyms)
&& isInlineableLambda inlineDepth def
&& isInlineableLambda inlineDepth tab bl def
&& length args >= argsNum ->
mkApps def args
_ ->
Expand Down
20 changes: 11 additions & 9 deletions src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,28 @@
-- ```
module Juvix.Compiler.Core.Transformation.Optimize.LetFolding (letFolding, letFolding') where

import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.FreeVarsInfo as Info
import Juvix.Compiler.Core.Transformation.Base

convertNode :: (Node -> Bool) -> InfoTable -> Node -> Node
convertNode isFoldable tab = rmap go
convertNode :: (InfoTable -> BinderList Binder -> Node -> Bool) -> InfoTable -> Node -> Node
convertNode isFoldable tab = rmapL go
where
go :: ([BinderChange] -> Node -> Node) -> Node -> Node
go recur = \case
go :: ([BinderChange] -> Node -> Node) -> BinderList Binder -> Node -> Node
go recur bl = \case
NLet Let {..}
| isImmediate tab (_letItem ^. letItemValue)
|| Info.freeVarOccurrences 0 _letBody <= 1
|| isFoldable (_letItem ^. letItemValue) ->
go (recur . (mkBCRemove (_letItem ^. letItemBinder) val' :)) _letBody
|| isFoldable tab bl (_letItem ^. letItemValue) ->
go (recur . (mkBCRemove b val' :)) (BL.cons b bl) _letBody
where
val' = go recur (_letItem ^. letItemValue)
val' = go recur bl (_letItem ^. letItemValue)
b = _letItem ^. letItemBinder
node ->
recur [] node

letFolding' :: (Node -> Bool) -> InfoTable -> InfoTable
letFolding' :: (InfoTable -> BinderList Binder -> Node -> Bool) -> InfoTable -> InfoTable
letFolding' isFoldable tab =
mapAllNodes
( removeInfo kFreeVarsInfo
Expand All @@ -41,4 +43,4 @@ letFolding' isFoldable tab =
tab

letFolding :: InfoTable -> InfoTable
letFolding = letFolding' (const False)
letFolding = letFolding' (\_ _ _ -> False)
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Translation/Stripped/FromCore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ translateNode node = case node of
(map translateCaseBranch _caseBranches)
(fmap translateNode _caseDefault)
_
| isType node ->
| isType' node ->
Stripped.mkConstr (Stripped.ConstrInfo "()" Nothing Stripped.TyDynamic) (BuiltinTag TagTrue) []
_ ->
unsupported
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Pipeline/EntryPoint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ defaultOptimizationLevel :: Int
defaultOptimizationLevel = 1

defaultInliningDepth :: Int
defaultInliningDepth = 2
defaultInliningDepth = 3

mainModulePath :: Traversal' EntryPoint (Path Abs File)
mainModulePath = entryPointModulePaths . _head
2 changes: 1 addition & 1 deletion test/Runtime/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ wasiArgs sysrootPath outputFile inputFile =
commonArgs outputFile
<> [ "-DARCH_WASM32",
"-DAPI_WASI",
"-Os",
"-O3",
"-nodefaultlibs",
"--target=wasm32-wasi",
"--sysroot",
Expand Down