Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] rewrite FreeVariables to not duplicate binding structure #14451

Merged
merged 3 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,14 @@ abstract class BaseIR {
}
}

def forEachChildWithEnv(env: BindingEnv[Type])(f: (BaseIR, BindingEnv[Type]) => Unit): Unit =
def forEachChildWithEnv[E <: GenericBindingEnv[E, Type]](env: E)(f: (BaseIR, E) => Unit): Unit =
childrenSeq.view.zipWithIndex.foreach { case (child, i) =>
val childEnv = Bindings(this, i, env)
f(child, childEnv)
}

def mapChildrenWithEnv(env: BindingEnv[Type])(f: (BaseIR, BindingEnv[Type]) => BaseIR): BaseIR = {
def mapChildrenWithEnv[E <: GenericBindingEnv[E, Type]](env: E)(f: (BaseIR, E) => BaseIR)
: BaseIR = {
val newChildren = childrenSeq.toArray
var res = this
for (i <- newChildren.indices) {
Expand All @@ -83,10 +84,10 @@ abstract class BaseIR {
res
}

def forEachChildWithEnvStackSafe(
env: BindingEnv[Type]
def forEachChildWithEnvStackSafe[E <: GenericBindingEnv[E, Type]](
env: E
)(
f: (BaseIR, Int, BindingEnv[Type]) => StackFrame[Unit]
f: (BaseIR, Int, E) => StackFrame[Unit]
): StackFrame[Unit] =
childrenSeq.view.zipWithIndex.foreachRecur { case (child, i) =>
val childEnv = Bindings(this, i, env)
Expand Down
157 changes: 76 additions & 81 deletions hail/src/main/scala/is/hail/expr/ir/FreeVariables.scala
Original file line number Diff line number Diff line change
@@ -1,90 +1,85 @@
package is.hail.expr.ir

import is.hail.types.virtual.Type

import scala.collection.mutable

class FreeVariableEnv(boundVars: Env[Unit], freeVars: mutable.Set[String]) {
def this(boundVars: Env[Unit]) =
this(boundVars, mutable.Set.empty)

private def copy(boundVars: Env[Unit]): FreeVariableEnv =
new FreeVariableEnv(boundVars, freeVars)

def visitRef(name: String): Unit =
if (!boundVars.contains(name))
freeVars += name

def bindIterable(bindings: Seq[(String, Type)]): FreeVariableEnv =
copy(boundVars.bindIterable(bindings.view.map(b => (b._1, ()))))

def getFreeVars: Env[Unit] = new Env[Unit].bindIterable(freeVars.view.map(n => (n, ())))
}

case class FreeVariableBindingEnv(
evalVars: Option[FreeVariableEnv],
aggVars: Option[FreeVariableEnv],
scanVars: Option[FreeVariableEnv],
) extends GenericBindingEnv[FreeVariableBindingEnv, Type] {
def visitRef(name: String): Unit =
evalVars.foreach(_.visitRef(name))

def getFreeVars: BindingEnv[Unit] = BindingEnv(
evalVars.map(_.getFreeVars).getOrElse(Env.empty),
aggVars.map(_.getFreeVars),
scanVars.map(_.getFreeVars),
)

override def promoteAgg: FreeVariableBindingEnv =
copy(evalVars = aggVars, aggVars = None)

override def promoteScan: FreeVariableBindingEnv =
copy(evalVars = scanVars, scanVars = None)

override def bindEval(bindings: (String, Type)*): FreeVariableBindingEnv =
copy(evalVars = evalVars.map(_.bindIterable(bindings)))

override def dropEval: FreeVariableBindingEnv = copy(evalVars = None)

override def bindAgg(bindings: (String, Type)*): FreeVariableBindingEnv =
copy(aggVars = aggVars.map(_.bindIterable(bindings)))

override def bindScan(bindings: (String, Type)*): FreeVariableBindingEnv =
copy(scanVars = scanVars.map(_.bindIterable(bindings)))

override def createAgg: FreeVariableBindingEnv = copy(aggVars = evalVars)

override def createScan: FreeVariableBindingEnv = copy(scanVars = evalVars)

override def noAgg: FreeVariableBindingEnv = copy(aggVars = None)

override def noScan: FreeVariableBindingEnv = copy(scanVars = None)

override def onlyRelational(keepAggCapabilities: Boolean): FreeVariableBindingEnv =
FreeVariableBindingEnv(None, None, None)

override def bindRelational(bindings: (String, Type)*): FreeVariableBindingEnv =
this
}

object FreeVariables {
def apply(ir: IR, supportsAgg: Boolean, supportsScan: Boolean): BindingEnv[Unit] = {

def compute(ir1: IR, baseEnv: BindingEnv[Unit]): BindingEnv[Unit] = {
ir1 match {
case Ref(name, _) =>
baseEnv.bindEval(name, ())
case TableAggregate(_, _) => baseEnv
case MatrixAggregate(_, _) => baseEnv
case StreamAggScan(a, name, query) =>
val aE = compute(a, baseEnv)
val qE = compute(query, baseEnv.copy(scan = Some(Env.empty)))
aE.merge(qE.copy(eval = qE.eval.bindIterable(qE.scan.get.m - name), scan = baseEnv.scan))
case StreamAgg(a, name, query) =>
val aE = compute(a, baseEnv)
val qE = compute(query, baseEnv.copy(agg = Some(Env.empty)))
aE.merge(qE.copy(eval = qE.eval.bindIterable(qE.agg.get.m - name), agg = baseEnv.agg))
case ApplyAggOp(init, seq, _) =>
val initEnv = baseEnv.copy(agg = None)
val initFreeVars = init.iterator.map(x => compute(x, initEnv)).fold(initEnv)(_.merge(_))
.copy(agg = Some(Env.empty[Unit]))
val seqEnv = baseEnv.promoteAgg
seq.iterator.map { x =>
val e = compute(x, seqEnv)
e.copy(eval = Env.empty[Unit], agg = Some(e.eval))
}.fold(initFreeVars)(_.merge(_))
case ApplyScanOp(init, seq, _) =>
val initEnv = baseEnv.copy(scan = None)
val initFreeVars = init.iterator.map(x => compute(x, initEnv)).fold(initEnv)(_.merge(_))
.copy(scan = Some(Env.empty[Unit]))
val seqEnv = baseEnv.promoteScan
seq.iterator.map { x =>
val e = compute(x, seqEnv)
e.copy(eval = Env.empty[Unit], scan = Some(e.eval))
}.fold(initFreeVars)(_.merge(_))
case AggFold(zero, seqOp, combOp, accumName, otherAccumName, isScan) =>
val zeroEnv = if (isScan) baseEnv.copy(scan = None) else baseEnv.copy(agg = None)
val zeroFreeVarsCompute = compute(zero, zeroEnv)
val zeroFreeVars = if (isScan) zeroFreeVarsCompute.copy(scan = Some(Env.empty[Unit]))
else zeroFreeVarsCompute.copy(agg = Some(Env.empty[Unit]))
val seqOpEnv = if (isScan) baseEnv.promoteScan else baseEnv.promoteAgg
val seqOpFreeVarsCompute = compute(seqOp, seqOpEnv)
val seqOpFreeVars = if (isScan) {
seqOpFreeVarsCompute.copy(
eval = Env.empty[Unit],
scan = Some(seqOpFreeVarsCompute.eval),
)
} else {
seqOpFreeVarsCompute.copy(eval = Env.empty[Unit], agg = Some(seqOpFreeVarsCompute.eval))
}
val combEval = Env.fromSeq(IndexedSeq((accumName, {}), (otherAccumName, {})))
val combOpFreeVarsCompute = compute(combOp, baseEnv.copy(eval = combEval))
val combOpFreeVars = combOpFreeVarsCompute.copy(
eval = Env.empty[Unit],
scan = Some(combOpFreeVarsCompute.eval),
)
zeroFreeVars.merge(seqOpFreeVars).merge(combOpFreeVars)
val env = FreeVariableBindingEnv(
Some(new FreeVariableEnv(Env.empty)),
if (supportsAgg) Some(new FreeVariableEnv(Env.empty)) else None,
if (supportsScan) Some(new FreeVariableEnv(Env.empty)) else None,
)
VisitIR.withEnv(ir, env) { (ir, env) =>
ir match {
case Ref(name, _) => env.visitRef(name)
case _ =>
ir1.children
.zipWithIndex
.map {
case (child: IR, i) =>
val bindings = Bindings.segregated(ir1, i, baseEnv)
val childEnv = bindings.childEnvWithoutBindings
val sub = compute(child, childEnv).subtract(bindings.newBindings)
if (UsesAggEnv(ir1, i))
sub.copy(eval = Env.empty[Unit], agg = Some(sub.eval), scan = baseEnv.scan)
else if (UsesScanEnv(ir1, i))
sub.copy(eval = Env.empty[Unit], agg = baseEnv.agg, scan = Some(sub.eval))
else
sub
case _ =>
baseEnv
}
.fold(baseEnv)(_.merge(_))
}
}

compute(
ir,
BindingEnv(
Env.empty,
if (supportsAgg) Some(Env.empty[Unit]) else None,
if (supportsScan) Some(Env.empty[Unit]) else None,
),
)
env.getFreeVars
}
}
8 changes: 8 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/MapIR.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package is.hail.expr.ir

import is.hail.types.virtual.Type

object MapIR {
def apply(f: IR => IR)(ir: IR): IR = ir match {
case ta: TableAggregate => ta
Expand All @@ -18,4 +20,10 @@ object VisitIR {
f(ir)
ir.children.foreach(apply(_)(f))
}

def withEnv[V, E <: GenericBindingEnv[E, Type]](ir: BaseIR, env: E)(f: (BaseIR, E) => Unit)
: Unit = {
f(ir, env)
ir.forEachChildWithEnv(env)(withEnv(_, _)(f))
}
}
37 changes: 34 additions & 3 deletions hail/src/test/scala/is/hail/expr/ir/IRSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4432,6 +4432,40 @@ class IRSuite extends HailSuite {
assert(!HasIRSharing(ctx)(ir1.deepCopy()))
}

@Test def freeVariables(): Unit = {
val stream = rangeIR(5)
val sumSig = AggSignature(Sum(), IndexedSeq(), IndexedSeq(TInt32))
val x = Ref("x", TInt32)
val y = Ref("y", TInt32)
val z = Ref("z", TInt32)
val explodeIR = AggExplode(
stream,
"x",
z + ApplyAggOp(FastSeq.empty, FastSeq(x + y), sumSig),
false,
)

assert(FreeVariables(explodeIR, true, true) == BindingEnv[Unit](
Env(("z", ())),
Some(Env(("y", ()))),
Some(Env()),
))
assert(FreeVariables(explodeIR, false, false) == BindingEnv[Unit](Env(("z", ()))))

val streamAggIR = StreamAgg(
stream,
"x",
z + ApplyAggOp(FastSeq.empty, FastSeq(x + y), sumSig),
)
assert(
FreeVariables(streamAggIR, true, true) == BindingEnv[Unit](
Env(("z", ()), ("y", ())),
Some(Env()),
Some(Env()),
)
)
}

@Test def freeVariablesAggScanBindingEnv(): Unit = {
def testFreeVarsHelper(ir: IR): Unit = {
val irFreeVarsTrue = FreeVariables.apply(ir, true, true)
Expand All @@ -4441,9 +4475,6 @@ class IRSuite extends HailSuite {
assert(irFreeVarsFalse.agg.isEmpty && irFreeVarsFalse.scan.isEmpty)
}

val liftIR = LiftMeOut(Ref("x", TInt32))
testFreeVarsHelper(liftIR)

val sumSig = AggSignature(Sum(), IndexedSeq(), IndexedSeq(TInt64))
val streamAggIR = StreamAgg(
StreamMap(StreamRange(I32(0), I32(4), I32(1)), "x", Cast(Ref("x", TInt32), TInt64)),
Expand Down