Skip to content

Commit

Permalink
push through declarative bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Apr 22, 2024
1 parent d3760b9 commit 5a13d5a
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 197 deletions.
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ abstract class BaseIR {

def forEachChildWithEnv[E <: GenericBindingEnv[E, Type]](env: E)(f: (BaseIR, E) => Unit): Unit =
childrenSeq.view.zipWithIndex.foreach { case (child, i) =>
val childEnv = Bindings.get(this, i, env)
val childEnv = env.extend(Bindings.get(this, i))
f(child, childEnv)
}

Expand All @@ -73,7 +73,7 @@ abstract class BaseIR {
val newChildren = childrenSeq.toArray
var res = this
for (i <- newChildren.indices) {
val childEnv = Bindings.get(res, i, env)
val childEnv = env.extend(Bindings.get(res, i))
val child = newChildren(i)
val newChild = f(child, childEnv)
if (!(newChild eq child)) {
Expand All @@ -90,7 +90,7 @@ abstract class BaseIR {
f: (BaseIR, Int, E) => StackFrame[Unit]
): StackFrame[Unit] =
childrenSeq.view.zipWithIndex.foreachRecur { case (child, i) =>
val childEnv = Bindings.get(this, i, env)
val childEnv = env.extend(Bindings.get(this, i))
f(child, i, childEnv)
}
}
137 changes: 29 additions & 108 deletions hail/src/main/scala/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,142 +7,63 @@ import is.hail.utils.FastSeq

import scala.collection.mutable

object SegregatedBindingEnv {
def apply[A, B](env: BindingEnv[A]): SegregatedBindingEnv[A, B] =
SegregatedBindingEnv(env, env.dropBindings)
object Binds {
def apply(x: IR, v: String, i: Int): Boolean =
Bindings.get(x, i).eval.exists(_._1 == v)
}

case class SegregatedBindingEnv[A, B](
childEnvWithoutBindings: BindingEnv[A],
newBindings: BindingEnv[B],
) extends GenericBindingEnv[SegregatedBindingEnv[A, B], B] {
def newBlock(
eval: Seq[(String, B)] = Seq.empty,
agg: AggEnv[B] = AggEnv.NoOp,
scan: AggEnv[B] = AggEnv.NoOp,
relational: Seq[(String, B)] = Seq.empty,
dropEval: Boolean = false,
): SegregatedBindingEnv[A, B] =
SegregatedBindingEnv(
childEnvWithoutBindings.newBlock(agg = agg.empty, scan = scan.empty, dropEval = dropEval),
newBindings.newBlock(eval, agg, scan, relational, dropEval),
)

def unified(implicit ev: BindingEnv[B] =:= BindingEnv[A]): BindingEnv[A] =
childEnvWithoutBindings.merge(newBindings)

def mapNewBindings[C](f: (String, B) => C): SegregatedBindingEnv[A, C] = SegregatedBindingEnv(
childEnvWithoutBindings,
newBindings.mapValuesWithKey(f),
)

override def promoteAgg: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.promoteAgg,
newBindings.promoteAgg,
)

override def promoteScan: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.promoteScan,
newBindings.promoteScan,
)

override def bindEval(bindings: (String, B)*): SegregatedBindingEnv[A, B] =
copy(newBindings = newBindings.bindEval(bindings: _*))

override def noEval: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.copy(eval = Env.empty),
newBindings.copy(eval = Env.empty),
)

override def bindAgg(bindings: (String, B)*): SegregatedBindingEnv[A, B] =
copy(newBindings = newBindings.bindAgg(bindings: _*))

override def bindScan(bindings: (String, B)*): SegregatedBindingEnv[A, B] =
copy(newBindings = newBindings.bindScan(bindings: _*))

override def createAgg: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.createAgg,
newBindings.createAgg,
)

override def createScan: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.createScan,
newBindings.createScan,
)

override def noAgg: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.noAgg,
newBindings.noAgg,
)

override def noScan: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.noScan,
newBindings.noScan,
final case class Bindings[+T](
eval: IndexedSeq[(String, T)] = FastSeq.empty,
agg: AggEnv[T] = AggEnv.NoOp,
scan: AggEnv[T] = AggEnv.NoOp,
relational: IndexedSeq[(String, T)] = FastSeq.empty,
dropEval: Boolean = false,
) {
def map[U](f: (String, T) => U): Bindings[U] = Bindings(
eval.map { case (n, v) => n -> f(n, v) },
agg.map(f),
scan.map(f),
relational.map { case (n, v) => n -> f(n, v) },
dropEval,
)

override def onlyRelational(keepAggCapabilities: Boolean = false): SegregatedBindingEnv[A, B] =
SegregatedBindingEnv(
childEnvWithoutBindings.onlyRelational(keepAggCapabilities),
newBindings.onlyRelational(keepAggCapabilities),
)

override def bindRelational(bindings: (String, B)*): SegregatedBindingEnv[A, B] =
copy(newBindings = newBindings.bindRelational(bindings: _*))
}
def allEmpty: Boolean =
eval.isEmpty && agg.isEmpty && scan.isEmpty && relational.isEmpty

