Skip to content

Commit

Permalink
Add Provenance feature for Type guards (#265)
Browse files Browse the repository at this point in the history
Co-authored-by: Jihyeok Park <jihyeok_park@korea.ac.kr>
  • Loading branch information
kimjg1119 and jhnaldo authored Nov 21, 2024
1 parent 6208b60 commit 0dfdf87
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 8 deletions.
53 changes: 46 additions & 7 deletions src/main/scala/esmeta/analyzer/tychecker/AbsTransfer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ trait AbsTransferDecl { analyzer: TyChecker =>
val NodePoint(func, node, view) = np
node match
case Block(_, insts, next) =>
val newSt = insts.foldLeft(st) {
val newSt = insts.zipWithIndex.foldLeft(st) {
case (nextSt, _) if nextSt.isBottom => nextSt
case (nextSt, inst) => transfer(inst)(nextSt)
case (nextSt, (inst, idx)) => transfer(inst, idx)(nextSt)
}
next.foreach(to => analyzer += getNextNp(np, to) -> newSt)
case call: Call =>
Expand All @@ -48,15 +48,36 @@ trait AbsTransferDecl { analyzer: TyChecker =>
} yield ())(st)
call.next.foreach(to => analyzer += getNextNp(np, to) -> newSt)
case br @ Branch(_, kind, c, thenNode, elseNode) =>
import RefinementTarget.*
import RefinementKind.*
(for { v <- transfer(c); newSt <- get } yield {
if (v.ty.bool.contains(true))
val refinedSt = refine(c, v, true)(newSt)
thenNode.map(analyzer += getNextNp(np, _) -> refinedSt)
val rst = refine(c, v, true)(newSt)
val pred = v.guard.get(True)
if (detail) logRefined(BranchTarget(br, true), pred, newSt, rst)
thenNode.map(analyzer += getNextNp(np, _) -> rst)
if (v.ty.bool.contains(false))
val refinedSt = refine(c, v, false)(newSt)
elseNode.map(analyzer += getNextNp(np, _) -> refinedSt)
val rst = refine(c, v, false)(newSt)
val pred = v.guard.get(False)
if (detail) logRefined(BranchTarget(br, false), pred, newSt, rst)
elseNode.map(analyzer += getNextNp(np, _) -> rst)
})(st)

def logRefined(
target: RefinementTarget,
pred: Option[SymPred],
st: AbsState,
refinedSt: AbsState,
): Unit =
val xs = for {
(x, v) <- st.locals
ty = v.ty(using st)
refinedTy = refinedSt.get(x).ty(using refinedSt)
if refinedTy != ty
} yield x
if (xs.isEmpty) refined -= target
else refined += target -> (xs.toSet, pred.fold(0)(_.depth))

/** refine with an expression and its abstract value */
def refine(
expr: Expr,
Expand Down Expand Up @@ -332,6 +353,12 @@ trait AbsTransferDecl { analyzer: TyChecker =>
/** transfer function for normal instructions */
def transfer(
inst: NormalInst,
)(using np: NodePoint[_]): Updater = transfer(inst, -1)

/** transfer function for normal instructions */
def transfer(
inst: NormalInst,
idx: Int,
)(using np: NodePoint[_]): Updater = inst match {
case IExpr(expr) =>
for {
Expand Down Expand Up @@ -389,8 +416,20 @@ trait AbsTransferDecl { analyzer: TyChecker =>
case IAssert(expr) =>
for {
v <- transfer(expr)
st <- get
pred = v.guard.get(RefinementKind.True)
_ <- modify(refine(expr, v, true))
given AbsState <- get
refinedSt <- get
given AbsState = refinedSt
_ = if (detail) np.node match
case block: Block =>
logRefined(
RefinementTarget.AssertTarget(block, idx),
pred,
st,
refinedSt,
)
case _ =>
_ <- if (v False) put(AbsState.Bot) else pure(())
} yield ()
case IPrint(expr) => st => st /* skip */
Expand Down
34 changes: 34 additions & 0 deletions src/main/scala/esmeta/analyzer/tychecker/TyChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ class TyChecker(
"returns" -> ratioSimpleString(analyzedReturns.size, cfg.funcs.size),
),
)
if (detail)
info :+= "refined" -> Map(
"targets" -> refinedTargets,
"locals" -> refinedLocals,
"avg. depth" -> refinedAvgDepth,
)
if (inferTypeGuard) info :+= "guards" -> typeGuards.size
Yaml(info: _*)
},
Expand Down Expand Up @@ -149,6 +155,9 @@ class TyChecker(
cfg.funcs.size, // total funcs
this.analyzedNodes.size, // analyzed nodes
cfg.nodes.size, // total nodes
if (detail) refinedTargets else 0, // refined targets
if (detail) refinedLocals else 0, // refined locals
if (detail) refinedAvgDepth else 0, // refined avg. depth
if (inferTypeGuard) typeGuards.size else 0, // guards
).mkString("\t"),
filename = s"$ANALYZE_LOG_DIR/summary",
Expand Down Expand Up @@ -204,6 +213,12 @@ class TyChecker(
filename = s"$ANALYZE_LOG_DIR/detailed-types",
silent = silent,
)
dumpFile(
name = "refined targets",
data = refinedString,
filename = s"$ANALYZE_LOG_DIR/refined",
silent = silent,
)
if (inferTypeGuard)
val names = typeGuards.map(_._1.name).toSet
dumpFile(
Expand All @@ -217,6 +232,25 @@ class TyChecker(
)
}

/** refined targets */
var refined: Map[RefinementTarget, (Set[Local], Int)] = Map()
def refinedTargets: Int = refined.size
def refinedLocals: Int = refined.values.map(_._1.size).sum
def refinedAvgDepth: Double =
refined.values.map(_._2).sum.toDouble / refined.size
def refinedString: String =
given Rule[Map[RefinementTarget, (Set[Local], Int)]] =
(app, refined) =>
val sorted = refined.toList.sortBy { (t, _) => t }
for ((target, (xs, depth)) <- sorted)
app >> target >> " -> "
app >> xs.toList.sorted.mkString("[locals: ", ", ", "]")
app >> " [depth: " >> depth >> "]"
app >> LINE_SEP
app
(new Appender >> refined).toString

/** inferred type guards */
def getTypeGuards: List[(Func, AbsValue)] = for {
func <- cfg.funcs
entrySt = getResult(NodePoint(func, func.entry, emptyView))
Expand Down
31 changes: 30 additions & 1 deletion src/main/scala/esmeta/analyzer/tychecker/TypeGuard.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package esmeta.analyzer.tychecker

import esmeta.cfg.{Func, Call}
import esmeta.cfg.*
import esmeta.ir.{Func => _, *}
import esmeta.ty.*
import esmeta.ty.util.{Stringifier => TyStringifier}
Expand Down Expand Up @@ -63,6 +63,15 @@ trait TypeGuardDecl { self: TyChecker =>
def apply(ps: (RefinementKind, SymPred)*): TypeGuard = TypeGuard(ps.toMap)
}

/** type refinement target */
enum RefinementTarget:
case BranchTarget(branch: Branch, isTrue: Boolean)
case AssertTarget(block: Block, idx: Int)
def node: Node = this match
case BranchTarget(branch, _) => branch
case AssertTarget(block, _) => block
def func: Func = cfg.funcOf(node)

/** type refinement kinds */
enum RefinementKind {
case True, False, Normal, Abrupt, NormalTrue, NormalFalse
Expand Down Expand Up @@ -97,6 +106,7 @@ trait TypeGuardDecl { self: TyChecker =>
/** Symbol */
type Sym = Int
case class Provenance(map: Map[Func, List[Call]] = Map()) {
def depth: Int = map.values.map(_.length).max
def join(that: Provenance): Provenance = Provenance((for {
key <- (this.map.keySet union that.map.keySet).toList
calls <- (this.map.get(key), that.map.get(key)) match
Expand Down Expand Up @@ -165,6 +175,9 @@ trait TypeGuardDecl { self: TyChecker =>
} yield x -> (origTy && ty, prov),
sexpr = None,
)
def depth: Int =
val provs = map.values.map(_._2).toList
sexpr.fold(provs)(_._2 :: provs).map(_.depth).max
override def toString: String = (new Appender >> this).toString
}
object SymPred {
Expand Down Expand Up @@ -356,6 +369,22 @@ trait TypeGuardDecl { self: TyChecker =>
given Rule[TypeGuard] = (app, guard) =>
given Rule[Map[RefinementKind, SymPred]] = sortedMapRule("{", "}", " => ")
app >> guard.map
given Rule[RefinementTarget] = (app, target) =>
import RefinementTarget.*
val node = target.node
val func = target.func
app >> func.nameWithId >> ":" >> node.name >> ":"
target match
case BranchTarget(branch, isTrue) =>
app >> (if (isTrue) "T" else "F")
case AssertTarget(block, idx) =>
app >> idx
given Ordering[RefinementTarget] = Ordering.by { target =>
import RefinementTarget.*
target match
case BranchTarget(branch, isTrue) => (branch.id, if (isTrue) 1 else 0)
case AssertTarget(block, idx) => (block.id, idx)
}
given Rule[RefinementKind] = (app, kind) =>
import RefinementKind.*
kind match
Expand Down

0 comments on commit 0dfdf87

Please sign in to comment.