Skip to content

Commit

Permalink
[query] refactor BaseIR so children only exposes iterable interface (#…
Browse files Browse the repository at this point in the history
…13214)

This is a simple tidying refactoring to hide the `children` array from
consumers of `BaseIR`. My main motivation is to make it easier to
explore other ir data structure designs, and to migrate to a new design
in the future, e.g. to allow for in-place mutation without requiring
large-scale changes to every compiler pass, and to simplify how we
encode binding structure.
  • Loading branch information
patrick-schultz authored Jun 26, 2023
1 parent 3160473 commit 3f4c39a
Show file tree
Hide file tree
Showing 32 changed files with 219 additions and 235 deletions.
40 changes: 35 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
@@ -1,23 +1,53 @@
package is.hail.expr.ir

import is.hail.types.BaseType
import is.hail.utils.StackSafe._
import is.hail.utils._

abstract class BaseIR {
def typ: BaseType

def children: IndexedSeq[BaseIR]
protected def childrenSeq: IndexedSeq[BaseIR]

def copy(newChildren: IndexedSeq[BaseIR]): BaseIR
def children: Iterable[BaseIR] = childrenSeq

def deepCopy(): this.type = copy(newChildren = children.map(_.deepCopy())).asInstanceOf[this.type]
protected def copy(newChildren: IndexedSeq[BaseIR]): BaseIR

def deepCopy(): this.type = copy(newChildren = childrenSeq.map(_.deepCopy())).asInstanceOf[this.type]

lazy val noSharing: this.type = if (HasIRSharing(this)) this.deepCopy() else this

def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = {
val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray
if (childrenSeq.elementsSameObjects(newChildren))
this
else
copy(newChildren)
}

def mapChildren(f: (BaseIR) => BaseIR): BaseIR = {
val newChildren = children.map(f)
if ((children, newChildren).zipped.forall(_ eq _))
val newChildren = childrenSeq.map(f)
if (childrenSeq.elementsSameObjects(newChildren))
this
else
copy(newChildren)
}

def mapChildrenWithIndexStackSafe(f: (BaseIR, Int) => StackFrame[BaseIR]): StackFrame[BaseIR] = {
call(childrenSeq.iterator.zipWithIndex.map(f.tupled).collectRecur).map { newChildren =>
if (childrenSeq.elementsSameObjects(newChildren))
this
else
copy(newChildren)
}
}

def mapChildrenStackSafe(f: BaseIR => StackFrame[BaseIR]): StackFrame[BaseIR] = {
call(childrenSeq.mapRecur(f)).map { newChildren =>
if (childrenSeq.elementsSameObjects(newChildren))
this
else
copy(newChildren)
}
}
}
26 changes: 13 additions & 13 deletions hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ abstract sealed class BlockMatrixIR extends BaseIR {
case class BlockMatrixRead(reader: BlockMatrixReader) extends BlockMatrixIR {
override lazy val typ: BlockMatrixType = reader.fullType

lazy val children: IndexedSeq[BaseIR] = Array.empty[BlockMatrixIR]
lazy val childrenSeq: IndexedSeq[BaseIR] = Array.empty[BlockMatrixIR]

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixRead = {
assert(newChildren.isEmpty)
Expand Down Expand Up @@ -249,7 +249,7 @@ case class BlockMatrixMap(child: BlockMatrixIR, eltName: String, f: IR, needsDen
override lazy val typ: BlockMatrixType = child.typ
assert(!needsDense || !typ.isSparse)

lazy val children: IndexedSeq[BaseIR] = Array(child, f)
lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, f)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixMap = {
val IndexedSeq(newChild: BlockMatrixIR, newF: IR) = newChildren
Expand Down Expand Up @@ -378,7 +378,7 @@ case class BlockMatrixMap2(left: BlockMatrixIR, right: BlockMatrixIR, leftName:

override lazy val typ: BlockMatrixType = left.typ.copy(sparsity = sparsityStrategy.mergeSparsity(left.typ.sparsity, right.typ.sparsity))

lazy val children: IndexedSeq[BaseIR] = Array(left, right, f)
lazy val childrenSeq: IndexedSeq[BaseIR] = Array(left, right, f)

val blockCostIsLinear: Boolean = left.blockCostIsLinear && right.blockCostIsLinear

Expand Down Expand Up @@ -496,7 +496,7 @@ case class BlockMatrixDot(left: BlockMatrixIR, right: BlockMatrixIR) extends Blo
BlockMatrixType(left.typ.elementType, tensorShape, isRowVector, blockSize, sparsity)
}

lazy val children: IndexedSeq[BaseIR] = Array(left, right)
lazy val childrenSeq: IndexedSeq[BaseIR] = Array(left, right)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixDot = {
assert(newChildren.length == 2)
Expand Down Expand Up @@ -578,7 +578,7 @@ case class BlockMatrixBroadcast(
BlockMatrixType(child.typ.elementType, tensorShape, isRowVector, blockSize, sparsity)
}

lazy val children: IndexedSeq[BaseIR] = Array(child)
lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixBroadcast = {
assert(newChildren.length == 1)
Expand Down Expand Up @@ -655,7 +655,7 @@ case class BlockMatrixAgg(
BlockMatrixType(child.typ.elementType, shape, isRowVector, child.typ.blockSize, sparsity)
}

lazy val children: IndexedSeq[BaseIR] = Array(child)
lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixAgg = {
assert(newChildren.length == 1)
Expand Down Expand Up @@ -708,7 +708,7 @@ case class BlockMatrixFilter(
BlockMatrixType(child.typ.elementType, tensorShape, isRowVector, blockSize, sparsity)
}

override def children: IndexedSeq[BaseIR] = Array(child)
override def childrenSeq: IndexedSeq[BaseIR] = Array(child)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixFilter = {
assert(newChildren.length == 1)
Expand All @@ -734,7 +734,7 @@ case class BlockMatrixDensify(child: BlockMatrixIR) extends BlockMatrixIR {

def blockCostIsLinear: Boolean = child.blockCostIsLinear

val children: IndexedSeq[BaseIR] = FastIndexedSeq(child)
val childrenSeq: IndexedSeq[BaseIR] = FastIndexedSeq(child)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR = {
val IndexedSeq(newChild: BlockMatrixIR) = newChildren
Expand Down Expand Up @@ -847,7 +847,7 @@ case class BlockMatrixSparsify(

def blockCostIsLinear: Boolean = child.blockCostIsLinear

val children: IndexedSeq[BaseIR] = Array(child)
val childrenSeq: IndexedSeq[BaseIR] = Array(child)

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR = {
val IndexedSeq(newChild: BlockMatrixIR) = newChildren
Expand Down Expand Up @@ -886,7 +886,7 @@ case class BlockMatrixSlice(child: BlockMatrixIR, slices: IndexedSeq[IndexedSeq[
BlockMatrixType(child.typ.elementType, tensorShape, isRowVector, child.typ.blockSize, sparsity)
}

override def children: IndexedSeq[BaseIR] = Array(child)
override def childrenSeq: IndexedSeq[BaseIR] = Array(child)

override def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR = {
assert(newChildren.length == 1)
Expand Down Expand Up @@ -936,7 +936,7 @@ case class ValueToBlockMatrix(
}
}

lazy val children: IndexedSeq[BaseIR] = Array(child)
lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child)

def copy(newChildren: IndexedSeq[BaseIR]): ValueToBlockMatrix = {
assert(newChildren.length == 1)
Expand Down Expand Up @@ -971,7 +971,7 @@ case class BlockMatrixRandom(
override lazy val typ: BlockMatrixType =
BlockMatrixType.dense(TFloat64, shape(0), shape(1), blockSize)

lazy val children: IndexedSeq[BaseIR] = Array.empty[BaseIR]
lazy val childrenSeq: IndexedSeq[BaseIR] = Array.empty[BaseIR]

def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixRandom = {
assert(newChildren.isEmpty)
Expand All @@ -986,7 +986,7 @@ case class BlockMatrixRandom(
case class RelationalLetBlockMatrix(name: String, value: IR, body: BlockMatrixIR) extends BlockMatrixIR {
override lazy val typ: BlockMatrixType = body.typ

def children: IndexedSeq[BaseIR] = Array(value, body)
def childrenSeq: IndexedSeq[BaseIR] = Array(value, body)

val blockCostIsLinear: Boolean = body.blockCostIsLinear

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ object ComputeUsesAndDefs {
}

ir.children
.iterator
.zipWithIndex
.foreach { case (child, i) =>
val e = ChildEnvWithoutBindings(ir, i, env)
Expand Down
7 changes: 3 additions & 4 deletions hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ object ForwardLets {
!ContainsAggIntermediate(value)
}

def mapRewrite(): BaseIR = ir.copy(ir.children
.iterator
.zipWithIndex
.map { case (ir1, i) => rewrite(ir1, ChildEnvWithoutBindings(ir, i, env)) }.toFastIndexedSeq)
def mapRewrite(): BaseIR = ir.mapChildrenWithIndex { (ir1, i) =>
rewrite(ir1, ChildEnvWithoutBindings(ir, i, env))
}

ir match {
case l@Let(name, value, body) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ object ForwardRelationalLets {
} else RelationalLetBlockMatrix(name, recur(value).asInstanceOf[IR], recur(body).asInstanceOf[BlockMatrixIR])
case x@RelationalRef(name, _) =>
m.getOrElse(name, x)
case _ => ir1.copy(ir1.children.map(recur))
case _ => ir1.mapChildren(recur)
}

recur(ir0)
Expand Down
1 change: 0 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/FreeVariables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ object FreeVariables {
zeroFreeVars.merge(seqOpFreeVars).merge(combOpFreeVars)
case _ =>
ir1.children
.iterator
.zipWithIndex
.map {
case (child: IR, i) =>
Expand Down
8 changes: 6 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ sealed trait IR extends BaseIR {
_typ
}

lazy val children: IndexedSeq[BaseIR] =
protected lazy val childrenSeq: IndexedSeq[BaseIR] =
Children(this)

override def copy(newChildren: IndexedSeq[BaseIR]): IR =
protected override def copy(newChildren: IndexedSeq[BaseIR]): IR =
Copy(this, newChildren)

override def mapChildren(f: BaseIR => BaseIR): IR = super.mapChildren(f).asInstanceOf[IR]

override def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): IR = super.mapChildrenWithIndex(f).asInstanceOf[IR]

override def deepCopy(): this.type = {

val cp = super.deepCopy()
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Interpret.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ object Interpret {
else
aValue.asInstanceOf[IndexedSeq[Row]].filter(_ != null).map { case Row(k, v) => (k, v) }.toMap
case _: CastToArray | _: ToArray | _: ToStream =>
val c = ir.children(0).asInstanceOf[IR]
val c = ir.children.head.asInstanceOf[IR]
val cValue = interpret(c, env, args)
if (cValue == null)
null
Expand Down
12 changes: 2 additions & 10 deletions hail/src/main/scala/is/hail/expr/ir/LiftRelationalValues.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,11 @@ object LiftRelationalValues {
| _: BlockMatrixCollect
| _: TableGetGlobals) if ir.typ != TVoid =>
val ref = RelationalRef(genUID(), ir.asInstanceOf[IR].typ)
val rwChildren = ir.children.map(rewrite(_, ab, memo))
val newChild = if ((rwChildren, ir.children).zipped.forall(_ eq _))
ir
else
ir.copy(rwChildren)
val newChild = ir.mapChildren(rewrite(_, ab, memo))
ab += ((ref.name, newChild.asInstanceOf[IR]))
ref
case x =>
val rwChildren = x.children.map(rewrite(_, ab, memo))
if ((rwChildren, ir.children).zipped.forall(_ eq _))
ir
else
ir.copy(rwChildren)
x.mapChildren(rewrite(_, ab, memo))
}

val ab = new BoxedArrayBuilder[(String, IR)]
Expand Down
6 changes: 1 addition & 5 deletions hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,13 @@ object LowerMatrixIR {
ir: BaseIR,
ab: BoxedArrayBuilder[(String, IR)]
): BaseIR = {
val loweredChildren = ir.children.map {
ir.mapChildren {
case tir: TableIR => lower(ctx, tir, ab)
case mir: MatrixIR => throw new RuntimeException(s"expect specialized lowering rule for " +
s"${ ir.getClass.getName }\n Found MatrixIR child $mir")
case bmir: BlockMatrixIR => lower(ctx, bmir, ab)
case vir: IR => lower(ctx, vir, ab)
}
if ((ir.children, loweredChildren).zipped.forall(_ eq _))
ir
else
ir.copy(loweredChildren)
}

def colVals(tir: TableIR): IR =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,8 @@ object LowerOrInterpretNonCompilable {
result
}

def rewriteChildren(x: BaseIR, m: mutable.Map[String, IR]): BaseIR = {
val children = x.children
val newChildren = children.map(rewrite(_, m))

// only recons if necessary
if ((children, newChildren).zipped.forall(_ eq _))
x
else
x.copy(newChildren)
}

def rewriteChildren(x: BaseIR, m: mutable.Map[String, IR]): BaseIR =
x.mapChildren(rewrite(_, m))

def rewrite(x: BaseIR, m: mutable.Map[String, IR]): BaseIR = {

Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/MapIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ object MapIR {
def apply(f: IR => IR)(ir: IR): IR = ir match {
case ta: TableAggregate => ta
case ma: MatrixAggregate => ma
case _ => Copy(ir, Children(ir).map {
case _ => ir.mapChildren {
case c: IR => f(c)
case c => c
})
}
}

def mapBaseIR(ir: BaseIR, f: BaseIR => BaseIR): BaseIR = f(ir.copy(newChildren = ir.children.map(mapBaseIR(_, f))))
def mapBaseIR(ir: BaseIR, f: BaseIR => BaseIR): BaseIR = f(ir.mapChildren(mapBaseIR(_, f)))
}

object VisitIR {
Expand Down
Loading

0 comments on commit 3f4c39a

Please sign in to comment.