Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute JuvixAsm stack usage info #1604

Merged
merged 3 commits into from
Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions src/Juvix/Compiler/Asm/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Juvix.Compiler.Asm.Language
validateCode :: forall r. Member (Error AsmError) r => InfoTable -> Arguments -> Code -> Sem r ()
validateCode tab args = void . recurse sig args
where
sig :: RecursorSig r ()
sig :: RecursorSig Memory r ()
sig =
RecursorSig
{ _recursorInfoTable = tab,
Expand All @@ -30,11 +30,49 @@ validateCode tab args = void . recurse sig args
validateFunction :: Member (Error AsmError) r => InfoTable -> FunctionInfo -> Sem r ()
validateFunction tab fi = validateCode tab (argumentsFromFunctionInfo fi) (fi ^. functionCode)

validateInfoTable :: Member (Error AsmError) r => InfoTable -> Sem r ()
validateInfoTable tab = mapM_ (validateFunction tab) (HashMap.elems (tab ^. infoFunctions))
validateInfoTable :: Member (Error AsmError) r => InfoTable -> Sem r InfoTable
validateInfoTable tab = do
mapM_ (validateFunction tab) (HashMap.elems (tab ^. infoFunctions))
return tab

validate :: InfoTable -> Maybe AsmError
validate tab =
case run $ runError $ validateInfoTable tab of
Left err -> Just err
_ -> Nothing

computeFunctionStackUsage :: Member (Error AsmError) r => InfoTable -> FunctionInfo -> Sem r FunctionInfo
computeFunctionStackUsage tab fi = do
ps <- snd <$> recurseS sig initialStackInfo (fi ^. functionCode)
let maxValueStack = maximum (map fst ps)
maxTempStack = maximum (map snd ps)
return
fi
{ _functionMaxValueStackHeight = maxValueStack,
_functionMaxTempStackHeight = maxTempStack
}
where
sig :: RecursorSig StackInfo r (Int, Int)
sig =
RecursorSig
{ _recursorInfoTable = tab,
_recurseInstr = \si _ -> return (si ^. stackInfoValueStackHeight, si ^. stackInfoTempStackHeight),
_recurseBranch = \si _ l r ->
return
( max (si ^. stackInfoValueStackHeight) (max (maximum (map fst l)) (maximum (map fst r))),
max (si ^. stackInfoTempStackHeight) (max (maximum (map snd l)) (maximum (map snd r)))
),
_recurseCase = \si _ cs md ->
return
( max (si ^. stackInfoValueStackHeight) (max (maximum (map (maximum . map fst) cs)) (maybe 0 (maximum . map fst) md)),
max (si ^. stackInfoTempStackHeight) (max (maximum (map (maximum . map snd) cs)) (maybe 0 (maximum . map snd) md))
)
}

computeStackUsage :: Member (Error AsmError) r => InfoTable -> Sem r InfoTable
computeStackUsage tab = do
fns <- mapM (computeFunctionStackUsage tab) (tab ^. infoFunctions)
return tab {_infoFunctions = fns}

computeStackUsage' :: InfoTable -> Either AsmError InfoTable
computeStackUsage' tab = run $ runError $ computeStackUsage tab
170 changes: 163 additions & 7 deletions src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@ import Juvix.Compiler.Asm.Language
import Juvix.Compiler.Asm.Pretty

-- | Recursor signature. Contains read-only recursor parameters.
data RecursorSig r a = RecursorSig
data RecursorSig m r a = RecursorSig
{ _recursorInfoTable :: InfoTable,
_recurseInstr :: Memory -> CmdInstr -> Sem r a,
_recurseBranch :: Memory -> CmdBranch -> [a] -> [a] -> Sem r a,
_recurseCase :: Memory -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a
_recurseInstr :: m -> CmdInstr -> Sem r a,
_recurseBranch :: m -> CmdBranch -> [a] -> [a] -> Sem r a,
_recurseCase :: m -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a
}

makeLenses ''RecursorSig

recurseFun :: Member (Error AsmError) r => RecursorSig r a -> FunctionInfo -> Sem r [a]
recurseFun :: Member (Error AsmError) r => RecursorSig Memory r a -> FunctionInfo -> Sem r [a]
recurseFun sig fi = recurse sig (argumentsFromFunctionInfo fi) (fi ^. functionCode)

recurse :: Member (Error AsmError) r => RecursorSig r a -> Arguments -> Code -> Sem r [a]
recurse :: Member (Error AsmError) r => RecursorSig Memory r a -> Arguments -> Code -> Sem r [a]
recurse sig args = fmap snd . recurse' sig (mkMemory args)

recurse' :: forall r a. Member (Error AsmError) r => RecursorSig r a -> Memory -> Code -> Sem r (Memory, [a])
recurse' :: forall r a. Member (Error AsmError) r => RecursorSig Memory r a -> Memory -> Code -> Sem r (Memory, [a])
recurse' sig = go True
where
go :: Bool -> Memory -> Code -> Sem r (Memory, [a])
Expand Down Expand Up @@ -276,3 +276,159 @@ recurse' sig = go True
)
$ throw
$ AsmError loc "temporary stack height changed after branching"

