Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] disentangle IR parser from bindings #13990

Merged
merged 6 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def copy(self, *children):
return TailLoop(self.name, [(n, v) for (n, _), v in zip(self.params, params)], body)

def head_str(self):
return f'{escape_id(self.name)} ({" ".join([escape_id(n) for n, _ in self.params])})'
return f'{escape_id(self.name)} ({" ".join([escape_id(n) for n, _ in self.params])}) {self.body.typ._parsable_string()}'

def _eq(self, other):
return self.name == other.name
Expand Down Expand Up @@ -489,7 +489,7 @@ def copy(self, args):
return Recur(self.name, args, self.return_type)

def head_str(self):
return f'{escape_id(self.name)} {self.return_type._parsable_string()}'
return f'{escape_id(self.name)}'

def _eq(self, other):
return other.name == self.name
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/matrix_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name):
return MatrixFilterIntervals(child, self.intervals, self.point_type, self.keep)

def head_str(self):
return f'{dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}'
return f'{self.child.typ.row_key_type._parsable_string()} {dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}'

def _eq(self, other):
return self.intervals == other.intervals and self.point_type == other.point_type and self.keep == other.keep
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/ir/table_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def _handle_randomness(self, uid_field_name):
return TableFilterIntervals(self.child.handle_randomness(uid_field_name), self.intervals, self.point_type, self.keep)

def head_str(self):
return f'{dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}'
return f'{self.child.typ.key_type._parsable_string()} {dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}'

