Skip to content

Commit

Permalink
More impl. of refinement
Browse files Browse the repository at this point in the history
  • Loading branch information
jhnaldo committed Oct 2, 2024
1 parent e279372 commit f6aecad
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 54 deletions.
17 changes: 12 additions & 5 deletions src/main/scala/esmeta/analyzer/tychecker/AbsState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,21 @@ trait AbsStateDecl { self: TyChecker =>
case _ if this.isBottom => that
case _ if that.isBottom => this
case (l, r) =>
val newLocals = (for {
x <- (l.locals.keySet ++ r.locals.keySet).toList
v = l.get(x) ⊔ r.get(x)
} yield x -> v).toMap
var killed = Set[Sym]()
def handleKilled(v: AbsValue)(using AbsState): AbsValue =
if (killed.exists(v.has)) AbsValue(v.ty, Many, v.guard)
else v
val newSymEnv = (for {
sym <- (l.symEnv.keySet ++ r.symEnv.keySet).toList
ty = l.getTy(sym) ⊔ r.getTy(sym)
lty = l.getTy(sym)
rty = r.getTy(sym)
_ = if (lty != rty) killed += sym
ty = lty || rty
} yield sym -> ty).toMap
val newLocals = (for {
x <- (l.locals.keySet ++ r.locals.keySet).toList
v = handleKilled(l.get(x))(using l) ⊔ handleKilled(r.get(x))(using r)
} yield x -> v).toMap
AbsState(true, newLocals, newSymEnv)

/** meet operator */
Expand Down
167 changes: 138 additions & 29 deletions src/main/scala/esmeta/analyzer/tychecker/AbsTransfer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,19 @@ trait AbsTransferDecl { analyzer: TyChecker =>
} yield ())(st)
call.next.foreach(to => analyzer += getNextNp(np, to) -> newSt)
case br @ Branch(_, kind, c, thenNode, elseNode) =>
import RefinementKind.*
(for { v <- transfer(c); newSt <- get } yield {
if (v.ty.bool.contains(true))
thenNode.map(
analyzer += getNextNp(np, _) -> refine(c, true)(newSt),
)
val refinedSt = v.guard.get(True) match
case Some(pred) => refine(pred, true)(newSt)
case None => if (useTypeGuard) newSt else refine(c, true)(newSt)
thenNode.map(analyzer += getNextNp(np, _) -> refinedSt)
if (v.ty.bool.contains(false))
elseNode.map(
analyzer += getNextNp(np, _) -> refine(c, false)(newSt),
)
val refinedSt = v.guard.get(False) match
case Some(pred) => refine(pred, true)(newSt)
case None =>
if (useTypeGuard) newSt else refine(c, false)(newSt)
elseNode.map(analyzer += getNextNp(np, _) -> refinedSt)
})(st)

