Skip to content

Commit

Permalink
[query] fix bug in new dict decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Oct 27, 2023
1 parent 1c53068 commit a869537
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 52 deletions.
16 changes: 8 additions & 8 deletions hail/src/main/scala/is/hail/expr/ir/ArraySorter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) {

cb.while_(i < size, {
cb.if_(!array.isMissing(i), {
cb.if_(newEnd.cne(i), cb += array.update(newEnd, array.apply(i)))
cb.if_(newEnd.cne(i), array.update(cb, newEnd, array.apply(i)))
cb.assign(newEnd, newEnd + 1)
})
cb.assign(i, i + 1)
})
cb.assign(i, newEnd)
cb.while_(i < size, {
cb += array.setMissing(i, true)
array.setMissing(cb, i, true)
cb.assign(i, i + 1)
})

Expand Down Expand Up @@ -126,7 +126,7 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) {

cb.assign(i, 0)
cb.while_(i < newEnd, {
cb += array.update(i, arrayRef(workingArray2)(i))
array.update(cb, i, arrayRef(workingArray2)(i))
cb.assign(i, i + 1)
})

Expand Down Expand Up @@ -162,12 +162,12 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) {
cb.while_(i < size, {
cb.if_(!array.isMissing(i), {
cb.if_(i.cne(n),
cb += array.update(n, array(i)))
array.update(cb, n, array(i)))
cb.assign(n, n + 1)
})
cb.assign(i, i + 1)
})
cb += array.setSize(n)
array.setSize(cb, n)
}

def distinctFromSorted(cb: EmitCodeBuilder, region: Value[Region], discardNext: (EmitCodeBuilder, Value[Region], EmitCode, EmitCode) => Code[Boolean]): Unit = {
Expand Down Expand Up @@ -197,12 +197,12 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) {
cb.assign(n, n + 1)

cb.if_(i < size && i.cne(n), {
cb += array.setMissing(n, array.isMissing(i))
cb.if_(!array.isMissing(n), cb += array.update(n, array(i)))
array.setMissing(cb, n, array.isMissing(i))
cb.if_(!array.isMissing(n), array.update(cb, n, array(i)))
})

})
cb += array.setSize(n)
array.setSize(cb, n)
}

