Skip to content

Commit

Permalink
Improve inference for --new-typechecker (#2524)
Browse files Browse the repository at this point in the history
This pr applies a number of fixes to the new typechecker.
The fixes implemented are:
1. When guessing the arity of the body, we properly use the type
information of the variables in the patterns.
2. When generating wildcards, we name them properly so that they align
with the name in the type signature.
3. When compiling named applications, we inline all clauses of the form
`fun : _ := body`. This is a workaround to
#2247 and
#2517
4. I've had to ignore test027 (Church numerals). While the typechecker
passes and one can see that the types are correct, there is a lambda
where its clauses have different number of patterns. Our goal is to
support that in the near future
(#1706). This is the conflicting
lambda:
    ```
    mutual num : Nat → Num
      := λ : Nat → Num {| (zero : Nat) := czero
      | ((suc n : Nat)) {A} := csuc (num n) {A}}
    ```
5. I've added non-trivial a compilation test involving monad
transformers.
  • Loading branch information
janmasrovira authored Nov 28, 2023
1 parent 628dd23 commit d6c1a74
Show file tree
Hide file tree
Showing 32 changed files with 244 additions and 78 deletions.
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Core/Translation/FromInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Juvix.Compiler.Core.Translation.FromInternal.Builtins.Nat
import Juvix.Compiler.Core.Translation.FromInternal.Data
import Juvix.Compiler.Internal.Data.Name
import Juvix.Compiler.Internal.Extra qualified as Internal
import Juvix.Compiler.Internal.Pretty (ppTrace)
import Juvix.Compiler.Internal.Pretty qualified as Internal
import Juvix.Compiler.Internal.Translation.Extra qualified as Internal
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.TypeChecking qualified as InternalTyped
Expand Down Expand Up @@ -856,7 +857,8 @@ goIden i = do
)
case i of
Internal.IdenVar n -> do
k <- HashMap.lookupDefault impossible id_ <$> asks (^. indexTableVars)
let err = error ("impossible: var not found: " <> ppTrace n <> " at " <> prettyText (getLoc n))
k <- HashMap.lookupDefault err id_ <$> asks (^. indexTableVars)
varsNum <- asks (^. indexTableVarsNum)
return (mkVar (setInfoLocation (n ^. nameLoc) (Info.singleton (NameInfo (n ^. nameText)))) (varsNum - k - 1))
Internal.IdenFunction n -> do
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Internal/Data/LocalVars.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ addTypeMapping v v' = over localTyMap (HashMap.insert v v')