def _eq(self, other):
return self.intervals == other.intervals and self.point_type == other.point_type and self.keep == other.keep
Expand Down
18 changes: 10 additions & 8 deletions hail/python/test/hail/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def value_irs_env(self):
'mat': hl.tndarray(hl.tfloat64, 2),
'aa': hl.tarray(hl.tarray(hl.tint32)),
'sta': hl.tstream(hl.tarray(hl.tint32)),
'da': hl.tarray(hl.ttuple(hl.tint32, hl.tstr)),
'nd': hl.tndarray(hl.tfloat64, 1),
'sts': hl.tstream(hl.tstruct(x=hl.tint32, y=hl.tint64, z=hl.tfloat64)),
'da': hl.tstream(hl.ttuple(hl.tint32, hl.tstr)),
'nd': hl.tndarray(hl.tfloat64, 2),
'v': hl.tint32,
's': hl.tstruct(x=hl.tint32, y=hl.tint64, z=hl.tfloat64),
't': hl.ttuple(hl.tint32, hl.tint64, hl.tfloat64),
Expand All @@ -42,6 +43,7 @@ def value_irs(self):
mat = ir.Ref('mat')
aa = ir.Ref('aa', env['aa'])
sta = ir.Ref('sta', env['sta'])
sts = ir.Ref('sts', env['sts'])
da = ir.Ref('da', env['da'])
nd = ir.Ref('nd', env['nd'])
v = ir.Ref('v', env['v'])
Expand Down Expand Up @@ -77,7 +79,7 @@ def aggregate(x):
ir.ArrayRef(a, i),
ir.ArrayLen(a),
ir.ArraySort(ir.ToStream(a), 'l', 'r', ir.ApplyComparisonOp("LT", ir.Ref('l', hl.tint32), ir.Ref('r', hl.tint32))),
ir.ToSet(a),
ir.ToSet(st),
ir.ToDict(da),
ir.ToArray(st),
ir.CastToArray(ir.NA(hl.tset(hl.tint32))),
Expand All @@ -89,17 +91,17 @@ def aggregate(x):
ir.NDArrayRef(nd, [ir.I64(1), ir.I64(2)]),
ir.NDArrayMap(nd, 'v', v),
ir.NDArrayMatMul(nd, nd),
ir.LowerBoundOnOrderedCollection(a, i, True),
ir.LowerBoundOnOrderedCollection(a, i, False),
ir.GroupByKey(da),
ir.RNGSplit(rngState, ir.MakeTuple([ir.I64(1), ir.MakeTuple([ir.I64(2), ir.I64(3)])])),
ir.RNGSplit(rngState, ir.MakeTuple([ir.I64(1), ir.I64(2), ir.I64(3)])),
ir.StreamMap(st, 'v', v),
ir.StreamZip([st, st], ['a', 'b'], ir.TrueIR(), 'ExtendNA'),
ir.StreamFilter(st, 'v', v),
ir.StreamFilter(st, 'v', c),
ir.StreamFlatMap(sta, 'v', ir.ToStream(v)),
ir.StreamFold(st, ir.I32(0), 'x', 'v', v),
ir.StreamScan(st, ir.I32(0), 'x', 'v', v),
ir.StreamWhiten(whitenStream, "newChunk", "prevWindow", 0, 0, 0, 0, False),
ir.StreamJoinRightDistinct(st, st, ['k'], ['k'], 'l', 'r', ir.I32(1), "left"),
ir.StreamWhiten(whitenStream, "newChunk", "prevWindow", 1, 1, 1, 1, False),
ir.StreamJoinRightDistinct(sts, sts, ['x'], ['x'], 'l', 'r', ir.I32(1), "left"),
ir.StreamFor(st, 'v', ir.Void()),
aggregate(ir.AggFilter(ir.TrueIR(), ir.I32(0), False)),
aggregate(ir.AggExplode(ir.StreamRange(ir.I32(0), ir.I32(2), ir.I32(1)), 'x', ir.I32(0), False)),
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ abstract class Backend {
def withExecuteContext[T](methodName: String): (ExecuteContext => T) => T

final def valueType(s: String): Array[Byte] = {
withExecuteContext("tableType") { ctx =>
withExecuteContext("valueType") { ctx =>
val v = IRParser.parse_value_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
v.typ.toString.getBytes(StandardCharsets.UTF_8)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class LocalBackend(
def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = {
ExecutionTimer.logTime("LocalBackend.parse_value_ir") { timer =>
withExecuteContext(timer) { ctx =>
IRParser.parse_value_ir(s, IRParserEnvironment(ctx, BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*), persistedIR.toMap))
IRParser.parse_value_ir(s, IRParserEnvironment(ctx, persistedIR.toMap), BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ class SparkBackend(
def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = {
ExecutionTimer.logTime("SparkBackend.parse_value_ir") { timer =>
withExecuteContext(timer) { ctx =>
IRParser.parse_value_ir(s, IRParserEnvironment(ctx, BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*), irMap = persistedIR.toMap))
IRParser.parse_value_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap), BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*))
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/AggOp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ object AggSignature {

case class AggSignature(
op: AggOp,
initOpArgs: Seq[Type],
seqOpArgs: Seq[Type]) {

var initOpArgs: Seq[Type],
var seqOpArgs: Seq[Type]
) {
// only to be used with virtual non-nested signatures on ApplyAggOp and ApplyScanOp
lazy val returnType: Type = Extract.getResultType(this)
}
Expand Down
30 changes: 30 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package is.hail.expr.ir

import is.hail.types.BaseType
import is.hail.types.virtual.Type
import is.hail.utils.StackSafe._
import is.hail.utils._

Expand Down Expand Up @@ -50,4 +51,33 @@ abstract class BaseIR {
copy(newChildren)
}
}

def forEachChildWithEnv(env: BindingEnv[Type])(f: (BaseIR, BindingEnv[Type]) => Unit): Unit = {
childrenSeq.view.zipWithIndex.foreach { case (child, i) =>
val childEnv = ChildBindings(this, i, env)
f(child, childEnv)
}
}

def mapChildrenWithEnv(env: BindingEnv[Type])(f: (BaseIR, BindingEnv[Type]) => BaseIR): BaseIR = {
val newChildren = childrenSeq.toArray
var res = this
for (i <- newChildren.indices) {
val childEnv = ChildBindings(res, i, env)
val child = newChildren(i)
val newChild = f(child, childEnv)
if (!(newChild eq child)) {
newChildren(i) = newChild
res = res.copy(newChildren)
}
}
res
}

def forEachChildWithEnvStackSafe(env: BindingEnv[Type])(f: (BaseIR, Int, BindingEnv[Type]) => StackFrame[Unit]): StackFrame[Unit] = {
childrenSeq.view.zipWithIndex.foreachRecur { case (child, i) =>
val childEnv = ChildBindings(this, i, env)
f(child, i, childEnv)
}
}
}
16 changes: 9 additions & 7 deletions hail/src/main/scala/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ object Binds {
object Bindings {
private val empty: Array[(String, Type)] = Array()

// A call to Bindings(x, i) may only query the types of children with
// index < i
def apply(x: BaseIR, i: Int): Iterable[(String, Type)] = x match {
case Let(name, value, _) => if (i == 1) Array(name -> value.typ) else empty
case TailLoop(name, args, body) => if (i == args.length)
case TailLoop(name, args, resultType, _) => if (i == args.length)
args.map { case (name, ir) => name -> ir.typ } :+
name -> TTuple(TTuple(args.map(_._2.typ): _*), body.typ) else empty
name -> TTuple(TTuple(args.map(_._2.typ): _*), resultType) else empty
case StreamMap(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty
case StreamZip(as, names, _, _, _) => if (i == as.length) names.zip(as.map(a => tcoerce[TStream](a.typ).elementType)) else empty
case StreamZipJoin(as, key, curKey, curVals, _) =>
Expand All @@ -26,14 +28,14 @@ object Bindings {
else
empty
case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) =>
val contextType = TIterable.elementType(contexts.typ)
val eltType = tcoerce[TStruct](tcoerce[TStream](makeProducer.typ).elementType)
if (i == 1)
if (i == 1) {
val contextType = TIterable.elementType(contexts.typ)
Array(ctxName -> contextType)
else if (i == 2)
} else if (i == 2) {
val eltType = tcoerce[TStruct](tcoerce[TStream](makeProducer.typ).elementType)
Array(curKey -> eltType.typeAfterSelectNames(key),
curVals -> TArray(eltType))
else
} else
empty
case StreamFor(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty
case StreamFlatMap(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty
Expand Down
74 changes: 48 additions & 26 deletions hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ abstract sealed class BlockMatrixIR extends BaseIR {
def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR

def blockCostIsLinear: Boolean

def typecheck(): Unit = {}
}

case class BlockMatrixRead(reader: BlockMatrixReader) extends BlockMatrixIR {
Expand Down Expand Up @@ -246,8 +248,11 @@ case class BlockMatrixPersistReader(id: String, typ: BlockMatrixType) extends Bl
}

case class BlockMatrixMap(child: BlockMatrixIR, eltName: String, f: IR, needsDense: Boolean) extends BlockMatrixIR {
override lazy val typ: BlockMatrixType = child.typ
assert(!needsDense || !typ.isSparse)
override def typecheck(): Unit = {
assert(!(needsDense && child.typ.isSparse))
}

override def typ: BlockMatrixType = child.typ

lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, f)

Expand Down Expand Up @@ -370,11 +375,19 @@ case object NeedsDense extends SparsityStrategy {
}
}

case class BlockMatrixMap2(left: BlockMatrixIR, right: BlockMatrixIR, leftName: String, rightName: String, f: IR, sparsityStrategy: SparsityStrategy) extends BlockMatrixIR {
assert(
left.typ.nRows == right.typ.nRows &&
left.typ.nCols == right.typ.nCols &&
left.typ.blockSize == right.typ.blockSize)
case class BlockMatrixMap2(
left: BlockMatrixIR,
right: BlockMatrixIR,
leftName: String,
rightName: String,
f: IR,
sparsityStrategy: SparsityStrategy
) extends BlockMatrixIR {
override def typecheck(): Unit = {
assert(left.typ.nRows == right.typ.nRows)
assert(left.typ.nCols == right.typ.nCols)
assert(left.typ.blockSize == right.typ.blockSize)
}

override lazy val typ: BlockMatrixType = left.typ.copy(sparsity = sparsityStrategy.mergeSparsity(left.typ.sparsity, right.typ.sparsity))

Expand Down Expand Up @@ -477,7 +490,6 @@ case class BlockMatrixMap2(left: BlockMatrixIR, right: BlockMatrixIR, leftName:
}

case class BlockMatrixDot(left: BlockMatrixIR, right: BlockMatrixIR) extends BlockMatrixIR {

override lazy val typ: BlockMatrixType = {
val blockSize = left.typ.blockSize
val (lRows, lCols) = BlockMatrixIR.tensorShapeToMatrixShape(left)
Expand Down Expand Up @@ -531,20 +543,24 @@ case class BlockMatrixBroadcast(
child: BlockMatrixIR,
inIndexExpr: IndexedSeq[Int],
shape: IndexedSeq[Long],
blockSize: Int) extends BlockMatrixIR {
blockSize: Int
) extends BlockMatrixIR {

val blockCostIsLinear: Boolean = child.blockCostIsLinear

assert(shape.length == 2)
assert(inIndexExpr.length <= 2 && inIndexExpr.forall(x => x == 0 || x == 1))

val (nRows, nCols) = BlockMatrixIR.tensorShapeToMatrixShape(child)
val childMatrixShape = IndexedSeq(nRows, nCols)
assert(inIndexExpr.zipWithIndex.forall({ case (out: Int, in: Int) =>
!child.typ.shape.contains(in) || childMatrixShape(in) == shape(out)
}))
override def typecheck(): Unit = {
val (nRows, nCols) = BlockMatrixIR.tensorShapeToMatrixShape(child)
val childMatrixShape = IndexedSeq(nRows, nCols)

assert(inIndexExpr.zipWithIndex.forall({ case (out: Int, in: Int) =>
!child.typ.shape.contains(in) || childMatrixShape(in) == shape(out)
}))
}

override val typ: BlockMatrixType = {
override lazy val typ: BlockMatrixType = {
val (tensorShape, isRowVector) = BlockMatrixIR.matrixShapeToTensorShape(shape(0), shape(1))
val nRowBlocks = BlockMatrixType.numBlocks(shape(0), blockSize)
val nColBlocks = BlockMatrixType.numBlocks(shape(1), blockSize)
Expand Down Expand Up @@ -626,11 +642,12 @@ case class BlockMatrixBroadcast(

case class BlockMatrixAgg(
child: BlockMatrixIR,
axesToSumOut: IndexedSeq[Int]) extends BlockMatrixIR {
axesToSumOut: IndexedSeq[Int]
) extends BlockMatrixIR {

val blockCostIsLinear: Boolean = child.blockCostIsLinear

assert(axesToSumOut.length > 0)
assert(axesToSumOut.nonEmpty)

override lazy val typ: BlockMatrixType = {
val matrixShape = BlockMatrixIR.tensorShapeToMatrixShape(child)
Expand Down Expand Up @@ -675,21 +692,22 @@ case class BlockMatrixAgg(

case class BlockMatrixFilter(
child: BlockMatrixIR,
indices: Array[Array[Long]]) extends BlockMatrixIR {
indices: Array[Array[Long]]
) extends BlockMatrixIR {

assert(indices.length == 2)

val blockCostIsLinear: Boolean = child.blockCostIsLinear
private[this] val Array(keepRow, keepCol) = indices
private[this] val blockSize = child.typ.blockSize

lazy val keepRowPartitioned: Array[Array[Long]] = keepRow.grouped(blockSize).toArray
lazy val keepColPartitioned: Array[Array[Long]] = keepCol.grouped(blockSize).toArray
override lazy val typ: BlockMatrixType = {
val blockSize = child.typ.blockSize
val keepRowPartitioned: Array[Array[Long]] = keepRow.grouped(blockSize).toArray
val keepColPartitioned: Array[Array[Long]] = keepCol.grouped(blockSize).toArray

lazy val rowBlockDependents: Array[Array[Int]] = child.typ.rowBlockDependents(keepRowPartitioned)
lazy val colBlockDependents: Array[Array[Int]] = child.typ.colBlockDependents(keepColPartitioned)
val rowBlockDependents: Array[Array[Int]] = child.typ.rowBlockDependents(keepRowPartitioned)
val colBlockDependents: Array[Array[Int]] = child.typ.colBlockDependents(keepColPartitioned)

override lazy val typ: BlockMatrixType = {
val childTensorShape = child.typ.shape
val childMatrixShape = (childTensorShape, child.typ.isRowVector) match {
case (IndexedSeq(vectorLength), true) => IndexedSeq(1, vectorLength)
Expand Down Expand Up @@ -918,7 +936,11 @@ case class BlockMatrixSlice(child: BlockMatrixIR, slices: IndexedSeq[IndexedSeq[
case class ValueToBlockMatrix(
child: IR,
shape: IndexedSeq[Long],
blockSize: Int) extends BlockMatrixIR {
blockSize: Int
) extends BlockMatrixIR {
override def typecheck(): Unit = {
assert(child.typ.isInstanceOf[TArray] || child.typ.isInstanceOf[TNDArray] || child.typ == TFloat64)
}

assert(shape.length == 2)

Expand Down Expand Up @@ -984,7 +1006,7 @@ case class BlockMatrixRandom(
}

case class RelationalLetBlockMatrix(name: String, value: IR, body: BlockMatrixIR) extends BlockMatrixIR {
override lazy val typ: BlockMatrixType = body.typ
override def typ: BlockMatrixType = body.typ

def childrenSeq: IndexedSeq[BaseIR] = Array(value, body)

Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/Children.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object Children {
Array(value, body)
case AggLet(name, value, body, _) =>
Array(value, body)
case TailLoop(_, args, body) =>
case TailLoop(_, args, _, body) =>
args.map(_._2).toFastSeq :+ body
case Recur(_, args, _) =>
args.toFastSeq
Expand Down Expand Up @@ -227,7 +227,7 @@ object Children {
case Trap(child) => Array(child)
case ConsoleLog(message, result) =>
Array(message, result)
case ApplyIR(_, _, args, _) =>
case ApplyIR(_, _, args, _, _) =>
args.toFastSeq
case Apply(_, _, args, _, _) =>
args.toFastSeq
Expand Down
Loading