Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] declarative Bindings #14496

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ abstract class BaseIR {

def forEachChildWithEnv[E <: GenericBindingEnv[E, Type]](env: E)(f: (BaseIR, E) => Unit): Unit =
childrenSeq.view.zipWithIndex.foreach { case (child, i) =>
val childEnv = Bindings.get(this, i, env)
val childEnv = env.extend(Bindings.get(this, i))
f(child, childEnv)
}

Expand All @@ -73,7 +73,7 @@ abstract class BaseIR {
val newChildren = childrenSeq.toArray
var res = this
for (i <- newChildren.indices) {
val childEnv = Bindings.get(res, i, env)
val childEnv = env.extend(Bindings.get(res, i))
val child = newChildren(i)
val newChild = f(child, childEnv)
if (!(newChild eq child)) {
Expand All @@ -90,7 +90,7 @@ abstract class BaseIR {
f: (BaseIR, Int, E) => StackFrame[Unit]
): StackFrame[Unit] =
childrenSeq.view.zipWithIndex.foreachRecur { case (child, i) =>
val childEnv = Bindings.get(this, i, env)
val childEnv = env.extend(Bindings.get(this, i))
f(child, i, childEnv)
}
}
137 changes: 29 additions & 108 deletions hail/src/main/scala/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,142 +7,63 @@ import is.hail.utils.FastSeq

import scala.collection.mutable

object SegregatedBindingEnv {
def apply[A, B](env: BindingEnv[A]): SegregatedBindingEnv[A, B] =
SegregatedBindingEnv(env, env.dropBindings)
object Binds {
def apply(x: IR, v: String, i: Int): Boolean =
Bindings.get(x, i).eval.exists(_._1 == v)
}

case class SegregatedBindingEnv[A, B](
childEnvWithoutBindings: BindingEnv[A],
newBindings: BindingEnv[B],
) extends GenericBindingEnv[SegregatedBindingEnv[A, B], B] {
def newBlock(
eval: Seq[(String, B)] = Seq.empty,
agg: AggEnv[B] = AggEnv.NoOp,
scan: AggEnv[B] = AggEnv.NoOp,
relational: Seq[(String, B)] = Seq.empty,
dropEval: Boolean = false,
): SegregatedBindingEnv[A, B] =
SegregatedBindingEnv(
childEnvWithoutBindings.newBlock(agg = agg.empty, scan = scan.empty, dropEval = dropEval),
newBindings.newBlock(eval, agg, scan, relational, dropEval),
)

def unified(implicit ev: BindingEnv[B] =:= BindingEnv[A]): BindingEnv[A] =
childEnvWithoutBindings.merge(newBindings)

def mapNewBindings[C](f: (String, B) => C): SegregatedBindingEnv[A, C] = SegregatedBindingEnv(
childEnvWithoutBindings,
newBindings.mapValuesWithKey(f),
)

override def promoteAgg: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.promoteAgg,
newBindings.promoteAgg,
)

override def promoteScan: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.promoteScan,
newBindings.promoteScan,
)

override def bindEval(bindings: (String, B)*): SegregatedBindingEnv[A, B] =
copy(newBindings = newBindings.bindEval(bindings: _*))

override def noEval: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.copy(eval = Env.empty),
newBindings.copy(eval = Env.empty),
)

override def bindAgg(bindings: (String, B)*): SegregatedBindingEnv[A, B] =
copy(newBindings = newBindings.bindAgg(bindings: _*))

override def bindScan(bindings: (String, B)*): SegregatedBindingEnv[A, B] =
copy(newBindings = newBindings.bindScan(bindings: _*))

override def createAgg: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.createAgg,
newBindings.createAgg,
)

override def createScan: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.createScan,
newBindings.createScan,
)

override def noAgg: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.noAgg,
newBindings.noAgg,
)

