Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Stage IndexReader #12159

Merged
merged 11 commits into from
Sep 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hail/python/test/hail/table/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,12 +903,14 @@ def test_indexed_read(self):
t = hl.utils.range_table(2000, 10)
f = new_temp_file(extension='ht')
t.write(f)

t2 = hl.read_table(f, _intervals=[
hl.Interval(start=150, end=250, includes_start=True, includes_end=False),
hl.Interval(start=250, end=500, includes_start=True, includes_end=False),
])
self.assertEqual(t2.n_partitions(), 2)
self.assertEqual(t2.count(), 350)
self.assertEqual(t2._force_count(), 350)
self.assertTrue(t.filter((t.idx >= 150) & (t.idx < 500))._same(t2))

t2 = hl.read_table(f, _intervals=[
Expand Down
214 changes: 0 additions & 214 deletions hail/src/main/scala/is/hail/HailContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,103 +150,6 @@ object HailContext {
theContext = null
}

def readRowsPartition(
makeDec: (InputStream, HailClassLoader) => Decoder
)(theHailClassLoader: HailClassLoader,
r: Region,
in: InputStream,
metrics: InputMetrics = null
): Iterator[Long] =
new Iterator[Long] {
private val region = r

private val trackedIn = new ByteTrackingInputStream(in)
private val dec =
try {
makeDec(trackedIn, theHailClassLoader)
} catch {
case e: Exception =>
in.close()
throw e
}

private var cont: Byte = dec.readByte()
if (cont == 0)
dec.close()

// can't throw
def hasNext: Boolean = cont != 0

def next(): Long = {
// !hasNext => cont == 0 => dec has been closed
if (!hasNext)
throw new NoSuchElementException("next on empty iterator")

try {
val res = dec.readRegionValue(region)
cont = dec.readByte()
if (metrics != null) {
ExposedMetrics.incrementRecord(metrics)
ExposedMetrics.incrementBytes(metrics, trackedIn.bytesReadAndClear())
}

if (cont == 0)
dec.close()

res
} catch {
case e: Exception =>
dec.close()
throw e
}
}

override def finalize(): Unit = {
dec.close()
}
}

def readRowsIndexedPartition(
makeDec: (InputStream, HailClassLoader) => Decoder
)(theHailClassLoader: HailClassLoader,
ctx: RVDContext,
in: InputStream,
idxr: IndexReader,
offsetField: Option[String],
bounds: Option[Interval],
metrics: InputMetrics = null
): Iterator[Long] =
bounds match {
case Some(b) =>
new IndexReadIterator(theHailClassLoader, makeDec, ctx.r, in, idxr, offsetField.orNull, b, metrics)
case None =>
idxr.close()
HailContext.readRowsPartition(makeDec)(theHailClassLoader, ctx.r, in, metrics)
}

def readSplitRowsPartition(
theHailClassLoader: HailClassLoader,
fs: BroadcastValue[FS],
mkRowsDec: (InputStream, HailClassLoader) => Decoder,
mkEntriesDec: (InputStream, HailClassLoader) => Decoder,
mkInserter: (HailClassLoader, FS, Int, Region) => AsmFunction3RegionLongLongLong
)(ctx: RVDContext,
isRows: InputStream,
isEntries: InputStream,
idxr: Option[IndexReader],
rowsOffsetField: Option[String],
entriesOffsetField: Option[String],
bounds: Option[Interval],
partIdx: Int,
metrics: InputMetrics = null
): Iterator[Long] = new MaybeIndexedReadZippedIterator(
is => mkRowsDec(is, theHailClassLoaderForSparkWorkers),
is => mkEntriesDec(is, theHailClassLoaderForSparkWorkers),
mkInserter(theHailClassLoader, fs.value, partIdx, ctx.partitionRegion),
ctx.r,
isRows, isEntries,
idxr.orNull, rowsOffsetField.orNull, entriesOffsetField.orNull, bounds.orNull, metrics)

def pyRemoveIrVector(id: Int) {
get.irVectors.remove(id)
}
Expand Down Expand Up @@ -275,123 +178,6 @@ object HailContext {
@transient override val partitioner: Option[Partitioner] = optPartitioner
}
}

def readRows(
ctx: ExecuteContext,
path: String,
enc: AbstractTypedCodecSpec,
partFiles: Array[String],
requestedType: TStruct
): (PStruct, ContextRDD[Long]) = {
val fs = ctx.fs
val (pType: PStruct, makeDec) = enc.buildDecoder(ctx, requestedType)
(pType, ContextRDD.weaken(HailContext.readPartitions(fs, path, partFiles, (_, is, m) => Iterator.single(is -> m)))
.cmapPartitions { (ctx, it) =>
assert(it.hasNext)
val (is, m) = it.next
assert(!it.hasNext)
HailContext.readRowsPartition(makeDec)(theHailClassLoaderForSparkWorkers, ctx.r, is, m)
})
}

def readIndexedRows(
ctx: ExecuteContext,
path: String,
indexSpec: AbstractIndexSpec,
enc: AbstractTypedCodecSpec,
partFiles: Array[String],
bounds: Array[Interval],
requestedType: TStruct
): (PStruct, ContextRDD[Long]) = {
val (pType: PStruct, makeDec) = enc.buildDecoder(ctx, requestedType)
(pType, ContextRDD.weaken(readIndexedPartitions(ctx, path, indexSpec, partFiles, Some(bounds)))
.cmapPartitions { (ctx, it) =>
assert(it.hasNext)
val (is, idxr, bounds, m) = it.next
assert(!it.hasNext)
readRowsIndexedPartition(makeDec)(theHailClassLoaderForSparkWorkers, ctx, is, idxr, indexSpec.offsetField, bounds, m)
})
}

def readIndexedPartitions(
ctx: ExecuteContext,
path: String,
indexSpec: AbstractIndexSpec,
partFiles: Array[String],
intervalBounds: Option[Array[Interval]] = None
): RDD[(InputStream, IndexReader, Option[Interval], InputMetrics)] = {
val idxPath = indexSpec.relPath
val fsBc = ctx.fsBc
val (keyType, annotationType) = indexSpec.types
indexSpec.offsetField.foreach { f =>
require(annotationType.asInstanceOf[TStruct].hasField(f))
require(annotationType.asInstanceOf[TStruct].fieldType(f) == TInt64)
}
val (leafPType: PStruct, leafDec) = indexSpec.leafCodec.buildDecoder(ctx, indexSpec.leafCodec.encodedVirtualType)
val (intPType: PStruct, intDec) = indexSpec.internalNodeCodec.buildDecoder(ctx, indexSpec.internalNodeCodec.encodedVirtualType)
val mkIndexReader = IndexReaderBuilder.withDecoders(leafDec, intDec, keyType, annotationType, leafPType, intPType)

new IndexReadRDD(partFiles, intervalBounds, { (p, context) =>
val fs = fsBc.value
val idxname = s"$path/$idxPath/${ p.file }.idx"
val filename = s"$path/parts/${ p.file }"
val idxr = mkIndexReader(theHailClassLoaderForSparkWorkers, fs, idxname, 8, SparkTaskContext.get().getRegionPool()) // default cache capacity
val in = fs.open(filename)
(in, idxr, p.bounds, context.taskMetrics().inputMetrics)
})
}


def readRowsSplit(
ctx: ExecuteContext,
pathRows: String,
pathEntries: String,
indexSpecRows: Option[AbstractIndexSpec],
indexSpecEntries: Option[AbstractIndexSpec],
partFiles: Array[String],
bounds: Array[Interval],
makeRowsDec: (InputStream, HailClassLoader) => Decoder,
makeEntriesDec: (InputStream, HailClassLoader) => Decoder,
makeInserter: (HailClassLoader, FS, Int, Region) => AsmFunction3RegionLongLongLong
): ContextRDD[Long] = {
require(!(indexSpecRows.isEmpty ^ indexSpecEntries.isEmpty))
val fsBc = ctx.fsBc

val mkIndexReader = indexSpecRows.map { indexSpec =>
val (keyType, annotationType) = indexSpec.types
indexSpec.offsetField.foreach { f =>
require(annotationType.asInstanceOf[TStruct].hasField(f))
require(annotationType.asInstanceOf[TStruct].fieldType(f) == TInt64)
}
indexSpecEntries.get.offsetField.foreach { f =>
require(annotationType.asInstanceOf[TStruct].hasField(f))
require(annotationType.asInstanceOf[TStruct].fieldType(f) == TInt64)
}
IndexReaderBuilder.fromSpec(ctx, indexSpec)
}

val rdd = new IndexReadRDD(partFiles, indexSpecRows.map(_ => bounds), (p, context) => {
val fs = fsBc.value
val idxr = mkIndexReader.map { mk =>
val idxname = s"$pathRows/${ indexSpecRows.get.relPath }/${ p.file }.idx"
mk(theHailClassLoaderForSparkWorkers, fs, idxname, 8, SparkTaskContext.get().getRegionPool()) // default cache capacity
}
val inRows = fs.open(s"$pathRows/parts/${ p.file }")
val inEntries = fs.open(s"$pathEntries/parts/${ p.file }")
(inRows, inEntries, idxr, p.bounds, context.taskMetrics().inputMetrics)
})

val rowsOffsetField = indexSpecRows.flatMap(_.offsetField)
val entriesOffsetField = indexSpecEntries.flatMap(_.offsetField)
ContextRDD.weaken(rdd).cmapPartitionsWithIndex { (i, ctx, it) =>
assert(it.hasNext)
val (isRows, isEntries, idxr, bounds, m) = it.next
assert(!it.hasNext)
HailContext.readSplitRowsPartition(theHailClassLoaderForSparkWorkers, fsBc, makeRowsDec, makeEntriesDec, makeInserter)(
ctx, isRows, isEntries, idxr, rowsOffsetField, entriesOffsetField, bounds, i, m)
}

}
}

