Skip to content

Commit

Permalink
Compute JuvixAsm stack usage info (#1604)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszcz authored Nov 7, 2022
1 parent a3b2aa6 commit 0d90a61
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 11 deletions.
44 changes: 41 additions & 3 deletions src/Juvix/Compiler/Asm/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Juvix.Compiler.Asm.Language
validateCode :: forall r. Member (Error AsmError) r => InfoTable -> Arguments -> Code -> Sem r ()
validateCode tab args = void . recurse sig args
where
sig :: RecursorSig r ()
sig :: RecursorSig Memory r ()
sig =
RecursorSig
{ _recursorInfoTable = tab,
Expand All @@ -30,11 +30,49 @@ validateCode tab args = void . recurse sig args
validateFunction :: Member (Error AsmError) r => InfoTable -> FunctionInfo -> Sem r ()
validateFunction tab fi = validateCode tab (argumentsFromFunctionInfo fi) (fi ^. functionCode)

validateInfoTable :: Member (Error AsmError) r => InfoTable -> Sem r ()
validateInfoTable tab = mapM_ (validateFunction tab) (HashMap.elems (tab ^. infoFunctions))
validateInfoTable :: Member (Error AsmError) r => InfoTable -> Sem r InfoTable
validateInfoTable tab = do
mapM_ (validateFunction tab) (HashMap.elems (tab ^. infoFunctions))
return tab

validate :: InfoTable -> Maybe AsmError
validate tab =
case run $ runError $ validateInfoTable tab of
Left err -> Just err
_ -> Nothing

computeFunctionStackUsage :: Member (Error AsmError) r => InfoTable -> FunctionInfo -> Sem r FunctionInfo
computeFunctionStackUsage tab fi = do
ps <- snd <$> recurseS sig initialStackInfo (fi ^. functionCode)
let maxValueStack = maximum (map fst ps)
maxTempStack = maximum (map snd ps)
return
fi
{ _functionMaxValueStackHeight = maxValueStack,
_functionMaxTempStackHeight = maxTempStack
}
where
sig :: RecursorSig StackInfo r (Int, Int)
sig =
RecursorSig
{ _recursorInfoTable = tab,
_recurseInstr = \si _ -> return (si ^. stackInfoValueStackHeight, si ^. stackInfoTempStackHeight),
_recurseBranch = \si _ l r ->
return
( max (si ^. stackInfoValueStackHeight) (max (maximum (map fst l)) (maximum (map fst r))),
max (si ^. stackInfoTempStackHeight) (max (maximum (map snd l)) (maximum (map snd r)))
),
_recurseCase = \si _ cs md ->
return
( max (si ^. stackInfoValueStackHeight) (max (maximum (map (maximum . map fst) cs)) (maybe 0 (maximum . map fst) md)),
max (si ^. stackInfoTempStackHeight) (max (maximum (map (maximum . map snd) cs)) (maybe 0 (maximum . map snd) md))
)
}

computeStackUsage :: Member (Error AsmError) r => InfoTable -> Sem r InfoTable
computeStackUsage tab = do
fns <- mapM (computeFunctionStackUsage tab) (tab ^. infoFunctions)
return tab {_infoFunctions = fns}

computeStackUsage' :: InfoTable -> Either AsmError InfoTable
computeStackUsage' tab = run $ runError $ computeStackUsage tab
170 changes: 163 additions & 7 deletions src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@ import Juvix.Compiler.Asm.Language
import Juvix.Compiler.Asm.Pretty

-- | Recursor signature. Contains read-only recursor parameters.
data RecursorSig r a = RecursorSig
data RecursorSig m r a = RecursorSig
{ _recursorInfoTable :: InfoTable,
_recurseInstr :: Memory -> CmdInstr -> Sem r a,
_recurseBranch :: Memory -> CmdBranch -> [a] -> [a] -> Sem r a,
_recurseCase :: Memory -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a
_recurseInstr :: m -> CmdInstr -> Sem r a,
_recurseBranch :: m -> CmdBranch -> [a] -> [a] -> Sem r a,
_recurseCase :: m -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a
}

makeLenses ''RecursorSig

recurseFun :: Member (Error AsmError) r => RecursorSig r a -> FunctionInfo -> Sem r [a]
recurseFun :: Member (Error AsmError) r => RecursorSig Memory r a -> FunctionInfo -> Sem r [a]
recurseFun sig fi = recurse sig (argumentsFromFunctionInfo fi) (fi ^. functionCode)

recurse :: Member (Error AsmError) r => RecursorSig r a -> Arguments -> Code -> Sem r [a]
recurse :: Member (Error AsmError) r => RecursorSig Memory r a -> Arguments -> Code -> Sem r [a]
recurse sig args = fmap snd . recurse' sig (mkMemory args)

recurse' :: forall r a. Member (Error AsmError) r => RecursorSig r a -> Memory -> Code -> Sem r (Memory, [a])
recurse' :: forall r a. Member (Error AsmError) r => RecursorSig Memory r a -> Memory -> Code -> Sem r (Memory, [a])
recurse' sig = go True
where
go :: Bool -> Memory -> Code -> Sem r (Memory, [a])
Expand Down Expand Up @@ -276,3 +276,159 @@ recurse' sig = go True
)
$ throw
$ AsmError loc "temporary stack height changed after branching"

data StackInfo = StackInfo
{ _stackInfoValueStackHeight :: Int,
_stackInfoTempStackHeight :: Int
}
deriving stock (Eq)

makeLenses ''StackInfo

initialStackInfo :: StackInfo
initialStackInfo = StackInfo {_stackInfoValueStackHeight = 0, _stackInfoTempStackHeight = 0}

-- | A simplified recursor which doesn't perform validity checking and only
-- computes stack height information. This makes a significant performance
-- difference, since we'll have many passes recursing over the entire JuvixAsm
-- program code which need only stack height information. Also, the code using
-- the simplified recursor can itself be simpler if it doesn't need the extra
-- info provided by the full recursor.
recurseS :: forall r a. Member (Error AsmError) r => RecursorSig StackInfo r a -> StackInfo -> Code -> Sem r (StackInfo, [a])
recurseS sig = go
where
go :: StackInfo -> Code -> Sem r (StackInfo, [a])
go si = \case
[] -> return (si, [])
h : t -> case h of
Instr x -> do
goNextCmd (goInstr si x) t
Branch x ->
goNextCmd (goBranch si x) t
Case x ->
goNextCmd (goCase si x) t

goNextCmd :: Sem r (StackInfo, a) -> Code -> Sem r (StackInfo, [a])
goNextCmd mp t = do
(si', r) <- mp
(si'', rs) <- go si' t
return (si'', r : rs)

goInstr :: StackInfo -> CmdInstr -> Sem r (StackInfo, a)
goInstr stackInfo cmd = do
a <- (sig ^. recurseInstr) stackInfo cmd
si' <- fixStackInstr stackInfo (cmd ^. cmdInstrInstruction)
return (si', a)
where
fixStackInstr :: StackInfo -> Instruction -> Sem r StackInfo
fixStackInstr si instr =
case instr of
Binop IntAdd ->
fixStackBinOp si
Binop IntSub ->
fixStackBinOp si
Binop IntMul ->
fixStackBinOp si
Binop IntDiv ->
fixStackBinOp si
Binop IntMod ->
fixStackBinOp si
Binop IntLt ->
fixStackBinOp si
Binop IntLe ->
fixStackBinOp si
Binop ValEq ->
fixStackBinOp si
Push {} -> do
return (stackInfoPushValueStack 1 si)
Pop -> do
return (stackInfoPopValueStack 1 si)
PushTemp -> do
return $ stackInfoPushTempStack 1 (stackInfoPopValueStack 1 si)
PopTemp -> do
return $ stackInfoPopTempStack 1 si
Trace ->
return si
Dump ->
return si
Failure ->
return si
AllocConstr tag -> do
let ci = getConstrInfo (sig ^. recursorInfoTable) tag
n = ci ^. constructorArgsNum
return $
stackInfoPopValueStack (n - 1) si
AllocClosure InstrAllocClosure {..} -> do
return $
stackInfoPopValueStack (_allocClosureArgsNum - 1) si
ExtendClosure InstrExtendClosure {..} ->
return $
stackInfoPopValueStack (_extendClosureArgsNum - 1) si
Call x ->
fixStackCall si x
TailCall x ->
fixStackCall si x
CallClosures x ->
fixStackCallClosures si x
TailCallClosures x ->
fixStackCallClosures si x
Return ->
return si

fixStackBinOp :: StackInfo -> Sem r StackInfo
fixStackBinOp si = return $ stackInfoPopValueStack 1 si

fixStackCall :: StackInfo -> InstrCall -> Sem r StackInfo
fixStackCall si InstrCall {..} = do
return $ stackInfoPopValueStack (_callArgsNum + (if _callType == CallClosure then 1 else 0) - 1) si

fixStackCallClosures :: StackInfo -> InstrCallClosures -> Sem r StackInfo
fixStackCallClosures si InstrCallClosures {..} = do
return $ stackInfoPopValueStack (_callClosuresArgsNum - 1) si

goBranch :: StackInfo -> CmdBranch -> Sem r (StackInfo, a)
goBranch si cmd@CmdBranch {..} = do
let si0 = stackInfoPopValueStack 1 si
(si1, as1) <- go si0 _cmdBranchTrue
(si2, as2) <- go si0 _cmdBranchFalse
a' <- (sig ^. recurseBranch) si cmd as1 as2
checkStackInfo loc si1 si2
return (si1, a')
where
loc = cmd ^. cmdBranchInfo . commandInfoLocation

goCase :: StackInfo -> CmdCase -> Sem r (StackInfo, a)
goCase si cmd@CmdCase {..} = do
rs <- mapM (go si . (^. caseBranchCode)) _cmdCaseBranches
let sis = map fst rs
ass = map snd rs
rd <- maybe (return Nothing) (fmap Just . go si) _cmdCaseDefault
let sd = fmap fst rd
ad = fmap snd rd
a' <- (sig ^. recurseCase) si cmd ass ad
case sis of
[] -> return (fromMaybe si sd, a')
si0 : sis' -> do
mapM_ (checkStackInfo loc si0) sis'
forM_ sd (checkStackInfo loc si0)
return (si0, a')
where
loc = cmd ^. (cmdCaseInfo . commandInfoLocation)

checkStackInfo :: Maybe Location -> StackInfo -> StackInfo -> Sem r ()
checkStackInfo loc si1 si2 =
when (si1 /= si2) $
throw $
AsmError loc "stack height mismatch"

stackInfoPushValueStack :: Int -> StackInfo -> StackInfo
stackInfoPushValueStack n si = si {_stackInfoValueStackHeight = si ^. stackInfoValueStackHeight + n}

stackInfoPopValueStack :: Int -> StackInfo -> StackInfo
stackInfoPopValueStack n si = si {_stackInfoValueStackHeight = si ^. stackInfoValueStackHeight - n}

stackInfoPushTempStack :: Int -> StackInfo -> StackInfo
stackInfoPushTempStack n si = si {_stackInfoTempStackHeight = si ^. stackInfoTempStackHeight + n}

stackInfoPopTempStack :: Int -> StackInfo -> StackInfo
stackInfoPopTempStack n si = si {_stackInfoTempStackHeight = si ^. stackInfoTempStackHeight - n}
10 changes: 10 additions & 0 deletions src/Juvix/Compiler/Asm/Pipeline.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module Juvix.Compiler.Asm.Pipeline where

import Juvix.Compiler.Asm.Data.InfoTable
import Juvix.Compiler.Asm.Extra
import Juvix.Compiler.Asm.Language

-- | Perform transformations on JuvixAsm necessary before the translation to
-- JuvixReg
toReg :: Member (Error AsmError) r => InfoTable -> Sem r InfoTable
toReg = validateInfoTable >=> computeStackUsage
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Reg/Translation/FromAsm.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fromAsmFun tab fi =
Left err -> error (show err)
Right code -> code
where
sig :: Asm.RecursorSig (Error Asm.AsmError ': r) Instruction
sig :: Asm.RecursorSig Asm.Memory (Error Asm.AsmError ': r) Instruction
sig =
Asm.RecursorSig
{ _recursorInfoTable = tab,
Expand Down

0 comments on commit 0d90a61

Please sign in to comment.