withEmptyLocalVars :: Sem (Reader LocalVars ': r) a -> Sem r a
withEmptyLocalVars = runReader emptyLocalVars

withLocalVars :: LocalVars -> Sem (Reader LocalVars ': r) a -> Sem r a
withLocalVars = runReader
46 changes: 45 additions & 1 deletion src/Juvix/Compiler/Internal/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ genFieldProjection _funDefName _funDefBuiltin info fieldIx = do
saturatedTy = unnamedParameter' implicity (constructorReturnType info)
inductiveArgs = map inductiveToFunctionParam inductiveParams
retTy = constrArgs !! fieldIx
return
cloneFunctionDefSameName
FunctionDef
{ _funDefExamples = [],
_funDefTerminating = False,
Expand Down Expand Up @@ -180,3 +180,47 @@ mkLetClauses pre = goSCC <$> (toList (buildLetMutualBlocks pre))
getFun :: PreLetStatement -> FunctionDef
getFun = \case
PreLetFunctionDef f -> f

inlineLet :: forall r. (Members '[NameIdGen] r) => Let -> Sem r Expression
inlineLet l = do
(lclauses, subs) <-
runOutputList
. execState (mempty @Subs)
$ forM (l ^. letClauses) helper
body' <- substitutionE subs (l ^. letExpression)
return $ case nonEmpty lclauses of
Nothing -> body'
Just cl' ->
ExpressionLet
Let
{ _letClauses = cl',
_letExpression = body'
}
where
helper :: forall r'. (r' ~ (State Subs ': Output LetClause ': r)) => LetClause -> Sem r' ()
helper c = do
subs <- get
c' <- substitutionE subs c
case subsClause c' of
Nothing -> output c'
Just (n, b) -> modify' @Subs (set (at n) (Just b))

subsClause :: LetClause -> Maybe (Name, Expression)
subsClause = \case
LetMutualBlock {} -> Nothing
LetFunDef f -> mkAssoc f
where
mkAssoc :: FunctionDef -> Maybe (Name, Expression)
mkAssoc = \case
FunctionDef
{ _funDefType = ExpressionHole {},
_funDefBody = body,
_funDefName = name,
_funDefArgsInfo = []
} -> Just (name, body)
_ -> Nothing

cloneFunctionDefSameName :: (Members '[NameIdGen] r) => FunctionDef -> Sem r FunctionDef
cloneFunctionDefSameName f = do
f' <- clone f
return (set funDefName (f ^. funDefName) f')
12 changes: 6 additions & 6 deletions src/Juvix/Compiler/Internal/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,13 @@ substitutionE m = leafExpressions goLeaf
where
goLeaf :: Expression -> Sem r Expression
goLeaf = \case
ExpressionIden i -> goIden i
ExpressionIden i -> goName (i ^. idenName)
e -> return e
goIden :: Iden -> Sem r Expression
goIden i = case i of
IdenVar v
| Just e <- HashMap.lookup v m -> clone e
_ -> return $ ExpressionIden i
goName :: Name -> Sem r Expression
goName n =
case HashMap.lookup n m of
Just e -> clone e
Nothing -> return (toExpression n)

smallUniverseE :: Interval -> Expression
smallUniverseE = ExpressionUniverse . SmallUniverse
Expand Down
3 changes: 2 additions & 1 deletion src/Juvix/Compiler/Internal/Extra/Clonable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,15 @@ instance Clonable ArgInfo where
_argInfoName
}

-- | Note that the name of the function is fresh. This is desirable when the
-- functionDef is part of a let.
instance Clonable FunctionDef where
freshNameIds :: (Members '[Reader FreshBindersContext, NameIdGen] r) => FunctionDef -> Sem r FunctionDef
freshNameIds fun@FunctionDef {..} = do
ty' <- freshNameIds _funDefType
underBinder fun $ \fun' -> do
body' <- freshNameIds _funDefBody
defaultSig' <- freshNameIds _funDefArgsInfo

return
FunctionDef
{ _funDefName = fun' ^. funDefName,
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Internal/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ instance HasAtomicity ConstructorApp where

instance HasAtomicity PatternArg where
atomicity p
| Implicit <- p ^. patternArgIsImplicit = Atom
| isImplicitOrInstance (p ^. patternArgIsImplicit) = Atom
| isJust (p ^. patternArgName) = Atom
| otherwise = atomicity (p ^. patternArgPattern)

Expand Down
3 changes: 2 additions & 1 deletion src/Juvix/Compiler/Internal/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import Juvix.Compiler.Internal.Data.InstanceInfo (instanceInfoResult, instanceTa
import Juvix.Compiler.Internal.Data.LocalVars
import Juvix.Compiler.Internal.Data.NameDependencyInfo
import Juvix.Compiler.Internal.Data.TypedHole
import Juvix.Compiler.Internal.Extra
import Juvix.Compiler.Internal.Extra.Base
import Juvix.Compiler.Internal.Language
import Juvix.Compiler.Internal.Pretty.Options
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.ArityChecking.Data.Types (Arity)
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.TypeChecking.CheckerNew.Arity qualified as New
Expand Down
28 changes: 15 additions & 13 deletions src/Juvix/Compiler/Internal/Translation/FromConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Juvix.Compiler.Builtins
import Juvix.Compiler.Concrete.Data.Scope.Base (ScoperState, scoperScopedConstructorFields, scoperScopedSignatures)
import Juvix.Compiler.Concrete.Data.ScopedName qualified as S
import Juvix.Compiler.Concrete.Extra qualified as Concrete
import Juvix.Compiler.Concrete.Gen qualified as Gen
import Juvix.Compiler.Concrete.Language qualified as Concrete
import Juvix.Compiler.Concrete.Translation.FromParsed.Analysis.Scoping qualified as Scoper
import Juvix.Compiler.Concrete.Translation.FromParsed.Analysis.Scoping.Error
Expand Down Expand Up @@ -784,16 +785,17 @@ goExpression = \case
napp' =
Concrete.NamedApplication
{ _namedAppName = napp ^. namedApplicationNewName,
_namedAppArgs = nonEmpty' $ createArgumentBlocks (sig ^. nameSignatureArgs)
_namedAppArgs = nonEmpty' (createArgumentBlocks (sig ^. nameSignatureArgs))
}
e <- goNamedApplication napp'
let l =
Internal.Let
{ _letClauses = cls,
_letExpression = e
}
expr <-
Internal.substitutionE updateKind
. Internal.ExpressionLet
$ Internal.Let
{ _letClauses = cls,
_letExpression = e
}
Internal.substitutionE updateKind l
>>= Internal.inlineLet
Internal.clone expr
where
goArgs :: NonEmpty (NamedArgumentNew 'Scoped) -> Sem r (NonEmpty Internal.LetClause)
Expand All @@ -803,7 +805,7 @@ goExpression = \case
goArg = fmap Internal.PreLetFunctionDef . goFunctionDef . (^. namedArgumentNewFunDef)

createArgumentBlocks :: [NameBlock 'Scoped] -> [ArgumentBlock 'Scoped]
createArgumentBlocks sblocks = snd $ foldr goBlock (args0, []) sblocks
createArgumentBlocks = snd . foldr goBlock (args0, [])
where
args0 :: HashSet S.Symbol = HashSet.fromList $ fmap (^. namedArgumentNewFunDef . signName) (toList appargs)
goBlock :: NameBlock 'Scoped -> (HashSet S.Symbol, [ArgumentBlock 'Scoped]) -> (HashSet S.Symbol, [ArgumentBlock 'Scoped])
Expand All @@ -813,11 +815,11 @@ goExpression = \case
where
namesInBlock =
HashSet.intersection
(HashSet.fromList $ HashMap.keys _nameBlock)
(HashSet.fromList (HashMap.keys _nameBlock))
(HashSet.map (^. S.nameConcrete) args)
argNames = HashMap.fromList $ map (\n -> (n ^. S.nameConcrete, n)) $ toList args
argNames = HashMap.fromList . map (\n -> (n ^. S.nameConcrete, n)) $ toList args
args' = HashSet.filter (not . flip HashSet.member namesInBlock . (^. S.nameConcrete)) args
_argBlockArgs = nonEmpty' $ map goArg (toList namesInBlock)
_argBlockArgs = nonEmpty' (map goArg (toList namesInBlock))
block' =
ArgumentBlock
{ _argBlockDelims = Irrelevant Nothing,
Expand All @@ -829,11 +831,11 @@ goExpression = \case
NamedArgument
{ _namedArgName = sym,
_namedArgAssignKw = Irrelevant dummyKw,
_namedArgValue = Concrete.ExpressionIdentifier $ ScopedIden name Nothing
_namedArgValue = Concrete.ExpressionIdentifier (ScopedIden name Nothing)
}
where
name = over S.nameConcrete NameUnqualified $ fromJust $ HashMap.lookup sym argNames
dummyKw = KeywordRef (asciiKw ":=") dummyLoc Ascii
dummyKw = run (runReader dummyLoc (Gen.kw Gen.kwAssign))
dummyLoc = getLoc sym

goDesugaredNamedApplication :: DesugaredNamedApplication -> Sem r Internal.Expression
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.ArityChecking.Data.Types where

import Juvix.Compiler.Internal.Extra
import Juvix.Compiler.Internal.Extra.Base
import Juvix.Compiler.Internal.Language
import Juvix.Prelude
import Juvix.Prelude.Pretty

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ checkPattern = go
go :: FunctionParameter -> PatternArg -> Sem r PatternArg
go argTy patArg = do
matchIsImplicit (argTy ^. paramImplicit) patArg
tyVarMap <- fmap (ExpressionIden . IdenVar) . (^. localTyMap) <$> get
tyVarMap <- localsToSubsE <$> get
ty <- substitutionE tyVarMap (argTy ^. paramType)
let pat = patArg ^. patternArgPattern
name = patArg ^. patternArgName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,12 +544,25 @@ checkClause clauseLoc clauseType clausePats body = do
helper :: [PatternArg] -> Expression -> Sem r (LocalVars, ([PatternArg], Expression))
helper pats ty = runState emptyLocalVars (go pats ty)

genPatternWildcard :: Interval -> FunctionParameter -> Sem (State LocalVars ': r) PatternArg
genPatternWildcard loc par = do
let impl = par ^. paramImplicit
var <- maybe (varFromWildcard (Wildcard loc)) return (par ^. paramName)
addPatternVar var (par ^. paramType) Nothing
return
PatternArg
{ _patternArgIsImplicit = impl,
_patternArgName = Nothing,
_patternArgPattern = PatternVariable var
}

go :: [PatternArg] -> Expression -> Sem (State LocalVars ': r) ([PatternArg], Expression)
go pats bodyTy = case pats of
[] -> do
(bodyParams, bodyRest) <- unfoldFunType' bodyTy
guessedBodyParams <- unfoldArity <$> guessArity body
let pref' :: [IsImplicit] = map (^. paramImplicit) (take pref bodyParams)
locals <- get
guessedBodyParams <- withLocalVars locals (unfoldArity <$> guessArity body)
let pref' :: [FunctionParameter] = take pref bodyParams
pref :: Int = aI - targetI
preImplicits = length . takeWhile isImplicitOrInstance
aI :: Int = preImplicits (map (^. paramImplicit) bodyParams)
Expand All @@ -559,10 +572,9 @@ checkClause clauseLoc clauseType clausePats body = do
let n = length pref'
bodyParams' = drop n bodyParams
ty' = foldFunType bodyParams' bodyRest
wildcards <- mapM (genWildcard clauseLoc) pref'
wildcards <- mapM (genPatternWildcard clauseLoc) pref'
return (wildcards, ty')
| otherwise -> do
return ([], bodyTy)
| otherwise -> return ([], bodyTy)
p : ps -> do
bodyTy' <- weakNormalize bodyTy
case bodyTy' of
Expand All @@ -581,9 +593,9 @@ checkClause clauseLoc clauseType clausePats body = do
loc :: Interval
loc = getLoc par

insertWildcard :: IsImplicit -> Sem (State LocalVars ': r) ([PatternArg], Expression)
insertWildcard impl = do
w <- genWildcard loc impl
insertWildcard :: Sem (State LocalVars ': r) ([PatternArg], Expression)
insertWildcard = do
w <- genPatternWildcard loc par
go (w : p : ps) bodyTy'

case (p ^. patternArgIsImplicit, par ^. paramImplicit) of
Expand All @@ -592,10 +604,10 @@ checkClause clauseLoc clauseType clausePats body = do
(ImplicitInstance, ImplicitInstance) -> checkPatternAndContinue
(Implicit, Explicit) -> throwWrongIsImplicit p Implicit
(ImplicitInstance, Explicit) -> throwWrongIsImplicit p ImplicitInstance
(Explicit, Implicit) -> insertWildcard Implicit
(ImplicitInstance, Implicit) -> insertWildcard Implicit
(Explicit, ImplicitInstance) -> insertWildcard ImplicitInstance
(Implicit, ImplicitInstance) -> insertWildcard ImplicitInstance
(Explicit, Implicit) -> insertWildcard
(ImplicitInstance, Implicit) -> insertWildcard
(Explicit, ImplicitInstance) -> insertWildcard
(Implicit, ImplicitInstance) -> insertWildcard
where
throwWrongIsImplicit :: (Members '[Error TypeCheckerError] r') => PatternArg -> IsImplicit -> Sem r' a
throwWrongIsImplicit patArg expected =
Expand Down Expand Up @@ -647,6 +659,12 @@ matchIsImplicit expected actual =
_wrongPatternIsImplicitActual = actual
}

addPatternVar :: (Members '[State LocalVars, Inference] r) => VarName -> Expression -> Maybe Name -> Sem r ()
addPatternVar v ty argName = do
modify (addType v ty)
registerIdenType v ty
whenJust argName (\v' -> modify (addTypeMapping v' v))

checkPattern ::
forall r.
(Members '[Reader InfoTable, Error TypeCheckerError, State LocalVars, Inference, NameIdGen, State FunctionsTable] r) =>
Expand All @@ -662,9 +680,9 @@ checkPattern = go
ty <- substitutionE tyVarMap (argTy ^. paramType)
let pat = patArg ^. patternArgPattern
name = patArg ^. patternArgName
whenJust name (\n -> addVar n ty argTy)
whenJust name (\n -> addPatternVar n ty (argTy ^. paramName))
pat' <- case pat of
PatternVariable v -> addVar v ty argTy $> pat
PatternVariable v -> addPatternVar v ty (argTy ^. paramName) $> pat
PatternWildcardConstructor {} -> impossible
PatternConstructorApp a -> goPatternConstructor pat ty a
return (set patternArgPattern pat' patArg)
Expand Down Expand Up @@ -714,12 +732,6 @@ checkPattern = go
)
PatternConstructorApp <$> goConstr (IdenInductive ind) a tyArgs

addVar :: VarName -> Expression -> FunctionParameter -> Sem r ()
addVar v ty argType = do
modify (addType v ty)
registerIdenType v ty
whenJust (argType ^. paramName) (\v' -> modify (addTypeMapping v' v))

goConstr :: Iden -> ConstructorApp -> [(InductiveParameter, Expression)] -> Sem r ConstructorApp
goConstr inductivename app@(ConstructorApp c ps _) ctx = do
(_, psTys) <- constructorArgTypes <$> lookupConstructor c
Expand Down Expand Up @@ -1129,7 +1141,7 @@ holesHelper mhint expr = do
where
goImplArgs :: Int -> [ApplicationArg] -> Sem r [ApplicationArg]
goImplArgs 0 as = return as
goImplArgs k ((ApplicationArg Implicit _) : as) = goImplArgs (k - 1) as
goImplArgs k (ApplicationArg Implicit _ : as) = goImplArgs (k - 1) as
goImplArgs _ as = return as

goArgs :: forall r'. (r' ~ State AppBuilder ': r) => Sem r' ()
Expand Down Expand Up @@ -1432,7 +1444,7 @@ typeArity = weakNormalize >=> go

guessArity ::
forall r.
(Members '[Reader InfoTable, Inference] r) =>
(Members '[Reader InfoTable, Inference, Reader LocalVars] r) =>
Expression ->
Sem r Arity
guessArity = \case
Expand All @@ -1449,7 +1461,7 @@ guessArity = \case
ExpressionCase l -> arityCase l
where
idenHelper :: Iden -> Sem r Arity
idenHelper = withEmptyLocalVars . idenArity
idenHelper = idenArity

appHelper :: Application -> Sem r Arity
appHelper a = do
Expand Down Expand Up @@ -1484,7 +1496,7 @@ arityUniverse = ArityUnit
simplelambda :: a
simplelambda = error "simple lambda expressions are not supported by the arity checker"

arityLambda :: forall r. (Members '[Reader InfoTable, Inference] r) => Lambda -> Sem r Arity
arityLambda :: forall r. (Members '[Reader InfoTable, Inference, Reader LocalVars] r) => Lambda -> Sem r Arity
arityLambda l = do
aris <- mapM guessClauseArity (l ^. lambdaClauses)
return $
Expand Down Expand Up @@ -1521,13 +1533,13 @@ guessPatternArgArity p =
}
}

arityLet :: (Members '[Reader InfoTable, Inference] r) => Let -> Sem r Arity
arityLet :: (Members '[Reader InfoTable, Inference, Reader LocalVars] r) => Let -> Sem r Arity
arityLet l = guessArity (l ^. letExpression)

-- | All branches should have the same arity. If they are all the same, we
-- return that, otherwise we return ArityBlocking. Probably something better can
-- be done.
arityCase :: (Members '[Reader InfoTable, Inference] r) => Case -> Sem r Arity
arityCase :: (Members '[Reader InfoTable, Inference, Reader LocalVars] r) => Case -> Sem r Arity
arityCase c = do
aris <- mapM (guessArity . (^. caseBranchExpression)) (c ^. caseBranches)
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.TypeChecking.CheckerNew.Arity where

import Juvix.Compiler.Internal.Extra
import Juvix.Compiler.Internal.Extra.Base
import Juvix.Compiler.Internal.Language
import Juvix.Prelude
import Juvix.Prelude.Pretty

Expand Down
Loading

0 comments on commit d6c1a74

Please sign in to comment.