diff --git a/src/Juvix/Compiler/Asm/Pipeline.hs b/src/Juvix/Compiler/Asm/Pipeline.hs index e124a84f56..0e8ac5ea8b 100644 --- a/src/Juvix/Compiler/Asm/Pipeline.hs +++ b/src/Juvix/Compiler/Asm/Pipeline.hs @@ -20,7 +20,7 @@ toReg' = validate >=> filterUnreachable >=> computeStackUsage >=> computePreallo -- | Perform transformations on JuvixAsm necessary before the translation to -- Nockma toNockma' :: (Members '[Error AsmError, Reader Options] r) => InfoTable -> Sem r InfoTable -toNockma' = validate >=> filterUnreachable +toNockma' = validate toReg :: (Members '[Error JuvixError, Reader EntryPoint] r) => InfoTable -> Sem r InfoTable toReg = mapReader fromEntryPoint . mapError (JuvixError @AsmError) . toReg' diff --git a/src/Juvix/Compiler/Tree/Data/CallGraph.hs b/src/Juvix/Compiler/Tree/Data/CallGraph.hs new file mode 100644 index 0000000000..b2793770d3 --- /dev/null +++ b/src/Juvix/Compiler/Tree/Data/CallGraph.hs @@ -0,0 +1,35 @@ +module Juvix.Compiler.Tree.Data.CallGraph where + +import Data.HashSet qualified as HashSet +import Juvix.Compiler.Tree.Data.InfoTable +import Juvix.Compiler.Tree.Extra.Recursors + +-- | Call graph type +type CallGraph = DependencyInfo Symbol + +-- | Compute the call graph +createCallGraph :: InfoTable -> CallGraph +createCallGraph tab = createDependencyInfo (createCallGraphMap tab) startVertices + where + startVertices :: HashSet Symbol + startVertices = HashSet.fromList syms + + syms :: [Symbol] + syms = maybe [] singleton (tab ^. infoMainFunction) + +createCallGraphMap :: InfoTable -> HashMap Symbol (HashSet Symbol) +createCallGraphMap tab = fmap (getFunSymbols . (^. functionCode)) (tab ^. infoFunctions) + +getFunSymbols :: Node -> HashSet Symbol +getFunSymbols = gather go mempty + where + go :: HashSet Symbol -> Node -> HashSet Symbol + go syms = \case + AllocClosure NodeAllocClosure {..} -> HashSet.insert _nodeAllocClosureFunSymbol syms + Call NodeCall {..} -> goCallType syms _nodeCallType + _ -> syms + + goCallType :: HashSet Symbol -> CallType -> HashSet Symbol + goCallType syms = \case + CallFun sym -> HashSet.insert sym syms + CallClosure {} -> syms diff --git a/src/Juvix/Compiler/Tree/Data/TransformationId.hs b/src/Juvix/Compiler/Tree/Data/TransformationId.hs index 2ead0ee432..e001e0535a 100644 --- a/src/Juvix/Compiler/Tree/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Tree/Data/TransformationId.hs @@ -10,6 +10,7 @@ data TransformationId | IdentityD | Apply | TempHeight + | FilterUnreachable deriving stock (Data, Bounded, Enum, Show) data PipelineId @@ -20,7 +21,7 @@ data PipelineId type TransformationLikeId = TransformationLikeId' TransformationId PipelineId toNockmaTransformations :: [TransformationId] -toNockmaTransformations = [Apply, TempHeight] +toNockmaTransformations = [Apply, FilterUnreachable, TempHeight] toAsmTransformations :: [TransformationId] toAsmTransformations = [] @@ -33,6 +34,7 @@ instance TransformationId' TransformationId where IdentityD -> strIdentityD Apply -> strApply TempHeight -> strTempHeight + FilterUnreachable -> strFilterUnreachable instance PipelineId' TransformationId PipelineId where pipelineText :: PipelineId -> Text diff --git a/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs index d5512d024a..7e454ccbdc 100644 --- a/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs @@ -22,3 +22,6 @@ strApply = "apply" strTempHeight :: Text strTempHeight = "temp-height" + +strFilterUnreachable :: Text +strFilterUnreachable = "filter-unreachable" diff --git a/src/Juvix/Compiler/Tree/Transformation.hs b/src/Juvix/Compiler/Tree/Transformation.hs index 714fe1a82c..c16456c838 100644 --- a/src/Juvix/Compiler/Tree/Transformation.hs +++ b/src/Juvix/Compiler/Tree/Transformation.hs @@ -8,6 +8,7 @@ where import Juvix.Compiler.Tree.Data.TransformationId import Juvix.Compiler.Tree.Transformation.Apply import Juvix.Compiler.Tree.Transformation.Base +import Juvix.Compiler.Tree.Transformation.FilterUnreachable import Juvix.Compiler.Tree.Transformation.Identity import Juvix.Compiler.Tree.Transformation.TempHeight @@ -21,3 +22,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts IdentityD -> return . identityD Apply -> return . computeApply TempHeight -> return . computeTempHeight + FilterUnreachable -> return . filterUnreachable diff --git a/src/Juvix/Compiler/Tree/Transformation/FilterUnreachable.hs b/src/Juvix/Compiler/Tree/Transformation/FilterUnreachable.hs new file mode 100644 index 0000000000..3bfcc7e16b --- /dev/null +++ b/src/Juvix/Compiler/Tree/Transformation/FilterUnreachable.hs @@ -0,0 +1,12 @@ +module Juvix.Compiler.Tree.Transformation.FilterUnreachable where + +import Data.HashMap.Strict qualified as HashMap +import Juvix.Compiler.Tree.Data.CallGraph +import Juvix.Compiler.Tree.Data.InfoTable +import Juvix.Prelude + +filterUnreachable :: InfoTable -> InfoTable +filterUnreachable tab = + over infoFunctions (HashMap.filterWithKey (const . isReachable graph)) tab + where + graph = createCallGraph tab diff --git a/test/Tree/Transformation.hs b/test/Tree/Transformation.hs index bc8ea92b31..cf252bcb94 100644 --- a/test/Tree/Transformation.hs +++ b/test/Tree/Transformation.hs @@ -3,11 +3,13 @@ module Tree.Transformation where import Base import Tree.Transformation.Apply qualified as Apply import Tree.Transformation.Identity qualified as Identity +import Tree.Transformation.Reachability qualified as Reachability allTests :: TestTree allTests = testGroup "JuvixTree transformations" [ Identity.allTests, - Apply.allTests + Apply.allTests, + Reachability.allTests ] diff --git a/test/Tree/Transformation/Reachability.hs b/test/Tree/Transformation/Reachability.hs new file mode 100644 index 0000000000..4f6cd7c8d6 --- /dev/null +++ b/test/Tree/Transformation/Reachability.hs @@ -0,0 +1,48 @@ +module Tree.Transformation.Reachability (allTests) where + +import Base +import Data.HashMap.Strict qualified as HashMap +import Juvix.Compiler.Tree.Transformation as Tree +import Tree.Eval.Positive qualified as Eval +import Tree.Transformation.Base + +data ReachabilityTest = ReachabilityTest + { _reachabilityTestReachable :: [Text], + _reachabilityTestEval :: Eval.PosTest + } + +allTests :: TestTree +allTests = + testGroup "Reachability" $ + map liftTest rtests + +rtests :: [ReachabilityTest] +rtests = + [ ReachabilityTest + { _reachabilityTestReachable = ["f", "f'", "g'", "h", "h'", "main"], + _reachabilityTestEval = + Eval.PosTest + "Test001: Reachability" + $(mkRelDir "reachability") + $(mkRelFile "test001.jvt") + $(mkRelFile "out/test001.out") + }, + ReachabilityTest + { _reachabilityTestReachable = ["f", "g", "id", "sum", "main"], + _reachabilityTestEval = + Eval.PosTest + "Test002: Reachability with loops & closures" + $(mkRelDir "reachability") + $(mkRelFile "test002.jvt") + $(mkRelFile "out/test002.out") + } + ] + +liftTest :: ReachabilityTest -> TestTree +liftTest ReachabilityTest {..} = + fromTest + Test + { _testTransformations = [Tree.FilterUnreachable], + _testAssertion = \tab -> unless (nubSort (map (^. functionName) (HashMap.elems (tab ^. infoFunctions))) == nubSort _reachabilityTestReachable) (error "check reachable"), + _testEval = _reachabilityTestEval + } diff --git a/tests/Tree/positive/reachability/out/test001.out b/tests/Tree/positive/reachability/out/test001.out new file mode 100644 index 0000000000..ec635144f6 --- /dev/null +++ b/tests/Tree/positive/reachability/out/test001.out @@ -0,0 +1 @@ +9 diff --git a/tests/Tree/positive/reachability/out/test002.out b/tests/Tree/positive/reachability/out/test002.out new file mode 100644 index 0000000000..e0fd17de85 --- /dev/null +++ b/tests/Tree/positive/reachability/out/test002.out @@ -0,0 +1 @@ +5051 diff --git a/tests/Tree/positive/reachability/test001.jvt b/tests/Tree/positive/reachability/test001.jvt new file mode 100644 index 0000000000..3e25a0be93 --- /dev/null +++ b/tests/Tree/positive/reachability/test001.jvt @@ -0,0 +1,36 @@ + +function h(integer) : integer; +function h'(integer) : integer; +function f(integer) : integer; +function f'(integer) : integer; +function g(integer) : integer; +function g'(integer) : integer; +function main() : integer; + +function h(integer) : integer { + arg[0] +} + +function h'(integer) : integer { + arg[0] +} + +function f(integer) : integer { + add(call[h](arg[0]), 1) +} + +function f'(integer) : integer { + add(call[h'](arg[0]), 1) +} + +function g(integer) : integer { + add(call[f](arg[0]), 2) +} + +function g'(integer) : integer { + call[f'](arg[0]) +} + +function main() : integer { + call[g'](call[f](7)) +} diff --git a/tests/Tree/positive/reachability/test002.jvt b/tests/Tree/positive/reachability/test002.jvt new file mode 100644 index 0000000000..5f0c193886 --- /dev/null +++ b/tests/Tree/positive/reachability/test002.jvt @@ -0,0 +1,39 @@ + +function f(*, integer) : integer; +function id(integer) : integer; +function g(integer) : integer; +function sum(integer) : integer; +function g'(integer) : integer; +function g''(integer) : integer; +function main() : integer; + +function f(*, integer) : integer { + call(arg[0], arg[1]) +} + +function id(integer) : integer { + arg[0] +} + +function g(integer) : integer { + add(call[f](calloc[id](), arg[0]), 1) +} + +function sum(integer) : integer { + br(eq(0, arg[0])) { + true: call[g](0) + false: add(arg[0], call[sum](sub(arg[0], 1))) + } +} + +function g'(integer) : integer { + add(call[id](arg[0]), 2) +} + +function g''(integer) : integer { + call[sum](arg[0]) +} + +function main() : integer { + call[sum](100) +}