Skip to content

Commit

Permalink
[query] fix bug in new dict decoder (#13939)
Browse files Browse the repository at this point in the history
The decoder uses a `StagedArrayBuilder` to hold elements while being
sorted. The array builder is stored in a class field. When the same
decoder function is called more than once, that array builder is reused.

Before this fix, the array builder was never cleared, so if the decoder
function was called more than once, the array builder would still
contain the elements from previously decoded dicts.

Since it's highly non-obvious that you need to call `clear` immediately
after `new StagedCodeBuilder`, this PR makes the constructor take a
CodeBuilder, and always inserts a clear at the call site. I also took
the opportunity to CodeBuilderify the rest of the interface.
  • Loading branch information
patrick-schultz authored Oct 30, 2023
1 parent aade5c6 commit 28e56f7
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 52 deletions.
7 changes: 7 additions & 0 deletions hail/python/test/hail/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,13 @@ def test_literal_ndarray_encodings(value):
_assert_encoding_roundtrip(value.T)


def test_decoding_multiple_dicts():
dict = {0: 'a', 1: 'b', 2: 'c'}
dict2 = {0: 'x', 1: 'y', 2: 'z'}
ht = hl.utils.range_table(1).annotate(indices = hl.array([0, 1, 2]))
ht.select(a=ht.indices.map(lambda i: hl.struct(x=hl.dict(dict).get(i), y=hl.dict(dict2).get(i)))).collect()


def test_locus_interval_encoding():
start = hl.Locus(contig='chr1', position=10001, reference_genome='GRCh38')
end = hl.Locus(contig='chr1', position=11001, reference_genome='GRCh38')
Expand Down
16 changes: 8 additions & 8 deletions hail/src/main/scala/is/hail/expr/ir/ArraySorter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) {

cb.while_(i < size, {
cb.if_(!array.isMissing(i), {
cb.if_(newEnd.cne(i), cb += array.update(newEnd, array.apply(i)))
cb.if_(newEnd.cne(i), array.update(cb, newEnd, array.apply(i)))
cb.assign(newEnd, newEnd + 1)
})
cb.assign(i, i + 1)
})
cb.assign(i, newEnd)
cb.while_(i < size, {
cb += array.setMissing(i, true)
array.setMissing(cb, i, true)
cb.assign(i, i + 1)
})

Expand Down Expand Up @@ -126,7 +126,7 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) {

cb.assign(i, 0)
cb.while_(i < newEnd, {
cb += array.update(i, arrayRef(workingArray2)(i))
array.update(cb, i, arrayRef(workingArray2)(i))
cb.assign(i, i + 1)
})

Expand Down Expand Up @@ -162,12 +162,12 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) {
cb.while_(i < size, {
cb.if_(!array.isMissing(i), {
cb.if_(i.cne(n),
cb += array.update(n, array(i)))
array.update(cb, n, array(i)))
cb.assign(n, n + 1)
})
cb.assign(i, i + 1)
})
cb += array.setSize(n)
array.setSize(cb, n)
}

def distinctFromSorted(cb: EmitCodeBuilder, region: Value[Region], discardNext: (EmitCodeBuilder, Value[Region], EmitCode, EmitCode) => Code[Boolean]): Unit = {
Expand Down Expand Up @@ -197,12 +197,12 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) {
cb.assign(n, n + 1)

cb.if_(i < size && i.cne(n), {
cb += array.setMissing(n, array.isMissing(i))
cb.if_(!array.isMissing(n), cb += array.update(n, array(i)))
array.setMissing(cb, n, array.isMissing(i))
cb.if_(!array.isMissing(n), array.update(cb, n, array(i)))
})

})
cb += array.setSize(n)
array.setSize(cb, n)
}

