Skip to content

Commit

Permalink
VampIR integration (#2103)
Browse files Browse the repository at this point in the history
* Closes #2035 
* Depends on #2086 
* Depends on #2096 
* Adds end-to-end tests for the Juvix-to-VampIR compilation pipeline.

---------

Co-authored-by: Jonathan Cubides <jonathan.cubides@uib.no>
  • Loading branch information
lukaszcz and jonaprieto authored May 22, 2023
1 parent 2148d17 commit d576111
Show file tree
Hide file tree
Showing 91 changed files with 1,127 additions and 361 deletions.
13 changes: 5 additions & 8 deletions src/Juvix/Compiler/Backend/VampIR/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ instance PrettyCode LocalDef where
ppCode LocalDef {..} = do
n <- ppName KNameLocal _localDefName
v <- ppCode _localDefValue
return $ kwDef <+> n <+> kwEq <+> v <> semi <> line
return $ kwDef <+> n <+> kwEq <+> v <> semi

instance PrettyCode Function where
ppCode Function {..} = do
Expand All @@ -71,9 +71,9 @@ instance PrettyCode Function where
ppEquation :: Function -> Sem r (Doc Ann)
ppEquation Function {..} = do
let n = length _functionArguments
args = if n == 1 then ["in"] else map (\k -> "in" <> show k) [1 .. n]
args = if n == 1 then ["(in + 0)"] else map (\k -> "(in" <> show k <> " + 0)") [1 .. n]
fn <- ppName KNameFunction _functionName
return $ fn <+> hsep args <+> kwEq <+> "out" <> semi
return $ fn <+> hsep args <+> kwEq <+> "(out + 0)" <> semi

instance PrettyCode Program where
ppCode Program {..} = do
Expand All @@ -92,11 +92,8 @@ vampIRDefs =
"def mul x y = x * y;",
"def isZero x = {def xi = fresh (1 | x); x * (1 - xi * x) = 0; 1 - xi * x};",
"def equal x y = isZero (x - y);",
"def bool x = {x * (x - 1) = 0; x};",
"def range32 a = {def a0 = bool (fresh ((a \\ 1) % 2)); def a1 = bool (fresh ((a \\ 2) % 2)); def a2 = bool (fresh ((a \\ 4) % 2)); def a3 = bool (fresh ((a \\ 8) % 2)); def a4 = bool (fresh ((a \\ 16) % 2)); def a5 = bool (fresh ((a \\ 32) % 2)); def a6 = bool (fresh ((a \\ 64) % 2)); def a7 = bool (fresh ((a \\ 128) % 2)); def a8 = bool (fresh ((a \\ 256) % 2)); def a9 = bool (fresh ((a \\ 512) % 2)); def a10 = bool (fresh ((a \\ 1024) % 2)); def a11 = bool (fresh ((a \\ 2048) % 2)); def a12 = bool (fresh ((a \\ 4096) % 2)); def a13 = bool (fresh ((a \\ 8192) % 2)); def a14 = bool (fresh ((a \\ 16384) % 2)); def a15 = bool (fresh ((a \\ 32768) % 2)); def a16 = bool (fresh ((a \\ 65536) % 2)); def a17 = bool (fresh ((a \\ 131072) % 2)); def a18 = bool (fresh ((a \\ 262144) % 2)); def a19 = bool (fresh ((a \\ 524288) % 2)); def a20 = bool (fresh ((a \\ 1048576) % 2)); def a21 = bool (fresh ((a \\ 2097152) % 2)); def a22 = bool (fresh ((a \\ 4194304) % 2)); def a23 = bool (fresh ((a \\ 8388608) % 2)); def a24 = bool (fresh ((a \\ 16777216) % 2)); def a25 = bool (fresh ((a \\ 33554432) % 2)); def a26 = bool (fresh ((a \\ 67108864) % 2)); def a27 = bool (fresh ((a \\ 134217728) % 2)); def a28 = bool (fresh ((a \\ 268435456) % 2)); def a29 = bool (fresh ((a \\ 536870912) % 2)); def a30 = bool (fresh ((a \\ 1073741824) % 2)); def a31 = bool (fresh ((a \\ 2147483648) % 2)); a = ((((((((((((((((((((((((((((((a0 + (2 * a1)) + (4 * a2)) + (8 * a3)) + (16 * a4)) + (32 * a5)) + (64 * a6)) + (128 * a7)) + (256 * a8)) + (512 * a9)) + (1024 * a10)) + (2048 * a11)) + (4096 * a12)) + (8192 * a13)) + (16384 * a14)) + (32768 * a15)) + (65536 * a16)) + (131072 * a17)) + (262144 * a18)) + (524288 * a19)) + (1048576 * a20)) + (2097152 * a21)) + (4194304 * a22)) + (8388608 * a23)) + (16777216 * a24)) + (33554432 * a25)) + (67108864 * a26)) + (134217728 * a27)) + (268435456 * a28)) + (536870912 * a29)) + (1073741824 * a30)) + (2147483648 * a31); (a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, a29, a30, a31, ())};",
"def intrange32 a = {range32 (a + 2147483648)};",
"def negative32 a = {def (a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, a29, a30, a31, ()) = intrange32 a; a31};",
"def isNegative x = negative32 x;",
"def isBool x = (x * (x - 1) = 0);",
"def isNegative a = {def e = 2^30; def b = a + e; def b0 = fresh (b % e); def b1 = fresh (b \\ e); isBool b1; b = b0 + e * b1; 1 - b1};",
"def lessThan x y = isNegative (x - y);",
"def lessOrEqual x y = lessThan x (y + 1);",
"def divRem a b = {def q = fresh (a\\b); def r = fresh (a%b); isNegative r = 0; lessThan r b = 1; a = b * q + r; (q, r) };",
Expand Down
3 changes: 2 additions & 1 deletion src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ data TransformationId
| DisambiguateNames
| CheckGeb
| CheckExec
| CheckVampIR
| Normalize
| LetFolding
| LambdaFolding
Expand Down Expand Up @@ -68,7 +69,7 @@ toNormalizeTransformations :: [TransformationId]
toNormalizeTransformations = toEvalTransformations ++ [LetRecLifting, LetFolding, UnrollRecursion]

toVampIRTransformations :: [TransformationId]
toVampIRTransformations = toNormalizeTransformations ++ [Normalize, LetHoisting]
toVampIRTransformations = toEvalTransformations ++ [CheckVampIR, LetRecLifting, LetFolding, UnrollRecursion, Normalize, LetHoisting]

toStrippedTransformations :: [TransformationId]
toStrippedTransformations =
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ transformationText = \case
DisambiguateNames -> strDisambiguateNames
CheckGeb -> strCheckGeb
CheckExec -> strCheckExec
CheckVampIR -> strCheckVampIR
Normalize -> strNormalize
LetFolding -> strLetFolding
LambdaFolding -> strLambdaFolding
Expand Down Expand Up @@ -169,6 +170,9 @@ strCheckGeb = "check-geb"
strCheckExec :: Text
strCheckExec = "check-exec"

strCheckVampIR :: Text
strCheckVampIR = "check-vampir"

strNormalize :: Text
strNormalize = "normalize"

Expand Down
6 changes: 4 additions & 2 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import Juvix.Compiler.Core.Data.TransformationId
import Juvix.Compiler.Core.Error
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.CheckExec
import Juvix.Compiler.Core.Transformation.CheckGeb
import Juvix.Compiler.Core.Transformation.Check.Exec
import Juvix.Compiler.Core.Transformation.Check.Geb
import Juvix.Compiler.Core.Transformation.Check.VampIR
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo
import Juvix.Compiler.Core.Transformation.ConvertBuiltinTypes
import Juvix.Compiler.Core.Transformation.DisambiguateNames
Expand Down Expand Up @@ -61,6 +62,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
DisambiguateNames -> return . disambiguateNames
CheckGeb -> mapError (JuvixError @CoreError) . checkGeb
CheckExec -> mapError (JuvixError @CoreError) . checkExec
CheckVampIR -> mapError (JuvixError @CoreError) . checkVampIR
Normalize -> return . normalize
LetFolding -> return . letFolding
LambdaFolding -> return . lambdaFolding
Expand Down
118 changes: 118 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Check/Base.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
module Juvix.Compiler.Core.Transformation.Check.Base where

import Juvix.Compiler.Core.Data.InfoTable
import Juvix.Compiler.Core.Data.TypeDependencyInfo (createTypeDependencyInfo)
import Juvix.Compiler.Core.Error
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.LocationInfo (getInfoLocation)
import Juvix.Compiler.Core.Info.TypeInfo qualified as Info
import Juvix.Compiler.Core.Language
import Juvix.Data.PPOutput

dynamicTypeError :: Node -> Maybe Location -> CoreError
dynamicTypeError node loc =
CoreError
{ _coreErrorMsg = ppOutput $ "compilation for this target requires full type information",
_coreErrorNode = Just node,
_coreErrorLoc = fromMaybe defaultLoc loc
}

unsupportedError :: Text -> Node -> Maybe Location -> CoreError
unsupportedError what node loc =
CoreError
{ _coreErrorMsg = ppOutput $ pretty what <> " not supported for this target",
_coreErrorNode = Just node,
_coreErrorLoc = fromMaybe defaultLoc loc
}

defaultLoc :: Interval
defaultLoc = singletonInterval (mkInitialLoc mockFile)
where
mockFile :: Path Abs File
mockFile = $(mkAbsFile "/core-check")

checkBuiltins :: forall r. Member (Error CoreError) r => Bool -> Node -> Sem r Node
checkBuiltins allowUntypedFail = dmapRM go
where
go :: Node -> Sem r Recur
go node = case node of
NPrim TypePrim {..}
| _typePrimPrimitive == PrimString ->
throw $ unsupportedError "strings" node (getInfoLocation _typePrimInfo)
NBlt BuiltinApp {..} ->
case _builtinAppOp of
OpShow -> throw $ unsupportedError "strings" node (getInfoLocation _builtinAppInfo)
OpStrConcat -> throw $ unsupportedError "strings" node (getInfoLocation _builtinAppInfo)
OpStrToInt -> throw $ unsupportedError "strings" node (getInfoLocation _builtinAppInfo)
OpTrace -> throw $ unsupportedError "tracing" node (getInfoLocation _builtinAppInfo)
OpFail | not allowUntypedFail -> do
let ty = Info.getInfoType _builtinAppInfo
when (isDynamic ty) $
throw $
unsupportedError "failing without type info" node (getInfoLocation _builtinAppInfo)
return $ Recur node
OpFail -> do
return $ End node
_ -> return $ Recur node
_ -> return $ Recur node

checkNoIO :: forall r. Member (Error CoreError) r => Node -> Sem r Node
checkNoIO = dmapM go
where
go :: Node -> Sem r Node
go node = case node of
NCtr Constr {..} ->
case _constrTag of
BuiltinTag TagReturn -> throw $ unsupportedError "IO" node (getInfoLocation _constrInfo)
BuiltinTag TagBind -> throw $ unsupportedError "IO" node (getInfoLocation _constrInfo)
BuiltinTag TagReadLn -> throw $ unsupportedError "IO" node (getInfoLocation _constrInfo)
BuiltinTag TagWrite -> throw $ unsupportedError "IO" node (getInfoLocation _constrInfo)
_ -> return node
_ -> return node

checkTypes :: forall r. Member (Error CoreError) r => Bool -> InfoTable -> Node -> Sem r Node
checkTypes allowPolymorphism tab = dmapM go
where
go :: Node -> Sem r Node
go node = case node of
NIdt Ident {..}
| isDynamic (lookupIdentifierInfo tab _identSymbol ^. identifierType) ->
throw (dynamicTypeError node (getInfoLocation _identInfo))
NLam Lambda {..}
| isDynamic (_lambdaBinder ^. binderType) ->
throw (dynamicTypeError node (_lambdaBinder ^. binderLocation))
NLet Let {..}
| isDynamic (_letItem ^. letItemBinder . binderType) ->
throw (dynamicTypeError node (_letItem ^. letItemBinder . binderLocation))
NRec LetRec {..}
| any (isDynamic . (^. letItemBinder . binderType)) _letRecValues ->
throw (dynamicTypeError node (head _letRecValues ^. letItemBinder . binderLocation))
NPi Pi {..}
| not allowPolymorphism && isTypeConstr tab (_piBinder ^. binderType) ->
throw
CoreError
{ _coreErrorMsg = ppOutput "polymorphism not supported for this target",
_coreErrorNode = Just node,
_coreErrorLoc = fromMaybe defaultLoc (_piBinder ^. binderLocation)
}
_ -> return node

checkNoRecursiveTypes :: forall r. Member (Error CoreError) r => InfoTable -> Sem r ()
checkNoRecursiveTypes tab =
when (isCyclic (createTypeDependencyInfo tab)) $
throw
CoreError
{ _coreErrorMsg = ppOutput "recursive types not supported for the GEB target",
_coreErrorNode = Nothing,
_coreErrorLoc = defaultLoc
}

checkMainExists :: forall r. Member (Error CoreError) r => InfoTable -> Sem r ()
checkMainExists tab =
when (isNothing (tab ^. infoMain)) $
throw
CoreError
{ _coreErrorMsg = ppOutput "no `main` function",
_coreErrorNode = Nothing,
_coreErrorLoc = defaultLoc
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
module Juvix.Compiler.Core.Transformation.CheckExec where
module Juvix.Compiler.Core.Transformation.Check.Exec where

import Juvix.Compiler.Core.Error
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Check.Base
import Juvix.Data.PPOutput

checkExec :: forall r. Member (Error CoreError) r => InfoTable -> Sem r InfoTable
checkExec tab =
case tab ^. infoMain of
Nothing -> return tab
Nothing ->
throw
CoreError
{ _coreErrorMsg = ppOutput "no `main` function",
_coreErrorNode = Nothing,
_coreErrorLoc = defaultLoc
}
Just sym ->
case ii ^. identifierType of
NPi {} ->
Expand All @@ -31,9 +38,3 @@ checkExec tab =
where
ii = lookupIdentifierInfo tab sym
loc = fromMaybe defaultLoc (ii ^. identifierLocation)

mockFile :: Path Abs File
mockFile = $(mkAbsFile "/core-to-exec")

defaultLoc :: Interval
defaultLoc = singletonInterval (mkInitialLoc mockFile)
13 changes: 13 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Check/Geb.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module Juvix.Compiler.Core.Transformation.Check.Geb where

import Juvix.Compiler.Core.Error
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Check.Base

checkGeb :: forall r. Member (Error CoreError) r => InfoTable -> Sem r InfoTable
checkGeb tab =
checkMainExists tab
>> checkNoRecursiveTypes tab
>> mapAllNodesM checkNoIO tab
>> mapAllNodesM (checkBuiltins False) tab
>> mapAllNodesM (checkTypes False tab) tab
36 changes: 36 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Check/VampIR.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
module Juvix.Compiler.Core.Transformation.Check.VampIR where

import Juvix.Compiler.Core.Error
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Check.Base
import Juvix.Data.PPOutput

checkVampIR :: forall r. Member (Error CoreError) r => InfoTable -> Sem r InfoTable
checkVampIR tab =
checkMainExists tab
>> checkMainType
>> mapAllNodesM checkNoIO tab
>> mapAllNodesM (checkBuiltins True) tab
where
checkMainType :: Sem r ()
checkMainType =
unless (checkType (ii ^. identifierType)) $
throw
CoreError
{ _coreErrorMsg = ppOutput "for this target the arguments and the result of the `main` function must be numbers",
_coreErrorLoc = fromMaybe defaultLoc (ii ^. identifierLocation),
_coreErrorNode = Nothing
}
where
ii = lookupIdentifierInfo tab (fromJust (tab ^. infoMain))

checkType :: Node -> Bool
checkType ty =
let (tyargs, tgt) = unfoldPi' ty
in all isPrimInteger (tgt : tyargs)
where
isPrimInteger ty' = case ty' of
NPrim (TypePrim _ (PrimInteger _)) -> True
NDyn _ -> True
_ -> False
Loading

0 comments on commit d576111

Please sign in to comment.