Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JuvixTree validation #2616

Merged
merged 4 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions app/Commands/Dev/Tree/Read.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ runCommand opts = do
case Tree.runParser (toFilePath afile) s of
Left err -> exitJuvixError (JuvixError err)
Right tab -> do
tab' <- Tree.applyTransformations (project opts ^. treeReadTransformations) tab
unless (project opts ^. treeReadNoPrint) $
renderStdOut (Tree.ppOutDefault tab' tab')
doEval tab'
r <- runError @JuvixError (Tree.applyTransformations (project opts ^. treeReadTransformations) tab)
case r of
Left err -> exitJuvixError (JuvixError err)
Right tab' -> do
unless (project opts ^. treeReadNoPrint) $
renderStdOut (Tree.ppOutDefault tab' tab')
doEval tab'
where
file :: AppPath File
file = opts ^. treeReadInputFile
Expand Down
6 changes: 3 additions & 3 deletions src/Juvix/Compiler/Asm/Extra/Memory.hs
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ unifyMemory' loc tab mem1 mem2 = do
unless (length (mem1 ^. memoryValueStack) == length (mem2 ^. memoryValueStack)) $
throw $
AsmError loc "value stack height mismatch"
vs <- zipWithM (unifyTypes' loc tab) (toList (mem1 ^. memoryValueStack)) (toList (mem2 ^. memoryValueStack))
vs <- zipWithM (unifyTypes'' loc tab) (toList (mem1 ^. memoryValueStack)) (toList (mem2 ^. memoryValueStack))
unless (length (mem1 ^. memoryTempStack) == length (mem2 ^. memoryTempStack)) $
throw $
AsmError loc "temporary stack height mismatch"
ts <- zipWithM (unifyTypes' loc tab) (toList (mem1 ^. memoryTempStack)) (toList (mem2 ^. memoryTempStack))
ts <- zipWithM (unifyTypes'' loc tab) (toList (mem1 ^. memoryTempStack)) (toList (mem2 ^. memoryTempStack))
unless
( length (mem1 ^. memoryArgumentArea) == length (mem2 ^. memoryArgumentArea)
&& mem1 ^. memoryArgsNum == mem2 ^. memoryArgsNum
Expand All @@ -183,7 +183,7 @@ unifyMemory' loc tab mem1 mem2 = do
args <-
mapM
( \off ->
unifyTypes'
unifyTypes''
loc
tab
(fromJust $ HashMap.lookup off (mem1 ^. memoryArgumentArea))
Expand Down
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ recurse' sig = go True
checkValueStack' loc (sig ^. recursorInfoTable) tyargs mem
tys <-
zipWithM
(\ty idx -> unifyTypes' loc (sig ^. recursorInfoTable) ty (topValueStack' idx mem))
(\ty idx -> unifyTypes'' loc (sig ^. recursorInfoTable) ty (topValueStack' idx mem))
tyargs
[0 ..]
return $
Expand Down Expand Up @@ -226,7 +226,7 @@ recurse' sig = go True
checkValueStack' loc (sig ^. recursorInfoTable) (take argsNum (typeArgs ty)) mem'
let tyargs = topValuesFromValueStack' argsNum mem'
-- `typeArgs ty` may be shorter than `tyargs` only if `ty` is dynamic
zipWithM_ (unifyTypes' loc (sig ^. recursorInfoTable)) tyargs (typeArgs ty)
zipWithM_ (unifyTypes'' loc (sig ^. recursorInfoTable)) tyargs (typeArgs ty)
return $
pushValueStack (mkTypeFun (drop argsNum (typeArgs ty)) (typeTarget ty)) $
popValueStack argsNum mem'
Expand Down
84 changes: 9 additions & 75 deletions src/Juvix/Compiler/Asm/Extra/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,84 +4,18 @@ module Juvix.Compiler.Asm.Extra.Type
)
where

import Data.List.NonEmpty qualified as NonEmpty
import Juvix.Compiler.Asm.Data.InfoTable
import Juvix.Compiler.Asm.Error
import Juvix.Compiler.Asm.Language
import Juvix.Compiler.Asm.Pretty
import Juvix.Compiler.Tree.Error
import Juvix.Compiler.Tree.Extra.Type

unifyTypes :: forall r. (Members '[Error AsmError, Reader (Maybe Location), Reader InfoTable] r) => Type -> Type -> Sem r Type
unifyTypes ty1 ty2 = case (ty1, ty2) of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM unifyTypes (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM unifyTypes args1 args2
tgt <- unifyTypes tgt1 tgt2
return $ TyFun (TypeFun (NonEmpty.fromList args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
unifyTypes'' :: forall t e r. (Member (Error AsmError) r) => Maybe Location -> InfoTable' t e -> Type -> Type -> Sem r Type
unifyTypes'' loc tab ty1 ty2 = mapError toAsmError $ unifyTypes' loc tab ty1 ty2
where
err :: Sem r a
err = do
loc <- ask
tab <- ask
throw $ AsmError loc ("not unifiable: " <> ppTrace tab ty1 <> ", " <> ppTrace tab ty2)

unifyTypes' :: (Member (Error AsmError) r) => Maybe Location -> InfoTable -> Type -> Type -> Sem r Type
unifyTypes' loc tab ty1 ty2 =
runReader loc $
runReader tab $
-- The `if` is to ensure correct behaviour with dynamic type targets. E.g.
-- `(A, B) -> *` should unify with `A -> B -> C -> D`.
if
| tgt1 == TyDynamic || tgt2 == TyDynamic ->
unifyTypes (curryType ty1) (curryType ty2)
| otherwise ->
unifyTypes ty1 ty2
where
tgt1 = typeTarget (uncurryType ty1)
tgt2 = typeTarget (uncurryType ty2)
toAsmError :: TreeError -> AsmError
toAsmError TreeError {..} =
AsmError
{ _asmErrorLoc = _treeErrorLoc,
_asmErrorMsg = _treeErrorMsg
}
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ coreToVampIR' = Core.toStored' >=> storedCoreToVampIR'
-- Other workflows
--------------------------------------------------------------------------------

treeToAsm :: Tree.InfoTable -> Sem r Asm.InfoTable
treeToAsm :: (Member (Error JuvixError) r) => Tree.InfoTable -> Sem r Asm.InfoTable
treeToAsm = Tree.toAsm >=> return . Asm.fromTree

treeToNockma :: (Members '[Error JuvixError, Reader EntryPoint] r) => Tree.InfoTable -> Sem r (Nockma.Cell Natural)
Expand Down
6 changes: 4 additions & 2 deletions src/Juvix/Compiler/Tree/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ data TransformationId
| Apply
| TempHeight
| FilterUnreachable
| Validate
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
Expand All @@ -21,10 +22,10 @@ data PipelineId
type TransformationLikeId = TransformationLikeId' TransformationId PipelineId

toNockmaTransformations :: [TransformationId]
toNockmaTransformations = [Apply, FilterUnreachable, TempHeight]
toNockmaTransformations = [Validate, Apply, FilterUnreachable, TempHeight]

toAsmTransformations :: [TransformationId]
toAsmTransformations = []
toAsmTransformations = [Validate]

instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
Expand All @@ -35,6 +36,7 @@ instance TransformationId' TransformationId where
Apply -> strApply
TempHeight -> strTempHeight
FilterUnreachable -> strFilterUnreachable
Validate -> strValidate

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ strTempHeight = "temp-height"

strFilterUnreachable :: Text
strFilterUnreachable = "filter-unreachable"

strValidate :: Text
strValidate = "validate"
85 changes: 84 additions & 1 deletion src/Juvix/Compiler/Tree/Extra/Type.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

{-# HLINT ignore "Avoid restricted extensions" #-}
{-# HLINT ignore "Avoid restricted flags" #-}

module Juvix.Compiler.Tree.Extra.Type where

import Juvix.Compiler.Tree.Data.InfoTable.Base
import Juvix.Compiler.Tree.Error
import Juvix.Compiler.Tree.Language.Base
import Juvix.Compiler.Tree.Language.Type
import Juvix.Compiler.Tree.Pretty

mkTypeInteger :: Type
mkTypeInteger = TyInteger (TypeInteger Nothing Nothing)
Expand Down Expand Up @@ -98,3 +106,78 @@ isSubtype' ty1 ty2
tgt2 = typeTarget (uncurryType ty2)
isSubtype' ty1 ty2 =
isSubtype ty1 ty2

unifyTypes :: forall t e r. (Members '[Error TreeError, Reader (Maybe Location), Reader (InfoTable' t e)] r) => Type -> Type -> Sem r Type
unifyTypes ty1 ty2 = case (ty1, ty2) of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM (unifyTypes @t @e) args1 args2
tgt <- unifyTypes @t @e tgt1 tgt2
return $ TyFun (TypeFun (nonEmpty' args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
where
err :: Sem r a
err = do
loc <- ask
tab <- ask @(InfoTable' t e)
throw $ TreeError loc ("not unifiable: " <> ppTrace' (defaultOptions tab) ty1 <> ", " <> ppTrace' (defaultOptions tab) ty2)

unifyTypes' :: forall t e r. (Member (Error TreeError) r) => Maybe Location -> InfoTable' t e -> Type -> Type -> Sem r Type
unifyTypes' loc tab ty1 ty2 =
runReader loc $
runReader tab $
-- The `if` is to ensure correct behaviour with dynamic type targets. E.g.
-- `(A, B) -> *` should unify with `A -> B -> C -> D`.
if
| tgt1 == TyDynamic || tgt2 == TyDynamic ->
unifyTypes @t @e (curryType ty1) (curryType ty2)
| otherwise ->
unifyTypes @t @e ty1 ty2
where
tgt1 = typeTarget (uncurryType ty1)
tgt2 = typeTarget (uncurryType ty2)
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Tree/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ where
import Juvix.Compiler.Tree.Data.InfoTable
import Juvix.Compiler.Tree.Transformation

toNockma :: InfoTable -> Sem r InfoTable
toNockma :: (Member (Error JuvixError) r) => InfoTable -> Sem r InfoTable
toNockma = applyTransformations toNockmaTransformations

toAsm :: InfoTable -> Sem r InfoTable
toAsm :: (Member (Error JuvixError) r) => InfoTable -> Sem r InfoTable
toAsm = applyTransformations toAsmTransformations
5 changes: 4 additions & 1 deletion src/Juvix/Compiler/Tree/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ module Juvix.Compiler.Tree.Transformation
where

import Juvix.Compiler.Tree.Data.TransformationId
import Juvix.Compiler.Tree.Error
import Juvix.Compiler.Tree.Transformation.Apply
import Juvix.Compiler.Tree.Transformation.Base
import Juvix.Compiler.Tree.Transformation.FilterUnreachable
import Juvix.Compiler.Tree.Transformation.Identity
import Juvix.Compiler.Tree.Transformation.TempHeight
import Juvix.Compiler.Tree.Transformation.Validate

applyTransformations :: forall r. [TransformationId] -> InfoTable -> Sem r InfoTable
applyTransformations :: forall r. (Member (Error JuvixError) r) => [TransformationId] -> InfoTable -> Sem r InfoTable
applyTransformations ts tbl = foldM (flip appTrans) tbl ts
where
appTrans :: TransformationId -> InfoTable -> Sem r InfoTable
Expand All @@ -23,3 +25,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
Apply -> return . computeApply
TempHeight -> return . computeTempHeight
FilterUnreachable -> return . filterUnreachable
Validate -> mapError (JuvixError @TreeError) . validate
Loading
Loading