From c8cfa800f3c7e69193c0ff1248b34bd57b79a737 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Mon, 13 Nov 2023 08:57:10 -0500 Subject: [PATCH] misc fixes --- hail/python/hail/ir/table_ir.py | 2 +- hail/python/test/hail/test_ir.py | 18 ++++++++++-------- .../src/main/scala/is/hail/expr/ir/Binds.scala | 12 +++++++----- .../main/scala/is/hail/expr/ir/Pretty.scala | 6 ++---- .../main/scala/is/hail/expr/ir/TypeCheck.scala | 4 ++-- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 0d7e96a5d90..22c0530c35a 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -881,7 +881,7 @@ def _handle_randomness(self, uid_field_name): return TableFilterIntervals(self.child.handle_randomness(uid_field_name), self.intervals, self.point_type, self.keep) def head_str(self): - return f'{self.child.typ.key_type()._parsable_string()} {dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}' + return f'{self.child.typ.key_type._parsable_string()} {dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}' def _eq(self, other): return self.intervals == other.intervals and self.point_type == other.point_type and self.keep == other.keep diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index b2a53f670b7..b2340e99e33 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -22,8 +22,9 @@ def value_irs_env(self): 'mat': hl.tndarray(hl.tfloat64, 2), 'aa': hl.tarray(hl.tarray(hl.tint32)), 'sta': hl.tstream(hl.tarray(hl.tint32)), - 'da': hl.tarray(hl.ttuple(hl.tint32, hl.tstr)), - 'nd': hl.tndarray(hl.tfloat64, 1), + 'sts': hl.tstream(hl.tstruct(x=hl.tint32, y=hl.tint64, z=hl.tfloat64)), + 'da': hl.tstream(hl.ttuple(hl.tint32, hl.tstr)), + 'nd': hl.tndarray(hl.tfloat64, 2), 'v': hl.tint32, 's': hl.tstruct(x=hl.tint32, y=hl.tint64, z=hl.tfloat64), 't': hl.ttuple(hl.tint32, hl.tint64, hl.tfloat64), @@ -42,6 +43,7 @@ def value_irs(self): mat = ir.Ref('mat') aa = ir.Ref('aa', env['aa']) sta = ir.Ref('sta', env['sta']) + sts = ir.Ref('sts', env['sts']) da = ir.Ref('da', env['da']) nd = ir.Ref('nd', env['nd']) v = ir.Ref('v', env['v']) @@ -77,7 +79,7 @@ def aggregate(x): ir.ArrayRef(a, i), ir.ArrayLen(a), ir.ArraySort(ir.ToStream(a), 'l', 'r', ir.ApplyComparisonOp("LT", ir.Ref('l', hl.tint32), ir.Ref('r', hl.tint32))), - ir.ToSet(a), + ir.ToSet(st), ir.ToDict(da), ir.ToArray(st), ir.CastToArray(ir.NA(hl.tset(hl.tint32))), @@ -89,17 +91,17 @@ def aggregate(x): ir.NDArrayRef(nd, [ir.I64(1), ir.I64(2)]), ir.NDArrayMap(nd, 'v', v), ir.NDArrayMatMul(nd, nd), - ir.LowerBoundOnOrderedCollection(a, i, True), + ir.LowerBoundOnOrderedCollection(a, i, False), ir.GroupByKey(da), - ir.RNGSplit(rngState, ir.MakeTuple([ir.I64(1), ir.MakeTuple([ir.I64(2), ir.I64(3)])])), + ir.RNGSplit(rngState, ir.MakeTuple([ir.I64(1), ir.I64(2), ir.I64(3)])), ir.StreamMap(st, 'v', v), ir.StreamZip([st, st], ['a', 'b'], ir.TrueIR(), 'ExtendNA'), - ir.StreamFilter(st, 'v', v), + ir.StreamFilter(st, 'v', c), ir.StreamFlatMap(sta, 'v', ir.ToStream(v)), ir.StreamFold(st, ir.I32(0), 'x', 'v', v), ir.StreamScan(st, ir.I32(0), 'x', 'v', v), - ir.StreamWhiten(whitenStream, "newChunk", "prevWindow", 0, 0, 0, 0, False), - ir.StreamJoinRightDistinct(st, st, ['k'], ['k'], 'l', 'r', ir.I32(1), "left"), + ir.StreamWhiten(whitenStream, "newChunk", "prevWindow", 1, 1, 1, 1, False), + ir.StreamJoinRightDistinct(sts, sts, ['x'], ['x'], 'l', 'r', ir.I32(1), "left"), ir.StreamFor(st, 'v', ir.Void()), aggregate(ir.AggFilter(ir.TrueIR(), ir.I32(0), False)), aggregate(ir.AggExplode(ir.StreamRange(ir.I32(0), ir.I32(2), ir.I32(1)), 'x', ir.I32(0), False)), diff --git a/hail/src/main/scala/is/hail/expr/ir/Binds.scala b/hail/src/main/scala/is/hail/expr/ir/Binds.scala index c4942c59134..1c6f3049b5a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Binds.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Binds.scala @@ -11,6 +11,8 @@ object Binds { object Bindings { private val empty: Array[(String, Type)] = Array() + // A call to Bindings(x, i) may only query the types of children with + // index < i def apply(x: BaseIR, i: Int): Iterable[(String, Type)] = x match { case Let(name, value, _) => if (i == 1) Array(name -> value.typ) else empty case TailLoop(name, args, resultType, _) => if (i == args.length) @@ -26,14 +28,14 @@ object Bindings { else empty case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) => - val contextType = TIterable.elementType(contexts.typ) - val eltType = tcoerce[TStruct](tcoerce[TStream](makeProducer.typ).elementType) - if (i == 1) + if (i == 1) { + val contextType = TIterable.elementType(contexts.typ) Array(ctxName -> contextType) - else if (i == 2) + } else if (i == 2) { + val eltType = tcoerce[TStruct](tcoerce[TStream](makeProducer.typ).elementType) Array(curKey -> eltType.typeAfterSelectNames(key), curVals -> TArray(eltType)) - else + } else empty case StreamFor(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty case StreamFlatMap(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty diff --git a/hail/src/main/scala/is/hail/expr/ir/Pretty.scala b/hail/src/main/scala/is/hail/expr/ir/Pretty.scala index 7e0f6d59053..542c69d0214 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Pretty.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Pretty.scala @@ -185,10 +185,8 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, FastSeq(prettyIdentifier(name), Pretty.prettyBooleanLiteral(isScan)) case TailLoop(name, args, returnType, _) if !elideBindings => FastSeq(prettyIdentifier(name), prettyIdentifiers(args.map(_._1).toFastSeq), returnType.parsableString()) - case Recur(name, _, t) => if (elideBindings) - single(t.parsableString()) - else - FastSeq(prettyIdentifier(name), t.parsableString()) + case Recur(name, _, t) if !elideBindings => + FastSeq(prettyIdentifier(name)) // case Ref(name, t) if t != null => FastSeq(prettyIdentifier(name), t.parsableString()) // For debug purposes case Ref(name, _) => single(prettyIdentifier(name)) case RelationalRef(name, t) => if (elideBindings) diff --git a/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala b/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala index baa951be064..582baeb089a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala @@ -283,7 +283,7 @@ object TypeCheck { assert(a.typ.isInstanceOf[TStream]) assert(lessThan.typ == TBoolean) case x@ToSet(a) => - assert(a.typ.isInstanceOf[TStream]) + assert(a.typ.isInstanceOf[TStream], a.typ) case x@ToDict(a) => assert(a.typ.isInstanceOf[TStream]) assert(tcoerce[TBaseStruct](tcoerce[TStream](a.typ).elementType).size == 2) @@ -360,7 +360,7 @@ object TypeCheck { assert(key.forall(eltType.hasField)) case x@StreamFilter(a, name, cond) => assert(a.typ.asInstanceOf[TStream].elementType.isRealizable) - assert(cond.typ == TBoolean) + assert(cond.typ == TBoolean, cond.typ) assert(x.typ == a.typ) case x@StreamTakeWhile(a, name, cond) => assert(a.typ.asInstanceOf[TStream].elementType.isRealizable)