diff --git a/intTests/test2049/test.log.1.good b/intTests/test2049/test.log.1.good index d28ff60b97..c17e4facd7 100644 --- a/intTests/test2049/test.log.1.good +++ b/intTests/test2049/test.log.1.good @@ -11,9 +11,7 @@ Subgoal failed: zero test.saw:52:4: error: in ghost_value Literal equality postcondition Expected term: -let { x@1 = Prelude.Vec 8 Prelude.Bool - } - in zero::table +zero::table Actual term: let { x@1 = Prelude.Vec 8 Prelude.Bool } diff --git a/intTests/test2049/test.log.2.good b/intTests/test2049/test.log.2.good index 3283f712b6..b816f1ad45 100644 --- a/intTests/test2049/test.log.2.good +++ b/intTests/test2049/test.log.2.good @@ -11,9 +11,7 @@ Subgoal failed: zero test.saw:52:4: error: in ghost_value Literal equality postcondition Expected term: -let { x@1 = Prelude.Vec 8 Prelude.Bool - } - in zero::table +zero::table Actual term: let { x@1 = Prelude.Vec 8 Prelude.Bool } diff --git a/intTests/test_solver_cache/test_basics.saw b/intTests/test_solver_cache/test_basics.saw index a81b8b53b5..c10b010462 100644 --- a/intTests/test_solver_cache/test_basics.saw +++ b/intTests/test_solver_cache/test_basics.saw @@ -10,32 +10,32 @@ test_solver_cache_stats 0 0 0 0 0; prove_print z3 {{ \(x:[64]) -> x == x }}; test_solver_cache_stats 1 0 0 1 0; -// Testing that cached results do not depend on variable names - thus, the -// cache should now have one more usage, but not a new entry or insertion +// Testing that cached results depend on variable names - thus, the cache +// should now have one more entry and one more insertion, but not a new usage prove_print z3 {{ \(new_name:[64]) -> new_name == new_name }}; -test_solver_cache_stats 1 1 0 1 0; +test_solver_cache_stats 2 0 0 2 0; // Testing that cached results depend on the backend used - thus, the cache // should now have one more entry and one more insertion, but not a new usage prove_print (w4_unint_z3 []) {{ \(x:[64]) -> x == x }}; -test_solver_cache_stats 2 1 0 2 0; +test_solver_cache_stats 3 0 0 3 0; // Testing that cached results depend on the options passed to the given // backend - thus, the cache should now have one more entry and one more // insertion, but not a new usage prove_print (w4_unint_z3_using "qfnia" []) {{ \(x:[64]) -> x == x }}; -test_solver_cache_stats 3 1 0 3 0; +test_solver_cache_stats 4 0 0 4 0; // Same as the above but for sat results fails (prove_print z3 {{ \(x:[64])(y:[64]) -> x == y }}); -test_solver_cache_stats 4 1 0 4 0; +test_solver_cache_stats 5 0 0 5 0; fails (prove_print z3 {{ \(new_name_1:[64])(new_name_2:[64]) -> new_name_1 == new_name_2 }}); -test_solver_cache_stats 4 2 0 4 0; +test_solver_cache_stats 6 0 0 6 0; fails (prove_print w4 {{ \(x:[64])(y:[64]) -> x == y }}); -test_solver_cache_stats 5 2 0 5 0; +test_solver_cache_stats 7 0 0 7 0; fails (prove_print (w4_unint_z3_using "qfnia" []) {{ \(x:[64])(y:[64]) -> x == y }}); -test_solver_cache_stats 6 2 0 6 0; +test_solver_cache_stats 8 0 0 8 0; diff --git a/intTests/test_solver_cache/test_clean.saw b/intTests/test_solver_cache/test_clean.saw index f0bef7a161..b79bba5581 100644 --- a/intTests/test_solver_cache/test_clean.saw +++ b/intTests/test_solver_cache/test_clean.saw @@ -3,7 +3,7 @@ set_solver_cache_path "test_solver_cache.cache"; // The cache still has entries from prior runs -test_solver_cache_stats 6 0 0 0 0; +test_solver_cache_stats 8 0 0 0 0; // After cleaning, all SBV entries should be removed (see test.sh) clean_mismatched_versions_solver_cache; @@ -13,4 +13,4 @@ test_solver_cache_stats 4 0 0 0 0; // as many insertions as there were SBV entries, and as many usages as there // were in test_path_second less the number of SBV entries include "test_ops.saw"; -test_solver_cache_stats 6 6 0 2 0; +test_solver_cache_stats 8 4 0 4 0; diff --git a/intTests/test_solver_cache/test_path_and_reuse.saw b/intTests/test_solver_cache/test_path_and_reuse.saw index 5efce48a0a..4566b991bb 100644 --- a/intTests/test_solver_cache/test_path_and_reuse.saw +++ b/intTests/test_solver_cache/test_path_and_reuse.saw @@ -3,10 +3,10 @@ set_solver_cache_path "test_solver_cache.cache"; // The cache still has entries from the last run -test_solver_cache_stats 6 0 0 0 0; +test_solver_cache_stats 8 0 0 0 0; // After running test_path_ops, we should have the same number of entries, but // no insertions and and as many usages as there were insertions plus usages // the first time include "test_ops.saw"; -test_solver_cache_stats 6 8 0 0 0; +test_solver_cache_stats 8 8 0 0 0; diff --git a/otherTests/saw-core/Tests/Functor.hs b/otherTests/saw-core/Tests/Functor.hs index 291c9a58ee..34a744a199 100644 --- a/otherTests/saw-core/Tests/Functor.hs +++ b/otherTests/saw-core/Tests/Functor.hs @@ -50,15 +50,14 @@ shared :: TermIndex -> TermF Term -> Term shared ix t = STApp { stAppIndex = ix, stAppHash = hash t, - stAppLooseVars = emptyBitSet, stAppFreeVars = mempty, stAppTermF = t } --- | Some LocalNames (LocalName is just Text) -lnFoo, lnBar :: LocalName -lnFoo = "foo" -lnBar = "bar" +-- | Some variable names +vnFoo, vnBar :: VarName +vnFoo = VarName 2 "foo" +vnBar = VarName 3 "bar" -- | An external constant nmFoo :: Name @@ -75,7 +74,7 @@ nmFoo = Name { type Failure = (String, String, String) type Result = Either Failure () --- | Check a result. +-- | Check a result. check :: Failure -> Bool -> Result check _ True = Right () check f False = Left f @@ -175,19 +174,19 @@ instance TestIt Term where let depth' = depth + 1 unit = Unshared $ FTermF $ UnitValue zero = Unshared $ FTermF $ NatLit 0 - localvar = Unshared $ LocalVar 0 + localvar = Unshared $ Variable vnBar t testOne depth' $ PairValue t t testOne depth' $ PairValue t zero testOne depth' $ PairValue unit t testOne depth' $ App t t testOne depth' $ App t zero testOne depth' $ App unit t - testOne depth' $ Lambda lnFoo t t - testOne depth' $ Lambda lnBar t localvar - testOne depth' $ Pi lnFoo t t - testOne depth' $ Pi lnBar t localvar - testTwo depth' GT (Lambda lnFoo t t) (Lambda lnBar t t) - testTwo depth' GT (Pi lnFoo t t) (Pi lnBar t t) + testOne depth' $ Lambda vnFoo t t + testOne depth' $ Lambda vnBar t localvar + testOne depth' $ Pi vnFoo t t + testOne depth' $ Pi vnBar t localvar + testTwo depth' LT (Lambda vnFoo t t) (Lambda vnBar t t) + testTwo depth' LT (Pi vnFoo t t) (Pi vnBar t t) testTwo depth' EQ (Constant nmFoo :: TermF Term) (Constant nmFoo :: TermF Term) testTwo depth comp t1 t2 = do @@ -197,12 +196,12 @@ instance TestIt Term where when (depth < 2 && comp /= EQ) $ do let depth' = depth + 1 -- check that the variable name affects the comparison - testTwo depth' comp (Lambda lnFoo t1 t1) (Lambda lnFoo t2 t2) - testTwo depth' GT (Lambda lnFoo t1 t1) (Lambda lnBar t2 t2) - testTwo depth' LT (Lambda lnBar t1 t1) (Lambda lnFoo t2 t2) - testTwo depth' comp (Pi lnFoo t1 t1) (Pi lnFoo t2 t2) - testTwo depth' GT (Pi lnFoo t1 t1) (Pi lnBar t2 t2) - testTwo depth' LT (Pi lnBar t1 t1) (Pi lnFoo t2 t2) + testTwo depth' comp (Lambda vnFoo t1 t1) (Lambda vnFoo t2 t2) + testTwo depth' LT (Lambda vnFoo t1 t1) (Lambda vnBar t2 t2) + testTwo depth' GT (Lambda vnBar t1 t1) (Lambda vnFoo t2 t2) + testTwo depth' comp (Pi vnFoo t1 t1) (Pi vnFoo t2 t2) + testTwo depth' LT (Pi vnFoo t1 t1) (Pi vnBar t2 t2) + testTwo depth' GT (Pi vnBar t1 t1) (Pi vnFoo t2 t2) pure () -- | Run some tests diff --git a/otherTests/saw-core/Tests/Parser.hs b/otherTests/saw-core/Tests/Parser.hs index 5837bdeaac..de91279738 100644 --- a/otherTests/saw-core/Tests/Parser.hs +++ b/otherTests/saw-core/Tests/Parser.hs @@ -24,11 +24,11 @@ checkDef :: Def -> Assertion checkDef d = do let sym = nameInfo (defName d) let tp = defType d - assertBool (namedMsg sym "Type is not ground.") (termIsClosed tp) + assertBool (namedMsg sym "Type is not ground.") (closedTerm tp) case defBody d of Nothing -> return () Just body -> - assertBool (namedMsg sym "Body is not ground.") (termIsClosed body) + assertBool (namedMsg sym "Body is not ground.") (closedTerm body) checkPrelude :: Assertion checkPrelude = diff --git a/saw-central/src/SAWCentral/Bisimulation.hs b/saw-central/src/SAWCentral/Bisimulation.hs index 517817eed2..19dc5dd20d 100644 --- a/saw-central/src/SAWCentral/Bisimulation.hs +++ b/saw-central/src/SAWCentral/Bisimulation.hs @@ -259,7 +259,6 @@ openConstantApp :: TypedTerm -- 'Constant' (will be unfolded). -> TopLevel (ExtCns Term, TermF Term) openConstantApp constant t = do - sc <- getSharedContext -- Unfold constant name <- constantName (unwrapTermF (ttTerm t)) tUnfolded <- unfold_term [name] t @@ -269,15 +268,15 @@ openConstantApp constant t = do -- Replace outer function's argument with an 'ExtCns' -- NOTE: The bisimulation relation type ensures this is a single argument - -- lambda, so it's OK to apply scOpenTerm once and not recurse - (ec, tExtconsified) <- io $ scOpenTerm sc nm tp 0 body - extractedF <- extractApp constant tExtconsified + -- lambda, so it's OK to not recurse + let ec = EC nm tp + extractedF <- extractApp constant body pure (ec, extractedF) where -- Break down lambda into its component parts. Fails if 'tt' is not a -- lambda. - lambdaOrFail :: TypedTerm -> TopLevel (LocalName, Term, Term) + lambdaOrFail :: TypedTerm -> TopLevel (VarName, Term, Term) lambdaOrFail tt = case asLambda (ttTerm tt) of Just lambda -> return lambda diff --git a/saw-central/src/SAWCentral/Builtins.hs b/saw-central/src/SAWCentral/Builtins.hs index c5c13995f4..976d6f4c93 100644 --- a/saw-central/src/SAWCentral/Builtins.hs +++ b/saw-central/src/SAWCentral/Builtins.hs @@ -71,7 +71,7 @@ import SAWCore.FiniteValue , FirstOrderValue(..) , scFirstOrderValue ) -import SAWCore.Name (ecShortName) +import SAWCore.Name (VarName(..), ecShortName) import SAWCore.SATQuery import SAWCore.SCTypeCheck import SAWCore.Recognizer @@ -272,10 +272,10 @@ replacePrim pat replace t = do let tpat = ttTerm pat let trepl = ttTerm replace - unless (termIsClosed tpat) $ fail $ unlines + unless (closedTerm tpat) $ fail $ unlines [ "pattern term is not closed", show tpat ] - unless (termIsClosed trepl) $ fail $ unlines + unless (closedTerm trepl) $ fail $ unlines [ "replacement term is not closed", show trepl ] io $ do @@ -732,15 +732,15 @@ build_congruence sc tm = case asPiList ty of ([],_) -> fail "congruence_for: Term is not a function" (pis, body) -> - if termIsClosed body then + if closedTerm body then loop pis [] else fail "congruence_for: cannot build congruence for dependent functions" where loop ((nm,tp):pis) vars = - if termIsClosed tp then - do l <- scFreshEC sc (nm <> "_1") tp - r <- scFreshEC sc (nm <> "_2") tp + if closedTerm tp then + do l <- scFreshEC sc (vnName nm <> "_1") tp + r <- scFreshEC sc (vnName nm <> "_2") tp loop pis ((l,r):vars) else fail "congruence_for: cannot build congruence for dependent functions" @@ -1396,8 +1396,8 @@ proveByBVInduction script t = -- and the width of the bitvector we are doing induction on. checkInductionScheme sc opts pis ty = do ty' <- scWhnf sc ty - scAsPi sc ty' >>= \case - Just (ec, body) -> checkInductionScheme sc opts (ec : pis) body + case asPi ty' of + Just (nm, tp, body) -> checkInductionScheme sc opts (EC nm tp : pis) body Nothing -> case asTupleType ty' of Just [bv, bool] -> @@ -1953,7 +1953,7 @@ parseCoreMod mnm_str input = let mnm = mkModuleName $ Text.splitOn "." mnm_str _ <- io $ scFindModule sc mnm -- Check that mnm exists - err_or_t <- io $ inferCompleteTermCtx sc (Just mnm) [] uterm + err_or_t <- io $ inferCompleteTermCtx sc (Just mnm) mempty uterm case err_or_t of Left err -> fail (show err) Right x -> pure x diff --git a/saw-central/src/SAWCentral/Crucible/Common/Override.hs b/saw-central/src/SAWCentral/Crucible/Common/Override.hs index ed1ae3adf2..72be709697 100644 --- a/saw-central/src/SAWCentral/Crucible/Common/Override.hs +++ b/saw-central/src/SAWCentral/Crucible/Common/Override.hs @@ -108,7 +108,7 @@ import Data.Parameterized.TraversableFC (toListFC) import qualified SAWSupport.Pretty as PPS (defaultOpts, limitMaxDepth) -import SAWCore.Name (ecShortName) +import SAWCore.Name (VarName(..), ecShortName) import SAWCore.Prelude as SAWVerifier (scEq) import SAWCore.SharedTerm as SAWVerifier import SAWCore.Term.Functor (unwrapTermF) @@ -723,9 +723,9 @@ matchTerm sc md prepost real expect = do let loc = MS.conditionLoc md free <- OM (use osFree) case unwrapTermF expect of - Variable ec - | Set.member (ecVarIndex ec) free -> - do assignTerm sc md prepost (ecVarIndex ec) real + Variable vn _tp + | Set.member (vnIndex vn) free -> + do assignTerm sc md prepost (vnIndex vn) real _ -> do t <- liftIO $ scEq sc real expect diff --git a/saw-central/src/SAWCentral/Proof.hs b/saw-central/src/SAWCentral/Proof.hs index 3505167fea..121291028b 100644 --- a/saw-central/src/SAWCentral/Proof.hs +++ b/saw-central/src/SAWCentral/Proof.hs @@ -137,6 +137,7 @@ import Control.Monad.Except (ExceptT, MonadError(..), runExceptT) import Control.Monad.Trans.Class (MonadTrans(..)) import qualified Data.Foldable as Fold import qualified Data.IntMap as IntMap +import qualified Data.IntSet as IntSet import Data.List (genericDrop, genericLength, genericSplitAt) import Data.Map (Map) import qualified Data.Map as Map @@ -156,7 +157,7 @@ import SAWCore.Prelude (scApplyPrelude_False) import SAWCore.Recognizer import SAWCore.Rewriter import SAWCore.SATQuery -import SAWCore.Name (DisplayNameEnv, ecShortName) +import SAWCore.Name (DisplayNameEnv, VarName(..), ecShortName) import SAWCore.SharedTerm import SAWCore.Term.Functor import CryptolSAWCore.TypedTerm @@ -255,23 +256,23 @@ splitIte sc (Prop p) = -- | Attempt to split a conjunctive proposition into two propositions. splitConj :: SharedContext -> Prop -> IO (Maybe (Prop, Prop)) splitConj sc (Prop p) = - do (vars, body) <- scAsPiList sc p + do let (vars, body) = asPiList p case (isGlobalDef "Prelude.and" <@> return <@> return) =<< asEqTrue body of Nothing -> pure Nothing Just (_ :*: p1 :*: p2) -> - do t1 <- scGeneralizeExts sc vars =<< scEqTrue sc p1 - t2 <- scGeneralizeExts sc vars =<< scEqTrue sc p2 + do t1 <- scPiList sc vars =<< scEqTrue sc p1 + t2 <- scPiList sc vars =<< scEqTrue sc p2 return (Just (Prop t1,Prop t2)) -- | Attempt to split a disjunctive proposition into two propositions. splitDisj :: SharedContext -> Prop -> IO (Maybe (Prop, Prop)) splitDisj sc (Prop p) = - do (vars, body) <- scAsPiList sc p + do let (vars, body) = asPiList p case (isGlobalDef "Prelude.or" <@> return <@> return) =<< asEqTrue body of Nothing -> pure Nothing Just (_ :*: p1 :*: p2) -> - do t1 <- scGeneralizeExts sc vars =<< scEqTrue sc p1 - t2 <- scGeneralizeExts sc vars =<< scEqTrue sc p2 + do t1 <- scPiList sc vars =<< scEqTrue sc p1 + t2 <- scPiList sc vars =<< scEqTrue sc p2 return (Just (Prop t1,Prop t2)) -- | Attempt to split an implication into a hypothesis and a conclusion @@ -295,8 +296,8 @@ splitImpl sc (Prop p) return (Just (Prop h', Prop c')) -- Handle the case of (H1 -> H2), where H1 and H2 are in Prop - | Just (_nm, arg, c) <- asPi p - , termIsClosed c -- make sure this is a nondependent Pi (AKA arrow type) + | Just (nm, arg, c) <- asPi p + , IntSet.notMember (vnIndex nm) (freeVars c) -- make sure this is a nondependent Pi (AKA arrow type) = termToMaybeProp sc arg >>= \case Nothing -> return Nothing Just h -> return (Just (h, Prop c)) @@ -448,7 +449,9 @@ hoistIfsInProp sc p = do -- fresh ExtCns values, being careful to ensure that -- dependent types are properly substituted. unbindAndFreshenProp :: SharedContext -> Prop -> IO ([ExtCns Term], Term) -unbindAndFreshenProp sc (Prop p0) = scAsPiList sc p0 +unbindAndFreshenProp _sc (Prop p0) = + do let (vars, body) = asPiList p0 + pure (map (uncurry EC) vars, body) -- | Evaluate the given proposition by round-tripping -- through the What4 formula representation. This will @@ -488,10 +491,10 @@ trivialProofTerm :: SharedContext -> Prop -> IO (Either String Term) trivialProofTerm sc (Prop p) = runExceptT (loop =<< lift (scWhnf sc p)) where loop t = - lift (scAsPi sc t) >>= \case - Just (ec, body) -> + case asPi t of + Just (nm, tp, body) -> do pf <- loop =<< lift (scWhnf sc body) - lift $ scAbstractExts sc [ec] pf + lift $ scLambda sc nm tp pf Nothing -> case asEq t of Just (tp, x, _y) -> @@ -1222,7 +1225,7 @@ specializeTheorem sc what4PushMuxOps db loc rsn thm ts = constructTheorem sc what4PushMuxOps db p' (ApplyEvidence thm (map Left ts)) loc Nothing rsn 0 specializeProp :: SharedContext -> Prop -> [Term] -> IO (Either TC.TCError Prop) -specializeProp sc (Prop p0) ts0 = TC.runTCM (loop p0 ts0) sc [] +specializeProp sc (Prop p0) ts0 = TC.runTCM (loop p0 ts0) sc mempty where loop p [] = return (Prop p) loop p (t:ts) = @@ -1311,12 +1314,12 @@ predicateToProp :: SharedContext -> Quantification -> Term -> IO Prop predicateToProp sc quant = loop where loop t = - scAsLambda sc t >>= \case - Just (x, body) -> + case asLambda t of + Just (x, ty, body) -> do Prop body' <- loop body - Prop <$> scGeneralizeExts sc [x] body' + Prop <$> scPi sc x ty body' Nothing -> - do (argTs, resT) <- scAsPiList sc =<< scTypeOf sc t + do (argTs, resT) <- asPiList <$> scTypeOf sc t let toPi [] t0 = case asBoolType resT of Nothing -> fail $ unlines ["predicateToProp : Expected boolean result type but got", showTerm resT] @@ -1324,10 +1327,10 @@ predicateToProp sc quant = loop case quant of Universal -> scEqTrue sc t0 Existential -> scEqTrue sc =<< scNot sc t0 - toPi (ec : ecs) t1 = - do t2 <- scApply sc t1 =<< scVariable sc ec - t3 <- toPi ecs t2 - scGeneralizeExts sc [ec] t3 + toPi ((x, xT) : tys) t1 = + do t2 <- scApply sc t1 =<< scVariable sc (EC x xT) + t3 <- toPi tys t2 + scPi sc x xT t3 Prop <$> toPi argTs t @@ -1465,9 +1468,9 @@ normalizeConcl sc p = _ -> -- handle the case of (H1 -> H2), where H1 and H2 are in Prop case asPi t of - Just (_nm, arg, body) + Just (nm, arg, body) -- check that this is non-dependent Pi (AKA arrow type) - | termIsClosed body -> + | IntSet.notMember (vnIndex nm) (freeVars body) -> termToMaybeProp sc arg >>= \case Nothing -> return (RawSequent [] [p]) Just h -> @@ -1532,8 +1535,8 @@ checkEvidence sc what4PushMuxOps = \e p -> do -- (i.e., nondependent Pi quantifying over a Prop) -- and the given evidence must match the expected prop. checkApply nenv mkSqt (Prop p) (Right e:es) - | Just (_lnm, tp, body) <- asPi p - , termIsClosed body + | Just (lnm, tp, body) <- asPi p + , IntSet.notMember (vnIndex lnm) (freeVars body) = do (d1,sy1) <- check nenv e . mkSqt =<< termToProp sc tp (d2,sy2,p') <- checkApply nenv mkSqt (Prop body) es return (Set.union d1 d2, sy1 <> sy2, p') @@ -1549,7 +1552,7 @@ checkEvidence sc what4PushMuxOps = \e p -> do p_typed <- TC.typeInferComplete p let err = TC.NotFuncTypeInApp p_typed tm' TC.applyPiTyped err p tm' - res <- TC.runTCM m sc [] + res <- TC.runTCM m sc mempty case res of Left msg -> fail (unlines (TC.prettyTCError msg)) Right p' -> checkApply nenv mkSqt (Prop p') es @@ -1734,11 +1737,10 @@ checkEvidence sc what4PushMuxOps = \e p -> do Unfocused -> fail "Intro evidence requires a focused sequent" HypFocus _ _ -> fail "Intro evidence apply in hypothesis" ConclFocus (Prop ptm) mkSqt -> - scAsPi sc ptm >>= \case + case asPi ptm of Nothing -> fail $ unlines ["Intro evidence expected function prop", showTerm ptm] - Just (ec, body) -> - do let ty = ecType ec - let ty' = ecType x + Just (nm, ty, body) -> + do let ty' = ecType x ok <- scConvertible sc False ty ty' unless ok $ fail $ unlines ["Intro evidence types do not match" @@ -1746,7 +1748,7 @@ checkEvidence sc what4PushMuxOps = \e p -> do , showTerm ty ] x' <- scVariable sc x - body' <- scInstantiateExt sc (IntMap.singleton (ecVarIndex ec) x') body + body' <- scInstantiateExt sc (IntMap.singleton (vnIndex nm) x') body check nenv e' (mkSqt (Prop body')) passthroughEvidence :: [Evidence] -> IO Evidence @@ -1889,12 +1891,12 @@ predicateToSATQuery sc unintSet tm0 = Just fot -> filterFirstOrderVars mmap (Map.insert e fot fovars) absvars es processTerm mmap vars tm = - scAsLambda sc tm >>= \case - Just (ec, body) -> - case evalFOT mmap (ecType ec) of - Nothing -> fail ("predicateToSATQuery: expected first order type: " ++ showTerm (ecType ec)) + case asLambda tm of + Just (nm, tp, body) -> + case evalFOT mmap tp of + Nothing -> fail ("predicateToSATQuery: expected first order type: " ++ showTerm tp) Just fot -> - processTerm mmap (Map.insert ec fot vars) body + processTerm mmap (Map.insert (EC nm tp) fot vars) body -- TODO: check that the type is a boolean Nothing -> @@ -1949,16 +1951,15 @@ sequentToSATQuery sc unintSet sqt = do -- TODO: See related TODO in processConcl let tm' = tm - scAsPi sc tm' >>= \case - Just (ec, body) -> - do let tp = ecType ec - -- TODO, same issue + case asPi tm' of + Just (nm, tp, body) -> + do -- TODO, same issue let tp' = tp case evalFOT mmap tp' of Just fot -> - processUnivAssert mmap ((ec, fot) : vars) xs body + processUnivAssert mmap ((EC nm tp, fot) : vars) xs body Nothing - | termIsClosed body -> + | IntSet.null (foldr IntSet.delete (freeVars body) (map (ecVarIndex . fst) vars)) -> case asEqTrue tp' of Just x -> processUnivAssert mmap vars (x:xs) body Nothing -> @@ -1979,17 +1980,16 @@ sequentToSATQuery sc unintSet sqt = -- tm' <- scWhnf sc tm let tm' = tm - scAsPi sc tm' >>= \case - Just (ec, body) -> - do let tp = ecType ec - -- same issue with WHNF + case asPi tm' of + Just (nm, tp, body) -> + do -- same issue with WHNF -- tp' <- scWhnf sc tp let tp' = tp case evalFOT mmap tp' of Just fot -> - processConcl mmap (Map.insert ec fot vars, xs) body + processConcl mmap (Map.insert (EC nm tp) fot vars, xs) body Nothing - | termIsClosed body -> + | IntSet.null (foldr IntSet.delete (freeVars body) (map ecVarIndex (Map.keys vars))) -> do asrt <- processAssert mmap tp processConcl mmap (vars, asrt : xs) body | otherwise -> @@ -2010,7 +2010,7 @@ propApply :: Prop {- ^ propsition to apply (usually a quantified and/or implication term) -} -> Prop {- ^ goal to apply the proposition to -} -> IO (Maybe [Either Term Prop]) -propApply sc rule goal = applyFirst =<< asPiLists (unProp rule) +propApply sc rule goal = applyFirst (asPiLists (unProp rule)) where applyFirst :: [([ExtCns Term], Term)] -> IO (Maybe [Either Term Prop]) applyFirst [] = pure Nothing @@ -2035,13 +2035,12 @@ propApply sc rule goal = applyFirst =<< asPiLists (unProp rule) pure (Left tm) Just <$> traverse mkNewGoal ruleArgs - asPiLists :: Term -> IO [([ExtCns Term], Term)] + asPiLists :: Term -> [([ExtCns Term], Term)] asPiLists t = - scAsPi sc t >>= \case - Nothing -> pure [([], t)] - Just (ec, body) -> - do lists <- asPiLists body - pure $ [ (ec : args, concl) | (args, concl) <- lists ] ++ [([], t)] + case asPi t of + Nothing -> [([], t)] + Just (nm, tp, body) -> + [ (EC nm tp : args, concl) | (args, concl) <- asPiLists body ] ++ [([], t)] -- | Attempt to prove a universally quantified goal by introducing a fresh variable @@ -2053,15 +2052,14 @@ tacticIntro :: (F.MonadFail m, MonadIO m) => tacticIntro sc usernm = Tactic \goal -> case sequentState (goalSequent goal) of ConclFocus p mkSqt -> - liftIO (scAsPi sc (unProp p)) >>= \case - Just (ec, body) -> - do let nm = ecShortName ec - let tp = ecType ec + case asPi (unProp p) of + Just (vn, tp, body) -> + do let nm = vnName vn let name = if Text.null usernm then nm else usernm xv <- liftIO $ scFreshEC sc name tp x <- liftIO $ scVariable sc xv tt <- liftIO $ mkTypedTerm sc x - body' <- liftIO $ scInstantiateExt sc (IntMap.singleton (ecVarIndex ec) x) body + body' <- liftIO $ scInstantiateExt sc (IntMap.singleton (vnIndex vn) x) body let goal' = goal { goalSequent = mkSqt (Prop body') } return (tt, mempty, [goal'], introEvidence xv) diff --git a/saw-central/src/SAWCentral/Yosys/Utils.hs b/saw-central/src/SAWCentral/Yosys/Utils.hs index 1d3346d9f0..f4a926afcc 100644 --- a/saw-central/src/SAWCentral/Yosys/Utils.hs +++ b/saw-central/src/SAWCentral/Yosys/Utils.hs @@ -155,7 +155,7 @@ validateTermAtType :: MonadIO m => SC.SharedContext -> Text -> SC.Term -> SC.Term -> m () validateTermAtType sc msg trm tp = liftIO (SC.TC.runTCM (SC.TC.typeInferComplete trm >>= \tp_trm -> - SC.TC.checkSubtype tp_trm tp) sc []) >>= \case + SC.TC.checkSubtype tp_trm tp) sc mempty) >>= \case Right _ -> return () Left err -> throw diff --git a/saw-core-aig/src/SAWCoreAIG/BitBlast.hs b/saw-core-aig/src/SAWCoreAIG/BitBlast.hs index 30a676a3f9..606dd37a71 100644 --- a/saw-core-aig/src/SAWCoreAIG/BitBlast.hs +++ b/saw-core-aig/src/SAWCoreAIG/BitBlast.hs @@ -482,7 +482,7 @@ asPiTypes sc t = Nothing -> pure ([], t) Just (n, t1, t2) -> do (args, ret) <- asPiTypes sc t2 - pure ((Text.unpack n, t1) : args, ret) + pure ((Text.unpack (vnName n), t1) : args, ret) bitBlastTerm :: AIG.IsAIG l g => diff --git a/saw-core-coq/src/SAWCoreCoq/Term.hs b/saw-core-coq/src/SAWCoreCoq/Term.hs index c0f43b7adb..54a6a3ada8 100644 --- a/saw-core-coq/src/SAWCoreCoq/Term.hs +++ b/saw-core-coq/src/SAWCoreCoq/Term.hs @@ -80,10 +80,6 @@ data TranslationReader = TranslationReader { _currentModule :: Maybe ModuleName -- ^ The current Coq module for the translation - , _localEnvironment :: [Coq.Ident] - -- ^ The list of Coq identifiers associated with the current SAW core - -- Bruijn-indexed local variables in scope, innermost (index 0) first - , _namedEnvironment :: Map.Map VarName Coq.Ident -- ^ The map of Coq identifiers associated with the SAW core named -- variables in scope @@ -188,22 +184,11 @@ invalidateOpenSharing :: TermTranslationMonad m => m a -> m a invalidateOpenSharing = localTR (over sharedNames $ IntMap.filter sharedNameIsClosed) --- | Run a translation in a context with one more SAW core variable with the --- given name. Pass the corresponding Coq identifier used for this SAW core --- variable to the computation in which it is bound. This invalidates all shared --- terms that are not closed, since these shared terms now correspond to --- different terms (with greater deBruijn indices) that have different --- 'TermIndex'es. -withSAWVar :: TermTranslationMonad m => LocalName -> (Coq.Ident -> m a) -> m a -withSAWVar n m = - invalidateOpenSharing $ withFreshIdent n $ \n_coq -> - localTR (over localEnvironment (n_coq :)) $ m n_coq - -- | Run a translation in a context with one more SAW core variable with the -- given name. Pass the corresponding Coq identifier used for this SAW core -- variable to the computation in which it is bound. -withSAWVarEC :: TermTranslationMonad m => VarName -> (Coq.Ident -> m a) -> m a -withSAWVarEC n m = +withSAWVar :: TermTranslationMonad m => VarName -> (Coq.Ident -> m a) -> m a +withSAWVar n m = withFreshIdent (vnName n) $ \n_coq -> localTR (over namedEnvironment (Map.insert n n_coq)) $ m n_coq @@ -215,7 +200,7 @@ withSharedTerm :: TermTranslationMonad m => TermIndex -> Term -> (Coq.Ident -> m a) -> m a withSharedTerm idx t f = do ident <- (view nextSharedName <$> askTR) >>= freshVariant - let sh_nm = SharedName ident $ termIsClosed t + let sh_nm = SharedName ident $ closedTerm t localTR (set nextSharedName (nextVariant ident) . over sharedNames (IntMap.insert idx sh_nm)) $ withUsedCoqIdent ident $ f ident @@ -280,7 +265,6 @@ runTermTranslationMonad configuration mname mm globalDecls localEnv = runTranslationMonad configuration (TranslationReader { _currentModule = mname - , _localEnvironment = localEnv , _namedEnvironment = Map.empty , _unavailableIdents = Set.union reservedIdents (Set.fromList localEnv) , _sharedNames = IntMap.empty @@ -493,7 +477,6 @@ withTopTranslationState m = localTR (\r -> TranslationReader { _currentModule = view currentModule r, - _localEnvironment = [], _namedEnvironment = Map.empty, _unavailableIdents = reservedIdents, _sharedNames = IntMap.empty, @@ -541,24 +524,25 @@ bindTransToPiBinder (BindTrans { .. }) = Coq.PiBinder (Just bindTransIdent) bindTransType : map (\(n,ty) -> Coq.PiImplicitBinder (Just n) ty) bindTransImps --- | Given a 'LocalName' and its type (as a 'Term'), translate the 'LocalName' +-- | Given a 'VarName' and its type (as a 'Term'), translate the 'VarName' -- to a Coq identifier, translate the type to a Coq term, and generate zero or -- more additional 'Ident's and 'Type's representing additonal implicit -- typeclass arguments, added if the given type is @isort@, etc. Pass all of -- this information to the supplied computation, in which the SAW core variable -- is bound to its Coq identifier. -translateBinder :: TermTranslationMonad m => LocalName -> Term -> +translateBinder :: TermTranslationMonad m => VarName -> Term -> (BindTrans -> m a) -> m a -translateBinder n ty@(asPiList -> (args, pi_body)) f = +translateBinder vn ty@(asPiList -> (args, pi_body)) f = do ty' <- translateTerm ty let mb_sort = asSortWithFlags pi_body flagValues = sortFlagsToList $ maybe noFlags snd mb_sort flagLocalNames = [("Inh", "SAWCoreScaffolding.Inhabited"), ("QT", "QuantType")] - withSAWVar n $ \n' -> + withSAWVar vn $ \n' -> helper n' (zip flagValues flagLocalNames) (\imps -> f $ BindTrans n' ty' imps) where + n = vnName vn helper _ [] g = g [] helper n' ((True,(prefix,tc)):rest) g = do nhty <- translateImplicitHyp (Coq.Var tc) args (Coq.Var n') @@ -581,7 +565,7 @@ translateBinderEC ec f = flagValues = sortFlagsToList $ maybe noFlags snd mb_sort flagLocalNames = [("Inh", "SAWCoreScaffolding.Inhabited"), ("QT", "QuantType")] - withSAWVarEC nm $ \n' -> + withSAWVar nm $ \n' -> helper n' (zip flagValues flagLocalNames) (\imps -> f $ BindTrans n' ty' imps) where @@ -602,7 +586,7 @@ translateBinderEC ec f = helper n' ((False,_):rest) g = helper n' rest g -- | Call 'translateBinder' on a list of SAW core bindings -translateBinders :: TermTranslationMonad m => [(LocalName,Term)] -> +translateBinders :: TermTranslationMonad m => [(VarName,Term)] -> ([BindTrans] -> m a) -> m a translateBinders [] f = f [] translateBinders ((n,ty):ns_tys) f = @@ -624,7 +608,7 @@ translateBindersEC (ec : ecs) f = -- function translateImplicitHyp :: TermTranslationMonad m => - Coq.Term -> [(LocalName, Term)] -> Coq.Term -> m Coq.Term + Coq.Term -> [(VarName, Term)] -> Coq.Term -> m Coq.Term translateImplicitHyp tc [] tm = return (Coq.App tc [tm]) translateImplicitHyp tc args tm = translateBinders args $ \args' -> @@ -638,7 +622,7 @@ translateImplicitHyp tc args tm = -- | Given a list of 'LocalName's and their corresponding types (as 'Term's), -- return a list of explicit 'Binder's, for use representing the bound variables -- in 'Lambda's, 'Let's, etc. -translateParams :: TermTranslationMonad m => [(LocalName, Term)] -> +translateParams :: TermTranslationMonad m => [(VarName, Term)] -> ([Coq.Binder] -> m a) -> m a translateParams bs m = translateBinders bs (m . concat . map bindTransToBinder) @@ -652,11 +636,11 @@ translateParamsEC bs m = translateBindersEC bs (m . concatMap bindTransToBinder) --- | Given a list of 'LocalName's and their corresponding types (as 'Term's) +-- | Given a list of 'VarName's and their corresponding types (as 'Term's) -- representing argument types and a 'Term' representing the return type, -- return the resulting 'Pi', with additional implicit arguments added after -- each instance of @isort@, @qsort@, etc. -translatePi :: TermTranslationMonad m => [(LocalName, Term)] -> Term -> m Coq.Term +translatePi :: TermTranslationMonad m => [(VarName, Term)] -> Term -> m Coq.Term translatePi binders body = translatePiBinders binders $ \bindersT -> do bodyT <- translateTermLet body @@ -666,7 +650,7 @@ translatePi binders body = -- 'PiBinder' followed by zero or more implicit 'PiBinder's representing -- additonal implicit typeclass arguments, added if the given type is @isort@, -- @qsort@, etc. -translatePiBinders :: TermTranslationMonad m => [(LocalName, Term)] -> +translatePiBinders :: TermTranslationMonad m => [(VarName, Term)] -> ([Coq.PiBinder] -> m a) -> m a translatePiBinders bs m = translateBinders bs (m . concat . map bindTransToPiBinder) @@ -704,12 +688,8 @@ translateTerm t = -- | Translate a SAW core 'Term' to Coq without using sharing translateTermUnshared :: TermTranslationMonad m => Term -> m Coq.Term -translateTermUnshared t = do +translateTermUnshared t = -- traceTerm "translateTerm" t $ - -- NOTE: env is in innermost-first order - env <- view localEnvironment <$> askTR - -- let t' = trace ("translateTerm: " ++ "env = " ++ show env ++ ", t =" ++ showTerm t) t - -- case t' of case unwrapTermF t of FTermF ftf -> flatTermFToExpr ftf @@ -760,16 +740,11 @@ translateTermUnshared t = do _ -> translateIdentWithArgs i args _ -> Coq.App <$> translateTerm f <*> traverse translateTerm args - LocalVar n - | n < length env -> Coq.Var <$> pure (env !! n) - | otherwise -> Except.throwError $ LocalVarOutOfBounds t - -- Constants Constant n -> translateConstant n - Variable ec -> + Variable nm _tp -> do nenv <- view namedEnvironment <$> askTR - let nm = ecName ec case Map.lookup nm nenv of Just ident -> pure (Coq.Var ident) Nothing -> @@ -807,7 +782,7 @@ defaultTermForType typ = do (asPiList -> (bs,body)) | not (null bs) - , looseVars body == emptyBitSet -> + , closedTerm body -> do bs' <- forM bs $ \ (_nm, ty) -> Coq.Binder "_" . Just <$> translateTerm ty body' <- defaultTermForType body return $ Coq.Lambda bs' body' diff --git a/saw-core-what4/src/SAWCoreWhat4/What4.hs b/saw-core-what4/src/SAWCoreWhat4/What4.hs index 51ab4012df..2161deb8d5 100644 --- a/saw-core-what4/src/SAWCoreWhat4/What4.hs +++ b/saw-core-what4/src/SAWCoreWhat4/What4.hs @@ -1495,7 +1495,7 @@ w4EvalAny sym st sc ps unintSet t = ty <- eval =<< scTypeOf sc t -- get the names of the arguments to the function - let lamNames = map (Text.unpack . fst) (fst (R.asLambdaList t)) + let lamNames = map (Text.unpack . vnName . fst) (fst (R.asLambdaList t)) let varNames = [ "var" ++ show (i :: Integer) | i <- [0 ..] ] let argNames = zipWith (++) varNames (map ("_" ++) lamNames ++ repeat "") diff --git a/saw-core/src/SAWCore/Conversion.hs b/saw-core/src/SAWCore/Conversion.hs index a8e73a05eb..2ef7ce13f7 100644 --- a/saw-core/src/SAWCore/Conversion.hs +++ b/saw-core/src/SAWCore/Conversion.hs @@ -54,7 +54,6 @@ module SAWCore.Conversion , asAnyNatLit , asAnyVecLit , asVariable - , asLocalVar -- ** Prelude matchers , asBoolType , asSuccLit @@ -283,10 +282,6 @@ asAnyVecLit = asVar $ \t -> do ArrayValue u xs <- R.asFTermF t; return (u,xs) asVariable :: Matcher (ExtCns Term) asVariable = asVar R.asVariable --- | Returns index of local var if any. -asLocalVar :: Matcher DeBruijnIndex -asLocalVar = asVar $ \t -> do i <- R.asLocalVar t; return i - ---------------------------------------------------------------------- -- Prelude matchers diff --git a/saw-core/src/SAWCore/ExternalFormat.hs b/saw-core/src/SAWCore/ExternalFormat.hs index e7f03cc16f..1c4b478624 100644 --- a/saw-core/src/SAWCore/ExternalFormat.hs +++ b/saw-core/src/SAWCore/ExternalFormat.hs @@ -82,8 +82,8 @@ scWriteExternal t0 = stashName ec = do (m, nms, lns, x) <- State.get State.put (m, Map.insert (nameIndex ec) (Right (nameInfo ec)) nms, lns, x) - stashEC :: ExtCns Int -> WriteM () - stashEC (EC vn _) = + stashVarName :: VarName -> WriteM () + stashVarName vn = do (m, nms, lns, x) <- State.get State.put (m, Map.insert (vnIndex vn) (Left (vnName vn)) nms, lns, x) @@ -110,15 +110,18 @@ scWriteExternal t0 = writeTermF tf = case tf of App e1 e2 -> pure $ unwords ["App", show e1, show e2] - Lambda s t e -> pure $ unwords ["Lam", Text.unpack s, show t, show e] - Pi s t e -> pure $ unwords ["Pi", Text.unpack s, show t, show e] - LocalVar i -> pure $ unwords ["Var", show i] + Lambda s t e -> + do stashVarName s + pure $ unwords ["Lam", show (vnIndex s), show t, show e] + Pi s t e -> + do stashVarName s + pure $ unwords ["Pi", show (vnIndex s), show t, show e] Constant nm -> do stashName nm pure $ unwords ["Constant", show (nameIndex nm)] - Variable ec -> - do stashEC ec - pure $ unwords ["Variable", show (ecVarIndex ec), show (ecType ec)] + Variable nm tp -> + do stashVarName nm + pure $ unwords ["Variable", show (vnIndex nm), show tp] FTermF ftf -> case ftf of UnitValue -> pure $ unwords ["Unit"] @@ -225,32 +228,30 @@ scReadExternal sc input = do vi <- readM i readName' vi - readEC' :: VarIndex -> Term -> ReadM (ExtCns Term) - readEC' vi t' = + readVarName' :: VarIndex -> ReadM VarName + readVarName' vi = do (ts, nms, vs) <- State.get case Map.lookup vi nms of Just (Left x) -> case Map.lookup vi vs of - Just vi' -> pure (EC (VarName vi' x) t') + Just vi' -> pure (VarName vi' x) Nothing -> do vn <- lift $ scFreshVarName sc x State.put (ts, nms, Map.insert vi (vnIndex vn) vs) - pure $ EC vn t' - _ -> lift $ fail $ "scReadExternal: ExtCns missing name: " ++ show vi + pure vn + _ -> lift $ fail $ "scReadExternal: VarName missing name: " ++ show vi - readEC :: String -> String -> ReadM (ExtCns Term) - readEC i t = + readVarName :: String -> ReadM VarName + readVarName i = do vi <- readM i - t' <- readIdx t - readEC' vi t' + readVarName' vi parse :: [String] -> ReadM (TermF Term) parse tokens = case tokens of ["App", e1, e2] -> App <$> readIdx e1 <*> readIdx e2 - ["Lam", x, t, e] -> Lambda (Text.pack x) <$> readIdx t <*> readIdx e - ["Pi", s, t, e] -> Pi (Text.pack s) <$> readIdx t <*> readIdx e - ["Var", i] -> pure $ LocalVar (read i) + ["Lam", s, t, e] -> Lambda <$> readVarName s <*> readIdx t <*> readIdx e + ["Pi", s, t, e] -> Pi <$> readVarName s <*> readIdx t <*> readIdx e ["Constant",i] -> Constant <$> readName i ["ConstantOpaque",i] -> Constant <$> readName i ["Unit"] -> pure $ FTermF UnitValue @@ -279,5 +280,5 @@ scReadExternal sc input = ["Nat", n] -> FTermF <$> (NatLit <$> readM n) ("Array" : e : es) -> FTermF <$> (ArrayValue <$> readIdx e <*> (V.fromList <$> traverse readIdx es)) ("String" : ts) -> FTermF <$> (StringLit <$> (readM (unwords ts))) - ["Variable", i, t] -> Variable <$> readEC i t + ["Variable", i, t] -> Variable <$> readVarName i <*> readIdx t _ -> fail $ "Parse error: " ++ unwords tokens diff --git a/saw-core/src/SAWCore/OpenTerm.hs b/saw-core/src/SAWCore/OpenTerm.hs index ac2489e1f6..4b1876192f 100644 --- a/saw-core/src/SAWCore/OpenTerm.hs +++ b/saw-core/src/SAWCore/OpenTerm.hs @@ -52,7 +52,7 @@ SAW core 'Term'. module SAWCore.OpenTerm ( -- * Open terms and converting to closed terms - OpenTerm(..), completeOpenTerm, completeNormOpenTerm, completeOpenTermType, + OpenTerm(..), completeOpenTerm, completeOpenTermType, -- * Basic operations for building open terms closedOpenTerm, openOpenTerm, failOpenTerm, bindTCMOpenTerm, bindPPOpenTerm, openTermType, @@ -80,22 +80,16 @@ module SAWCore.OpenTerm ( pairTermLike, pairTypeTermLike, pairLeftTermLike, pairRightTermLike, tupleTermLike, tupleTypeTermLike, projTupleTermLike, letTermLike, sawLetTermLike, - -- * Other exported helper functions - sawLetMinimize ) where import qualified Data.Vector as V -import Control.Monad.State -import Control.Monad.Writer import Control.Monad.Reader import Data.Text (Text) import Numeric.Natural -import Data.IntMap.Strict (IntMap) -import qualified Data.IntMap.Strict as IntMap - import qualified SAWSupport.Pretty as PPS (defaultOpts, render) +import SAWCore.Name import SAWCore.Panic import SAWCore.Term.Functor import SAWCore.Term.Pretty @@ -105,7 +99,6 @@ import SAWCore.Module ( ctorName , dtName ) -import SAWCore.Recognizer -- | An open term is represented as a type-checking computation that computes a @@ -117,25 +110,20 @@ newtype OpenTerm = OpenTerm { unOpenTerm :: TCM SCTypedTerm } completeOpenTerm :: SharedContext -> OpenTerm -> IO Term completeOpenTerm sc (OpenTerm termM) = either (fail . show) return =<< - runTCM (typedVal <$> termM) sc [] - --- | \"Complete\" an 'OpenTerm' to a closed term and 'betaNormalize' the result -completeNormOpenTerm :: SharedContext -> OpenTerm -> IO Term -completeNormOpenTerm sc m = - completeOpenTerm sc m >>= sawLetMinimize sc >>= betaNormalize sc + runTCM (typedVal <$> termM) sc mempty -- | \"Complete\" an 'OpenTerm' to a closed term for its type completeOpenTermType :: SharedContext -> OpenTerm -> IO Term completeOpenTermType sc (OpenTerm termM) = either (fail . show) return =<< - runTCM (typedType <$> termM) sc [] + runTCM (typedType <$> termM) sc mempty -- | Embed a closed 'Term' into an 'OpenTerm' closedOpenTerm :: Term -> OpenTerm closedOpenTerm t = OpenTerm $ typeInferComplete t -- | Embed a 'Term' in the given typing context into an 'OpenTerm' -openOpenTerm :: [(LocalName, Term)] -> Term -> OpenTerm +openOpenTerm :: [(VarName, Term)] -> Term -> OpenTerm openOpenTerm ctx t = -- Extend the local type-checking context, wherever this OpenTerm gets used, -- by appending ctx to the end, so that variables 0..length ctx-1 all get @@ -160,14 +148,13 @@ bindTCMOpenTerm m f = OpenTerm (m >>= unOpenTerm . f) bindPPOpenTerm :: OpenTerm -> (String -> OpenTerm) -> OpenTerm bindPPOpenTerm (OpenTerm m) f = OpenTerm $ - do ctx <- askCtx - t <- typedVal <$> m + do t <- typedVal <$> m -- XXX: this could use scPrettyTermInCtx (which builds in the call to -- PPS.render) except that it's slightly different under the covers -- (in its use of the "global" flag, and it isn't entirely clear what -- that actually does) unOpenTerm $ f $ PPS.render PPS.defaultOpts $ - ppTermInCtx PPS.defaultOpts (map fst ctx) t + ppTermInCtx PPS.defaultOpts [] t -- | Return type type of an 'OpenTerm' as an 'OpenTerm openTermType :: OpenTerm -> OpenTerm @@ -383,36 +370,17 @@ piArgOpenTerm (OpenTerm m) = OpenTerm $ m >>= \case (unwrapTermF . typedVal -> Pi _ tp _) -> typeInferComplete tp t -> - do ctx <- askCtx - fail ("piArgOpenTerm: not a pi type: " ++ - scPrettyTermInCtx PPS.defaultOpts (map fst ctx) (typedVal t)) - --- | Build an 'OpenTerm' for the top variable in the current context, by --- building the 'TCM' computation which checks how much longer the context has --- gotten since the variable was created and uses this to compute its deBruijn --- index -openTermTopVar :: TCM OpenTerm -openTermTopVar = - do outer_ctx <- askCtx - return $ OpenTerm $ do - inner_ctx <- askCtx - typeInferComplete (LocalVar (length inner_ctx - - length outer_ctx) :: TermF Term) - --- | Build an open term inside a binder of a variable with the given name and --- type, where the binder is represented as a Haskell function on 'OpenTerm's -bindOpenTerm :: LocalName -> SCTypedTerm -> (OpenTerm -> OpenTerm) -> - TCM SCTypedTerm -bindOpenTerm x tp body_f = - do tp_whnf <- typeCheckWHNF $ typedVal tp - withVar x tp_whnf (openTermTopVar >>= (unOpenTerm . body_f)) + fail ("piArgOpenTerm: not a pi type: " ++ + scPrettyTermInCtx PPS.defaultOpts [] (typedVal t)) -- | Build a lambda abstraction as an 'OpenTerm' lambdaOpenTerm :: LocalName -> OpenTerm -> (OpenTerm -> OpenTerm) -> OpenTerm lambdaOpenTerm x (OpenTerm tpM) body_f = OpenTerm $ do tp <- tpM - body <- bindOpenTerm x tp body_f - typeInferComplete $ Lambda x tp body + vn <- liftTCM scFreshVarName x + var <- typeInferComplete $ Variable vn tp + body <- unOpenTerm (body_f (OpenTerm (pure var))) + typeInferComplete $ Lambda vn tp body -- | Build a nested sequence of lambda abstractions as an 'OpenTerm' lambdaOpenTermMulti :: [(LocalName, OpenTerm)] -> ([OpenTerm] -> OpenTerm) -> @@ -425,8 +393,10 @@ lambdaOpenTermMulti xs_tps body_f = piOpenTerm :: LocalName -> OpenTerm -> (OpenTerm -> OpenTerm) -> OpenTerm piOpenTerm x (OpenTerm tpM) body_f = OpenTerm $ do tp <- tpM - body <- bindOpenTerm x tp body_f - typeInferComplete $ Pi x tp body + nm <- liftTCM scFreshVarName x + var <- typeInferComplete $ Variable nm tp + body <- unOpenTerm (body_f (OpenTerm (pure var))) + typeInferComplete $ Pi nm tp body -- | Build a non-dependent function type. arrowOpenTerm :: LocalName -> OpenTerm -> OpenTerm -> OpenTerm @@ -645,104 +615,3 @@ sawLetTermLike :: OpenTermLike t => LocalName -> t -> t -> t -> (t -> t) -> t sawLetTermLike x tp tp_ret rhs body_f = applyTermLikeMulti (globalTermLike "Prelude.sawLet") [tp, tp_ret, rhs, lambdaTermLike x tp body_f] - - --------------------------------------------------------------------------------- --- sawLet-minimization - --- | A map from each deBruijn index to a count of its occurrences in a term -newtype VarOccs = VarOccs [Integer] - --- | Make a 'VarOccs' with a single occurrence of a deBruijn index -varOccs1 :: DeBruijnIndex -> VarOccs -varOccs1 i = VarOccs (take i (repeat 0) ++ [1]) - --- | Move a 'VarOccs' out of a binder by returning the number of occurrences of --- deBruijn index 0 along with the result of subtracting 1 from all other indices -unconsVarOccs :: VarOccs -> (Integer, VarOccs) -unconsVarOccs (VarOccs []) = (0, VarOccs []) -unconsVarOccs (VarOccs (cnt:occs)) = (cnt, VarOccs occs) - --- | Multiply every index in a 'VarOccs' by a constant -multVarOccs :: Integer -> VarOccs -> VarOccs -multVarOccs i (VarOccs occs) = VarOccs $ map (* i) occs - --- | The infinite list of zeroes -zeroes :: [Integer] -zeroes = 0:zeroes - -instance Semigroup VarOccs where - (VarOccs occs1) <> (VarOccs occs2) - | length occs1 < length occs2 - = VarOccs (zipWith (+) (occs1 ++ zeroes) occs2) - (VarOccs occs1) <> (VarOccs occs2) - = VarOccs (zipWith (+) occs1 (occs2 ++ zeroes)) - -instance Monoid VarOccs where - mempty = VarOccs [] - --- | 'listen' to the output of a writer computation and return that output but --- drop it from the writer output of the computation -listenDrop :: MonadWriter w m => m a -> m (a, w) -listenDrop m = pass (listen m >>= \aw -> return (aw, const mempty)) - --- | The monad for sawLet minimization -type SLMinM = StateT (IntMap (Term, VarOccs)) (WriterT VarOccs IO) - --- | Find every subterm of the form @sawLet a b rhs (\ x -> body)@ and, whenever --- @x@ occurs at most once in @body@, unfold the @sawLet@ by substituting @rhs@ --- into @body@ -sawLetMinimize :: SharedContext -> Term -> IO Term -sawLetMinimize sc t_top = - fst <$> runWriterT (evalStateT (slMinTerm t_top) IntMap.empty) where - slMinTerm :: Term -> SLMinM Term - slMinTerm (Unshared tf) = slMinTermF tf - slMinTerm t@(STApp { stAppIndex = i }) = - do memo_table <- get - case IntMap.lookup i memo_table of - Just (t', occs) -> - -- NOTE: the fact that we explicitly tell occs here means that we are - -- going to double-count variable occurrences for multiple - -- occurrences of the same subterm. That is, a variable occurence - -- counts for each copy of a shared subterm. - tell occs >> return t' - Nothing -> - do (t', occs) <- listen $ slMinTermF (unwrapTermF t) - modify $ IntMap.insert i (t', occs) - return t' - - slMinTermF :: TermF Term -> SLMinM Term - slMinTermF tf@(App (asApplyAll -> - (isGlobalDef "Prelude.sawLet" -> Just _, [_a, _b, rhs])) - (asLambda -> Just (_, _, body))) = - do (body', (unconsVarOccs -> - (x_cnt, body_occs))) <- listenDrop $ slMinTerm body - if x_cnt > 1 then slMinTermF' tf else - do (rhs', rhs_occs) <- listenDrop $ slMinTerm rhs - tell (multVarOccs x_cnt rhs_occs <> body_occs) - liftIO $ instantiateVar sc 0 rhs' body' - slMinTermF tf = slMinTermF' tf - - slMinTermF' :: TermF Term -> SLMinM Term - slMinTermF' (FTermF ftf) = slMinFTermF ftf - slMinTermF' (App f arg) = - do f' <- slMinTerm f - arg' <- slMinTerm arg - liftIO $ scTermF sc (App f' arg') - slMinTermF' (Lambda x tp body) = - do tp' <- slMinTerm tp - (body', body_occs) <- listenDrop $ slMinTerm body - tell $ snd $ unconsVarOccs body_occs - liftIO $ scTermF sc (Lambda x tp' body') - slMinTermF' (Pi x tp body) = - do tp' <- slMinTerm tp - (body', body_occs) <- listenDrop $ slMinTerm body - tell $ snd $ unconsVarOccs body_occs - liftIO $ scTermF sc (Pi x tp' body') - slMinTermF' tf@(LocalVar i) = - tell (varOccs1 i) >> liftIO (scTermF sc tf) - slMinTermF' tf@(Constant _) = liftIO (scTermF sc tf) - slMinTermF' tf@(Variable _) = liftIO (scTermF sc tf) - - slMinFTermF :: FlatTermF Term -> SLMinM Term - slMinFTermF ftf = traverse slMinTerm ftf >>= liftIO . scFlatTermF sc diff --git a/saw-core/src/SAWCore/Prelude.hs b/saw-core/src/SAWCore/Prelude.hs index f90da1a498..e558c8f8fb 100644 --- a/saw-core/src/SAWCore/Prelude.hs +++ b/saw-core/src/SAWCore/Prelude.hs @@ -119,12 +119,12 @@ scDecEq sc fot args = case fot of Just (x,y) -> mkRecordEqBody (Map.toList fs) x y - Nothing -> - do x <- scLocalVar sc 1 - y <- scLocalVar sc 0 - tp <- scFirstOrderType sc fot - body <- mkRecordEqBody (Map.toList fs) x y - scLambdaList sc [("x",tp),("y",tp)] body + Nothing -> + do tp <- scFirstOrderType sc fot + x <- scFreshVariable sc "x" tp + y <- scFreshVariable sc "y" tp + body <- mkRecordEqBody (Map.toList fs) x y + scAbstractTerms sc [x, y] body where mkRecordEqBody [] _x _y = scBool sc True diff --git a/saw-core/src/SAWCore/Recognizer.hs b/saw-core/src/SAWCore/Recognizer.hs index bbb5cf494f..1dd6b9a7f2 100644 --- a/saw-core/src/SAWCore/Recognizer.hs +++ b/saw-core/src/SAWCore/Recognizer.hs @@ -49,7 +49,6 @@ module SAWCore.Recognizer , asLambdaList , asPi , asPiList - , asLocalVar , asConstant , asVariable , asSort @@ -309,30 +308,26 @@ asArrayValue _ = Nothing asStringLit :: Recognizer Term Text asStringLit t = do StringLit i <- asFTermF t; return i -asLambda :: Recognizer Term (LocalName, Term, Term) +asLambda :: Recognizer Term (VarName, Term, Term) asLambda (unwrapTermF -> Lambda s ty body) = return (s, ty, body) asLambda _ = Nothing -asLambdaList :: Term -> ([(LocalName, Term)], Term) +asLambdaList :: Term -> ([(VarName, Term)], Term) asLambdaList = go [] where go r (asLambda -> Just (nm,tp,rhs)) = go ((nm,tp):r) rhs go r rhs = (reverse r, rhs) -asPi :: Recognizer Term (LocalName, Term, Term) +asPi :: Recognizer Term (VarName, Term, Term) asPi (unwrapTermF -> Pi nm tp body) = return (nm, tp, body) asPi _ = Nothing -- | Decomposes a term into a list of pi bindings, followed by a right -- term that is not a pi binding. -asPiList :: Term -> ([(LocalName, Term)], Term) +asPiList :: Term -> ([(VarName, Term)], Term) asPiList = go [] where go r (asPi -> Just (nm,tp,rhs)) = go ((nm,tp):r) rhs go r rhs = (reverse r, rhs) -asLocalVar :: Recognizer Term DeBruijnIndex -asLocalVar (unwrapTermF -> LocalVar i) = return i -asLocalVar _ = Nothing - asConstant :: Recognizer Term Name asConstant (unwrapTermF -> Constant nm) = pure nm asConstant _ = Nothing @@ -340,8 +335,8 @@ asConstant _ = Nothing asVariable :: Recognizer Term (ExtCns Term) asVariable t = case unwrapTermF t of - Variable ec -> pure ec - _ -> Nothing + Variable nm tp -> pure (EC nm tp) + _ -> Nothing asSort :: Recognizer Term Sort asSort t = do diff --git a/saw-core/src/SAWCore/Rewriter.hs b/saw-core/src/SAWCore/Rewriter.hs index ce2919381d..2fb6ec9dff 100644 --- a/saw-core/src/SAWCore/Rewriter.hs +++ b/saw-core/src/SAWCore/Rewriter.hs @@ -59,7 +59,7 @@ module SAWCore.Rewriter , hoistIfs ) where -import Control.Monad (MonadPlus(..), (>=>), guard, unless) +import Control.Monad (MonadPlus(..), (>=>), guard) import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.Maybe import Data.IntMap (IntMap) @@ -167,7 +167,7 @@ firstOrderMatch ctxt pat term = match pat term IntMap.empty match :: Term -> Term -> IntMap Term -> Maybe (IntMap Term) match x y m = case (unwrapTermF x, unwrapTermF y) of - (Variable (ecVarIndex -> i), _) | IntSet.member i ixs -> + (Variable (vnIndex -> i) _, _) | IntSet.member i ixs -> case my' of Nothing -> Just m' Just y' -> if alphaEquiv y y' then Just m' else Nothing @@ -230,7 +230,7 @@ scMatch :: scMatch sc ctxt pat term = runMaybeT $ do -- lift $ putStrLn $ "********** scMatch **********" - MatchState inst cs <- match 0 [] pat term emptyMatchState + MatchState inst cs <- match IntSet.empty pat term emptyMatchState mapM_ (check inst) cs return inst where @@ -253,61 +253,42 @@ scMatch sc ctxt pat term = -- Check if a term is a higher-order variable pattern, i.e., a free variable -- (meaning one that can match anything) applied to 0 or more bound variable - -- arguments. Depth is the number of variables bound by lambdas or pis since - -- the top of the current pattern, so "bound" means less than the current depth - asVarPat :: Int -> Term -> Maybe (VarIndex, [DeBruijnIndex]) - asVarPat depth = go [] + -- arguments. + asVarPat :: IntSet -> Term -> Maybe (VarIndex, [ExtCns Term]) + asVarPat locals = go [] where go js x = case unwrapTermF x of - Variable ec - | IntSet.member (ecVarIndex ec) ixs -> Just (ecVarIndex ec, js) + Variable nm _tp + | IntSet.member (vnIndex nm) ixs -> Just (vnIndex nm, js) | otherwise -> Nothing - App t (unwrapTermF -> LocalVar j) - | j < depth -> go (j : js) t + App t (unwrapTermF -> Variable nm tp) + | IntSet.member (vnIndex nm) locals -> go (EC nm tp : js) t _ -> Nothing -- Test if term y matches pattern x, meaning whether there is a substitution - -- to the free variables of x to make it equal to y. Depth is the number of - -- bound variables, so a "free" variable is a deBruijn index >= depth. Env - -- saves the names associated with those bound variables. - match :: Int -> [(LocalName, Term)] -> Term -> Term -> MatchState -> - MaybeT IO MatchState - match _ _ t@(STApp{stAppIndex = i}) (STApp{stAppIndex = j}) s - | termIsClosed t && i == j = return s - match depth env x y s@(MatchState m cs) = + -- to the free variables of x to make it equal to y. + -- The IntSet contains the VarIndexes named variables that are locally bound. + match :: IntSet -> Term -> Term -> MatchState -> MaybeT IO MatchState + match _ t@(STApp{stAppIndex = i}) (STApp{stAppIndex = j}) s + | closedTerm t && i == j = return s + match locals x y s@(MatchState m cs) = -- (lift $ putStrLn $ "matching (lhs): " ++ scPrettyTerm PPS.defaultOpts x) >> -- (lift $ putStrLn $ "matching (rhs): " ++ scPrettyTerm PPS.defaultOpts y) >> - case asVarPat depth x of + case asVarPat locals x of -- If the lhs pattern is of the form (?u b1..bk) where ?u is a -- unification variable and b1..bk are all locally bound -- variables: First check whether the rhs contains any locally -- bound variables *not* in the list b1..bk. If it contains any -- others, then there is no match. If it only uses a subset of -- b1..bk, then we can instantiate ?u to (\b1..bk -> rhs). - Just (i, js) -> + Just (i, vs) -> do -- ensure parameter variables are distinct - guard (Set.size (Set.fromList js) == length js) - -- ensure y mentions only variables that are in js - let fvj = foldl unionBitSets emptyBitSet (map singletonBitSet js) - let fvy = looseVars y `intersectBitSets` (completeBitSet depth) - guard (fvy `unionBitSets` fvj == fvj) - let fixVar t (nm, ty) = - do ec <- scFreshEC sc nm ty - v <- scVariable sc ec - t' <- instantiateVar sc 0 v t - return (t', ec) - let fixVars t [] = return (t, []) - fixVars t (ty : tys) = - do (t', ec) <- fixVar t ty - (t'', ecs) <- fixVars t' tys - return (t'', ec : ecs) - -- replace local bound variables with global ones - -- this also decrements loose variables in y by `depth` - (y1, ecs) <- lift $ fixVars y env - -- replace global variables with reindexed bound vars - -- y2 should have no more of the newly-created ExtCns vars - y2 <- lift $ scAbstractExts sc [ ecs !! j | j <- js ] y1 + guard (Set.size (Set.fromList vs) == length vs) + -- ensure y mentions only variables that are in vs + let vset = IntSet.fromList (map ecVarIndex vs) + guard (IntSet.disjoint (IntSet.difference locals vset) (freeVars y)) + y2 <- lift $ scAbstractExts sc vs y let (my3, m') = insertLookup i y2 m case my3 of Nothing -> return (MatchState m' cs) @@ -318,18 +299,18 @@ scMatch sc ctxt pat term = | Just [x'] <- R.asGlobalApply preludeSuccIdent x , n > 0 -> do y' <- lift $ scNat sc (n-1) - match depth env x' y' s + match locals x' y' s -- check that neither x nor y contains bound variables less than `depth` (FTermF xf, FTermF yf) -> - case zipWithFlatTermF (match depth env) xf yf of + case zipWithFlatTermF (match locals) xf yf of Nothing -> mzero Just zf -> Foldable.foldl (>=>) return zf s (App x1 x2, App y1 y2) -> - match depth env x1 y1 s >>= match depth env x2 y2 - (Lambda _ t1 x1, Lambda nm t2 x2) -> - match depth env t1 t2 s >>= match (depth + 1) ((nm, t2) : env) x1 x2 - (Pi _ t1 x1, Pi nm t2 x2) -> - match depth env t1 t2 s >>= match (depth + 1) ((nm, t2) : env) x1 x2 + match locals x1 y1 s >>= match locals x2 y2 + (Lambda nm t1 x1, Lambda _ t2 x2) -> + match locals t1 t2 s >>= match (IntSet.insert (vnIndex nm) locals) x1 x2 + (Pi nm t1 x1, Pi _ t2 x2) -> + match locals t1 t2 s >>= match (IntSet.insert (vnIndex nm) locals) x1 x2 (App _ _, FTermF (NatLit n)) -> -- add deferred constraint return (MatchState m ((x, n) : cs)) @@ -372,11 +353,12 @@ intModEqIdent = mkIdent (mkModuleName ["Prelude"]) "intModEq" -- | Converts a universally quantified equality proposition from a -- Term representation to a RewriteRule. -ruleOfTerm :: SharedContext -> Term -> Maybe a -> IO (RewriteRule a) -ruleOfTerm sc t ann = - do (ecs, body) <- scAsPiList sc t +ruleOfTerm :: Term -> Maybe a -> RewriteRule a +ruleOfTerm t ann = + do let (vars, body) = R.asPiList t + let ecs = map (uncurry EC) vars case R.asGlobalApply eqIdent body of - Just [_, x, y] -> pure $ mkRewriteRule ecs x y False ann + Just [_, x, y] -> mkRewriteRule ecs x y False ann _ -> panic "ruleOfTerm" ["Illegal argument"] -- Test whether a rewrite rule is permutative @@ -410,15 +392,15 @@ ruleOfTerms l r = mkRewriteRule [] l r False Nothing -- returning 'Nothing' if the predicate is not an equation. ruleOfProp :: SharedContext -> Term -> Maybe a -> IO (Maybe (RewriteRule a)) ruleOfProp sc term ann = - scAsPi sc term >>= \case - Just (ec, body) -> + case R.asPi term of + Just (nm, tp, body) -> do rule <- ruleOfProp sc body ann - pure $ (\r -> r { ctxt = ec : ctxt r}) <$> rule + pure $ (\r -> r { ctxt = EC nm tp : ctxt r}) <$> rule Nothing -> - scAsLambda sc term >>= \case - Just (ec, body) -> + case R.asLambda term of + Just (nm, tp, body) -> do rule <- ruleOfProp sc body ann - pure $ (\r -> r { ctxt = ec : ctxt r}) <$> rule + pure $ (\r -> r { ctxt = EC nm tp : ctxt r}) <$> rule Nothing -> case term of (R.asGlobalApply ecEqIdent -> Just [_, _, x, y]) -> eqRule x y @@ -451,9 +433,7 @@ ruleOfProp sc term ann = -- | Generate a rewrite rule from the type of an identifier, using 'ruleOfTerm' scEqRewriteRule :: SharedContext -> Ident -> IO (RewriteRule a) -scEqRewriteRule sc i = - do ty <- scTypeOfIdent sc i - ruleOfTerm sc ty Nothing +scEqRewriteRule sc i = ruleOfTerm <$> scTypeOfIdent sc i <*> pure Nothing -- | Collects rewrite rules from named constants, whose types must be equations. scEqsRewriteRules :: SharedContext -> [Ident] -> IO [RewriteRule a] @@ -467,9 +447,10 @@ scEqsRewriteRules sc = mapM (scEqRewriteRule sc) -- * If the rhs is a record, then split into a separate rule for each accessor. scExpandRewriteRule :: SharedContext -> RewriteRule a -> IO (Maybe [RewriteRule a]) scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _ shallow ann) = - scAsLambda sc rhs >>= \case - Just (ec, body) -> - do let ctxt' = ctxt ++ [ec] + case R.asLambda rhs of + Just (nm, tp, body) -> + do let ec = EC nm tp + let ctxt' = ctxt ++ [ec] var0 <- scVariable sc ec lhs' <- scApply sc lhs var0 pure $ Just [mkRewriteRule ctxt' lhs' body shallow ann] @@ -497,7 +478,7 @@ scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _ shallow ann) = let ctorRule ctor = do -- Compute the argument types @argTs@. ctorT <- piAppType (ctorType ctor) params1 - argECs <- fst <$> scAsPiList sc ctorT + let argECs = map (uncurry EC) $ fst $ R.asPiList ctorT -- Build a fully-applied constructor @c@. args <- traverse (scVariable sc) argECs c <- scConstApply sc (ctorName ctor) (params1 ++ args) @@ -546,7 +527,8 @@ scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _ shallow ann) = do f' <- betaReduce f case R.asLambda f' of Nothing -> scApply sc f' arg - Just (_, _, body) -> instantiateVar sc 0 arg body + Just (vn, _, body) -> + scInstantiateExt sc (IntMap.singleton (vnIndex vn) arg) body -- | Repeatedly apply the rule transformations in 'scExpandRewriteRule'. scExpandRewriteRules :: SharedContext -> [RewriteRule a] -> IO [RewriteRule a] @@ -597,11 +579,11 @@ delRule rule = Net.delete_term (lhs rule, Left rule) addRules :: [RewriteRule a] -> Simpset a -> Simpset a addRules rules ss = foldr addRule ss rules -addSimp :: SharedContext -> Term -> Maybe a -> Simpset a -> IO (Simpset a) -addSimp sc prop ann ss = flip addRule ss <$> ruleOfTerm sc prop ann +addSimp :: Term -> Maybe a -> Simpset a -> Simpset a +addSimp prop ann = addRule (ruleOfTerm prop ann) -delSimp :: SharedContext -> Term -> Simpset a -> IO (Simpset a) -delSimp sc prop ss = flip delRule ss <$> ruleOfTerm sc prop Nothing +delSimp :: Term -> Simpset a -> Simpset a +delSimp prop = delRule (ruleOfTerm prop Nothing) addConv :: Conversion -> Simpset a -> Simpset a addConv conv = Net.insert_term (conv, Right conv) @@ -621,7 +603,7 @@ listRules ss = [ r | Left r <- Net.content ss ] ---------------------------------------------------------------------- -- Destructors for terms -asBetaRedex :: R.Recognizer Term (LocalName, Term, Term, Term) +asBetaRedex :: R.Recognizer Term (VarName, Term, Term, Term) asBetaRedex t = do (f, arg) <- R.asApp t (s, ty, body) <- R.asLambda f @@ -700,7 +682,8 @@ termWeightLt t t' = -- | Do a single reduction step (beta, record or tuple selector) at top -- level, if possible. reduceSharedTerm :: SharedContext -> Term -> IO (Maybe Term) -reduceSharedTerm sc (asBetaRedex -> Just (_, _, body, arg)) = Just <$> instantiateVar sc 0 arg body +reduceSharedTerm sc (asBetaRedex -> Just (vn, _, body, arg)) = + Just <$> scInstantiateExt sc (IntMap.singleton (vnIndex vn) arg) body reduceSharedTerm _ (asPairRedex -> Just t) = pure (Just t) reduceSharedTerm _ (asRecordRedex -> Just t) = pure (Just t) reduceSharedTerm sc (asNatIotaRedex -> Just (f1, f2, n)) = @@ -826,7 +809,9 @@ rewriteSharedTermTypeSafe sc ss t0 = case unwrapTermF t1 of -- We only rewrite e2 if type of e1 is not a dependent type. -- This prevents rewriting e2 from changing type of @App e1 e2@. - Pi _ _ t | inBitSet 0 (looseVars t) -> App <$> rewriteAll e1 <*> rewriteAll e2 + Pi x _ t + | IntSet.notMember (vnIndex x) (freeVars t) -> + App <$> rewriteAll e1 <*> rewriteAll e2 _ -> App <$> rewriteAll e1 <*> pure e2 Lambda pat t e -> Lambda pat t <$> rewriteAll e Constant{} -> return tf @@ -924,8 +909,6 @@ replaceTerm :: Ord a => Term {- ^ the term in which to perform the replacement -} -> IO (Set a, Term) replaceTerm sc ss (pat, repl) t = do - unless (termIsClosed pat) $ fail $ unwords - [ "replaceTerm: term to replace has free variables!", scPrettyTerm PPS.defaultOpts t ] let rule = ruleOfTerms pat repl let ss' = addRule rule ss rewriteSharedTerm sc ss' t @@ -946,7 +929,7 @@ hoistIfs :: SharedContext hoistIfs sc t = do cache <- newCache - rules <- mapM (\i -> scTypeOfIdent sc i >>= \rt -> ruleOfTerm sc rt Nothing) + rules <- map (\rt -> ruleOfTerm rt Nothing) <$> mapM (scTypeOfIdent sc) [ "Prelude.ite_true" , "Prelude.ite_false" , "Prelude.ite_not" @@ -1028,7 +1011,6 @@ doHoistIfs sc ss hoistCache = go goF :: Term -> TermF Term -> IO (HoistIfs s) - goF t (LocalVar _) = return (t, []) goF t (Constant {}) = return (t, []) goF t (Variable {}) = return (t, []) @@ -1046,13 +1028,13 @@ doHoistIfs sc ss hoistCache = go goF _ (Lambda nm tp body) = goBinder scLambda nm tp body goF _ (Pi nm tp body) = goBinder scPi nm tp body - goBinder close nm tp body = do - (ec, body') <- scOpenTerm sc nm tp 0 body - (body'', conds) <- go body' - let (stuck, float) = List.partition (\(_,ecs) -> Set.member ec ecs) conds + goBinder close nm tp body = + do let ec = EC nm tp + (body'', conds) <- go body + let (stuck, float) = List.partition (\(_,ecs) -> Set.member ec ecs) conds - stuck' <- orderTerms sc (map fst stuck) - body''' <- splitConds sc ss stuck' body'' + stuck' <- orderTerms sc (map fst stuck) + body''' <- splitConds sc ss stuck' body'' - t' <- scCloseTerm close sc ec body''' - return (t', float) + t' <- close sc nm tp body''' + return (t', float) diff --git a/saw-core/src/SAWCore/SCTypeCheck.hs b/saw-core/src/SAWCore/SCTypeCheck.hs index 55c2a6b05d..790f284001 100644 --- a/saw-core/src/SAWCore/SCTypeCheck.hs +++ b/saw-core/src/SAWCore/SCTypeCheck.hs @@ -58,6 +58,8 @@ import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Reader (MonadReader(..), Reader, ReaderT(..), asks, runReader) import Control.Monad.State.Strict (MonadState(..), StateT, evalStateT, modify) +import Data.IntMap (IntMap) +import qualified Data.IntMap as IntMap import Data.Map (Map) import qualified Data.Map as Map import Data.Text (Text) @@ -91,7 +93,7 @@ type TCState = Map TermIndex Term data TCEnv = TCEnv { tcSharedContext :: SharedContext -- ^ the SAW context - , tcCtx :: [(LocalName, Term)] -- ^ the mapping of names to de Bruijn bound variables + , tcCtx :: IntMap Term -- ^ the type environment for variables } -- | The monad for type checking and inference, which: @@ -109,18 +111,14 @@ newtype TCM a = TCM (ReaderT TCEnv (StateT TCState (ExceptT TCError IO)) a) -- | Run a type-checking computation in a given context, starting from the empty -- memoization table runTCM :: - TCM a -> SharedContext -> [(LocalName, Term)] -> IO (Either TCError a) + TCM a -> SharedContext -> IntMap Term -> IO (Either TCError a) runTCM (TCM m) sc ctx = runExceptT $ evalStateT (runReaderT m (TCEnv sc ctx)) Map.empty -- | Read the current typing context -askCtx :: TCM [(LocalName, Term)] +askCtx :: TCM (IntMap Term) askCtx = asks tcCtx --- | Read the current typing context, without names. -askCtx' :: TCM [Term] -askCtx' = map snd <$> askCtx - -- | Run a type-checking computation in a typing context extended with a new -- variable with the given type. This throws away the memoization table while -- running the sub-computation, as memoization tables are tied to specific sets @@ -128,15 +126,15 @@ askCtx' = map snd <$> askCtx -- -- NOTE: the type given for the variable should be in WHNF, so that we do not -- have to normalize the types of variables each time we see them. -withVar :: LocalName -> Term -> TCM a -> TCM a +withVar :: VarName -> Term -> TCM a -> TCM a withVar x tp m = - rethrowTCError (ErrorCtx x tp) $ + rethrowTCError (ErrorCtx (vnName x) tp) $ withEmptyTCState $ - local (\env -> env { tcCtx = (x,tp) : tcCtx env }) m + local (\env -> env { tcCtx = IntMap.insert (vnIndex x) tp (tcCtx env) }) m -- | Run a type-checking computation in a typing context extended by a list of -- variables and their types. See 'withVar'. -withCtx :: [(LocalName, Term)] -> TCM a -> TCM a +withCtx :: [(VarName, Term)] -> TCM a -> TCM a withCtx = flip (foldr (\(x,tp) -> withVar x tp)) -- | Augment and rethrow any 'TCError' thrown by the given computation. @@ -294,7 +292,7 @@ prettyTCError e = runReader (helper e) ([], Nothing) where ishow :: Term -> PPErrM String ishow tm = -- return $ show tm - (\(ctx,_) -> indent " " $ scPrettyTermInCtx PPS.defaultOpts ctx tm) <$> ask + (\(_ctx,_) -> indent " " $ scPrettyTermInCtx PPS.defaultOpts [] tm) <$> ask instance Show TCError where show = unlines . prettyTCError @@ -308,13 +306,13 @@ scTypeCheckError sc t0 = -- well-formed and that all internal type annotations are correct. Types are -- evaluated to WHNF as necessary, and the returned type is in WHNF. scTypeCheck :: TypeInfer a => SharedContext -> a -> IO (Either TCError Term) -scTypeCheck sc = scTypeCheckInCtx sc [] +scTypeCheck sc = scTypeCheckInCtx sc IntMap.empty -- | Like 'scTypeCheck', but type-check the term relative to a typing context, -- which assigns types to free variables in the term scTypeCheckInCtx :: TypeInfer a => SharedContext -> - [(LocalName, Term)] -> a -> IO (Either TCError Term) + IntMap Term -> a -> IO (Either TCError Term) scTypeCheckInCtx sc ctx t0 = runTCM (typeInfer t0) sc ctx -- | Infer the type of an @a@ and complete it to a term using @@ -331,12 +329,12 @@ scTypeCheckCompleteError sc t0 = -- returned type is in WHNF, though the returned term may not be. scTypeCheckComplete :: TypeInfer a => SharedContext -> a -> IO (Either TCError SCTypedTerm) -scTypeCheckComplete sc = scTypeCheckCompleteInCtx sc [] +scTypeCheckComplete sc = scTypeCheckCompleteInCtx sc IntMap.empty -- | Like 'scTypeCheckComplete', but type-check the term relative to a typing -- context, which assigns types to free variables in the term scTypeCheckCompleteInCtx :: TypeInfer a => SharedContext -> - [(LocalName, Term)] -> a -> + IntMap Term -> a -> IO (Either TCError SCTypedTerm) scTypeCheckCompleteInCtx sc ctx t0 = runTCM (typeInferComplete t0) sc ctx @@ -346,14 +344,14 @@ scTypeCheckCompleteInCtx sc ctx t0 = scCheckSubtype :: SharedContext -> SCTypedTerm -> Term -> IO () scCheckSubtype sc arg req_tp = either (fail . unlines . prettyTCError) return =<< - runTCM (checkSubtype arg req_tp) sc [] + runTCM (checkSubtype arg req_tp) sc IntMap.empty -- | An abstract datatype pairing a 'Term' with its type. data SCTypedTerm = SCTypedTerm Term -- ^ value Term -- ^ type - [Term] -- ^ de Bruijn typing context + (IntMap Term) -- ^ typing context -- | The raw 'Term' of an 'SCTypedTerm'. typedVal :: SCTypedTerm -> Term @@ -365,7 +363,7 @@ typedType (SCTypedTerm _ typ _) = typ -- | The de Bruijn typing context of an 'SCTypedTerm', with de Bruijn -- index 0 at the head of the list. -typedCtx :: SCTypedTerm -> [Term] +typedCtx :: SCTypedTerm -> IntMap Term typedCtx (SCTypedTerm _ _ ctx) = ctx -- | The class of things that we can infer types of. The 'typeInfer' method @@ -399,7 +397,7 @@ instance TypeInfer Term where modify (Map.insert i x') return x' typeInferComplete trm = - SCTypedTerm trm <$> typeInfer trm <*> askCtx' + SCTypedTerm trm <$> typeInfer trm <*> askCtx -- Type inference for TermF Term dispatches to that for TermF SCTypedTerm by -- calling inference on all the sub-components and extending the context inside @@ -430,7 +428,7 @@ instance TypeInfer (TermF Term) where typeInfer (Constant nm) = typeInferConstant nm typeInfer t = typeInfer =<< mapM typeInferComplete t typeInferComplete tf = - SCTypedTerm <$> liftTCM scTermF tf <*> withErrorTermF tf (typeInfer tf) <*> askCtx' + SCTypedTerm <$> liftTCM scTermF tf <*> withErrorTermF tf (typeInfer tf) <*> askCtx typeInferConstant :: Name -> TCM Term typeInferConstant nm = @@ -448,7 +446,7 @@ instance TypeInfer (FlatTermF Term) where SCTypedTerm <$> liftTCM scFlatTermF ftf <*> typeInfer ftf - <*> askCtx' + <*> askCtx -- Type inference for TermF SCTypedTerm is the main workhorse. Intuitively, this @@ -459,7 +457,7 @@ instance TypeInfer (TermF SCTypedTerm) where typeInfer (App x@(SCTypedTerm _ x_tp _) y) = applyPiTyped (NotFuncTypeInApp x y) x_tp y typeInfer (Lambda x (SCTypedTerm a a_tp _) (SCTypedTerm _ b _)) = - void (ensureSort a_tp) >> liftTCM scTermF (Pi x a b) + void (ensureSort a_tp) >> liftTCM scPi x a b typeInfer (Pi _ (SCTypedTerm _ a_tp _) (SCTypedTerm _ b_tp _)) = do s1 <- ensureSort a_tp s2 <- ensureSort b_tp @@ -467,26 +465,16 @@ instance TypeInfer (TermF SCTypedTerm) where -- when b is a Prop (this is a forall proposition), otherwise it is a -- (Type (max (sortOf a) (sortOf b))) liftTCM scSort $ if s2 == propSort then propSort else max s1 s2 - typeInfer (LocalVar i) = - do ctx <- askCtx - if i < length ctx then - -- The ith type in the current variable typing context is well-typed - -- relative to the suffix of the context after it, so we have to lift it - -- (i.e., call incVars) to make it well-typed relative to all of ctx - liftTCM incVars 0 (i+1) (snd (ctx !! i)) - else - error ("Context = " ++ show ctx) - -- throwTCError (DanglingVar (i - length ctx)) typeInfer (Constant nm) = typeInferConstant nm - typeInfer (Variable ec) = + typeInfer (Variable _nm tp) = -- FIXME: should we check that the type of ecType is a sort? - typeCheckWHNF $ typedVal $ ecType ec + typeCheckWHNF $ typedVal tp typeInferComplete tf = SCTypedTerm <$> liftTCM scTermF (fmap typedVal tf) <*> withErrorSCTypedTermF tf (typeInfer tf) - <*> askCtx' + <*> askCtx -- Type inference for FlatTermF SCTypedTerm is the main workhorse for flat @@ -535,7 +523,7 @@ instance TypeInfer (FlatTermF SCTypedTerm) where SCTypedTerm <$> liftTCM scFlatTermF (fmap typedVal ftf) <*> withErrorSCTypedTermF (FTermF ftf) (typeInfer ftf) - <*> askCtx' + <*> askCtx -- | Check that @fun_tp=Pi x a b@ and that @arg@ has type @a@, and return the -- result of substituting @arg@ for @x@ in the result type @b@, i.e., @@ -543,9 +531,10 @@ instance TypeInfer (FlatTermF SCTypedTerm) where -- evaluator. If @fun_tp@ is not a pi type, raise the supplied error. applyPiTyped :: TCError -> Term -> SCTypedTerm -> TCM Term applyPiTyped err fun_tp arg = - ensurePiType err fun_tp >>= \(_,arg_tp,ret_tp) -> + ensurePiType err fun_tp >>= \(nm, arg_tp, ret_tp) -> do checkSubtype arg arg_tp - liftTCM instantiateVar 0 (typedVal arg) ret_tp >>= typeCheckWHNF + let sub = IntMap.singleton (vnIndex nm) (typedVal arg) + liftTCM scInstantiateExt sub ret_tp >>= typeCheckWHNF -- | Ensure that a 'Term' matches a recognizer function, normalizing if -- necessary; otherwise throw the supplied 'TCError' @@ -572,7 +561,7 @@ ensureRecordType err tp = ensureRecognizer asRecordType err tp -- | Ensure a 'Term' is a pi type, normalizing if necessary. Return the -- components of that pi type on success; otherwise throw the supplied error. -ensurePiType :: TCError -> Term -> TCM (LocalName, Term, Term) +ensurePiType :: TCError -> Term -> TCM (VarName, Term, Term) ensurePiType err tp = ensureRecognizer asPi err tp -- | Reduce a type to WHNF (using 'scWhnf'), also adding in some conversions for @@ -598,8 +587,17 @@ checkSubtype arg req_tp = -- types, i.e., that both have type Sort s for some s, and that they are both -- already in WHNF isSubtype :: Term -> Term -> TCM Bool -isSubtype (unwrapTermF -> Pi x1 a1 b1) (unwrapTermF -> Pi _ a2 b2) = +isSubtype (unwrapTermF -> Pi x1 a1 b1) (unwrapTermF -> Pi x2 a2 b2) + | x1 == x2 = (&&) <$> areConvertible a1 a2 <*> withVar x1 a1 (isSubtype b1 b2) + | otherwise = + do conv1 <- areConvertible a1 a2 + let ec1 = EC x1 a1 + var1 <- liftTCM scVariable ec1 + let sub = IntMap.singleton (vnIndex x2) var1 + b2' <- liftTCM scInstantiateExt sub b2 + conv2 <- withVar x1 a1 (isSubtype b1 b2') + pure (conv1 && conv2) isSubtype (asSort -> Just s1) (asSort -> Just s2) | s1 <= s2 = return True isSubtype t1' t2' = areConvertible t1' t2' @@ -643,9 +641,7 @@ inferRecursor r = scTypeOfTypedTerm :: SharedContext -> SCTypedTerm -> IO SCTypedTerm scTypeOfTypedTerm sc (SCTypedTerm _tm tp ctx) = do tp_tp <- scTypeOf' sc ctx tp - -- Shrink de Bruijn context if possible - let ctx' = take (bitSetBound (looseVars tp_tp)) ctx - pure (SCTypedTerm tp tp_tp ctx') + pure (SCTypedTerm tp tp_tp ctx) -- | Reduce an 'SCTypedTerm' to WHNF (see also 'scTypeCheckWHNF'). scTypedTermWHNF :: SharedContext -> SCTypedTerm -> IO SCTypedTerm @@ -657,4 +653,4 @@ scGlobalTypedTerm :: SharedContext -> Ident -> IO SCTypedTerm scGlobalTypedTerm sc ident = do tm <- scGlobalDef sc ident tp <- scTypeOfIdent sc ident - pure (SCTypedTerm tm tp []) + pure (SCTypedTerm tm tp IntMap.empty) diff --git a/saw-core/src/SAWCore/SharedTerm.hs b/saw-core/src/SAWCore/SharedTerm.hs index 1b40c5250a..fe69d4d6bf 100644 --- a/saw-core/src/SAWCore/SharedTerm.hs +++ b/saw-core/src/SAWCore/SharedTerm.hs @@ -32,8 +32,6 @@ module SAWCore.SharedTerm -- * Shared terms , Term(..) , TermIndex - , looseVars - , smallestLooseVar , scSharedTerm , unshare , scImport @@ -98,7 +96,6 @@ module SAWCore.SharedTerm , scISort , scSortWithFlags -- *** Variables and constants - , scLocalVar , scConst , scConstApply -- *** Functions and function application @@ -235,18 +232,7 @@ module SAWCore.SharedTerm , scArrayCopy , scArraySet , scArrayRangeEq - -- ** Utilities --- , scTrue --- , scFalse - , scOpenTerm - , scCloseTerm - , scAsLambda - , scAsLambdaList - , scAsPi - , scAsPiList -- ** Variable substitution - , instantiateVar - , instantiateVarList , betaNormalize , getAllExts , getAllExtSet @@ -257,7 +243,6 @@ module SAWCore.SharedTerm , scAbstractExtsEtaCollapse , scGeneralizeExts , scGeneralizeTerms - , incVars , scUnfoldConstants , scUnfoldConstants' , scUnfoldConstantSet @@ -281,7 +266,7 @@ import Control.Monad.IO.Class (MonadIO(..)) import qualified Control.Monad.State.Strict as State import Control.Monad.Trans.Class (MonadTrans(..)) import Data.Bits -import Data.List (inits, find) +import Data.List (find) import Data.Maybe import qualified Data.Foldable as Fold import Data.Foldable (foldl', foldlM, foldrM, maximum) @@ -446,7 +431,7 @@ scConstApply sc i ts = -- | Create a named variable 'Term' from an 'ExtCns'. scVariable :: SharedContext -> ExtCns Term -> IO Term -scVariable sc ec = scTermF sc (Variable ec) +scVariable sc (EC nm tp) = scTermF sc (Variable nm tp) data DuplicateNameException = DuplicateNameException URI instance Exception DuplicateNameException @@ -758,7 +743,6 @@ getTerm cache termF = i <- getUniqueInt let term = STApp { stAppIndex = i , stAppHash = hash termF - , stAppLooseVars = looseTermF (fmap looseVars termF) , stAppFreeVars = freesTermF (fmap freeVars termF) , stAppTermF = termF } @@ -1174,7 +1158,7 @@ scWhnf sc t0 = go xs (asApp -> Just (t, x)) = go (ElimApp x : xs) t go xs (asRecordSelector -> Just (t, n)) = go (ElimProj n : xs) t go xs (asPairSelector -> Just (t, i)) = go (ElimPair i : xs) t - go (ElimApp x : xs) (asLambda -> Just (_, _, body)) = betaReduce xs [x] body + go (ElimApp x : xs) (asLambda -> Just (vn, _, body)) = betaReduce xs [(vn, x)] body go (ElimPair i : xs) (asPairValue -> Just (a, b)) = go xs (if i then b else a) go (ElimProj fld : xs) (asRecordValue -> Just elems) = case Map.lookup fld elems of Just t -> go xs t @@ -1219,11 +1203,13 @@ scWhnf sc t0 = go xs t = foldM reapply t xs betaReduce :: (?cache :: Cache IO TermIndex Term) => - [WHNFElim] -> [Term] -> Term -> IO Term - betaReduce (ElimApp x : xs) vs (asLambda -> Just (_,_,body)) = - betaReduce xs (x:vs) body + [WHNFElim] -> [(VarName, Term)] -> Term -> IO Term + betaReduce (ElimApp x : xs) vs (asLambda -> Just (vn,_,body)) = + betaReduce xs ((vn, x) : vs) body betaReduce xs vs body = - instantiateVarList sc 0 vs body >>= go xs + do let subst = IntMap.fromList [ (vnIndex vn, x) | (vn, x) <- vs ] + body' <- scInstantiateExt sc subst body + go xs body' reapply :: Term -> WHNFElim -> IO Term reapply t (ElimApp x) = scApply sc t x @@ -1265,55 +1251,58 @@ scConvertibleEval :: SharedContext -> IO Bool scConvertibleEval sc eval unfoldConst tm1 tm2 = do c <- newCache - go c tm1 tm2 + go c IntMap.empty tm1 tm2 where whnf :: Cache IO TermIndex Term -> Term -> IO (TermF Term) whnf _c t@(Unshared _) = unwrapTermF <$> eval sc t whnf c t@(STApp{ stAppIndex = idx}) = unwrapTermF <$> useCache c idx (eval sc t) - go :: Cache IO TermIndex Term -> Term -> Term -> IO Bool - go _c (STApp{ stAppIndex = idx1}) (STApp{ stAppIndex = idx2}) - | idx1 == idx2 = return True -- succeed early case - go c t1 t2 = join (goF c <$> whnf c t1 <*> whnf c t2) + go :: Cache IO TermIndex Term -> IntMap VarIndex -> Term -> Term -> IO Bool + go _c vm (STApp{stAppIndex = idx1, stAppFreeVars = vs1}) (STApp{stAppIndex = idx2}) + | IntSet.disjoint vs1 (IntMap.keysSet vm) && idx1 == idx2 = pure True -- succeed early case + go c vm t1 t2 = join (goF c vm <$> whnf c t1 <*> whnf c t2) - goF :: Cache IO TermIndex Term -> TermF Term -> TermF Term -> IO Bool + goF :: Cache IO TermIndex Term -> IntMap VarIndex -> TermF Term -> TermF Term -> IO Bool - goF _c (Constant nx) (Constant ny) | nameIndex nx == nameIndex ny = pure True - goF c (Constant nx) y + goF _c _vm (Constant nx) (Constant ny) | nameIndex nx == nameIndex ny = pure True + goF c vm (Constant nx) y | unfoldConst = do mx <- scFindDefBody sc (nameIndex nx) case mx of - Just x -> join (goF c <$> whnf c x <*> return y) + Just x -> join (goF c vm <$> whnf c x <*> return y) Nothing -> pure False - goF c x (Constant ny) + goF c vm x (Constant ny) | unfoldConst = do my <- scFindDefBody sc (nameIndex ny) case my of - Just y -> join (goF c <$> return x <*> whnf c y) + Just y -> join (goF c vm <$> return x <*> whnf c y) Nothing -> pure False - goF c (FTermF ftf1) (FTermF ftf2) = - case zipWithFlatTermF (go c) ftf1 ftf2 of + goF c vm (FTermF ftf1) (FTermF ftf2) = + case zipWithFlatTermF (go c vm) ftf1 ftf2 of Nothing -> return False Just zipped -> Fold.and <$> traverse id zipped - goF _c (LocalVar i) (LocalVar j) = return (i == j) + goF c vm (App f1 x1) (App f2 x2) = + pure (&&) <*> go c vm f1 f2 <*> go c vm x1 x2 - goF c (App f1 x1) (App f2 x2) = - pure (&&) <*> go c f1 f2 <*> go c x1 x2 + goF c vm (Lambda (vnIndex -> i1) ty1 body1) (Lambda (vnIndex -> i2) ty2 body2) = + pure (&&) <*> go c vm ty1 ty2 <*> go c vm' body1 body2 + where vm' = if i1 == i2 then vm else IntMap.insert i1 i2 vm - goF c (Lambda _ ty1 body1) (Lambda _ ty2 body2) = - pure (&&) <*> go c ty1 ty2 <*> go c body1 body2 + goF c vm (Pi (vnIndex -> i1) ty1 body1) (Pi (vnIndex -> i2) ty2 body2) = + pure (&&) <*> go c vm ty1 ty2 <*> go c vm' body1 body2 + where vm' = if i1 == i2 then vm else IntMap.insert i1 i2 vm - goF c (Pi _ ty1 body1) (Pi _ ty2 body2) = - pure (&&) <*> go c ty1 ty2 <*> go c body1 body2 - - goF c (Variable ec1) (Variable ec2) - | ecVarIndex ec1 == ecVarIndex ec2 = go c (ecType ec1) (ecType ec2) + goF c vm (Variable x1 t1) (Variable x2 t2) + | i' == vnIndex x2 = go c vm t1 t2 + where i' = case IntMap.lookup (vnIndex x1) vm of + Nothing -> vnIndex x1 + Just i -> i -- final catch-all case - goF _c _x _y = return False + goF _c _vm _x _y = pure False -- | Test if two terms are convertible using 'scWhnf' for evaluation scConvertible :: SharedContext @@ -1332,7 +1321,8 @@ reducePi :: SharedContext -> Term -> Term -> IO Term reducePi sc t arg = do t' <- scWhnf sc t case asPi t' of - Just (_, _, body) -> instantiateVar sc 0 arg body + Just (vn, _, body) -> + scInstantiateExt sc (IntMap.singleton (vnIndex vn) arg) body _ -> fail $ unlines ["reducePi: not a Pi term", showTerm t'] @@ -1349,10 +1339,10 @@ scTypeOfIdent sc ident = -- | Computes the type of a term as quickly as possible, assuming that -- the term is well-typed. scTypeOf :: SharedContext -> Term -> IO Term -scTypeOf sc t0 = scTypeOf' sc [] t0 +scTypeOf sc t0 = scTypeOf' sc IntMap.empty t0 --- | A version for open terms; the list argument encodes the type environment. -scTypeOf' :: SharedContext -> [Term] -> Term -> IO Term +-- | A version for open terms; the map argument encodes the type environment. +scTypeOf' :: SharedContext -> IntMap Term -> Term -> IO Term scTypeOf' sc env t0 = State.evalStateT (memo t0) Map.empty where memo :: Term -> State.StateT (Map TermIndex Term) IO Term @@ -1380,28 +1370,28 @@ scTypeOf' sc env t0 = State.evalStateT (memo t0) Map.empty App x y -> do tx <- memo x lift $ reducePi sc tx y - Lambda name tp rhs -> do - rtp <- lift $ scTypeOf' sc (tp : env) rhs - lift $ scTermF sc (Pi name tp rtp) - Pi _ tp rhs -> do - ltp <- sort tp - rtp <- toSort =<< lift (scTypeOf' sc (tp : env) rhs) - - -- NOTE: the rule for type-checking Pi types is that (Pi x a b) is a Prop - -- when b is a Prop (this is a forall proposition), otherwise it is a - -- (Type (max (sortOf a) (sortOf b))) - let srt = if rtp == propSort then propSort else max ltp rtp - - lift $ scSort sc srt - LocalVar i - | i < length env -> lift $ incVars sc 0 (i + 1) (env !! i) - | otherwise -> fail $ "Dangling bound variable: " ++ show (i - length env) + Lambda x tp rhs -> + do let env' = IntMap.insert (vnIndex x) tp env + rtp <- lift $ scTypeOf' sc env' rhs + lift $ scPi sc x tp rtp + Pi x tp rhs -> + do ltp <- sort tp + let env' = IntMap.insert (vnIndex x) tp env + rtp <- toSort =<< lift (scTypeOf' sc env' rhs) + -- NOTE: the rule for type-checking Pi types is that (Pi x a b) is a Prop + -- when b is a Prop (this is a forall proposition), otherwise it is a + -- (Type (max (sortOf a) (sortOf b))) + let srt = if rtp == propSort then propSort else max ltp rtp + lift $ scSort sc srt Constant nm -> do mm <- liftIO $ scGetModuleMap sc case lookupVarIndexInMap (nameIndex nm) mm of Just r -> pure $ resolvedNameType r _ -> panic "scTypeOf'" ["Constant not found: " <> toAbsoluteName (nameInfo nm)] - Variable ec -> pure $ ecType ec + Variable x tp -> + case IntMap.lookup (vnIndex x) env of + Just tx -> pure tx + Nothing -> pure tp ftermf :: FlatTermF Term -> State.StateT (Map TermIndex Term) IO Term ftermf tf = @@ -1492,131 +1482,6 @@ scImport sc t0 = go cache (STApp{ stAppIndex = idx, stAppTermF = tf}) = useCache cache idx (scTermF sc =<< traverse (go cache) tf) --------------------------------------------------------------------------------- --- Instantiating variables - --- | The second argument is a function that takes the number of --- enclosing lambdas and the de Bruijn index of the variable, --- returning the new term to replace it with. -instantiateLocalVars :: - SharedContext -> - (DeBruijnIndex -> DeBruijnIndex -> IO Term) -> - DeBruijnIndex -> Term -> IO Term -instantiateLocalVars sc f initialLevel t0 = - do cache <- newCache - let ?cache = cache in go initialLevel t0 - where - go :: (?cache :: Cache IO (TermIndex, DeBruijnIndex) Term) => - DeBruijnIndex -> Term -> IO Term - go l t = - case t of - Unshared tf -> go' l tf - STApp{ stAppIndex = tidx, stAppTermF = tf } - | termIsClosed t -> return t -- closed terms map to themselves - | otherwise -> useCache ?cache (tidx, l) (go' l tf) - - go' :: (?cache :: Cache IO (TermIndex, DeBruijnIndex) Term) => - DeBruijnIndex -> TermF Term -> IO Term - go' l (FTermF tf) = scFlatTermF sc =<< (traverse (go l) tf) - go' l (App x y) = scTermF sc =<< (App <$> go l x <*> go l y) - go' l (Lambda i tp rhs) = scTermF sc =<< (Lambda i <$> go l tp <*> go (l+1) rhs) - go' l (Pi i lhs rhs) = scTermF sc =<< (Pi i <$> go l lhs <*> go (l+1) rhs) - go' l (LocalVar i) - | i < l = scTermF sc (LocalVar i) - | otherwise = f l i - go' _ tf@(Constant {}) = scTermF sc tf - go' _ tf@(Variable {}) = scTermF sc tf - -instantiateVars :: SharedContext - -> ((Term -> IO Term) -> DeBruijnIndex -> Either (ExtCns Term) DeBruijnIndex -> IO Term) - -> DeBruijnIndex -> Term -> IO Term -instantiateVars sc f initialLevel t0 = - do cache <- newCache - let ?cache = cache in go initialLevel t0 - where - go :: (?cache :: Cache IO (TermIndex, DeBruijnIndex) Term) => - DeBruijnIndex -> Term -> IO Term - go l (Unshared tf) = - go' l tf - go l (STApp{ stAppIndex = tidx, stAppTermF = tf}) = - useCache ?cache (tidx, l) (go' l tf) - - go' :: (?cache :: Cache IO (TermIndex, DeBruijnIndex) Term) => - DeBruijnIndex -> TermF Term -> IO Term - go' l (Variable ec) = f (go l) l (Left ec) - go' l (FTermF tf) = scFlatTermF sc =<< (traverse (go l) tf) - go' l (App x y) = scTermF sc =<< (App <$> go l x <*> go l y) - go' l (Lambda i tp rhs) = scTermF sc =<< (Lambda i <$> go l tp <*> go (l+1) rhs) - go' l (Pi i lhs rhs) = scTermF sc =<< (Pi i <$> go l lhs <*> go (l+1) rhs) - go' l (LocalVar i) - | i < l = scTermF sc (LocalVar i) - | otherwise = f (go l) l (Right i) - go' _ tf@(Constant {}) = scTermF sc tf - --- | @incVars k j t@ increments free variables at least @k@ by @j@. --- e.g., incVars 1 2 (C ?0 ?1) = C ?0 ?3 -incVars :: SharedContext - -> DeBruijnIndex -> DeBruijnIndex -> Term -> IO Term -incVars sc initialLevel j - | j == 0 = return - | otherwise = instantiateLocalVars sc fn initialLevel - where - fn _ i = scTermF sc (LocalVar (i+j)) - --- | Substitute @t0@ for variable @k@ in @t@ and decrement all higher --- dangling variables. -instantiateVar :: SharedContext - -> DeBruijnIndex -> Term -> Term -> IO Term -instantiateVar sc k t0 t = - do cache <- newCache - let ?cache = cache in instantiateLocalVars sc fn k t - where -- Use map reference to memoize instantiated versions of t. - term :: (?cache :: Cache IO DeBruijnIndex Term) => - DeBruijnIndex -> IO Term - term i = useCache ?cache i (incVars sc 0 i t0) - -- Instantiate variable 0. - fn :: (?cache :: Cache IO DeBruijnIndex Term) => - DeBruijnIndex -> DeBruijnIndex -> IO Term - fn i j | j > i = scTermF sc (LocalVar (j - 1)) - | j == i = term i - | otherwise = scTermF sc (LocalVar j) - --- | Substitute @ts@ for variables @[k .. k + length ts - 1]@ and decrement all --- higher deBruijn indices by @length ts@. Assume that deBruijn index 0 in @ts@ --- refers to deBruijn index @k + length ts@ in the current term; i.e., this --- substitution lifts terms in @ts@ by @k@ (plus any additional binders). --- --- For example, @instantiateVarList 0 [x,y,z] t@ is the beta-reduced form of --- --- > Lam (Lam (Lam t)) `App` z `App` y `App` x --- --- Note that the first element of the @ts@ list corresponds to @x@, which is the --- outermost, or last, application. In terms of 'instantiateVar', we can write --- this as: --- --- > instantiateVarList 0 [x,y,z] t == --- > instantiateVar 0 x (instantiateVar 1 (incVars 0 1 y) --- > (instantiateVar 2 (incVars 0 2 z) t)) -instantiateVarList :: SharedContext - -> DeBruijnIndex -> [Term] -> Term -> IO Term -instantiateVarList _ _ [] t = return t -instantiateVarList sc k ts t = - do caches <- mapM (const newCache) ts - instantiateLocalVars sc (fn (zip caches ts)) k t - where - l = length ts - -- Memoize instantiated versions of ts. - term :: (Cache IO DeBruijnIndex Term, Term) - -> DeBruijnIndex -> IO Term - term (cache, x) i = useCache cache i (incVars sc 0 (i-k) x) - -- Instantiate variables [k .. k+l-1]. - fn :: [(Cache IO DeBruijnIndex Term, Term)] - -> DeBruijnIndex -> DeBruijnIndex -> IO Term - fn rs i j | j >= i + l = scTermF sc (LocalVar (j - l)) - | j >= i = term (rs !! (j - i)) i - | otherwise = scTermF sc (LocalVar j) - - -------------------------------------------------------------------------------- -- Beta Normalization @@ -1638,9 +1503,11 @@ betaNormalize sc t0 = let n = length (zip args params) if n == 0 then go3 t else do body' <- go body - f' <- scLambdaList sc (drop n params) body' + let ecs = map (uncurry EC) (drop n params) + f' <- scAbstractExts sc ecs body' args' <- mapM go args - f'' <- instantiateVarList sc 0 (reverse (take n args')) f' + let sub = IntMap.fromList [(vnIndex nm, arg) | (arg, (nm, _)) <- zip args params] + f'' <- scInstantiateExt sc sub f' scApplyAll sc f'' (drop n args') go3 :: (?cache :: Cache IO TermIndex Term) => Term -> IO Term @@ -1661,9 +1528,12 @@ scApplyAll sc = foldlM (scApply sc) -- | Apply a function to an argument, beta-reducing if the function is a lambda scApplyBeta :: SharedContext -> Term -> Term -> IO Term -scApplyBeta sc (asLambda -> Just (_, _, body)) arg = - instantiateVar sc 0 arg body -scApplyBeta sc f arg = scApply sc f arg +scApplyBeta sc f arg = + case asLambda f of + Just (name, _, body) -> + scInstantiateExt sc (IntMap.singleton (vnIndex name) arg) body + Nothing -> + scApply sc f arg -- | Apply a function 'Term' to zero or more arguments, beta reducing any time -- the function is a lambda @@ -1792,8 +1662,9 @@ scFun :: SharedContext -> Term -- ^ The parameter type -> Term -- ^ The result type -> IO Term -scFun sc a b = do b' <- incVars sc 0 1 b - scTermF sc (Pi "_" a b') +scFun sc a b = + do nm <- scFreshVarName sc "_" + scTermF sc (Pi nm a b) -- | Create a term representing the type of a non-dependent n-ary function, -- given a list of parameter types and a result type (as terms). @@ -1807,7 +1678,7 @@ scFunAll sc argTypes resultType = foldrM (scFun sc) resultType argTypes -- (as a 'Term'), and a body. Regarding deBruijn indices, in the body of the -- function, an index of 0 refers to the bound parameter. scLambda :: SharedContext - -> LocalName -- ^ The parameter name + -> VarName -- ^ The parameter name -> Term -- ^ The parameter type -> Term -- ^ The body -> IO Term @@ -1819,7 +1690,7 @@ scLambda sc varname ty body = scTermF sc (Lambda varname ty body) -- parameter in the list, and n-1 (where n is the list length) refers to the -- first. scLambdaList :: SharedContext - -> [(LocalName, Term)] -- ^ List of parameter / parameter type pairs + -> [(VarName, Term)] -- ^ List of parameter / parameter type pairs -> Term -- ^ The body -> IO Term scLambdaList _ [] rhs = return rhs @@ -1827,31 +1698,24 @@ scLambdaList sc ((nm,tp):r) rhs = scLambda sc nm tp =<< scLambdaList sc r rhs -- | Create a (possibly dependent) function given a parameter name, parameter --- type (as a 'Term'), and a body. This function follows the same deBruijn --- index convention as 'scLambda'. +-- type (as a 'Term'), and a body. scPi :: SharedContext - -> LocalName -- ^ The parameter name + -> VarName -- ^ The parameter name -> Term -- ^ The parameter type -> Term -- ^ The body -> IO Term scPi sc nm tp body = scTermF sc (Pi nm tp body) + -- | Create a (possibly dependent) function of multiple arguments (curried) -- from a list associating parameter names to types (as 'Term's) and a body. --- This function follows the same deBruijn index convention as 'scLambdaList'. scPiList :: SharedContext - -> [(LocalName, Term)] -- ^ List of parameter / parameter type pairs + -> [(VarName, Term)] -- ^ List of parameter / parameter type pairs -> Term -- ^ The body -> IO Term scPiList _ [] rhs = return rhs scPiList sc ((nm,tp):r) rhs = scPi sc nm tp =<< scPiList sc r rhs --- | Create a local variable term from a 'DeBruijnIndex'. -scLocalVar :: SharedContext - -> DeBruijnIndex - -> IO Term -scLocalVar sc i = scTermF sc (LocalVar i) - -- | Create an abstract constant with the specified name, body, and -- type. The term for the body must not have any loose de Bruijn -- indices. If the body contains any ExtCns variables, they will be @@ -1862,7 +1726,7 @@ scConstant :: SharedContext -> Term -- ^ The type -> IO Term scConstant sc name rhs ty = - do unless (termIsClosed rhs) $ + do unless (closedTerm rhs) $ fail "scConstant: term contains loose variables" unless (null (getAllExts rhs)) $ fail $ unlines @@ -1891,7 +1755,7 @@ scConstant' :: SharedContext -> Term -- ^ The type -> IO Term scConstant' sc nmi rhs ty = - do unless (termIsClosed rhs) $ + do unless (closedTerm rhs) $ fail "scConstant': term contains loose variables" unless (null (getAllExts rhs)) $ fail $ unlines @@ -2739,18 +2603,33 @@ getAllExts t = Set.toList (getAllExtSet t) -- | Return a set of all ExtCns subterms in the given term. -- Does not traverse the unfoldings of @Constant@ terms. getAllExtSet :: Term -> Set.Set (ExtCns Term) -getAllExtSet t = snd $ getExtCns (IntSet.empty, Set.empty) t - where getExtCns acc@(is, _) (STApp{ stAppIndex = idx }) | IntSet.member idx is = acc - getExtCns (is, a) (STApp{ stAppIndex = idx, stAppTermF = (Variable ec) }) = - (IntSet.insert idx is, Set.insert ec a) - getExtCns (is, a) (Unshared (Variable ec)) = - (is, Set.insert ec a) - getExtCns acc (STApp{ stAppTermF = Constant {} }) = acc - getExtCns acc (Unshared (Constant {})) = acc - getExtCns (is, a) (STApp{ stAppIndex = idx, stAppTermF = tf'}) = - foldl' getExtCns (IntSet.insert idx is, a) tf' - getExtCns acc (Unshared tf') = - foldl' getExtCns acc tf' +getAllExtSet t = State.evalState (go t) IntMap.empty + where + go :: Term -> State.State (IntMap (Set.Set (ExtCns Term))) (Set.Set (ExtCns Term)) + go (Unshared tf) = termf tf + go STApp{ stAppIndex = i, stAppTermF = tf, stAppFreeVars = fvs } + | IntSet.null fvs = pure Set.empty + | otherwise = + do memo <- State.get + case IntMap.lookup i memo of + Just ecs -> pure ecs + Nothing -> + do ecs <- termf tf + State.modify' (IntMap.insert i ecs) + pure ecs + termf :: TermF Term -> State.State (IntMap (Set.Set (ExtCns Term))) (Set.Set (ExtCns Term)) + termf tf = + case tf of + Variable x tp -> pure (Set.singleton (EC x tp)) + Lambda x t1 t2 -> + do ecs1 <- go t1 + ecs2 <- go t2 + pure (ecs1 <> Set.delete (EC x t1) ecs2) + Pi x t1 t2 -> + do ecs1 <- go t1 + ecs2 <- go t2 + pure (ecs1 <> Set.delete (EC x t1) ecs2) + _ -> Fold.fold <$> traverse go tf getConstantSet :: Term -> Map VarIndex NameInfo getConstantSet t = snd $ go (IntSet.empty, Map.empty) t @@ -2765,26 +2644,14 @@ getConstantSet t = snd $ go (IntSet.empty, Map.empty) t Constant (Name vidx n) -> (idxs, Map.insert vidx n names) _ -> foldl' go acc tf --- | Instantiate some of the external constants. --- Note: this replacement is _not_ applied recursively --- to the terms in the replacement map; so external constants --- in those terms will not be replaced. +-- | Instantiate some of the named variables in the term. +-- The 'IntMap' is keyed by 'VarIndex'. +-- Note: The replacement is _not_ applied recursively +-- to the terms in the substitution map. scInstantiateExt :: SharedContext -> IntMap Term -> Term -> IO Term -scInstantiateExt sc vmap - | all termIsClosed vmap = scInstantiateExtClosed sc vmap - | otherwise = instantiateVars sc fn 0 - where fn _rec l (Left ec) = - case IntMap.lookup (ecVarIndex ec) vmap of - Just t -> incVars sc 0 l t - Nothing -> scVariable sc ec - fn _ _ (Right i) = scLocalVar sc i - --- | Internal variant of 'scInstantiateExt' that requires the --- substituted terms to all be closed, i.e. they must not have any --- loose de Bruijn indices. -scInstantiateExtClosed :: SharedContext -> IntMap Term -> Term -> IO Term -scInstantiateExtClosed sc vmap t0 = - do let vs = IntMap.keysSet vmap +scInstantiateExt sc vmap t0 = + do let domainVars = IntMap.keysSet vmap + let rangeVars = foldMap freeVars vmap tcache <- newCacheIntMap let memo :: Term -> IO Term memo t = @@ -2793,66 +2660,52 @@ scInstantiateExtClosed sc vmap t0 = STApp {stAppIndex = i} -> useCache tcache i (go t) go :: Term -> IO Term go t - | IntSet.disjoint vs (freeVars t) = pure t + | IntSet.disjoint domainVars (freeVars t) = pure t | otherwise = case unwrapTermF t of FTermF ftf -> scFlatTermF sc =<< traverse memo ftf App t1 t2 -> scTermF sc =<< App <$> memo t1 <*> memo t2 - Lambda x t1 t2 -> scTermF sc =<< Lambda x <$> memo t1 <*> memo t2 - Pi x t1 t2 -> scTermF sc =<< Pi x <$> memo t1 <*> memo t2 - LocalVar {} -> pure t + Lambda x t1 t2 -> + do t1' <- memo t1 + (x', t2') <- goBinder x t1' t2 + scLambda sc x' t1' t2' + Pi x t1 t2 -> + do t1' <- memo t1 + (x', t2') <- goBinder x t1' t2 + scPi sc x' t1' t2' Constant {} -> pure t - Variable ec -> - case IntMap.lookup (ecVarIndex ec) vmap of + Variable nm tp -> + case IntMap.lookup (vnIndex nm) vmap of Just t' -> pure t' - Nothing -> pure t + Nothing -> scVariable sc =<< traverse memo (EC nm tp) + goBinder :: VarName -> Term -> Term -> IO (VarName, Term) + goBinder x@(vnIndex -> i) t body + | IntSet.member i rangeVars = + -- Possibility of capture; rename bound variable. + do x' <- scFreshVarName sc (vnName x) + var <- scVariable sc (EC x' t) + let vmap' = IntMap.insert i var vmap + body' <- scInstantiateExt sc vmap' body + pure (x', body') + | IntMap.member i vmap = + -- Shadowing; remove entry from substitution. + do let vmap' = IntMap.delete i vmap + body' <- scInstantiateExt sc vmap' body + pure (x, body') + | otherwise = + -- No possibility of shadowing or capture. + do body' <- memo body + pure (x, body') go t0 --- | Convert the given list of external constants to local variables, --- with the right-most mapping to local variable 0. If the term is --- open (i.e. it contains loose de Bruijn indices) then increment them --- accordingly. -scExtsToLocals :: SharedContext -> [ExtCns Term] -> Term -> IO Term -scExtsToLocals _ [] x = return x -scExtsToLocals sc exts x = instantiateVars sc fn 0 x - where - m = Map.fromList [ (ecVarIndex ec, k) | (ec, k) <- zip (reverse exts) [0 ..] ] - fn r l e = - case e of - Left ec -> - case Map.lookup (ecVarIndex ec) m of - Just k -> scLocalVar sc (l + k) - Nothing -> scVariable sc =<< traverse r ec - Right i -> - scLocalVar sc (i + length exts) - -- | Abstract over the given list of external constants by wrapping -- the given term with lambdas and replacing the external constant -- occurrences with the appropriate local variables. scAbstractExts :: SharedContext -> [ExtCns Term] -> Term -> IO Term scAbstractExts _ [] x = return x -scAbstractExts sc exts x = loop (zip (inits exts) exts) - where - -- each pair contains a single ExtCns and a list of all - -- the ExtCns values that appear before it in the original list. - loop :: [([ExtCns Term], ExtCns Term)] -> IO Term - - -- special case: outermost variable, no need to abstract - -- inside the type of ec - loop (([],ec):ecs) = - do tm' <- loop ecs - scLambda sc (ecShortName ec) (ecType ec) tm' - - -- ordinary case. We need to abstract over all the ExtCns in @begin@ - -- before apply scLambda. This ensures any dependencies between the - -- types are handled correctly. - loop ((begin,ec):ecs) = - do tm' <- loop ecs - tp' <- scExtsToLocals sc begin (ecType ec) - scLambda sc (ecShortName ec) tp' tm' - - -- base case, convert all the exts in the body of x into deBruijn variables - loop [] = scExtsToLocals sc exts x +scAbstractExts sc (ec : ecs) x = + do body <- scAbstractExts sc ecs x + scLambda sc (ecName ec) (ecType ec) body -- | Create a lambda term by abstracting over the list of arguments, -- which must all be named variables (e.g. terms generated by @@ -2895,28 +2748,9 @@ scAbstractExtsEtaCollapse sc = \exts tm -> loop (reverse exts) tm -- occurrences with the appropriate local variables. scGeneralizeExts :: SharedContext -> [ExtCns Term] -> Term -> IO Term scGeneralizeExts _ [] x = return x -scGeneralizeExts sc exts x = loop (zip (inits exts) exts) - where - -- each pair contains a single ExtCns and a list of all - -- the ExtCns values that appear before it in the original list. - loop :: [([ExtCns Term], ExtCns Term)] -> IO Term - - -- specical case: outermost variable, no need to abstract - -- inside the type of ec - loop (([],ec):ecs) = - do tm' <- loop ecs - scPi sc (ecShortName ec) (ecType ec) tm' - - -- ordinary case. We need to abstract over all the ExtCns in @begin@ - -- before apply scLambda. This ensures any dependenices between the - -- types are handled correctly. - loop ((begin,ec):ecs) = - do tm' <- loop ecs - tp' <- scExtsToLocals sc begin (ecType ec) - scPi sc (ecShortName ec) tp' tm' - - -- base case, convert all the exts in the body of x into deBruijn variables - loop [] = scExtsToLocals sc exts x +scGeneralizeExts sc (ec : ecs) x = + do body <- scGeneralizeExts sc ecs x + scPi sc (ecName ec) (ecType ec) body -- | Create a pi term by abstracting over the list of arguments, which -- must all be named variables (e.g. terms generated by 'scVariable' or @@ -3068,81 +2902,3 @@ scTreeSizeAux = go Just sz' -> (sz + sz', seen) Nothing -> (sz + sz', Map.insert idx sz' seen') where (sz', seen') = foldl' go (1, seen) tf - - --- | `openTerm sc nm ty i body` replaces the loose deBruijn variable `i` --- with a fresh external constant (with name `nm`, and type `ty`) in `body`. -scOpenTerm :: SharedContext - -> Text - -> Term - -> DeBruijnIndex - -> Term - -> IO (ExtCns Term, Term) -scOpenTerm sc nm tp idx body = do - ec <- scFreshEC sc nm tp - ec_term <- scVariable sc ec - body' <- instantiateVar sc idx ec_term body - return (ec, body') - --- | `closeTerm close sc ec body` replaces the external constant `ec` in `body` by --- a new deBruijn variable and binds it using the binding form given by 'close'. --- The name and type of the new bound variable are given by the name and type of `ec`. -scCloseTerm :: (SharedContext -> LocalName -> Term -> Term -> IO Term) - -> SharedContext - -> ExtCns Term - -> Term - -> IO Term -scCloseTerm close sc ec body = do - lv <- scLocalVar sc 0 - body' <- scInstantiateExt sc (IntMap.singleton (ecVarIndex ec) lv) =<< incVars sc 0 1 body - close sc (ecShortName ec) (ecType ec) body' - --- | Deconstruct a lambda term into a bound variable and a body, using --- a fresh 'ExtCns' for the bound variable. -scAsLambda :: SharedContext -> Term -> IO (Maybe (ExtCns Term, Term)) -scAsLambda sc t = - case asLambda t of - Nothing -> pure Nothing - Just (nm, tp, body) -> Just <$> scOpenTerm sc nm tp 0 body - --- | Deconstruct a nested lambda term with 0 or more binders into a --- list of bound variables and a body, using a fresh 'ExtCns' for each --- bound variable. -scAsLambdaList :: SharedContext -> Term -> IO ([ExtCns Term], Term) -scAsLambdaList sc = loop [] [] - where - loop ecs vs t = - case asLambda t of - Nothing -> - do t' <- instantiateVarList sc 0 vs t - pure (reverse ecs, t') - Just (nm, tp, body) -> - do tp' <- instantiateVarList sc 0 vs tp - ec <- scFreshEC sc nm tp' - v <- scVariable sc ec - loop (ec : ecs) (v : vs) body - --- | Deconstruct a pi term into a bound variable and a body, using --- a fresh 'ExtCns' for the bound variable. -scAsPi :: SharedContext -> Term -> IO (Maybe (ExtCns Term, Term)) -scAsPi sc t = - case asPi t of - Nothing -> pure Nothing - Just (nm, tp, body) -> Just <$> scOpenTerm sc nm tp 0 body - --- | Deconstruct a nested pi term with 0 or more binders into a list --- of bound variables and a body, using a fresh 'ExtCns' for each --- bound variable. -scAsPiList :: SharedContext -> Term -> IO ([ExtCns Term], Term) -scAsPiList sc = loop [] [] - where - loop ecs vs t = - case asPi t of - Nothing -> - do t' <- instantiateVarList sc 0 vs t - pure (reverse ecs, t') - Just (nm, tp, body) -> - do tp' <- instantiateVarList sc 0 vs tp - ec <- scFreshEC sc nm tp' - v <- scVariable sc ec - loop (ec : ecs) (v : vs) body diff --git a/saw-core/src/SAWCore/Simulator.hs b/saw-core/src/SAWCore/Simulator.hs index 4d65a656b8..7808d1066c 100644 --- a/saw-core/src/SAWCore/Simulator.hs +++ b/saw-core/src/SAWCore/Simulator.hs @@ -46,6 +46,8 @@ import qualified Data.Map as Map import Data.IntMap (IntMap) import qualified Data.IntMap as IntMap import qualified Data.IntMap as IMap +import Data.IntSet (IntSet) +import qualified Data.IntSet as IntSet import Data.Text (Text) import qualified Data.Text as Text import Data.Traversable @@ -113,7 +115,7 @@ data SimulatorConfig l = ------------------------------------------------------------ -- Evaluation of terms -type Env l = [Thunk l] +type Env l = IntMap (Thunk l) -- indexed by VarIndex type EnvIn m l = Env (WithM m l) -- | Meaning of an open term, parameterized by environment of bound variables @@ -148,19 +150,15 @@ evalTermF cfg lam recEval tf env = do x <- recEvalDelay t2 f x _ -> panic "evalTermF" ["Expected VFun"] - Lambda _nm _tp t -> pure $ VFun (\x -> lam t (x : env)) - Pi _nm t1 t2 -> do v <- evalType t1 + Lambda nm _tp t -> pure $ VFun (\x -> lam t (IntMap.insert (vnIndex nm) x env)) + Pi nm t1 t2 -> do v <- evalType t1 body <- - if inBitSet 0 (looseVars t2) then - pure (VDependentPi (\x -> toTValue <$> lam t2 (x : env))) + if IntSet.member (vnIndex nm) (freeVars t2) then + pure (VDependentPi (\x -> toTValue <$> lam t2 (IntMap.insert (vnIndex nm) x env))) else - do -- put dummy values in the environment; the term should never reference them - let val = ready VUnit - VNondependentPi . toTValue <$> lam t2 (val : env) + VNondependentPi . toTValue <$> lam t2 env return $ TValue $ VPiType v body - LocalVar i -> force (env !! i) - Constant nm -> do let r = requireNameInMap nm (simModMap cfg) ty' <- evalType (resolvedNameType r) case simConstant cfg tf nm ty' of @@ -176,8 +174,10 @@ evalTermF cfg lam recEval tf env = Just t -> recEval t Nothing -> simPrimitive cfg nm - Variable ec -> do ec' <- traverse evalType ec - simExtCns cfg tf ec' + Variable nm tp -> do tp' <- evalType tp + case IntMap.lookup (vnIndex nm) env of + Nothing -> simExtCns cfg tf (EC nm tp') + Just x -> force x FTermF ftf -> case ftf of UnitValue -> return VUnit @@ -366,7 +366,7 @@ reduceRecursor r elim c_args argstruct = go elim c_args (map snd (ctorArgs argst Map Ident (PrimIn Id l) -> (ExtCns (TValueIn Id l) -> MValueIn Id l) -> (Name -> TValueIn Id l -> Maybe (MValueIn Id l)) -> - (Name -> Text -> EnvIn Id l -> MValueIn Id l) -> + (Name -> Text -> [ThunkIn Id l] -> MValueIn Id l) -> (VBool (WithM Id l) -> MValueIn Id l -> MValueIn Id l -> MValueIn Id l) -> Id (SimulatorConfigIn Id l) #-} {-# SPECIALIZE evalGlobal :: @@ -375,7 +375,7 @@ reduceRecursor r elim c_args argstruct = go elim c_args (map snd (ctorArgs argst Map Ident (PrimIn IO l) -> (ExtCns (TValueIn IO l) -> MValueIn IO l) -> (Name -> TValueIn IO l -> Maybe (MValueIn IO l)) -> - (Name -> Text -> EnvIn IO l -> MValueIn IO l) -> + (Name -> Text -> [ThunkIn IO l] -> MValueIn IO l) -> (VBool (WithM IO l) -> MValueIn IO l -> MValueIn IO l -> MValueIn IO l) -> IO (SimulatorConfigIn IO l) #-} evalGlobal :: forall l. (VMonadLazy l, MonadFix (EvalM l), Show (Extra l)) => @@ -383,7 +383,7 @@ evalGlobal :: forall l. (VMonadLazy l, MonadFix (EvalM l), Show (Extra l)) => Map Ident (Prims.Prim l) -> (ExtCns (TValue l) -> MValue l) -> (Name -> TValue l -> Maybe (EvalM l (Value l))) -> - (Name -> Text -> Env l -> MValue l) -> + (Name -> Text -> [Thunk l] -> MValue l) -> (VBool l -> MValue l -> MValue l -> MValue l) -> EvalM l (SimulatorConfig l) evalGlobal modmap prims extcns uninterpreted primHandler lazymux = @@ -395,7 +395,7 @@ evalGlobal modmap prims extcns uninterpreted primHandler lazymux = Map Ident (PrimIn Id l) -> (TermF Term -> ExtCns (TValueIn Id l) -> MValueIn Id l) -> (TermF Term -> Name -> TValueIn Id l -> Maybe (MValueIn Id l)) -> - (Name -> Text -> EnvIn Id l -> MValueIn Id l) -> + (Name -> Text -> [ThunkIn Id l] -> MValueIn Id l) -> (VBool l -> MValueIn Id l -> MValueIn Id l -> MValueIn Id l) -> Id (SimulatorConfigIn Id l) #-} {-# SPECIALIZE evalGlobal' :: @@ -404,7 +404,7 @@ evalGlobal modmap prims extcns uninterpreted primHandler lazymux = Map Ident (PrimIn IO l) -> (TermF Term -> ExtCns (TValueIn IO l) -> MValueIn IO l) -> (TermF Term -> Name -> TValueIn IO l -> Maybe (MValueIn IO l)) -> - (Name -> Text -> EnvIn IO l -> MValueIn IO l) -> + (Name -> Text -> [ThunkIn IO l] -> MValueIn IO l) -> (VBool l -> MValueIn IO l -> MValueIn IO l -> MValueIn IO l) -> IO (SimulatorConfigIn IO l) #-} -- | A variant of 'evalGlobal' that lets the uninterpreted function @@ -419,7 +419,7 @@ evalGlobal' :: -- | Overrides for Constant terms (e.g. uninterpreted functions) (TermF Term -> Name -> TValue l -> Maybe (MValue l)) -> -- | Handler for stuck primitives - (Name -> Text -> Env l -> MValue l) -> + (Name -> Text -> [Thunk l] -> MValue l) -> -- | Lazy mux operation (VBool l -> MValue l -> MValue l -> MValue l) -> EvalM l (SimulatorConfig l) @@ -508,7 +508,7 @@ evalSharedTerm :: (VMonadLazy l, MonadFix (EvalM l), Show (Extra l)) => SimulatorConfig l -> Term -> MValue l evalSharedTerm cfg t = do memoClosed <- mkMemoClosed cfg t - evalOpen cfg memoClosed t [] + evalOpen cfg memoClosed t IntMap.empty {-# SPECIALIZE mkMemoClosed :: Show (Extra l) => @@ -525,9 +525,9 @@ mkMemoClosed cfg t = where -- | Map of all closed subterms of t. subterms :: IntMap (TermF Term) - subterms = fmap fst $ IMap.filter ((== emptyBitSet) . snd) $ State.execState (go t) IMap.empty + subterms = fmap fst $ IMap.filter (IntSet.null . snd) $ State.execState (go t) IMap.empty - go :: Term -> State.State (IntMap (TermF Term, BitSet)) BitSet + go :: Term -> State.State (IntMap (TermF Term, IntSet)) IntSet go (Unshared tf) = termf tf go (STApp{ stAppIndex = i, stAppTermF = tf }) = do memo <- State.get @@ -538,7 +538,7 @@ mkMemoClosed cfg t = State.modify (IMap.insert i (tf, b)) pure b - termf :: TermF Term -> State.State (IntMap (TermF Term, BitSet)) BitSet + termf :: TermF Term -> State.State (IntMap (TermF Term, IntSet)) IntSet termf tf = do -- if tf is a defined constant, traverse the definition body and type case tf of @@ -549,7 +549,7 @@ mkMemoClosed cfg t = ResolvedDef (defBody -> Just body) -> void $ go body _ -> pure () _ -> pure () - looseTermF <$> traverse go tf + freesTermF <$> traverse go tf {-# SPECIALIZE evalClosedTermF :: Show (Extra l) => @@ -569,10 +569,10 @@ evalClosedTermF :: (VMonadLazy l, Show (Extra l)) => SimulatorConfig l -> IntMap (Thunk l) -> TermF Term -> MValue l -evalClosedTermF cfg memoClosed tf = evalTermF cfg lam recEval tf [] +evalClosedTermF cfg memoClosed tf = evalTermF cfg lam recEval tf IntMap.empty where lam = evalOpen cfg memoClosed - recEval (Unshared tf') = evalTermF cfg lam recEval tf' [] + recEval (Unshared tf') = evalTermF cfg lam recEval tf' IntMap.empty recEval (STApp{ stAppIndex = i }) = case IMap.lookup i memoClosed of Just x -> force x @@ -602,7 +602,7 @@ mkMemoLocal cfg memoClosed t env = go mempty t go :: IntMap (Thunk l) -> Term -> EvalM l (IntMap (Thunk l)) go memo (Unshared tf) = goTermF memo tf go memo (t'@STApp{ stAppIndex = i, stAppTermF = tf }) - | termIsClosed t' = pure memo + | closedTerm t' = pure memo | otherwise = case IMap.lookup i memo of Just _ -> pure memo @@ -618,9 +618,8 @@ mkMemoLocal cfg memoClosed t env = go mempty t go memo' t2 Lambda _ t1 _ -> go memo t1 Pi _ t1 _ -> go memo t1 - LocalVar _ -> pure memo Constant{} -> pure memo - Variable ec -> go memo (ecType ec) + Variable _nm tp -> go memo tp {-# SPECIALIZE evalLocalTermF :: Show (Extra l) => @@ -649,7 +648,7 @@ evalLocalTermF cfg memoClosed memoLocal tf0 env = evalTermF cfg lam recEval tf0 case IMap.lookup i memo of Just x -> force x Nothing -> evalTermF cfg lam recEval tf env - where memo = if termIsClosed t then memoClosed else memoLocal + where memo = if closedTerm t then memoClosed else memoLocal {-# SPECIALIZE evalOpen :: Show (Extra l) => @@ -678,7 +677,7 @@ evalOpen cfg memoClosed t env = do case IMap.lookup i memo of Just x -> force x Nothing -> evalF tf - where memo = if termIsClosed t' then memoClosed else memoLocal + where memo = if closedTerm t' then memoClosed else memoLocal evalF :: TermF Term -> MValue l evalF tf = evalTermF cfg (evalOpen cfg memoClosed) eval tf env eval t @@ -686,23 +685,23 @@ evalOpen cfg memoClosed t env = do {-# SPECIALIZE evalPrim :: Show (Extra l) => - (Text -> EnvIn Id l -> MValueIn Id l) -> + (Text -> [ThunkIn Id l] -> MValueIn Id l) -> PrimIn Id l -> MValueIn Id l #-} {-# SPECIALIZE evalPrim :: Show (Extra l) => - (Text -> EnvIn IO l -> MValueIn IO l) -> + (Text -> [ThunkIn IO l] -> MValueIn IO l) -> PrimIn IO l -> MValueIn IO l #-} evalPrim :: forall l. (VMonadLazy l, Show (Extra l)) => - (Text -> Env l -> MValue l) -> + (Text -> [Thunk l] -> MValue l) -> Prims.Prim l -> MValue l evalPrim fallback = loop [] where - loop :: Env l -> Prims.Prim l -> MValue l + loop :: [Thunk l] -> Prims.Prim l -> MValue l loop env (Prims.PrimFun f) = pure $ VFun $ \x -> loop (x : env) (f x) @@ -728,7 +727,7 @@ evalPrim fallback = loop [] -- | A basic handler for stuck primitives. defaultPrimHandler :: (VMonadLazy l, MonadFail (EvalM l)) => - Name -> Text -> Env l -> MValue l + Name -> Text -> [Thunk l] -> MValue l defaultPrimHandler nm msg env = fail $ unlines [ "Could not evaluate primitive " ++ Text.unpack (toAbsoluteName (nameInfo nm)) diff --git a/saw-core/src/SAWCore/Term/CtxTerm.hs b/saw-core/src/SAWCore/Term/CtxTerm.hs index 65308c082f..234a08fa14 100644 --- a/saw-core/src/SAWCore/Term/CtxTerm.hs +++ b/saw-core/src/SAWCore/Term/CtxTerm.hs @@ -81,58 +81,55 @@ asCtorDTApp _ _ _ _ = Nothing -- | Check that an argument for a constructor has one of the allowed forms asCtorArg :: - SharedContext -> Name -> [ExtCns Term] -> [index] -> Term -> - IO (Maybe CtorArg) -asCtorArg _ d _ _ tp + Maybe CtorArg +asCtorArg d _ _ tp | not (usesDataType d tp) - = pure $ Just (ConstArg tp) -asCtorArg sc d params dt_ixs tp = - do (zs, ret) <- scAsPiList sc tp + = Just (ConstArg tp) +asCtorArg d params dt_ixs tp = + do let (zs, ret) = asPiList tp case asCtorDTApp d params dt_ixs ret of Just ixs - | not (any (usesDataType d . ecType) zs) -> - pure $ Just (RecursiveArg zs ixs) + | not (any (usesDataType d . snd) zs) -> + Just (RecursiveArg (map (uncurry EC) zs) ixs) _ -> - pure Nothing + Nothing -- | Check that a constructor type is a pi-abstraction that takes as input an -- argument of one of the allowed forms described by 'CtorArg' asPiCtorArg :: - SharedContext -> Name -> [ExtCns Term] -> [index] -> Term -> - IO (Maybe (VarName, CtorArg, Term)) -asPiCtorArg sc d params dt_ixs t = - scAsPi sc t >>= \case - Nothing -> pure Nothing - Just (ec, rest) -> - asCtorArg sc d params dt_ixs (ecType ec) >>= \case - Nothing -> pure Nothing - Just arg -> pure $ Just (ecName ec, arg, rest) + Maybe (VarName, CtorArg, Term) +asPiCtorArg d params dt_ixs t = + case asPi t of + Nothing -> Nothing + Just (nm, tp, rest) -> + case asCtorArg d params dt_ixs tp of + Nothing -> Nothing + Just arg -> Just (nm, arg, rest) -- | Helper function for 'mkCtorArgStruct' mkCtorArgsIxs :: - SharedContext -> Name -> [ExtCns Term] -> [index] -> Term -> - IO (Maybe ([(VarName, CtorArg)], [Term])) -mkCtorArgsIxs _sc d params dt_ixs (asCtorDTApp d params dt_ixs -> Just ixs) = - pure $ Just ([], ixs) -mkCtorArgsIxs sc d params dt_ixs ty = - asPiCtorArg sc d params dt_ixs ty >>= \case - Nothing -> pure Nothing + Maybe ([(VarName, CtorArg)], [Term]) +mkCtorArgsIxs d params dt_ixs (asCtorDTApp d params dt_ixs -> Just ixs) = + Just ([], ixs) +mkCtorArgsIxs d params dt_ixs ty = + case asPiCtorArg d params dt_ixs ty of + Nothing -> Nothing Just (x, arg, rest) -> - mkCtorArgsIxs sc d params dt_ixs rest >>= \case - Nothing -> pure Nothing - Just (args, ixs) -> pure $ Just ((x, arg) : args, ixs) + case mkCtorArgsIxs d params dt_ixs rest of + Nothing -> Nothing + Just (args, ixs) -> Just ((x, arg) : args, ixs) -- | Take in a datatype and bindings lists for its parameters and indices, and -- also a prospective type of a constructor for that datatype, where the @@ -140,14 +137,13 @@ mkCtorArgsIxs sc d params dt_ixs ty = -- Test that the constructor type is an allowed type for a constructor of this -- datatype, and, if so, build a 'CtorArgStruct' for it. mkCtorArgStruct :: - SharedContext -> Name -> [ExtCns Term] -> [ExtCns Term] -> Term -> - IO (Maybe CtorArgStruct) -mkCtorArgStruct sc d params dt_ixs ctor_tp = - mkCtorArgsIxs sc d params dt_ixs ctor_tp >>= \case - Nothing -> pure Nothing + Maybe CtorArgStruct +mkCtorArgStruct d params dt_ixs ctor_tp = + case mkCtorArgsIxs d params dt_ixs ctor_tp of + Nothing -> Nothing Just (args, ctor_ixs) -> - pure $ Just (CtorArgStruct params args ctor_ixs) + Just (CtorArgStruct params args ctor_ixs) diff --git a/saw-core/src/SAWCore/Term/Functor.hs b/saw-core/src/SAWCore/Term/Functor.hs index 2c1049a394..b49a374990 100644 --- a/saw-core/src/SAWCore/Term/Functor.hs +++ b/saw-core/src/SAWCore/Term/Functor.hs @@ -30,7 +30,6 @@ module SAWCore.Term.Functor , identText , identPieces -- * Data types and definitions - , DeBruijnIndex , FieldName , LocalName , ExtCns(..) @@ -45,7 +44,6 @@ module SAWCore.Term.Functor , TermF(..) , FlatTermF(..) , zipWithFlatTermF - , looseTermF , unwrapTermF , termToPat , alphaEquiv @@ -54,17 +52,14 @@ module SAWCore.Term.Functor , Sort(..), mkSort, propSort, sortOf, maxSort , SortFlags(..), noFlags, sortFlagsLift2, sortFlagsToList, sortFlagsFromList -- * Sets of free variables - , BitSet, emptyBitSet, inBitSet, unionBitSets, intersectBitSets - , decrBitSet, multiDecrBitSet, completeBitSet, singletonBitSet, bitSetElems - , smallestBitSetElem - , bitSetBound - , looseVars, smallestLooseVar, termIsClosed , freesTermF, freeVars + , closedTerm ) where -import Data.Bits import qualified Data.Foldable as Foldable (and, foldl') import Data.Hashable +import Data.IntMap (IntMap) +import qualified Data.IntMap as IntMap import Data.IntSet (IntSet) import qualified Data.IntSet as IntSet import Data.Text (Text) @@ -72,7 +67,6 @@ import qualified Data.Text as Text import Data.Typeable (Typeable) import Data.Vector (Vector) import qualified Data.Vector as V -import Data.Word import GHC.Generics (Generic) import Numeric.Natural @@ -82,13 +76,8 @@ import Instances.TH.Lift () -- for instance TH.Lift Text import SAWCore.Name import qualified SAWCore.TermNet as Net -type DeBruijnIndex = Int type FieldName = Text type LocalName = Text - -- ^ 'LocalName' is used for pretty printing purposes, but does not affect the semantics of SAWCore terms, - -- rather, the 'DeBruijnIndex'-s are what is used to reference variables. - -- FIXME: Verify the above statement - -- FIXME: Possibly, change to a name that suggests this use. instance Hashable a => Hashable (Vector a) where hashWithSalt x v = hashWithSalt x (V.toList v) @@ -333,15 +322,13 @@ data TermF e -- ^ The atomic, or builtin, term constructs | App !e !e -- ^ Applications of functions - | Lambda !LocalName !e !e + | Lambda !VarName !e !e -- ^ Function abstractions - | Pi !LocalName !e !e + | Pi !VarName !e !e -- ^ The type of a (possibly) dependent function - | LocalVar !DeBruijnIndex - -- ^ Local variables are referenced by deBruijn index. | Constant !Name -- ^ A global constant identified by its name. - | Variable !(ExtCns e) + | Variable !VarName !e -- ^ A named variable with a type. deriving (Eq, Ord, Show, Functor, Foldable, Traversable, Generic) @@ -378,9 +365,6 @@ data Term -- ^ The hash, according to 'hash', of the 'stAppTermF' field associated -- with this 'Term'. This should be as unique as a hash can be, but is -- not guaranteed unique as 'stAppIndex' is. - , stAppLooseVars :: !BitSet - -- ^ A set containing the 'DeBruijnIndex' of each of the loose - -- de Bruijn indices from 'LocalVar' constructors in the term. , stAppFreeVars :: !IntSet -- ^ A set containing the 'VarIndex' of each of the free named -- variables from 'Variable' constructors in the term. @@ -441,39 +425,46 @@ equalTerm (STApp{stAppIndex = i1, stAppHash = h1, stAppTermF = tf1}) -- inequality. -- | Return 'True' iff the given terms are equal modulo alpha equivalence (i.e. --- 'LocalNames' in 'Lambda' and 'Pi' expressions) and sharing (i.e. 'STApp' vs. +-- 'VarName's in 'Lambda' and 'Pi' expressions) and sharing (i.e. 'STApp' vs. -- 'Unshared' expressions). alphaEquiv :: Term -> Term -> Bool -alphaEquiv = term +alphaEquiv = term IntMap.empty where - term :: Term -> Term -> Bool - term (Unshared tf1) (Unshared tf2) = termf tf1 tf2 - term (Unshared tf1) (STApp{stAppTermF = tf2}) = termf tf1 tf2 - term (STApp{stAppTermF = tf1}) (Unshared tf2) = termf tf1 tf2 - term (STApp{stAppIndex = i1, stAppTermF = tf1}) - (STApp{stAppIndex = i2, stAppTermF = tf2}) = - i1 == i2 || termf tf1 tf2 - - termf :: TermF Term -> TermF Term -> Bool - termf (FTermF ftf1) (FTermF ftf2) = ftermf ftf1 ftf2 - termf (App t1 u1) (App t2 u2) = term t1 t2 && term u1 u2 - termf (Lambda _ t1 u1) (Lambda _ t2 u2) = term t1 t2 && term u1 u2 - termf (Pi _ t1 u1) (Pi _ t2 u2) = term t1 t2 && term u1 u2 - termf (LocalVar i1) (LocalVar i2) = i1 == i2 - termf (Constant x1) (Constant x2) = x1 == x2 - termf (Variable x1) (Variable x2) = x1 == x2 - termf FTermF{} _ = False - termf App{} _ = False - termf Lambda{} _ = False - termf Pi{} _ = False - termf LocalVar{} _ = False - termf Constant{} _ = False - termf Variable{} _ = False - - ftermf :: FlatTermF Term -> FlatTermF Term -> Bool - ftermf ftf1 ftf2 = case zipWithFlatTermF term ftf1 ftf2 of - Nothing -> False - Just ftf3 -> Foldable.and ftf3 + term :: IntMap VarIndex -> Term -> Term -> Bool + term vm (Unshared tf1) (Unshared tf2) = termf vm tf1 tf2 + term vm (Unshared tf1) (STApp{stAppTermF = tf2}) = termf vm tf1 tf2 + term vm (STApp{stAppTermF = tf1}) (Unshared tf2) = termf vm tf1 tf2 + term vm + (STApp{stAppIndex = i1, stAppTermF = tf1, stAppFreeVars = vs1}) + (STApp{stAppIndex = i2, stAppTermF = tf2}) = + (IntSet.disjoint vs1 (IntMap.keysSet vm) && i1 == i2) || termf vm tf1 tf2 + + termf :: IntMap VarIndex -> TermF Term -> TermF Term -> Bool + termf vm (FTermF ftf1) (FTermF ftf2) = ftermf vm ftf1 ftf2 + termf vm (App t1 u1) (App t2 u2) = term vm t1 t2 && term vm u1 u2 + termf vm (Lambda (vnIndex -> i1) t1 u1) (Lambda (vnIndex -> i2) t2 u2) = + let vm' = if i1 == i2 then vm else IntMap.insert i1 i2 vm + in term vm t1 t2 && term vm' u1 u2 + termf vm (Pi (vnIndex -> i1) t1 u1) (Pi (vnIndex -> i2) t2 u2) = + let vm' = if i1 == i2 then vm else IntMap.insert i1 i2 vm + in term vm t1 t2 && term vm' u1 u2 + termf _vm (Constant x1) (Constant x2) = x1 == x2 + termf vm (Variable x1 _t1) (Variable x2 _t2) = + case IntMap.lookup (vnIndex x1) vm of + Just i -> vnIndex x2 == i + Nothing -> x1 == x2 + termf _ FTermF{} _ = False + termf _ App{} _ = False + termf _ Lambda{} _ = False + termf _ Pi{} _ = False + termf _ Constant{} _ = False + termf _ Variable{} _ = False + + ftermf :: IntMap Int -> FlatTermF Term -> FlatTermF Term -> Bool + ftermf vm ftf1 ftf2 = + case zipWithFlatTermF (term vm) ftf1 ftf2 of + Nothing -> False + Just ftf3 -> Foldable.and ftf3 instance Ord Term where compare (STApp{stAppIndex = i}) (STApp{stAppIndex = j}) | i == j = EQ @@ -498,100 +489,6 @@ unwrapTermF STApp{stAppTermF = tf} = tf unwrapTermF (Unshared tf) = tf --- Free de Bruijn Variables ---------------------------------------------------- - --- | A @BitSet@ represents a set of natural numbers. --- Bit n is a 1 iff n is in the set. -newtype BitSet = BitSet Integer deriving (Eq, Ord, Show) - --- | The empty 'BitSet' -emptyBitSet :: BitSet -emptyBitSet = BitSet 0 - --- | The singleton 'BitSet' -singletonBitSet :: Int -> BitSet -singletonBitSet = BitSet . bit - --- | Test if a number is in a 'BitSet' -inBitSet :: Int -> BitSet -> Bool -inBitSet i (BitSet j) = testBit j i - --- | Union two 'BitSet's -unionBitSets :: BitSet -> BitSet -> BitSet -unionBitSets (BitSet i1) (BitSet i2) = BitSet (i1 .|. i2) - --- | Intersect two 'BitSet's -intersectBitSets :: BitSet -> BitSet -> BitSet -intersectBitSets (BitSet i1) (BitSet i2) = BitSet (i1 .&. i2) - --- | Decrement all elements of a 'BitSet' by 1, removing 0 if it is in the --- set. This is useful for moving a 'BitSet' out of the scope of a variable. -decrBitSet :: BitSet -> BitSet -decrBitSet (BitSet i) = BitSet (shiftR i 1) - --- | Decrement all elements of a 'BitSet' by some non-negative amount @N@, --- removing any value less than @N@. This is the same as calling 'decrBitSet' --- @N@ times. -multiDecrBitSet :: Int -> BitSet -> BitSet -multiDecrBitSet n (BitSet i) = BitSet (shiftR i n) - --- | The 'BitSet' containing all elements less than a given index @i@ -completeBitSet :: Int -> BitSet -completeBitSet i = BitSet (bit i - 1) - --- | Compute the smallest element of a 'BitSet', if any -smallestBitSetElem :: BitSet -> Maybe Int -smallestBitSetElem (BitSet 0) = Nothing -smallestBitSetElem (BitSet i) | i < 0 = error "smallestBitSetElem" -smallestBitSetElem (BitSet i) = Just $ go 0 i where - go :: Int -> Integer -> Int - go !shft !x - | xw == 0 = go (shft+64) (shiftR x 64) - | otherwise = shft + countTrailingZeros xw - where xw :: Word64 - xw = fromInteger x - --- | Compute the list of all elements of a 'BitSet' -bitSetElems :: BitSet -> [Int] -bitSetElems = go 0 where - -- Return the addition of shft to all elements of a BitSet - go :: Int -> BitSet -> [Int] - go shft bs = case smallestBitSetElem bs of - Nothing -> [] - Just i -> - shft + i : go (shft + i + 1) (multiDecrBitSet (i + 1) bs) - --- | Return the smallest non-negative integer greater than every --- element of the 'BitSet'. -bitSetBound :: BitSet -> Int -bitSetBound b = length $ takeWhile (/= emptyBitSet) $ iterate decrBitSet b - --- | Compute the loose de Bruijn indices of a term given the loose --- indices for its immediate subterms. -looseTermF :: TermF BitSet -> BitSet -looseTermF tf = - case tf of - FTermF ftf -> Foldable.foldl' unionBitSets emptyBitSet ftf - App l r -> unionBitSets l r - Lambda _name tp rhs -> unionBitSets tp (decrBitSet rhs) - Pi _name lhs rhs -> unionBitSets lhs (decrBitSet rhs) - LocalVar i -> singletonBitSet i - Constant {} -> emptyBitSet -- assume type is a closed term - Variable ec -> ecType ec - --- | Return a bitset containing indices of all loose de Bruijn indices. -looseVars :: Term -> BitSet -looseVars STApp{ stAppLooseVars = x } = x -looseVars (Unshared f) = looseTermF (fmap looseVars f) - --- | Compute the value of the smallest variable in the term, if any. -smallestLooseVar :: Term -> Maybe Int -smallestLooseVar = smallestBitSetElem . looseVars - --- | Test whether a 'Term' is closed, i.e., has no loose de Bruijn indices. -termIsClosed :: Term -> Bool -termIsClosed t = looseVars t == emptyBitSet - -- Free Named Variables -------------------------------------------------------- -- | Compute an 'IntSet' containing the 'VarIndex' of the free @@ -602,14 +499,17 @@ freesTermF tf = case tf of FTermF ftf -> Foldable.foldl' IntSet.union IntSet.empty ftf App l r -> IntSet.union l r - Lambda _name tp rhs -> IntSet.union tp rhs - Pi _name lhs rhs -> IntSet.union lhs rhs - LocalVar _ -> IntSet.empty + Lambda nm tp rhs -> IntSet.union tp (IntSet.delete (vnIndex nm) rhs) + Pi nm lhs rhs -> IntSet.union lhs (IntSet.delete (vnIndex nm) rhs) Constant {} -> IntSet.empty - Variable ec -> IntSet.singleton (ecVarIndex ec) + Variable nm tp -> IntSet.insert (vnIndex nm) tp -- | Return an 'IntSet' containing the 'VarIndex' of all free -- variables in the 'Term'. freeVars :: Term -> IntSet freeVars STApp{ stAppFreeVars = s } = s freeVars (Unshared tf) = freesTermF (fmap freeVars tf) + +-- | Test whether a 'Term' is closed, i.e., it has no free variables. +closedTerm :: Term -> Bool +closedTerm t = IntSet.null (freeVars t) diff --git a/saw-core/src/SAWCore/Term/Pretty.hs b/saw-core/src/SAWCore/Term/Pretty.hs index 9ea65fde91..d1bec439f0 100644 --- a/saw-core/src/SAWCore/Term/Pretty.hs +++ b/saw-core/src/SAWCore/Term/Pretty.hs @@ -36,6 +36,7 @@ import Control.Monad.Reader (MonadReader(..), Reader, asks, runReader) import Control.Monad.State.Strict (MonadState(..), State, evalState, execState, get, modify) import qualified Data.Foldable as Fold import Data.Hashable (hash) +import qualified Data.IntSet as IntSet import qualified Data.Text as Text import qualified Data.Map as Map import Data.Set (Set) @@ -111,15 +112,6 @@ data VarNaming = VarNaming [LocalName] (IntMap LocalName) (Set LocalName) emptyVarNaming :: Set LocalName -> VarNaming emptyVarNaming reserved = VarNaming [] IntMap.empty reserved --- | Look up a string to use for a 'DeBruijnIndex', if the first --- argument is 'True', or just print the variable number if the first --- argument is 'False'. -lookupDeBruijn :: Bool -> VarNaming -> DeBruijnIndex -> LocalName -lookupDeBruijn True (VarNaming names _ _) i - | i >= length names = Text.pack ('!' : show (i - length names)) -lookupDeBruijn True (VarNaming names _ _) i = names!!i -lookupDeBruijn False _ i = Text.pack ('!' : show i) - -- | Look up a string to use for a 'VarName'. lookupVarName :: VarNaming -> VarName -> LocalName lookupVarName (VarNaming _ renames _) vn = @@ -146,14 +138,6 @@ nextName = Text.pack . reverse . go . reverse . Text.unpack | isDigit c = succ c : cs go cs = '1' : cs --- | Add a new variable with the given base name to the local variable list, --- returning both the fresh name actually used and the new variable list. As a --- special case, if the base name is "_", it is not modified. -consVarNaming :: VarNaming -> LocalName -> (LocalName, VarNaming) -consVarNaming (VarNaming names renames used) name = - let nm = freshName used name - in (nm, VarNaming (nm : names) renames (Set.insert nm used)) - -- | Add a new variable with the given 'VarName' to the 'VarNaming', -- returning both the chosen fresh name and the new 'VarNaming'. -- As a special case, if the base name is "_", it is not modified. @@ -183,11 +167,10 @@ termVarNames t0 = evalState (go t0) IntMap.empty case tf of FTermF ftf -> Fold.fold ftf App e1 e2 -> Set.union e1 e2 - Lambda _ e1 e2 -> Set.union e1 e2 - Pi _ e1 e2 -> Set.union e1 e2 - LocalVar _ -> Set.empty + Lambda x e1 e2 -> Set.union e1 (Set.delete x e2) + Pi x e1 e2 -> Set.union e1 (Set.delete x e2) Constant _ -> Set.empty - Variable ec -> Set.insert (ecName ec) (ecType ec) + Variable vn e1 -> Set.insert vn e1 -------------------------------------------------------------------------------- -- * Pretty-printing monad @@ -257,12 +240,6 @@ instance MonadReader PPState PPM where ask = PPM ask local f (PPM m) = PPM $ local f m --- | Look up the given local variable by deBruijn index to get its name -varLookupM :: DeBruijnIndex -> PPM LocalName -varLookupM idx = - lookupDeBruijn <$> (PPS.ppShowLocalNames <$> ppOpts <$> ask) - <*> (ppNaming <$> ask) <*> return idx - -- | Test if a given term index is memoized, returning its memoization variable -- if so and otherwise returning 'Nothing' memoLookupM :: TermIndex -> PPM (Maybe MemoVar) @@ -288,10 +265,10 @@ atNextDepthM dflt m = -- also erasing the local memoization table (which is no longer valid in an -- extended variable context) during that computation. Return the result of the -- computation and also the name that was actually used for the bound variable. -withBoundVarM :: LocalName -> PPM a -> PPM (LocalName, a) +withBoundVarM :: VarName -> PPM a -> PPM (LocalName, a) withBoundVarM basename m = do st <- ask - let (var, naming) = consVarNaming (ppNaming st) basename + let (var, naming) = insertVarNaming (ppNaming st) basename ret <- local (\_ -> st { ppNaming = naming, ppLocalMemoTable = IntMap.empty }) m return (var, ret) @@ -494,10 +471,6 @@ ppBitsToHex bits = ] where bits' = Text.pack (show bits) --- | Pretty-print an 'ExtCns' according to the current 'VarNaming'. -ppExtCns :: ExtCns e -> PPM PPS.Doc -ppExtCns ec = ppVarName (ecName ec) - -- | Pretty-print a 'VarName' according to the current 'VarNaming'. ppVarName :: VarName -> PPM PPS.Doc ppVarName vn = @@ -529,10 +502,8 @@ ppTermF prec (Pi x tp body) = ppParensPrec prec PrecLambda <$> (ppPi <$> ppTerm' PrecApp tp <*> ppTermInBinder PrecLambda x body) -ppTermF _ (LocalVar x) = annotate PPS.LocalVarStyle <$> pretty <$> varLookupM x ppTermF _ (Constant nm) = annotate PPS.ConstantStyle <$> ppBestName nm -ppTermF _ (Variable ec) = annotate PPS.ExtCnsStyle <$> ppExtCns ec - +ppTermF _ (Variable vn _tp) = annotate PPS.ExtCnsStyle <$> ppVarName vn -- | Internal function to recursively pretty-print a term ppTerm' :: Prec -> Term -> PPM PPS.Doc @@ -588,6 +559,7 @@ scTermCountAux doBinders = go Lambda _ t1 _ | not doBinders -> [t1] Pi _ t1 _ | not doBinders -> [t1] Constant{} -> [] + Variable{} -> [] FTermF (Recursor _) -> [] tf -> Fold.toList tf @@ -604,7 +576,6 @@ shouldMemoizeTerm t = FTermF (ArrayValue _ v) | V.length v == 0 -> False FTermF StringLit{} -> False Constant{} -> False - LocalVar{} -> False Variable{} -> False _ -> True @@ -626,7 +597,7 @@ filterOccurenceMap min_occs global_p = IntMap.filter (\(t,cnt) -> cnt >= min_occs && shouldMemoizeTerm t && - (if global_p then termIsClosed t else True)) + (if global_p then closedTerm t else True)) -- For each (TermIndex, Term) pair in the occurrence map, pretty-print the @@ -662,11 +633,11 @@ ppLets global_p ((termIdx, (term,_)):idxs) bindings baseDoc = -- -- Also, pretty-print let-bindings around the term for all subterms that occur -- more than once at the same binding level. -ppTermInBinder :: Prec -> LocalName -> Term -> PPM (LocalName, PPS.Doc) -ppTermInBinder prec basename trm = - let nm = if basename == "_" && inBitSet 0 (looseVars trm) then "_x" +ppTermInBinder :: Prec -> VarName -> Term -> PPM (LocalName, PPS.Doc) +ppTermInBinder prec (VarName i basename) trm = + let nm = if basename == "_" && IntSet.member i (freeVars trm) then "_x" else basename in - withBoundVarM nm $ ppTermWithMemoTable prec False trm + withBoundVarM (VarName i nm) $ ppTermWithMemoTable prec False trm -- | Pretty-print a term, also adding let-bindings for all subterms that occur -- more than once at the same binding level @@ -675,7 +646,7 @@ ppTerm opts = ppTermWithNames opts emptyDisplayNameEnv -- | Like 'ppTerm', but also supply a context of bound names, where the most -- recently-bound variable is listed first in the context -ppTermInCtx :: PPS.Opts -> [LocalName] -> Term -> PPS.Doc +ppTermInCtx :: PPS.Opts -> [VarName] -> Term -> PPS.Doc ppTermInCtx opts ctx trm = runPPM opts emptyDisplayNameEnv $ withVarNames (Set.toList (termVarNames trm)) $ @@ -689,7 +660,7 @@ scPrettyTerm opts t = -- | Like 'scPrettyTerm', but also supply a context of bound names, where the -- most recently-bound variable is listed first in the context -scPrettyTermInCtx :: PPS.Opts -> [LocalName] -> Term -> String +scPrettyTermInCtx :: PPS.Opts -> [VarName] -> Term -> String scPrettyTermInCtx opts ctx trm = PPS.render opts $ runPPM opts emptyDisplayNameEnv $ diff --git a/saw-core/src/SAWCore/Typechecker.hs b/saw-core/src/SAWCore/Typechecker.hs index 1c704622d5..da9a699540 100644 --- a/saw-core/src/SAWCore/Typechecker.hs +++ b/saw-core/src/SAWCore/Typechecker.hs @@ -31,7 +31,8 @@ module SAWCore.Typechecker import Control.Monad (forM, forM_, void, unless) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Reader (ReaderT(..), asks, lift, local) -import Data.List (findIndex) +import Data.IntMap (IntMap) +import qualified Data.IntMap as IntMap import Data.Map (Map) import qualified Data.Map as Map import Data.Text (Text) @@ -52,6 +53,7 @@ import SAWCore.Module , DefQualifier(..) ) import qualified SAWCore.Parser.AST as Un +import SAWCore.Name import SAWCore.Parser.Position import SAWCore.Term.Functor import SAWCore.Term.CtxTerm @@ -67,12 +69,12 @@ import Debug.Trace -- empty typing context inferCompleteTerm :: SharedContext -> Maybe ModuleName -> Un.UTerm -> IO (Either PPS.Doc Term) -inferCompleteTerm sc mnm t = inferCompleteTermCtx sc mnm [] t +inferCompleteTerm sc mnm t = inferCompleteTermCtx sc mnm IntMap.empty t -- | Infer the type of an untyped term and complete it to a 'Term' in a given -- typing context inferCompleteTermCtx :: - SharedContext -> Maybe ModuleName -> [(LocalName, Term)] -> + SharedContext -> Maybe ModuleName -> IntMap Term -> Un.UTerm -> IO (Either PPS.Doc Term) inferCompleteTermCtx sc mnm ctx t = do res <- runCheckM (typeInferCompleteUTerm t) sc mnm ctx @@ -94,7 +96,7 @@ data CheckEnv = type CheckM = ReaderT CheckEnv TC.TCM runCheckM :: - CheckM a -> SharedContext -> Maybe ModuleName -> [(LocalName, Term)] -> + CheckM a -> SharedContext -> Maybe ModuleName -> IntMap Term -> IO (Either TC.TCError a) runCheckM m sc mnm ctx = TC.runTCM (runReaderT m (CheckEnv mnm Map.empty)) sc ctx @@ -131,23 +133,19 @@ inferApplyAll t (arg:args) = -- | Resolve a name in the current module and apply it to some arguments inferResolveNameApp :: Text -> [SCTypedTerm] -> CheckM SCTypedTerm inferResolveNameApp n args = - do ctx <- lift $ TC.askCtx - nctx <- askCtxEC + do nctx <- askCtxEC mnm <- getModuleName mm <- lift $ TC.liftTCM scGetModuleMap let ident = mkIdent mnm n - case (findIndex ((== n) . fst) ctx, Map.lookup n nctx, resolveNameInMap mm ident) of - (Just i, _, _) -> - do t <- typeInferComplete (LocalVar i :: TermF SCTypedTerm) + case (Map.lookup n nctx, resolveNameInMap mm ident) of + (Just ec, _) -> + do t <- typeInferComplete (Variable (ecName ec) (ecType ec)) inferApplyAll t args - (_, Just ec, _) -> - do t <- typeInferComplete (Variable ec) - inferApplyAll t args - (_, _, Just rn) -> + (_, Just rn) -> do let c = resolvedNameName rn t <- typeInferComplete (Constant c :: TermF SCTypedTerm) inferApplyAll t args - (Nothing, Nothing, Nothing) -> + (Nothing, Nothing) -> throwTCError $ UnboundName n -- | Match an untyped term as a name applied to 0 or more arguments @@ -226,10 +224,11 @@ typeInferCompleteTerm (Un.Lambda p ((Un.termVarLocalName -> x, tp) : ctx) t) = -- context in withVar, but we do not want to normalize this type in the -- output, as the contract for typeInferComplete only normalizes the type, -- so we use the unnormalized tp_trm in the return - tp_whnf <- lift $ TC.typeCheckWHNF $ typedVal tp_trm - body <- withVar x tp_whnf $ + -- tp_whnf <- lift $ TC.typeCheckWHNF $ typedVal tp_trm + vn <- lift $ TC.liftTCM scFreshVarName x + body <- withVar x (EC vn (typedVal tp_trm)) $ typeInferCompleteUTerm $ Un.Lambda p ctx t - typeInferComplete (Lambda x tp_trm body) + typeInferComplete (Lambda vn tp_trm body) typeInferCompleteTerm (Un.Pi _ [] t) = typeInferCompleteUTerm t typeInferCompleteTerm (Un.Pi p ((Un.termVarLocalName -> x, tp) : ctx) t) = do tp_trm <- typeInferCompleteUTerm tp @@ -238,9 +237,11 @@ typeInferCompleteTerm (Un.Pi p ((Un.termVarLocalName -> x, tp) : ctx) t) = -- output, as the contract for typeInferComplete only normalizes the type, -- so we use the unnormalized tp_trm in the return tp_whnf <- lift $ TC.typeCheckWHNF $ typedVal tp_trm - body <- withVar x tp_whnf $ + vn <- lift $ TC.liftTCM scFreshVarName x + body <- withVar x (EC vn tp_whnf) $ typeInferCompleteUTerm $ Un.Pi p ctx t - typeInferComplete (Pi x tp_trm body) + result <- typeInferComplete (Pi vn tp_trm body) + pure result -- Non-dependent records typeInferCompleteTerm (Un.RecordValue _ elems) = @@ -352,7 +353,8 @@ processDecls (Un.TypeDecl NoQualifier (PosPair p nm) tp : withCtx ctx $ do typed_body <- typeInferCompleteUTerm body lift $ TC.checkSubtype typed_body req_body_tp - lift $ TC.liftTCM scLambdaList ctx (typedVal typed_body) + result <- lift $ TC.liftTCM scAbstractExts (map snd ctx) (typedVal typed_body) + pure result -- Step 4: add the definition to the current module mnm <- getModuleName @@ -445,7 +447,7 @@ processDecls (Un.DataDecl (PosPair p nm) param_ctx dt_tp c_decls : rest) = "Type of that type: " <> Text.pack (showTerm $ typedType typed_tp) ] let tp = typedVal typed_tp - result <- lift $ TC.liftTCM mkCtorArgStruct pn dtParams dtIndices tp + let result = mkCtorArgStruct pn dtParams dtIndices tp case result of Just arg_struct -> lift $ TC.liftTCM scBuildCtor pn (mkIdent mnm c) arg_struct @@ -467,7 +469,7 @@ tcInsertModule sc (Un.Module (PosPair _ mnm) imports decls) = do unless i_exists $ fail $ "Imported module not found: " ++ show imn scImportModule sc (Un.nameSatsConstraint (Un.importConstraints imp) . Text.unpack) imn mnm -- Finally, process all the decls - decls_res <- runCheckM (processDecls decls) sc (Just mnm) [] + decls_res <- runCheckM (processDecls decls) sc (Just mnm) IntMap.empty case decls_res of Left err -> fail $ unlines $ TC.prettyTCError err Right _ -> return () @@ -479,11 +481,11 @@ tcInsertModule sc (Un.Module (PosPair _ mnm) imports decls) = do -- | Pattern match a nested pi-abstraction, like 'asPiList', but only match as -- far as the supplied list of variables, and use them as the new names -matchPiWithNames :: [LocalName] -> Term -> Maybe ([(LocalName, Term)], Term) +matchPiWithNames :: [LocalName] -> Term -> Maybe ([(LocalName, ExtCns Term)], Term) matchPiWithNames [] tp = return ([], tp) -matchPiWithNames (var:vars) (asPi -> Just (_, arg_tp, body_tp)) = +matchPiWithNames (var : vars) (asPi -> Just (nm, arg_tp, body_tp)) = do (ctx,body) <- matchPiWithNames vars body_tp - return ((var,arg_tp):ctx,body) + return ((var, EC nm arg_tp) : ctx,body) matchPiWithNames _ _ = Nothing -- | Run a type-checking computation in a typing context extended with a new @@ -493,25 +495,17 @@ matchPiWithNames _ _ = Nothing -- -- NOTE: the type given for the variable should be in WHNF, so that we do not -- have to normalize the types of variables each time we see them. -withVar :: LocalName -> Term -> CheckM a -> CheckM a -withVar x tp m = ReaderT $ \env -> TC.withVar x tp (runReaderT m env) - -withEC :: LocalName -> ExtCns Term -> CheckM a -> CheckM a -withEC x ec m = +withVar :: LocalName -> ExtCns Term -> CheckM a -> CheckM a +withVar x ec m = TC.rethrowTCError (ErrorCtx x (ecType ec)) $ TC.withEmptyTCState $ local (\env -> env { tcCtxEC = Map.insert x ec (tcCtxEC env) }) m -- | Run a type-checking computation in a typing context extended by a list of -- variables and their types. See 'withVar'. -withCtx :: [(LocalName, Term)] -> CheckM a -> CheckM a +withCtx :: [(LocalName, ExtCns Term)] -> CheckM a -> CheckM a withCtx = flip (foldr (\(x,tp) -> withVar x tp)) --- | Run a type-checking computation in a typing context extended by a list of --- variables and their types. See 'withEC'. -withCtxEC :: [(LocalName, ExtCns Term)] -> CheckM a -> CheckM a -withCtxEC = flip (foldr (\(x,ec) -> withEC x ec)) - -- | Perform type inference on a context, i.e., a list of variable names and -- their associated types. This will give us 'Term's for each type, as -- well as their 'Sort's, since the type of any type is a 'Sort'. @@ -522,14 +516,14 @@ typeInferCompleteCtxEC ((x, tp) : ctx) = do typed_tp <- typeInferCompleteUTerm tp s <- lift $ TC.ensureSort (typedType typed_tp) ec <- lift $ TC.liftTCM scFreshEC x (typedVal typed_tp) - ((x, ec, s) :) <$> withEC x ec (typeInferCompleteCtxEC ctx) + ((x, ec, s) :) <$> withVar x ec (typeInferCompleteCtxEC ctx) -- | Perform type inference on a context via 'typeInferCompleteCtxEC', and then --- run a computation in that context via 'withCtxEC', also passing in that context +-- run a computation in that context via 'withCtx', also passing in that context -- to the computation typeInferCompleteInCtxEC :: [(LocalName, Un.UTerm)] -> ([(LocalName, ExtCns Term, Sort)] -> CheckM a) -> CheckM a typeInferCompleteInCtxEC ctx f = do typed_ctx <- typeInferCompleteCtxEC ctx - withCtxEC (map (\(x,ec,_) -> (x,ec)) typed_ctx) (f typed_ctx) + withCtx (map (\(x,ec,_) -> (x,ec)) typed_ctx) (f typed_ctx)