diff --git a/rewrite/src/main/java/vct/col/rewrite/InlineApplicables.scala b/rewrite/src/main/java/vct/col/rewrite/InlineApplicables.scala index 9d1e5468da..087a6badc8 100644 --- a/rewrite/src/main/java/vct/col/rewrite/InlineApplicables.scala +++ b/rewrite/src/main/java/vct/col/rewrite/InlineApplicables.scala @@ -1,14 +1,16 @@ package vct.col.rewrite +import com.typesafe.scalalogging.LazyLogging import hre.util.ScopedStack import vct.col.ast._ -import vct.col.origin.Origin +import vct.col.origin.{AssertFailed, Blame, FoldFailed, Origin, UnfoldFailed} import vct.col.ref.Ref import vct.col.rewrite.{Generation, NonLatchingRewriter, Rewriter, RewriterBuilder, Rewritten} import vct.col.util.AstBuildHelpers._ import vct.col.util.Substitute import vct.result.VerificationError.{Unreachable, UserError} +import scala.annotation.tailrec import scala.collection.mutable import scala.reflect.ClassTag @@ -68,6 +70,16 @@ case object InlineApplicables extends RewriterBuilder { override def inlineContext: String = s"${definition.inlineContext} [inlined from] ${usages.head.o.inlineContext}" } + case class InlineFoldAssertFailed(fold: Fold[_]) extends Blame[AssertFailed] { + override def blame(error: AssertFailed): Unit = + fold.blame.blame(FoldFailed(error.failure, fold)) + } + + case class InlineUnfoldAssertFailed(unfold: Unfold[_]) extends Blame[AssertFailed] { + override def blame(error: AssertFailed): Unit = + unfold.blame.blame(UnfoldFailed(error.failure, unfold)) + } + case class Replacement[Pre](replacing: Expr[Pre], binding: Expr[Pre])(implicit o: Origin) { val withVariable: Variable[Pre] = new Variable(replacing.t) def +(other: Replacements[Pre]): Replacements[Pre] = Replacements(Seq(this)) + other @@ -106,7 +118,7 @@ case object InlineApplicables extends RewriterBuilder { } } -case class InlineApplicables[Pre <: Generation]() extends Rewriter[Pre] { +case class InlineApplicables[Pre <: Generation]() extends Rewriter[Pre] with LazyLogging { import InlineApplicables._ val inlineStack: ScopedStack[Apply[Pre]] = ScopedStack() @@ -132,6 +144,21 @@ case class InlineApplicables[Pre <: Generation]() extends Rewriter[Pre] { case other => rewriteDefault(other) } + @tailrec + private def isInlinePredicateApply(e: Expr[Pre]): Boolean = e match { + case PredicateApply(Ref(pred), _, _) => pred.inline + case InstancePredicateApply(_, Ref(pred), _, _) => pred.inline + case Scale(_, res) => isInlinePredicateApply(res) + case _ => false + } + + override def dispatch(stat: Statement[Pre]): Statement[Post] = stat match { + case f @ Fold(e) if isInlinePredicateApply(e) => Assert(dispatch(e))(InlineFoldAssertFailed(f))(stat.o) + case u @ Unfold(e) if isInlinePredicateApply(e) => Assert(dispatch(e))(InlineUnfoldAssertFailed(u))(stat.o) + + case other => rewriteDefault(other) + } + override def dispatch(e: Expr[Pre]): Expr[Post] = e match { case apply: ApplyInlineable[Pre] if apply.ref.decl.inline => implicit val o: Origin = apply.o @@ -192,6 +219,16 @@ case class InlineApplicables[Pre <: Generation]() extends Rewriter[Pre] { } } + case Unfolding(PredicateApply(Ref(pred), args, perm), body) if pred.inline => + With(Block(args.map(dispatch).map(e => Eval(e)(e.o)) :+ Eval(dispatch(perm))(perm.o))(e.o), dispatch(body))(e.o) + + case Unfolding(InstancePredicateApply(obj, Ref(pred), args, perm), body) if pred.inline => + With(Block( + Seq(Eval(dispatch(obj))(obj.o)) ++ + args.map(dispatch).map(e => Eval(e)(e.o)) ++ + Seq(Eval(dispatch(perm))(perm.o)) + )(e.o), dispatch(body))(e.o) + case other => rewriteDefault(other) } } diff --git a/rewrite/src/main/java/vct/col/rewrite/ResolveScale.scala b/rewrite/src/main/java/vct/col/rewrite/ResolveScale.scala index 645a0555c3..c7a772961e 100644 --- a/rewrite/src/main/java/vct/col/rewrite/ResolveScale.scala +++ b/rewrite/src/main/java/vct/col/rewrite/ResolveScale.scala @@ -66,6 +66,8 @@ case class ResolveScale[Pre <: Generation]() extends Rewriter[Pre] { case Select(cond, whenTrue, whenFalse) => Select(dispatch(cond), scale(whenTrue, amount), scale(whenFalse, amount)) case s: Starall[Pre] => s.rewrite(body = scale(s.body, amount)) + case l: Let[Pre] => l.rewrite(main = scale(l.main, amount)) + case other => throw WrongScale(other) } } diff --git a/rewrite/src/main/java/vct/col/rewrite/SingletonStarall.scala b/rewrite/src/main/java/vct/col/rewrite/SingletonStarall.scala deleted file mode 100644 index 2693d9f78e..0000000000 --- a/rewrite/src/main/java/vct/col/rewrite/SingletonStarall.scala +++ /dev/null @@ -1,49 +0,0 @@ -package vct.col.rewrite -import vct.col.ast._ -import vct.col.origin.{Blame, Origin, ReceiverNotInjective} -import vct.col.rewrite.SingletonStarall.UnknownStarallFormat -import vct.col.util.AstBuildHelpers.{ExprBuildHelpers, foldAnd, foldStar} -import vct.result.VerificationError.SystemError - -case object SingletonStarall extends RewriterBuilder { - override def key: String = "singletonStarall" - override def desc: String = "Convert multi-resource staralls to single-resource staralls" - - case class UnknownStarallFormat(e: Expr[_]) extends SystemError { - override def text: String = e.o.messageInContext("This expression should probably not occur in a ∀*") - } -} - -/** - * Starall's that have multiple resources are expanded into several starall's with one resource. This is mostly already - * done by the simplifier, but some passes after the simplifier may introduce them. - */ -case class SingletonStarall[Pre <: Generation]() extends Rewriter[Pre] { - def expand(bindings: Seq[Variable[Pre]], - triggers: Seq[Seq[Expr[Pre]]], - blame: Blame[ReceiverNotInjective], - conds: Seq[Expr[Pre]], - body: Expr[Pre])(implicit o: Origin): Seq[Expr[Post]] = body match { - case Perm(_, _) | Value(_) | PredicateApply(_, _, _) => variables.scope { Seq( - Starall(variables.dispatch(bindings), triggers.map(_.map(dispatch)), foldAnd(conds.map(dispatch)) ==> dispatch(body))(blame) - ) } - case body if body.t == TBool[Pre]() => variables.scope { Seq( - Forall(variables.dispatch(bindings), triggers.map(_.map(dispatch)), foldAnd(conds.map(dispatch)) ==> dispatch(body)) - ) } - - case Star(left, right) => expand(bindings, triggers, blame, conds, left) ++ expand(bindings, triggers, blame, conds, right) - case Implies(cond, res) => - expand(bindings, triggers, blame, conds :+ cond, res) - case Select(cond, left, right) => - expand(bindings, triggers, blame, conds :+ cond, left) ++ expand(bindings, triggers, blame, conds :+ !cond, right) - case Starall(moreBindings, _, body) => - expand(bindings ++ moreBindings, triggers, blame, conds, body) - case other => throw UnknownStarallFormat(other) - } - - override def dispatch(e: Expr[Pre]): Expr[Post] = e match { - case starall @ Starall(bindings, triggers, body) => - foldStar(expand(bindings, triggers, starall.blame, Nil, body)(e.o))(e.o) - case other => rewriteDefault(other) - } -} diff --git a/src/main/java/vct/main/stages/Transformation.scala b/src/main/java/vct/main/stages/Transformation.scala index bcaf8ccb20..f5354af285 100644 --- a/src/main/java/vct/main/stages/Transformation.scala +++ b/src/main/java/vct/main/stages/Transformation.scala @@ -244,7 +244,6 @@ case class SilverTransformation ForLoopToWhileLoop, BranchToIfElse, EvaluationTargetDummy, - SingletonStarall, // Final translation to rigid silver nodes SilverIntRatCoercion, diff --git a/viper/src/main/java/viper/api/transform/ColToSilver.scala b/viper/src/main/java/viper/api/transform/ColToSilver.scala index e1a9c82a96..ae6b6be9e1 100644 --- a/viper/src/main/java/viper/api/transform/ColToSilver.scala +++ b/viper/src/main/java/viper/api/transform/ColToSilver.scala @@ -281,9 +281,15 @@ case class ColToSilver(program: col.Program[_]) { case col.Forall(bindings, triggers, body) => scoped { silver.Forall(bindings.map(variable), triggers.map(trigger), exp(body))(pos=pos(e), info=expInfo(e)) } case starall @ col.Starall(bindings, triggers, body) => - scoped { currentStarall.having(starall) { - silver.Forall(bindings.map(variable), triggers.map(trigger), exp(body))(pos=pos(e), info=expInfo(e)) - } } + scoped { + currentStarall.having(starall) { + val foralls: Seq[silver.Forall] = silver.utility.QuantifiedPermissions.desugarSourceQuantifiedPermissionSyntax( + silver.Forall(bindings.map(variable), triggers.map(trigger), exp(body))(pos=pos(e), info=expInfo(e)) + ) + + foralls.reduce[silver.Exp] { case (l, r) => silver.And(l, r)(pos=pos(e), info=expInfo(e)) } + } + } case col.Let(binding, value, main) => scoped { silver.Let(variable(binding), exp(value), exp(main))(pos=pos(e), info=expInfo(e)) } case col.Not(arg) => silver.Not(exp(arg))(pos=pos(e), info=expInfo(e))