Skip to content

Commit

Permalink
misc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Nov 13, 2023
1 parent 8e0c7fe commit c8cfa80
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 20 deletions.
2 changes: 1 addition & 1 deletion hail/python/hail/ir/table_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def _handle_randomness(self, uid_field_name):
return TableFilterIntervals(self.child.handle_randomness(uid_field_name), self.intervals, self.point_type, self.keep)

def head_str(self):
return f'{self.child.typ.key_type()._parsable_string()} {dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}'
return f'{self.child.typ.key_type._parsable_string()} {dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}'

def _eq(self, other):
return self.intervals == other.intervals and self.point_type == other.point_type and self.keep == other.keep
Expand Down
18 changes: 10 additions & 8 deletions hail/python/test/hail/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def value_irs_env(self):
'mat': hl.tndarray(hl.tfloat64, 2),
'aa': hl.tarray(hl.tarray(hl.tint32)),
'sta': hl.tstream(hl.tarray(hl.tint32)),
'da': hl.tarray(hl.ttuple(hl.tint32, hl.tstr)),
'nd': hl.tndarray(hl.tfloat64, 1),
'sts': hl.tstream(hl.tstruct(x=hl.tint32, y=hl.tint64, z=hl.tfloat64)),
'da': hl.tstream(hl.ttuple(hl.tint32, hl.tstr)),
'nd': hl.tndarray(hl.tfloat64, 2),
'v': hl.tint32,
's': hl.tstruct(x=hl.tint32, y=hl.tint64, z=hl.tfloat64),
't': hl.ttuple(hl.tint32, hl.tint64, hl.tfloat64),
Expand All @@ -42,6 +43,7 @@ def value_irs(self):
mat = ir.Ref('mat')
aa = ir.Ref('aa', env['aa'])
sta = ir.Ref('sta', env['sta'])
sts = ir.Ref('sts', env['sts'])
da = ir.Ref('da', env['da'])
nd = ir.Ref('nd', env['nd'])
v = ir.Ref('v', env['v'])
Expand Down Expand Up @@ -77,7 +79,7 @@ def aggregate(x):
ir.ArrayRef(a, i),
ir.ArrayLen(a),
ir.ArraySort(ir.ToStream(a), 'l', 'r', ir.ApplyComparisonOp("LT", ir.Ref('l', hl.tint32), ir.Ref('r', hl.tint32))),
ir.ToSet(a),
ir.ToSet(st),
ir.ToDict(da),
ir.ToArray(st),
ir.CastToArray(ir.NA(hl.tset(hl.tint32))),
Expand All @@ -89,17 +91,17 @@ def aggregate(x):
ir.NDArrayRef(nd, [ir.I64(1), ir.I64(2)]),
ir.NDArrayMap(nd, 'v', v),
ir.NDArrayMatMul(nd, nd),
ir.LowerBoundOnOrderedCollection(a, i, True),
ir.LowerBoundOnOrderedCollection(a, i, False),
ir.GroupByKey(da),
ir.RNGSplit(rngState, ir.MakeTuple([ir.I64(1), ir.MakeTuple([ir.I64(2), ir.I64(3)])])),
ir.RNGSplit(rngState, ir.MakeTuple([ir.I64(1), ir.I64(2), ir.I64(3)])),
ir.StreamMap(st, 'v', v),
ir.StreamZip([st, st], ['a', 'b'], ir.TrueIR(), 'ExtendNA'),
ir.StreamFilter(st, 'v', v),
ir.StreamFilter(st, 'v', c),
ir.StreamFlatMap(sta, 'v', ir.ToStream(v)),
ir.StreamFold(st, ir.I32(0), 'x', 'v', v),
ir.StreamScan(st, ir.I32(0), 'x', 'v', v),
ir.StreamWhiten(whitenStream, "newChunk", "prevWindow", 0, 0, 0, 0, False),
ir.StreamJoinRightDistinct(st, st, ['k'], ['k'], 'l', 'r', ir.I32(1), "left"),
ir.StreamWhiten(whitenStream, "newChunk", "prevWindow", 1, 1, 1, 1, False),
ir.StreamJoinRightDistinct(sts, sts, ['x'], ['x'], 'l', 'r', ir.I32(1), "left"),
ir.StreamFor(st, 'v', ir.Void()),
aggregate(ir.AggFilter(ir.TrueIR(), ir.I32(0), False)),
aggregate(ir.AggExplode(ir.StreamRange(ir.I32(0), ir.I32(2), ir.I32(1)), 'x', ir.I32(0), False)),
Expand Down
12 changes: 7 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ object Binds {
object Bindings {
private val empty: Array[(String, Type)] = Array()

// A call to Bindings(x, i) may only query the types of children with
// index < i
def apply(x: BaseIR, i: Int): Iterable[(String, Type)] = x match {
case Let(name, value, _) => if (i == 1) Array(name -> value.typ) else empty
case TailLoop(name, args, resultType, _) => if (i == args.length)
Expand All @@ -26,14 +28,14 @@ object Bindings {
else
empty
case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) =>
val contextType = TIterable.elementType(contexts.typ)
val eltType = tcoerce[TStruct](tcoerce[TStream](makeProducer.typ).elementType)
if (i == 1)
if (i == 1) {
val contextType = TIterable.elementType(contexts.typ)
Array(ctxName -> contextType)
else if (i == 2)
} else if (i == 2) {
val eltType = tcoerce[TStruct](tcoerce[TStream](makeProducer.typ).elementType)
Array(curKey -> eltType.typeAfterSelectNames(key),
curVals -> TArray(eltType))
else
} else
empty
case StreamFor(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty
case StreamFlatMap(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty
Expand Down
6 changes: 2 additions & 4 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,8 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int,
FastSeq(prettyIdentifier(name), Pretty.prettyBooleanLiteral(isScan))
case TailLoop(name, args, returnType, _) if !elideBindings =>
FastSeq(prettyIdentifier(name), prettyIdentifiers(args.map(_._1).toFastSeq), returnType.parsableString())
case Recur(name, _, t) => if (elideBindings)
single(t.parsableString())
else
FastSeq(prettyIdentifier(name), t.parsableString())
case Recur(name, _, t) if !elideBindings =>
FastSeq(prettyIdentifier(name))
// case Ref(name, t) if t != null => FastSeq(prettyIdentifier(name), t.parsableString()) // For debug purposes
case Ref(name, _) => single(prettyIdentifier(name))
case RelationalRef(name, t) => if (elideBindings)
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ object TypeCheck {
assert(a.typ.isInstanceOf[TStream])
assert(lessThan.typ == TBoolean)
case x@ToSet(a) =>
assert(a.typ.isInstanceOf[TStream])
assert(a.typ.isInstanceOf[TStream], a.typ)
case x@ToDict(a) =>
assert(a.typ.isInstanceOf[TStream])
assert(tcoerce[TBaseStruct](tcoerce[TStream](a.typ).elementType).size == 2)
Expand Down Expand Up @@ -360,7 +360,7 @@ object TypeCheck {
assert(key.forall(eltType.hasField))
case x@StreamFilter(a, name, cond) =>
assert(a.typ.asInstanceOf[TStream].elementType.isRealizable)
assert(cond.typ == TBoolean)
assert(cond.typ == TBoolean, cond.typ)
assert(x.typ == a.typ)
case x@StreamTakeWhile(a, name, cond) =>
assert(a.typ.asInstanceOf[TStream].elementType.isRealizable)
Expand Down

0 comments on commit c8cfa80

Please sign in to comment.