diff --git a/juvix-stdlib b/juvix-stdlib index 216cb609cb..17f22fcec5 160000 --- a/juvix-stdlib +++ b/juvix-stdlib @@ -1 +1 @@ -Subproject commit 216cb609cbe5aec9badea858f151a5ea400f2e66 +Subproject commit 17f22fcec5d78be511ea59984aee3499da5f3342 diff --git a/src/Juvix/Compiler/Core/Data/IdentDependencyInfo.hs b/src/Juvix/Compiler/Core/Data/IdentDependencyInfo.hs index 063908ff3e..f70452518f 100644 --- a/src/Juvix/Compiler/Core/Data/IdentDependencyInfo.hs +++ b/src/Juvix/Compiler/Core/Data/IdentDependencyInfo.hs @@ -101,3 +101,6 @@ nonRecursiveIdents' tab = HashSet.difference (HashSet.fromList (HashMap.keys (tab ^. infoIdentifiers))) (recursiveIdentsClosure tab) + +nonRecursiveIdents :: Module -> HashSet Symbol +nonRecursiveIdents = nonRecursiveIdents' . computeCombinedInfoTable diff --git a/src/Juvix/Compiler/Core/Data/TransformationId.hs b/src/Juvix/Compiler/Core/Data/TransformationId.hs index 3542f00418..f2d754696b 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId.hs @@ -39,6 +39,7 @@ data TransformationId | SpecializeArgs | CaseFolding | CasePermutation + | ConstantFolding | FilterUnreachable | OptPhaseEval | OptPhaseExec @@ -113,6 +114,7 @@ instance TransformationId' TransformationId where SpecializeArgs -> strSpecializeArgs CaseFolding -> strCaseFolding CasePermutation -> strCasePermutation + ConstantFolding -> strConstantFolding FilterUnreachable -> strFilterUnreachable OptPhaseEval -> strOptPhaseEval OptPhaseExec -> strOptPhaseExec diff --git a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs index 7aaa71fae8..ac50840786 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs @@ -119,6 +119,9 @@ strCaseFolding = "case-folding" strCasePermutation :: Text strCasePermutation = "case-permutation" +strConstantFolding :: Text +strConstantFolding = "constant-folding" + strFilterUnreachable :: Text strFilterUnreachable = "filter-unreachable" diff --git a/src/Juvix/Compiler/Core/Transformation.hs b/src/Juvix/Compiler/Core/Transformation.hs index d87e85f57e..f0a0a23764 100644 --- a/src/Juvix/Compiler/Core/Transformation.hs +++ b/src/Juvix/Compiler/Core/Transformation.hs @@ -36,6 +36,7 @@ import Juvix.Compiler.Core.Transformation.Normalize import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation) +import Juvix.Compiler.Core.Transformation.Optimize.ConstantFolding import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnreachable) import Juvix.Compiler.Core.Transformation.Optimize.Inlining import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding @@ -96,6 +97,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts SpecializeArgs -> return . specializeArgs CaseFolding -> return . caseFolding CasePermutation -> return . casePermutation + ConstantFolding -> constantFolding FilterUnreachable -> return . filterUnreachable OptPhaseEval -> Phase.Eval.optimize OptPhaseExec -> Phase.Exec.optimize diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs index 9c54e4ac21..1dccadbff3 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs @@ -17,7 +17,7 @@ isInlineableLambda inlineDepth md bl node = case node of False convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node -convertNode inlineDepth recSyms md = dmapL go +convertNode inlineDepth nonRecSyms md = dmapL go where go :: BinderList Binder -> Node -> Node go bl node = case node of @@ -37,7 +37,7 @@ convertNode inlineDepth recSyms md = dmapL go Just InlineNever -> node _ - | not (HashSet.member _identSymbol recSyms) + | HashSet.member _identSymbol nonRecSyms && isInlineableLambda inlineDepth md bl def && length args >= argsNum -> mkApps def args @@ -57,7 +57,7 @@ convertNode inlineDepth recSyms md = dmapL go Just InlineAlways -> def Just InlineNever -> node _ - | not (HashSet.member _identSymbol recSyms) + | HashSet.member _identSymbol nonRecSyms && isImmediate md def -> def | otherwise -> @@ -76,7 +76,7 @@ convertNode inlineDepth recSyms md = dmapL go Just InlineCase -> NCase cs {_caseValue = mkApps def args} Nothing - | not (HashSet.member _identSymbol recSyms) + | HashSet.member _identSymbol nonRecSyms && isConstructorApp def && checkDepth md bl inlineDepth def -> NCase cs {_caseValue = mkApps def args} @@ -92,9 +92,9 @@ convertNode inlineDepth recSyms md = dmapL go node inlining' :: Int -> HashSet Symbol -> Module -> Module -inlining' inliningDepth recSyms md = mapT (const (convertNode inliningDepth recSyms md)) md +inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module inlining md = do d <- asks (^. optInliningDepth) - return $ inlining' d (recursiveIdents md) md + return $ inlining' d (nonRecursiveIdents md) md diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs index 0c60f7fd4f..05a7cec3be 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs @@ -33,9 +33,6 @@ optimize' opts@CoreOptions {..} md = tab :: InfoTable tab = computeCombinedInfoTable md - recs :: HashSet Symbol - recs = recursiveIdents' tab - nonRecs :: HashSet Symbol nonRecs = nonRecursiveIdents' tab @@ -48,12 +45,12 @@ optimize' opts@CoreOptions {..} md = | otherwise = nonRecs doInlining :: Module -> Module - doInlining md' = inlining' _optInliningDepth recs' md' + doInlining md' = inlining' _optInliningDepth nonRecs' md' where - recs' = + nonRecs' = if - | _optOptimizationLevel > 1 -> recursiveIdents md' - | otherwise -> recs + | _optOptimizationLevel > 1 -> nonRecursiveIdents md' + | otherwise -> nonRecs doSimplification :: Int -> Module -> Module doSimplification n = diff --git a/tests/Compilation/positive/test056.juvix b/tests/Compilation/positive/test056.juvix index 2981ae3b5d..e7fb7b746a 100644 --- a/tests/Compilation/positive/test056.juvix +++ b/tests/Compilation/positive/test056.juvix @@ -9,20 +9,12 @@ mymap {A B} (f : A -> B) : List A -> List B | (x :: xs) := f x :: mymap f xs; {-# specialize: [2, 5], inline: false #-} -myf - : {A B : Type} - -> A - -> (A -> A -> B) - -> A - -> B - -> Bool - -> B +myf : {A B : Type} -> A -> (A -> A -> B) -> A -> B -> Bool -> B | a0 f a b true := f a0 a | a0 f a b false := b; {-# inline: false #-} -myf' - : {A B : Type} -> A -> (A -> A -> A -> B) -> A -> B -> B +myf' : {A B : Type} -> A -> (A -> A -> A -> B) -> A -> B -> B | a0 f a b := myf a0 (f a0) a b true; sum : List Nat -> Nat @@ -40,8 +32,7 @@ funa : {A : Type} -> (A -> A) -> A -> A {-# specialize: true #-} type Additive A := mkAdditive {add : A -> A -> A}; -type Multiplicative A := - mkMultiplicative {mul : A -> A -> A}; +type Multiplicative A := mkMultiplicative {mul : A -> A -> A}; addNat : Additive Nat := mkAdditive (+); @@ -49,20 +40,20 @@ addNat : Additive Nat := mkAdditive (+); mulNat : Multiplicative Nat := mkMultiplicative (*); {-# inline: false #-} -fadd {A} (a : Additive A) (x y : A) : A := - Additive.add a x y; +fadd {A} (a : Additive A) (x y : A) : A := Additive.add a x y; {-# inline: false #-} -fmul {A} (m : Multiplicative A) (x y : A) : A := - Multiplicative.mul m x y; +fmul {A} (m : Multiplicative A) (x y : A) : A := Multiplicative.mul m x y; + +{-# specialize: [1] #-} +myfilter {A} (f : A → Bool) : List A → List A + | nil := nil + | (h :: hs) := ite (f h) (h :: myfilter f hs) (myfilter f hs); main : Nat := - sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil)) - + sum - (flatten - (mymap - (mymap λ {x := x + 2}) - ((1 :: nil) :: (2 :: 3 :: nil) :: nil))) + sum (myfilter (const false) []) + + sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil)) + + sum (flatten (mymap (mymap λ {x := x + 2}) ((1 :: nil) :: (2 :: 3 :: nil) :: nil))) + myf 3 (*) 2 5 true + myf 1 (+) 2 0 false + myf' 7 (const (+)) 2 0