diff --git a/src/Juvix/Compiler/Asm/Extra.hs b/src/Juvix/Compiler/Asm/Extra.hs index d3f37d22b5..d4faa04ef6 100644 --- a/src/Juvix/Compiler/Asm/Extra.hs +++ b/src/Juvix/Compiler/Asm/Extra.hs @@ -18,7 +18,7 @@ import Juvix.Compiler.Asm.Language validateCode :: forall r. Member (Error AsmError) r => InfoTable -> Arguments -> Code -> Sem r () validateCode tab args = void . recurse sig args where - sig :: RecursorSig r () + sig :: RecursorSig Memory r () sig = RecursorSig { _recursorInfoTable = tab, @@ -30,11 +30,49 @@ validateCode tab args = void . recurse sig args validateFunction :: Member (Error AsmError) r => InfoTable -> FunctionInfo -> Sem r () validateFunction tab fi = validateCode tab (argumentsFromFunctionInfo fi) (fi ^. functionCode) -validateInfoTable :: Member (Error AsmError) r => InfoTable -> Sem r () -validateInfoTable tab = mapM_ (validateFunction tab) (HashMap.elems (tab ^. infoFunctions)) +validateInfoTable :: Member (Error AsmError) r => InfoTable -> Sem r InfoTable +validateInfoTable tab = do + mapM_ (validateFunction tab) (HashMap.elems (tab ^. infoFunctions)) + return tab validate :: InfoTable -> Maybe AsmError validate tab = case run $ runError $ validateInfoTable tab of Left err -> Just err _ -> Nothing + +computeFunctionStackUsage :: Member (Error AsmError) r => InfoTable -> FunctionInfo -> Sem r FunctionInfo +computeFunctionStackUsage tab fi = do + ps <- snd <$> recurseS sig initialStackInfo (fi ^. functionCode) + let maxValueStack = maximum (map fst ps) + maxTempStack = maximum (map snd ps) + return + fi + { _functionMaxValueStackHeight = maxValueStack, + _functionMaxTempStackHeight = maxTempStack + } + where + sig :: RecursorSig StackInfo r (Int, Int) + sig = + RecursorSig + { _recursorInfoTable = tab, + _recurseInstr = \si _ -> return (si ^. stackInfoValueStackHeight, si ^. stackInfoTempStackHeight), + _recurseBranch = \si _ l r -> + return + ( max (si ^. stackInfoValueStackHeight) (max (maximum (map fst l)) (maximum (map fst r))), + max (si ^. stackInfoTempStackHeight) (max (maximum (map snd l)) (maximum (map snd r))) + ), + _recurseCase = \si _ cs md -> + return + ( max (si ^. stackInfoValueStackHeight) (max (maximum (map (maximum . map fst) cs)) (maybe 0 (maximum . map fst) md)), + max (si ^. stackInfoTempStackHeight) (max (maximum (map (maximum . map snd) cs)) (maybe 0 (maximum . map snd) md)) + ) + } + +computeStackUsage :: Member (Error AsmError) r => InfoTable -> Sem r InfoTable +computeStackUsage tab = do + fns <- mapM (computeFunctionStackUsage tab) (tab ^. infoFunctions) + return tab {_infoFunctions = fns} + +computeStackUsage' :: InfoTable -> Either AsmError InfoTable +computeStackUsage' tab = run $ runError $ computeStackUsage tab diff --git a/src/Juvix/Compiler/Asm/Extra/Recursors.hs b/src/Juvix/Compiler/Asm/Extra/Recursors.hs index 05ba93ea7c..09ac50e18e 100644 --- a/src/Juvix/Compiler/Asm/Extra/Recursors.hs +++ b/src/Juvix/Compiler/Asm/Extra/Recursors.hs @@ -14,22 +14,22 @@ import Juvix.Compiler.Asm.Language import Juvix.Compiler.Asm.Pretty -- | Recursor signature. Contains read-only recursor parameters. -data RecursorSig r a = RecursorSig +data RecursorSig m r a = RecursorSig { _recursorInfoTable :: InfoTable, - _recurseInstr :: Memory -> CmdInstr -> Sem r a, - _recurseBranch :: Memory -> CmdBranch -> [a] -> [a] -> Sem r a, - _recurseCase :: Memory -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a + _recurseInstr :: m -> CmdInstr -> Sem r a, + _recurseBranch :: m -> CmdBranch -> [a] -> [a] -> Sem r a, + _recurseCase :: m -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a } makeLenses ''RecursorSig -recurseFun :: Member (Error AsmError) r => RecursorSig r a -> FunctionInfo -> Sem r [a] +recurseFun :: Member (Error AsmError) r => RecursorSig Memory r a -> FunctionInfo -> Sem r [a] recurseFun sig fi = recurse sig (argumentsFromFunctionInfo fi) (fi ^. functionCode) -recurse :: Member (Error AsmError) r => RecursorSig r a -> Arguments -> Code -> Sem r [a] +recurse :: Member (Error AsmError) r => RecursorSig Memory r a -> Arguments -> Code -> Sem r [a] recurse sig args = fmap snd . recurse' sig (mkMemory args) -recurse' :: forall r a. Member (Error AsmError) r => RecursorSig r a -> Memory -> Code -> Sem r (Memory, [a]) +recurse' :: forall r a. Member (Error AsmError) r => RecursorSig Memory r a -> Memory -> Code -> Sem r (Memory, [a]) recurse' sig = go True where go :: Bool -> Memory -> Code -> Sem r (Memory, [a]) @@ -276,3 +276,159 @@ recurse' sig = go True ) $ throw $ AsmError loc "temporary stack height changed after branching" + +data StackInfo = StackInfo + { _stackInfoValueStackHeight :: Int, + _stackInfoTempStackHeight :: Int + } + deriving stock (Eq) + +makeLenses ''StackInfo + +initialStackInfo :: StackInfo +initialStackInfo = StackInfo {_stackInfoValueStackHeight = 0, _stackInfoTempStackHeight = 0} + +-- | A simplified recursor which doesn't perform validity checking and only +-- computes stack height information. This makes a significant performance +-- difference, since we'll have many passes recursing over the entire JuvixAsm +-- program code which need only stack height information. Also, the code using +-- the simplified recursor can itself be simpler if it doesn't need the extra +-- info provided by the full recursor. +recurseS :: forall r a. Member (Error AsmError) r => RecursorSig StackInfo r a -> StackInfo -> Code -> Sem r (StackInfo, [a]) +recurseS sig = go + where + go :: StackInfo -> Code -> Sem r (StackInfo, [a]) + go si = \case + [] -> return (si, []) + h : t -> case h of + Instr x -> do + goNextCmd (goInstr si x) t + Branch x -> + goNextCmd (goBranch si x) t + Case x -> + goNextCmd (goCase si x) t + + goNextCmd :: Sem r (StackInfo, a) -> Code -> Sem r (StackInfo, [a]) + goNextCmd mp t = do + (si', r) <- mp + (si'', rs) <- go si' t + return (si'', r : rs) + + goInstr :: StackInfo -> CmdInstr -> Sem r (StackInfo, a) + goInstr stackInfo cmd = do + a <- (sig ^. recurseInstr) stackInfo cmd + si' <- fixStackInstr stackInfo (cmd ^. cmdInstrInstruction) + return (si', a) + where + fixStackInstr :: StackInfo -> Instruction -> Sem r StackInfo + fixStackInstr si instr = + case instr of + Binop IntAdd -> + fixStackBinOp si + Binop IntSub -> + fixStackBinOp si + Binop IntMul -> + fixStackBinOp si + Binop IntDiv -> + fixStackBinOp si + Binop IntMod -> + fixStackBinOp si + Binop IntLt -> + fixStackBinOp si + Binop IntLe -> + fixStackBinOp si + Binop ValEq -> + fixStackBinOp si + Push {} -> do + return (stackInfoPushValueStack 1 si) + Pop -> do + return (stackInfoPopValueStack 1 si) + PushTemp -> do + return $ stackInfoPushTempStack 1 (stackInfoPopValueStack 1 si) + PopTemp -> do + return $ stackInfoPopTempStack 1 si + Trace -> + return si + Dump -> + return si + Failure -> + return si + AllocConstr tag -> do + let ci = getConstrInfo (sig ^. recursorInfoTable) tag + n = ci ^. constructorArgsNum + return $ + stackInfoPopValueStack (n - 1) si + AllocClosure InstrAllocClosure {..} -> do + return $ + stackInfoPopValueStack (_allocClosureArgsNum - 1) si + ExtendClosure InstrExtendClosure {..} -> + return $ + stackInfoPopValueStack (_extendClosureArgsNum - 1) si + Call x -> + fixStackCall si x + TailCall x -> + fixStackCall si x + CallClosures x -> + fixStackCallClosures si x + TailCallClosures x -> + fixStackCallClosures si x + Return -> + return si + + fixStackBinOp :: StackInfo -> Sem r StackInfo + fixStackBinOp si = return $ stackInfoPopValueStack 1 si + + fixStackCall :: StackInfo -> InstrCall -> Sem r StackInfo + fixStackCall si InstrCall {..} = do + return $ stackInfoPopValueStack (_callArgsNum + (if _callType == CallClosure then 1 else 0) - 1) si + + fixStackCallClosures :: StackInfo -> InstrCallClosures -> Sem r StackInfo + fixStackCallClosures si InstrCallClosures {..} = do + return $ stackInfoPopValueStack (_callClosuresArgsNum - 1) si + + goBranch :: StackInfo -> CmdBranch -> Sem r (StackInfo, a) + goBranch si cmd@CmdBranch {..} = do + let si0 = stackInfoPopValueStack 1 si + (si1, as1) <- go si0 _cmdBranchTrue + (si2, as2) <- go si0 _cmdBranchFalse + a' <- (sig ^. recurseBranch) si cmd as1 as2 + checkStackInfo loc si1 si2 + return (si1, a') + where + loc = cmd ^. cmdBranchInfo . commandInfoLocation + + goCase :: StackInfo -> CmdCase -> Sem r (StackInfo, a) + goCase si cmd@CmdCase {..} = do + rs <- mapM (go si . (^. caseBranchCode)) _cmdCaseBranches + let sis = map fst rs + ass = map snd rs + rd <- maybe (return Nothing) (fmap Just . go si) _cmdCaseDefault + let sd = fmap fst rd + ad = fmap snd rd + a' <- (sig ^. recurseCase) si cmd ass ad + case sis of + [] -> return (fromMaybe si sd, a') + si0 : sis' -> do + mapM_ (checkStackInfo loc si0) sis' + forM_ sd (checkStackInfo loc si0) + return (si0, a') + where + loc = cmd ^. (cmdCaseInfo . commandInfoLocation) + + checkStackInfo :: Maybe Location -> StackInfo -> StackInfo -> Sem r () + checkStackInfo loc si1 si2 = + when (si1 /= si2) $ + throw $ + AsmError loc "stack height mismatch" + + stackInfoPushValueStack :: Int -> StackInfo -> StackInfo + stackInfoPushValueStack n si = si {_stackInfoValueStackHeight = si ^. stackInfoValueStackHeight + n} + + stackInfoPopValueStack :: Int -> StackInfo -> StackInfo + stackInfoPopValueStack n si = si {_stackInfoValueStackHeight = si ^. stackInfoValueStackHeight - n} + + stackInfoPushTempStack :: Int -> StackInfo -> StackInfo + stackInfoPushTempStack n si = si {_stackInfoTempStackHeight = si ^. stackInfoTempStackHeight + n} + + stackInfoPopTempStack :: Int -> StackInfo -> StackInfo + stackInfoPopTempStack n si = si {_stackInfoTempStackHeight = si ^. stackInfoTempStackHeight - n} diff --git a/src/Juvix/Compiler/Asm/Pipeline.hs b/src/Juvix/Compiler/Asm/Pipeline.hs new file mode 100644 index 0000000000..1ba5a31ea8 --- /dev/null +++ b/src/Juvix/Compiler/Asm/Pipeline.hs @@ -0,0 +1,10 @@ +module Juvix.Compiler.Asm.Pipeline where + +import Juvix.Compiler.Asm.Data.InfoTable +import Juvix.Compiler.Asm.Extra +import Juvix.Compiler.Asm.Language + +-- | Perform transformations on JuvixAsm necessary before the translation to +-- JuvixReg +toReg :: Member (Error AsmError) r => InfoTable -> Sem r InfoTable +toReg = validateInfoTable >=> computeStackUsage diff --git a/src/Juvix/Compiler/Reg/Translation/FromAsm.hs b/src/Juvix/Compiler/Reg/Translation/FromAsm.hs index f1be5bb711..42cb3a0a24 100644 --- a/src/Juvix/Compiler/Reg/Translation/FromAsm.hs +++ b/src/Juvix/Compiler/Reg/Translation/FromAsm.hs @@ -59,7 +59,7 @@ fromAsmFun tab fi = Left err -> error (show err) Right code -> code where - sig :: Asm.RecursorSig (Error Asm.AsmError ': r) Instruction + sig :: Asm.RecursorSig Asm.Memory (Error Asm.AsmError ': r) Instruction sig = Asm.RecursorSig { _recursorInfoTable = tab,