diff --git a/app/Commands/Dev/Reg/Read.hs b/app/Commands/Dev/Reg/Read.hs index f8a86ebe5f..2d58b6f64f 100644 --- a/app/Commands/Dev/Reg/Read.hs +++ b/app/Commands/Dev/Reg/Read.hs @@ -2,7 +2,7 @@ module Commands.Dev.Reg.Read where import Commands.Base import Commands.Dev.Reg.Read.Options -import Juvix.Compiler.Reg.Pretty qualified as Reg +import Juvix.Compiler.Reg.Pretty qualified as Reg hiding (defaultOptions) import Juvix.Compiler.Reg.Transformation qualified as Reg import Juvix.Compiler.Reg.Translation.FromSource qualified as Reg import RegInterpreter @@ -15,7 +15,10 @@ runCommand opts = do Left err -> exitJuvixError (JuvixError err) Right tab -> do - r <- runError @JuvixError (Reg.applyTransformations (project opts ^. regReadTransformations) tab) + r <- + runError @JuvixError + . runReader Reg.defaultOptions + $ (Reg.applyTransformations (project opts ^. regReadTransformations) tab) case r of Left err -> exitJuvixError (JuvixError err) Right tab' -> do diff --git a/src/Juvix/Compiler/Asm/Options.hs b/src/Juvix/Compiler/Asm/Options.hs index 417d95c541..46cb60789e 100644 --- a/src/Juvix/Compiler/Asm/Options.hs +++ b/src/Juvix/Compiler/Asm/Options.hs @@ -6,11 +6,13 @@ where import Juvix.Compiler.Backend import Juvix.Compiler.Pipeline.EntryPoint +import Juvix.Compiler.Tree.Options qualified as Tree import Juvix.Prelude data Options = Options { _optDebug :: Bool, - _optLimits :: Limits + _optLimits :: Limits, + _optTreeOptions :: Tree.Options } makeLenses ''Options @@ -19,7 +21,8 @@ makeOptions :: Target -> Bool -> Options makeOptions tgt debug = Options { _optDebug = debug, - _optLimits = getLimits tgt debug + _optLimits = getLimits tgt debug, + _optTreeOptions = Tree.defaultOptions } getClosureSize :: Options -> Int -> Int @@ -29,5 +32,6 @@ fromEntryPoint :: EntryPoint -> Options fromEntryPoint e@EntryPoint {..} = Options { _optDebug = _entryPointDebug, - _optLimits = getLimits (getEntryPointTarget e) _entryPointDebug + _optLimits = getLimits (getEntryPointTarget e) _entryPointDebug, + _optTreeOptions = Tree.fromEntryPoint e } diff --git a/src/Juvix/Compiler/Pipeline.hs b/src/Juvix/Compiler/Pipeline.hs index 8fac49cc11..58cd880cb1 100644 --- a/src/Juvix/Compiler/Pipeline.hs +++ b/src/Juvix/Compiler/Pipeline.hs @@ -358,18 +358,24 @@ regToRust = regToRust' Rust.BackendRust regToRiscZeroRust :: (Member (Reader EntryPoint) r) => Reg.InfoTable -> Sem r Rust.Result regToRiscZeroRust = regToRust' Rust.BackendRiscZero -regToCasm :: Reg.InfoTable -> Sem r Casm.Result +regToCasm :: (Member (Reader EntryPoint) r) => Reg.InfoTable -> Sem r Casm.Result regToCasm = Reg.toCasm >=> return . Casm.fromReg +regToCasm' :: (Member (Reader Reg.Options) r) => Reg.InfoTable -> Sem r Casm.Result +regToCasm' = Reg.toCasm' >=> return . Casm.fromReg + casmToCairo :: Casm.Result -> Sem r Cairo.Result casmToCairo Casm.Result {..} = return . Cairo.serialize _resultOutputSize (map Casm.builtinName _resultBuiltins) $ Cairo.fromCasm _resultCode -regToCairo :: Reg.InfoTable -> Sem r Cairo.Result +regToCairo :: (Member (Reader EntryPoint) r) => Reg.InfoTable -> Sem r Cairo.Result regToCairo = regToCasm >=> casmToCairo +regToCairo' :: (Member (Reader Reg.Options) r) => Reg.InfoTable -> Sem r Cairo.Result +regToCairo' = regToCasm' >=> casmToCairo + treeToAnoma' :: (Members '[Error JuvixError, Reader NockmaTree.CompilerOptions] r) => Tree.InfoTable -> Sem r NockmaTree.AnomaResult treeToAnoma' = Tree.toNockma >=> NockmaTree.fromTreeTable @@ -378,6 +384,6 @@ asmToMiniC' = mapError (JuvixError @Asm.AsmError) . Asm.toReg' >=> regToMiniC' . regToMiniC' :: (Member (Reader Asm.Options) r) => Reg.InfoTable -> Sem r C.MiniCResult regToMiniC' tab = do - tab' <- Reg.toC tab + tab' <- mapReader (^. Asm.optTreeOptions) $ Reg.toC' tab e <- ask return $ C.fromReg (e ^. Asm.optLimits) tab' diff --git a/src/Juvix/Compiler/Reg/Data/TransformationId.hs b/src/Juvix/Compiler/Reg/Data/TransformationId.hs index a353dd9eb0..7dca328397 100644 --- a/src/Juvix/Compiler/Reg/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Reg/Data/TransformationId.hs @@ -11,6 +11,9 @@ data TransformationId | InitBranchVars | CopyPropagation | ConstantPropagation + | DeadCodeElimination + | OptPhaseMain + | OptPhaseCairo deriving stock (Data, Bounded, Enum, Show) data PipelineId @@ -28,7 +31,7 @@ toRustTransformations :: [TransformationId] toRustTransformations = [Cleanup] toCasmTransformations :: [TransformationId] -toCasmTransformations = [Cleanup, CopyPropagation, ConstantPropagation, SSA] +toCasmTransformations = [Cleanup, SSA, OptPhaseCairo] instance TransformationId' TransformationId where transformationText :: TransformationId -> Text @@ -39,6 +42,9 @@ instance TransformationId' TransformationId where InitBranchVars -> strInitBranchVars CopyPropagation -> strCopyPropagation ConstantPropagation -> strConstantPropagation + DeadCodeElimination -> strDeadCodeElimination + OptPhaseMain -> strOptPhaseMain + OptPhaseCairo -> strOptPhaseCairo instance PipelineId' TransformationId PipelineId where pipelineText :: PipelineId -> Text diff --git a/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs index 69363f81d8..aa44771ee0 100644 --- a/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs @@ -28,3 +28,12 @@ strCopyPropagation = "copy-propagation" strConstantPropagation :: Text strConstantPropagation = "constant-propagation" + +strDeadCodeElimination :: Text +strDeadCodeElimination = "dead-code" + +strOptPhaseMain :: Text +strOptPhaseMain = "opt-main" + +strOptPhaseCairo :: Text +strOptPhaseCairo = "opt-cairo" diff --git a/src/Juvix/Compiler/Reg/Extra/Base.hs b/src/Juvix/Compiler/Reg/Extra/Base.hs index 79128ab604..e9003737af 100644 --- a/src/Juvix/Compiler/Reg/Extra/Base.hs +++ b/src/Juvix/Compiler/Reg/Extra/Base.hs @@ -28,135 +28,168 @@ setResultVar instr vref = case instr of CallClosures x -> CallClosures $ set instrCallClosuresResult vref x _ -> impossible -overValueRefs' :: (VarRef -> Value) -> Instruction -> Instruction -overValueRefs' f = \case - Binop x -> Binop $ goBinop x - Unop x -> Unop $ goUnop x - Cairo x -> Cairo $ goCairo x - Assign x -> Assign $ goAssign x - Alloc x -> Alloc $ goAlloc x - AllocClosure x -> AllocClosure $ goAllocClosure x - ExtendClosure x -> ExtendClosure $ goExtendClosure x - Call x -> Call $ goCall x - CallClosures x -> CallClosures $ goCallClosures x - TailCall x -> TailCall $ goTailCall x - TailCallClosures x -> TailCallClosures $ goTailCallClosures x - Return x -> Return $ goReturn x - Branch x -> Branch $ goBranch x - Case x -> Case $ goCase x - Trace x -> Trace $ goTrace x - Dump -> Dump - Failure x -> Failure $ goFailure x - Prealloc x -> Prealloc $ goPrealloc x - Nop -> Nop - Block x -> Block $ goBlock x +getOutVar :: Instruction -> Maybe VarRef +getOutVar = \case + Branch x -> x ^. instrBranchOutVar + Case x -> x ^. instrCaseOutVar + _ -> Nothing + +overValueRefs'' :: forall m. (Monad m) => (VarRef -> m Value) -> Instruction -> m Instruction +overValueRefs'' f = \case + Binop x -> Binop <$> goBinop x + Unop x -> Unop <$> goUnop x + Cairo x -> Cairo <$> goCairo x + Assign x -> Assign <$> goAssign x + Alloc x -> Alloc <$> goAlloc x + AllocClosure x -> AllocClosure <$> goAllocClosure x + ExtendClosure x -> ExtendClosure <$> goExtendClosure x + Call x -> Call <$> goCall x + CallClosures x -> CallClosures <$> goCallClosures x + TailCall x -> TailCall <$> goTailCall x + TailCallClosures x -> TailCallClosures <$> goTailCallClosures x + Return x -> Return <$> goReturn x + Branch x -> Branch <$> goBranch x + Case x -> Case <$> goCase x + Trace x -> Trace <$> goTrace x + Dump -> return Dump + Failure x -> Failure <$> goFailure x + Prealloc x -> Prealloc <$> goPrealloc x + Nop -> return Nop + Block x -> Block <$> goBlock x where fromVarRef :: Value -> VarRef fromVarRef = \case VRef r -> r _ -> impossible - goConstrField :: ConstrField -> ConstrField - goConstrField = over constrFieldRef (fromVarRef . f) + goConstrField :: ConstrField -> m ConstrField + goConstrField = overM constrFieldRef (fmap fromVarRef . f) - goValue :: Value -> Value + goValue :: Value -> m Value goValue = \case - ValConst c -> ValConst c - CRef x -> CRef $ goConstrField x + ValConst c -> return $ ValConst c + CRef x -> CRef <$> goConstrField x VRef x -> f x - goBinop :: InstrBinop -> InstrBinop - goBinop InstrBinop {..} = - InstrBinop - { _instrBinopArg1 = goValue _instrBinopArg1, - _instrBinopArg2 = goValue _instrBinopArg2, - .. - } - - goUnop :: InstrUnop -> InstrUnop - goUnop = over instrUnopArg goValue - - goCairo :: InstrCairo -> InstrCairo - goCairo = over instrCairoArgs (map goValue) - - goAssign :: InstrAssign -> InstrAssign - goAssign = over instrAssignValue goValue - - goAlloc :: InstrAlloc -> InstrAlloc - goAlloc = over instrAllocArgs (map goValue) - - goAllocClosure :: InstrAllocClosure -> InstrAllocClosure - goAllocClosure = over instrAllocClosureArgs (map goValue) - - goExtendClosure :: InstrExtendClosure -> InstrExtendClosure - goExtendClosure InstrExtendClosure {..} = - InstrExtendClosure - { _instrExtendClosureValue = fromVarRef (f _instrExtendClosureValue), - _instrExtendClosureArgs = map goValue _instrExtendClosureArgs, - .. - } - - goCallType :: CallType -> CallType + goBinop :: InstrBinop -> m InstrBinop + goBinop InstrBinop {..} = do + arg1 <- goValue _instrBinopArg1 + arg2 <- goValue _instrBinopArg2 + return + InstrBinop + { _instrBinopArg1 = arg1, + _instrBinopArg2 = arg2, + .. + } + + goUnop :: InstrUnop -> m InstrUnop + goUnop = overM instrUnopArg goValue + + goCairo :: InstrCairo -> m InstrCairo + goCairo = overM instrCairoArgs (mapM goValue) + + goAssign :: InstrAssign -> m InstrAssign + goAssign = overM instrAssignValue goValue + + goAlloc :: InstrAlloc -> m InstrAlloc + goAlloc = overM instrAllocArgs (mapM goValue) + + goAllocClosure :: InstrAllocClosure -> m InstrAllocClosure + goAllocClosure = overM instrAllocClosureArgs (mapM goValue) + + goExtendClosure :: InstrExtendClosure -> m InstrExtendClosure + goExtendClosure InstrExtendClosure {..} = do + val <- f _instrExtendClosureValue + args <- mapM goValue _instrExtendClosureArgs + return + InstrExtendClosure + { _instrExtendClosureValue = fromVarRef val, + _instrExtendClosureArgs = args, + .. + } + + goCallType :: CallType -> m CallType goCallType = \case - CallFun sym -> CallFun sym - CallClosure cl -> CallClosure (fromVarRef (f cl)) - - goCall :: InstrCall -> InstrCall - goCall InstrCall {..} = - InstrCall - { _instrCallType = goCallType _instrCallType, - _instrCallArgs = map goValue _instrCallArgs, - .. - } - - goCallClosures :: InstrCallClosures -> InstrCallClosures - goCallClosures InstrCallClosures {..} = - InstrCallClosures - { _instrCallClosuresArgs = map goValue _instrCallClosuresArgs, - _instrCallClosuresValue = fromVarRef (f _instrCallClosuresValue), - .. - } - - goTailCall :: InstrTailCall -> InstrTailCall - goTailCall InstrTailCall {..} = - InstrTailCall - { _instrTailCallType = goCallType _instrTailCallType, - _instrTailCallArgs = map goValue _instrTailCallArgs, - .. - } - - goTailCallClosures :: InstrTailCallClosures -> InstrTailCallClosures - goTailCallClosures InstrTailCallClosures {..} = - InstrTailCallClosures - { _instrTailCallClosuresValue = fromVarRef (f _instrTailCallClosuresValue), - _instrTailCallClosuresArgs = map goValue _instrTailCallClosuresArgs, - .. - } - - goReturn :: InstrReturn -> InstrReturn - goReturn = over instrReturnValue goValue - - goBranch :: InstrBranch -> InstrBranch - goBranch = over instrBranchValue goValue - - goCase :: InstrCase -> InstrCase - goCase = over instrCaseValue goValue - - goTrace :: InstrTrace -> InstrTrace - goTrace = over instrTraceValue goValue - - goFailure :: InstrFailure -> InstrFailure - goFailure = over instrFailureValue goValue - - goPrealloc :: InstrPrealloc -> InstrPrealloc - goPrealloc x = x - - goBlock :: InstrBlock -> InstrBlock - goBlock x = x + CallFun sym -> return $ CallFun sym + CallClosure cl -> do + val <- f cl + return $ CallClosure (fromVarRef val) + + goCall :: InstrCall -> m InstrCall + goCall InstrCall {..} = do + ct <- goCallType _instrCallType + args <- mapM goValue _instrCallArgs + return $ + InstrCall + { _instrCallType = ct, + _instrCallArgs = args, + .. + } + + goCallClosures :: InstrCallClosures -> m InstrCallClosures + goCallClosures InstrCallClosures {..} = do + args <- mapM goValue _instrCallClosuresArgs + val <- f _instrCallClosuresValue + return $ + InstrCallClosures + { _instrCallClosuresArgs = args, + _instrCallClosuresValue = fromVarRef val, + .. + } + + goTailCall :: InstrTailCall -> m InstrTailCall + goTailCall InstrTailCall {..} = do + ct <- goCallType _instrTailCallType + args <- mapM goValue _instrTailCallArgs + return + InstrTailCall + { _instrTailCallType = ct, + _instrTailCallArgs = args, + .. + } + + goTailCallClosures :: InstrTailCallClosures -> m InstrTailCallClosures + goTailCallClosures InstrTailCallClosures {..} = do + val <- f _instrTailCallClosuresValue + args <- mapM goValue _instrTailCallClosuresArgs + return + InstrTailCallClosures + { _instrTailCallClosuresValue = fromVarRef val, + _instrTailCallClosuresArgs = args, + .. + } + + goReturn :: InstrReturn -> m InstrReturn + goReturn = overM instrReturnValue goValue + + goBranch :: InstrBranch -> m InstrBranch + goBranch = overM instrBranchValue goValue + + goCase :: InstrCase -> m InstrCase + goCase = overM instrCaseValue goValue + + goTrace :: InstrTrace -> m InstrTrace + goTrace = overM instrTraceValue goValue + + goFailure :: InstrFailure -> m InstrFailure + goFailure = overM instrFailureValue goValue + + goPrealloc :: InstrPrealloc -> m InstrPrealloc + goPrealloc x = return x + + goBlock :: InstrBlock -> m InstrBlock + goBlock x = return x + +overValueRefs' :: (VarRef -> Value) -> Instruction -> Instruction +overValueRefs' f = runIdentity . overValueRefs'' (return . f) overValueRefs :: (VarRef -> VarRef) -> Instruction -> Instruction overValueRefs f = overValueRefs' (VRef . f) +getValueRefs :: Instruction -> [VarRef] +getValueRefs = + run . execOutputList . overValueRefs'' (\vr -> output vr >> return (VRef vr)) + updateLiveVars' :: (VarRef -> Maybe VarRef) -> Instruction -> Instruction updateLiveVars' f = \case Prealloc x -> Prealloc $ over instrPreallocLiveVars (mapMaybe f) x diff --git a/src/Juvix/Compiler/Reg/Pipeline.hs b/src/Juvix/Compiler/Reg/Pipeline.hs index 43eee4a079..d7440c1c76 100644 --- a/src/Juvix/Compiler/Reg/Pipeline.hs +++ b/src/Juvix/Compiler/Reg/Pipeline.hs @@ -1,9 +1,11 @@ module Juvix.Compiler.Reg.Pipeline ( module Juvix.Compiler.Reg.Pipeline, module Juvix.Compiler.Reg.Data.InfoTable, + Options, ) where +import Juvix.Compiler.Pipeline.EntryPoint (EntryPoint) import Juvix.Compiler.Reg.Data.Blocks.InfoTable qualified as Blocks import Juvix.Compiler.Reg.Data.InfoTable import Juvix.Compiler.Reg.Transformation @@ -11,14 +13,25 @@ import Juvix.Compiler.Reg.Transformation.Blocks.Liveness qualified as Blocks import Juvix.Compiler.Reg.Translation.Blocks.FromReg qualified as Blocks -- | Perform transformations on JuvixReg necessary before the translation to C -toC :: InfoTable -> Sem r InfoTable -toC = applyTransformations toCTransformations +toC' :: (Member (Reader Options) r) => InfoTable -> Sem r InfoTable +toC' = applyTransformations toCTransformations -- | Perform transformations on JuvixReg necessary before the translation to Rust -toRust :: InfoTable -> Sem r InfoTable -toRust = applyTransformations toRustTransformations +toRust' :: (Member (Reader Options) r) => InfoTable -> Sem r InfoTable +toRust' = applyTransformations toRustTransformations -- | Perform transformations on JuvixReg necessary before the translation to -- Cairo assembly -toCasm :: InfoTable -> Sem r Blocks.InfoTable -toCasm = applyTransformations toCasmTransformations >=> return . Blocks.computeLiveness . Blocks.fromReg +toCasm' :: (Member (Reader Options) r) => InfoTable -> Sem r Blocks.InfoTable +toCasm' = + applyTransformations toCasmTransformations + >=> return . Blocks.computeLiveness . Blocks.fromReg + +toC :: (Member (Reader EntryPoint) r) => InfoTable -> Sem r InfoTable +toC = mapReader fromEntryPoint . toC' + +toRust :: (Member (Reader EntryPoint) r) => InfoTable -> Sem r InfoTable +toRust = mapReader fromEntryPoint . toRust' + +toCasm :: (Member (Reader EntryPoint) r) => InfoTable -> Sem r Blocks.InfoTable +toCasm = mapReader fromEntryPoint . toCasm' diff --git a/src/Juvix/Compiler/Reg/Transformation.hs b/src/Juvix/Compiler/Reg/Transformation.hs index e7c02ef518..e932652b08 100644 --- a/src/Juvix/Compiler/Reg/Transformation.hs +++ b/src/Juvix/Compiler/Reg/Transformation.hs @@ -8,13 +8,16 @@ where import Juvix.Compiler.Reg.Data.TransformationId import Juvix.Compiler.Reg.Transformation.Base import Juvix.Compiler.Reg.Transformation.Cleanup -import Juvix.Compiler.Reg.Transformation.ConstantPropagation (constantPropagate) -import Juvix.Compiler.Reg.Transformation.CopyPropagation import Juvix.Compiler.Reg.Transformation.IdentityTrans import Juvix.Compiler.Reg.Transformation.InitBranchVars +import Juvix.Compiler.Reg.Transformation.Optimize.ConstantPropagation +import Juvix.Compiler.Reg.Transformation.Optimize.CopyPropagation +import Juvix.Compiler.Reg.Transformation.Optimize.DeadCodeElimination +import Juvix.Compiler.Reg.Transformation.Optimize.Phase.Cairo qualified as Phase.Cairo +import Juvix.Compiler.Reg.Transformation.Optimize.Phase.Main qualified as Phase.Main import Juvix.Compiler.Reg.Transformation.SSA -applyTransformations :: forall r. [TransformationId] -> InfoTable -> Sem r InfoTable +applyTransformations :: forall r. (Member (Reader Options) r) => [TransformationId] -> InfoTable -> Sem r InfoTable applyTransformations ts tbl = foldM (flip appTrans) tbl ts where appTrans :: TransformationId -> InfoTable -> Sem r InfoTable @@ -25,3 +28,6 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts InitBranchVars -> return . initBranchVars CopyPropagation -> return . copyPropagate ConstantPropagation -> return . constantPropagate + DeadCodeElimination -> return . removeDeadAssignments + OptPhaseMain -> Phase.Main.optimize + OptPhaseCairo -> Phase.Cairo.optimize diff --git a/src/Juvix/Compiler/Reg/Transformation/Base.hs b/src/Juvix/Compiler/Reg/Transformation/Base.hs index 60131d0e39..9f734efb5d 100644 --- a/src/Juvix/Compiler/Reg/Transformation/Base.hs +++ b/src/Juvix/Compiler/Reg/Transformation/Base.hs @@ -1,5 +1,6 @@ module Juvix.Compiler.Reg.Transformation.Base ( module Juvix.Compiler.Tree.Transformation.Generic.Base, + module Juvix.Compiler.Tree.Options, module Juvix.Compiler.Reg.Data.InfoTable, module Juvix.Compiler.Reg.Language, ) @@ -7,4 +8,5 @@ where import Juvix.Compiler.Reg.Data.InfoTable import Juvix.Compiler.Reg.Language +import Juvix.Compiler.Tree.Options import Juvix.Compiler.Tree.Transformation.Generic.Base diff --git a/src/Juvix/Compiler/Reg/Transformation/ConstantPropagation.hs b/src/Juvix/Compiler/Reg/Transformation/Optimize/ConstantPropagation.hs similarity index 83% rename from src/Juvix/Compiler/Reg/Transformation/ConstantPropagation.hs rename to src/Juvix/Compiler/Reg/Transformation/Optimize/ConstantPropagation.hs index 9da543dc81..071d4ab40f 100644 --- a/src/Juvix/Compiler/Reg/Transformation/ConstantPropagation.hs +++ b/src/Juvix/Compiler/Reg/Transformation/Optimize/ConstantPropagation.hs @@ -1,4 +1,4 @@ -module Juvix.Compiler.Reg.Transformation.ConstantPropagation where +module Juvix.Compiler.Reg.Transformation.Optimize.ConstantPropagation where import Data.HashMap.Strict qualified as HashMap import Juvix.Compiler.Reg.Extra @@ -7,17 +7,20 @@ import Juvix.Compiler.Tree.Evaluator.Builtins type VarMap = HashMap VarRef Constant -constantPropagateFunction :: Code -> Code -constantPropagateFunction = - snd - . runIdentity - . recurseF - ForwardRecursorSig - { _forwardFun = \i acc -> return (go i acc), - _forwardCombine = combine - } - mempty +constantPropagate :: InfoTable -> InfoTable +constantPropagate = mapT (const goFun) where + goFun :: Code -> Code + goFun = + snd + . runIdentity + . recurseF + ForwardRecursorSig + { _forwardFun = \i acc -> return (go i acc), + _forwardCombine = combine + } + mempty + go :: Instruction -> VarMap -> (VarMap, Instruction) go instr mpv = case instr' of Assign InstrAssign {..} @@ -60,6 +63,3 @@ constantPropagateFunction = _ -> impossible _ -> (combineMaps mpvs, instr) - -constantPropagate :: InfoTable -> InfoTable -constantPropagate = mapT (const constantPropagateFunction) diff --git a/src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs b/src/Juvix/Compiler/Reg/Transformation/Optimize/CopyPropagation.hs similarity index 78% rename from src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs rename to src/Juvix/Compiler/Reg/Transformation/Optimize/CopyPropagation.hs index c486eeefb6..1d47bfaa4d 100644 --- a/src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs +++ b/src/Juvix/Compiler/Reg/Transformation/Optimize/CopyPropagation.hs @@ -1,4 +1,4 @@ -module Juvix.Compiler.Reg.Transformation.CopyPropagation where +module Juvix.Compiler.Reg.Transformation.Optimize.CopyPropagation where import Data.HashMap.Strict qualified as HashMap import Juvix.Compiler.Reg.Extra @@ -6,17 +6,20 @@ import Juvix.Compiler.Reg.Transformation.Base type VarMap = HashMap VarRef VarRef -copyPropagateFunction :: Code -> Code -copyPropagateFunction = - snd - . runIdentity - . recurseF - ForwardRecursorSig - { _forwardFun = \i acc -> return (go i acc), - _forwardCombine = combine - } - mempty +copyPropagate :: InfoTable -> InfoTable +copyPropagate = mapT (const goFun) where + goFun :: Code -> Code + goFun = + snd + . runIdentity + . recurseF + ForwardRecursorSig + { _forwardFun = \i acc -> return (go i acc), + _forwardCombine = combine + } + mempty + go :: Instruction -> VarMap -> (VarMap, Instruction) go instr mpv = case instr' of Assign InstrAssign {..} @@ -44,6 +47,3 @@ copyPropagateFunction = Branch x -> Branch $ over instrBranchOutVar (fmap (adjustVarRef mpv)) x Case x -> Case $ over instrCaseOutVar (fmap (adjustVarRef mpv)) x _ -> impossible - -copyPropagate :: InfoTable -> InfoTable -copyPropagate = mapT (const copyPropagateFunction) diff --git a/src/Juvix/Compiler/Reg/Transformation/Optimize/DeadCodeElimination.hs b/src/Juvix/Compiler/Reg/Transformation/Optimize/DeadCodeElimination.hs new file mode 100644 index 0000000000..4f2540a4e4 --- /dev/null +++ b/src/Juvix/Compiler/Reg/Transformation/Optimize/DeadCodeElimination.hs @@ -0,0 +1,45 @@ +module Juvix.Compiler.Reg.Transformation.Optimize.DeadCodeElimination where + +import Data.HashSet qualified as HashSet +import Juvix.Compiler.Reg.Extra +import Juvix.Compiler.Reg.Transformation.Base + +removeDeadAssignments :: InfoTable -> InfoTable +removeDeadAssignments = mapT (const goFun) + where + goFun :: Code -> Code + goFun = + snd + . runIdentity + . recurseB + BackwardRecursorSig + { _backwardFun = \is a as -> return (go is a as), + _backwardAdjust = id + } + mempty + + -- The accumulator contains live variables + go :: Code -> HashSet VarRef -> [HashSet VarRef] -> (HashSet VarRef, Code) + go is live lives = case is of + Assign InstrAssign {..} : is' + | VRef r <- _instrAssignValue, + _instrAssignResult == r -> + (live, is') + instr : is' -> case getResultVar instr of + Just var + | not (HashSet.member var liveVars) -> + (liveVars, is') + _ -> + (liveVars', instr : is') + where + liveVars' = + HashSet.union + (maybe liveVars (`HashSet.delete` liveVars) (getResultVar instr)) + (HashSet.fromList (getValueRefs instr)) + liveVars = case instr of + Branch {} -> ulives + Case {} -> ulives + _ -> live + ulives = HashSet.unions lives + [] -> + (live, []) diff --git a/src/Juvix/Compiler/Reg/Transformation/Optimize/Phase/Cairo.hs b/src/Juvix/Compiler/Reg/Transformation/Optimize/Phase/Cairo.hs new file mode 100644 index 0000000000..ba4033ebdf --- /dev/null +++ b/src/Juvix/Compiler/Reg/Transformation/Optimize/Phase/Cairo.hs @@ -0,0 +1,7 @@ +module Juvix.Compiler.Reg.Transformation.Optimize.Phase.Cairo where + +import Juvix.Compiler.Reg.Transformation.Base +import Juvix.Compiler.Reg.Transformation.Optimize.Phase.Main qualified as Main + +optimize :: (Member (Reader Options) r) => InfoTable -> Sem r InfoTable +optimize = withOptimizationLevel 1 Main.optimize diff --git a/src/Juvix/Compiler/Reg/Transformation/Optimize/Phase/Main.hs b/src/Juvix/Compiler/Reg/Transformation/Optimize/Phase/Main.hs new file mode 100644 index 0000000000..4dfcdcffaa --- /dev/null +++ b/src/Juvix/Compiler/Reg/Transformation/Optimize/Phase/Main.hs @@ -0,0 +1,20 @@ +module Juvix.Compiler.Reg.Transformation.Optimize.Phase.Main where + +import Juvix.Compiler.Reg.Transformation.Base +import Juvix.Compiler.Reg.Transformation.Optimize.ConstantPropagation +import Juvix.Compiler.Reg.Transformation.Optimize.CopyPropagation +import Juvix.Compiler.Reg.Transformation.Optimize.DeadCodeElimination + +optimize' :: Options -> InfoTable -> InfoTable +optimize' Options {..} = + compose + (2 * _optOptimizationLevel) + ( copyPropagate + . constantPropagate + . removeDeadAssignments + ) + +optimize :: (Member (Reader Options) r) => InfoTable -> Sem r InfoTable +optimize tab = do + opts <- ask + return $ optimize' opts tab diff --git a/src/Juvix/Compiler/Tree/Options.hs b/src/Juvix/Compiler/Tree/Options.hs new file mode 100644 index 0000000000..ab9684f9c3 --- /dev/null +++ b/src/Juvix/Compiler/Tree/Options.hs @@ -0,0 +1,26 @@ +module Juvix.Compiler.Tree.Options where + +import Juvix.Compiler.Pipeline.EntryPoint +import Juvix.Data.Field +import Juvix.Prelude + +data Options = Options + { _optOptimizationLevel :: Int, + _optFieldSize :: Natural + } + +makeLenses ''Options + +defaultOptions :: Options +defaultOptions = + Options + { _optOptimizationLevel = defaultOptimizationLevel, + _optFieldSize = defaultFieldSize + } + +fromEntryPoint :: EntryPoint -> Options +fromEntryPoint EntryPoint {..} = + Options + { _optOptimizationLevel = _entryPointOptimizationLevel, + _optFieldSize = _entryPointFieldSize + } diff --git a/src/Juvix/Compiler/Tree/Transformation/Generic/Base.hs b/src/Juvix/Compiler/Tree/Transformation/Generic/Base.hs index 15d547e24d..84c5fe66e8 100644 --- a/src/Juvix/Compiler/Tree/Transformation/Generic/Base.hs +++ b/src/Juvix/Compiler/Tree/Transformation/Generic/Base.hs @@ -4,6 +4,7 @@ import Data.HashMap.Strict qualified as HashMap import Juvix.Compiler.Tree.Data.InfoTable.Base import Juvix.Compiler.Tree.Data.InfoTableBuilder.Base import Juvix.Compiler.Tree.Language.Base +import Juvix.Compiler.Tree.Options mapFunctionsM :: (Monad m) => (FunctionInfo' a e -> m (FunctionInfo' a e)) -> InfoTable' a e -> m (InfoTable' a e) mapFunctionsM = overM infoFunctions . mapM @@ -36,3 +37,13 @@ mapT' f tab = walkT :: (Applicative f) => (Symbol -> a -> f ()) -> InfoTable' a e -> f () walkT f tab = for_ (HashMap.toList (tab ^. infoFunctions)) (\(k, v) -> f k (v ^. functionCode)) + +withOptimizationLevel :: (Member (Reader Options) r) => Int -> (InfoTable' a e -> Sem r (InfoTable' a e)) -> InfoTable' a e -> Sem r (InfoTable' a e) +withOptimizationLevel n f tab = do + l <- asks (^. optOptimizationLevel) + if + | l >= n -> f tab + | otherwise -> return tab + +withOptimizationLevel' :: (Member (Reader Options) r) => InfoTable' a e -> Int -> (InfoTable' a e -> Sem r (InfoTable' a e)) -> Sem r (InfoTable' a e) +withOptimizationLevel' tab n f = withOptimizationLevel n f tab diff --git a/test/Asm/Transformation/Prealloc.hs b/test/Asm/Transformation/Prealloc.hs index 6ce807be1c..c7a5b1ed2b 100644 --- a/test/Asm/Transformation/Prealloc.hs +++ b/test/Asm/Transformation/Prealloc.hs @@ -6,6 +6,7 @@ import Base import Juvix.Compiler.Asm.Options import Juvix.Compiler.Asm.Transformation import Juvix.Compiler.Asm.Transformation.Base +import Juvix.Compiler.Tree.Options qualified as Tree allTests :: TestTree allTests = testGroup "Prealloc" (map liftTest Run.tests) @@ -22,5 +23,6 @@ liftTest _testEval = opts = Options { _optDebug = True, - _optLimits = getLimits TargetCWasm32Wasi True + _optLimits = getLimits TargetCWasm32Wasi True, + _optTreeOptions = Tree.defaultOptions } diff --git a/test/Asm/Transformation/Reachability.hs b/test/Asm/Transformation/Reachability.hs index b6aa55b5c3..7f5fcce59a 100644 --- a/test/Asm/Transformation/Reachability.hs +++ b/test/Asm/Transformation/Reachability.hs @@ -7,6 +7,7 @@ import Data.HashMap.Strict qualified as HashMap import Juvix.Compiler.Asm.Options import Juvix.Compiler.Asm.Transformation import Juvix.Compiler.Asm.Transformation.Base +import Juvix.Compiler.Tree.Options qualified as Tree data ReachabilityTest = ReachabilityTest { _reachabilityTestReachable :: [Text], @@ -52,5 +53,6 @@ liftTest ReachabilityTest {..} = opts = Options { _optDebug = True, - _optLimits = getLimits TargetCWasm32Wasi True + _optLimits = getLimits TargetCWasm32Wasi True, + _optTreeOptions = Tree.defaultOptions } diff --git a/test/Casm/Reg/Base.hs b/test/Casm/Reg/Base.hs index 72f266127d..609d160d68 100644 --- a/test/Casm/Reg/Base.hs +++ b/test/Casm/Reg/Base.hs @@ -7,13 +7,14 @@ import Juvix.Compiler.Casm.Data.Result import Juvix.Compiler.Casm.Error import Juvix.Compiler.Casm.Interpreter import Juvix.Compiler.Reg.Data.InfoTable qualified as Reg +import Juvix.Compiler.Reg.Transformation qualified as Reg import Juvix.Data.PPOutput import Reg.Run.Base qualified as Reg compileAssertion' :: Maybe (Path Abs File) -> Path Abs Dir -> Path Abs File -> Symbol -> Reg.InfoTable -> (String -> IO ()) -> Assertion compileAssertion' inputFile _ outputFile _ tab step = do step "Translate to CASM" - case run $ runError @JuvixError $ regToCasm tab of + case run $ runError @JuvixError $ runReader Reg.defaultOptions $ regToCasm' tab of Left err -> assertFailure (prettyString (fromJuvixError @GenericError err)) Right Result {..} -> do step "Interpret" @@ -30,7 +31,7 @@ compileAssertion' inputFile _ outputFile _ tab step = do cairoAssertion' :: Maybe (Path Abs File) -> Path Abs Dir -> Path Abs File -> Symbol -> Reg.InfoTable -> (String -> IO ()) -> Assertion cairoAssertion' inputFile dirPath outputFile _ tab step = do step "Translate to Cairo" - case run $ runError @JuvixError $ regToCairo tab of + case run $ runError @JuvixError $ runReader Reg.defaultOptions $ regToCairo' tab of Left err -> assertFailure (prettyString (fromJuvixError @GenericError err)) Right res -> do step "Serialize to Cairo bytecode" diff --git a/test/Reg/Run/Base.hs b/test/Reg/Run/Base.hs index f6a66c9fac..5083173375 100644 --- a/test/Reg/Run/Base.hs +++ b/test/Reg/Run/Base.hs @@ -5,7 +5,7 @@ import Juvix.Compiler.Reg.Data.InfoTable import Juvix.Compiler.Reg.Error import Juvix.Compiler.Reg.Interpreter import Juvix.Compiler.Reg.Pretty -import Juvix.Compiler.Reg.Transformation +import Juvix.Compiler.Reg.Transformation as Reg import Juvix.Compiler.Reg.Translation.FromSource import Juvix.Data.PPOutput @@ -56,7 +56,7 @@ regRunAssertionParam interpretFun mainFile expectedFile trans testTrans step = d Right tab0 -> do unless (null trans) $ step "Transform" - case run $ runError @JuvixError $ applyTransformations trans tab0 of + case run $ runError @JuvixError $ runReader Reg.defaultOptions $ applyTransformations trans tab0 of Left err -> assertFailure (prettyString (fromJuvixError @GenericError err)) Right tab -> do testTrans tab diff --git a/tests/Casm/Reg/positive/out/test014.out b/tests/Casm/Reg/positive/out/test014.out index 7579c89006..d62abd38b6 100644 --- a/tests/Casm/Reg/positive/out/test014.out +++ b/tests/Casm/Reg/positive/out/test014.out @@ -2,13 +2,19 @@ 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 666 2 1 @@ -18,70 +24,106 @@ 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 1 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 2 1 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 1 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 2 1 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 1 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 1 3 2 @@ -89,68 +131,103 @@ 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 1 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 2 1 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 1 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 2 1 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 +777 1 3 2 0 +777 0 +777 2 0 +777 0 +777 2 0 +777 0 777 diff --git a/tests/Casm/Reg/positive/test014.jvr b/tests/Casm/Reg/positive/test014.jvr index 932fb19e62..f3a5cafae7 100644 --- a/tests/Casm/Reg/positive/test014.jvr +++ b/tests/Casm/Reg/positive/test014.jvr @@ -101,7 +101,7 @@ function preorder(tree) : * { nop; tmp[0] = arg[0].node2[0]; tmp[0] = call preorder (tmp[0]), live: (arg[0]); - nop; + trace tmp[0]; tmp[0] = arg[0].node2[1]; tcall preorder (tmp[0]); }; @@ -112,10 +112,10 @@ function preorder(tree) : * { nop; tmp[0] = arg[0].node3[0]; tmp[0] = call preorder (tmp[0]), live: (arg[0]); - nop; + trace tmp[0]; tmp[0] = arg[0].node3[1]; tmp[0] = call preorder (tmp[0]), live: (arg[0]); - nop; + trace tmp[0]; tmp[0] = arg[0].node3[2]; tcall preorder (tmp[0]); }; @@ -134,6 +134,7 @@ function main() : * { tmp[0] = 3; tmp[0] = call gen (tmp[0]); tmp[0] = call preorder (tmp[0]); + trace tmp[0]; tmp[0] = 666; trace tmp[0]; tmp[0] = 7;