diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index bb1716eb613..5b54f1dd482 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -454,7 +454,7 @@ def copy(self, *children): return TailLoop(self.name, [(n, v) for (n, _), v in zip(self.params, params)], body) def head_str(self): - return f'{escape_id(self.name)} ({" ".join([escape_id(n) for n, _ in self.params])})' + return f'{escape_id(self.name)} ({" ".join([escape_id(n) for n, _ in self.params])}) {self.body.typ._parsable_string()}' def _eq(self, other): return self.name == other.name @@ -493,7 +493,7 @@ def copy(self, args): return Recur(self.name, args, self.return_type) def head_str(self): - return f'{escape_id(self.name)} {self.return_type._parsable_string()}' + return f'{escape_id(self.name)}' def _eq(self, other): return other.name == self.name diff --git a/hail/python/hail/ir/matrix_ir.py b/hail/python/hail/ir/matrix_ir.py index 44fcc8f44b6..06644337b19 100644 --- a/hail/python/hail/ir/matrix_ir.py +++ b/hail/python/hail/ir/matrix_ir.py @@ -1183,7 +1183,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): return MatrixFilterIntervals(child, self.intervals, self.point_type, self.keep) def head_str(self): - return f'{dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}' + return f'{self.child.typ.row_key_type._parsable_string()} {dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}' def _eq(self, other): return self.intervals == other.intervals and self.point_type == other.point_type and self.keep == other.keep diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 38fcaae51c3..22c0530c35a 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -881,7 +881,7 @@ def _handle_randomness(self, uid_field_name): return TableFilterIntervals(self.child.handle_randomness(uid_field_name), self.intervals, self.point_type, self.keep) def head_str(self): - return f'{dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}' + return f'{self.child.typ.key_type._parsable_string()} {dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}' def _eq(self, other): return self.intervals == other.intervals and self.point_type == other.point_type and self.keep == other.keep diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index b2a53f670b7..b2340e99e33 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -22,8 +22,9 @@ def value_irs_env(self): 'mat': hl.tndarray(hl.tfloat64, 2), 'aa': hl.tarray(hl.tarray(hl.tint32)), 'sta': hl.tstream(hl.tarray(hl.tint32)), - 'da': hl.tarray(hl.ttuple(hl.tint32, hl.tstr)), - 'nd': hl.tndarray(hl.tfloat64, 1), + 'sts': hl.tstream(hl.tstruct(x=hl.tint32, y=hl.tint64, z=hl.tfloat64)), + 'da': hl.tstream(hl.ttuple(hl.tint32, hl.tstr)), + 'nd': hl.tndarray(hl.tfloat64, 2), 'v': hl.tint32, 's': hl.tstruct(x=hl.tint32, y=hl.tint64, z=hl.tfloat64), 't': hl.ttuple(hl.tint32, hl.tint64, hl.tfloat64), @@ -42,6 +43,7 @@ def value_irs(self): mat = ir.Ref('mat') aa = ir.Ref('aa', env['aa']) sta = ir.Ref('sta', env['sta']) + sts = ir.Ref('sts', env['sts']) da = ir.Ref('da', env['da']) nd = ir.Ref('nd', env['nd']) v = ir.Ref('v', env['v']) @@ -77,7 +79,7 @@ def aggregate(x): ir.ArrayRef(a, i), ir.ArrayLen(a), ir.ArraySort(ir.ToStream(a), 'l', 'r', ir.ApplyComparisonOp("LT", ir.Ref('l', hl.tint32), ir.Ref('r', hl.tint32))), - ir.ToSet(a), + ir.ToSet(st), ir.ToDict(da), ir.ToArray(st), ir.CastToArray(ir.NA(hl.tset(hl.tint32))), @@ -89,17 +91,17 @@ def aggregate(x): ir.NDArrayRef(nd, [ir.I64(1), ir.I64(2)]), ir.NDArrayMap(nd, 'v', v), ir.NDArrayMatMul(nd, nd), - ir.LowerBoundOnOrderedCollection(a, i, True), + ir.LowerBoundOnOrderedCollection(a, i, False), ir.GroupByKey(da), - ir.RNGSplit(rngState, ir.MakeTuple([ir.I64(1), ir.MakeTuple([ir.I64(2), ir.I64(3)])])), + ir.RNGSplit(rngState, ir.MakeTuple([ir.I64(1), ir.I64(2), ir.I64(3)])), ir.StreamMap(st, 'v', v), ir.StreamZip([st, st], ['a', 'b'], ir.TrueIR(), 'ExtendNA'), - ir.StreamFilter(st, 'v', v), + ir.StreamFilter(st, 'v', c), ir.StreamFlatMap(sta, 'v', ir.ToStream(v)), ir.StreamFold(st, ir.I32(0), 'x', 'v', v), ir.StreamScan(st, ir.I32(0), 'x', 'v', v), - ir.StreamWhiten(whitenStream, "newChunk", "prevWindow", 0, 0, 0, 0, False), - ir.StreamJoinRightDistinct(st, st, ['k'], ['k'], 'l', 'r', ir.I32(1), "left"), + ir.StreamWhiten(whitenStream, "newChunk", "prevWindow", 1, 1, 1, 1, False), + ir.StreamJoinRightDistinct(sts, sts, ['x'], ['x'], 'l', 'r', ir.I32(1), "left"), ir.StreamFor(st, 'v', ir.Void()), aggregate(ir.AggFilter(ir.TrueIR(), ir.I32(0), False)), aggregate(ir.AggExplode(ir.StreamRange(ir.I32(0), ir.I32(2), ir.I32(1)), 'x', ir.I32(0), False)), diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index 4fb3e28e7de..0f77f6d91ed 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -151,7 +151,7 @@ abstract class Backend { def withExecuteContext[T](methodName: String): (ExecuteContext => T) => T final def valueType(s: String): Array[Byte] = { - withExecuteContext("tableType") { ctx => + withExecuteContext("valueType") { ctx => val v = IRParser.parse_value_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) v.typ.toString.getBytes(StandardCharsets.UTF_8) } diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 57270f3b75d..b133f622bd4 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -296,7 +296,7 @@ class LocalBackend( def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = { ExecutionTimer.logTime("LocalBackend.parse_value_ir") { timer => withExecuteContext(timer) { ctx => - IRParser.parse_value_ir(s, IRParserEnvironment(ctx, BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*), persistedIR.toMap)) + IRParser.parse_value_ir(s, IRParserEnvironment(ctx, persistedIR.toMap), BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*)) } } } diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index f1ddfb72172..7b48b641487 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -670,7 +670,7 @@ class SparkBackend( def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = { ExecutionTimer.logTime("SparkBackend.parse_value_ir") { timer => withExecuteContext(timer) { ctx => - IRParser.parse_value_ir(s, IRParserEnvironment(ctx, BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*), irMap = persistedIR.toMap)) + IRParser.parse_value_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap), BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*)) } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/AggOp.scala b/hail/src/main/scala/is/hail/expr/ir/AggOp.scala index 3ffb10e9228..81f500a4288 100644 --- a/hail/src/main/scala/is/hail/expr/ir/AggOp.scala +++ b/hail/src/main/scala/is/hail/expr/ir/AggOp.scala @@ -26,9 +26,9 @@ object AggSignature { case class AggSignature( op: AggOp, - initOpArgs: Seq[Type], - seqOpArgs: Seq[Type]) { - + var initOpArgs: Seq[Type], + var seqOpArgs: Seq[Type] +) { // only to be used with virtual non-nested signatures on ApplyAggOp and ApplyScanOp lazy val returnType: Type = Extract.getResultType(this) } diff --git a/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala b/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala index 7fc1eb4d385..b3008b7a6c9 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala @@ -1,6 +1,7 @@ package is.hail.expr.ir import is.hail.types.BaseType +import is.hail.types.virtual.Type import is.hail.utils.StackSafe._ import is.hail.utils._ @@ -50,4 +51,33 @@ abstract class BaseIR { copy(newChildren) } } + + def forEachChildWithEnv(env: BindingEnv[Type])(f: (BaseIR, BindingEnv[Type]) => Unit): Unit = { + childrenSeq.view.zipWithIndex.foreach { case (child, i) => + val childEnv = ChildBindings(this, i, env) + f(child, childEnv) + } + } + + def mapChildrenWithEnv(env: BindingEnv[Type])(f: (BaseIR, BindingEnv[Type]) => BaseIR): BaseIR = { + val newChildren = childrenSeq.toArray + var res = this + for (i <- newChildren.indices) { + val childEnv = ChildBindings(res, i, env) + val child = newChildren(i) + val newChild = f(child, childEnv) + if (!(newChild eq child)) { + newChildren(i) = newChild + res = res.copy(newChildren) + } + } + res + } + + def forEachChildWithEnvStackSafe(env: BindingEnv[Type])(f: (BaseIR, Int, BindingEnv[Type]) => StackFrame[Unit]): StackFrame[Unit] = { + childrenSeq.view.zipWithIndex.foreachRecur { case (child, i) => + val childEnv = ChildBindings(this, i, env) + f(child, i, childEnv) + } + } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Binds.scala b/hail/src/main/scala/is/hail/expr/ir/Binds.scala index b095aaa735e..1c6f3049b5a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Binds.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Binds.scala @@ -11,11 +11,13 @@ object Binds { object Bindings { private val empty: Array[(String, Type)] = Array() + // A call to Bindings(x, i) may only query the types of children with + // index < i def apply(x: BaseIR, i: Int): Iterable[(String, Type)] = x match { case Let(name, value, _) => if (i == 1) Array(name -> value.typ) else empty - case TailLoop(name, args, body) => if (i == args.length) + case TailLoop(name, args, resultType, _) => if (i == args.length) args.map { case (name, ir) => name -> ir.typ } :+ - name -> TTuple(TTuple(args.map(_._2.typ): _*), body.typ) else empty + name -> TTuple(TTuple(args.map(_._2.typ): _*), resultType) else empty case StreamMap(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty case StreamZip(as, names, _, _, _) => if (i == as.length) names.zip(as.map(a => tcoerce[TStream](a.typ).elementType)) else empty case StreamZipJoin(as, key, curKey, curVals, _) => @@ -26,14 +28,14 @@ object Bindings { else empty case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) => - val contextType = TIterable.elementType(contexts.typ) - val eltType = tcoerce[TStruct](tcoerce[TStream](makeProducer.typ).elementType) - if (i == 1) + if (i == 1) { + val contextType = TIterable.elementType(contexts.typ) Array(ctxName -> contextType) - else if (i == 2) + } else if (i == 2) { + val eltType = tcoerce[TStruct](tcoerce[TStream](makeProducer.typ).elementType) Array(curKey -> eltType.typeAfterSelectNames(key), curVals -> TArray(eltType)) - else + } else empty case StreamFor(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty case StreamFlatMap(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty diff --git a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala index fe4099ccfb6..b56a02edbc2 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala @@ -68,6 +68,8 @@ abstract sealed class BlockMatrixIR extends BaseIR { def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixIR def blockCostIsLinear: Boolean + + def typecheck(): Unit = {} } case class BlockMatrixRead(reader: BlockMatrixReader) extends BlockMatrixIR { @@ -246,8 +248,11 @@ case class BlockMatrixPersistReader(id: String, typ: BlockMatrixType) extends Bl } case class BlockMatrixMap(child: BlockMatrixIR, eltName: String, f: IR, needsDense: Boolean) extends BlockMatrixIR { - override lazy val typ: BlockMatrixType = child.typ - assert(!needsDense || !typ.isSparse) + override def typecheck(): Unit = { + assert(!(needsDense && child.typ.isSparse)) + } + + override def typ: BlockMatrixType = child.typ lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, f) @@ -370,11 +375,19 @@ case object NeedsDense extends SparsityStrategy { } } -case class BlockMatrixMap2(left: BlockMatrixIR, right: BlockMatrixIR, leftName: String, rightName: String, f: IR, sparsityStrategy: SparsityStrategy) extends BlockMatrixIR { - assert( - left.typ.nRows == right.typ.nRows && - left.typ.nCols == right.typ.nCols && - left.typ.blockSize == right.typ.blockSize) +case class BlockMatrixMap2( + left: BlockMatrixIR, + right: BlockMatrixIR, + leftName: String, + rightName: String, + f: IR, + sparsityStrategy: SparsityStrategy +) extends BlockMatrixIR { + override def typecheck(): Unit = { + assert(left.typ.nRows == right.typ.nRows) + assert(left.typ.nCols == right.typ.nCols) + assert(left.typ.blockSize == right.typ.blockSize) + } override lazy val typ: BlockMatrixType = left.typ.copy(sparsity = sparsityStrategy.mergeSparsity(left.typ.sparsity, right.typ.sparsity)) @@ -477,7 +490,6 @@ case class BlockMatrixMap2(left: BlockMatrixIR, right: BlockMatrixIR, leftName: } case class BlockMatrixDot(left: BlockMatrixIR, right: BlockMatrixIR) extends BlockMatrixIR { - override lazy val typ: BlockMatrixType = { val blockSize = left.typ.blockSize val (lRows, lCols) = BlockMatrixIR.tensorShapeToMatrixShape(left) @@ -531,20 +543,24 @@ case class BlockMatrixBroadcast( child: BlockMatrixIR, inIndexExpr: IndexedSeq[Int], shape: IndexedSeq[Long], - blockSize: Int) extends BlockMatrixIR { + blockSize: Int +) extends BlockMatrixIR { val blockCostIsLinear: Boolean = child.blockCostIsLinear assert(shape.length == 2) assert(inIndexExpr.length <= 2 && inIndexExpr.forall(x => x == 0 || x == 1)) - val (nRows, nCols) = BlockMatrixIR.tensorShapeToMatrixShape(child) - val childMatrixShape = IndexedSeq(nRows, nCols) - assert(inIndexExpr.zipWithIndex.forall({ case (out: Int, in: Int) => - !child.typ.shape.contains(in) || childMatrixShape(in) == shape(out) - })) + override def typecheck(): Unit = { + val (nRows, nCols) = BlockMatrixIR.tensorShapeToMatrixShape(child) + val childMatrixShape = IndexedSeq(nRows, nCols) + + assert(inIndexExpr.zipWithIndex.forall({ case (out: Int, in: Int) => + !child.typ.shape.contains(in) || childMatrixShape(in) == shape(out) + })) + } - override val typ: BlockMatrixType = { + override lazy val typ: BlockMatrixType = { val (tensorShape, isRowVector) = BlockMatrixIR.matrixShapeToTensorShape(shape(0), shape(1)) val nRowBlocks = BlockMatrixType.numBlocks(shape(0), blockSize) val nColBlocks = BlockMatrixType.numBlocks(shape(1), blockSize) @@ -626,11 +642,12 @@ case class BlockMatrixBroadcast( case class BlockMatrixAgg( child: BlockMatrixIR, - axesToSumOut: IndexedSeq[Int]) extends BlockMatrixIR { + axesToSumOut: IndexedSeq[Int] +) extends BlockMatrixIR { val blockCostIsLinear: Boolean = child.blockCostIsLinear - assert(axesToSumOut.length > 0) + assert(axesToSumOut.nonEmpty) override lazy val typ: BlockMatrixType = { val matrixShape = BlockMatrixIR.tensorShapeToMatrixShape(child) @@ -675,21 +692,22 @@ case class BlockMatrixAgg( case class BlockMatrixFilter( child: BlockMatrixIR, - indices: Array[Array[Long]]) extends BlockMatrixIR { + indices: Array[Array[Long]] +) extends BlockMatrixIR { assert(indices.length == 2) val blockCostIsLinear: Boolean = child.blockCostIsLinear private[this] val Array(keepRow, keepCol) = indices - private[this] val blockSize = child.typ.blockSize - lazy val keepRowPartitioned: Array[Array[Long]] = keepRow.grouped(blockSize).toArray - lazy val keepColPartitioned: Array[Array[Long]] = keepCol.grouped(blockSize).toArray + override lazy val typ: BlockMatrixType = { + val blockSize = child.typ.blockSize + val keepRowPartitioned: Array[Array[Long]] = keepRow.grouped(blockSize).toArray + val keepColPartitioned: Array[Array[Long]] = keepCol.grouped(blockSize).toArray - lazy val rowBlockDependents: Array[Array[Int]] = child.typ.rowBlockDependents(keepRowPartitioned) - lazy val colBlockDependents: Array[Array[Int]] = child.typ.colBlockDependents(keepColPartitioned) + val rowBlockDependents: Array[Array[Int]] = child.typ.rowBlockDependents(keepRowPartitioned) + val colBlockDependents: Array[Array[Int]] = child.typ.colBlockDependents(keepColPartitioned) - override lazy val typ: BlockMatrixType = { val childTensorShape = child.typ.shape val childMatrixShape = (childTensorShape, child.typ.isRowVector) match { case (IndexedSeq(vectorLength), true) => IndexedSeq(1, vectorLength) @@ -918,7 +936,11 @@ case class BlockMatrixSlice(child: BlockMatrixIR, slices: IndexedSeq[IndexedSeq[ case class ValueToBlockMatrix( child: IR, shape: IndexedSeq[Long], - blockSize: Int) extends BlockMatrixIR { + blockSize: Int +) extends BlockMatrixIR { + override def typecheck(): Unit = { + assert(child.typ.isInstanceOf[TArray] || child.typ.isInstanceOf[TNDArray] || child.typ == TFloat64) + } assert(shape.length == 2) @@ -984,7 +1006,7 @@ case class BlockMatrixRandom( } case class RelationalLetBlockMatrix(name: String, value: IR, body: BlockMatrixIR) extends BlockMatrixIR { - override lazy val typ: BlockMatrixType = body.typ + override def typ: BlockMatrixType = body.typ def childrenSeq: IndexedSeq[BaseIR] = Array(value, body) diff --git a/hail/src/main/scala/is/hail/expr/ir/Children.scala b/hail/src/main/scala/is/hail/expr/ir/Children.scala index 7aec32d3d53..d81ab5c764e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Children.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Children.scala @@ -40,7 +40,7 @@ object Children { Array(value, body) case AggLet(name, value, body, _) => Array(value, body) - case TailLoop(_, args, body) => + case TailLoop(_, args, _, body) => args.map(_._2).toFastSeq :+ body case Recur(_, args, _) => args.toFastSeq @@ -227,7 +227,7 @@ object Children { case Trap(child) => Array(child) case ConsoleLog(message, result) => Array(message, result) - case ApplyIR(_, _, args, _) => + case ApplyIR(_, _, args, _, _) => args.toFastSeq case Apply(_, _, args, _, _) => args.toFastSeq diff --git a/hail/src/main/scala/is/hail/expr/ir/ComparisonOp.scala b/hail/src/main/scala/is/hail/expr/ir/ComparisonOp.scala index b406445130f..f9d83ebb980 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ComparisonOp.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ComparisonOp.scala @@ -14,28 +14,21 @@ object ComparisonOp { if (lt != rt) throw new RuntimeException(s"Cannot compare types $lt and $rt") - val fromStringAndTypes: PartialFunction[(String, Type, Type), ComparisonOp[_]] = { - case ("==" | "EQ", t1, t2) => - checkCompatible(t1, t2) - EQ(t1, t2) - case ("!=" | "NEQ", t1, t2) => - checkCompatible(t1, t2) - NEQ(t1, t2) - case (">=" | "GTEQ", t1, t2) => - checkCompatible(t1, t2) - GTEQ(t1, t2) - case ("<=" | "LTEQ", t1, t2) => - checkCompatible(t1, t2) - LTEQ(t1, t2) - case (">" | "GT", t1, t2) => - checkCompatible(t1, t2) - GT(t1, t2) - case ("<" | "LT", t1, t2) => - checkCompatible(t1, t2) - LT(t1, t2) - case ("Compare", t1, t2) => - checkCompatible(t1, t2) - Compare(t1, t2) + val fromString: PartialFunction[String, ComparisonOp[_]] = { + case "==" | "EQ" => + EQ(null, null) + case "!=" | "NEQ" => + NEQ(null, null) + case ">=" | "GTEQ" => + GTEQ(null, null) + case "<=" | "LTEQ" => + LTEQ(null, null) + case ">" | "GT" => + GT(null, null) + case "<" | "LT" => + LT(null, null) + case "Compare" => + Compare(null, null) } def negate(op: ComparisonOp[Boolean]): ComparisonOp[Boolean] = { @@ -84,33 +77,56 @@ sealed trait ComparisonOp[ReturnType] { } def render(): is.hail.utils.prettyPrint.Doc = Pretty.prettyClass(this) + + def copy(t1: Type, t2: Type): ComparisonOp[ReturnType] } -case class GT(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Gt() } +case class GT(t1: Type, t2: Type) extends ComparisonOp[Boolean] { + val op: CodeOrdering.Op = CodeOrdering.Gt() + override def copy(t1: Type = t1, t2: Type = t2): GT = GT(t1, t2) +} object GT { def apply(typ: Type): GT = GT(typ, typ) } -case class GTEQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Gteq() } +case class GTEQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { + val op: CodeOrdering.Op = CodeOrdering.Gteq() + override def copy(t1: Type = t1, t2: Type = t2): GTEQ = GTEQ(t1, t2) +} object GTEQ { def apply(typ: Type): GTEQ = GTEQ(typ, typ) } -case class LTEQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Lteq() } +case class LTEQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { + val op: CodeOrdering.Op = CodeOrdering.Lteq() + override def copy(t1: Type = t1, t2: Type = t2): LTEQ = LTEQ(t1, t2) +} object LTEQ { def apply(typ: Type): LTEQ = LTEQ(typ, typ) } -case class LT(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Lt() } +case class LT(t1: Type, t2: Type) extends ComparisonOp[Boolean] { + val op: CodeOrdering.Op = CodeOrdering.Lt() + override def copy(t1: Type = t1, t2: Type = t2): LT = LT(t1, t2) +} object LT { def apply(typ: Type): LT = LT(typ, typ) } -case class EQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Equiv() } +case class EQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { + val op: CodeOrdering.Op = CodeOrdering.Equiv() + override def copy(t1: Type = t1, t2: Type = t2): EQ = EQ(t1, t2) +} object EQ { def apply(typ: Type): EQ = EQ(typ, typ) } -case class NEQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Neq() } +case class NEQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { + val op: CodeOrdering.Op = CodeOrdering.Neq() + override def copy(t1: Type = t1, t2: Type = t2): NEQ = NEQ(t1, t2) +} object NEQ { def apply(typ: Type): NEQ = NEQ(typ, typ) } case class EQWithNA(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Equiv() override val strict: Boolean = false + override def copy(t1: Type = t1, t2: Type = t2): EQWithNA = EQWithNA(t1, t2) } object EQWithNA { def apply(typ: Type): EQWithNA = EQWithNA(typ, typ) } case class NEQWithNA(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Neq() override val strict: Boolean = false + override def copy(t1: Type = t1, t2: Type = t2): NEQWithNA = NEQWithNA(t1, t2) } object NEQWithNA { def apply(typ: Type): NEQWithNA = NEQWithNA(typ, typ) } case class Compare(t1: Type, t2: Type) extends ComparisonOp[Int] { override val strict: Boolean = false val op: CodeOrdering.Op = CodeOrdering.Compare() + override def copy(t1: Type = t1, t2: Type = t2): Compare = Compare(t1, t2) } object Compare { def apply(typ: Type): Compare = Compare(typ, typ) } @@ -126,28 +142,33 @@ trait StructComparisonOp[T] extends ComparisonOp[T] { case class StructCompare(t1: Type, t2: Type, sortFields: Array[SortField]) extends StructComparisonOp[Int] { val op: CodeOrdering.Op = CodeOrdering.StructCompare() override val strict: Boolean = false + override def copy(t1: Type = t1, t2: Type = t2): StructCompare = StructCompare(t1, t2, sortFields) } case class StructLT(t1: Type, t2: Type, sortFields: Array[SortField]) extends StructComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.StructLt() + override def copy(t1: Type = t1, t2: Type = t2): StructLT = StructLT(t1, t2, sortFields) } object StructLT { def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructLT = StructLT(typ, typ, sortFields.toArray) } case class StructLTEQ(t1: Type, t2: Type, sortFields: Array[SortField]) extends StructComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.StructLteq() + override def copy(t1: Type = t1, t2: Type = t2): StructLTEQ = StructLTEQ(t1, t2, sortFields) } object StructLTEQ { def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructLTEQ = StructLTEQ(typ, typ, sortFields.toArray) } case class StructGT(t1: Type, t2: Type, sortFields: Array[SortField]) extends StructComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.StructGt() + override def copy(t1: Type = t1, t2: Type = t2): StructGT = StructGT(t1, t2, sortFields) } object StructGT { def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructGT = StructGT(typ, typ, sortFields.toArray) } case class StructGTEQ(t1: Type, t2: Type, sortFields: Array[SortField]) extends StructComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.StructGteq() + override def copy(t1: Type = t1, t2: Type = t2): StructGTEQ = StructGTEQ(t1, t2, sortFields) } object StructGTEQ { def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructGTEQ = StructGTEQ(typ, typ, sortFields.toArray) } diff --git a/hail/src/main/scala/is/hail/expr/ir/Copy.scala b/hail/src/main/scala/is/hail/expr/ir/Copy.scala index 5e1bd4a4030..93f6ef4913b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Copy.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Copy.scala @@ -40,9 +40,9 @@ object Copy { case AggLet(name, _, _, isScan) => assert(newChildren.length == 2) AggLet(name, newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], isScan) - case TailLoop(name, params, _) => + case TailLoop(name, params, resultType, _) => assert(newChildren.length == params.length + 1) - TailLoop(name, params.map(_._1).zip(newChildren.init.map(_.asInstanceOf[IR])), newChildren.last.asInstanceOf[IR]) + TailLoop(name, params.map(_._1).zip(newChildren.init.map(_.asInstanceOf[IR])), resultType, newChildren.last.asInstanceOf[IR]) case Recur(name, args, t) => assert(newChildren.length == args.length) Recur(name, newChildren.map(_.asInstanceOf[IR]), t) @@ -339,8 +339,8 @@ object Copy { case ConsoleLog(message, result) => assert(newChildren.length == 2) ConsoleLog(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case x@ApplyIR(fn, typeArgs, args, errorID) => - val r = ApplyIR(fn, typeArgs, newChildren.map(_.asInstanceOf[IR]), errorID) + case x@ApplyIR(fn, typeArgs, args, rt, errorID) => + val r = ApplyIR(fn, typeArgs, newChildren.map(_.asInstanceOf[IR]), rt, errorID) r.conversion = x.conversion r.inline = x.inline r diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index ca9c1cefcd1..ecfc03297ef 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -2368,7 +2368,7 @@ class Emit[C]( } } - case x@TailLoop(name, args, body) => + case x@TailLoop(name, args, _, body) => val loopStartLabel = CodeLabel() val accTypes = ctx.req.lookupState(x).zip(args.map(_._2.typ)) diff --git a/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala b/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala index ccdb45572a8..4619fe17287 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala @@ -755,7 +755,7 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { .restrict(keySet) case (IsNA(_), Seq(b: BoolValue)) => b.isNA.restrict(keySet) // collection contains - case (ApplyIR("contains", _, _, _), Seq(ConstantValue(collectionVal), queryVal)) if literalSizeOkay(collectionVal) => + case (ApplyIR("contains", _, _, _, _), Seq(ConstantValue(collectionVal), queryVal)) if literalSizeOkay(collectionVal) => if (collectionVal == null) { BoolValue.allNA(keySet) } else queryVal match { diff --git a/hail/src/main/scala/is/hail/expr/ir/IR.scala b/hail/src/main/scala/is/hail/expr/ir/IR.scala index 54367a89db9..ccaba5925d5 100644 --- a/hail/src/main/scala/is/hail/expr/ir/IR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/IR.scala @@ -196,22 +196,27 @@ sealed abstract class BaseRef extends IR with TrivialIR { def _typ: Type } -final case class Ref(name: String, var _typ: Type) extends BaseRef +final case class Ref(name: String, var _typ: Type) extends BaseRef { + override def typ: Type = { + assert(_typ != null) + _typ + } +} // Recur can't exist outside of loop // Loops can be nested, but we can't call outer loops in terms of inner loops so there can only be one loop "active" in a given context -final case class TailLoop(name: String, params: IndexedSeq[(String, IR)], body: IR) extends IR { +final case class TailLoop(name: String, params: IndexedSeq[(String, IR)], resultType: Type, body: IR) extends IR { lazy val paramIdx: Map[String, Int] = params.map(_._1).zipWithIndex.toMap } -final case class Recur(name: String, args: IndexedSeq[IR], _typ: Type) extends BaseRef +final case class Recur(name: String, args: IndexedSeq[IR], var _typ: Type) extends BaseRef final case class RelationalLet(name: String, value: IR, body: IR) extends IR final case class RelationalRef(name: String, _typ: Type) extends BaseRef final case class ApplyBinaryPrimOp(op: BinaryOp, l: IR, r: IR) extends IR final case class ApplyUnaryPrimOp(op: UnaryOp, x: IR) extends IR -final case class ApplyComparisonOp(op: ComparisonOp[_], l: IR, r: IR) extends IR +final case class ApplyComparisonOp(var op: ComparisonOp[_], l: IR, r: IR) extends IR object MakeArray { def apply(args: IR*): MakeArray = { @@ -319,8 +324,7 @@ final case class StreamLen(a: IR) extends IR final case class StreamGrouped(a: IR, groupSize: IR) extends IR final case class StreamGroupByKey(a: IR, key: IndexedSeq[String], missingEqual: Boolean) extends IR -final case class StreamMap(a: IR, name: String, body: IR) extends IR { - override def typ: TStream = tcoerce[TStream](super.typ) +final case class StreamMap(a: IR, name: String, body: IR) extends TypedIR[TStream] { def elementTyp: Type = typ.elementType } @@ -354,34 +358,30 @@ object ArrayZipBehavior extends Enumeration { val ExtendNA: Value = Value(3) } -final case class StreamZip(as: IndexedSeq[IR], names: IndexedSeq[String], body: IR, behavior: ArrayZipBehavior, - errorID: Int = ErrorIDs.NO_ERROR) extends IR { - lazy val nameIdx: Map[String, Int] = names.zipWithIndex.toMap - override def typ: TStream = tcoerce[TStream](super.typ) -} -final case class StreamMultiMerge(as: IndexedSeq[IR], key: IndexedSeq[String]) extends IR { - override def typ: TStream = tcoerce[TStream](super.typ) -} +final case class StreamZip( + as: IndexedSeq[IR], names: IndexedSeq[String], body: IR, behavior: ArrayZipBehavior, + errorID: Int = ErrorIDs.NO_ERROR +) extends TypedIR[TStream] + +final case class StreamMultiMerge(as: IndexedSeq[IR], key: IndexedSeq[String]) extends TypedIR[TStream] -final case class StreamZipJoinProducers(contexts: IR, ctxName: String, makeProducer: IR, - key: IndexedSeq[String], curKey: String, curVals: String, joinF: IR) extends IR { - override def typ: TStream = tcoerce[TStream](super.typ) -} + +final case class StreamZipJoinProducers( + contexts: IR, ctxName: String, makeProducer: IR, key: IndexedSeq[String], + curKey: String, curVals: String, joinF: IR +) extends TypedIR[TStream] /** * The StreamZipJoin node assumes that input streams have distinct keys. If input streams * do not have distinct keys, the key that is included in the result is undefined, but * is likely the last. */ -final case class StreamZipJoin(as: IndexedSeq[IR], key: IndexedSeq[String], curKey: String, curVals: String, joinF: IR) extends IR { - override def typ: TStream = tcoerce[TStream](super.typ) -} -final case class StreamFilter(a: IR, name: String, cond: IR) extends IR { - override def typ: TStream = tcoerce[TStream](super.typ) -} -final case class StreamFlatMap(a: IR, name: String, body: IR) extends IR { - override def typ: TStream = tcoerce[TStream](super.typ) -} +final case class StreamZipJoin( + as: IndexedSeq[IR], key: IndexedSeq[String], curKey: String, curVals: String, joinF: IR +) extends TypedIR[TStream] + +final case class StreamFilter(a: IR, name: String, cond: IR) extends TypedIR[TStream] +final case class StreamFlatMap(a: IR, name: String, body: IR) extends TypedIR[TStream] final case class StreamFold(a: IR, zero: IR, accumName: String, valueName: String, body: IR) extends IR @@ -661,10 +661,7 @@ final case class SelectFields(old: IR, fields: IndexedSeq[String]) extends IR object InsertFields { def apply(old: IR, fields: Seq[(String, IR)]): InsertFields = InsertFields(old, fields, None) } -final case class InsertFields(old: IR, fields: Seq[(String, IR)], fieldOrder: Option[IndexedSeq[String]]) extends IR { - - override def typ: TStruct = tcoerce[TStruct](super.typ) -} +final case class InsertFields(old: IR, fields: Seq[(String, IR)], fieldOrder: Option[IndexedSeq[String]]) extends TypedIR[TStruct] object GetFieldByIdx { def apply(s: IR, field: Int): IR = { @@ -715,7 +712,7 @@ final case class Trap(child: IR) extends IR final case class Die(message: IR, _typ: Type, errorId: Int) extends IR final case class ConsoleLog(message: IR, result: IR) extends IR -final case class ApplyIR(function: String, typeArgs: Seq[Type], args: Seq[IR], errorID: Int) extends IR { +final case class ApplyIR(function: String, typeArgs: Seq[Type], args: Seq[IR], returnType: Type, errorID: Int) extends IR { var conversion: (Seq[Type], Seq[IR], Int) => IR = _ var inline: Boolean = _ @@ -725,7 +722,9 @@ final case class ApplyIR(function: String, typeArgs: Seq[Type], args: Seq[IR], e lazy val explicitNode: IR = { // foldRight because arg1 should be at the top so it is evaluated first - refs.zip(args).foldRight(body) { case ((ref, arg), bodyIR) => Let(ref.name, arg, bodyIR) } + val ir = refs.zip(args).foldRight(body) { case ((ref, arg), bodyIR) => Let(ref.name, arg, bodyIR) } + assert(ir.typ == returnType) + ir } } @@ -756,22 +755,14 @@ final case class MatrixAggregate(child: MatrixIR, query: IR) extends IR final case class TableWrite(child: TableIR, writer: TableWriter) extends IR -final case class TableMultiWrite(_children: IndexedSeq[TableIR], writer: WrappedMatrixNativeMultiWriter) extends IR { - private val t = _children.head.typ - require(_children.forall(_.typ == t)) -} +final case class TableMultiWrite(_children: IndexedSeq[TableIR], writer: WrappedMatrixNativeMultiWriter) extends IR final case class TableGetGlobals(child: TableIR) extends IR final case class TableCollect(child: TableIR) extends IR final case class MatrixWrite(child: MatrixIR, writer: MatrixWriter) extends IR -final case class MatrixMultiWrite(_children: IndexedSeq[MatrixIR], writer: MatrixNativeMultiWriter) extends IR { - private val t = _children.head.typ - assert(!t.rowType.hasField(MatrixReader.rowUIDFieldName) && - !t.colType.hasField(MatrixReader.colUIDFieldName), t) - require(_children.forall(_.typ == t)) -} +final case class MatrixMultiWrite(_children: IndexedSeq[MatrixIR], writer: MatrixNativeMultiWriter) extends IR final case class TableToValueApply(child: TableIR, function: TableToValueFunction) extends IR final case class MatrixToValueApply(child: MatrixIR, function: MatrixToValueFunction) extends IR @@ -960,10 +951,7 @@ final case class SimpleMetadataWriter(val annotationType: Type) extends Metadata writeAnnotations.consume(cb, {}, {_ => ()}) } -final case class ReadPartition(context: IR, rowType: TStruct, reader: PartitionReader) extends IR { - assert(context.typ == reader.contextType, s"context: ${context.typ}, expected: ${reader.contextType}") - assert(PruneDeadFields.isSupertype(rowType, reader.fullRowType), s"requested type: $rowType, full type: ${reader.fullRowType}") -} +final case class ReadPartition(context: IR, rowType: TStruct, reader: PartitionReader) extends IR final case class WritePartition(value: IR, writeCtx: IR, writer: PartitionWriter) extends IR final case class WriteMetadata(writeAnnotations: IR, writer: MetadataWriter) extends IR diff --git a/hail/src/main/scala/is/hail/expr/ir/InTailPosition.scala b/hail/src/main/scala/is/hail/expr/ir/InTailPosition.scala index 2a83dc673b7..108e3082f7e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/InTailPosition.scala +++ b/hail/src/main/scala/is/hail/expr/ir/InTailPosition.scala @@ -5,7 +5,7 @@ object InTailPosition { case Let(_, _, _) => i == 1 case If(_, _, _) => i != 0 case _: Switch => i != 0 - case TailLoop(_, params, _) => i == params.length + case TailLoop(_, params, _, _) => i == params.length case _ => false } } diff --git a/hail/src/main/scala/is/hail/expr/ir/InferType.scala b/hail/src/main/scala/is/hail/expr/ir/InferType.scala index 7c7b3f78d40..528577b4384 100644 --- a/hail/src/main/scala/is/hail/expr/ir/InferType.scala +++ b/hail/src/main/scala/is/hail/expr/ir/InferType.scala @@ -66,8 +66,8 @@ object InferType { body.typ case AggLet(name, value, body, _) => body.typ - case TailLoop(_, _, body) => - body.typ + case TailLoop(_, _, resultType, _) => + resultType case Recur(_, _, typ) => typ case ApplyBinaryPrimOp(op, l, r) => @@ -80,7 +80,7 @@ object InferType { case _: Compare => TInt32 case _ => TBoolean } - case a: ApplyIR => a.explicitNode.typ + case a: ApplyIR => a.returnType case a: AbstractApplyNode[_] => val typeArgs = a.typeArgs val argTypes = a.args.map(_.typ) diff --git a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala index 98979a60db0..e6f7127865e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala @@ -790,7 +790,7 @@ object Interpret { val message_ = interpret(message).asInstanceOf[String] info(message_) interpret(result) - case ir@ApplyIR(function, _, functionArgs, _) => + case ir@ApplyIR(function, _, _, functionArgs, _) => interpret(ir.explicitNode, env, args) case ApplySpecial("lor", _, Seq(left_, right_), _, _) => val left = interpret(left_) diff --git a/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala index bce01b0b994..7e08faa9901 100644 --- a/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala @@ -51,6 +51,8 @@ abstract sealed class MatrixIR extends BaseIR { } def pyUnpersist(): MatrixIR = unpersist() + + def typecheck(): Unit = {} } object MatrixLiteral { @@ -482,7 +484,7 @@ case class MatrixFilterCols(child: MatrixIR, pred: IR) extends MatrixIR { MatrixFilterCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR]) } - val typ: MatrixType = child.typ + def typ: MatrixType = child.typ override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts @@ -513,7 +515,7 @@ case class MatrixChooseCols(child: MatrixIR, oldIndices: IndexedSeq[Int]) extend MatrixChooseCols(newChildren(0).asInstanceOf[MatrixIR], oldIndices) } - val typ: MatrixType = child.typ + def typ: MatrixType = child.typ override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts @@ -530,7 +532,7 @@ case class MatrixCollectColsByKey(child: MatrixIR) extends MatrixIR { MatrixCollectColsByKey(newChildren(0).asInstanceOf[MatrixIR]) } - val typ: MatrixType = { + lazy val typ: MatrixType = { val newColValueType = TStruct(child.typ.colValueStruct.fields.map(f => f.copy(typ = TArray(f.typ)))) val newColType = child.typ.colKeyStruct ++ newColValueType val newEntryType = TStruct(child.typ.entryType.fields.map(f => f.copy(typ = TArray(f.typ)))) @@ -544,7 +546,9 @@ case class MatrixCollectColsByKey(child: MatrixIR) extends MatrixIR { } case class MatrixAggregateRowsByKey(child: MatrixIR, entryExpr: IR, rowExpr: IR) extends MatrixIR { - require(child.typ.rowKey.nonEmpty) + override def typecheck(): Unit = { + assert(child.typ.rowKey.nonEmpty) + } lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, entryExpr, rowExpr) @@ -553,7 +557,7 @@ case class MatrixAggregateRowsByKey(child: MatrixIR, entryExpr: IR, rowExpr: IR) MatrixAggregateRowsByKey(newChild, newEntryExpr, newRowExpr) } - val typ: MatrixType = child.typ.copy( + lazy val typ: MatrixType = child.typ.copy( rowType = child.typ.rowKeyStruct ++ tcoerce[TStruct](rowExpr.typ), entryType = tcoerce[TStruct](entryExpr.typ) ) @@ -564,7 +568,9 @@ case class MatrixAggregateRowsByKey(child: MatrixIR, entryExpr: IR, rowExpr: IR) } case class MatrixAggregateColsByKey(child: MatrixIR, entryExpr: IR, colExpr: IR) extends MatrixIR { - require(child.typ.colKey.nonEmpty) + override def typecheck(): Unit = { + assert(child.typ.colKey.nonEmpty) + } lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, entryExpr, colExpr) @@ -573,7 +579,7 @@ case class MatrixAggregateColsByKey(child: MatrixIR, entryExpr: IR, colExpr: IR) MatrixAggregateColsByKey(newChild, newEntryExpr, newColExpr) } - val typ = child.typ.copy( + lazy val typ = child.typ.copy( entryType = tcoerce[TStruct](entryExpr.typ), colType = child.typ.colKeyStruct ++ tcoerce[TStruct](colExpr.typ)) @@ -584,7 +590,12 @@ case class MatrixAggregateColsByKey(child: MatrixIR, entryExpr: IR, colExpr: IR) case class MatrixUnionCols(left: MatrixIR, right: MatrixIR, joinType: String) extends MatrixIR { require(joinType == "inner" || joinType == "outer") - require(left.typ.rowKeyStruct isIsomorphicTo right.typ.rowKeyStruct) + + override def typecheck(): Unit = { + assert(left.typ.rowKeyStruct == right.typ.rowKeyStruct, s"${left.typ.rowKeyStruct} != ${right.typ.rowKeyStruct}") + assert(left.typ.colType == right.typ.colType, s"${left.typ.colType} != ${right.typ.colType}") + assert(left.typ.entryType == right.typ.entryType, s"${left.typ.entryType} != ${right.typ.entryType}") + } lazy val childrenSeq: IndexedSeq[BaseIR] = Array(left, right) @@ -593,7 +604,7 @@ case class MatrixUnionCols(left: MatrixIR, right: MatrixIR, joinType: String) ex MatrixUnionCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[MatrixIR], joinType) } - private val newRowType = { + private def newRowType = { val leftKeyType = left.typ.rowKeyStruct val leftValueType = left.typ.rowValueStruct val rightValueType = right.typ.rowValueStruct @@ -605,7 +616,7 @@ case class MatrixUnionCols(left: MatrixIR, right: MatrixIR, joinType: String) ex leftKeyType ++ leftValueType ++ rightValueType } - val typ: MatrixType = if (joinType == "inner") + lazy val typ: MatrixType = if (joinType == "inner") left.typ.copy(rowType = newRowType) else left.typ.copy( @@ -632,7 +643,7 @@ case class MatrixMapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR { MatrixMapEntries(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR]) } - val typ: MatrixType = + lazy val typ: MatrixType = child.typ.copy(entryType = tcoerce[TStruct](newEntries.typ)) override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts @@ -643,12 +654,14 @@ case class MatrixMapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR { } case class MatrixKeyRowsBy(child: MatrixIR, keys: IndexedSeq[String], isSorted: Boolean = false) extends MatrixIR { - private val fields = child.typ.rowType.fieldNames.toSet - assert(keys.forall(fields.contains), s"${ keys.filter(k => !fields.contains(k)).mkString(", ") }") + override def typecheck(): Unit = { + val fields = child.typ.rowType.fieldNames.toSet + assert(keys.forall(fields.contains), s"${keys.filter(k => !fields.contains(k)).mkString(", ")}") + } val childrenSeq: IndexedSeq[BaseIR] = Array(child) - val typ: MatrixType = child.typ.copy(rowKey = keys) + lazy val typ: MatrixType = child.typ.copy(rowKey = keys) def copy(newChildren: IndexedSeq[BaseIR]): MatrixKeyRowsBy = { assert(newChildren.length == 1) @@ -669,7 +682,7 @@ case class MatrixMapRows(child: MatrixIR, newRow: IR) extends MatrixIR { MatrixMapRows(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR]) } - val typ: MatrixType = { + lazy val typ: MatrixType = { child.typ.copy(rowType = newRow.typ.asInstanceOf[TStruct]) } @@ -688,7 +701,7 @@ case class MatrixMapCols(child: MatrixIR, newCol: IR, newKey: Option[IndexedSeq[ MatrixMapCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR], newKey) } - val typ: MatrixType = { + lazy val typ: MatrixType = { val newColType = newCol.typ.asInstanceOf[TStruct] val newColKey = newKey.getOrElse(child.typ.colKey) child.typ.copy(colKey = newColKey, colType = newColType) @@ -704,7 +717,7 @@ case class MatrixMapCols(child: MatrixIR, newCol: IR, newKey: Option[IndexedSeq[ case class MatrixMapGlobals(child: MatrixIR, newGlobals: IR) extends MatrixIR { val childrenSeq: IndexedSeq[BaseIR] = Array(child, newGlobals) - val typ: MatrixType = + lazy val typ: MatrixType = child.typ.copy(globalType = newGlobals.typ.asInstanceOf[TStruct]) def copy(newChildren: IndexedSeq[BaseIR]): MatrixMapGlobals = { @@ -727,7 +740,7 @@ case class MatrixFilterEntries(child: MatrixIR, pred: IR) extends MatrixIR { MatrixFilterEntries(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR]) } - val typ: MatrixType = child.typ + def typ: MatrixType = child.typ override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts @@ -739,8 +752,11 @@ case class MatrixFilterEntries(child: MatrixIR, pred: IR) extends MatrixIR { case class MatrixAnnotateColsTable( child: MatrixIR, table: TableIR, - root: String) extends MatrixIR { - require(child.typ.colType.fieldOption(root).isEmpty) + root: String +) extends MatrixIR { + override def typecheck(): Unit = { + assert(child.typ.colType.fieldOption(root).isEmpty) + } lazy val childrenSeq: IndexedSeq[BaseIR] = FastSeq(child, table) @@ -748,8 +764,8 @@ case class MatrixAnnotateColsTable( override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts - private val (colType, inserter) = child.typ.colType.structInsert(table.typ.valueType, List(root)) - val typ: MatrixType = child.typ.copy(colType = colType) + lazy val typ: MatrixType = child.typ.copy( + colType = child.typ.colType.structInsert(table.typ.valueType, List(root))) def copy(newChildren: IndexedSeq[BaseIR]): MatrixAnnotateColsTable = { MatrixAnnotateColsTable( @@ -767,9 +783,12 @@ case class MatrixAnnotateRowsTable( root: String, product: Boolean ) extends MatrixIR { - require((!product && table.typ.keyType.isPrefixOf(child.typ.rowKeyStruct)) || - (table.typ.keyType.size == 1 && table.typ.keyType.types(0) == TInterval(child.typ.rowKeyStruct.types(0))), - s"\n L: ${ child.typ }\n R: ${ table.typ }") + override def typecheck(): Unit = { + assert( + (!product && table.typ.keyType.isPrefixOf(child.typ.rowKeyStruct)) || + (table.typ.keyType.size == 1 && table.typ.keyType.types(0) == TInterval(child.typ.rowKeyStruct.types(0))), + s"\n L: ${child.typ}\n R: ${table.typ}") + } lazy val childrenSeq: IndexedSeq[BaseIR] = FastSeq(child, table) @@ -779,13 +798,13 @@ case class MatrixAnnotateRowsTable( lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound - private val annotationType = + private def annotationType = if (product) TArray(table.typ.valueType) else table.typ.valueType - val typ: MatrixType = + lazy val typ: MatrixType = child.typ.copy(rowType = child.typ.rowType.appendKey(root, annotationType)) def copy(newChildren: IndexedSeq[BaseIR]): MatrixAnnotateRowsTable = { @@ -808,27 +827,28 @@ case class MatrixExplodeRows(child: MatrixIR, path: IndexedSeq[String]) extends override def columnCount: Option[Int] = child.columnCount - val idx = Ref(genUID(), TInt32) - - val newRow: InsertFields = { - val refs = path.init.scanLeft(Ref("va", child.typ.rowType))((struct, name) => - Ref(genUID(), tcoerce[TStruct](struct.typ).field(name).typ)) + lazy val typ: MatrixType = { + // FIXME: compute row type directly + val newRow: InsertFields = { + val refs = path.init.scanLeft(Ref("va", child.typ.rowType))((struct, name) => + Ref(genUID(), tcoerce[TStruct](struct.typ).field(name).typ)) + + path.zip(refs).zipWithIndex.foldRight[IR](Ref(genUID(), TInt32)) { + case (((field, ref), i), arg) => + InsertFields(ref, FastSeq(field -> + (if (i == refs.length - 1) + ArrayRef(ToArray(ToStream(GetField(ref, field))), arg) + else + Let(refs(i + 1).name, GetField(ref, field), arg)))) + }.asInstanceOf[InsertFields] + } - path.zip(refs).zipWithIndex.foldRight[IR](idx) { - case (((field, ref), i), arg) => - InsertFields(ref, FastSeq(field -> - (if (i == refs.length - 1) - ArrayRef(ToArray(ToStream(GetField(ref, field))), arg) - else - Let(refs(i + 1).name, GetField(ref, field), arg)))) - }.asInstanceOf[InsertFields] + child.typ.copy(rowType = newRow.typ) } - - val typ: MatrixType = child.typ.copy(rowType = newRow.typ) } case class MatrixRepartition(child: MatrixIR, n: Int, strategy: Int) extends MatrixIR { - val typ: MatrixType = child.typ + def typ: MatrixType = child.typ lazy val childrenSeq: IndexedSeq[BaseIR] = FastSeq(child) @@ -844,8 +864,12 @@ case class MatrixRepartition(child: MatrixIR, n: Int, strategy: Int) extends Mat case class MatrixUnionRows(childrenSeq: IndexedSeq[MatrixIR]) extends MatrixIR { require(childrenSeq.length > 1) - require(childrenSeq.tail.forall(c => compatible(c.typ, childrenSeq.head.typ)), childrenSeq.map(_.typ)) - val typ: MatrixType = childrenSeq.head.typ + + override def typecheck(): Unit = { + assert(childrenSeq.tail.forall(c => compatible(c.typ, childrenSeq.head.typ)), childrenSeq.map(_.typ)) + } + + def typ: MatrixType = childrenSeq.head.typ def compatible(t1: MatrixType, t2: MatrixType): Boolean = { t1.colKeyStruct == t2.colKeyStruct && @@ -873,8 +897,7 @@ case class MatrixUnionRows(childrenSeq: IndexedSeq[MatrixIR]) extends MatrixIR { } case class MatrixDistinctByRow(child: MatrixIR) extends MatrixIR { - - val typ: MatrixType = child.typ + def typ: MatrixType = child.typ lazy val childrenSeq: IndexedSeq[BaseIR] = FastSeq(child) @@ -890,7 +913,7 @@ case class MatrixDistinctByRow(child: MatrixIR) extends MatrixIR { case class MatrixRowsHead(child: MatrixIR, n: Long) extends MatrixIR { require(n >= 0) - val typ: MatrixType = child.typ + def typ: MatrixType = child.typ override lazy val partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts.map { pc => val prefixSums = pc.iterator.scanLeft(0L)(_ + _) @@ -919,7 +942,7 @@ case class MatrixRowsHead(child: MatrixIR, n: Long) extends MatrixIR { case class MatrixColsHead(child: MatrixIR, n: Int) extends MatrixIR { require(n >= 0) - val typ: MatrixType = child.typ + def typ: MatrixType = child.typ lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child) @@ -937,7 +960,7 @@ case class MatrixColsHead(child: MatrixIR, n: Int) extends MatrixIR { case class MatrixRowsTail(child: MatrixIR, n: Long) extends MatrixIR { require(n >= 0) - val typ: MatrixType = child.typ + def typ: MatrixType = child.typ lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child) @@ -956,7 +979,7 @@ case class MatrixRowsTail(child: MatrixIR, n: Long) extends MatrixIR { case class MatrixColsTail(child: MatrixIR, n: Int) extends MatrixIR { require(n >= 0) - val typ: MatrixType = child.typ + def typ: MatrixType = child.typ lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child) @@ -987,13 +1010,12 @@ case class MatrixExplodeCols(child: MatrixIR, path: IndexedSeq[String]) extends lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound - private val (keysType, querier) = child.typ.colType.queryTyped(path.toList) - private val keyType = keysType match { - case TArray(e) => e - case TSet(e) => e + lazy val typ: MatrixType = { + val (keysType, _) = child.typ.colType.queryTyped(path.toList) + val keyType = keysType.asInstanceOf[TContainer].elementType + child.typ.copy( + colType = child.typ.colType.structInsert(keyType, path.toList)) } - val (newColType, inserter) = child.typ.colType.structInsert(keyType, path.toList) - val typ: MatrixType = child.typ.copy(colType = newColType) } /** Create a MatrixTable from a Table, where the column values are stored in a @@ -1006,13 +1028,14 @@ case class CastTableToMatrix( colsFieldName: String, colKey: IndexedSeq[String] ) extends MatrixIR { - - child.typ.rowType.fieldType(entriesFieldName) match { - case TArray(TStruct(_)) => - case t => fatal(s"expected entry field to be an array of structs, found $t") + override def typecheck(): Unit = { + child.typ.rowType.fieldType(entriesFieldName) match { + case TArray(TStruct(_)) => + case t => fatal(s"expected entry field to be an array of structs, found $t") + } } - val typ: MatrixType = MatrixType.fromTableType(child.typ, colsFieldName, entriesFieldName, colKey) + lazy val typ: MatrixType = MatrixType.fromTableType(child.typ, colsFieldName, entriesFieldName, colKey) lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child) @@ -1046,12 +1069,19 @@ case class MatrixToMatrixApply(child: MatrixIR, function: MatrixToMatrixFunction lazy val rowCountUpperBound: Option[Long] = if (function.preservesPartitionCounts) child.rowCountUpperBound else None } -case class MatrixRename(child: MatrixIR, - globalMap: Map[String, String], colMap: Map[String, String], rowMap: Map[String, String], entryMap: Map[String, String]) extends MatrixIR { - require(globalMap.keys.forall(child.typ.globalType.hasField)) - require(colMap.keys.forall(child.typ.colType.hasField)) - require(rowMap.keys.forall(child.typ.rowType.hasField)) - require(entryMap.keys.forall(child.typ.entryType.hasField)) +case class MatrixRename( + child: MatrixIR, + globalMap: Map[String, String], + colMap: Map[String, String], + rowMap: Map[String, String], + entryMap: Map[String, String] +) extends MatrixIR { + override def typecheck(): Unit = { + assert(globalMap.keys.forall(child.typ.globalType.hasField)) + assert(colMap.keys.forall(child.typ.colType.hasField)) + assert(rowMap.keys.forall(child.typ.rowType.hasField)) + assert(entryMap.keys.forall(child.typ.entryType.hasField)) + } lazy val typ: MatrixType = MatrixType( globalType = child.typ.globalType.rename(globalMap), @@ -1083,7 +1113,7 @@ case class MatrixFilterIntervals(child: MatrixIR, intervals: IndexedSeq[Interval MatrixFilterIntervals(newChild, intervals, keep) } - override lazy val typ: MatrixType = child.typ + override def typ: MatrixType = child.typ override def columnCount: Option[Int] = child.columnCount diff --git a/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala b/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala index ee1b92cd449..b3c71f57345 100644 --- a/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala +++ b/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala @@ -99,7 +99,7 @@ object NestingDepth { computeIR(left, depth) computeIR(right, depth) computeIR(joinF, depth.incrementEval) - case TailLoop(_, params, body) => + case TailLoop(_, params, _, body) => params.foreach { case (_, p) => computeIR(p, depth) } computeIR(body, depth.incrementEval) case NDArrayMap(nd, _, body) => diff --git a/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala b/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala index 6be2c8f42a0..fab2531f5a6 100644 --- a/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala +++ b/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala @@ -63,14 +63,14 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = newValue <- normalize(value, valueEnv) newBody <- normalize(body, bodyEnv) } yield AggLet(newName, newValue, newBody, isScan) - case TailLoop(name, args, body) => + case TailLoop(name, args, resultType, body) => val newFName = gen() val newNames = Array.tabulate(args.length)(i => gen()) val (names, values) = args.unzip for { newValues <- values.mapRecur(v => normalize(v)) newBody <- normalize(body, env.copy(eval = env.eval.bind(names.zip(newNames) :+ name -> newFName: _*))) - } yield TailLoop(newFName, newNames.zip(newValues), newBody) + } yield TailLoop(newFName, newNames.zip(newValues), resultType, newBody) case ArraySort(a, left, right, lessThan) => val newLeft = gen() val newRight = gen() diff --git a/hail/src/main/scala/is/hail/expr/ir/Parser.scala b/hail/src/main/scala/is/hail/expr/ir/Parser.scala index 915c5eafe1e..668ca0d7f20 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Parser.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Parser.scala @@ -3,7 +3,7 @@ package is.hail.expr.ir import is.hail.HailContext import is.hail.backend.ExecuteContext import is.hail.expr.ir.agg._ -import is.hail.expr.ir.functions.RelationalFunctions +import is.hail.expr.ir.functions.{IRFunctionRegistry, RelationalFunctions} import is.hail.expr.{JSONAnnotationImpex, Nat, ParserUtils} import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.rvd.{RVDPartitioner, RVDType} @@ -128,70 +128,7 @@ object IRLexer extends JavaTokenParsers { case class IRParserEnvironment( ctx: ExecuteContext, - refMap: BindingEnv[Type] = BindingEnv.empty[Type], - irMap: Map[Int, BaseIR] = Map.empty, -) { - - def promoteAgg: IRParserEnvironment = copy(refMap = refMap.promoteAgg) - - def promoteScan: IRParserEnvironment = copy(refMap = refMap.promoteScan) - - def promoteAggScan(isScan: Boolean): IRParserEnvironment = - if (isScan) promoteScan else promoteAgg - - def noAgg: IRParserEnvironment = copy(refMap = refMap.noAgg) - - def noScan: IRParserEnvironment = copy(refMap = refMap.noScan) - - def noAggScan(isScan: Boolean): IRParserEnvironment = - if (isScan) noScan else noAgg - - def createAgg: IRParserEnvironment = copy(refMap = refMap.createAgg) - - def createScan: IRParserEnvironment = copy(refMap = refMap.createScan) - - def onlyRelational: IRParserEnvironment = { - if (refMap.eval.isEmpty && refMap.agg.isEmpty && refMap.scan.isEmpty) - this - else - copy(refMap = refMap.onlyRelational) - } - - def empty: IRParserEnvironment = copy(refMap = BindingEnv.empty) - - def bindEval(name: String, t: Type): IRParserEnvironment = - copy(refMap = refMap.bindEval(name, t)) - - def bindEval(bindings: (String, Type)*): IRParserEnvironment = - copy(refMap = refMap.bindEval(bindings: _*)) - - def bindEval(bindings: Env[Type]): IRParserEnvironment = - copy(refMap = refMap.bindEval(bindings.m.toSeq: _*)) - - def bindAggScan(isScan: Boolean, bindings: (String, Type)*): IRParserEnvironment = - copy(refMap = if (isScan) refMap.bindScan(bindings: _*) else refMap.bindAgg(bindings: _*)) - - def bindAgg(name: String, t: Type): IRParserEnvironment = - copy(refMap = refMap.bindAgg(name, t)) - - def bindAgg(bindings: (String, Type)*): IRParserEnvironment = - copy(refMap = refMap.bindAgg(bindings: _*)) - - def bindAgg(bindings: Env[Type]): IRParserEnvironment = - copy(refMap = refMap.bindAgg(bindings.m.toSeq: _*)) - - def bindScan(name: String, t: Type): IRParserEnvironment = - copy(refMap = refMap.bindScan(name, t)) - - def bindScan(bindings: (String, Type)*): IRParserEnvironment = - copy(refMap = refMap.bindScan(bindings: _*)) - - def bindScan(bindings: Env[Type]): IRParserEnvironment = - copy(refMap = refMap.bindScan(bindings.m.toSeq: _*)) - - def bindRelational(name: String, t: Type): IRParserEnvironment = - copy(refMap = refMap.bindRelational(name, t)) -} + irMap: Map[Int, BaseIR] = Map.empty) object IRParser { def error(t: Token, msg: String): Nothing = ParserUtils.error(t.pos, msg) @@ -743,7 +680,7 @@ object IRParser { val vtwr = vtwr_expr(it) val accumName = identifier(it) val otherAccumName = identifier(it) - val combIR = ir_value_expr(env.empty.bindEval(accumName -> vtwr.t, otherAccumName -> vtwr.t))(it).run() + val combIR = ir_value_expr(env)(it).run() FoldStateSig(vtwr.canonicalEmitType, accumName, otherAccumName, combIR) } punctuation(it, ")") @@ -829,6 +766,16 @@ object IRParser { } yield ir } + def apply_like(env: IRParserEnvironment, cons: (String, Seq[Type], Seq[IR], Type, Int) => IR)(it: TokenIterator): StackFrame[IR] = { + val errorID = int32_literal(it) + val function = identifier(it) + val typeArgs = type_exprs(it) + val rt = type_expr(it) + ir_value_children(env)(it).map { args => + cons(function, typeArgs, args, rt, errorID) + } + } + def ir_value_expr_1(env: IRParserEnvironment)(it: TokenIterator): StackFrame[IR] = { identifier(it) match { case "I32" => done(I32(int32_literal(it))) @@ -881,33 +828,32 @@ object IRParser { val name = identifier(it) for { value <- ir_value_expr(env)(it) - body <- ir_value_expr(env.bindEval(name, value.typ))(it) + body <- ir_value_expr(env)(it) } yield Let(name, value, body) case "AggLet" => val name = identifier(it) val isScan = boolean_literal(it) for { - value <- ir_value_expr(env.promoteAggScan(isScan))(it) - body <- ir_value_expr(env.bindAggScan(isScan, name -> value.typ))(it) + value <- ir_value_expr(env)(it) + body <- ir_value_expr(env)(it) } yield AggLet(name, value, body, isScan) case "TailLoop" => val name = identifier(it) val paramNames = identifiers(it) + val resultType = type_expr(it) for { paramIRs <- fillArray(paramNames.length)(ir_value_expr(env)(it)) params = paramNames.zip(paramIRs) - bodyEnv = env.bindEval(params.map { case (n, v) => n -> v.typ}: _*) - body <- ir_value_expr(bodyEnv)(it) - } yield TailLoop(name, params, body) + body <- ir_value_expr(env)(it) + } yield TailLoop(name, params, resultType, body) case "Recur" => val name = identifier(it) - val typ = type_expr(it) ir_value_children(env)(it).map { args => - Recur(name, args, typ) + Recur(name, args, null) } case "Ref" => val id = identifier(it) - done(Ref(id, env.refMap.eval(id))) + done(Ref(id, null)) case "RelationalRef" => val id = identifier(it) val t = type_expr(it) @@ -915,8 +861,8 @@ object IRParser { case "RelationalLet" => val name = identifier(it) for { - value <- ir_value_expr(env.onlyRelational)(it) - body <- ir_value_expr(env.noAgg.noScan.bindRelational(name, value.typ))(it) + value <- ir_value_expr(env)(it) + body <- ir_value_expr(env)(it) } yield RelationalLet(name, value, body) case "ApplyBinaryPrimOp" => val op = BinaryOp.fromString(identifier(it)) @@ -932,11 +878,11 @@ object IRParser { for { l <- ir_value_expr(env)(it) r <- ir_value_expr(env)(it) - } yield ApplyComparisonOp(ComparisonOp.fromStringAndTypes((opName, l.typ, r.typ)), l, r) + } yield ApplyComparisonOp(ComparisonOp.fromString(opName), l, r) case "MakeArray" => val typ = opt(it, type_expr).map(_.asInstanceOf[TArray]).orNull ir_value_children(env)(it).map { args => - MakeArray.unify(env.ctx, args, typ) + MakeArray(args, typ) } case "MakeStream" => val typ = opt(it, type_expr).map(_.asInstanceOf[TStream]).orNull @@ -990,8 +936,7 @@ object IRParser { val r = identifier(it) for { a <- ir_value_expr(env)(it) - elt = tcoerce[TStream](a.typ).elementType - lessThan <- ir_value_expr(env.bindEval(l -> elt, r -> elt))(it) + lessThan <- ir_value_expr(env)(it) } yield ArraySort(a, l, r, lessThan) case "ArrayMaximalIndependentSet" => val hasTieBreaker = boolean_literal(it) @@ -999,10 +944,8 @@ object IRParser { for { edges <- ir_value_expr(env)(it) tieBreaker <- if (hasTieBreaker) { - val eltType = tcoerce[TArray](edges.typ).elementType.asInstanceOf[TBaseStruct].types.head - val tbType = TTuple(eltType) val Some((left, right)) = bindings - ir_value_expr(IRParserEnvironment(env.ctx, BindingEnv.eval(left -> tbType, right -> tbType)))(it).map(tbf => Some((left, right, tbf))) + ir_value_expr(env)(it).map(tbf => Some((left, right, tbf))) } else { done(None) } @@ -1030,7 +973,7 @@ object IRParser { val name = identifier(it) for { nd <- ir_value_expr(env)(it) - body <- ir_value_expr(env.bindEval(name, tcoerce[TNDArray](nd.typ).elementType))(it) + body <- ir_value_expr(env)(it) } yield NDArrayMap(nd, name, body) case "NDArrayMap2" => val errorID = int32_literal(it) @@ -1039,10 +982,7 @@ object IRParser { for { l <- ir_value_expr(env)(it) r <- ir_value_expr(env)(it) - body_env = env.bindEval( - lName -> tcoerce[TNDArray](l.typ).elementType, - rName -> tcoerce[TNDArray](r.typ).elementType) - body <- ir_value_expr(body_env)(it) + body <- ir_value_expr(env)(it) } yield NDArrayMap2(l, r, lName, rName, body, errorID) case "NDArrayReindex" => val indexExpr = int32_literals(it) @@ -1068,7 +1008,7 @@ object IRParser { case "NDArrayFilter" => for { nd <- ir_value_expr(env)(it) - filters <- fillArray(tcoerce[TNDArray](nd.typ).nDims)(ir_value_expr(env)(it)) + filters <- repUntil(it, ir_value_expr(env), PunctuationToken(")")) } yield NDArrayFilter(nd, filters.toFastSeq) case "NDArrayMatMul" => val errorID = int32_literal(it) @@ -1123,7 +1063,7 @@ object IRParser { val name = identifier(it) for { a <- ir_value_expr(env)(it) - body <- ir_value_expr(env.bindEval(name, tcoerce[TStream](a.typ).elementType))(it) + body <- ir_value_expr(env)(it) } yield StreamMap(a, name, body) case "StreamTake" => for { @@ -1146,7 +1086,7 @@ object IRParser { val names = identifiers(it) for { as <- names.mapRecur(_ => ir_value_expr(env)(it)) - body <- ir_value_expr(env.bindEval(names.zip(as.map(a => tcoerce[TStream](a.typ).elementType)): _*))(it) + body <- ir_value_expr(env)(it) } yield StreamZip(as, names, body, behavior, errorID) case "StreamZipJoinProducers" => val key = identifiers(it) @@ -1155,10 +1095,9 @@ object IRParser { val curVals = identifier(it) for { ctxs <- ir_value_expr(env)(it) - makeProducer <- ir_value_expr(env.bindEval(ctxName, TIterable.elementType(ctxs.typ)))(it) + makeProducer <- ir_value_expr(env)(it) body <- { - val structType = TIterable.elementType(makeProducer.typ).asInstanceOf[TStruct] - ir_value_expr(env.bindEval((curKey, structType.typeAfterSelectNames(key)), (curVals, TArray(structType))))(it) + ir_value_expr(env)(it) } } yield StreamZipJoinProducers(ctxs, ctxName, makeProducer, key, curKey, curVals, body) case "StreamZipJoin" => @@ -1169,8 +1108,7 @@ object IRParser { for { streams <- (0 until nStreams).mapRecur(_ => ir_value_expr(env)(it)) body <- { - val structType = streams.head.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct] - ir_value_expr(env.bindEval((curKey, structType.typeAfterSelectNames(key)), (curVals, TArray(structType))))(it) + ir_value_expr(env)(it) } } yield StreamZipJoin(streams, key, curKey, curVals, body) case "StreamMultiMerge" => @@ -1182,25 +1120,25 @@ object IRParser { val name = identifier(it) for { a <- ir_value_expr(env)(it) - body <- ir_value_expr(env.bindEval(name, tcoerce[TStream](a.typ).elementType))(it) + body <- ir_value_expr(env)(it) } yield StreamFilter(a, name, body) case "StreamTakeWhile" => val name = identifier(it) for { a <- ir_value_expr(env)(it) - body <- ir_value_expr(env.bindEval(name, tcoerce[TStream](a.typ).elementType))(it) + body <- ir_value_expr(env)(it) } yield StreamTakeWhile(a, name, body) case "StreamDropWhile" => val name = identifier(it) for { a <- ir_value_expr(env)(it) - body <- ir_value_expr(env.bindEval(name, tcoerce[TStream](a.typ).elementType))(it) + body <- ir_value_expr(env)(it) } yield StreamDropWhile(a, name, body) case "StreamFlatMap" => val name = identifier(it) for { a <- ir_value_expr(env)(it) - body <- ir_value_expr(env.bindEval(name, tcoerce[TStream](a.typ).elementType))(it) + body <- ir_value_expr(env)(it) } yield StreamFlatMap(a, name, body) case "StreamFold" => val accumName = identifier(it) @@ -1208,8 +1146,7 @@ object IRParser { for { a <- ir_value_expr(env)(it) zero <- ir_value_expr(env)(it) - eltType = tcoerce[TStream](a.typ).elementType - body <- ir_value_expr(env.bindEval(accumName -> zero.typ, valueName -> eltType))(it) + body <- ir_value_expr(env)(it) } yield StreamFold(a, zero, accumName, valueName, body) case "StreamFold2" => val accumNames = identifiers(it) @@ -1218,11 +1155,8 @@ object IRParser { a <- ir_value_expr(env)(it) accIRs <- fillArray(accumNames.length)(ir_value_expr(env)(it)) accs = accumNames.zip(accIRs) - eltType = tcoerce[TStream](a.typ).elementType - resultEnv = env.bindEval(accs.map { case (name, value) => (name, value.typ) }: _*) - seqEnv = resultEnv.bindEval(valueName, eltType) - seqs <- fillArray(accs.length)(ir_value_expr(seqEnv)(it)) - res <- ir_value_expr(resultEnv)(it) + seqs <- fillArray(accs.length)(ir_value_expr(env)(it)) + res <- ir_value_expr(env)(it) } yield StreamFold2(a, accs, valueName, seqs, res) case "StreamScan" => val accumName = identifier(it) @@ -1230,8 +1164,7 @@ object IRParser { for { a <- ir_value_expr(env)(it) zero <- ir_value_expr(env)(it) - eltType = tcoerce[TStream](a.typ).elementType - body <- ir_value_expr(env.bindEval(accumName -> zero.typ, valueName -> eltType))(it) + body <- ir_value_expr(env)(it) } yield StreamScan(a, zero, accumName, valueName, body) case "StreamWhiten" => val newChunk = identifier(it) @@ -1253,27 +1186,25 @@ object IRParser { for { left <- ir_value_expr(env)(it) right <- ir_value_expr(env)(it) - lelt = tcoerce[TStream](left.typ).elementType - relt = tcoerce[TStream](right.typ).elementType - join <- ir_value_expr(env.bindEval(l -> lelt, r -> relt))(it) + join <- ir_value_expr(env)(it) } yield StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, joinType) case "StreamFor" => val name = identifier(it) for { a <- ir_value_expr(env)(it) - body <- ir_value_expr(env.bindEval(name, tcoerce[TStream](a.typ).elementType))(it) + body <- ir_value_expr(env)(it) } yield StreamFor(a, name, body) case "StreamAgg" => val name = identifier(it) for { a <- ir_value_expr(env)(it) - query <- ir_value_expr(env.createAgg.bindAgg(name, tcoerce[TStream](a.typ).elementType))(it) + query <- ir_value_expr(env)(it) } yield StreamAgg(a, name, query) case "StreamAggScan" => val name = identifier(it) for { a <- ir_value_expr(env)(it) - query <- ir_value_expr(env.createScan.bindScan(name, tcoerce[TStream](a.typ).elementType))(it) + query <- ir_value_expr(env)(it) } yield StreamAggScan(a, name, query) case "RunAgg" => val signatures = agg_state_signatures(env)(it) @@ -1286,28 +1217,27 @@ object IRParser { val signatures = agg_state_signatures(env)(it) for { array <- ir_value_expr(env)(it) - newE = env.bindEval(name, tcoerce[TStream](array.typ).elementType) init <- ir_value_expr(env)(it) - seq <- ir_value_expr(newE)(it) - result <- ir_value_expr(newE)(it) + seq <- ir_value_expr(env)(it) + result <- ir_value_expr(env)(it) } yield RunAggScan(array, name, init, seq, result, signatures) case "AggFilter" => val isScan = boolean_literal(it) for { - cond <- ir_value_expr(env.promoteAggScan(isScan))(it) + cond <- ir_value_expr(env)(it) aggIR <- ir_value_expr(env)(it) } yield AggFilter(cond, aggIR, isScan) case "AggExplode" => val name = identifier(it) val isScan = boolean_literal(it) for { - a <- ir_value_expr(env.promoteAggScan(isScan))(it) - aggBody <- ir_value_expr(env.bindAggScan(isScan, name -> tcoerce[TStream](a.typ).elementType))(it) + a <- ir_value_expr(env)(it) + aggBody <- ir_value_expr(env)(it) } yield AggExplode(a, name, aggBody, isScan) case "AggGroupBy" => val isScan = boolean_literal(it) for { - key <- ir_value_expr(env.promoteAggScan(isScan))(it) + key <- ir_value_expr(env)(it) aggIR <- ir_value_expr(env)(it) } yield AggGroupBy(key, aggIR, isScan) case "AggArrayPerElement" => @@ -1316,39 +1246,32 @@ object IRParser { val isScan = boolean_literal(it) val hasKnownLength = boolean_literal(it) for { - a <- ir_value_expr(env.promoteAggScan(isScan))(it) - aggBody <- ir_value_expr(env - .bindEval(indexName, TInt32) - .bindAggScan(isScan, indexName -> TInt32, elementName -> tcoerce[TArray](a.typ).elementType))(it) + a <- ir_value_expr(env)(it) + aggBody <- ir_value_expr(env)(it) knownLength <- if (hasKnownLength) ir_value_expr(env)(it).map(Some(_)) else done(None) } yield AggArrayPerElement(a, elementName, indexName, aggBody, knownLength, isScan) case "ApplyAggOp" => val aggOp = agg_op(it) for { - initOpArgs <- ir_value_exprs(env.noAgg)(it) - seqOpArgs <- ir_value_exprs(env.promoteAgg)(it) - aggSig = AggSignature(aggOp, initOpArgs.map(arg => arg.typ), seqOpArgs.map(arg => arg.typ)) + initOpArgs <- ir_value_exprs(env)(it) + seqOpArgs <- ir_value_exprs(env)(it) + aggSig = AggSignature(aggOp, null, null) } yield ApplyAggOp(initOpArgs, seqOpArgs, aggSig) case "ApplyScanOp" => val aggOp = agg_op(it) for { - initOpArgs <- ir_value_exprs(env.noScan)(it) - seqOpArgs <- ir_value_exprs(env.promoteScan)(it) - aggSig = AggSignature(aggOp, initOpArgs.map(arg => arg.typ), seqOpArgs.map(arg => arg.typ)) + initOpArgs <- ir_value_exprs(env)(it) + seqOpArgs <- ir_value_exprs(env)(it) + aggSig = AggSignature(aggOp, null, null) } yield ApplyScanOp(initOpArgs, seqOpArgs, aggSig) case "AggFold" => val accumName = identifier(it) val otherAccumName = identifier(it) val isScan = boolean_literal(it) for { - zero <- ir_value_expr(env.noAggScan(isScan))(it) - seqOp <- ir_value_expr(env.promoteAggScan(isScan).bindEval(accumName, zero.typ))(it) - combEnv = (if (isScan) - env.copy(refMap = env.refMap.copy(eval = Env.empty, scan = None)) - else - env.copy(refMap = env.refMap.copy(eval = Env.empty, agg = None)) - ).bindEval(accumName -> zero.typ, otherAccumName -> zero.typ) - combOp <- ir_value_expr(combEnv)(it) + zero <- ir_value_expr(env)(it) + seqOp <- ir_value_expr(env)(it) + combOp <- ir_value_expr(env)(it) } yield AggFold(zero, seqOp, combOp, accumName, otherAccumName, isScan) case "InitOp" => val i = int32_literal(it) @@ -1450,14 +1373,12 @@ object IRParser { rngState <- ir_value_expr(env)(it) args <- ir_value_children(env)(it) } yield ApplySeeded(function, args, rngState, staticUID, rt) - case "ApplyIR" | "ApplySpecial" | "Apply" => - val errorID = int32_literal(it) - val function = identifier(it) - val typeArgs = type_exprs(it) - val rt = type_expr(it) - ir_value_children(env)(it).map { args => - invoke(function, rt, typeArgs, errorID, args: _*) - } + case "ApplyIR" => + apply_like(env, ApplyIR)(it) + case "ApplySpecial" => + apply_like(env, ApplySpecial)(it) + case "Apply" => + apply_like(env, Apply)(it) case "MatrixCount" => matrix_ir(env)(it).map(MatrixCount) case "TableCount" => @@ -1468,17 +1389,17 @@ object IRParser { table_ir(env)(it).map(TableCollect) case "TableAggregate" => for { - child <- table_ir(env.onlyRelational)(it) - query <- ir_value_expr(env.onlyRelational.createAgg.bindEval(child.typ.globalEnv).bindAgg(child.typ.rowEnv))(it) + child <- table_ir(env)(it) + query <- ir_value_expr(env)(it) } yield TableAggregate(child, query) case "TableToValueApply" => val config = string_literal(it) - table_ir(env.onlyRelational)(it).map { child => + table_ir(env)(it).map { child => TableToValueApply(child, RelationalFunctions.lookupTableToValue(env.ctx, config)) } case "MatrixToValueApply" => val config = string_literal(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixToValueApply(child, RelationalFunctions.lookupMatrixToValue(env.ctx, config)) } case "BlockMatrixToValueApply" => @@ -1502,8 +1423,8 @@ object IRParser { } case "MatrixAggregate" => for { - child <- matrix_ir(env.onlyRelational)(it) - query <- ir_value_expr(env.onlyRelational.createAgg.bindEval(child.typ.globalEnv).bindAgg(child.typ.entryEnv))(it) + child <- matrix_ir(env)(it) + query <- ir_value_expr(env)(it) } yield MatrixAggregate(child, query) case "MatrixWrite" => val writerStr = string_literal(it) @@ -1540,7 +1461,7 @@ object IRParser { for { ctxs <- ir_value_expr(env)(it) globals <- ir_value_expr(env)(it) - body <- ir_value_expr(env.onlyRelational.bindEval(cname -> tcoerce[TStream](ctxs.typ).elementType, gname -> globals.typ))(it) + body <- ir_value_expr(env)(it) dynamicID <- ir_value_expr(env)(it) } yield CollectDistributedArray(ctxs, globals, cname, gname, body, dynamicID, staticID) case "JavaIR" => @@ -1624,14 +1545,14 @@ object IRParser { case "TableKeyBy" => val keys = identifiers(it) val isSorted = boolean_literal(it) - table_ir(env.onlyRelational)(it).map { child => + table_ir(env)(it).map { child => TableKeyBy(child, keys, isSorted) } - case "TableDistinct" => table_ir(env.onlyRelational)(it).map(TableDistinct) + case "TableDistinct" => table_ir(env)(it).map(TableDistinct) case "TableFilter" => for { - child <- table_ir(env.onlyRelational)(it) - pred <- ir_value_expr(env.onlyRelational.bindEval(child.typ.rowEnv))(it) + child <- table_ir(env)(it) + pred <- ir_value_expr(env)(it) } yield TableFilter(child, pred) case "TableRead" => val requestedTypeRaw = it.head match { @@ -1650,123 +1571,123 @@ object IRParser { case Right(t) => t } done(TableRead(requestedType, dropRows, reader)) - case "MatrixColsTable" => matrix_ir(env.onlyRelational)(it).map(MatrixColsTable) - case "MatrixRowsTable" => matrix_ir(env.onlyRelational)(it).map(MatrixRowsTable) - case "MatrixEntriesTable" => matrix_ir(env.onlyRelational)(it).map(MatrixEntriesTable) + case "MatrixColsTable" => matrix_ir(env)(it).map(MatrixColsTable) + case "MatrixRowsTable" => matrix_ir(env)(it).map(MatrixRowsTable) + case "MatrixEntriesTable" => matrix_ir(env)(it).map(MatrixEntriesTable) case "TableAggregateByKey" => for { - child <- table_ir(env.onlyRelational)(it) - expr <- ir_value_expr(env.onlyRelational.createAgg.bindEval(child.typ.globalEnv).bindAgg(child.typ.rowEnv))(it) + child <- table_ir(env)(it) + expr <- ir_value_expr(env)(it) } yield TableAggregateByKey(child, expr) case "TableKeyByAndAggregate" => val nPartitions = opt(it, int32_literal) val bufferSize = int32_literal(it) for { - child <- table_ir(env.onlyRelational)(it) - expr <- ir_value_expr(env.onlyRelational.createAgg.bindEval(child.typ.globalEnv).bindAgg(child.typ.rowEnv))(it) - newKey <- ir_value_expr(env.onlyRelational.bindEval(child.typ.rowEnv))(it) + child <- table_ir(env)(it) + expr <- ir_value_expr(env)(it) + newKey <- ir_value_expr(env)(it) } yield TableKeyByAndAggregate(child, expr, newKey, nPartitions, bufferSize) case "TableRepartition" => val n = int32_literal(it) val strategy = int32_literal(it) - table_ir(env.onlyRelational)(it).map { child => + table_ir(env)(it).map { child => TableRepartition(child, n, strategy) } case "TableHead" => val n = int64_literal(it) - table_ir(env.onlyRelational)(it).map { child => + table_ir(env)(it).map { child => TableHead(child, n) } case "TableTail" => val n = int64_literal(it) - table_ir(env.onlyRelational)(it).map { child => + table_ir(env)(it).map { child => TableTail(child, n) } case "TableJoin" => val joinType = identifier(it) val joinKey = int32_literal(it) for { - left <- table_ir(env.onlyRelational)(it) - right <- table_ir(env.onlyRelational)(it) + left <- table_ir(env)(it) + right <- table_ir(env)(it) } yield TableJoin(left, right, joinType, joinKey) case "TableLeftJoinRightDistinct" => val root = identifier(it) for { - left <- table_ir(env.onlyRelational)(it) - right <- table_ir(env.onlyRelational)(it) + left <- table_ir(env)(it) + right <- table_ir(env)(it) } yield TableLeftJoinRightDistinct(left, right, root) case "TableIntervalJoin" => val root = identifier(it) val product = boolean_literal(it) for { - left <- table_ir(env.onlyRelational)(it) - right <- table_ir(env.onlyRelational)(it) + left <- table_ir(env)(it) + right <- table_ir(env)(it) } yield TableIntervalJoin(left, right, root, product) case "TableMultiWayZipJoin" => val dataName = string_literal(it) val globalsName = string_literal(it) - table_ir_children(env.onlyRelational)(it).map { children => + table_ir_children(env)(it).map { children => TableMultiWayZipJoin(children, dataName, globalsName) } case "TableParallelize" => val nPartitions = opt(it, int32_literal) - ir_value_expr(env.onlyRelational)(it).map { rowsAndGlobal => + ir_value_expr(env)(it).map { rowsAndGlobal => TableParallelize(rowsAndGlobal, nPartitions) } case "TableMapRows" => for { - child <- table_ir(env.onlyRelational)(it) - newRow <- ir_value_expr(env.onlyRelational.createScan.bindEval(child.typ.rowEnv).bindScan(child.typ.rowEnv))(it) + child <- table_ir(env)(it) + newRow <- ir_value_expr(env)(it) } yield TableMapRows(child, newRow) case "TableMapGlobals" => for { - child <- table_ir(env.onlyRelational)(it) - newRow <- ir_value_expr(env.onlyRelational.bindEval(child.typ.globalEnv))(it) + child <- table_ir(env)(it) + newRow <- ir_value_expr(env)(it) } yield TableMapGlobals(child, newRow) case "TableRange" => val n = int32_literal(it) val nPartitions = opt(it, int32_literal) done(TableRange(n, nPartitions.getOrElse(HailContext.backend.defaultParallelism))) - case "TableUnion" => table_ir_children(env.onlyRelational)(it).map(TableUnion(_)) + case "TableUnion" => table_ir_children(env)(it).map(TableUnion(_)) case "TableOrderBy" => val sortFields = sort_fields(it) - table_ir(env.onlyRelational)(it).map { child => + table_ir(env)(it).map { child => TableOrderBy(child, sortFields) } case "TableExplode" => val path = string_literals(it) - table_ir(env.onlyRelational)(it).map { child => + table_ir(env)(it).map { child => TableExplode(child, path) } case "CastMatrixToTable" => val entriesField = string_literal(it) val colsField = string_literal(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => CastMatrixToTable(child, entriesField, colsField) } case "MatrixToTableApply" => val config = string_literal(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixToTableApply(child, RelationalFunctions.lookupMatrixToTable(env.ctx, config)) } case "TableToTableApply" => val config = string_literal(it) - table_ir(env.onlyRelational)(it).map { child => + table_ir(env)(it).map { child => TableToTableApply(child, RelationalFunctions.lookupTableToTable(env.ctx, config)) } case "BlockMatrixToTableApply" => val config = string_literal(it) for { - bm <- blockmatrix_ir(env.onlyRelational)(it) - aux <- ir_value_expr(env.onlyRelational)(it) + bm <- blockmatrix_ir(env)(it) + aux <- ir_value_expr(env)(it) } yield BlockMatrixToTableApply(bm, aux, RelationalFunctions.lookupBlockMatrixToTable(env.ctx, config)) - case "BlockMatrixToTable" => blockmatrix_ir(env.onlyRelational)(it).map(BlockMatrixToTable) + case "BlockMatrixToTable" => blockmatrix_ir(env)(it).map(BlockMatrixToTable) case "TableRename" => val rowK = string_literals(it) val rowV = string_literals(it) val globalK = string_literals(it) val globalV = string_literals(it) - table_ir(env.onlyRelational)(it).map { child => + table_ir(env)(it).map { child => TableRename(child, rowK.zip(rowV).toMap, globalK.zip(globalV).toMap) } @@ -1776,21 +1697,19 @@ object IRParser { val partitioner = between(punctuation(_, "("), punctuation(_, ")"), partitioner_literal(env))(it) val errorId = int32_literal(it) for { - contexts <- ir_value_expr(env.onlyRelational)(it) - globals <- ir_value_expr(env.onlyRelational)(it) - body <- ir_value_expr(env.onlyRelational.bindEval( - cname -> TIterable.elementType(contexts.typ), - gname -> globals.typ - ))(it) + contexts <- ir_value_expr(env)(it) + globals <- ir_value_expr(env)(it) + body <- ir_value_expr(env)(it) } yield TableGen(contexts, globals, cname, gname, body, partitioner, errorId) case "TableFilterIntervals" => + val keyType = type_expr(it) val intervals = string_literal(it) val keep = boolean_literal(it) - table_ir(env.onlyRelational)(it).map { child => + table_ir(env)(it).map { child => TableFilterIntervals(child, JSONAnnotationImpex.importAnnotation(JsonMethods.parse(intervals), - TArray(TInterval(child.typ.keyType)), + TArray(TInterval(keyType)), padNulls = false).asInstanceOf[IndexedSeq[Interval]], keep) } @@ -1800,14 +1719,14 @@ object IRParser { val requestedKey = int32_literal(it) val allowedOverlap = int32_literal(it) for { - child <- table_ir(env.onlyRelational)(it) - body <- ir_value_expr(env.onlyRelational.bindEval(globalsName -> child.typ.globalType, partitionStreamName -> TStream(child.typ.rowType)))(it) + child <- table_ir(env)(it) + body <- ir_value_expr(env)(it) } yield TableMapPartitions(child, globalsName, partitionStreamName, body, requestedKey, allowedOverlap) case "RelationalLetTable" => val name = identifier(it) for { - value <- ir_value_expr(env.onlyRelational)(it) - body <- table_ir(env.onlyRelational.bindRelational(name, value.typ))(it) + value <- ir_value_expr(env)(it) + body <- table_ir(env)(it) } yield RelationalLetTable(name, value, body) case "JavaTable" => val id = int32_literal(it) @@ -1830,69 +1749,63 @@ object IRParser { identifier(it) match { case "MatrixFilterCols" => for { - child <- matrix_ir(env.onlyRelational)(it) - pred <- ir_value_expr(env.onlyRelational.bindEval(child.typ.colEnv))(it) + child <- matrix_ir(env)(it) + pred <- ir_value_expr(env)(it) } yield MatrixFilterCols(child, pred) case "MatrixFilterRows" => for { - child <- matrix_ir(env.onlyRelational)(it) - pred <- ir_value_expr(env.onlyRelational.bindEval(child.typ.rowEnv))(it) + child <- matrix_ir(env)(it) + pred <- ir_value_expr(env)(it) } yield MatrixFilterRows(child, pred) case "MatrixFilterEntries" => for { - child <- matrix_ir(env.onlyRelational)(it) - pred <- ir_value_expr(env.onlyRelational.bindEval(child.typ.entryEnv))(it) + child <- matrix_ir(env)(it) + pred <- ir_value_expr(env)(it) } yield MatrixFilterEntries(child, pred) case "MatrixMapCols" => val newKey = opt(it, string_literals) for { - child <- matrix_ir(env.onlyRelational)(it) - newEnv = env.onlyRelational.createAgg.createScan - .bindEval(child.typ.colEnv).bindEval("n_rows", TInt64) - .bindAgg(child.typ.entryEnv).bindScan(child.typ.colEnv) - newCol <- ir_value_expr(newEnv)(it) + child <- matrix_ir(env)(it) + newCol <- ir_value_expr(env)(it) } yield MatrixMapCols(child, newCol, newKey.map(_.toFastSeq)) case "MatrixKeyRowsBy" => val key = identifiers(it) val isSorted = boolean_literal(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixKeyRowsBy(child, key, isSorted) } case "MatrixMapRows" => for { - child <- matrix_ir(env.onlyRelational)(it) - newEnv = env.onlyRelational.createAgg.createScan - .bindEval(child.typ.rowEnv).bindEval("n_cols", TInt32) - .bindAgg(child.typ.entryEnv).bindScan(child.typ.rowEnv) - newRow <- ir_value_expr(newEnv)(it) + child <- matrix_ir(env)(it) + newRow <- ir_value_expr(env)(it) } yield MatrixMapRows(child, newRow) case "MatrixMapEntries" => for { child <- matrix_ir(env)(it) - newEntry <- ir_value_expr(env.onlyRelational.bindEval(child.typ.entryEnv))(it) + newEntry <- ir_value_expr(env)(it) } yield MatrixMapEntries(child, newEntry) case "MatrixUnionCols" => val joinType = identifier(it) for { - left <- matrix_ir(env.onlyRelational)(it) - right <- matrix_ir(env.onlyRelational)(it) + left <- matrix_ir(env)(it) + right <- matrix_ir(env)(it) } yield MatrixUnionCols(left, right, joinType) case "MatrixMapGlobals" => for { - child <- matrix_ir(env.onlyRelational)(it) - newGlobals <- ir_value_expr(env.onlyRelational.bindEval(child.typ.globalEnv))(it) + child <- matrix_ir(env)(it) + newGlobals <- ir_value_expr(env)(it) } yield MatrixMapGlobals(child, newGlobals) case "MatrixAggregateColsByKey" => for { - child <- matrix_ir(env.onlyRelational)(it) - entryExpr <- ir_value_expr(env.onlyRelational.createAgg.bindEval(child.typ.rowEnv).bindAgg(child.typ.entryEnv))(it) - colExpr <- ir_value_expr(env.onlyRelational.createAgg.bindEval(child.typ.globalEnv).bindAgg(child.typ.colEnv))(it) + child <- matrix_ir(env)(it) + entryExpr <- ir_value_expr(env)(it) + colExpr <- ir_value_expr(env)(it) } yield MatrixAggregateColsByKey(child, entryExpr, colExpr) case "MatrixAggregateRowsByKey" => for { - child <- matrix_ir(env.onlyRelational)(it) - entryExpr <- ir_value_expr(env.onlyRelational.createAgg.bindEval(child.typ.colEnv).bindAgg(child.typ.entryEnv))(it) - rowExpr <- ir_value_expr(env.onlyRelational.createAgg.bindEval(child.typ.globalEnv).bindAgg(child.typ.rowEnv))(it) + child <- matrix_ir(env)(it) + entryExpr <- ir_value_expr(env)(it) + rowExpr <- ir_value_expr(env)(it) } yield MatrixAggregateRowsByKey(child, entryExpr, rowExpr) case "MatrixRead" => val requestedTypeRaw = it.head match { @@ -1905,7 +1818,7 @@ object IRParser { val dropCols = boolean_literal(it) val dropRows = boolean_literal(it) val readerStr = string_literal(it) - val reader = MatrixReader.fromJson(env.onlyRelational, JsonMethods.parse(readerStr).asInstanceOf[JObject]) + val reader = MatrixReader.fromJson(env, JsonMethods.parse(readerStr).asInstanceOf[JObject]) val fullType = reader.fullMatrixType val requestedType = requestedTypeRaw match { case Left("None") => fullType @@ -1923,70 +1836,70 @@ object IRParser { val root = string_literal(it) val product = boolean_literal(it) for { - child <- matrix_ir(env.onlyRelational)(it) - table <- table_ir(env.onlyRelational)(it) + child <- matrix_ir(env)(it) + table <- table_ir(env)(it) } yield MatrixAnnotateRowsTable(child, table, root, product) case "MatrixAnnotateColsTable" => val root = string_literal(it) for { - child <- matrix_ir(env.onlyRelational)(it) - table <- table_ir(env.onlyRelational)(it) + child <- matrix_ir(env)(it) + table <- table_ir(env)(it) } yield MatrixAnnotateColsTable(child, table, root) case "MatrixExplodeRows" => val path = identifiers(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixExplodeRows(child, path) } case "MatrixExplodeCols" => val path = identifiers(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixExplodeCols(child, path) } case "MatrixChooseCols" => val oldIndices = int32_literals(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixChooseCols(child, oldIndices) } case "MatrixCollectColsByKey" => - matrix_ir(env.onlyRelational)(it).map(MatrixCollectColsByKey) + matrix_ir(env)(it).map(MatrixCollectColsByKey) case "MatrixRepartition" => val n = int32_literal(it) val strategy = int32_literal(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixRepartition(child, n, strategy) } - case "MatrixUnionRows" => matrix_ir_children(env.onlyRelational)(it).map(MatrixUnionRows(_)) - case "MatrixDistinctByRow" => matrix_ir(env.onlyRelational)(it).map(MatrixDistinctByRow) + case "MatrixUnionRows" => matrix_ir_children(env)(it).map(MatrixUnionRows(_)) + case "MatrixDistinctByRow" => matrix_ir(env)(it).map(MatrixDistinctByRow) case "MatrixRowsHead" => val n = int64_literal(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixRowsHead(child, n) } case "MatrixColsHead" => val n = int32_literal(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixColsHead(child, n) } case "MatrixRowsTail" => val n = int64_literal(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixRowsTail(child, n) } case "MatrixColsTail" => val n = int32_literal(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixColsTail(child, n) } case "CastTableToMatrix" => val entriesField = identifier(it) val colsField = identifier(it) val colKey = identifiers(it) - table_ir(env.onlyRelational)(it).map { child => + table_ir(env)(it).map { child => CastTableToMatrix(child, entriesField, colsField, colKey) } case "MatrixToMatrixApply" => val config = string_literal(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixToMatrixApply(child, RelationalFunctions.lookupMatrixToMatrix(env.ctx, config)) } case "MatrixRename" => @@ -1998,24 +1911,25 @@ object IRParser { val rowV = string_literals(it) val entryK = string_literals(it) val entryV = string_literals(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixRename(child, globalK.zip(globalV).toMap, colK.zip(colV).toMap, rowK.zip(rowV).toMap, entryK.zip(entryV).toMap) } case "MatrixFilterIntervals" => + val keyType = type_expr(it) val intervals = string_literal(it) val keep = boolean_literal(it) - matrix_ir(env.onlyRelational)(it).map { child => + matrix_ir(env)(it).map { child => MatrixFilterIntervals(child, JSONAnnotationImpex.importAnnotation(JsonMethods.parse(intervals), - TArray(TInterval(child.typ.rowKeyStruct)), + TArray(TInterval(keyType)), padNulls = false).asInstanceOf[IndexedSeq[Interval]], keep) } case "RelationalLetMatrixTable" => val name = identifier(it) for { - value <- ir_value_expr(env.onlyRelational)(it) - body <- matrix_ir(env.onlyRelational.bindRelational(name, value.typ))(it) + value <- ir_value_expr(env)(it) + body <- matrix_ir(env)(it) } yield RelationalLetMatrixTable(name, value, body) } } @@ -2026,7 +1940,8 @@ object IRParser { case "PyRowIntervalSparsifier" => val blocksOnly = boolean_literal(it) punctuation(it, ")") - ir_value_expr(env)(it).map { ir => + ir_value_expr(env)(it).map { ir_ => + val ir = annotateTypes(env.ctx, ir_, BindingEnv.empty).asInstanceOf[IR] val Row(starts: IndexedSeq[Long @unchecked], stops: IndexedSeq[Long @unchecked]) = CompileAndEvaluate[Row](env.ctx, ir) RowIntervalSparsifier(blocksOnly, starts, stops) @@ -2034,20 +1949,23 @@ object IRParser { case "PyBandSparsifier" => val blocksOnly = boolean_literal(it) punctuation(it, ")") - ir_value_expr(env)(it).map { ir => + ir_value_expr(env)(it).map { ir_ => + val ir = annotateTypes(env.ctx, ir_, BindingEnv.empty).asInstanceOf[IR] val Row(l: Long, u: Long) = CompileAndEvaluate[Row](env.ctx, ir) BandSparsifier(blocksOnly, l, u) } case "PyPerBlockSparsifier" => punctuation(it, ")") - ir_value_expr(env)(it).map { ir => + ir_value_expr(env)(it).map { ir_ => + val ir = annotateTypes(env.ctx, ir_, BindingEnv.empty).asInstanceOf[IR] val indices: IndexedSeq[Int] = CompileAndEvaluate[IndexedSeq[Int]](env.ctx, ir) PerBlockSparsifier(indices) } case "PyRectangleSparsifier" => punctuation(it, ")") - ir_value_expr(env)(it).map { ir => + ir_value_expr(env)(it).map { ir_ => + val ir = annotateTypes(env.ctx, ir_, BindingEnv.empty).asInstanceOf[IR] val rectangles: IndexedSeq[Long] = CompileAndEvaluate[IndexedSeq[Long]](env.ctx, ir) RectangleSparsifier(rectangles.grouped(4).toIndexedSeq) @@ -2089,56 +2007,56 @@ object IRParser { val name = identifier(it) val needs_dense = boolean_literal(it) for { - child <- blockmatrix_ir(env.onlyRelational)(it) - f <- ir_value_expr(env.onlyRelational.bindEval(name, child.typ.elementType))(it) + child <- blockmatrix_ir(env)(it) + f <- ir_value_expr(env)(it) } yield BlockMatrixMap(child, name, f, needs_dense) case "BlockMatrixMap2" => val lName = identifier(it) val rName = identifier(it) val sparsityStrategy = SparsityStrategy.fromString(identifier(it)) for { - left <- blockmatrix_ir(env.onlyRelational)(it) - right <- blockmatrix_ir(env.onlyRelational)(it) - f <- ir_value_expr(env.onlyRelational.bindEval(lName -> left.typ.elementType, rName -> right.typ.elementType))(it) + left <- blockmatrix_ir(env)(it) + right <- blockmatrix_ir(env)(it) + f <- ir_value_expr(env)(it) } yield BlockMatrixMap2(left, right, lName, rName, f, sparsityStrategy) case "BlockMatrixDot" => for { - left <- blockmatrix_ir(env.onlyRelational)(it) - right <- blockmatrix_ir(env.onlyRelational)(it) + left <- blockmatrix_ir(env)(it) + right <- blockmatrix_ir(env)(it) } yield BlockMatrixDot(left, right) case "BlockMatrixBroadcast" => val inIndexExpr = int32_literals(it) val shape = int64_literals(it) val blockSize = int32_literal(it) - blockmatrix_ir(env.onlyRelational)(it).map { child => + blockmatrix_ir(env)(it).map { child => BlockMatrixBroadcast(child, inIndexExpr, shape, blockSize) } case "BlockMatrixAgg" => val outIndexExpr = int32_literals(it) - blockmatrix_ir(env.onlyRelational)(it).map { child => + blockmatrix_ir(env)(it).map { child => BlockMatrixAgg(child, outIndexExpr) } case "BlockMatrixFilter" => val indices = literals(literals(int64_literal))(it) - blockmatrix_ir(env.onlyRelational)(it).map { child => + blockmatrix_ir(env)(it).map { child => BlockMatrixFilter(child, indices) } case "BlockMatrixDensify" => - blockmatrix_ir(env.onlyRelational)(it).map(BlockMatrixDensify) + blockmatrix_ir(env)(it).map(BlockMatrixDensify) case "BlockMatrixSparsify" => for { - sparsifier <- blockmatrix_sparsifier(env.onlyRelational)(it) - child <- blockmatrix_ir(env.onlyRelational)(it) + sparsifier <- blockmatrix_sparsifier(env)(it) + child <- blockmatrix_ir(env)(it) } yield BlockMatrixSparsify(child, sparsifier) case "BlockMatrixSlice" => val slices = literals(literals(int64_literal))(it) - blockmatrix_ir(env.onlyRelational)(it).map { child => + blockmatrix_ir(env)(it).map { child => BlockMatrixSlice(child, slices.map(_.toFastSeq).toFastSeq) } case "ValueToBlockMatrix" => val shape = int64_literals(it) val blockSize = int32_literal(it) - ir_value_expr(env.onlyRelational)(it).map { child => + ir_value_expr(env)(it).map { child => ValueToBlockMatrix(child, shape, blockSize) } case "BlockMatrixRandom" => @@ -2150,20 +2068,68 @@ object IRParser { case "RelationalLetBlockMatrix" => val name = identifier(it) for { - value <- ir_value_expr(env.onlyRelational)(it) - body <- blockmatrix_ir(env.onlyRelational.bindRelational(name, value.typ))(it) + value <- ir_value_expr(env)(it) + body <- blockmatrix_ir(env)(it) } yield RelationalLetBlockMatrix(name, value, body) } } + def annotateTypes(ctx: ExecuteContext, ir: BaseIR, env: BindingEnv[Type]): BaseIR = { + def run(ir: BaseIR, env: BindingEnv[Type]): BaseIR = { + val rw = ir.mapChildrenWithEnv(env)(run) + rw match { + case x: Ref => + x._typ = env.eval(x.name) + x + case x: Recur => + val TTuple(IndexedSeq(_, TupleField(_, rt))) = env.eval.lookup(x.name) + x._typ = rt + x + case x: ApplyAggOp => + x.aggSig.initOpArgs = x.initOpArgs.map(_.typ) + x.aggSig.seqOpArgs = x.seqOpArgs.map(_.typ) + x + case x: ApplyScanOp => + x.aggSig.initOpArgs = x.initOpArgs.map(_.typ) + x.aggSig.seqOpArgs = x.seqOpArgs.map(_.typ) + x + case x: ApplyComparisonOp => + x.op = x.op.copy(x.l.typ, x.r.typ) + x + case MakeArray(args, typ) => + MakeArray.unify(ctx, args, typ) + case x@InitOp(_, _, BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR))) => + run(combIR, BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType)) + x + case x@SeqOp(_, _, BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR))) => + run(combIR, BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType)) + x + case x@CombOp(_, _, BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR))) => + run(combIR, BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType)) + x + case x@ResultOp(_, BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR))) => + run(combIR, BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType)) + x + case Apply(name, typeArgs, args, rt, errorID) => + invoke(name, rt, typeArgs, errorID, args: _*) + case _ => + rw + } + } + + run(ir, env) + } + def parse[T](s: String, f: (TokenIterator) => T): T = { - val t = System.nanoTime() val it = IRLexer.parse(s).toIterator.buffered f(it) } - def parse_value_ir(s: String, env: IRParserEnvironment): IR = { - parse(s, ir_value_expr(env)(_).run()) + def parse_value_ir(s: String, env: IRParserEnvironment, typeEnv: BindingEnv[Type] = BindingEnv.empty): IR = { + var ir = parse(s, ir_value_expr(env)(_).run()) + ir = annotateTypes(env.ctx, ir, typeEnv).asInstanceOf[IR] + TypeCheck(env.ctx, ir, typeEnv) + ir } def parse_value_ir(ctx: ExecuteContext, s: String): IR = { @@ -2172,13 +2138,28 @@ object IRParser { def parse_table_ir(ctx: ExecuteContext, s: String): TableIR = parse_table_ir(s, IRParserEnvironment(ctx)) - def parse_table_ir(s: String, env: IRParserEnvironment): TableIR = parse(s, table_ir(env)(_).run()) + def parse_table_ir(s: String, env: IRParserEnvironment): TableIR = { + var ir = parse(s, table_ir(env)(_).run()) + ir = annotateTypes(env.ctx, ir, BindingEnv.empty).asInstanceOf[TableIR] + TypeCheck(env.ctx, ir) + ir + } - def parse_matrix_ir(s: String, env: IRParserEnvironment): MatrixIR = parse(s, matrix_ir(env)(_).run()) + def parse_matrix_ir(s: String, env: IRParserEnvironment): MatrixIR = { + var ir = parse(s, matrix_ir(env)(_).run()) + ir = annotateTypes(env.ctx, ir, BindingEnv.empty).asInstanceOf[MatrixIR] + TypeCheck(env.ctx, ir) + ir + } def parse_matrix_ir(ctx: ExecuteContext, s: String): MatrixIR = parse_matrix_ir(s, IRParserEnvironment(ctx)) - def parse_blockmatrix_ir(s: String, env: IRParserEnvironment): BlockMatrixIR = parse(s, blockmatrix_ir(env)(_).run()) + def parse_blockmatrix_ir(s: String, env: IRParserEnvironment): BlockMatrixIR = { + var ir = parse(s, blockmatrix_ir(env)(_).run()) + ir = annotateTypes(env.ctx, ir, BindingEnv.empty).asInstanceOf[BlockMatrixIR] + TypeCheck(env.ctx, ir) + ir + } def parse_blockmatrix_ir(ctx: ExecuteContext, s: String): BlockMatrixIR = parse_blockmatrix_ir(s, IRParserEnvironment(ctx)) diff --git a/hail/src/main/scala/is/hail/expr/ir/Pretty.scala b/hail/src/main/scala/is/hail/expr/ir/Pretty.scala index 353b2d1760f..c2bb7922ea9 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Pretty.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Pretty.scala @@ -183,12 +183,10 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, single(Pretty.prettyBooleanLiteral(isScan)) else FastSeq(prettyIdentifier(name), Pretty.prettyBooleanLiteral(isScan)) - case TailLoop(name, args, _) if !elideBindings => - FastSeq(prettyIdentifier(name), prettyIdentifiers(args.map(_._1).toFastSeq)) - case Recur(name, _, t) => if (elideBindings) - single(t.parsableString()) - else - FastSeq(prettyIdentifier(name), t.parsableString()) + case TailLoop(name, args, returnType, _) if !elideBindings => + FastSeq(prettyIdentifier(name), prettyIdentifiers(args.map(_._1).toFastSeq), returnType.parsableString()) + case Recur(name, _, t) if !elideBindings => + FastSeq(prettyIdentifier(name)) // case Ref(name, t) if t != null => FastSeq(prettyIdentifier(name), t.parsableString()) // For debug purposes case Ref(name, _) => single(prettyIdentifier(name)) case RelationalRef(name, t) => if (elideBindings) @@ -276,7 +274,7 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, case NDArrayInv(_, errorID) => single(s"$errorID") case ArraySort(_, l, r, _) if !elideBindings => FastSeq(prettyIdentifier(l), prettyIdentifier(r)) case ArrayRef(_,_, errorID) => single(s"$errorID") - case ApplyIR(function, typeArgs, _, errorID) => FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), ir.typ.parsableString()) + case ApplyIR(function, typeArgs, _, _, errorID) => FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), ir.typ.parsableString()) case Apply(function, typeArgs, _, t, errorID) => FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString()) case ApplySeeded(function, _, rngState, staticUID, t) => FastSeq(prettyIdentifier(function), staticUID.toString, t.parsableString()) case ApplySpecial(function, typeArgs, _, t, errorID) => FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString()) @@ -414,12 +412,14 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, prettyStrings(entryKV.map(_._1)), prettyStrings(entryKV.map(_._2))) case TableFilterIntervals(child, intervals, keep) => FastSeq( + child.typ.keyType.parsableString(), prettyStringLiteral(Serialization.write( JSONAnnotationImpex.exportAnnotation(intervals, TArray(TInterval(child.typ.keyType))) )(RelationalSpec.formats)), Pretty.prettyBooleanLiteral(keep)) case MatrixFilterIntervals(child, intervals, keep) => FastSeq( + child.typ.rowType.parsableString(), prettyStringLiteral(Serialization.write( JSONAnnotationImpex.exportAnnotation(intervals, TArray(TInterval(child.typ.rowKeyStruct))) )(RelationalSpec.formats)), @@ -494,7 +494,7 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, if (i > 0) Some(FastSeq()) else None case _: Switch => if (i > 0) Some(FastSeq()) else None - case TailLoop(name, args, body) => if (i == args.length) + case TailLoop(name, args, _, body) => if (i == args.length) Some(args.map { case (name, ir) => name -> "loopvar" } :+ name -> "loop") else None case StreamMap(a, name, _) => diff --git a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala index 4bf3dd0ee76..eea3d832126 100644 --- a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala +++ b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala @@ -1511,7 +1511,7 @@ object PruneDeadFields { ) memoizeMatrixIR(ctx, child, dep, memo) BindingEnv.empty - case TailLoop(name, params, body) => + case TailLoop(name, params, _, body) => val bodyEnv = memoizeValueIR(ctx, body, body.typ, memo) val paramTypes = params.map{ case (paramName, paramIR) => unifySeq(paramIR.typ, uses(paramName, bodyEnv.eval)) diff --git a/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala b/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala index b8ee0de9978..e2c6b0c317e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala @@ -158,7 +158,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { case Let(name, value, body) => addBinding(name, value) case RelationalLet(name, value, body) => addBinding(name, value) case RelationalLetTable(name, value, body) => addBinding(name, value) - case TailLoop(loopName, params, body) => + case TailLoop(loopName, params, _, body) => addBinding(loopName, body) val argDefs = Array.fill(params.length)(new BoxedArrayBuilder[IR]()) refMap.getOrElse(loopName, FastSeq()).map(_.t).foreach { case Recur(_, args, _) => @@ -173,7 +173,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { i += 1 } states.bind(node, s) - case x@ApplyIR(_, _, args, _) => + case x@ApplyIR(_, _, args, _, _) => x.refIdx.foreach { case (n, i) => addBinding(n, args(i)) } case ArraySort(a, l, r, c) => addElementBinding(l, a, makeRequired = true) @@ -543,7 +543,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { requiredness.unionFrom(lookup(body)) case RelationalLet(name, value, body) => requiredness.unionFrom(lookup(body)) - case TailLoop(name, params, body) => + case TailLoop(name, params, _, body) => requiredness.unionFrom(lookup(body)) case x: BaseRef => requiredness.unionFrom(defs(node).map(tcoerce[TypeWithRequiredness])) diff --git a/hail/src/main/scala/is/hail/expr/ir/Simplify.scala b/hail/src/main/scala/is/hail/expr/ir/Simplify.scala index 492510e0707..a963f36b350 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Simplify.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Simplify.scala @@ -276,15 +276,15 @@ object Simplify { case CastRename(x, t) if x.typ == t => x case CastRename(CastRename(x, _), t) => CastRename(x, t) - case ApplyIR("indexArray", _, Seq(a, i@I32(v)), errorID) if v >= 0 => + case ApplyIR("indexArray", _, Seq(a, i@I32(v)), _, errorID) if v >= 0 => ArrayRef(a, i, errorID) - case ApplyIR("contains", _, Seq(CastToArray(x), element), _) if x.typ.isInstanceOf[TSet] => invoke("contains", TBoolean, x, element) + case ApplyIR("contains", _, Seq(CastToArray(x), element), _, _) if x.typ.isInstanceOf[TSet] => invoke("contains", TBoolean, x, element) - case ApplyIR("contains", _, Seq(Literal(t, v), element), _) if t.isInstanceOf[TArray] => + case ApplyIR("contains", _, Seq(Literal(t, v), element), _, _) if t.isInstanceOf[TArray] => invoke("contains", TBoolean, Literal(TSet(t.asInstanceOf[TArray].elementType), v.asInstanceOf[IndexedSeq[_]].toSet), element) - case ApplyIR("contains", _, Seq(ToSet(x), element), _) if x.typ.isInstanceOf[TArray] => invoke("contains", TBoolean, x, element) + case ApplyIR("contains", _, Seq(ToSet(x), element), _, _) if x.typ.isInstanceOf[TArray] => invoke("contains", TBoolean, x, element) case x: ApplyIR if x.inline || x.body.size < 10 => x.explicitNode @@ -635,7 +635,7 @@ object Simplify { // ArrayAgg(GetField(Ref(uid, rowsAndGlobal.typ), "rows"), "row", query))) // } - case ApplyIR("annotate", _, Seq(s, MakeStruct(fields)), _) => + case ApplyIR("annotate", _, Seq(s, MakeStruct(fields)), _, _) => InsertFields(s, fields) // simplify Boolean equality diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index cf38e3b00aa..fa1e40fbc08 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -8,7 +8,7 @@ import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext, TaskF import is.hail.expr.ir import is.hail.expr.ir.functions.{BlockMatrixToTableFunction, IntervalFunctions, MatrixToTableFunction, TableToTableFunction} import is.hail.expr.ir.lowering._ -import is.hail.expr.ir.streams.StreamProducer +import is.hail.expr.ir.streams.{StreamProducer, StreamUtils} import is.hail.io._ import is.hail.io.avro.AvroTableReader import is.hail.io.fs.FS @@ -72,6 +72,8 @@ abstract sealed class TableIR extends BaseIR { } def pyUnpersist(): TableIR = unpersist() + + def typecheck(): Unit = {} } object TableLiteral { @@ -1684,15 +1686,14 @@ case class TableRead(typ: TableType, dropRows: Boolean, tr: TableReader) extends } case class TableParallelize(rowsAndGlobal: IR, nPartitions: Option[Int] = None) extends TableIR { - require(rowsAndGlobal.typ.isInstanceOf[TStruct]) - require(rowsAndGlobal.typ.asInstanceOf[TStruct].fieldNames.sameElements(Array("rows", "global"))) - require(nPartitions.forall(_ > 0)) + override def typecheck(): Unit = { + assert(rowsAndGlobal.typ.isInstanceOf[TStruct]) + assert(rowsAndGlobal.typ.asInstanceOf[TStruct].fieldNames.sameElements(Array("rows", "global"))) + assert(nPartitions.forall(_ > 0)) + } lazy val rowCountUpperBound: Option[Long] = None - private val rowsType = rowsAndGlobal.typ.asInstanceOf[TStruct].fieldType("rows").asInstanceOf[TArray] - private val globalsType = rowsAndGlobal.typ.asInstanceOf[TStruct].fieldType("global").asInstanceOf[TStruct] - val childrenSeq: IndexedSeq[BaseIR] = FastSeq(rowsAndGlobal) def copy(newChildren: IndexedSeq[BaseIR]): TableParallelize = { @@ -1700,10 +1701,14 @@ case class TableParallelize(rowsAndGlobal: IR, nPartitions: Option[Int] = None) TableParallelize(newrowsAndGlobal, nPartitions) } - val typ: TableType = TableType( - rowsType.elementType.asInstanceOf[TStruct], - FastSeq(), - globalsType) + lazy val typ: TableType = { + def rowsType = rowsAndGlobal.typ.asInstanceOf[TStruct].fieldType("rows").asInstanceOf[TArray] + def globalsType = rowsAndGlobal.typ.asInstanceOf[TStruct].fieldType("global").asInstanceOf[TStruct] + TableType( + rowsType.elementType.asInstanceOf[TStruct], + FastSeq(), + globalsType) + } protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { val (ptype: PStruct, res) = CompileAndEvaluate._apply(ctx, rowsAndGlobal, optimize = false) match { @@ -1776,14 +1781,16 @@ case class TableParallelize(rowsAndGlobal: IR, nPartitions: Option[Int] = None) * - Otherwise, if 'isSorted' is false and n < 'keys.length', then shuffle. */ case class TableKeyBy(child: TableIR, keys: IndexedSeq[String], isSorted: Boolean = false) extends TableIR { - private val fields = child.typ.rowType.fieldNames.toSet - assert(keys.forall(fields.contains), s"${ keys.filter(k => !fields.contains(k)).mkString(", ") }") + override def typecheck(): Unit = { + val fields = child.typ.rowType.fieldNames.toSet + assert(keys.forall(fields.contains), s"${keys.filter(k => !fields.contains(k)).mkString(", ")}") + } lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound val childrenSeq: IndexedSeq[BaseIR] = Array(child) - val typ: TableType = child.typ.copy(key = keys) + lazy val typ: TableType = child.typ.copy(key = keys) def definitelyDoesNotShuffle: Boolean = child.typ.key.startsWith(keys) || isSorted @@ -1794,7 +1801,7 @@ case class TableKeyBy(child: TableIR, keys: IndexedSeq[String], isSorted: Boolea protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { val tv = child.execute(ctx, r).asTableValue(ctx) - new TableValueIntermediate(tv.copy(typ = typ, rvd = tv.rvd.enforceKey(ctx, keys, isSorted))) + TableValueIntermediate(tv.copy(typ = typ, rvd = tv.rvd.enforceKey(ctx, keys, isSorted))) } } @@ -1820,24 +1827,29 @@ case class TableGen(contexts: IR, errorId: Int = ErrorIDs.NO_ERROR ) extends TableIR { - TypeCheck.coerce[TStream]("contexts", contexts.typ) + override def typecheck(): Unit = { + TypeCheck.coerce[TStream]("contexts", contexts.typ) + TypeCheck.coerce[TStruct]("globals", globals.typ) + val bodyType = TypeCheck.coerce[TStream]("body", body.typ) + val rowType = TypeCheck.coerce[TStruct]("body.elementType", bodyType.elementType) + + if (!partitioner.kType.isSubsetOf(rowType)) + throw new IllegalArgumentException( + s"""'partitioner': key type contains fields absent from row type + | Key type: ${partitioner.kType} + | Row type: $rowType""".stripMargin + ) + } - private val globalType = + private def globalType = TypeCheck.coerce[TStruct]("globals", globals.typ) - private val rowType = { + private def rowType = { val bodyType = TypeCheck.coerce[TStream]( "body", body.typ) TypeCheck.coerce[TStruct]( "body.elementType", bodyType.elementType) } - if (!partitioner.kType.isSubsetOf(rowType)) - throw new IllegalArgumentException( - s"""'partitioner': key type contains fields absent from row type - | Key type: ${partitioner.kType} - | Row type: $rowType""".stripMargin - ) - - override def typ: TableType = + override lazy val typ: TableType = TableType(rowType, partitioner.kType.fieldNames, globalType) override val rowCountUpperBound: Option[Long] = @@ -1910,7 +1922,7 @@ case class TableRange(n: Int, nPartitions: Int) extends TableIR { case class TableFilter(child: TableIR, pred: IR) extends TableIR { val childrenSeq: IndexedSeq[BaseIR] = Array(child, pred) - val typ: TableType = child.typ + def typ: TableType = child.typ lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound @@ -2048,22 +2060,26 @@ case class TableJoin(left: TableIR, right: TableIR, joinType: String, joinKey: I extends TableIR { require(joinKey >= 0) - require(left.typ.key.length >= joinKey) - require(right.typ.key.length >= joinKey) - require(left.typ.keyType.truncate(joinKey) isIsomorphicTo right.typ.keyType.truncate(joinKey)) - require(left.typ.globalType.fieldNames.toSet - .intersect(right.typ.globalType.fieldNames.toSet) - .isEmpty) require(joinType == "inner" || joinType == "left" || joinType == "right" || joinType == "outer") + override def typecheck(): Unit = { + assert(left.typ.key.length >= joinKey) + assert(right.typ.key.length >= joinKey) + assert(left.typ.keyType.truncate(joinKey) isIsomorphicTo right.typ.keyType.truncate(joinKey)) + assert( + left.typ.globalType.fieldNames.toSet + .intersect(right.typ.globalType.fieldNames.toSet) + .isEmpty) + } + val childrenSeq: IndexedSeq[BaseIR] = Array(left, right) lazy val rowCountUpperBound: Option[Long] = None - private val newRowType = { + lazy val typ: TableType = { val leftRowType = left.typ.rowType val rightRowType = right.typ.rowType val leftKey = left.typ.key.take(joinKey) @@ -2077,14 +2093,12 @@ case class TableJoin(left: TableIR, right: TableIR, joinType: String, joinKey: I .nonEmpty) throw new RuntimeException(s"invalid join: \n left value: $leftValueType\n right value: $rightValueType") - leftKeyType ++ leftValueType ++ rightValueType - } - - private val newGlobalType = left.typ.globalType ++ right.typ.globalType - - private val newKey = left.typ.key ++ right.typ.key.drop(joinKey) + val newRowType = leftKeyType ++ leftValueType ++ rightValueType + val newGlobalType = left.typ.globalType ++ right.typ.globalType + val newKey = left.typ.key ++ right.typ.key.drop(joinKey) - val typ: TableType = TableType(newRowType, newKey, newGlobalType) + TableType(newRowType, newKey, newGlobalType) + } def copy(newChildren: IndexedSeq[BaseIR]): TableJoin = { assert(newChildren.length == 2) @@ -2112,8 +2126,10 @@ case class TableIntervalJoin( lazy val rowCountUpperBound: Option[Long] = left.rowCountUpperBound - val rightType: Type = if (product) TArray(right.typ.valueType) else right.typ.valueType - val typ: TableType = left.typ.copy(rowType = left.typ.rowType.appendKey(root, rightType)) + lazy val typ: TableType = { + val rightType: Type = if (product) TArray(right.typ.valueType) else right.typ.valueType + left.typ.copy(rowType = left.typ.rowType.appendKey(root, rightType)) + } override def copy(newChildren: IndexedSeq[BaseIR]): TableIR = TableIntervalJoin(newChildren(0).asInstanceOf[TableIR], newChildren(1).asInstanceOf[TableIR], root, product) @@ -2194,21 +2210,25 @@ case class TableIntervalJoin( * is likely the last. */ case class TableMultiWayZipJoin(childrenSeq: IndexedSeq[TableIR], fieldName: String, globalName: String) extends TableIR { - require(childrenSeq.length > 0, "there must be at least one table as an argument") + require(childrenSeq.nonEmpty, "there must be at least one table as an argument") - private val first = childrenSeq.head - private val rest = childrenSeq.tail + override def typecheck(): Unit = { + val first = childrenSeq.head + val rest = childrenSeq.tail + assert(rest.forall(e => e.typ.rowType == first.typ.rowType), "all rows must have the same type") + assert(rest.forall(e => e.typ.key == first.typ.key), "all keys must be the same") + assert( + rest.forall(e => e.typ.globalType == first.typ.globalType), + "all globals must have the same type") + } - lazy val rowCountUpperBound: Option[Long] = None + private def first = childrenSeq.head - require(rest.forall(e => e.typ.rowType == first.typ.rowType), "all rows must have the same type") - require(rest.forall(e => e.typ.key == first.typ.key), "all keys must be the same") - require(rest.forall(e => e.typ.globalType == first.typ.globalType), - "all globals must have the same type") + lazy val rowCountUpperBound: Option[Long] = None - private val newGlobalType = TStruct(globalName -> TArray(first.typ.globalType)) - private val newValueType = TStruct(fieldName -> TArray(first.typ.valueType)) - private val newRowType = first.typ.keyType ++ newValueType + private def newGlobalType = TStruct(globalName -> TArray(first.typ.globalType)) + private def newValueType = TStruct(fieldName -> TArray(first.typ.valueType)) + private def newRowType = first.typ.keyType ++ newValueType lazy val typ: TableType = first.typ.copy( rowType = newRowType, @@ -2294,15 +2314,18 @@ case class TableMultiWayZipJoin(childrenSeq: IndexedSeq[TableIR], fieldName: Str } case class TableLeftJoinRightDistinct(left: TableIR, right: TableIR, root: String) extends TableIR { - require(right.typ.keyType isPrefixOf left.typ.keyType, - s"\n L: ${ left.typ }\n R: ${ right.typ }") + override def typecheck(): Unit = { + assert( + right.typ.keyType isPrefixOf left.typ.keyType, + s"\n L: ${left.typ}\n R: ${right.typ}") + } lazy val rowCountUpperBound: Option[Long] = left.rowCountUpperBound lazy val childrenSeq: IndexedSeq[BaseIR] = Array(left, right) - private val newRowType = left.typ.rowType.structInsert(right.typ.valueType, List(root))._1 - val typ: TableType = left.typ.copy(rowType = newRowType) + lazy val typ: TableType = left.typ.copy( + rowType = left.typ.rowType.structInsert(right.typ.valueType, List(root))) override def partitionCounts: Option[IndexedSeq[Long]] = left.partitionCounts @@ -2337,11 +2360,18 @@ case class TableMapPartitions(child: TableIR, requestedKey: Int, allowedOverlap: Int ) extends TableIR { - assert(body.typ.isInstanceOf[TStream], s"${ body.typ }") - assert(allowedOverlap >= -1 && allowedOverlap <= child.typ.key.size) - assert(requestedKey >= 0 && requestedKey <= child.typ.key.size) + override def typecheck(): Unit = { + assert(body.typ.isInstanceOf[TStream], s"${body.typ}") + assert(allowedOverlap >= -1) + assert(allowedOverlap <= child.typ.key.size) + assert(requestedKey >= 0) + assert(requestedKey <= child.typ.key.size) + assert(StreamUtils.isIterationLinear(body, partitionStreamName), "must iterate over the partition exactly once") + val newRowType = body.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct] + child.typ.key.foreach { k => if (!newRowType.hasField(k)) throw new RuntimeException(s"prev key: ${child.typ.key}, new row: ${newRowType}") } + } - lazy val typ = child.typ.copy( + lazy val typ: TableType = child.typ.copy( rowType = body.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct]) lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, body) @@ -2406,11 +2436,16 @@ case class TableMapPartitions(child: TableIR, // Must leave key fields unchanged. case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { + override def typecheck(): Unit = { + val newFieldSet = newRow.typ.asInstanceOf[TStruct].fieldNames.toSet + assert(child.typ.key.forall(newFieldSet.contains)) + } + val childrenSeq: IndexedSeq[BaseIR] = Array(child, newRow) lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound - val typ: TableType = child.typ.copy(rowType = newRow.typ.asInstanceOf[TStruct]) + lazy val typ: TableType = child.typ.copy(rowType = newRow.typ.asInstanceOf[TStruct]) def copy(newChildren: IndexedSeq[BaseIR]): TableMapRows = { assert(newChildren.length == 2) @@ -2729,7 +2764,7 @@ case class TableMapGlobals(child: TableIR, newGlobals: IR) extends TableIR { lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound - val typ: TableType = + lazy val typ: TableType = child.typ.copy(globalType = newGlobals.typ.asInstanceOf[TStruct]) def copy(newChildren: IndexedSeq[BaseIR]): TableMapGlobals = { @@ -2758,24 +2793,19 @@ case class TableMapGlobals(child: TableIR, newGlobals: IR) extends TableIR { case class TableExplode(child: TableIR, path: IndexedSeq[String]) extends TableIR { assert(path.nonEmpty) - assert(!child.typ.key.contains(path.head)) + + override def typecheck(): Unit = { + assert(!child.typ.key.contains(path.head)) + } lazy val rowCountUpperBound: Option[Long] = None lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child) - private val childRowType = child.typ.rowType + private def childRowType = child.typ.rowType - private val length: IR = { - Coalesce(FastSeq( - ArrayLen(CastToArray( - path.foldLeft[IR](Ref("row", childRowType))((struct, field) => - GetField(struct, field)))), - 0)) - } - - val idx = Ref(genUID(), TInt32) - val newRow: InsertFields = { + private[this] lazy val idx = Ref(genUID(), TInt32) + private[this] lazy val newRow: InsertFields = { val refs = path.init.scanLeft(Ref("row", childRowType))((struct, name) => Ref(genUID(), tcoerce[TStruct](struct.typ).field(name).typ)) @@ -2789,7 +2819,7 @@ case class TableExplode(child: TableIR, path: IndexedSeq[String]) extends TableI }.asInstanceOf[InsertFields] } - val typ: TableType = child.typ.copy(rowType = newRow.typ) + lazy val typ: TableType = child.typ.copy(rowType = newRow.typ) def copy(newChildren: IndexedSeq[BaseIR]): TableExplode = { assert(newChildren.length == 1) @@ -2799,6 +2829,14 @@ case class TableExplode(child: TableIR, path: IndexedSeq[String]) extends TableI protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { val prev = child.execute(ctx, r).asTableValue(ctx) + val length: IR = + Coalesce(FastSeq( + ArrayLen(CastToArray( + path.foldLeft[IR](Ref("row", childRowType)) { (struct, field) => + GetField(struct, field) + })), + 0)) + val (len, l) = Compile[AsmFunction2RegionLongInt](ctx, FastSeq(("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.rvd.rowPType)))), FastSeq(classInfo[Region], LongInfo), IntInfo, @@ -2816,7 +2854,7 @@ case class TableExplode(child: TableIR, path: IndexedSeq[String]) extends TableI prev.rvd.typ.key.takeWhile(_ != path.head) ) val fsBc = ctx.fsBc - new TableValueIntermediate( + TableValueIntermediate( TableValue(ctx, typ, prev.globals, prev.rvd.boundary.mapPartitionsWithIndex(rvdType) { (i, ctx, it) => @@ -2844,8 +2882,11 @@ case class TableExplode(child: TableIR, path: IndexedSeq[String]) extends TableI case class TableUnion(childrenSeq: IndexedSeq[TableIR]) extends TableIR { assert(childrenSeq.nonEmpty) - assert(childrenSeq.tail.forall(_.typ.rowType == childrenSeq(0).typ.rowType)) - assert(childrenSeq.tail.forall(_.typ.key == childrenSeq(0).typ.key)) + + override def typecheck(): Unit = { + assert(childrenSeq.tail.forall(_.typ.rowType == childrenSeq(0).typ.rowType)) + assert(childrenSeq.tail.forall(_.typ.key == childrenSeq(0).typ.key)) + } lazy val rowCountUpperBound: Option[Long] = { val definedChildren = childrenSeq.flatMap(_.rowCountUpperBound) @@ -2859,11 +2900,11 @@ case class TableUnion(childrenSeq: IndexedSeq[TableIR]) extends TableIR { TableUnion(newChildren.map(_.asInstanceOf[TableIR])) } - val typ: TableType = childrenSeq(0).typ + def typ: TableType = childrenSeq(0).typ protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { val tvs = childrenSeq.map(_.execute(ctx, r).asTableValue(ctx)) - new TableValueIntermediate( + TableValueIntermediate( tvs(0).copy( rvd = RVD.union(RVD.unify(ctx, tvs.map(_.rvd)), tvs(0).typ.key.length, ctx))) } @@ -2881,7 +2922,7 @@ case class MatrixRowsTable(child: MatrixIR) extends TableIR { MatrixRowsTable(newChildren(0).asInstanceOf[MatrixIR]) } - val typ: TableType = child.typ.rowsTableType + def typ: TableType = child.typ.rowsTableType } case class MatrixColsTable(child: MatrixIR) extends TableIR { @@ -2894,7 +2935,7 @@ case class MatrixColsTable(child: MatrixIR) extends TableIR { MatrixColsTable(newChildren(0).asInstanceOf[MatrixIR]) } - val typ: TableType = child.typ.colsTableType + def typ: TableType = child.typ.colsTableType } case class MatrixEntriesTable(child: MatrixIR) extends TableIR { @@ -2907,7 +2948,7 @@ case class MatrixEntriesTable(child: MatrixIR) extends TableIR { MatrixEntriesTable(newChildren(0).asInstanceOf[MatrixIR]) } - val typ: TableType = child.typ.entriesTableType + def typ: TableType = child.typ.entriesTableType } case class TableDistinct(child: TableIR) extends TableIR { @@ -2920,7 +2961,7 @@ case class TableDistinct(child: TableIR) extends TableIR { TableDistinct(newChild.asInstanceOf[TableIR]) } - val typ: TableType = child.typ + def typ: TableType = child.typ protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { val prev = child.execute(ctx, r).asTableValue(ctx) @@ -2933,10 +2974,14 @@ case class TableKeyByAndAggregate( expr: IR, newKey: IR, nPartitions: Option[Int] = None, - bufferSize: Int) extends TableIR { - require(expr.typ.isInstanceOf[TStruct]) - require(newKey.typ.isInstanceOf[TStruct]) - require(bufferSize > 0) + bufferSize: Int +) extends TableIR { + assert(bufferSize > 0) + + override def typecheck(): Unit = { + assert(expr.typ.isInstanceOf[TStruct]) + assert(newKey.typ.isInstanceOf[TStruct]) + } lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, expr, newKey) @@ -2947,8 +2992,9 @@ case class TableKeyByAndAggregate( TableKeyByAndAggregate(newChild, newExpr, newNewKey, nPartitions, bufferSize) } - private val keyType = newKey.typ.asInstanceOf[TStruct] - val typ: TableType = TableType(rowType = keyType ++ tcoerce[TStruct](expr.typ), + private lazy val keyType = newKey.typ.asInstanceOf[TStruct] + lazy val typ: TableType = TableType( + rowType = keyType ++ tcoerce[TStruct](expr.typ), globalType = child.typ.globalType, key = keyType.fieldNames ) @@ -3094,7 +3140,9 @@ case class TableKeyByAndAggregate( // follows key_by non-empty key case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { - require(child.typ.key.nonEmpty) + override def typecheck(): Unit = { + assert(child.typ.key.nonEmpty) + } lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound @@ -3106,7 +3154,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { TableAggregateByKey(newChild, newExpr) } - val typ: TableType = child.typ.copy(rowType = child.typ.keyType ++ tcoerce[TStruct](expr.typ)) + lazy val typ: TableType = child.typ.copy(rowType = child.typ.keyType ++ tcoerce[TStruct](expr.typ)) protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { val prev = child.execute(ctx, r).asTableValue(ctx) @@ -3236,7 +3284,7 @@ case class TableOrderBy(child: TableIR, sortFields: IndexedSeq[SortField]) exten TableOrderBy(newChild.asInstanceOf[TableIR], sortFields) } - val typ: TableType = child.typ.copy(key = FastSeq()) + lazy val typ: TableType = child.typ.copy(key = FastSeq()) protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { val prev = child.execute(ctx, r).asTableValue(ctx) @@ -3289,15 +3337,15 @@ case class CastMatrixToTable( } case class TableRename(child: TableIR, rowMap: Map[String, String], globalMap: Map[String, String]) extends TableIR { - require(rowMap.keys.forall(child.typ.rowType.hasField)) - require(globalMap.keys.forall(child.typ.globalType.hasField)) + override def typecheck(): Unit = { + assert(rowMap.keys.forall(child.typ.rowType.hasField)) + assert(globalMap.keys.forall(child.typ.globalType.hasField)) + } lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound def rowF(old: String): String = rowMap.getOrElse(old, old) - def globalF(old: String): String = globalMap.getOrElse(old, old) - lazy val typ: TableType = child.typ.copy( rowType = child.typ.rowType.rename(rowMap), globalType = child.typ.globalType.rename(globalMap), @@ -3314,7 +3362,7 @@ case class TableRename(child: TableIR, rowMap: Map[String, String], globalMap: M } protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = - new TableValueIntermediate( + TableValueIntermediate( child.execute(ctx, r).asTableValue(ctx).rename(globalMap, rowMap)) } @@ -3328,7 +3376,7 @@ case class TableFilterIntervals(child: TableIR, intervals: IndexedSeq[Interval], TableFilterIntervals(newChild, intervals, keep) } - override lazy val typ: TableType = child.typ + override def typ: TableType = child.typ protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { val tv = child.execute(ctx, r).asTableValue(ctx) @@ -3381,7 +3429,8 @@ case class TableToTableApply(child: TableIR, function: TableToTableFunction) ext case class BlockMatrixToTableApply( bm: BlockMatrixIR, aux: IR, - function: BlockMatrixToTableFunction) extends TableIR { + function: BlockMatrixToTableFunction +) extends TableIR { override lazy val childrenSeq: IndexedSeq[BaseIR] = Array(bm, aux) @@ -3418,7 +3467,7 @@ case class BlockMatrixToTable(child: BlockMatrixIR) extends TableIR { } protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { - new TableValueIntermediate(child.execute(ctx).entriesTable(ctx)) + TableValueIntermediate(child.execute(ctx).entriesTable(ctx)) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala b/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala index 566be722765..582baeb089a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala @@ -29,17 +29,15 @@ object TypeCheck { def check(ctx: ExecuteContext, ir: BaseIR, env: BindingEnv[Type]): StackFrame[Unit] = { for { - _ <- ir.children - .zipWithIndex - .foreachRecur { case (child, i) => - for { - _ <- call(check(ctx, child, ChildBindings(ir, i, env))) - } yield { - if (child.typ == TVoid) { - checkVoidTypedChild(ctx, ir, i, env) - } else () - } + _ <- ir.forEachChildWithEnvStackSafe(env) { (child, i, childEnv) => + for { + _ <- call(check(ctx, child, childEnv)) + } yield { + if (child.typ == TVoid) { + checkVoidTypedChild(ctx, ir, i, env) + } else () } + } } yield checkSingleNode(ctx, ir, env) } @@ -118,17 +116,18 @@ object TypeCheck { case None => throw new RuntimeException(s"RelationalRef not found in env: $name") } - case x@TailLoop(name, _, body) => - assert(x.typ == body.typ) + case x@TailLoop(name, _, rt, body) => + assert(x.typ == rt) + assert(body.typ == rt) def recurInTail(node: IR, tailPosition: Boolean): Boolean = node match { case x: Recur => x.name != name || tailPosition case _ => node.children.zipWithIndex .forall { - case (c: IR, i) => recurInTail(c, tailPosition && InTailPosition(node, i)) + case (c: IR, i) => recurInTail(c, tailPosition && InTailPosition(node, i)) case _ => true - } + } } assert(recurInTail(body, tailPosition = true)) case x@Recur(name, args, typ) => @@ -142,6 +141,7 @@ object TypeCheck { case x@ApplyComparisonOp(op, l, r) => assert(op.t1 == l.typ) assert(op.t2 == r.typ) + ComparisonOp.checkCompatible(op.t1, op.t2) op match { case _: Compare => assert(x.typ == TInt32) case _ => assert(x.typ == TBoolean) @@ -283,7 +283,7 @@ object TypeCheck { assert(a.typ.isInstanceOf[TStream]) assert(lessThan.typ == TBoolean) case x@ToSet(a) => - assert(a.typ.isInstanceOf[TStream]) + assert(a.typ.isInstanceOf[TStream], a.typ) case x@ToDict(a) => assert(a.typ.isInstanceOf[TStream]) assert(tcoerce[TBaseStruct](tcoerce[TStream](a.typ).elementType).size == 2) @@ -360,7 +360,7 @@ object TypeCheck { assert(key.forall(eltType.hasField)) case x@StreamFilter(a, name, cond) => assert(a.typ.asInstanceOf[TStream].elementType.isRealizable) - assert(cond.typ == TBoolean) + assert(cond.typ == TBoolean, cond.typ) assert(x.typ == a.typ) case x@StreamTakeWhile(a, name, cond) => assert(a.typ.asInstanceOf[TStream].elementType.isRealizable) @@ -520,18 +520,26 @@ object TypeCheck { assert(msg.typ == TString) case Trap(child) => case ConsoleLog(msg, _) => assert(msg.typ == TString) - case x@ApplyIR(fn, typeArgs, args, _) => + case x@ApplyIR(fn, _, typeArgs, args, _) => case x: AbstractApplyNode[_] => assert(x.implementation.unify(x.typeArgs, x.args.map(_.typ), x.returnType)) case MatrixWrite(_, _) => - case MatrixMultiWrite(_, _) => // do nothing + case MatrixMultiWrite(children, _) => + val t = children.head.typ + assert( + !t.rowType.hasField(MatrixReader.rowUIDFieldName) && + !t.colType.hasField(MatrixReader.colUIDFieldName), t + ) + assert(children.forall(_.typ == t)) case x@TableAggregate(child, query) => assert(x.typ == query.typ) case x@MatrixAggregate(child, query) => assert(x.typ == query.typ) case RelationalLet(_, _, _) => case TableWrite(_, _) => - case TableMultiWrite(_, _) => + case TableMultiWrite(children, _) => + val t = children.head.typ + assert(children.forall(_.typ == t)) case TableCount(_) => case MatrixCount(_) => case TableGetGlobals(_) => @@ -543,8 +551,6 @@ object TypeCheck { case BlockMatrixCollect(_) => case BlockMatrixWrite(_, writer) => writer.loweredTyp case BlockMatrixMultiWrite(_, _) => - case ValueToBlockMatrix(child, _, _) => - assert(child.typ.isInstanceOf[TArray] || child.typ.isInstanceOf[TNDArray] || child.typ == TFloat64) case CollectDistributedArray(ctxs, globals, cname, gname, body, dynamicID, _, _) => assert(ctxs.typ.isInstanceOf[TStream]) assert(dynamicID.typ == TString) @@ -571,21 +577,10 @@ object TypeCheck { assert(stagingFile.forall(_.typ == TString)) case LiftMeOut(_) => case Consume(_) => - case TableMapRows(child, newRow) => - val newFieldSet = newRow.typ.asInstanceOf[TStruct].fieldNames.toSet - assert(child.typ.key.forall(newFieldSet.contains)) - case TableMapPartitions(child, globalName, partitionStreamName, body, requestedKey, allowedOverlap) => - assert(StreamUtils.isIterationLinear(body, partitionStreamName), "must iterate over the partition exactly once") - val newRowType = body.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct] - child.typ.key.foreach { k => if (!newRowType.hasField(k)) throw new RuntimeException(s"prev key: ${child.typ.key}, new row: ${newRowType}")} - case MatrixUnionCols(left, right, joinType) => - assert(left.typ.rowKeyStruct == right.typ.rowKeyStruct, s"${left.typ.rowKeyStruct} != ${right.typ.rowKeyStruct}") - assert(left.typ.colType == right.typ.colType, s"${left.typ.colType} != ${right.typ.colType}") - assert(left.typ.entryType == right.typ.entryType, s"${left.typ.entryType} != ${right.typ.entryType}") - case _: TableIR => - case _: MatrixIR => - case _: BlockMatrixIR => + case x: TableIR => x.typecheck() + case x: MatrixIR => x.typecheck() + case x: BlockMatrixIR => x.typecheck() } } diff --git a/hail/src/main/scala/is/hail/expr/ir/analyses/ControlFlowPreventsSplit.scala b/hail/src/main/scala/is/hail/expr/ir/analyses/ControlFlowPreventsSplit.scala index afee062cd8e..6796e4fef72 100644 --- a/hail/src/main/scala/is/hail/expr/ir/analyses/ControlFlowPreventsSplit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/analyses/ControlFlowPreventsSplit.scala @@ -11,7 +11,7 @@ object ControlFlowPreventsSplit { case r@Recur(name, _, _) => var parent: BaseIR = r while (parent match { - case TailLoop(`name`, _, _) => false + case TailLoop(`name`, _, _, _) => false case _ => true }) { if (!m.contains(parent)) diff --git a/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala b/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala index 39b5e9e6b72..a012c87f560 100644 --- a/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala +++ b/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala @@ -114,7 +114,7 @@ case object SemanticHash extends Logging { case ApplyComparisonOp(op, _, _) => buffer ++= Bytes.fromClass(op.getClass) - case ApplyIR(fname, tyArgs, _, _) => + case ApplyIR(fname, tyArgs, _, _, _) => buffer ++= fname.getBytes tyArgs.foreach(buffer ++= EncodeTypename(_)) diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala index 5812d74dac1..1143d680843 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala @@ -99,8 +99,8 @@ object IRFunctionRegistry { val refMap = BindingEnv.eval(argNames.zip(valueParameterTypes): _*) val body = IRParser.parse_value_ir( bodyStr, - IRParserEnvironment(ctx, refMap, Map()) - ) + IRParserEnvironment(ctx, Map()), + refMap) userAddedFunctions += ((name, (body.typ, typeParameters, valueParameterTypes))) addIR( @@ -147,7 +147,7 @@ object IRFunctionRegistry { ): JVMFunction = { jvmRegistry.lift(name) match { case None => - fatal(s"no functions found with the name ${name}") + fatal(s"no functions found with the signature $name(${valueParameterTypes.mkString(", ")}): $returnType") case Some(functions) => functions.filter(t => t.unify(typeParameters, valueParameterTypes, returnType)).toSeq match { case Seq() => @@ -196,7 +196,7 @@ object IRFunctionRegistry { def lookupUnseeded(name: String, returnType: Type, typeParameters: Seq[Type], arguments: Seq[Type]): Option[IRFunctionImplementation] = { val validIR: Option[IRFunctionImplementation] = lookupIR(name, returnType, typeParameters, arguments).map { case ((_, _, _, inline), conversion) => (typeParametersPassed, args, errorID) => - val x = ApplyIR(name, typeParametersPassed, args, errorID) + val x = ApplyIR(name, typeParametersPassed, args, returnType, errorID) x.conversion = conversion x.inline = inline x diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala index 984c7d9a4ed..f0d168dc829 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala @@ -729,27 +729,31 @@ object LowerTableIR { })) } - bindIR(TailLoop(treeAggFunction, + val loopBody = If( + ArrayLen(currentAggStates) <= I32(branchFactor), + currentAggStates, + Recur( + treeAggFunction, + FastSeq( + CollectDistributedArray( + mapIR(StreamGrouped(ToStream(currentAggStates), I32(branchFactor)))(x => ToArray(x)), + MakeStruct(FastSeq()), + distAggStatesRef.name, + genUID(), + RunAgg( + combineGroup(distAggStatesRef, false), + WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), writer), + aggs.states), + strConcat(Str("iteration="), invoke("str", TString, iterNumber), Str(", n_states="), invoke("str", TString, ArrayLen(currentAggStates))), + "table_tree_aggregate"), + iterNumber + 1), + currentAggStates.typ)) + bindIR(TailLoop( + treeAggFunction, FastSeq[(String, IR)](currentAggStates.name -> collected, iterNumber.name -> I32(0)), - If(ArrayLen(currentAggStates) <= I32(branchFactor), - currentAggStates, - Recur(treeAggFunction, - FastSeq( - CollectDistributedArray( - mapIR(StreamGrouped(ToStream(currentAggStates), I32(branchFactor)))(x => ToArray(x)), - MakeStruct(FastSeq()), - distAggStatesRef.name, - genUID(), - RunAgg( - combineGroup(distAggStatesRef, false), - WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), writer), - aggs.states - ), - strConcat(Str("iteration="), invoke("str", TString, iterNumber), Str(", n_states="), invoke("str", TString, ArrayLen(currentAggStates))), - "table_tree_aggregate"), - iterNumber + 1), - currentAggStates.typ))) - ) { finalParts => + loopBody.typ, + loopBody + )) { finalParts => RunAgg( combineGroup(finalParts, true), Let("global", globals, @@ -1073,17 +1077,25 @@ object LowerTableIR { val howManyPartsToTry = if (targetNumRows == 1L) 1 else 4 val iteration = Ref(genUID(), TInt32) - TailLoop( - partitionSizeArrayFunc, - FastSeq(howManyPartsToTryRef.name -> howManyPartsToTry, iteration.name -> 0), - bindIR(loweredChild.mapContexts(_ => StreamTake(ToStream(childContexts), howManyPartsToTryRef)) { ctx: IR => ctx } - .mapCollect("table_head_recursive_count", + val loopBody = bindIR( + loweredChild + .mapContexts(_ => StreamTake(ToStream(childContexts), howManyPartsToTryRef)) { ctx: IR => ctx } + .mapCollect( + "table_head_recursive_count", strConcat(Str("iteration="), invoke("str", TString, iteration), Str(",nParts="), invoke("str", TString, howManyPartsToTryRef)) - )(streamLenOrMax)) { counts => - If((Cast(streamSumIR(ToStream(counts)), TInt64) >= targetNumRows) || (ArrayLen(childContexts) <= ArrayLen(counts)), + )(streamLenOrMax) + ) { counts => + If( + (Cast(streamSumIR(ToStream(counts)), TInt64) >= targetNumRows) + || (ArrayLen(childContexts) <= ArrayLen(counts)), counts, Recur(partitionSizeArrayFunc, FastSeq(howManyPartsToTryRef * 4, iteration + 1), TArray(TInt32))) - }) + } + TailLoop( + partitionSizeArrayFunc, + FastSeq(howManyPartsToTryRef.name -> howManyPartsToTry, iteration.name -> 0), + loopBody.typ, + loopBody) } } @@ -1094,16 +1106,22 @@ object LowerTableIR { val numLeft = Ref(genUID(), TInt64) def makeAnswer(howManyParts: IR, howManyFromLast: IR) = MakeTuple(FastSeq((0, howManyParts), (1, howManyFromLast))) + val loopBody = If( + (i ceq numPartitions - 1) || ((numLeft - Cast(ArrayRef(partitionSizeArrayRef, i), TInt64)) <= 0L), + makeAnswer(i + 1, numLeft), + Recur( + howManyPartsToKeep, + FastSeq( + i + 1, + numLeft - Cast(ArrayRef(partitionSizeArrayRef, i), TInt64)), + TTuple(TInt32, TInt64))) If(numPartitions ceq 0, makeAnswer(0, 0L), - TailLoop(howManyPartsToKeep, FastSeq(i.name -> 0, numLeft.name -> targetNumRows), - If((i ceq numPartitions - 1) || ((numLeft - Cast(ArrayRef(partitionSizeArrayRef, i), TInt64)) <= 0L), - makeAnswer(i + 1, numLeft), - Recur(howManyPartsToKeep, - FastSeq( - i + 1, - numLeft - Cast(ArrayRef(partitionSizeArrayRef, i), TInt64)), - TTuple(TInt32, TInt64))))) + TailLoop( + howManyPartsToKeep, + FastSeq(i.name -> 0, numLeft.name -> targetNumRows), + loopBody.typ, + loopBody)) } val newCtxs = bindIR(ToArray(loweredChild.contexts)) { childContexts => @@ -1168,19 +1186,24 @@ object LowerTableIR { val iteration = Ref(genUID(), TInt32) + val loopBody = bindIR( + loweredChild + .mapContexts(_ => StreamDrop(ToStream(childContexts), maxIR(totalNumPartitions - howManyPartsToTryRef, 0))) { ctx: IR => ctx } + .mapCollect( + "table_tail_recursive_count", + strConcat(Str("iteration="), invoke("str", TString, iteration), Str(", nParts="), invoke("str", TString, howManyPartsToTryRef)) + )(StreamLen) + ) { counts => + If( + (Cast(streamSumIR(ToStream(counts)), TInt64) >= targetNumRows) || (totalNumPartitions <= ArrayLen(counts)), + counts, + Recur(partitionSizeArrayFunc, FastSeq(howManyPartsToTryRef * 4, iteration + 1), TArray(TInt32))) + } TailLoop( partitionSizeArrayFunc, FastSeq(howManyPartsToTryRef.name -> howManyPartsToTry, iteration.name -> 0), - bindIR( - loweredChild - .mapContexts(_ => StreamDrop(ToStream(childContexts), maxIR(totalNumPartitions - howManyPartsToTryRef, 0))) { ctx: IR => ctx } - .mapCollect("table_tail_recursive_count", - strConcat(Str("iteration="), invoke("str", TString, iteration), Str(", nParts="), invoke("str", TString, howManyPartsToTryRef)))(StreamLen) - ) { counts => - If((Cast(streamSumIR(ToStream(counts)), TInt64) >= targetNumRows) || (totalNumPartitions <= ArrayLen(counts)), - counts, - Recur(partitionSizeArrayFunc, FastSeq(howManyPartsToTryRef * 4, iteration + 1), TArray(TInt32))) - }) + loopBody.typ, + loopBody) } } @@ -1192,19 +1215,22 @@ object LowerTableIR { val nRowsToRight = Ref(genUID(), TInt64) def makeAnswer(howManyParts: IR, howManyFromLast: IR) = MakeTuple.ordered(FastSeq(howManyParts, howManyFromLast)) + val loopBody = If( + (i ceq numPartitions) || ((nRowsToRight + Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64)) >= targetNumRows), + makeAnswer(i, maxIR(0L, Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64) - (I64(targetNumRows) - nRowsToRight)).toI), + Recur( + howManyPartsToDrop, + FastSeq( + i + 1, + nRowsToRight + Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64)), + TTuple(TInt32, TInt32))) If(numPartitions ceq 0, makeAnswer(0, 0), TailLoop( howManyPartsToDrop, FastSeq(i.name -> 1, nRowsToRight.name -> 0L), - If((i ceq numPartitions) || ((nRowsToRight + Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64)) >= targetNumRows), - makeAnswer(i, maxIR(0L, Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64) - (I64(targetNumRows) - nRowsToRight)).toI), - Recur( - howManyPartsToDrop, - FastSeq( - i + 1, - nRowsToRight + Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64)), - TTuple(TInt32, TInt32))))) + loopBody.typ, + loopBody)) } } @@ -1325,30 +1351,33 @@ object LowerTableIR { val iteration = Ref(genUID(), TInt32) val loopName = genUID() - TailLoop(loopName, IndexedSeq((aggStack.name, MakeArray(collected)), (iteration.name, I32(0))), - bindIR(ArrayRef(aggStack, (ArrayLen(aggStack) - 1))) { states => - bindIR(ArrayLen(states)) { statesLen => - If(statesLen > branchFactor, - bindIR((statesLen + branchFactor - 1) floorDiv branchFactor) { nCombines => - val contexts = mapIR(rangeIR(nCombines)) { outerIdxRef => - sliceArrayIR(states, outerIdxRef * branchFactor, (outerIdxRef + 1) * branchFactor) - } - val cdaResult = cdaIR(contexts, MakeStruct(FastSeq()), "table_scan_up_pass", - strConcat(Str("iteration="), invoke("str", TString, iteration), Str(", nStates="), invoke("str", TString, statesLen)) - ) { case (contexts, _) => - RunAgg( - combineGroup(contexts), - WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), writer), - aggs.states - ) - } - Recur(loopName, IndexedSeq(invoke("extend", TArray(TArray(TString)), aggStack, MakeArray(cdaResult)), iteration + 1), TArray(TArray(TString))) - }, - aggStack - ) - } + val loopBody = bindIR(ArrayRef(aggStack, (ArrayLen(aggStack) - 1))) { states => + bindIR(ArrayLen(states)) { statesLen => + If( + statesLen > branchFactor, + bindIR((statesLen + branchFactor - 1) floorDiv branchFactor) { nCombines => + val contexts = mapIR(rangeIR(nCombines)) { outerIdxRef => + sliceArrayIR(states, outerIdxRef * branchFactor, (outerIdxRef + 1) * branchFactor) + } + val cdaResult = cdaIR( + contexts, MakeStruct(FastSeq()), "table_scan_up_pass", + strConcat(Str("iteration="), invoke("str", TString, iteration), Str(", nStates="), invoke("str", TString, statesLen)) + ) { case (contexts, _) => + RunAgg( + combineGroup(contexts), + WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), writer), + aggs.states) + } + Recur(loopName, IndexedSeq(invoke("extend", TArray(TArray(TString)), aggStack, MakeArray(cdaResult)), iteration + 1), TArray(TArray(TString))) + }, + aggStack) } - ) + } + TailLoop( + loopName, + IndexedSeq((aggStack.name, MakeArray(collected)), (iteration.name, I32(0))), + loopBody.typ, + loopBody) } // The downward pass traverses the tree from root to leaves, computing partial scan @@ -1365,53 +1394,54 @@ object LowerTableIR { bindIR(WriteValue(initState, Str(tmpDir) + UUID4(), writer)) { freshState => - TailLoop(downPassLoopName, IndexedSeq((level.name, ArrayLen(aggStack) - 1), (last.name, MakeArray(freshState)), (iteration.name, I32(0))), - If(level < 0, - last, - bindIR(ArrayRef(aggStack, level)) { aggsArray => - - val groups = mapIR(zipWithIndex(mapIR(StreamGrouped(ToStream(aggsArray), I32(branchFactor)))(x => ToArray(x)))) { eltAndIdx => - MakeStruct(FastSeq( - ("prev", ArrayRef(last, GetField(eltAndIdx, "idx"))), - ("partialSums", GetField(eltAndIdx, "elt")) - )) - } + val loopBody = If( + level < 0, + last, + bindIR(ArrayRef(aggStack, level)) { aggsArray => + val groups = mapIR(zipWithIndex(mapIR(StreamGrouped(ToStream(aggsArray), I32(branchFactor)))(x => ToArray(x)))) { eltAndIdx => + MakeStruct(FastSeq( + ("prev", ArrayRef(last, GetField(eltAndIdx, "idx"))), + ("partialSums", GetField(eltAndIdx, "elt")))) + } - val results = cdaIR(groups, MakeTuple.ordered(FastSeq()), "table_scan_down_pass", - strConcat(Str("iteration="), invoke("str", TString, iteration), Str(", level="), invoke("str", TString, level)) - ) { case (context, _) => - bindIR(GetField(context, "prev")) { prev => - - val elt = Ref(genUID(), TString) - ToArray(RunAggScan( - ToStream(GetField(context, "partialSums"), requiresMemoryManagementPerElement = true), - elt.name, - bindIR(ReadValue(prev, reader, reader.spec.encodedVirtualType)) { serializedTuple => - Begin( - aggs.aggs.zipWithIndex.map { case (sig, i) => - InitFromSerializedValue(i, GetTupleElement(serializedTuple, i), sig.state) - }) - }, - bindIR(ReadValue(elt, reader, reader.spec.encodedVirtualType)) { serializedTuple => - Begin( - aggs.aggs.zipWithIndex.map { case (sig, i) => - CombOpValue(i, GetTupleElement(serializedTuple, i), sig) - }) - }, - WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), writer), - aggs.states - )) - } + val results = cdaIR( + groups, MakeTuple.ordered(FastSeq()), "table_scan_down_pass", + strConcat(Str("iteration="), invoke("str", TString, iteration), Str(", level="), invoke("str", TString, level)) + ) { case (context, _) => + bindIR(GetField(context, "prev")) { prev => + val elt = Ref(genUID(), TString) + ToArray(RunAggScan( + ToStream(GetField(context, "partialSums"), requiresMemoryManagementPerElement = true), + elt.name, + bindIR(ReadValue(prev, reader, reader.spec.encodedVirtualType)) { serializedTuple => + Begin( + aggs.aggs.zipWithIndex.map { case (sig, i) => + InitFromSerializedValue(i, GetTupleElement(serializedTuple, i), sig.state) + }) + }, + bindIR(ReadValue(elt, reader, reader.spec.encodedVirtualType)) { serializedTuple => + Begin( + aggs.aggs.zipWithIndex.map { case (sig, i) => + CombOpValue(i, GetTupleElement(serializedTuple, i), sig) + }) + }, + WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), writer), + aggs.states)) } - Recur(downPassLoopName, - IndexedSeq( - level - 1, - ToArray(flatten(ToStream(results))), - iteration + 1), - TArray(TString)) } - ) - ) + Recur( + downPassLoopName, + IndexedSeq( + level - 1, + ToArray(flatten(ToStream(results))), + iteration + 1), + TArray(TString)) + }) + TailLoop( + downPassLoopName, + IndexedSeq((level.name, ArrayLen(aggStack) - 1), (last.name, MakeArray(freshState)), (iteration.name, I32(0))), + loopBody.typ, + loopBody) } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/package.scala b/hail/src/main/scala/is/hail/expr/ir/package.scala index 6f7d56838fa..4e25e2cc232 100644 --- a/hail/src/main/scala/is/hail/expr/ir/package.scala +++ b/hail/src/main/scala/is/hail/expr/ir/package.scala @@ -37,10 +37,12 @@ package object ir { ir.Ref(pred, TBoolean))) } - def invoke(name: String, rt: Type, typeArgs: Array[Type], errorID: Int, args: IR*): IR = IRFunctionRegistry.lookupUnseeded(name, rt, typeArgs, args.map(_.typ)) match { - case Some(f) => f(typeArgs, args, errorID) - case None => fatal(s"no conversion found for $name(${typeArgs.mkString(", ")}, ${args.map(_.typ).mkString(", ")}) => $rt") - } + def invoke(name: String, rt: Type, typeArgs: Seq[Type], errorID: Int, args: IR*): IR = + IRFunctionRegistry.lookupUnseeded(name, rt, typeArgs, args.map(_.typ)) match { + case Some(f) => f(typeArgs, args, errorID) + case None => fatal(s"no conversion found for $name(${typeArgs.mkString(", ")}, ${args.map(_.typ).mkString(", ")}) => $rt") + } + def invoke(name: String, rt: Type, typeArgs: Array[Type], args: IR*): IR = invoke(name, rt, typeArgs, ErrorIDs.NO_ERROR, args:_*) diff --git a/hail/src/main/scala/is/hail/types/virtual/TStruct.scala b/hail/src/main/scala/is/hail/types/virtual/TStruct.scala index 5e467ec0f4c..a104c2c399d 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TStruct.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TStruct.scala @@ -144,10 +144,10 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { } } - def structInsert(signature: Type, p: List[String]): (TStruct, Inserter) = { + def structInsert(signature: Type, p: List[String]): TStruct = { require(p.nonEmpty || signature.isInstanceOf[TStruct], s"tried to remap top-level struct to non-struct $signature") val (t, f) = insert(signature, p) - (t.asInstanceOf[TStruct], f) + t.asInstanceOf[TStruct] } def updateKey(key: String, i: Int, sig: Type): TStruct = { diff --git a/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala index 0beaa71fe6b..c778465f0cd 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala @@ -35,7 +35,7 @@ class ForwardLetsSuite extends HailSuite { Array( NDArrayMap(In(1, TNDArray(TInt32, Nat(1))), "y", x + y), NDArrayMap2(In(1, TNDArray(TInt32, Nat(1))), In(2, TNDArray(TInt32, Nat(1))), "y", "z", x + y + Ref("z", TInt32), ErrorIDs.NO_ERROR), - TailLoop("f", FastSeq("y" -> I32(0)), If(y < x, Recur("f", FastSeq[IR](y - I32(1)), TInt32), x)) + TailLoop("f", FastSeq("y" -> I32(0)), TInt32, If(y < x, Recur("f", FastSeq[IR](y - I32(1)), TInt32), x)) ).map(ir => Array[IR](Let("x", In(0, TInt32) + In(0, TInt32), ir))) } diff --git a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala index 621b56fca5a..f3d83d9507e 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala @@ -2766,12 +2766,12 @@ class IRSuite extends HailSuite { NDArrayAgg(nd, FastSeq(0)), NDArrayWrite(nd, Str("/path/to/ndarray")), NDArrayMatMul(nd, nd, ErrorIDs.NO_ERROR), - NDArraySlice(nd, MakeTuple.ordered(FastSeq(MakeTuple.ordered(FastSeq(F64(0), F64(2), F64(1))), - MakeTuple.ordered(FastSeq(F64(0), F64(2), F64(1)))))), + NDArraySlice(nd, MakeTuple.ordered(FastSeq(MakeTuple.ordered(FastSeq(I64(0), I64(2), I64(1))), + MakeTuple.ordered(FastSeq(I64(0), I64(2), I64(1)))))), NDArrayFilter(nd, FastSeq(NA(TArray(TInt64)), NA(TArray(TInt64)))), ArrayRef(a, i) -> Array(a), ArrayLen(a) -> Array(a), - RNGSplit(rngState, MakeTuple.ordered(FastSeq(I64(1), MakeTuple.ordered(FastSeq(I64(2), I64(3)))))), + RNGSplit(rngState, MakeTuple.ordered(FastSeq(I64(1), I64(2), I64(3)))), StreamLen(st) -> Array(st), StreamRange(I32(0), I32(5), I32(1)), StreamRange(I32(0), I32(5), I32(1)), @@ -2781,8 +2781,8 @@ class IRSuite extends HailSuite { ToArray(st) -> Array(st), CastToArray(NA(TSet(TInt32))), ToStream(a) -> Array(a), - LowerBoundOnOrderedCollection(a, i, onKey = true) -> Array(a), - GroupByKey(da) -> Array(da), + LowerBoundOnOrderedCollection(a, i, onKey = false) -> Array(a), + GroupByKey(std) -> Array(std), StreamTake(st, I32(10)) -> Array(st), StreamDrop(st, I32(10)) -> Array(st), StreamTakeWhile(st, "v", v < I32(5)) -> Array(st), @@ -2794,7 +2794,7 @@ class IRSuite extends HailSuite { StreamFold(st, I32(0), "x", "v", v) -> Array(st), StreamFold2(StreamFold(st, I32(0), "x", "v", v)) -> Array(st), StreamScan(st, I32(0), "x", "v", v) -> Array(st), - StreamWhiten(whitenStream, "newChunk", "prevWindow", 0, 0, 0, 0, false) -> Array(whitenStream), + StreamWhiten(whitenStream, "newChunk", "prevWindow", 1, 1, 1, 1, false) -> Array(whitenStream), StreamJoinRightDistinct( StreamMap(StreamRange(0, 2, 1), "x", MakeStruct(FastSeq("x" -> Ref("x", TInt32)))), StreamMap(StreamRange(0, 3, 1), "x", MakeStruct(FastSeq("x" -> Ref("x", TInt32)))), @@ -2804,12 +2804,12 @@ class IRSuite extends HailSuite { StreamAggScan(st, "x", ApplyScanOp(FastSeq.empty, FastSeq(Cast(Ref("x", TInt32), TInt64)), sumSig)) -> Array(st), RunAgg(Begin(FastSeq( InitOp(0, FastSeq(Begin(FastSeq(InitOp(0, FastSeq(), pSumSig)))), groupSignature), - SeqOp(0, FastSeq(I32(1), SeqOp(0, FastSeq(), pSumSig)), groupSignature))), + SeqOp(0, FastSeq(I32(1), SeqOp(0, FastSeq(I64(1)), pSumSig)), groupSignature))), AggStateValue(0, groupSignature.state), FastSeq(groupSignature.state)), RunAggScan(StreamRange(I32(0), I32(1), I32(1)), "foo", InitOp(0, FastSeq(Begin(FastSeq(InitOp(0, FastSeq(), pSumSig)))), groupSignature), - SeqOp(0, FastSeq(Ref("foo", TInt32), SeqOp(0, FastSeq(), pSumSig)), groupSignature), + SeqOp(0, FastSeq(Ref("foo", TInt32), SeqOp(0, FastSeq(I64(1)), pSumSig)), groupSignature), AggStateValue(0, groupSignature.state), FastSeq(groupSignature.state)), AggFilter(True(), I32(0), false) -> (_.createAgg), @@ -2845,7 +2845,7 @@ class IRSuite extends HailSuite { TableCount(table), MatrixCount(mt), TableGetGlobals(table), - TableCollect(table), + TableCollect(TableKeyBy(table, FastSeq())), TableAggregate(table, MakeStruct(IndexedSeq("foo" -> count))), TableToValueApply(table, ForceCountTable()), MatrixToValueApply(mt, ForceCountMatrixTable()), @@ -2872,14 +2872,14 @@ class IRSuite extends HailSuite { MakeStream(FastSeq(), TStream(TStruct())), NA(TString), PartitionNativeWriter(TypedCodecSpec(PType.canonical(TStruct()), BufferSpec.default), IndexedSeq(), "path", None, None)), WriteMetadata( - NA(TStruct("global" -> TString, "partitions" -> TStruct("filePath" -> TString, "partitionCounts" -> TInt64))), + Begin(FastSeq()), RelationalWriter("path", overwrite = false, None)), ReadValue(Str("foo"), ETypeValueReader(TypedCodecSpec(PCanonicalStruct("foo" -> PInt32(), "bar" -> PCanonicalString()), BufferSpec.default)), TStruct("foo" -> TInt32)), WriteValue(I32(1), Str("foo"), ETypeValueWriter(TypedCodecSpec(PInt32(), BufferSpec.default))), WriteValue(I32(1), Str("foo"), ETypeValueWriter(TypedCodecSpec(PInt32(), BufferSpec.default)), Some(Str("/tmp/uid/part"))), LiftMeOut(I32(1)), RelationalLet("x", I32(0), I32(0)), - TailLoop("y", IndexedSeq("x" -> I32(0)), Recur("y", FastSeq(I32(4)), TInt32)) + TailLoop("y", IndexedSeq("x" -> I32(0)), TInt32, Recur("y", FastSeq(I32(4)), TInt32)) ) val emptyEnv = BindingEnv.empty[Type] irs.map { case (ir, bind) => Array(ir, bind(emptyEnv)) } @@ -3092,11 +3092,11 @@ class IRSuite extends HailSuite { @Test(dataProvider = "valueIRs") def testValueIRParser(x: IR, refMap: BindingEnv[Type]) { - val env = IRParserEnvironment(ctx, refMap = refMap) + val env = IRParserEnvironment(ctx) val s = Pretty.sexprStyle(x, elideLiterals = false) - val x2 = IRParser.parse_value_ir(s, env) + val x2 = IRParser.parse_value_ir(s, env, refMap) assert(x2 == x) } @@ -3314,6 +3314,7 @@ class IRSuite extends HailSuite { implicit val execStrats = ExecStrategy.compileOnly val triangleSum: IR = TailLoop("f", FastSeq("x" -> In(0, TInt32), "accum" -> In(1, TInt32)), + TInt32, If(Ref("x", TInt32) <= I32(0), Ref("accum", TInt32), Recur("f", @@ -3331,9 +3332,11 @@ class IRSuite extends HailSuite { implicit val execStrats = ExecStrategy.compileOnly val triangleSum: IR = TailLoop("f1", FastSeq("x" -> In(0, TInt32), "accum" -> I32(0)), + TInt32, If(Ref("x", TInt32) <= I32(0), TailLoop("f2", FastSeq("x2" -> Ref("accum", TInt32), "accum2" -> I32(0)), + TInt32, If(Ref("x2", TInt32) <= I32(0), Ref("accum2", TInt32), Recur("f2", @@ -3357,6 +3360,7 @@ class IRSuite extends HailSuite { val ndSum: IR = TailLoop("f", FastSeq("x" -> In(0, TInt32), "accum" -> In(1, ndType)), + ndType, If(Ref("x", TInt32) <= I32(0), Ref("accum", ndType), Recur("f", diff --git a/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala index 394f673868c..b747c351247 100644 --- a/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala @@ -265,7 +265,7 @@ class MatrixIRSuite extends HailSuite { // The entry field must be an array interceptFatal("") { - CastTableToMatrix(rowTab, "animal", "__cols", Array("col_idx")) + TypeCheck(ctx, CastTableToMatrix(rowTab, "animal", "__cols", Array("col_idx"))) } val rdata2 = Array( @@ -322,8 +322,8 @@ class MatrixIRSuite extends HailSuite { val range = rangeMatrix(10, 2, None) val path1 = ctx.createTmpPath("test1") val path2 = ctx.createTmpPath("test2") - intercept[java.lang.IllegalArgumentException] { - val ir = MatrixMultiWrite(FastSeq(vcf, range), MatrixNativeMultiWriter(IndexedSeq(path1, path2))) + intercept[HailException] { + TypeCheck(ctx, MatrixMultiWrite(FastSeq(vcf, range), MatrixNativeMultiWriter(IndexedSeq(path1, path2)))) } } } diff --git a/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala b/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala index e50fdc48b39..c75c4aef7dd 100644 --- a/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala @@ -200,9 +200,12 @@ class RequirednessSuite extends HailSuite { // TailLoop val param1 = Ref(genUID(), tarray) val param2 = Ref(genUID(), TInt32) - val loop = TailLoop("loop", FastSeq( - param1.name -> array(required, required), - param2.name -> int(required)), + val loop = TailLoop( + "loop", + FastSeq( + param1.name -> array(required, required), + param2.name -> int(required)), + tnestedarray, If(False(), // required MakeArray(FastSeq(param1), tnestedarray), // required If(param2 <= I32(1), // possibly missing diff --git a/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala b/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala index 8a86b410c21..ed6b72c9584 100644 --- a/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala @@ -20,7 +20,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("construction", "typecheck")) def testWithInvalidContextsType: Unit = { val ex = intercept[IllegalArgumentException] { - mkTableGen(contexts = Some(Str("oh noes :'("))) + mkTableGen(contexts = Some(Str("oh noes :'("))).typecheck() } ex.getMessage should include("contexts") @@ -31,7 +31,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("construction", "typecheck")) def testWithInvalidGlobalsType: Unit = { val ex = intercept[IllegalArgumentException] { - mkTableGen(globals = Some(Str("oh noes :'(")), body = Some(MakeStream(IndexedSeq(), TStream(TStruct())))) + mkTableGen(globals = Some(Str("oh noes :'(")), body = Some(MakeStream(IndexedSeq(), TStream(TStruct())))).typecheck() } ex.getMessage should include("globals") ex.getMessage should include(s"Expected: ${classOf[TStruct].getName}") @@ -41,7 +41,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("construction", "typecheck")) def testWithInvalidBodyType: Unit = { val ex = intercept[IllegalArgumentException] { - mkTableGen(body = Some(Str("oh noes :'("))) + mkTableGen(body = Some(Str("oh noes :'("))).typecheck() } ex.getMessage should include("body") ex.getMessage should include(s"Expected: ${classOf[TStream].getName}") @@ -51,7 +51,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("construction", "typecheck")) def testWithInvalidBodyElementType: Unit = { val ex = intercept[IllegalArgumentException] { - mkTableGen(body = Some(MakeStream(IndexedSeq(Str("oh noes :'(")), TStream(TString)))) + mkTableGen(body = Some(MakeStream(IndexedSeq(Str("oh noes :'(")), TStream(TString)))).typecheck() } ex.getMessage should include("body.elementType") ex.getMessage should include(s"Expected: ${classOf[TStruct].getName}") @@ -61,7 +61,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("construction", "typecheck")) def testWithInvalidPartitionerKeyType: Unit = { val ex = intercept[IllegalArgumentException] { - mkTableGen(partitioner = Some(RVDPartitioner.empty(ctx.stateManager, TStruct("does-not-exist" -> TInt32)))) + mkTableGen(partitioner = Some(RVDPartitioner.empty(ctx.stateManager, TStruct("does-not-exist" -> TInt32)))).typecheck() } ex.getMessage should include("partitioner") } @@ -69,7 +69,7 @@ class TableGenSuite extends HailSuite { @Test(groups = Array("construction", "typecheck")) def testWithTooLongPartitionerKeyType: Unit = { val ex = intercept[IllegalArgumentException] { - mkTableGen(partitioner = Some(RVDPartitioner.empty(ctx.stateManager, TStruct("does-not-exist" -> TInt32)))) + mkTableGen(partitioner = Some(RVDPartitioner.empty(ctx.stateManager, TStruct("does-not-exist" -> TInt32)))).typecheck() } ex.getMessage should include("partitioner") }