object Binds {
def apply(x: IR, v: String, i: Int): Boolean =
Bindings.get(x, i, BindingEnv.empty[Type].createAgg.createScan).eval.contains(v)
def dropBindings[U]: Bindings[U] =
Bindings(FastSeq.empty, agg.empty, scan.empty, FastSeq.empty, dropEval)
}

final case class Bindings(
eval: IndexedSeq[(String, Type)] = FastSeq.empty,
agg: AggEnv[Type] = AggEnv.NoOp,
scan: AggEnv[Type] = AggEnv.NoOp,
relational: IndexedSeq[(String, Type)] = FastSeq.empty,
dropEval: Boolean = false,
)

object Bindings {
val empty: Bindings = Bindings(FastSeq.empty, AggEnv.NoOp, AggEnv.NoOp, FastSeq.empty, false)
val empty: Bindings[Nothing] =
Bindings(FastSeq.empty, AggEnv.NoOp, AggEnv.NoOp, FastSeq.empty, false)

/** Returns the environment of the `i`th child of `ir` given the environment of the parent node
* `ir`.
*/
def get[E <: GenericBindingEnv[E, Type]](ir: BaseIR, i: Int, baseEnv: E): E = {
val bindings = ir match {
def get(ir: BaseIR, i: Int): Bindings[Type] =
ir match {
case ir: MatrixIR => childEnvMatrix(ir, i)
case ir: TableIR => childEnvTable(ir, i)
case ir: BlockMatrixIR => childEnvBlockMatrix(ir, i)
case ir: IR => childEnvValue(ir, i)
}
baseEnv.extend(bindings)
}

/** Like [[Bindings.get]], but keeps separate any new bindings introduced by `ir`. Always
* satisfies the identity
* {{{
* Bindings.segregated(ir, i, baseEnv).unified == Bindings(ir, i, baseEnv)
* }}}
*/
def segregated[A](ir: BaseIR, i: Int, baseEnv: BindingEnv[A]): SegregatedBindingEnv[A, Type] =
get(ir, i, SegregatedBindingEnv(baseEnv))

// Create a `Bindings` which cannot see anything bound in the enclosing context.
private def inFreshScope(
eval: IndexedSeq[(String, Type)] = FastSeq.empty,
agg: Option[IndexedSeq[(String, Type)]] = None,
scan: Option[IndexedSeq[(String, Type)]] = None,
relational: IndexedSeq[(String, Type)] = FastSeq.empty,
): Bindings = Bindings(
): Bindings[Type] = Bindings(
eval,
agg.map(AggEnv.Create(_)).getOrElse(AggEnv.Drop),
scan.map(AggEnv.Create(_)).getOrElse(AggEnv.Drop),
relational,
dropEval = true,
)

private def childEnvMatrix(ir: MatrixIR, i: Int): Bindings = {
private def childEnvMatrix(ir: MatrixIR, i: Int): Bindings[Type] = {
ir match {
case MatrixMapRows(child, _) if i == 1 =>
Bindings.inFreshScope(
Expand Down Expand Up @@ -197,7 +118,7 @@ object Bindings {
}
}

private def childEnvTable(ir: TableIR, i: Int): Bindings = {
private def childEnvTable(ir: TableIR, i: Int): Bindings[Type] = {
ir match {
case TableFilter(child, _) if i == 1 =>
Bindings.inFreshScope(child.typ.rowBindings)
Expand Down Expand Up @@ -239,7 +160,7 @@ object Bindings {
}
}

private def childEnvBlockMatrix(ir: BlockMatrixIR, i: Int): Bindings = {
private def childEnvBlockMatrix(ir: BlockMatrixIR, i: Int): Bindings[Type] = {
ir match {
case BlockMatrixMap(_, eltName, _, _) if i == 1 =>
Bindings.inFreshScope(FastSeq(eltName -> TFloat64))
Expand All @@ -252,7 +173,7 @@ object Bindings {
}
}

private def childEnvValue(ir: IR, i: Int): Bindings =
private def childEnvValue(ir: IR, i: Int): Bindings[Type] =
ir match {
case Block(bindings, _) =>
val eval = mutable.ArrayBuilder.make[(String, Type)]
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/ComputeUsesAndDefs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ object ComputeUsesAndDefs {
ir.children
.zipWithIndex
.foreach { case (child, i) =>
val bindings = Bindings.segregated(ir, i, env).mapNewBindings((_, _) => ir)
if (!bindings.newBindings.allEmpty && !uses.contains(ir))
val newBindings = Bindings.get(ir, i).map((_, _) => ir)
if (!newBindings.allEmpty && !uses.contains(ir))
uses.bind(ir, mutable.Set.empty[RefEquality[BaseRef]])
compute(child, bindings.unified)
compute(child, env.extend(newBindings))
}
}

Expand Down
Loading

0 comments on commit 5a13d5a

Please sign in to comment.