cb.invokeVoid(distinctMB, region)
Expand Down
18 changes: 8 additions & 10 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ class Emit[C](

val sct = SingleCodeType.fromSType(producer.element.st)

val vab = new StagedArrayBuilder(sct, producer.element.required, mb, 0)
val vab = new StagedArrayBuilder(cb, sct, producer.element.required, 0)
StreamUtils.writeToArrayBuilder(cb, producer, vab, region)
val sorter = new ArraySorter(EmitRegion(mb, region), vab)
sorter.sort(cb, region, makeDependentSortingFunction(cb, sct, lessThan, env, emitSelf, Array(left, right)))
Expand Down Expand Up @@ -1188,7 +1188,7 @@ class Emit[C](

val sct = SingleCodeType.fromSType(producer.element.st)

val vab = new StagedArrayBuilder(sct, producer.element.required, mb, 0)
val vab = new StagedArrayBuilder(cb, sct, producer.element.required, 0)
StreamUtils.writeToArrayBuilder(cb, producer, vab, region)
val sorter = new ArraySorter(EmitRegion(mb, region), vab)

Expand All @@ -1214,7 +1214,7 @@ class Emit[C](

val sct = SingleCodeType.fromSType(producer.element.st)

val vab = new StagedArrayBuilder(sct, producer.element.required, mb, 0)
val vab = new StagedArrayBuilder(cb, sct, producer.element.required, 0)
StreamUtils.writeToArrayBuilder(cb, producer, vab, region)
val sorter = new ArraySorter(EmitRegion(mb, region), vab)

Expand Down Expand Up @@ -1255,7 +1255,7 @@ class Emit[C](

val producer = stream.getProducer(mb)
val sct = SingleCodeType.fromSType(producer.element.st)
val sortedElts = new StagedArrayBuilder(sct, producer.element.required, mb, 16)
val sortedElts = new StagedArrayBuilder(cb, sct, producer.element.required, 16)
StreamUtils.writeToArrayBuilder(cb, producer, sortedElts, region)
val sorter = new ArraySorter(EmitRegion(mb, region), sortedElts)

Expand All @@ -1270,20 +1270,18 @@ class Emit[C](
sorter.sort(cb, region, lt)
sorter.pruneMissing(cb)

val groupSizes = new StagedArrayBuilder(Int32SingleCodeType, true, mb, 0)
val groupSizes = new StagedArrayBuilder(cb, Int32SingleCodeType, true, 0)

val eltIdx = mb.newLocal[Int]("groupByKey_eltIdx")
val grpIdx = mb.newLocal[Int]("groupByKey_grpIdx")
val withinGrpIdx = mb.newLocal[Int]("groupByKey_withinGrpIdx")
val outerSize = mb.newLocal[Int]("groupByKey_outerSize")
val groupSize = mb.newLocal[Int]("groupByKey_groupSize")


cb += groupSizes.clear
cb.assign(eltIdx, 0)
cb.assign(groupSize, 0)

def sameKeyAtIndices(cb: EmitCodeBuilder, region: Value[Region], idx1: Code[Int], idx2: Code[Int]): Code[Boolean] = {
def sameKeyAtIndices(cb: EmitCodeBuilder, region: Value[Region], idx1: Value[Int], idx2: Value[Int]): Code[Boolean] = {
val lk = cb.memoize(
sortedElts.loadFromIndex(cb, region, idx1).flatMap(cb) { x =>
x.asBaseStruct.loadField(cb, 0)
Expand All @@ -1306,14 +1304,14 @@ class Emit[C](
cb.if_(eltIdx.ceq(sortedElts.size - 1), {
cb.goto(newGroup)
}, {
cb.if_(sameKeyAtIndices(cb, region, eltIdx, eltIdx + 1), {
cb.if_(sameKeyAtIndices(cb, region, eltIdx, cb.memoize(eltIdx + 1)), {
cb.goto(bottomOfLoop)
}, {
cb.goto(newGroup)
})
})
cb.define(newGroup)
cb += groupSizes.add(groupSize)
groupSizes.add(cb, groupSize)
cb.assign(groupSize, 0)

cb.define(bottomOfLoop)
Expand Down
39 changes: 24 additions & 15 deletions hail/src/main/scala/is/hail/expr/ir/SpecializedArrayBuilders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import is.hail.utils.BoxedArrayBuilder

import scala.reflect.ClassTag

class StagedArrayBuilder(val elt: SingleCodeType, val eltRequired: Boolean, mb: EmitMethodBuilder[_], len: Code[Int]) {
class StagedArrayBuilder(cb: EmitCodeBuilder, val elt: SingleCodeType, val eltRequired: Boolean, len: Int) {

def mb = cb.emb
val ti: TypeInfo[_] = elt.ti

val ref: Value[Any] = coerce[Any](ti match {
Expand All @@ -22,13 +23,19 @@ class StagedArrayBuilder(val elt: SingleCodeType, val eltRequired: Boolean, mb:
case ti => throw new RuntimeException(s"unsupported typeinfo found: $ti")
})

def add(x: Code[_]): Code[Unit] = ti match {
// If a method containing `new StagedArrayBuilder(...)` is called multiple times,
// the invocations will share the same array builder at runtime. Clearing
// here ensures a "new" array builder is always empty.
clear(cb)
ensureCapacity(cb, len)

def add(cb: EmitCodeBuilder, x: Code[_]): Unit = cb.append(ti match {
case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Boolean, Unit]("add", coerce[Boolean](x))
case IntInfo => coerce[IntMissingArrayBuilder](ref).invoke[Int, Unit]("add", coerce[Int](x))
case LongInfo => coerce[LongMissingArrayBuilder](ref).invoke[Long, Unit]("add", coerce[Long](x))
case FloatInfo => coerce[FloatMissingArrayBuilder](ref).invoke[Float, Unit]("add", coerce[Float](x))
case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Double, Unit]("add", coerce[Double](x))
}
})

def apply(i: Code[Int]): Code[_] = ti match {
case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Int, Boolean]("apply", i)
Expand All @@ -38,34 +45,36 @@ class StagedArrayBuilder(val elt: SingleCodeType, val eltRequired: Boolean, mb:
case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Int, Double]("apply", i)
}

def update(i: Code[Int], x: Code[_]): Code[Unit] = ti match {
def update(cb: EmitCodeBuilder, i: Code[Int], x: Code[_]): Unit = cb.append(ti match {
case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Int, Boolean, Unit]("update", i, coerce[Boolean](x))
case IntInfo => coerce[IntMissingArrayBuilder](ref).invoke[Int, Int, Unit]("update", i, coerce[Int](x))
case LongInfo => coerce[LongMissingArrayBuilder](ref).invoke[Int, Long, Unit]("update", i, coerce[Long](x))
case FloatInfo => coerce[FloatMissingArrayBuilder](ref).invoke[Int, Float, Unit]("update", i, coerce[Float](x))
case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Int, Double, Unit]("update", i, coerce[Double](x))
}
})

def addMissing(): Code[Unit] =
coerce[MissingArrayBuilder](ref).invoke[Unit]("addMissing")
def addMissing(cb: EmitCodeBuilder): Unit =
cb += coerce[MissingArrayBuilder](ref).invoke[Unit]("addMissing")

def isMissing(i: Code[Int]): Code[Boolean] =
coerce[MissingArrayBuilder](ref).invoke[Int, Boolean]("isMissing", i)

def setMissing(i: Code[Int], m: Code[Boolean]): Code[Unit] =
coerce[MissingArrayBuilder](ref).invoke[Int, Boolean, Unit]("setMissing", i, m)
def setMissing(cb: EmitCodeBuilder, i: Code[Int], m: Code[Boolean]): Unit =
cb += coerce[MissingArrayBuilder](ref).invoke[Int, Boolean, Unit]("setMissing", i, m)

def size: Code[Int] = coerce[MissingArrayBuilder](ref).invoke[Int]("size")

def setSize(n: Code[Int]): Code[Unit] = coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("setSize", n)
def setSize(cb: EmitCodeBuilder, n: Code[Int]): Unit =
cb += coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("setSize", n)

def ensureCapacity(n: Code[Int]): Code[Unit] = coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("ensureCapacity", n)
def ensureCapacity(cb: EmitCodeBuilder, n: Code[Int]): Unit =
cb += coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("ensureCapacity", n)

def clear: Code[Unit] = coerce[MissingArrayBuilder](ref).invoke[Unit]("clear")
def clear(cb: EmitCodeBuilder): Unit =
cb += coerce[MissingArrayBuilder](ref).invoke[Unit]("clear")

def loadFromIndex(cb: EmitCodeBuilder, r: Value[Region], i: Code[Int]): IEmitCode = {
val idx = cb.newLocal[Int]("loadFromIndex_idx", i)
IEmitCode(cb, isMissing(idx), elt.loadToSValue(cb, cb.memoizeAny(apply(idx), ti)))
def loadFromIndex(cb: EmitCodeBuilder, r: Value[Region], i: Value[Int]): IEmitCode = {
IEmitCode(cb, isMissing(i), elt.loadToSValue(cb, cb.memoizeAny(apply(i), ti)))
}
}

Expand Down
12 changes: 6 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/streams/StreamUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object StreamUtils {
val aTyp = PCanonicalArray(stream.element.emitType.storageType, true)
stream.length match {
case None =>
val vab = new StagedArrayBuilder(SingleCodeType.fromSType(stream.element.st), stream.element.required, mb, 0)
val vab = new StagedArrayBuilder(cb, SingleCodeType.fromSType(stream.element.st), stream.element.required, 0)
writeToArrayBuilder(cb, stream, vab, destRegion)
cb.assign(xLen, vab.size)

Expand Down Expand Up @@ -86,17 +86,17 @@ object StreamUtils {
destRegion: Value[Region]
): Unit = {
stream.memoryManagedConsume(destRegion, cb, setup = { cb =>
cb += ab.clear
ab.clear(cb)
stream.length match {
case Some(computeLen) => cb += ab.ensureCapacity(computeLen(cb))
case None => cb += ab.ensureCapacity(16)
case Some(computeLen) => ab.ensureCapacity(cb, computeLen(cb))
case None => ab.ensureCapacity(cb, 16)
}


}) { cb =>
stream.element.toI(cb).consume(cb,
cb += ab.addMissing(),
sc => cb += ab.add(ab.elt.coerceSCode(cb, sc, destRegion, deepCopy = stream.requiresMemoryManagementPerElement).code)
ab.addMissing(cb),
sc => ab.add(cb, ab.elt.coerceSCode(cb, sc, destRegion, deepCopy = stream.requiresMemoryManagementPerElement).code)
)
}
}
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/io/bgen/StagedBGENReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ object BGENFunctions extends RegistryFunctions {
"alleles" -> PCanonicalArray(PCanonicalString(true), true),
"offset" -> PInt64Required)
val bufferSct = SingleCodeType.fromSType(rowPType.sType)
val buffer = new StagedArrayBuilder(bufferSct, true, mb, 8)
val buffer = new StagedArrayBuilder(cb, bufferSct, true, 8)
val currSize = cb.newLocal[Int]("currSize", 0)

val spec = TypedCodecSpec(
Expand Down Expand Up @@ -553,7 +553,7 @@ object BGENFunctions extends RegistryFunctions {
cb += ob.invoke[Unit]("close")

cb.assign(groupIndex, groupIndex + 1)
cb += buffer.clear
buffer.clear(cb)
cb.assign(currSize, 0)
}

Expand All @@ -569,7 +569,7 @@ object BGENFunctions extends RegistryFunctions {
cb.if_(currSize ceq bufferSize, {
dumpBuffer(cb)
})
cb += buffer.add(bufferSct.coerceSCode(cb, row, er.region, false).code)
buffer.add(cb, bufferSct.coerceSCode(cb, row, er.region, false).code)
cb.assign(currSize, currSize + 1)
cb.assign(nWritten, nWritten + 1)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ final case class EDictAsUnsortedArrayOfPairs(val elementType: EType, override va
val decodedUnsortedArray = arrayDecoder(cb, region, in).asInstanceOf[SIndexablePointerValue]
val sct = SingleCodeType.fromSType(decodedUnsortedArray.st.elementType)

val ab = new StagedArrayBuilder(sct, true, cb.emb, 0)
cb.append(ab.ensureCapacity(decodedUnsortedArray.length))
val ab = new StagedArrayBuilder(cb, sct, true, 0)
ab.ensureCapacity(cb, decodedUnsortedArray.length)
decodedUnsortedArray.forEachDefined(cb) { (cb, i, res) =>
cb.append(ab.add(ab.elt.coerceSCode(cb, res, region, false).code))
ab.add(cb, ab.elt.coerceSCode(cb, res, region, false).code)
}

val sorter = new ArraySorter(EmitRegion(cb.emb, region), ab)
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/types/encoded/EUnsortedSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ final case class EUnsortedSet(val elementType: EType, override val required: Boo
val decodedUnsortedArray = arrayDecoder(cb, region, in).asInstanceOf[SIndexablePointerValue]
val sct = SingleCodeType.fromSType(decodedUnsortedArray.st.elementType)

val ab = new StagedArrayBuilder(sct, elementType.required, cb.emb, 0)
cb.append(ab.ensureCapacity(decodedUnsortedArray.length))
val ab = new StagedArrayBuilder(cb, sct, elementType.required, 0)
ab.ensureCapacity(cb, decodedUnsortedArray.length)
decodedUnsortedArray.forEachDefined(cb) { (cb, i, res) =>
cb.append(ab.add(ab.elt.coerceSCode(cb, res, region, false).code))
ab.add(cb, ab.elt.coerceSCode(cb, res, region, false).code)
}

val sorter = new ArraySorter(EmitRegion(cb.emb, region), ab)
Expand Down
7 changes: 3 additions & 4 deletions hail/src/test/scala/is/hail/expr/ir/StagedBTreeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,19 @@ class BTreeBackedSet(ctx: ExecuteContext, region: Region, n: Int) {
val key = new TestBTreeKey(fb.apply_method)
val btree = new AppendOnlyBTree(cb, key, r, root, maxElements = n)

val sab = new StagedArrayBuilder(Int64SingleCodeType, true, fb.apply_method, 16)
val idx = fb.newLocal[Int]()
val returnArray = fb.newLocal[Array[java.lang.Long]]()

fb.emitWithBuilder { cb =>
val sab = new StagedArrayBuilder(cb, Int64SingleCodeType, true, 16)
cb += (r := fb.getCodeParam[Region](1))
cb += (root := fb.getCodeParam[Long](2))
cb += sab.clear
btree.foreach(cb) { (cb, _koff) =>
val koff = cb.memoize(_koff)
val ec = key.loadCompKey(cb, koff)
cb.if_(ec.m,
cb += sab.addMissing(),
cb += sab.add(ec.pv.asInt64.value))
sab.addMissing(cb),
sab.add(cb, ec.pv.asInt64.value))
}
cb += (returnArray := Code.newArray[java.lang.Long](sab.size))
cb.for_(
Expand Down

0 comments on commit 28e56f7

Please sign in to comment.