override def noScan: SegregatedBindingEnv[A, B] = SegregatedBindingEnv(
childEnvWithoutBindings.noScan,
newBindings.noScan,
final case class Bindings[+T](
eval: IndexedSeq[(String, T)] = FastSeq.empty,
agg: AggEnv[T] = AggEnv.NoOp,
scan: AggEnv[T] = AggEnv.NoOp,
relational: IndexedSeq[(String, T)] = FastSeq.empty,
dropEval: Boolean = false,
) {
def map[U](f: (String, T) => U): Bindings[U] = Bindings(
patrick-schultz marked this conversation as resolved.
Show resolved Hide resolved
eval.map { case (n, v) => n -> f(n, v) },
agg.map(f),
scan.map(f),
relational.map { case (n, v) => n -> f(n, v) },
dropEval,
)

override def onlyRelational(keepAggCapabilities: Boolean = false): SegregatedBindingEnv[A, B] =
SegregatedBindingEnv(
childEnvWithoutBindings.onlyRelational(keepAggCapabilities),
newBindings.onlyRelational(keepAggCapabilities),
)

override def bindRelational(bindings: (String, B)*): SegregatedBindingEnv[A, B] =
copy(newBindings = newBindings.bindRelational(bindings: _*))
}
def allEmpty: Boolean =
eval.isEmpty && agg.isEmpty && scan.isEmpty && relational.isEmpty

object Binds {
def apply(x: IR, v: String, i: Int): Boolean =
Bindings.get(x, i, BindingEnv.empty[Type].createAgg.createScan).eval.contains(v)
def dropBindings[U]: Bindings[U] =
Bindings(FastSeq.empty, agg.empty, scan.empty, FastSeq.empty, dropEval)
}

final case class Bindings(
eval: IndexedSeq[(String, Type)] = FastSeq.empty,
agg: AggEnv[Type] = AggEnv.NoOp,
scan: AggEnv[Type] = AggEnv.NoOp,
relational: IndexedSeq[(String, Type)] = FastSeq.empty,
dropEval: Boolean = false,
)

object Bindings {
val empty: Bindings = Bindings(FastSeq.empty, AggEnv.NoOp, AggEnv.NoOp, FastSeq.empty, false)
val empty: Bindings[Nothing] =
Bindings(FastSeq.empty, AggEnv.NoOp, AggEnv.NoOp, FastSeq.empty, false)

/** Returns the environment of the `i`th child of `ir` given the environment of the parent node
* `ir`.
*/
def get[E <: GenericBindingEnv[E, Type]](ir: BaseIR, i: Int, baseEnv: E): E = {
val bindings = ir match {
def get(ir: BaseIR, i: Int): Bindings[Type] =
ir match {
case ir: MatrixIR => childEnvMatrix(ir, i)
case ir: TableIR => childEnvTable(ir, i)
case ir: BlockMatrixIR => childEnvBlockMatrix(ir, i)
case ir: IR => childEnvValue(ir, i)
}
baseEnv.extend(bindings)
}

/** Like [[Bindings.get]], but keeps separate any new bindings introduced by `ir`. Always
* satisfies the identity
* {{{
* Bindings.segregated(ir, i, baseEnv).unified == Bindings(ir, i, baseEnv)
* }}}
*/
def segregated[A](ir: BaseIR, i: Int, baseEnv: BindingEnv[A]): SegregatedBindingEnv[A, Type] =
get(ir, i, SegregatedBindingEnv(baseEnv))

// Create a `Bindings` which cannot see anything bound in the enclosing context.
private def inFreshScope(
eval: IndexedSeq[(String, Type)] = FastSeq.empty,
agg: Option[IndexedSeq[(String, Type)]] = None,
scan: Option[IndexedSeq[(String, Type)]] = None,
relational: IndexedSeq[(String, Type)] = FastSeq.empty,
): Bindings = Bindings(
): Bindings[Type] = Bindings(
eval,
agg.map(AggEnv.Create(_)).getOrElse(AggEnv.Drop),
scan.map(AggEnv.Create(_)).getOrElse(AggEnv.Drop),
relational,
dropEval = true,
)

private def childEnvMatrix(ir: MatrixIR, i: Int): Bindings = {
private def childEnvMatrix(ir: MatrixIR, i: Int): Bindings[Type] = {
ir match {
case MatrixMapRows(child, _) if i == 1 =>
Bindings.inFreshScope(
Expand Down Expand Up @@ -197,7 +118,7 @@ object Bindings {
}
}

private def childEnvTable(ir: TableIR, i: Int): Bindings = {
private def childEnvTable(ir: TableIR, i: Int): Bindings[Type] = {
ir match {
case TableFilter(child, _) if i == 1 =>
Bindings.inFreshScope(child.typ.rowBindings)
Expand Down Expand Up @@ -239,7 +160,7 @@ object Bindings {
}
}

private def childEnvBlockMatrix(ir: BlockMatrixIR, i: Int): Bindings = {
private def childEnvBlockMatrix(ir: BlockMatrixIR, i: Int): Bindings[Type] = {
ir match {
case BlockMatrixMap(_, eltName, _, _) if i == 1 =>
Bindings.inFreshScope(FastSeq(eltName -> TFloat64))
Expand All @@ -252,7 +173,7 @@ object Bindings {
}
}

private def childEnvValue(ir: IR, i: Int): Bindings =
private def childEnvValue(ir: IR, i: Int): Bindings[Type] =
ir match {
case Block(bindings, _) =>
val eval = mutable.ArrayBuilder.make[(String, Type)]
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/ComputeUsesAndDefs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ object ComputeUsesAndDefs {
ir.children
.zipWithIndex
.foreach { case (child, i) =>
val bindings = Bindings.segregated(ir, i, env).mapNewBindings((_, _) => ir)
if (!bindings.newBindings.allEmpty && !uses.contains(ir))
val newBindings = Bindings.get(ir, i).map((_, _) => ir)
if (!newBindings.allEmpty && !uses.contains(ir))
uses.bind(ir, mutable.Set.empty[RefEquality[BaseRef]])
compute(child, bindings.unified)
compute(child, env.extend(newBindings))
}
}

Expand Down
Loading