class HailContext private(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,6 @@ trait ShimRVDSpec extends AbstractRVDSpec {

override def partitioner: RVDPartitioner = shim.partitioner

override def read(
ctx: ExecuteContext,
path: String,
requestedType: TStruct,
newPartitioner: Option[RVDPartitioner],
filterIntervals: Boolean
): RVD = shim.read(ctx, path, requestedType, newPartitioner, filterIntervals)

override def typedCodecSpec: AbstractTypedCodecSpec = shim.typedCodecSpec

override def partFiles: Array[String] = shim.partFiles
Expand Down
18 changes: 0 additions & 18 deletions hail/src/main/scala/is/hail/expr/ir/AbstractMatrixTableSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,24 +110,6 @@ case class RVDComponentSpec(rel_path: String) extends ComponentSpec {
AbstractRVDSpec.read(fs, absolutePath(path))

def indexed(fs: FS, path: String): Boolean = rvdSpec(fs, path).indexed

def read(
ctx: ExecuteContext,
path: String,
requestedType: TStruct,
newPartitioner: Option[RVDPartitioner] = None,
filterIntervals: Boolean = false
): RVD = {
val rvdPath = path + "/" + rel_path
rvdSpec(ctx.fs, path)
.read(ctx, rvdPath, requestedType, newPartitioner, filterIntervals)
}

def readLocalSingleRow(ctx: ExecuteContext, path: String, requestedType: TStruct): (PStruct, Long) = {
val rvdPath = path + "/" + rel_path
rvdSpec(ctx.fs, path)
.readLocalSingleRow(ctx, rvdPath, requestedType)
}
}

case class PartitionCountsComponentSpec(counts: Seq[Long]) extends ComponentSpec
Expand Down
47 changes: 22 additions & 25 deletions hail/src/main/scala/is/hail/expr/ir/BinarySearch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,12 @@ object BinarySearch {
runSearchBounded[T](cb, haystack, compare, 0, haystack.loadLength(), found, notFound)
}

class BinarySearch[C](mb: EmitMethodBuilder[C], containerType: SContainer, eltType: EmitType, keyOnly: Boolean) {
class BinarySearch[C](mb: EmitMethodBuilder[C],
containerType: SContainer,
eltType: EmitType,
getKey: (EmitCodeBuilder, EmitValue) => EmitValue,
bound: String = "lower",
ltF: CodeOrdering.F[Boolean] = null) {
val containerElementType: EmitType = containerType.elementEmitType
val findElt = mb.genEmitMethod("findElt", FastIndexedSeq[ParamType](containerType.paramType, eltType.paramType), typeInfo[Int])

Expand All @@ -289,35 +294,27 @@ class BinarySearch[C](mb: EmitMethodBuilder[C], containerType: SContainer, eltTy
val haystack = findElt.getSCodeParam(1).asIndexable
val needle = findElt.getEmitParam(cb, 2, null) // no streams

def ltNeedle(x: IEmitCode): Code[Boolean] = if (keyOnly) {
val kt: EmitType = containerElementType.st match {
case s: SBaseStruct =>
require(s.size == 2)
s.fieldEmitTypes(0)
case interval: SInterval =>
interval.pointEmitType
}

val keyLT = mb.ecb.getOrderingFunction(kt.st, eltType.st, CodeOrdering.Lt())

val key = cb.memoize(x.flatMap(cb) {
case x: SBaseStructValue =>
x.loadField(cb, 0)
case x: SIntervalValue =>
x.loadStart(cb)
})

keyLT(cb, key, needle)
} else {
val lt = mb.ecb.getOrderingFunction(containerElementType.st, eltType.st, CodeOrdering.Lt())
lt(cb, cb.memoize(x), needle)
val f: (
EmitCodeBuilder,
SIndexableValue,
IEmitCode => Code[Boolean],
Value[Int],
Value[Int]
) => Value[Int] = bound match {
case "upper" => BinarySearch.upperBound
case "lower" => BinarySearch.lowerBound
}

BinarySearch.lowerBound(cb, haystack, ltNeedle)
f(cb, haystack, { containerElement =>
val elementVal = cb.memoize(containerElement, "binary_search_elt")
val compareVal = getKey(cb, elementVal)
val lt = Option(ltF).getOrElse(mb.ecb.getOrderingFunction(compareVal.st, eltType.st, CodeOrdering.Lt()))
lt(cb, compareVal, needle)
}, 0, haystack.loadLength())
}

// check missingness of v before calling
def lowerBound(cb: EmitCodeBuilder, array: SValue, v: EmitCode): Value[Int] = {
def search(cb: EmitCodeBuilder, array: SValue, v: EmitCode): Value[Int] = {
cb.memoize(cb.invokeCode[Int](findElt, array, v))
}
}
14 changes: 12 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1104,8 +1104,18 @@ class Emit[C](
case x@LowerBoundOnOrderedCollection(orderedCollection, elem, onKey) =>
emitI(orderedCollection).map(cb) { a =>
val e = EmitCode.fromI(cb.emb)(cb => this.emitI(elem, cb, region, env, container, loopEnv))
val bs = new BinarySearch[C](mb, a.st.asInstanceOf[SContainer], e.emitType, keyOnly = onKey)
primitive(bs.lowerBound(cb, a, e))
val bs = new BinarySearch[C](mb, a.st.asInstanceOf[SContainer], e.emitType, { (cb, elt) =>

if (onKey) {
cb.memoize(elt.toI(cb).flatMap(cb) {
case x: SBaseStructValue =>
x.loadField(cb, 0)
case x: SIntervalValue =>
x.loadStart(cb)
})
} else elt
})
primitive(bs.search(cb, a, e))
}

case x@ArraySort(a, left, right, lessThan) =>
Expand Down
Loading