/** get next node point */
Expand Down Expand Up @@ -131,26 +135,6 @@ trait AbsTransferDecl { analyzer: TyChecker =>
doCall(callPoint, st, args, ast :: vs, method = true)

newV
// TODO bv.getSingle match
// case One(AstValue(syn: Syntactic)) =>
// getSdo((syn, method)) match
// case Some((ast, sdo)) =>
// val callPoint = CallPoint(callerNp, sdo)
// val astV = AbsValue(ast)
// doCall(callPoint, st, args, astV :: vs, method = true)
// case None => error("invalid sdo")
// case One(AstValue(lex: Lexical)) =>
// newV ⊔= AbsValue(Interpreter.eval(lex, method))
// case Many =>
// // lexical sdo
// newV ⊔= bv.getLexical(method)

// // syntactic sdo
// for ((sdo, ast) <- bv.getSdo(method))
// val callPoint = CallPoint(callerNp, sdo)
// doCall(callPoint, st, args, ast :: vs, method = true)
// case _ => /* do nothing */
// newV
}
}
}
Expand All @@ -175,7 +159,6 @@ trait AbsTransferDecl { analyzer: TyChecker =>
}.toMap
val newRetV = (for {
refine <- typeGuards.get(callee.name)
if useTypeGuard || defaultTypeGuards.contains(callee.name)
v = refine(vs, retTy, callerSt)
guard = for {
(kind, pred) <- v.guard
Expand Down Expand Up @@ -369,6 +352,7 @@ trait AbsTransferDecl { analyzer: TyChecker =>
def transfer(
expr: Expr,
)(using np: NodePoint[Node]): Result[AbsValue] = expr match {
case Refiner(v) => v
case EParse(code, rule) =>
for {
c <- transfer(code)
Expand Down Expand Up @@ -534,6 +518,125 @@ trait AbsTransferDecl { analyzer: TyChecker =>
case ECodeUnit(c) => AbsValue(CodeUnitT)
}

object Refiner {
import RefinementKind.*, SymExpr.*, SymRef.*
def apply(
expr: Expr,
)(using np: NodePoint[_]): Option[Result[AbsValue]] = unapply(expr)

def unapply(
expr: Expr,
)(using np: NodePoint[_]): Option[Result[AbsValue]] = {
if (!useTypeGuard) None
else
expr match {
case EBinary(BOp.Eq, ERef(x: Local), expr) =>
Some(for {
lv <- transfer(x)
rv <- transfer(expr)
given AbsState <- get
} yield {
val lty = lv.ty
val rty = rv.ty
val thenTy = lty && rty
val elseTy = if (rty.isSingle) lty -- rty else lty
val expr = lv.expr match
case One(expr) => expr
case _ => SERef(SLocal(x))
var guard: TypeGuard = Map()
if (lty != thenTy) guard += True -> SETypeCheck(expr, thenTy)
if (lty != elseTy) guard += False -> SETypeCheck(expr, elseTy)
AbsValue(BoolT, Many, guard)
})
case ETypeCheck(ERef(x: Local), givenTy) =>
Some(for {
lv <- transfer(x)
given AbsState <- get
} yield {
val lty = lv.ty
val rty = givenTy.toValue
val thenTy = lty && rty
val elseTy = lty -- rty
val expr = lv.expr match
case One(expr) => expr
case _ => SERef(SLocal(x))
var guard: TypeGuard = Map()
if (lty != thenTy) guard += True -> SETypeCheck(expr, thenTy)
if (lty != elseTy) guard += False -> SETypeCheck(expr, elseTy)
AbsValue(BoolT, Many, guard)
})
case EExists(Field(x: Local, EStr(field))) =>
Some(for {
lv <- transfer(x)
given AbsState <- get
} yield {
val lty = lv.ty
def aux(binding: Binding) = ValueTy(
ast = lty.ast,
record = lty.record.update(field, binding, refine = true),
)
val binding = Binding.Exist
val thenTy = aux(binding)
val elseTy = aux(lty.record(field) -- binding)
val expr = lv.expr match
case One(expr) => expr
case _ => SERef(SLocal(x))
var guard: TypeGuard = Map()
if (lty != thenTy) guard += True -> SETypeCheck(expr, thenTy)
if (lty != elseTy) guard += False -> SETypeCheck(expr, elseTy)
AbsValue(BoolT, Many, guard)
})
case EBinary(BOp.Eq, ETypeOf(ERef(x: Local)), expr) =>
Some(for {
lv <- transfer(x)
rv <- transfer(expr)
given AbsState <- get
} yield {
val lty = lv.ty
val rty = rv.ty
def aux(positive: Boolean): ValueTy = rty.str.getSingle match
case One(tname) =>
val vty = ValueTy.fromTypeOf(tname)
if (positive) lty && vty else lty -- vty
case _ => lty
val thenTy = aux(true)
val elseTy = aux(false)
val expr = lv.expr match
case One(expr) => expr
case _ => SERef(SLocal(x))
var guard: TypeGuard = Map()
if (lty != thenTy) guard += True -> SETypeCheck(expr, thenTy)
if (lty != elseTy) guard += False -> SETypeCheck(expr, elseTy)
AbsValue(BoolT, Many, guard)
})
case EBinary(BOp.And, l, r) =>
Some(for {
lv <- transfer(l)
st <- get
lguard = lv.guard
lt = lguard.get(True)
lf = lguard.get(False)
} yield {
var guard: TypeGuard = Map()
val refinedSt = lt.fold(st)(refine(_, true)(st))
println(refinedSt)
val (thenPred, _) = (for {
rv <- transfer(r)
rt = rv.guard.get(True)
} yield lt && rt)(refinedSt)
thenPred.map { guard += True -> _ }
val (elsePred, _) = (for {
rv <- transfer(r)
rf = rv.guard.get(False)
} yield lf || rf)(refinedSt)
elsePred.map { guard += False -> _ }
AbsValue(BoolT, Many, guard)
})
case _ => None
}
}
}

/** transfer function for references */
def transfer(
ref: Ref,
Expand Down Expand Up @@ -823,7 +926,10 @@ trait AbsTransferDecl { analyzer: TyChecker =>
)(using np: NodePoint[_]): Updater =
import SymExpr.*, SymRef.*
ref match
case SSym(sym) => ???
case SSym(sym) =>
st =>
val refinedTy = st.symEnv.get(sym).fold(ty)(_ && ty)
st.copy(symEnv = st.symEnv + (sym -> refinedTy))
case SLocal(x) =>
for {
v <- transfer(x)
Expand Down Expand Up @@ -1019,7 +1125,7 @@ trait AbsTransferDecl { analyzer: TyChecker =>
/** check if the return type can be used */
private lazy val canUseReturnTy: Func => Boolean = cached { func =>
!func.retTy.isImprec ||
(useTypeGuard && typeGuards.contains(func.name)) ||
typeGuards.contains(func.name) ||
defaultTypeGuards.contains(func.name)
}

Expand Down Expand Up @@ -1077,14 +1183,17 @@ trait AbsTransferDecl { analyzer: TyChecker =>
},
"IteratorClose" -> { (vs, retTy, st) =>
given AbsState = st
// Throw | #1
AbsValue(vs(1).ty || ThrowT, Zero, Map())
},
"AsyncIteratorClose" -> { (vs, retTy, st) =>
given AbsState = st
// Throw | #1
AbsValue(vs(1).ty || ThrowT, Zero, Map())
},
"OrdinaryObjectCreate" -> { (vs, retTy, st) =>
given AbsState = st
// Object
AbsValue(RecordT("Object"), Zero, Map())
},
"UpdateEmpty" -> { (vs, retTy, st) =>
Expand Down
26 changes: 15 additions & 11 deletions src/main/scala/esmeta/analyzer/tychecker/AbsValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,26 @@ trait AbsValueDecl { self: TyChecker =>
case _ => None
}.toMap
(lexpr, rexpr) match
case (Zero, Zero) => AbsValue(llty && rlty, Zero, guard)
case (Zero, One(r)) => AbsValue(llty && rlty, One(r), guard)
case (One(l), Zero) => AbsValue(llty && rlty, One(l), guard)
case (One(l), One(r)) if l == r =>
AbsValue(llty && rlty, One(l), guard)
case (One(_), One(_)) => AbsValue(l.ty && r.ty, Many, guard)
case (One(_), Many) => AbsValue(l.ty && rlty, Many, guard)
case (Many, One(_)) => AbsValue(llty && r.ty, Many, guard)
case (Zero, Many) => AbsValue(llty && rlty, Many, guard)
case (Many, Zero) => AbsValue(llty && rlty, Many, guard)
case (Many, Many) => AbsValue(llty && rlty, Many, guard)
case (Zero, Zero) => AbsValue(llty && rlty, Zero, guard)
case (Zero, One(r)) => AbsValue(llty && rlty, One(r), guard)
case (One(l), Zero) => AbsValue(llty && rlty, One(l), guard)
case (One(l), One(r)) if l == r => AbsValue(llty && rlty, One(l), guard)
case (One(_), One(_)) => AbsValue(l.ty && r.ty, Many, guard)
case (One(_), Many) => AbsValue(l.ty && rlty, Many, guard)
case (Many, One(_)) => AbsValue(llty && r.ty, Many, guard)
case (Zero, Many) => AbsValue(llty && rlty, Many, guard)
case (Many, Zero) => AbsValue(llty && rlty, Many, guard)
case (Many, Many) => AbsValue(llty && rlty, Many, guard)

/** prune operator */
def --(that: AbsValue)(using AbsState): AbsValue =
this.copy(lowerTy = this.lowerTy -- that.lowerTy)

/** has symbols */
def has(sym: Sym): Boolean = expr match
case One(expr) => expr.has(sym)
case _ => false

/** get lexical result */
def getLexical(method: String)(using AbsState): AbsValue = {
val ty = this.ty
Expand Down
19 changes: 10 additions & 9 deletions src/main/scala/esmeta/analyzer/tychecker/TyChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,16 @@ class TyChecker(
): AbsState =
import SymExpr.*, SymRef.*
given AbsState = callerSt
val idxLocals = locals.zipWithIndex
val (newLocals, symEnv) = (for {
((x, value), sym) <- idxLocals
} yield (
x -> value,
// TODO x -> AbsValue(BotT, One(SERef(SSym(sym))), Map()),
sym -> value.ty,
)).unzip
callerSt.copy(locals = newLocals.toMap) // TODO , symEnv = symEnv.toMap)
if (useTypeGuard) {
val idxLocals = locals.zipWithIndex
val (newLocals, symEnv) = (for {
((x, value), sym) <- idxLocals
} yield (
x -> AbsValue(BotT, One(SERef(SSym(sym))), Map()),
sym -> value.ty,
)).unzip
callerSt.copy(locals = newLocals.toMap, symEnv = symEnv.toMap)
} else callerSt.copy(locals = locals.toMap)

/** get initial abstract states in each node point */
private def getInitNpMap(
Expand Down
5 changes: 5 additions & 0 deletions src/main/scala/esmeta/ty/ValueTy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ sealed trait ValueTy extends Ty with Lattice[ValueTy] {
(if (this.undef.isBottom) Zero else One(Undef)) ||
(if (this.nullv.isBottom) Zero else One(Null))

/** single value check */
def isSingle: Boolean = getSingle match
case One(_) => true
case _ => false

/** types having no field */
def noField: ValueTy = this match
case ValueTopTy =>
Expand Down
30 changes: 30 additions & 0 deletions src/main/scala/esmeta/ty/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,45 @@ enum SymExpr:
case SETypeCheck(base: SymExpr, ty: ValueTy)
case SEBinary(bop: BOp, left: SymExpr, right: SymExpr)
case SEUnary(uop: UOp, expr: SymExpr)
def &&(that: SymExpr): SymExpr = (this, that) match
case (SEBool(true), _) => that
case (_, SEBool(true)) => this
case (SEBool(false), _) | (_, SEBool(false)) => SEBool(false)
case _ => SEBinary(BOp.And, this, that)
def ||(that: SymExpr): SymExpr = (this, that) match
case (SEBool(false), _) => that
case (_, SEBool(false)) => this
case (SEBool(true), _) | (_, SEBool(true)) => SEBool(true)
case _ => SEBinary(BOp.Or, this, that)
def has(sym: Sym): Boolean = this match
case SEBool(b) => false
case SEStr(s) => false
case SERef(ref) => ref.has(sym)
case SETypeCheck(base, ty) => base.has(sym)
case SEBinary(bop, left, right) => left.has(sym) || right.has(sym)
case SEUnary(uop, expr) => expr.has(sym)
object SymExpr:
val T: SymExpr = SEBool(true)
val F: SymExpr = SEBool(false)
extension (l: Option[SymExpr])
def &&(r: Option[SymExpr]): Option[SymExpr] = (l, r) match
case (Some(l), Some(r)) => Some(l && r)
case (Some(l), None) => Some(l)
case (None, Some(r)) => Some(r)
case _ => None
def ||(r: Option[SymExpr]): Option[SymExpr] = (l, r) match
case (Some(l), Some(r)) => Some(l || r)
case _ => None

/** symbolic references */
enum SymRef:
case SSym(sym: Sym)
case SLocal(x: Local)
case SField(base: SymRef, field: SymExpr)
def has(sym: Sym): Boolean = this match
case SSym(s) => s == sym
case SLocal(x) => false
case SField(base, f) => base.has(sym) || f.has(sym)

/** type guard */
type TypeGuard = Map[RefinementKind, SymExpr]
Expand Down

0 comments on commit f6aecad

Please sign in to comment.