From 28e56f7027d43d9434817234ba8d1a30a7d7af76 Mon Sep 17 00:00:00 2001 From: Patrick Schultz Date: Mon, 30 Oct 2023 11:00:55 -0400 Subject: [PATCH] [query] fix bug in new dict decoder (#13939) The decoder uses a `StagedArrayBuilder` to hold elements while being sorted. The array builder is stored in a class field. When the same decoder function is called more than once, that array builder is reused. Before this fix, the array builder was never cleared, so if the decoder function was called more than once, the array builder would still contain the elements from previously decoded dicts. Since it's highly non-obvious that you need to call `clear` immediately after `new StagedCodeBuilder`, this PR makes the constructor take a CodeBuilder, and always inserts a clear at the call site. I also took the opportunity to CodeBuilderify the rest of the interface. --- hail/python/test/hail/test_ir.py | 7 ++++ .../scala/is/hail/expr/ir/ArraySorter.scala | 16 ++++---- .../src/main/scala/is/hail/expr/ir/Emit.scala | 18 ++++----- .../expr/ir/SpecializedArrayBuilders.scala | 39 ++++++++++++------- .../is/hail/expr/ir/streams/StreamUtils.scala | 12 +++--- .../is/hail/io/bgen/StagedBGENReader.scala | 6 +-- .../encoded/EDictAsUnsortedArrayOfPairs.scala | 6 +-- .../is/hail/types/encoded/EUnsortedSet.scala | 6 +-- .../is/hail/expr/ir/StagedBTreeSuite.scala | 7 ++-- 9 files changed, 65 insertions(+), 52 deletions(-) diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index d2544e337fa..b2a53f670b7 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -625,6 +625,13 @@ def test_literal_ndarray_encodings(value): _assert_encoding_roundtrip(value.T) +def test_decoding_multiple_dicts(): + dict = {0: 'a', 1: 'b', 2: 'c'} + dict2 = {0: 'x', 1: 'y', 2: 'z'} + ht = hl.utils.range_table(1).annotate(indices = hl.array([0, 1, 2])) + ht.select(a=ht.indices.map(lambda i: hl.struct(x=hl.dict(dict).get(i), y=hl.dict(dict2).get(i)))).collect() + + def test_locus_interval_encoding(): start = hl.Locus(contig='chr1', position=10001, reference_genome='GRCh38') end = hl.Locus(contig='chr1', position=11001, reference_genome='GRCh38') diff --git a/hail/src/main/scala/is/hail/expr/ir/ArraySorter.scala b/hail/src/main/scala/is/hail/expr/ir/ArraySorter.scala index d11cc6d2d77..ef1109fb6b6 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ArraySorter.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ArraySorter.scala @@ -33,14 +33,14 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) { cb.while_(i < size, { cb.if_(!array.isMissing(i), { - cb.if_(newEnd.cne(i), cb += array.update(newEnd, array.apply(i))) + cb.if_(newEnd.cne(i), array.update(cb, newEnd, array.apply(i))) cb.assign(newEnd, newEnd + 1) }) cb.assign(i, i + 1) }) cb.assign(i, newEnd) cb.while_(i < size, { - cb += array.setMissing(i, true) + array.setMissing(cb, i, true) cb.assign(i, i + 1) }) @@ -126,7 +126,7 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) { cb.assign(i, 0) cb.while_(i < newEnd, { - cb += array.update(i, arrayRef(workingArray2)(i)) + array.update(cb, i, arrayRef(workingArray2)(i)) cb.assign(i, i + 1) }) @@ -162,12 +162,12 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) { cb.while_(i < size, { cb.if_(!array.isMissing(i), { cb.if_(i.cne(n), - cb += array.update(n, array(i))) + array.update(cb, n, array(i))) cb.assign(n, n + 1) }) cb.assign(i, i + 1) }) - cb += array.setSize(n) + array.setSize(cb, n) } def distinctFromSorted(cb: EmitCodeBuilder, region: Value[Region], discardNext: (EmitCodeBuilder, Value[Region], EmitCode, EmitCode) => Code[Boolean]): Unit = { @@ -197,12 +197,12 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) { cb.assign(n, n + 1) cb.if_(i < size && i.cne(n), { - cb += array.setMissing(n, array.isMissing(i)) - cb.if_(!array.isMissing(n), cb += array.update(n, array(i))) + array.setMissing(cb, n, array.isMissing(i)) + cb.if_(!array.isMissing(n), array.update(cb, n, array(i))) }) }) - cb += array.setSize(n) + array.setSize(cb, n) } cb.invokeVoid(distinctMB, region) diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index 66b840aad90..732265b5362 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -1137,7 +1137,7 @@ class Emit[C]( val sct = SingleCodeType.fromSType(producer.element.st) - val vab = new StagedArrayBuilder(sct, producer.element.required, mb, 0) + val vab = new StagedArrayBuilder(cb, sct, producer.element.required, 0) StreamUtils.writeToArrayBuilder(cb, producer, vab, region) val sorter = new ArraySorter(EmitRegion(mb, region), vab) sorter.sort(cb, region, makeDependentSortingFunction(cb, sct, lessThan, env, emitSelf, Array(left, right))) @@ -1188,7 +1188,7 @@ class Emit[C]( val sct = SingleCodeType.fromSType(producer.element.st) - val vab = new StagedArrayBuilder(sct, producer.element.required, mb, 0) + val vab = new StagedArrayBuilder(cb, sct, producer.element.required, 0) StreamUtils.writeToArrayBuilder(cb, producer, vab, region) val sorter = new ArraySorter(EmitRegion(mb, region), vab) @@ -1214,7 +1214,7 @@ class Emit[C]( val sct = SingleCodeType.fromSType(producer.element.st) - val vab = new StagedArrayBuilder(sct, producer.element.required, mb, 0) + val vab = new StagedArrayBuilder(cb, sct, producer.element.required, 0) StreamUtils.writeToArrayBuilder(cb, producer, vab, region) val sorter = new ArraySorter(EmitRegion(mb, region), vab) @@ -1255,7 +1255,7 @@ class Emit[C]( val producer = stream.getProducer(mb) val sct = SingleCodeType.fromSType(producer.element.st) - val sortedElts = new StagedArrayBuilder(sct, producer.element.required, mb, 16) + val sortedElts = new StagedArrayBuilder(cb, sct, producer.element.required, 16) StreamUtils.writeToArrayBuilder(cb, producer, sortedElts, region) val sorter = new ArraySorter(EmitRegion(mb, region), sortedElts) @@ -1270,7 +1270,7 @@ class Emit[C]( sorter.sort(cb, region, lt) sorter.pruneMissing(cb) - val groupSizes = new StagedArrayBuilder(Int32SingleCodeType, true, mb, 0) + val groupSizes = new StagedArrayBuilder(cb, Int32SingleCodeType, true, 0) val eltIdx = mb.newLocal[Int]("groupByKey_eltIdx") val grpIdx = mb.newLocal[Int]("groupByKey_grpIdx") @@ -1278,12 +1278,10 @@ class Emit[C]( val outerSize = mb.newLocal[Int]("groupByKey_outerSize") val groupSize = mb.newLocal[Int]("groupByKey_groupSize") - - cb += groupSizes.clear cb.assign(eltIdx, 0) cb.assign(groupSize, 0) - def sameKeyAtIndices(cb: EmitCodeBuilder, region: Value[Region], idx1: Code[Int], idx2: Code[Int]): Code[Boolean] = { + def sameKeyAtIndices(cb: EmitCodeBuilder, region: Value[Region], idx1: Value[Int], idx2: Value[Int]): Code[Boolean] = { val lk = cb.memoize( sortedElts.loadFromIndex(cb, region, idx1).flatMap(cb) { x => x.asBaseStruct.loadField(cb, 0) @@ -1306,14 +1304,14 @@ class Emit[C]( cb.if_(eltIdx.ceq(sortedElts.size - 1), { cb.goto(newGroup) }, { - cb.if_(sameKeyAtIndices(cb, region, eltIdx, eltIdx + 1), { + cb.if_(sameKeyAtIndices(cb, region, eltIdx, cb.memoize(eltIdx + 1)), { cb.goto(bottomOfLoop) }, { cb.goto(newGroup) }) }) cb.define(newGroup) - cb += groupSizes.add(groupSize) + groupSizes.add(cb, groupSize) cb.assign(groupSize, 0) cb.define(bottomOfLoop) diff --git a/hail/src/main/scala/is/hail/expr/ir/SpecializedArrayBuilders.scala b/hail/src/main/scala/is/hail/expr/ir/SpecializedArrayBuilders.scala index 29628595777..39a28150f5d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/SpecializedArrayBuilders.scala +++ b/hail/src/main/scala/is/hail/expr/ir/SpecializedArrayBuilders.scala @@ -9,8 +9,9 @@ import is.hail.utils.BoxedArrayBuilder import scala.reflect.ClassTag -class StagedArrayBuilder(val elt: SingleCodeType, val eltRequired: Boolean, mb: EmitMethodBuilder[_], len: Code[Int]) { +class StagedArrayBuilder(cb: EmitCodeBuilder, val elt: SingleCodeType, val eltRequired: Boolean, len: Int) { + def mb = cb.emb val ti: TypeInfo[_] = elt.ti val ref: Value[Any] = coerce[Any](ti match { @@ -22,13 +23,19 @@ class StagedArrayBuilder(val elt: SingleCodeType, val eltRequired: Boolean, mb: case ti => throw new RuntimeException(s"unsupported typeinfo found: $ti") }) - def add(x: Code[_]): Code[Unit] = ti match { + // If a method containing `new StagedArrayBuilder(...)` is called multiple times, + // the invocations will share the same array builder at runtime. Clearing + // here ensures a "new" array builder is always empty. + clear(cb) + ensureCapacity(cb, len) + + def add(cb: EmitCodeBuilder, x: Code[_]): Unit = cb.append(ti match { case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Boolean, Unit]("add", coerce[Boolean](x)) case IntInfo => coerce[IntMissingArrayBuilder](ref).invoke[Int, Unit]("add", coerce[Int](x)) case LongInfo => coerce[LongMissingArrayBuilder](ref).invoke[Long, Unit]("add", coerce[Long](x)) case FloatInfo => coerce[FloatMissingArrayBuilder](ref).invoke[Float, Unit]("add", coerce[Float](x)) case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Double, Unit]("add", coerce[Double](x)) - } + }) def apply(i: Code[Int]): Code[_] = ti match { case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Int, Boolean]("apply", i) @@ -38,34 +45,36 @@ class StagedArrayBuilder(val elt: SingleCodeType, val eltRequired: Boolean, mb: case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Int, Double]("apply", i) } - def update(i: Code[Int], x: Code[_]): Code[Unit] = ti match { + def update(cb: EmitCodeBuilder, i: Code[Int], x: Code[_]): Unit = cb.append(ti match { case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Int, Boolean, Unit]("update", i, coerce[Boolean](x)) case IntInfo => coerce[IntMissingArrayBuilder](ref).invoke[Int, Int, Unit]("update", i, coerce[Int](x)) case LongInfo => coerce[LongMissingArrayBuilder](ref).invoke[Int, Long, Unit]("update", i, coerce[Long](x)) case FloatInfo => coerce[FloatMissingArrayBuilder](ref).invoke[Int, Float, Unit]("update", i, coerce[Float](x)) case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Int, Double, Unit]("update", i, coerce[Double](x)) - } + }) - def addMissing(): Code[Unit] = - coerce[MissingArrayBuilder](ref).invoke[Unit]("addMissing") + def addMissing(cb: EmitCodeBuilder): Unit = + cb += coerce[MissingArrayBuilder](ref).invoke[Unit]("addMissing") def isMissing(i: Code[Int]): Code[Boolean] = coerce[MissingArrayBuilder](ref).invoke[Int, Boolean]("isMissing", i) - def setMissing(i: Code[Int], m: Code[Boolean]): Code[Unit] = - coerce[MissingArrayBuilder](ref).invoke[Int, Boolean, Unit]("setMissing", i, m) + def setMissing(cb: EmitCodeBuilder, i: Code[Int], m: Code[Boolean]): Unit = + cb += coerce[MissingArrayBuilder](ref).invoke[Int, Boolean, Unit]("setMissing", i, m) def size: Code[Int] = coerce[MissingArrayBuilder](ref).invoke[Int]("size") - def setSize(n: Code[Int]): Code[Unit] = coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("setSize", n) + def setSize(cb: EmitCodeBuilder, n: Code[Int]): Unit = + cb += coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("setSize", n) - def ensureCapacity(n: Code[Int]): Code[Unit] = coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("ensureCapacity", n) + def ensureCapacity(cb: EmitCodeBuilder, n: Code[Int]): Unit = + cb += coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("ensureCapacity", n) - def clear: Code[Unit] = coerce[MissingArrayBuilder](ref).invoke[Unit]("clear") + def clear(cb: EmitCodeBuilder): Unit = + cb += coerce[MissingArrayBuilder](ref).invoke[Unit]("clear") - def loadFromIndex(cb: EmitCodeBuilder, r: Value[Region], i: Code[Int]): IEmitCode = { - val idx = cb.newLocal[Int]("loadFromIndex_idx", i) - IEmitCode(cb, isMissing(idx), elt.loadToSValue(cb, cb.memoizeAny(apply(idx), ti))) + def loadFromIndex(cb: EmitCodeBuilder, r: Value[Region], i: Value[Int]): IEmitCode = { + IEmitCode(cb, isMissing(i), elt.loadToSValue(cb, cb.memoizeAny(apply(i), ti))) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/streams/StreamUtils.scala b/hail/src/main/scala/is/hail/expr/ir/streams/StreamUtils.scala index 3f214878e1e..a3dfba0fe2f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/streams/StreamUtils.scala +++ b/hail/src/main/scala/is/hail/expr/ir/streams/StreamUtils.scala @@ -53,7 +53,7 @@ object StreamUtils { val aTyp = PCanonicalArray(stream.element.emitType.storageType, true) stream.length match { case None => - val vab = new StagedArrayBuilder(SingleCodeType.fromSType(stream.element.st), stream.element.required, mb, 0) + val vab = new StagedArrayBuilder(cb, SingleCodeType.fromSType(stream.element.st), stream.element.required, 0) writeToArrayBuilder(cb, stream, vab, destRegion) cb.assign(xLen, vab.size) @@ -86,17 +86,17 @@ object StreamUtils { destRegion: Value[Region] ): Unit = { stream.memoryManagedConsume(destRegion, cb, setup = { cb => - cb += ab.clear + ab.clear(cb) stream.length match { - case Some(computeLen) => cb += ab.ensureCapacity(computeLen(cb)) - case None => cb += ab.ensureCapacity(16) + case Some(computeLen) => ab.ensureCapacity(cb, computeLen(cb)) + case None => ab.ensureCapacity(cb, 16) } }) { cb => stream.element.toI(cb).consume(cb, - cb += ab.addMissing(), - sc => cb += ab.add(ab.elt.coerceSCode(cb, sc, destRegion, deepCopy = stream.requiresMemoryManagementPerElement).code) + ab.addMissing(cb), + sc => ab.add(cb, ab.elt.coerceSCode(cb, sc, destRegion, deepCopy = stream.requiresMemoryManagementPerElement).code) ) } } diff --git a/hail/src/main/scala/is/hail/io/bgen/StagedBGENReader.scala b/hail/src/main/scala/is/hail/io/bgen/StagedBGENReader.scala index 06127b6b4dd..b1fc7e1e9d8 100644 --- a/hail/src/main/scala/is/hail/io/bgen/StagedBGENReader.scala +++ b/hail/src/main/scala/is/hail/io/bgen/StagedBGENReader.scala @@ -522,7 +522,7 @@ object BGENFunctions extends RegistryFunctions { "alleles" -> PCanonicalArray(PCanonicalString(true), true), "offset" -> PInt64Required) val bufferSct = SingleCodeType.fromSType(rowPType.sType) - val buffer = new StagedArrayBuilder(bufferSct, true, mb, 8) + val buffer = new StagedArrayBuilder(cb, bufferSct, true, 8) val currSize = cb.newLocal[Int]("currSize", 0) val spec = TypedCodecSpec( @@ -553,7 +553,7 @@ object BGENFunctions extends RegistryFunctions { cb += ob.invoke[Unit]("close") cb.assign(groupIndex, groupIndex + 1) - cb += buffer.clear + buffer.clear(cb) cb.assign(currSize, 0) } @@ -569,7 +569,7 @@ object BGENFunctions extends RegistryFunctions { cb.if_(currSize ceq bufferSize, { dumpBuffer(cb) }) - cb += buffer.add(bufferSct.coerceSCode(cb, row, er.region, false).code) + buffer.add(cb, bufferSct.coerceSCode(cb, row, er.region, false).code) cb.assign(currSize, currSize + 1) cb.assign(nWritten, nWritten + 1) }) diff --git a/hail/src/main/scala/is/hail/types/encoded/EDictAsUnsortedArrayOfPairs.scala b/hail/src/main/scala/is/hail/types/encoded/EDictAsUnsortedArrayOfPairs.scala index ef2ad14c5f9..55fa15d9692 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EDictAsUnsortedArrayOfPairs.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EDictAsUnsortedArrayOfPairs.scala @@ -39,10 +39,10 @@ final case class EDictAsUnsortedArrayOfPairs(val elementType: EType, override va val decodedUnsortedArray = arrayDecoder(cb, region, in).asInstanceOf[SIndexablePointerValue] val sct = SingleCodeType.fromSType(decodedUnsortedArray.st.elementType) - val ab = new StagedArrayBuilder(sct, true, cb.emb, 0) - cb.append(ab.ensureCapacity(decodedUnsortedArray.length)) + val ab = new StagedArrayBuilder(cb, sct, true, 0) + ab.ensureCapacity(cb, decodedUnsortedArray.length) decodedUnsortedArray.forEachDefined(cb) { (cb, i, res) => - cb.append(ab.add(ab.elt.coerceSCode(cb, res, region, false).code)) + ab.add(cb, ab.elt.coerceSCode(cb, res, region, false).code) } val sorter = new ArraySorter(EmitRegion(cb.emb, region), ab) diff --git a/hail/src/main/scala/is/hail/types/encoded/EUnsortedSet.scala b/hail/src/main/scala/is/hail/types/encoded/EUnsortedSet.scala index ab657cedc33..b3be4d606af 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EUnsortedSet.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EUnsortedSet.scala @@ -36,10 +36,10 @@ final case class EUnsortedSet(val elementType: EType, override val required: Boo val decodedUnsortedArray = arrayDecoder(cb, region, in).asInstanceOf[SIndexablePointerValue] val sct = SingleCodeType.fromSType(decodedUnsortedArray.st.elementType) - val ab = new StagedArrayBuilder(sct, elementType.required, cb.emb, 0) - cb.append(ab.ensureCapacity(decodedUnsortedArray.length)) + val ab = new StagedArrayBuilder(cb, sct, elementType.required, 0) + ab.ensureCapacity(cb, decodedUnsortedArray.length) decodedUnsortedArray.forEachDefined(cb) { (cb, i, res) => - cb.append(ab.add(ab.elt.coerceSCode(cb, res, region, false).code)) + ab.add(cb, ab.elt.coerceSCode(cb, res, region, false).code) } val sorter = new ArraySorter(EmitRegion(cb.emb, region), ab) diff --git a/hail/src/test/scala/is/hail/expr/ir/StagedBTreeSuite.scala b/hail/src/test/scala/is/hail/expr/ir/StagedBTreeSuite.scala index 3912c5420b0..82642328b63 100644 --- a/hail/src/test/scala/is/hail/expr/ir/StagedBTreeSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/StagedBTreeSuite.scala @@ -136,20 +136,19 @@ class BTreeBackedSet(ctx: ExecuteContext, region: Region, n: Int) { val key = new TestBTreeKey(fb.apply_method) val btree = new AppendOnlyBTree(cb, key, r, root, maxElements = n) - val sab = new StagedArrayBuilder(Int64SingleCodeType, true, fb.apply_method, 16) val idx = fb.newLocal[Int]() val returnArray = fb.newLocal[Array[java.lang.Long]]() fb.emitWithBuilder { cb => + val sab = new StagedArrayBuilder(cb, Int64SingleCodeType, true, 16) cb += (r := fb.getCodeParam[Region](1)) cb += (root := fb.getCodeParam[Long](2)) - cb += sab.clear btree.foreach(cb) { (cb, _koff) => val koff = cb.memoize(_koff) val ec = key.loadCompKey(cb, koff) cb.if_(ec.m, - cb += sab.addMissing(), - cb += sab.add(ec.pv.asInt64.value)) + sab.addMissing(cb), + sab.add(cb, ec.pv.asInt64.value)) } cb += (returnArray := Code.newArray[java.lang.Long](sab.size)) cb.for_(