cb.invokeVoid(distinctMB, region)
Expand Down
18 changes: 8 additions & 10 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ class Emit[C](

val sct = SingleCodeType.fromSType(producer.element.st)

val vab = new StagedArrayBuilder(sct, producer.element.required, mb, 0)
val vab = new StagedArrayBuilder(cb, sct, producer.element.required, 0)
StreamUtils.writeToArrayBuilder(cb, producer, vab, region)
val sorter = new ArraySorter(EmitRegion(mb, region), vab)
sorter.sort(cb, region, makeDependentSortingFunction(cb, sct, lessThan, env, emitSelf, Array(left, right)))
Expand Down Expand Up @@ -1188,7 +1188,7 @@ class Emit[C](

val sct = SingleCodeType.fromSType(producer.element.st)

val vab = new StagedArrayBuilder(sct, producer.element.required, mb, 0)
val vab = new StagedArrayBuilder(cb, sct, producer.element.required, 0)
StreamUtils.writeToArrayBuilder(cb, producer, vab, region)
val sorter = new ArraySorter(EmitRegion(mb, region), vab)

Expand All @@ -1214,7 +1214,7 @@ class Emit[C](

val sct = SingleCodeType.fromSType(producer.element.st)

val vab = new StagedArrayBuilder(sct, producer.element.required, mb, 0)
val vab = new StagedArrayBuilder(cb, sct, producer.element.required, 0)
StreamUtils.writeToArrayBuilder(cb, producer, vab, region)
val sorter = new ArraySorter(EmitRegion(mb, region), vab)

Expand Down Expand Up @@ -1255,7 +1255,7 @@ class Emit[C](

val producer = stream.getProducer(mb)
val sct = SingleCodeType.fromSType(producer.element.st)
val sortedElts = new StagedArrayBuilder(sct, producer.element.required, mb, 16)
val sortedElts = new StagedArrayBuilder(cb, sct, producer.element.required, 16)
StreamUtils.writeToArrayBuilder(cb, producer, sortedElts, region)
val sorter = new ArraySorter(EmitRegion(mb, region), sortedElts)

Expand All @@ -1270,20 +1270,18 @@ class Emit[C](
sorter.sort(cb, region, lt)
sorter.pruneMissing(cb)

val groupSizes = new StagedArrayBuilder(Int32SingleCodeType, true, mb, 0)
val groupSizes = new StagedArrayBuilder(cb, Int32SingleCodeType, true, 0)

val eltIdx = mb.newLocal[Int]("groupByKey_eltIdx")
val grpIdx = mb.newLocal[Int]("groupByKey_grpIdx")
val withinGrpIdx = mb.newLocal[Int]("groupByKey_withinGrpIdx")
val outerSize = mb.newLocal[Int]("groupByKey_outerSize")
val groupSize = mb.newLocal[Int]("groupByKey_groupSize")


cb += groupSizes.clear
cb.assign(eltIdx, 0)
cb.assign(groupSize, 0)

def sameKeyAtIndices(cb: EmitCodeBuilder, region: Value[Region], idx1: Code[Int], idx2: Code[Int]): Code[Boolean] = {
def sameKeyAtIndices(cb: EmitCodeBuilder, region: Value[Region], idx1: Value[Int], idx2: Value[Int]): Code[Boolean] = {
val lk = cb.memoize(
sortedElts.loadFromIndex(cb, region, idx1).flatMap(cb) { x =>
x.asBaseStruct.loadField(cb, 0)
Expand All @@ -1306,14 +1304,14 @@ class Emit[C](
cb.if_(eltIdx.ceq(sortedElts.size - 1), {
cb.goto(newGroup)
}, {
cb.if_(sameKeyAtIndices(cb, region, eltIdx, eltIdx + 1), {
cb.if_(sameKeyAtIndices(cb, region, eltIdx, cb.memoize(eltIdx + 1)), {
cb.goto(bottomOfLoop)
}, {
cb.goto(newGroup)
})
})
cb.define(newGroup)
cb += groupSizes.add(groupSize)
groupSizes.add(cb, groupSize)
cb.assign(groupSize, 0)

cb.define(bottomOfLoop)
Expand Down
39 changes: 24 additions & 15 deletions hail/src/main/scala/is/hail/expr/ir/SpecializedArrayBuilders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import is.hail.utils.BoxedArrayBuilder

import scala.reflect.ClassTag

class StagedArrayBuilder(val elt: SingleCodeType, val eltRequired: Boolean, mb: EmitMethodBuilder[_], len: Code[Int]) {
class StagedArrayBuilder(cb: EmitCodeBuilder, val elt: SingleCodeType, val eltRequired: Boolean, len: Value[Int]) {

def mb = cb.emb
val ti: TypeInfo[_] = elt.ti

val ref: Value[Any] = coerce[Any](ti match {
Expand All @@ -22,13 +23,19 @@ class StagedArrayBuilder(val elt: SingleCodeType, val eltRequired: Boolean, mb:
case ti => throw new RuntimeException(s"unsupported typeinfo found: $ti")
})

def add(x: Code[_]): Code[Unit] = ti match {
// If a method containing `new StagedArrayBuilder(...)` is called multiple times,
// the invocations will share the same array builder at runtime. Clearing
// here ensures a "new" array builder is always empty.
clear(cb)
ensureCapacity(cb, len)

def add(cb: EmitCodeBuilder, x: Code[_]): Unit = cb.append(ti match {
case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Boolean, Unit]("add", coerce[Boolean](x))
case IntInfo => coerce[IntMissingArrayBuilder](ref).invoke[Int, Unit]("add", coerce[Int](x))
case LongInfo => coerce[LongMissingArrayBuilder](ref).invoke[Long, Unit]("add", coerce[Long](x))
case FloatInfo => coerce[FloatMissingArrayBuilder](ref).invoke[Float, Unit]("add", coerce[Float](x))
case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Double, Unit]("add", coerce[Double](x))
}
})

def apply(i: Code[Int]): Code[_] = ti match {
case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Int, Boolean]("apply", i)
Expand All @@ -38,34 +45,36 @@ class StagedArrayBuilder(val elt: SingleCodeType, val eltRequired: Boolean, mb:
case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Int, Double]("apply", i)
}

def update(i: Code[Int], x: Code[_]): Code[Unit] = ti match {
def update(cb: EmitCodeBuilder, i: Code[Int], x: Code[_]): Unit = cb.append(ti match {
case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Int, Boolean, Unit]("update", i, coerce[Boolean](x))
case IntInfo => coerce[IntMissingArrayBuilder](ref).invoke[Int, Int, Unit]("update", i, coerce[Int](x))
case LongInfo => coerce[LongMissingArrayBuilder](ref).invoke[Int, Long, Unit]("update", i, coerce[Long](x))
case FloatInfo => coerce[FloatMissingArrayBuilder](ref).invoke[Int, Float, Unit]("update", i, coerce[Float](x))
case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Int, Double, Unit]("update", i, coerce[Double](x))
}
})

def addMissing(): Code[Unit] =
coerce[MissingArrayBuilder](ref).invoke[Unit]("addMissing")
def addMissing(cb: EmitCodeBuilder): Unit =
cb += coerce[MissingArrayBuilder](ref).invoke[Unit]("addMissing")

def isMissing(i: Code[Int]): Code[Boolean] =
coerce[MissingArrayBuilder](ref).invoke[Int, Boolean]("isMissing", i)

def setMissing(i: Code[Int], m: Code[Boolean]): Code[Unit] =
coerce[MissingArrayBuilder](ref).invoke[Int, Boolean, Unit]("setMissing", i, m)
def setMissing(cb: EmitCodeBuilder, i: Code[Int], m: Code[Boolean]): Unit =
cb += coerce[MissingArrayBuilder](ref).invoke[Int, Boolean, Unit]("setMissing", i, m)

def size: Code[Int] = coerce[MissingArrayBuilder](ref).invoke[Int]("size")

def setSize(n: Code[Int]): Code[Unit] = coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("setSize", n)
def setSize(cb: EmitCodeBuilder, n: Code[Int]): Unit =
cb += coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("setSize", n)

def ensureCapacity(n: Code[Int]): Code[Unit] = coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("ensureCapacity", n)
def ensureCapacity(cb: EmitCodeBuilder, n: Code[Int]): Unit =
cb += coerce[MissingArrayBuilder](ref).invoke[Int, Unit]("ensureCapacity", n)

def clear: Code[Unit] = coerce[MissingArrayBuilder](ref).invoke[Unit]("clear")
def clear(cb: EmitCodeBuilder): Unit =
cb += coerce[MissingArrayBuilder](ref).invoke[Unit]("clear")

def loadFromIndex(cb: EmitCodeBuilder, r: Value[Region], i: Code[Int]): IEmitCode = {
val idx = cb.newLocal[Int]("loadFromIndex_idx", i)
IEmitCode(cb, isMissing(idx), elt.loadToSValue(cb, cb.memoizeAny(apply(idx), ti)))
def loadFromIndex(cb: EmitCodeBuilder, r: Value[Region], i: Value[Int]): IEmitCode = {
IEmitCode(cb, isMissing(i), elt.loadToSValue(cb, cb.memoizeAny(apply(i), ti)))
}
}

Expand Down
12 changes: 6 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/streams/StreamUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object StreamUtils {
val aTyp = PCanonicalArray(stream.element.emitType.storageType, true)
stream.length match {
case None =>
val vab = new StagedArrayBuilder(SingleCodeType.fromSType(stream.element.st), stream.element.required, mb, 0)
val vab = new StagedArrayBuilder(cb, SingleCodeType.fromSType(stream.element.st), stream.element.required, 0)
writeToArrayBuilder(cb, stream, vab, destRegion)
cb.assign(xLen, vab.size)

Expand Down Expand Up @@ -86,17 +86,17 @@ object StreamUtils {
destRegion: Value[Region]
): Unit = {
stream.memoryManagedConsume(destRegion, cb, setup = { cb =>
cb += ab.clear
ab.clear(cb)
stream.length match {
case Some(computeLen) => cb += ab.ensureCapacity(computeLen(cb))
case None => cb += ab.ensureCapacity(16)
case Some(computeLen) => ab.ensureCapacity(cb, computeLen(cb))
case None => ab.ensureCapacity(cb, 16)
}


}) { cb =>
stream.element.toI(cb).consume(cb,
cb += ab.addMissing(),
sc => cb += ab.add(ab.elt.coerceSCode(cb, sc, destRegion, deepCopy = stream.requiresMemoryManagementPerElement).code)
ab.addMissing(cb),
sc => ab.add(cb, ab.elt.coerceSCode(cb, sc, destRegion, deepCopy = stream.requiresMemoryManagementPerElement).code)
)
}
}
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/io/bgen/StagedBGENReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ object BGENFunctions extends RegistryFunctions {
"alleles" -> PCanonicalArray(PCanonicalString(true), true),
"offset" -> PInt64Required)
val bufferSct = SingleCodeType.fromSType(rowPType.sType)
val buffer = new StagedArrayBuilder(bufferSct, true, mb, 8)
val buffer = new StagedArrayBuilder(cb, bufferSct, true, 8)
val currSize = cb.newLocal[Int]("currSize", 0)

val spec = TypedCodecSpec(
Expand Down Expand Up @@ -553,7 +553,7 @@ object BGENFunctions extends RegistryFunctions {
cb += ob.invoke[Unit]("close")

cb.assign(groupIndex, groupIndex + 1)
cb += buffer.clear
buffer.clear(cb)
cb.assign(currSize, 0)
}

Expand All @@ -569,7 +569,7 @@ object BGENFunctions extends RegistryFunctions {
cb.if_(currSize ceq bufferSize, {
dumpBuffer(cb)
})
cb += buffer.add(bufferSct.coerceSCode(cb, row, er.region, false).code)
buffer.add(cb, bufferSct.coerceSCode(cb, row, er.region, false).code)
cb.assign(currSize, currSize + 1)
cb.assign(nWritten, nWritten + 1)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ final case class EDictAsUnsortedArrayOfPairs(val elementType: EType, override va
val decodedUnsortedArray = arrayDecoder(cb, region, in).asInstanceOf[SIndexablePointerValue]
val sct = SingleCodeType.fromSType(decodedUnsortedArray.st.elementType)

val ab = new StagedArrayBuilder(sct, true, cb.emb, 0)
cb.append(ab.ensureCapacity(decodedUnsortedArray.length))
val ab = new StagedArrayBuilder(cb, sct, true, decodedUnsortedArray.length)
decodedUnsortedArray.forEachDefined(cb) { (cb, i, res) =>
cb.append(ab.add(ab.elt.coerceSCode(cb, res, region, false).code))
ab.add(cb, ab.elt.coerceSCode(cb, res, region, false).code)
}

val sorter = new ArraySorter(EmitRegion(cb.emb, region), ab)
Expand Down
5 changes: 2 additions & 3 deletions hail/src/main/scala/is/hail/types/encoded/EUnsortedSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ final case class EUnsortedSet(val elementType: EType, override val required: Boo
val decodedUnsortedArray = arrayDecoder(cb, region, in).asInstanceOf[SIndexablePointerValue]
val sct = SingleCodeType.fromSType(decodedUnsortedArray.st.elementType)

val ab = new StagedArrayBuilder(sct, elementType.required, cb.emb, 0)
cb.append(ab.ensureCapacity(decodedUnsortedArray.length))
val ab = new StagedArrayBuilder(cb, sct, elementType.required, decodedUnsortedArray.length)
decodedUnsortedArray.forEachDefined(cb) { (cb, i, res) =>
cb.append(ab.add(ab.elt.coerceSCode(cb, res, region, false).code))
ab.add(cb, ab.elt.coerceSCode(cb, res, region, false).code)
}

val sorter = new ArraySorter(EmitRegion(cb.emb, region), ab)
Expand Down
7 changes: 3 additions & 4 deletions hail/src/test/scala/is/hail/expr/ir/StagedBTreeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,19 @@ class BTreeBackedSet(ctx: ExecuteContext, region: Region, n: Int) {
val key = new TestBTreeKey(fb.apply_method)
val btree = new AppendOnlyBTree(cb, key, r, root, maxElements = n)

val sab = new StagedArrayBuilder(Int64SingleCodeType, true, fb.apply_method, 16)
val idx = fb.newLocal[Int]()
val returnArray = fb.newLocal[Array[java.lang.Long]]()

fb.emitWithBuilder { cb =>
val sab = new StagedArrayBuilder(cb, Int64SingleCodeType, true, 16)
cb += (r := fb.getCodeParam[Region](1))
cb += (root := fb.getCodeParam[Long](2))
cb += sab.clear
btree.foreach(cb) { (cb, _koff) =>
val koff = cb.memoize(_koff)
val ec = key.loadCompKey(cb, koff)
cb.if_(ec.m,
cb += sab.addMissing(),
cb += sab.add(ec.pv.asInt64.value))
sab.addMissing(cb),
sab.add(cb, ec.pv.asInt64.value))
}
cb += (returnArray := Code.newArray[java.lang.Long](sab.size))
cb.for_(
Expand Down

0 comments on commit a869537

Please sign in to comment.