data StackInfo = StackInfo
{ _stackInfoValueStackHeight :: Int,
_stackInfoTempStackHeight :: Int
}
deriving stock (Eq)

makeLenses ''StackInfo

initialStackInfo :: StackInfo
initialStackInfo = StackInfo {_stackInfoValueStackHeight = 0, _stackInfoTempStackHeight = 0}

-- | A simplified recursor which doesn't perform validity checking and only
-- computes stack height information. This makes a significant performance
-- difference, since we'll have many passes recursing over the entire JuvixAsm
-- program code which need only stack height information. Also, the code using
-- the simplified recursor can itself be simpler if it doesn't need the extra
-- info provided by the full recursor.
recurseS :: forall r a. Member (Error AsmError) r => RecursorSig StackInfo r a -> StackInfo -> Code -> Sem r (StackInfo, [a])
recurseS sig = go
where
go :: StackInfo -> Code -> Sem r (StackInfo, [a])
go si = \case
[] -> return (si, [])
h : t -> case h of
Instr x -> do
goNextCmd (goInstr si x) t
Branch x ->
goNextCmd (goBranch si x) t
Case x ->
goNextCmd (goCase si x) t

goNextCmd :: Sem r (StackInfo, a) -> Code -> Sem r (StackInfo, [a])
goNextCmd mp t = do
(si', r) <- mp
(si'', rs) <- go si' t
return (si'', r : rs)

goInstr :: StackInfo -> CmdInstr -> Sem r (StackInfo, a)
goInstr stackInfo cmd = do
a <- (sig ^. recurseInstr) stackInfo cmd
si' <- fixStackInstr stackInfo (cmd ^. cmdInstrInstruction)
return (si', a)
where
fixStackInstr :: StackInfo -> Instruction -> Sem r StackInfo
fixStackInstr si instr =
case instr of
Binop IntAdd ->
fixStackBinOp si
Binop IntSub ->
fixStackBinOp si
Binop IntMul ->
fixStackBinOp si
Binop IntDiv ->
fixStackBinOp si
Binop IntMod ->
fixStackBinOp si
Binop IntLt ->
fixStackBinOp si
Binop IntLe ->
fixStackBinOp si
Binop ValEq ->
fixStackBinOp si
Push {} -> do
return (stackInfoPushValueStack 1 si)
Pop -> do
return (stackInfoPopValueStack 1 si)
PushTemp -> do
return $ stackInfoPushTempStack 1 (stackInfoPopValueStack 1 si)
PopTemp -> do
return $ stackInfoPopTempStack 1 si
Trace ->
return si
Dump ->
return si
Failure ->
return si
AllocConstr tag -> do
let ci = getConstrInfo (sig ^. recursorInfoTable) tag
n = ci ^. constructorArgsNum
return $
stackInfoPopValueStack (n - 1) si
AllocClosure InstrAllocClosure {..} -> do
return $
stackInfoPopValueStack (_allocClosureArgsNum - 1) si
ExtendClosure InstrExtendClosure {..} ->
return $
stackInfoPopValueStack (_extendClosureArgsNum - 1) si
Call x ->
fixStackCall si x
TailCall x ->
fixStackCall si x
CallClosures x ->
fixStackCallClosures si x
TailCallClosures x ->
fixStackCallClosures si x
Return ->
return si

fixStackBinOp :: StackInfo -> Sem r StackInfo
fixStackBinOp si = return $ stackInfoPopValueStack 1 si

fixStackCall :: StackInfo -> InstrCall -> Sem r StackInfo
fixStackCall si InstrCall {..} = do
return $ stackInfoPopValueStack (_callArgsNum + (if _callType == CallClosure then 1 else 0) - 1) si

fixStackCallClosures :: StackInfo -> InstrCallClosures -> Sem r StackInfo
fixStackCallClosures si InstrCallClosures {..} = do
return $ stackInfoPopValueStack (_callClosuresArgsNum - 1) si

goBranch :: StackInfo -> CmdBranch -> Sem r (StackInfo, a)
goBranch si cmd@CmdBranch {..} = do
let si0 = stackInfoPopValueStack 1 si
(si1, as1) <- go si0 _cmdBranchTrue
(si2, as2) <- go si0 _cmdBranchFalse
a' <- (sig ^. recurseBranch) si cmd as1 as2
checkStackInfo loc si1 si2
return (si1, a')
where
loc = cmd ^. cmdBranchInfo . commandInfoLocation

goCase :: StackInfo -> CmdCase -> Sem r (StackInfo, a)
goCase si cmd@CmdCase {..} = do
rs <- mapM (go si . (^. caseBranchCode)) _cmdCaseBranches
let sis = map fst rs
ass = map snd rs
rd <- maybe (return Nothing) (fmap Just . go si) _cmdCaseDefault
let sd = fmap fst rd
ad = fmap snd rd
a' <- (sig ^. recurseCase) si cmd ass ad
case sis of
[] -> return (fromMaybe si sd, a')
si0 : sis' -> do
mapM_ (checkStackInfo loc si0) sis'
forM_ sd (checkStackInfo loc si0)
return (si0, a')
where
loc = cmd ^. (cmdCaseInfo . commandInfoLocation)

checkStackInfo :: Maybe Location -> StackInfo -> StackInfo -> Sem r ()
checkStackInfo loc si1 si2 =
when (si1 /= si2) $
throw $
AsmError loc "stack height mismatch"

stackInfoPushValueStack :: Int -> StackInfo -> StackInfo
stackInfoPushValueStack n si = si {_stackInfoValueStackHeight = si ^. stackInfoValueStackHeight + n}

stackInfoPopValueStack :: Int -> StackInfo -> StackInfo
stackInfoPopValueStack n si = si {_stackInfoValueStackHeight = si ^. stackInfoValueStackHeight - n}

stackInfoPushTempStack :: Int -> StackInfo -> StackInfo
stackInfoPushTempStack n si = si {_stackInfoTempStackHeight = si ^. stackInfoTempStackHeight + n}

stackInfoPopTempStack :: Int -> StackInfo -> StackInfo
stackInfoPopTempStack n si = si {_stackInfoTempStackHeight = si ^. stackInfoTempStackHeight - n}
10 changes: 10 additions & 0 deletions src/Juvix/Compiler/Asm/Pipeline.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module Juvix.Compiler.Asm.Pipeline where

import Juvix.Compiler.Asm.Data.InfoTable
import Juvix.Compiler.Asm.Extra
import Juvix.Compiler.Asm.Language

-- | Perform transformations on JuvixAsm necessary before the translation to
-- JuvixReg
toReg :: Member (Error AsmError) r => InfoTable -> Sem r InfoTable
toReg = validateInfoTable >=> computeStackUsage
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Reg/Translation/FromAsm.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fromAsmFun tab fi =
Left err -> error (show err)
Right code -> code
where
sig :: Asm.RecursorSig (Error Asm.AsmError ': r) Instruction
sig :: Asm.RecursorSig Asm.Memory (Error Asm.AsmError ': r) Instruction
sig =
Asm.RecursorSig
{ _recursorInfoTable = tab,
Expand Down