Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: compiler looping with the specialize pragma #2899

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion juvix-stdlib
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Core/Data/IdentDependencyInfo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,6 @@ nonRecursiveIdents' tab =
HashSet.difference
(HashSet.fromList (HashMap.keys (tab ^. infoIdentifiers)))
(recursiveIdentsClosure tab)

nonRecursiveIdents :: Module -> HashSet Symbol
nonRecursiveIdents = nonRecursiveIdents' . computeCombinedInfoTable
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ data TransformationId
| SpecializeArgs
| CaseFolding
| CasePermutation
| ConstantFolding
| FilterUnreachable
| OptPhaseEval
| OptPhaseExec
Expand Down Expand Up @@ -113,6 +114,7 @@ instance TransformationId' TransformationId where
SpecializeArgs -> strSpecializeArgs
CaseFolding -> strCaseFolding
CasePermutation -> strCasePermutation
ConstantFolding -> strConstantFolding
FilterUnreachable -> strFilterUnreachable
OptPhaseEval -> strOptPhaseEval
OptPhaseExec -> strOptPhaseExec
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ strCaseFolding = "case-folding"
strCasePermutation :: Text
strCasePermutation = "case-permutation"

strConstantFolding :: Text
strConstantFolding = "constant-folding"

strFilterUnreachable :: Text
strFilterUnreachable = "filter-unreachable"

Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import Juvix.Compiler.Core.Transformation.Normalize
import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation)
import Juvix.Compiler.Core.Transformation.Optimize.ConstantFolding
import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnreachable)
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
Expand Down Expand Up @@ -96,6 +97,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
SpecializeArgs -> return . specializeArgs
CaseFolding -> return . caseFolding
CasePermutation -> return . casePermutation
ConstantFolding -> constantFolding
FilterUnreachable -> return . filterUnreachable
OptPhaseEval -> Phase.Eval.optimize
OptPhaseExec -> Phase.Exec.optimize
Expand Down
12 changes: 6 additions & 6 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ isInlineableLambda inlineDepth md bl node = case node of
False

convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node
convertNode inlineDepth recSyms md = dmapL go
convertNode inlineDepth nonRecSyms md = dmapL go
where
go :: BinderList Binder -> Node -> Node
go bl node = case node of
Expand All @@ -37,7 +37,7 @@ convertNode inlineDepth recSyms md = dmapL go
Just InlineNever ->
node
_
| not (HashSet.member _identSymbol recSyms)
| HashSet.member _identSymbol nonRecSyms
&& isInlineableLambda inlineDepth md bl def
&& length args >= argsNum ->
mkApps def args
Expand All @@ -57,7 +57,7 @@ convertNode inlineDepth recSyms md = dmapL go
Just InlineAlways -> def
Just InlineNever -> node
_
| not (HashSet.member _identSymbol recSyms)
| HashSet.member _identSymbol nonRecSyms
&& isImmediate md def ->
def
| otherwise ->
Expand All @@ -76,7 +76,7 @@ convertNode inlineDepth recSyms md = dmapL go
Just InlineCase ->
NCase cs {_caseValue = mkApps def args}
Nothing
| not (HashSet.member _identSymbol recSyms)
| HashSet.member _identSymbol nonRecSyms
&& isConstructorApp def
&& checkDepth md bl inlineDepth def ->
NCase cs {_caseValue = mkApps def args}
Expand All @@ -92,9 +92,9 @@ convertNode inlineDepth recSyms md = dmapL go
node

inlining' :: Int -> HashSet Symbol -> Module -> Module
inlining' inliningDepth recSyms md = mapT (const (convertNode inliningDepth recSyms md)) md
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md

inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
inlining md = do
d <- asks (^. optInliningDepth)
return $ inlining' d (recursiveIdents md) md
return $ inlining' d (nonRecursiveIdents md) md
11 changes: 4 additions & 7 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ optimize' opts@CoreOptions {..} md =
tab :: InfoTable
tab = computeCombinedInfoTable md

recs :: HashSet Symbol
recs = recursiveIdents' tab

nonRecs :: HashSet Symbol
nonRecs = nonRecursiveIdents' tab

Expand All @@ -48,12 +45,12 @@ optimize' opts@CoreOptions {..} md =
| otherwise = nonRecs

doInlining :: Module -> Module
doInlining md' = inlining' _optInliningDepth recs' md'
doInlining md' = inlining' _optInliningDepth nonRecs' md'
where
recs' =
nonRecs' =
if
| _optOptimizationLevel > 1 -> recursiveIdents md'
| otherwise -> recs
| _optOptimizationLevel > 1 -> nonRecursiveIdents md'
| otherwise -> nonRecs

doSimplification :: Int -> Module -> Module
doSimplification n =
Expand Down
35 changes: 13 additions & 22 deletions tests/Compilation/positive/test056.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,12 @@ mymap {A B} (f : A -> B) : List A -> List B
| (x :: xs) := f x :: mymap f xs;

{-# specialize: [2, 5], inline: false #-}
myf
: {A B : Type}
-> A
-> (A -> A -> B)
-> A
-> B
-> Bool
-> B
myf : {A B : Type} -> A -> (A -> A -> B) -> A -> B -> Bool -> B
| a0 f a b true := f a0 a
| a0 f a b false := b;

{-# inline: false #-}
myf'
: {A B : Type} -> A -> (A -> A -> A -> B) -> A -> B -> B
myf' : {A B : Type} -> A -> (A -> A -> A -> B) -> A -> B -> B
| a0 f a b := myf a0 (f a0) a b true;

sum : List Nat -> Nat
Expand All @@ -40,29 +32,28 @@ funa : {A : Type} -> (A -> A) -> A -> A
{-# specialize: true #-}
type Additive A := mkAdditive {add : A -> A -> A};

type Multiplicative A :=
mkMultiplicative {mul : A -> A -> A};
type Multiplicative A := mkMultiplicative {mul : A -> A -> A};

addNat : Additive Nat := mkAdditive (+);

{-# specialize: true #-}
mulNat : Multiplicative Nat := mkMultiplicative (*);

{-# inline: false #-}
fadd {A} (a : Additive A) (x y : A) : A :=
Additive.add a x y;
fadd {A} (a : Additive A) (x y : A) : A := Additive.add a x y;

{-# inline: false #-}
fmul {A} (m : Multiplicative A) (x y : A) : A :=
Multiplicative.mul m x y;
fmul {A} (m : Multiplicative A) (x y : A) : A := Multiplicative.mul m x y;

{-# specialize: [1] #-}
myfilter {A} (f : A → Bool) : List A → List A
| nil := nil
| (h :: hs) := ite (f h) (h :: myfilter f hs) (myfilter f hs);

main : Nat :=
sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil))
+ sum
(flatten
(mymap
(mymap λ {x := x + 2})
((1 :: nil) :: (2 :: 3 :: nil) :: nil)))
sum (myfilter (const false) [])
+ sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil))
+ sum (flatten (mymap (mymap λ {x := x + 2}) ((1 :: nil) :: (2 :: 3 :: nil) :: nil)))
+ myf 3 (*) 2 5 true
+ myf 1 (+) 2 0 false
+ myf' 7 (const (+)) 2 0
Expand Down
Loading