From 2b9a40274d21eb6523fc42d55c4c0477c9a821ee Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 4 May 2018 15:50:44 -0400 Subject: [PATCH] RVD spicy meatball (now with only +876/-587) (#3414) * # This is a combination of 22 commits. # This is the 1st commit message: apply resettable context forgot to fix one use of AutoCloseable fix add setup iterator more sensible method ordering make TrivialContext Resettable a few more missing resettablecontexts address comments apply resettable context forgot to fix one use of AutoCloseable fix add setup iterator more sensible method ordering remove rogue element type type make TrivialContext Resettable wip wip wip wip use safe row in join suite pull over hailcontext remove Region.clear(newEnd) add selectRegionValue # This is the commit message #2: convert relational.scala ; # This is the commit message #3: scope the extract aggregators constfb call # This is the commit message #4: scope interpret # This is the commit message #5: typeAfterSelect used by selectRegionValue # This is the commit message #6: load matrix # This is the commit message #7: imports # This is the commit message #8: loadbgen converted # This is the commit message #9: convert loadplink # This is the commit message #10: convert loadgdb # This is the commit message #11: convert loadvcf # This is the commit message #12: convert blockmatrix # This is the commit message #13: convert filterintervals # This is the commit message #14: convert ibd # This is the commit message #15: convert a few methods # This is the commit message #16: convert split multi # This is the commit message #17: convert VEP # This is the commit message #18: formatting fix # This is the commit message #19: add partitionBy and values # This is the commit message #20: fix bug in localkeysort # This is the commit message #21: fixup HailContext.readRowsPartition use # This is the commit message #22: port balding nichols model * apply resettable context forgot to fix one use of AutoCloseable fix add setup iterator more sensible method ordering make TrivialContext Resettable a few more missing resettablecontexts address comments apply resettable context forgot to fix one use of AutoCloseable fix add setup iterator more sensible method ordering remove rogue element type type make TrivialContext Resettable wip wip wip wip use safe row in join suite pull over hailcontext remove Region.clear(newEnd) add selectRegionValue convert relational.scala ; scope the extract aggregators constfb call scope interpret typeAfterSelect used by selectRegionValue load matrix imports loadbgen converted convert loadplink convert loadgdb convert loadvcf convert blockmatrix convert filterintervals convert ibd convert a few methods convert split multi convert VEP formatting fix add partitionBy and values fix bug in localkeysort fixup HailContext.readRowsPartition use port balding nichols model port over table.scala couple fixes convert matrix table remove necessary use of rdd variety of fixups wip add a clear * Remove direct Region allocation from FilterColsIR When regions are off-heap, we can allow the globals to live in a separate, longer-lived Region that is not cleared until the whole partition is finished. For now, we pay the memory cost. * Use RVDContext in MatrixRead zip This Region will get cleared by consumers. I introduced the zip primitive which is a safer way to zip two RVDs because it does not rely on the user correctly clearing the regions used by the left and right hand sides of the zip. * Control the Regions in LoadGDB I do not fully understand how LoadGDB is working, but a simple solution to the use-case is to serialize to arrays of bytes and parallelize those. I realize there is a proliferation of `coerce` methods. I plan to trim this down once we do not have RDD and ContextRDD coexisting * wip * unify RVD.run * reset in write * fixes * use context region when allocating * also read RVDs using RVDContext * formatting * address comments * remove unused val * abstract over boundary * little fixes * whoops forgot to clear before persisting This fixes the LDPrune if you dont clear the region things go wrong. Not sure what causes that bug. Maybe its something about encoders? * serialize for shuffles, region.scoped in matrixmapglobals, fix joins * clear more! * wip * wip * rework GeneralRDD to ease ContextRDD transition * formatting * final fixes * formatting * merge failures * more bad merge stuff * formatting * remove unnecessary stuff * remove fixme * boom! * variety of merge mistakes * fix destabilize bug * add missing newline * remember to clear the producer region in localkeysort * switch def to val * cleanup filteralleles and exporbidbimfam * fix clearing and serialization issue * fix BitPackedVectorView Previously it always assumed the variant struct started at offset zero, which is not true * address comments, remove a comment * remove direct use of Region * oops * werrrks, mebbe * needs cleanup * fix filter intervals * fixes * fixes * fix filterintervals * remove unnecessary copy in TableJoin * and finally fix the last test * re-use existing CodecSpec definition * remove unnecessary boundaries * use RVD abstraction when possible * formatting * bugfix: RegionValue must know its region * remove unnecessary val and comment * remove unused methods * eliminate unused constructors * undo debug change * formatting * remove unused imports * fix bug in tablejoin * fix RichRDDSuite test If you have no data, then you have no partitions, not 1 partition --- build.sbt | 1 + src/main/scala/is/hail/HailContext.scala | 76 ++-- .../hail/annotations/OrderedRVIterator.scala | 30 +- .../scala/is/hail/annotations/Region.scala | 5 - .../annotations/WritableRegionValue.scala | 5 +- src/main/scala/is/hail/expr/Relational.scala | 149 ++++---- .../is/hail/expr/ir/ExtractAggregators.scala | 2 +- src/main/scala/is/hail/io/LoadMatrix.scala | 10 +- src/main/scala/is/hail/io/RowStore.scala | 21 +- src/main/scala/is/hail/io/bgen/LoadBgen.scala | 15 +- .../scala/is/hail/io/plink/ExportPlink.scala | 4 +- .../scala/is/hail/io/plink/LoadPlink.scala | 22 +- src/main/scala/is/hail/io/vcf/LoadGDB.scala | 39 +- src/main/scala/is/hail/io/vcf/LoadVCF.scala | 32 +- .../scala/is/hail/linalg/BlockMatrix.scala | 22 +- .../scala/is/hail/methods/FilterAlleles.scala | 12 +- .../is/hail/methods/FilterIntervals.scala | 10 +- src/main/scala/is/hail/methods/IBD.scala | 17 +- .../is/hail/methods/LinearRegression.scala | 10 +- .../scala/is/hail/methods/LocalLDPrune.scala | 6 +- src/main/scala/is/hail/methods/Nirvana.scala | 7 +- src/main/scala/is/hail/methods/PCA.scala | 20 +- src/main/scala/is/hail/methods/PCRelate.scala | 2 +- src/main/scala/is/hail/methods/Skat.scala | 5 +- .../scala/is/hail/methods/SplitMulti.scala | 109 +++--- src/main/scala/is/hail/methods/VEP.scala | 12 +- .../scala/is/hail/rvd/KeyedOrderedRVD.scala | 63 +-- src/main/scala/is/hail/rvd/OrderedRVD.scala | 361 +++++++++++------- src/main/scala/is/hail/rvd/RVD.scala | 254 +++++++++--- src/main/scala/is/hail/rvd/RVDContext.scala | 4 - .../scala/is/hail/rvd/UnpartitionedRVD.scala | 13 +- .../is/hail/sparkextras/BlockedRDD.scala | 2 +- .../is/hail/sparkextras/ContextRDD.scala | 84 +++- .../sparkextras/RepartitionedOrderedRDD.scala | 4 +- .../is/hail/stats/BaldingNicholsModel.scala | 11 +- src/main/scala/is/hail/table/Table.scala | 64 ++-- .../is/hail/utils/FlipbookIterator.scala | 17 +- .../is/hail/utils/richUtils/Implicits.scala | 6 +- .../hail/utils/richUtils/RichContextRDD.scala | 11 +- .../hail/utils/richUtils/RichIterator.scala | 4 +- .../is/hail/utils/richUtils/RichRDD.scala | 13 +- .../scala/is/hail/variant/MatrixTable.scala | 162 ++++---- .../is/hail/methods/LocalLDPruneSuite.scala | 49 +-- .../is/hail/testUtils/RichMatrixTable.scala | 3 +- .../is/hail/utils/FlipbookIteratorSuite.scala | 1 + .../scala/is/hail/utils/RichRDDSuite.scala | 4 +- .../hail/variant/vsm/PartitioningSuite.scala | 12 +- src/test/scala/is/hail/vds/JoinSuite.scala | 16 +- 48 files changed, 1079 insertions(+), 722 deletions(-) diff --git a/build.sbt b/build.sbt index 7bddfa5e680f..d85ac4efcc48 100644 --- a/build.sbt +++ b/build.sbt @@ -42,6 +42,7 @@ lazy val root = (project in file(".")). , "org.json4s" %% "json4s-core" % "3.2.10" , "org.json4s" %% "json4s-jackson" % "3.2.10" , "org.json4s" %% "json4s-ast" % "3.2.10" + , "org.elasticsearch" % "elasticsearch-spark-20_2.11" % "6.2.4" , "org.apache.solr" % "solr-solrj" % "6.2.0" , "com.datastax.cassandra" % "cassandra-driver-core" % "3.0.0" , "com.jayway.restassured" % "rest-assured" % "2.8.0" diff --git a/src/main/scala/is/hail/HailContext.scala b/src/main/scala/is/hail/HailContext.scala index 7f889d24fc72..dddf9c5ea4a6 100644 --- a/src/main/scala/is/hail/HailContext.scala +++ b/src/main/scala/is/hail/HailContext.scala @@ -11,7 +11,9 @@ import is.hail.io.bgen.LoadBgen import is.hail.io.gen.LoadGen import is.hail.io.plink.{FamFileConfig, LoadPlink} import is.hail.io.vcf._ +import is.hail.rvd.RVDContext import is.hail.table.Table +import is.hail.sparkextras.ContextRDD import is.hail.stats.{BaldingNicholsModel, Distribution, UniformDist} import is.hail.utils.{log, _} import is.hail.variant.{MatrixTable, ReferenceGenome, VSMSubgen} @@ -208,9 +210,14 @@ object HailContext { ProgressBarBuilder.build(sc) } - def readRowsPartition(makeDec: (InputStream) => Decoder)(i: Int, in: InputStream, metrics: InputMetrics = null): Iterator[RegionValue] = { + def readRowsPartition( + makeDec: (InputStream) => Decoder + )(ctx: RVDContext, + in: InputStream, + metrics: InputMetrics = null + ): Iterator[RegionValue] = new Iterator[RegionValue] { - private val region = Region() + private val region = ctx.region private val rv = RegionValue(region) private val trackedIn = new ByteTrackingInputStream(in) @@ -236,7 +243,6 @@ object HailContext { throw new NoSuchElementException("next on empty iterator") try { - region.clear() rv.setOffset(dec.readRegionValue(region)) if (metrics != null) { ExposedMetrics.incrementRecord(metrics) @@ -259,7 +265,6 @@ object HailContext { dec.close() } } - } } class HailContext private(val sc: SparkContext, @@ -531,8 +536,19 @@ class HailContext private(val sc: SparkContext, } } - def readRows(path: String, t: TStruct, codecSpec: CodecSpec, partFiles: Array[String]): RDD[RegionValue] = - readPartitions(path, partFiles, HailContext.readRowsPartition(codecSpec.buildDecoder(t))) + def readRows( + path: String, + t: TStruct, + codecSpec: CodecSpec, + partFiles: Array[String] + ): ContextRDD[RVDContext, RegionValue] = + ContextRDD.weaken[RVDContext](readPartitions(path, partFiles, (_, is, m) => Iterator.single(is -> m))) + .cmapPartitions { (ctx, it) => + assert(it.hasNext) + val (is, m) = it.next + assert(!it.hasNext) + HailContext.readRowsPartition(codecSpec.buildDecoder(t))(ctx, is, m) + } def parseVCFMetadata(file: String): Map[String, Map[String, Map[String, String]]] = { val reader = new HtsjdkRecordReader(Set.empty) @@ -663,30 +679,30 @@ class HailContext private(val sc: SparkContext, val ast = Parser.parseToAST(expr, ec) ast.toIR() match { case Some(body) => - val region = Region() - val t = ast.`type` - t match { - case _: TBoolean => - val (_, f) = ir.Compile[Boolean](body) - (f()(region), t) - case _: TInt32 => - val (_, f) = ir.Compile[Int](body) - (f()(region), t) - case _: TInt64 => - val (_, f) = ir.Compile[Long](body) - (f()(region), t) - case _: TFloat32 => - val (_, f) = ir.Compile[Float](body) - (f()(region), t) - case _: TFloat64 => - val (_, f) = ir.Compile[Double](body) - (f()(region), t) - case _ => - val (_, f) = ir.Compile[Long](body) - val off = f()(region) - val v2 = UnsafeRow.read(t, region, off) - val v3 = Annotation.copy(t, v2) - (v3, t) + Region.scoped { region => + val t = ast.`type` + t match { + case _: TBoolean => + val (_, f) = ir.Compile[Boolean](body) + (f()(region), t) + case _: TInt32 => + val (_, f) = ir.Compile[Int](body) + (f()(region), t) + case _: TInt64 => + val (_, f) = ir.Compile[Long](body) + (f()(region), t) + case _: TFloat32 => + val (_, f) = ir.Compile[Float](body) + (f()(region), t) + case _: TFloat64 => + val (_, f) = ir.Compile[Double](body) + (f()(region), t) + case _ => + val (_, f) = ir.Compile[Long](body) + val off = f()(region) + val v2 = SafeRow.read(t, region, off) + (v2, t) + } } case None => val (t, f) = Parser.eval(ast, ec) diff --git a/src/main/scala/is/hail/annotations/OrderedRVIterator.scala b/src/main/scala/is/hail/annotations/OrderedRVIterator.scala index 3608ac157c97..4bde68498f57 100644 --- a/src/main/scala/is/hail/annotations/OrderedRVIterator.scala +++ b/src/main/scala/is/hail/annotations/OrderedRVIterator.scala @@ -3,6 +3,8 @@ package is.hail.annotations import is.hail.rvd.OrderedRVDType import is.hail.utils._ +import scala.collection.generic.Growable + case class OrderedRVIterator(t: OrderedRVDType, iterator: Iterator[RegionValue]) { def restrictToPKInterval(interval: Interval): Iterator[RegionValue] = { @@ -56,50 +58,62 @@ case class OrderedRVIterator(t: OrderedRVDType, iterator: Iterator[RegionValue]) this.t.kComp(other.t).compare ) - def innerJoin(other: OrderedRVIterator): Iterator[JoinedRegionValue] = { + def innerJoin( + other: OrderedRVIterator, + rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.innerJoin( other.iterator.toFlipbookIterator, this.t.kRowOrdView, other.t.kRowOrdView, null, null, - new RegionValueArrayBuffer(other.t.rowType), + rightBuffer, this.t.kComp(other.t).compare ) } - def leftJoin(other: OrderedRVIterator): Iterator[JoinedRegionValue] = { + def leftJoin( + other: OrderedRVIterator, + rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.leftJoin( other.iterator.toFlipbookIterator, this.t.kRowOrdView, other.t.kRowOrdView, null, null, - new RegionValueArrayBuffer(other.t.rowType), + rightBuffer, this.t.kComp(other.t).compare ) } - def rightJoin(other: OrderedRVIterator): Iterator[JoinedRegionValue] = { + def rightJoin( + other: OrderedRVIterator, + rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.rightJoin( other.iterator.toFlipbookIterator, this.t.kRowOrdView, other.t.kRowOrdView, null, null, - new RegionValueArrayBuffer(other.t.rowType), + rightBuffer, this.t.kComp(other.t).compare ) } - def outerJoin(other: OrderedRVIterator): Iterator[JoinedRegionValue] = { + def outerJoin( + other: OrderedRVIterator, + rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.outerJoin( other.iterator.toFlipbookIterator, this.t.kRowOrdView, other.t.kRowOrdView, null, null, - new RegionValueArrayBuffer(other.t.rowType), + rightBuffer, this.t.kComp(other.t).compare ) } diff --git a/src/main/scala/is/hail/annotations/Region.scala b/src/main/scala/is/hail/annotations/Region.scala index 059d86d3c98d..55ddea66a198 100644 --- a/src/main/scala/is/hail/annotations/Region.scala +++ b/src/main/scala/is/hail/annotations/Region.scala @@ -255,11 +255,6 @@ final class Region( off } - def clear(newEnd: Long) { - assert(newEnd <= end) - end = newEnd - } - def clear() { end = 0 } diff --git a/src/main/scala/is/hail/annotations/WritableRegionValue.scala b/src/main/scala/is/hail/annotations/WritableRegionValue.scala index 8bfc06123619..11c8d33af20f 100644 --- a/src/main/scala/is/hail/annotations/WritableRegionValue.scala +++ b/src/main/scala/is/hail/annotations/WritableRegionValue.scala @@ -55,10 +55,9 @@ class WritableRegionValue private (val t: Type) { def pretty: String = value.pretty(t) } -class RegionValueArrayBuffer(val t: Type) +class RegionValueArrayBuffer(val t: Type, region: Region) extends Iterable[RegionValue] with Growable[RegionValue] { - val region = Region() val value = RegionValue(region, 0) private val rvb = new RegionValueBuilder(region) @@ -94,7 +93,7 @@ class RegionValueArrayBuffer(val t: Type) def clear() { region.clear() idx.clear() - rvb.clear() + rvb.clear() // remove } private var itIdx = 0 diff --git a/src/main/scala/is/hail/expr/Relational.scala b/src/main/scala/is/hail/expr/Relational.scala index 1d44a316bebb..f21d99f3d396 100644 --- a/src/main/scala/is/hail/expr/Relational.scala +++ b/src/main/scala/is/hail/expr/Relational.scala @@ -2,12 +2,14 @@ package is.hail.expr import is.hail.HailContext import is.hail.annotations._ +import is.hail.annotations.Annotation._ import is.hail.annotations.aggregators.RegionValueAggregator import is.hail.expr.ir._ import is.hail.expr.types._ import is.hail.io._ import is.hail.methods.Aggregators import is.hail.rvd._ +import is.hail.sparkextras.ContextRDD import is.hail.table.TableSpec import is.hail.variant._ import org.apache.spark.rdd.RDD @@ -201,8 +203,10 @@ case class MatrixValue( val hc = HailContext.get val signature = typ.colType - new UnpartitionedRVD(signature, hc.sc.parallelize(colValues.value.toArray.map(_.asInstanceOf[Row])) - .mapPartitions(_.toRegionValueIterator(signature))) + new UnpartitionedRVD( + signature, + ContextRDD.parallelize(hc.sc, colValues.value.toArray.map(_.asInstanceOf[Row])) + .cmapPartitions { (ctx, it) => it.toRegionValueIterator(ctx.region, signature) }) } def entriesRVD(): RVD = { @@ -216,21 +220,17 @@ case class MatrixValue( val localColValues = colValues.broadcast.value - rvd.mapPartitions(resultStruct) { it => - - val rv2b = new RegionValueBuilder() - val rv2 = RegionValue() + rvd.boundary.mapPartitions(resultStruct, { (ctx, it) => + val rv2b = ctx.rvb + val rv2 = RegionValue(ctx.region) it.flatMap { rv => - val rvEnd = rv.region.size - rv2b.set(rv.region) val gsOffset = fullRowType.loadField(rv, localEntriesIndex) (0 until localNCols).iterator .filter { i => localEntriesType.isElementDefined(rv.region, gsOffset, i) } .map { i => - rv.region.clear(rvEnd) rv2b.clear() rv2b.start(resultStruct) rv2b.startStruct() @@ -245,11 +245,11 @@ case class MatrixValue( rv2b.addInlineRow(localColType, localColValues(i).asInstanceOf[Row]) rv2b.addAllFields(localEntryType, rv.region, localEntriesType.elementOffsetInRegion(rv.region, gsOffset, i)) rv2b.endStruct() - rv2.set(rv.region, rv2b.end()) + rv2.setOffset(rv2b.end()) rv2 } } - } + }) } } @@ -471,42 +471,25 @@ case class MatrixRead( } else { val entriesRVD = spec.entriesComponent.read(hc, path) val entriesRowType = entriesRVD.rowType - rowsRVD.zipPartitionsPreservesPartitioning( - typ.orvdType, - entriesRVD - ) { case (it1, it2) => - val rvb = new RegionValueBuilder() - - new Iterator[RegionValue] { - def hasNext: Boolean = { - val hn = it1.hasNext - assert(hn == it2.hasNext) - hn - } - - def next(): RegionValue = { - val rv1 = it1.next() - val rv2 = it2.next() - val region = rv2.region - rvb.set(region) - rvb.start(fullRowType) - rvb.startStruct() - var i = 0 - while (i < localEntriesIndex) { - rvb.addField(rowType, rv1, i) - i += 1 - } - rvb.addField(entriesRowType, rv2, 0) - i += 1 - while (i < fullRowType.size) { - rvb.addField(rowType, rv1, i - 1) - i += 1 - } - rvb.endStruct() - rv2.set(region, rvb.end()) - rv2 - } + rowsRVD.zip(typ.orvdType, entriesRVD) { (ctx, rv1, rv2) => + val rvb = ctx.rvb + val region = ctx.region + rvb.start(fullRowType) + rvb.startStruct() + var i = 0 + while (i < localEntriesIndex) { + rvb.addField(rowType, rv1, i) + i += 1 + } + rvb.addField(entriesRowType, rv2, 0) + i += 1 + while (i < fullRowType.size) { + rvb.addField(rowType, rv1, i - 1) + i += 1 } + rvb.endStruct() + rv2.set(region, rvb.end()) + rv2 } } } @@ -556,16 +539,15 @@ case class MatrixRange(nRows: Int, nCols: Int, nPartitions: Int) extends MatrixI val start = partStarts(i) Interval(Row(start), Row(start + localPartCounts(i)), includesStart = true, includesEnd = false) }), - hc.sc.parallelize(Range(0, nPartitionsAdj), nPartitionsAdj) - .mapPartitionsWithIndex { case (i, _) => - val region = Region() - val rvb = new RegionValueBuilder(region) + ContextRDD.parallelize[RVDContext](hc.sc, Range(0, nPartitionsAdj), nPartitionsAdj) + .cmapPartitionsWithIndex { (i, ctx, _) => + val region = ctx.region + val rvb = ctx.rvb val rv = RegionValue(region) val start = partStarts(i) Iterator.range(start, start + localPartCounts(i)) .map { j => - region.clear() rvb.start(localRVType) rvb.startStruct() @@ -963,12 +945,12 @@ case class MatrixMapRows(child: MatrixIR, newRow: IR) extends MatrixIR { }) assert(rTyp == typ.rvRowType, s"$rTyp, ${ typ.rvRowType }") - val mapPartitionF = { it: Iterator[RegionValue] => + val mapPartitionF = { (ctx: RVDContext, it: Iterator[RegionValue]) => val rvb = new RegionValueBuilder() val newRV = RegionValue() val rowF = f() - val partRegion = Region() + val partRegion = ctx.freshContext.region rvb.set(partRegion) rvb.start(localGlobalsType) @@ -1020,11 +1002,12 @@ case class MatrixMapRows(child: MatrixIR, newRow: IR) extends MatrixIR { } if (touchesKeys) { - val newRDD = prev.rvd.mapPartitions(mapPartitionF) prev.copy(typ = typ, - rvd = OrderedRVD.coerce(typ.orvdType, newRDD, None, None)) + rvd = OrderedRVD.coerce( + typ.orvdType, + prev.rvd.mapPartitions(typ.rvRowType, mapPartitionF))) } else { - val newRVD = prev.rvd.mapPartitionsPreservesPartitioning(typ.orvdType)(mapPartitionF) + val newRVD = prev.rvd.mapPartitionsPreservesPartitioning(typ.orvdType, mapPartitionF) prev.copy(typ = typ, rvd = newRVD) } } @@ -1210,14 +1193,15 @@ case class MatrixMapGlobals(child: MatrixIR, newRow: IR, value: BroadcastRow) ex newRow) assert(rTyp == typ.globalType) - val globalRegion = Region() - val globalOff = prev.globals.toRegion(globalRegion) - val valueOff = value.toRegion(globalRegion) - val newOff = f()(globalRegion, globalOff, false, valueOff, false) + val newGlobals = Region.scoped { globalRegion => + val globalOff = prev.globals.toRegion(globalRegion) + val valueOff = value.toRegion(globalRegion) + val newOff = f()(globalRegion, globalOff, false, valueOff, false) - val newGlobals = prev.globals.copy( - value = SafeRow(rTyp.asInstanceOf[TStruct], globalRegion, newOff), - t = rTyp.asInstanceOf[TStruct]) + prev.globals.copy( + value = SafeRow(rTyp.asInstanceOf[TStruct], globalRegion, newOff), + t = rTyp.asInstanceOf[TStruct]) + } prev.copy(typ = typ, globals = newGlobals) } @@ -1372,8 +1356,8 @@ case class TableParallelize(typ: TableType, rows: IndexedSeq[Row], nPartitions: def execute(hc: HailContext): TableValue = { val rowTyp = typ.rowType - val rvd = hc.sc.parallelize(rows, nPartitions.getOrElse(hc.sc.defaultParallelism)) - .mapPartitions(_.toRegionValueIterator(rowTyp)) + val rvd = ContextRDD.parallelize[RVDContext](hc.sc, rows, nPartitions) + .cmapPartitions((ctx, it) => it.toRegionValueIterator(ctx.region, rowTyp)) TableValue(typ, BroadcastRow(Row(), typ.globalType, hc.sc), new UnpartitionedRVD(rowTyp, rvd)) } } @@ -1397,20 +1381,18 @@ case class TableImport(paths: Array[String], typ: TableType, readerOpts: TableRe val useColIndices = readerOpts.useColIndices - val rvd = hc.sc.textFilesLines(paths, readerOpts.nPartitions) + val rvd = ContextRDD.textFilesLines[RVDContext](hc.sc, paths, readerOpts.nPartitions) .filter { line => !readerOpts.isComment(line.value) && (readerOpts.noHeader || readerOpts.header != line.value) && !(readerOpts.skipBlankLines && line.value.isEmpty) - }.mapPartitions { it => - val region = Region() - val rvb = new RegionValueBuilder(region) + }.cmapPartitions { (ctx, it) => + val region = ctx.region + val rvb = ctx.rvb val rv = RegionValue(region) it.map { _.map { line => - region.clear() - val sp = TextTableReader.splitLine(line, readerOpts.separator, readerOpts.quote) if (sp.length != nFieldOrig) fatal(s"expected $nFieldOrig fields, but found ${ sp.length } fields") @@ -1483,16 +1465,15 @@ case class TableRange(n: Int, nPartitions: Int) extends TableIR { val end = partStarts(i + 1) Interval(Row(start), Row(end), includesStart = true, includesEnd = false) }), - hc.sc.parallelize(Range(0, nPartitionsAdj), nPartitionsAdj) - .mapPartitionsWithIndex { case (i, _) => - val region = Region() - val rvb = new RegionValueBuilder(region) + ContextRDD.parallelize(hc.sc, Range(0, nPartitionsAdj), nPartitionsAdj) + .cmapPartitionsWithIndex { case (i, ctx, _) => + val region = ctx.region + val rvb = ctx.rvb val rv = RegionValue(region) val start = partStarts(i) Iterator.range(start, start + localPartCounts(i)) .map { j => - region.clear() rvb.start(localRowType) rvb.startStruct() rvb.addInt(j) @@ -1562,7 +1543,7 @@ case class TableJoin(left: TableIR, right: TableIR, joinType: String) extends Ta val leftValueFieldIdx = left.typ.valueFieldIdx val rightValueFieldIdx = right.typ.valueFieldIdx val localNewRowType = newRowType - val rvMerger: Iterator[JoinedRegionValue] => Iterator[RegionValue] = { it => + val rvMerger = { (ctx: RVDContext, it: Iterator[JoinedRegionValue]) => val rvb = new RegionValueBuilder() val rv = RegionValue() it.map { joined => @@ -1751,26 +1732,26 @@ case class TableExplode(child: TableIR, column: String) extends TableIR { assert(resultType == typ.rowType) TableValue(typ, prev.globals, - prev.rvd.mapPartitions(typ.rowType) { it => + prev.rvd.boundary.mapPartitions(typ.rowType, { (ctx, it) => val rv2 = RegionValue() - it.flatMap { rv => val isMissing = isMissingF()(rv.region, rv.offset, false) if (isMissing) Iterator.empty else { - val end = rv.region.size val n = lengthF()(rv.region, rv.offset, false) Iterator.range(0, n) .map { i => - rv.region.clear(end) - val off = explodeF()(rv.region, rv.offset, false, i, false) - rv2.set(rv.region, off) + ctx.rvb.start(childRowType) + ctx.rvb.addRegionValue(childRowType, rv) + val incomingRow = ctx.rvb.end() + val off = explodeF()(ctx.region, incomingRow, false, i, false) + rv2.set(ctx.region, off) rv2 } } } - }) + })) } } diff --git a/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala b/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala index fe862a0cdbab..c0b624223b1b 100644 --- a/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala +++ b/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala @@ -68,6 +68,6 @@ object ExtractAggregators { Code(codeArgs.map(_.setup): _*), AggOp.get(op, x.inputType, args.map(_.typ)) .stagedNew(codeArgs.map(_.v).toArray, codeArgs.map(_.m).toArray))) - constfb.result()()(Region()) + Region.scoped(constfb.result()()(_)) } } diff --git a/src/main/scala/is/hail/io/LoadMatrix.scala b/src/main/scala/is/hail/io/LoadMatrix.scala index 5cf481ae3641..d93a71eeb433 100644 --- a/src/main/scala/is/hail/io/LoadMatrix.scala +++ b/src/main/scala/is/hail/io/LoadMatrix.scala @@ -3,7 +3,8 @@ package is.hail.io import is.hail.HailContext import is.hail.annotations._ import is.hail.expr.types._ -import is.hail.rvd.{OrderedRVD, OrderedRVDPartitioner} +import is.hail.rvd.{OrderedRVD, OrderedRVDPartitioner, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant._ import org.apache.hadoop.conf.Configuration @@ -359,9 +360,9 @@ object LoadMatrix { rowPartitionKey = rowKey.toFastIndexedSeq, entryType = cellType) - val rdd = lines.filter(l => l.value.nonEmpty) - .mapPartitionsWithIndex { (i, it) => - val region = Region() + val rdd = ContextRDD.weaken[RVDContext](lines.filter(l => l.value.nonEmpty)) + .cmapPartitionsWithIndex { (i, ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) @@ -374,7 +375,6 @@ object LoadMatrix { val fileRowNum = partitionStartInFile + row val line = v.value - region.clear() rvb.start(matrixType.rvRowType) rvb.startStruct() if (useIndex) { diff --git a/src/main/scala/is/hail/io/RowStore.scala b/src/main/scala/is/hail/io/RowStore.scala index 56b0dae51e82..44e6aa41702f 100644 --- a/src/main/scala/is/hail/io/RowStore.scala +++ b/src/main/scala/is/hail/io/RowStore.scala @@ -4,8 +4,8 @@ import is.hail.annotations._ import is.hail.expr.JSONAnnotationImpex import is.hail.expr.types._ import is.hail.io.compress.LZ4Utils -import is.hail.rvd.{OrderedRVDPartitioner, OrderedRVDSpec, RVDSpec, UnpartitionedRVDSpec} -import is.hail.sparkextras.ContextRDD +import is.hail.rvd.{OrderedRVDPartitioner, OrderedRVDSpec, RVDContext, RVDSpec, UnpartitionedRVDSpec} +import is.hail.sparkextras._ import is.hail.utils._ import org.apache.spark.rdd.RDD import org.json4s.{Extraction, JValue} @@ -62,6 +62,10 @@ object CodecSpec { new LZ4BlockBufferSpec(32 * 1024, new StreamBlockBufferSpec)))) + val defaultUncompressed = new PackCodecSpec( + new BlockingBufferSpec(32 * 1024, + new StreamBlockBufferSpec)) + val blockSpecs: Array[BufferSpec] = Array( new BlockingBufferSpec(64 * 1024, new StreamBlockBufferSpec), @@ -850,13 +854,14 @@ final class PackEncoder(rowType: Type, out: OutputBuffer) extends Encoder { } object RichContextRDDRegionValue { - def writeRowsPartition(makeEnc: (OutputStream) => Encoder)(i: Int, it: Iterator[RegionValue], os: OutputStream): Long = { + def writeRowsPartition(makeEnc: (OutputStream) => Encoder)(ctx: RVDContext, it: Iterator[RegionValue], os: OutputStream): Long = { val en = makeEnc(os) var rowCount = 0L it.foreach { rv => en.writeByte(1) en.writeRegionValue(rv.region, rv.offset) + ctx.region.clear() rowCount += 1 } @@ -868,7 +873,7 @@ object RichContextRDDRegionValue { } } -class RichContextRDDRegionValue[C <: AutoCloseable](val crdd: ContextRDD[C, RegionValue]) extends AnyVal { +class RichContextRDDRegionValue(val crdd: ContextRDD[RVDContext, RegionValue]) extends AnyVal { def writeRows(path: String, t: TStruct, codecSpec: CodecSpec): (Array[String], Array[Long]) = { crdd.writePartitions(path, RichContextRDDRegionValue.writeRowsPartition(codecSpec.buildEncoder(t))) } @@ -892,11 +897,11 @@ class RichContextRDDRegionValue[C <: AutoCloseable](val crdd: ContextRDD[C, Regi val entriesRVType = TStruct( MatrixType.entriesIdentifier -> TArray(t.entryType)) - val makeRowsEnc = codecSpec.buildEncoder(rowsRVType)(_) + val makeRowsEnc = codecSpec.buildEncoder(rowsRVType) - val makeEntriesEnc = codecSpec.buildEncoder(entriesRVType)(_) + val makeEntriesEnc = codecSpec.buildEncoder(entriesRVType) - val partFilePartitionCounts = crdd.mapPartitionsWithIndex { case (i, it) => + val partFilePartitionCounts = crdd.cmapPartitionsWithIndex { (i, ctx, it) => val hConf = sHConfBc.value.value val f = partFile(d, i, TaskContext.get) @@ -940,6 +945,8 @@ class RichContextRDDRegionValue[C <: AutoCloseable](val crdd: ContextRDD[C, Regi rowsEN.writeByte(0) // end entriesEN.writeByte(0) + ctx.region.clear() + Iterator.single(f -> rowCount) } } diff --git a/src/main/scala/is/hail/io/bgen/LoadBgen.scala b/src/main/scala/is/hail/io/bgen/LoadBgen.scala index 7cad671d81bd..0035318f315e 100644 --- a/src/main/scala/is/hail/io/bgen/LoadBgen.scala +++ b/src/main/scala/is/hail/io/bgen/LoadBgen.scala @@ -5,7 +5,8 @@ import is.hail.annotations._ import is.hail.expr.types._ import is.hail.io.vcf.LoadVCF import is.hail.io.{HadoopFSDataBinaryReader, IndexBTree} -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant._ import org.apache.hadoop.io.LongWritable @@ -103,8 +104,10 @@ object LoadBgen { val kType = matrixType.orvdType.kType val rowType = matrixType.rvRowType - val fastKeys = sc.union(results.map(_.rdd.mapPartitions { it => - val region = Region() + val crdds = results.map(x => ContextRDD.weaken[RVDContext](x.rdd)) + + val fastKeys = ContextRDD.union(sc, crdds.map(_.cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) @@ -112,7 +115,6 @@ object LoadBgen { val (contig, pos, alleles) = record.getKey val contigRecoded = contigRecoding.getOrElse(contig, contig) - region.clear() rvb.start(kType) rvb.startStruct() rvb.addAnnotation(kType.types(0), Locus.annotation(contigRecoded, pos, rg)) @@ -134,8 +136,8 @@ object LoadBgen { val loadEntries = entryFields.length > 0 - val rdd2 = sc.union(results.map(_.rdd.mapPartitions { it => - val region = Region() + val rdd2 = ContextRDD.union(sc, crdds.map(_.cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) @@ -145,7 +147,6 @@ object LoadBgen { val contigRecoded = contigRecoding.getOrElse(contig, contig) - region.clear() rvb.start(rowType) rvb.startStruct() rvb.addAnnotation(kType.types(0), Locus.annotation(contigRecoded, pos, rg)) diff --git a/src/main/scala/is/hail/io/plink/ExportPlink.scala b/src/main/scala/is/hail/io/plink/ExportPlink.scala index 103fa5b3aa77..774c8da86b87 100644 --- a/src/main/scala/is/hail/io/plink/ExportPlink.scala +++ b/src/main/scala/is/hail/io/plink/ExportPlink.scala @@ -76,7 +76,7 @@ object ExportPlink { val nSamples = mv.colValues.value.length val fullRowType = mv.typ.rvRowType - val nRecordsWritten = mv.rvd.mapPartitionsWithIndex { case (i, it) => + val nRecordsWritten = mv.rvd.mapPartitionsWithIndex { (i, ctx, it) => val hConf = sHConfBc.value.value val f = partFile(d, i, TaskContext.get) val bedPartPath = tmpBedDir + "/" + f @@ -97,7 +97,7 @@ object ExportPlink { hcv.setRegion(rv) ExportPlink.writeBedRow(hcv, bp, nSamples) - + ctx.region.clear() rowCount += 1 } } diff --git a/src/main/scala/is/hail/io/plink/LoadPlink.scala b/src/main/scala/is/hail/io/plink/LoadPlink.scala index 9e1af3de8e5f..119e1f18977b 100644 --- a/src/main/scala/is/hail/io/plink/LoadPlink.scala +++ b/src/main/scala/is/hail/io/plink/LoadPlink.scala @@ -4,7 +4,8 @@ import is.hail.HailContext import is.hail.annotations._ import is.hail.expr.types._ import is.hail.io.vcf.LoadVCF -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils.StringEscapeUtils._ import is.hail.utils._ import is.hail.variant.{Locus, _} @@ -123,8 +124,13 @@ object LoadPlink { sc.hadoopConfiguration.setInt("nSamples", nSamples) sc.hadoopConfiguration.setBoolean("a2Reference", a2Reference) - val rdd = sc.hadoopFile(bedPath, classOf[PlinkInputFormat], classOf[LongWritable], classOf[PlinkRecord], - nPartitions.getOrElse(sc.defaultMinPartitions)) + val crdd = ContextRDD.weaken[RVDContext]( + sc.hadoopFile( + bedPath, + classOf[PlinkInputFormat], + classOf[LongWritable], + classOf[PlinkRecord], + nPartitions.getOrElse(sc.defaultMinPartitions))) val matrixType = MatrixType.fromParts( globalType = TStruct.empty(), @@ -138,15 +144,14 @@ object LoadPlink { val kType = matrixType.orvdType.kType val rvRowType = matrixType.rvRowType - val fastKeys = rdd.mapPartitions { it => - val region = Region() + val fastKeys = crdd.cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) it.map { case (_, record) => val (contig, pos, posMorgan, ref, alt, rsid) = variantsBc.value(record.getKey) - region.clear() rvb.start(kType) rvb.startStruct() rvb.addAnnotation(kType.types(0), Locus.annotation(contig, pos, rg)) @@ -161,15 +166,14 @@ object LoadPlink { } } - val rdd2 = rdd.mapPartitions { it => - val region = Region() + val rdd2 = crdd.cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) it.map { case (_, record) => val (contig, pos, posMorgan, ref, alt, rsid) = variantsBc.value(record.getKey) - region.clear() rvb.start(rvRowType) rvb.startStruct() rvb.addAnnotation(kType.types(0), Locus.annotation(contig, pos, rg)) diff --git a/src/main/scala/is/hail/io/vcf/LoadGDB.scala b/src/main/scala/is/hail/io/vcf/LoadGDB.scala index 861f7ae578f6..488b243addba 100644 --- a/src/main/scala/is/hail/io/vcf/LoadGDB.scala +++ b/src/main/scala/is/hail/io/vcf/LoadGDB.scala @@ -5,17 +5,19 @@ import htsjdk.variant.vcf.{VCFCompoundHeaderLine, VCFHeader} import is.hail.HailContext import is.hail.annotations._ import is.hail.utils._ +import is.hail.io._ import is.hail.variant.{MatrixTable, ReferenceGenome} import org.json4s._ import scala.collection.JavaConversions._ import scala.collection.JavaConverters.asScalaIteratorConverter -import java.io.{File, FileWriter} +import java.io._ import is.hail.expr.types._ import is.hail.io.VCFAttributes import is.hail.io.vcf.LoadVCF.headerSignature -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVDContext} +import is.hail.sparkextras.ContextRDD import org.apache.spark.sql.Row import scala.collection.mutable @@ -170,27 +172,42 @@ object LoadGDB { entryType = genotypeSignature) val localRowType = matrixType.rvRowType - val region = Region() - val rvb = new RegionValueBuilder(region) + val rvCodec = CodecSpec.defaultUncompressed - val records = gdbReader - .iterator - .asScala - .map { vc => + val records = Region.scoped { region => + val baos = new ByteArrayOutputStream() + val enc = rvCodec.buildEncoder(localRowType)(baos) + val rvb = new RegionValueBuilder(region) + gdbReader + .iterator + .asScala + .map { vc => rvb.clear() region.clear() rvb.start(localRowType) reader.readRecord(vc, rvb, infoSignature, genotypeSignature, dropSamples, canonicalFlags, infoFlagFieldNames) - rvb.result().copy() + enc.writeRegionValue(region, rvb.end()) + baos.toByteArray() }.toArray + } - val recordRDD = sc.parallelize(records, nPartitions.getOrElse(sc.defaultMinPartitions)) + val recordCRDD = ContextRDD.parallelize[RVDContext](sc, records, nPartitions) + .cmapPartitions { (ctx, it) => + val region = ctx.region + val rv = RegionValue(region) + it.map { bytes => + val bais = new ByteArrayInputStream(bytes) + val dec = rvCodec.buildDecoder(localRowType)(bais) + rv.setOffset(dec.readRegionValue(region)) + rv + } + } queryFile.delete() new MatrixTable(hc, matrixType, BroadcastRow(Row.empty, matrixType.globalType, sc), BroadcastIndexedSeq(sampleIds.map(x => Annotation(x)), TArray(matrixType.colType), sc), - OrderedRVD.coerce(matrixType.orvdType, hc.sc.parallelize(records), None, None)) + OrderedRVD.coerce(matrixType.orvdType, recordCRDD)) } } diff --git a/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/src/main/scala/is/hail/io/vcf/LoadVCF.scala index 35586aba31a9..fa70d03b0610 100644 --- a/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -5,7 +5,8 @@ import is.hail.HailContext import is.hail.annotations._ import is.hail.expr.types._ import is.hail.io.{VCFAttributes, VCFMetadata} -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant._ import org.apache.hadoop @@ -711,12 +712,18 @@ object LoadVCF { } // parses the Variant (key), leaves the rest to f - def parseLines[C](makeContext: () => C)(f: (C, VCFLine, RegionValueBuilder) => Unit)( - lines: RDD[WithContext[String]], t: Type, rg: Option[ReferenceGenome], contigRecoding: Map[String, String]): RDD[RegionValue] = { - lines.mapPartitions { it => + def parseLines[C]( + makeContext: () => C + )(f: (C, VCFLine, RegionValueBuilder) => Unit + )(lines: ContextRDD[RVDContext, WithContext[String]], + t: Type, + rg: Option[ReferenceGenome], + contigRecoding: Map[String, String] + ): ContextRDD[RVDContext, RegionValue] = { + lines.cmapPartitions { (ctx, it) => new Iterator[RegionValue] { - val region = Region() - val rvb = new RegionValueBuilder(region) + val region = ctx.region + val rvb = ctx.rvb val rv = RegionValue(region) val context: C = makeContext() @@ -729,7 +736,6 @@ object LoadVCF { val line = lwc.value try { val vcfLine = new VCFLine(line) - region.clear() rvb.start(t) rvb.startStruct() present = vcfLine.parseAddVariant(rvb, rg, contigRecoding) @@ -845,7 +851,7 @@ object LoadVCF { val headerLinesBc = sc.broadcast(headerLines1) - val lines = sc.textFilesLines(files, nPartitions.getOrElse(sc.defaultMinPartitions)) + val lines = ContextRDD.textFilesLines[RVDContext](sc, files, nPartitions) val matrixType: MatrixType = MatrixType.fromParts( TStruct.empty(true), @@ -860,7 +866,13 @@ object LoadVCF { val rowType = matrixType.rvRowType // nothing after the key - val justVariants = parseLines(() => ())((c, l, rvb) => ())(lines, kType, rg, contigRecoding) + val justVariants = parseLines( + () => () + )((c, l, rvb) => () + )(ContextRDD.textFilesLines[RVDContext](sc, files, nPartitions), + kType, + rg, + contigRecoding) val rdd = OrderedRVD.coerce( matrixType.orvdType, @@ -895,7 +907,7 @@ object LoadVCF { } rvb.endArray() }(lines, rowType, rg, contigRecoding), - Some(justVariants), None) + justVariants) new MatrixTable(hc, matrixType, diff --git a/src/main/scala/is/hail/linalg/BlockMatrix.scala b/src/main/scala/is/hail/linalg/BlockMatrix.scala index e39ff2f00ea5..e3da2cc5c7cc 100644 --- a/src/main/scala/is/hail/linalg/BlockMatrix.scala +++ b/src/main/scala/is/hail/linalg/BlockMatrix.scala @@ -11,7 +11,8 @@ import is.hail.table.Table import is.hail.expr.types._ import is.hail.io.{BlockingBufferSpec, BufferSpec, LZ4BlockBufferSpec, StreamBlockBufferSpec} import is.hail.methods.UpperIndexBounds -import is.hail.rvd.RVD +import is.hail.rvd.{RVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.utils.richUtils.RichDenseMatrixDouble import org.apache.commons.math3.random.MersenneTwister @@ -247,7 +248,7 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], val hadoop = blocks.sparkContext.hadoopConfiguration hadoop.mkDir(uri) - def writeBlock(i: Int, it: Iterator[((Int, Int), BDM[Double])], os: OutputStream): Int = { + def writeBlock(it: Iterator[((Int, Int), BDM[Double])], os: OutputStream): Int = { assert(it.hasNext) val bdm = it.next()._2 assert(!it.hasNext) @@ -621,17 +622,16 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], case None => blocks } - val entriesRDD = rdd.flatMap { case ((blockRow, blockCol), block) => + val entriesRDD = ContextRDD.weaken[RVDContext](rdd).cflatMap { case (ctx, ((blockRow, blockCol), block)) => val rowOffset = blockRow * blockSize.toLong val colOffset = blockCol * blockSize.toLong - val region = Region() + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) block.activeIterator .map { case ((i, j), entry) => - region.clear() rvb.start(rvRowType) rvb.startStruct() rvb.addLong(rowOffset + i) @@ -1008,7 +1008,7 @@ case class WriteBlocksRDDPartition(index: Int, start: Int, skip: Int, end: Int) } class WriteBlocksRDD(path: String, - rdd: RDD[RegionValue], + crdd: ContextRDD[RVDContext, RegionValue], sc: SparkContext, matrixType: MatrixType, parentPartStarts: Array[Long], @@ -1017,7 +1017,7 @@ class WriteBlocksRDD(path: String, require(gp.nRows == parentPartStarts.last) - private val parentParts = rdd.partitions + private val parentParts = crdd.partitions private val blockSize = gp.blockSize private val d = digitsNeeded(gp.numPartitions) @@ -1025,7 +1025,7 @@ class WriteBlocksRDD(path: String, override def getDependencies: Seq[Dependency[_]] = Array[Dependency[_]]( - new NarrowDependency(rdd) { + new NarrowDependency(crdd.rdd) { def getParents(partitionId: Int): Seq[Int] = partitions(partitionId).asInstanceOf[WriteBlocksRDDPartition].range } @@ -1100,8 +1100,8 @@ class WriteBlocksRDD(path: String, val writeBlocksPart = split.asInstanceOf[WriteBlocksRDDPartition] val start = writeBlocksPart.start - writeBlocksPart.range.foreach { pi => - val it = rdd.iterator(parentParts(pi), context) + writeBlocksPart.range.foreach { pi => using(crdd.mkc()) { ctx => + val it = crdd.iterator(parentParts(pi), context, ctx) if (pi == start) { var j = 0 @@ -1147,7 +1147,7 @@ class WriteBlocksRDD(path: String, } i += 1 } - } + } } outPerBlockCol.foreach(_.close()) diff --git a/src/main/scala/is/hail/methods/FilterAlleles.scala b/src/main/scala/is/hail/methods/FilterAlleles.scala index 0407f589355d..161bea782500 100644 --- a/src/main/scala/is/hail/methods/FilterAlleles.scala +++ b/src/main/scala/is/hail/methods/FilterAlleles.scala @@ -104,15 +104,14 @@ object FilterAlleles { val localSampleAnnotationsBc = vsm.colValues.broadcast - rdd.mapPartitions(newRVType) { it => + rdd.boundary.mapPartitions(newRVType, { (ctx, it) => var prevLocus: Locus = null val fullRow = new UnsafeRow(fullRowType) val rvv = new RegionValueVariant(fullRowType) + val rvb = ctx.rvb + val rv2 = RegionValue(ctx.region) it.flatMap { rv => - val rvb = new RegionValueBuilder() - val rv2 = RegionValue() - rvv.setRegion(rv) fullRow.set(rv) @@ -132,7 +131,6 @@ object FilterAlleles { || (!isLeftAligned && removeMoving)) None else { - rvb.set(rv.region) rvb.start(newRVType) rvb.startStruct() @@ -162,13 +160,13 @@ object FilterAlleles { } rvb.endArray() rvb.endStruct() - rv2.set(rv.region, rvb.end()) + rv2.setOffset(rvb.end()) Some(rv2) } } } - } + }) } val newRDD2: OrderedRVD = diff --git a/src/main/scala/is/hail/methods/FilterIntervals.scala b/src/main/scala/is/hail/methods/FilterIntervals.scala index 84fdefea1cbf..971cfcc3ecc2 100644 --- a/src/main/scala/is/hail/methods/FilterIntervals.scala +++ b/src/main/scala/is/hail/methods/FilterIntervals.scala @@ -22,15 +22,15 @@ object FilterIntervals { val pkRowFieldIdx = vsm.rvd.typ.pkRowFieldIdx val rowType = vsm.rvd.typ.rowType - vsm.copy2(rvd = vsm.rvd.mapPartitionsPreservesPartitioning(vsm.rvd.typ) { it => - val pk = WritableRegionValue(pkType) + vsm.copy2(rvd = vsm.rvd.mapPartitionsPreservesPartitioning(vsm.rvd.typ, { (ctx, it) => val pkUR = new UnsafeRow(pkType) it.filter { rv => - pk.setSelect(rowType, pkRowFieldIdx, rv) - pkUR.set(pk.value) + ctx.rvb.start(pkType) + ctx.rvb.selectRegionValue(rowType, pkRowFieldIdx, rv) + pkUR.set(ctx.region, ctx.rvb.end()) !intervalsBc.value.contains(pkType.ordering, pkUR) } - }) + })) } } } diff --git a/src/main/scala/is/hail/methods/IBD.scala b/src/main/scala/is/hail/methods/IBD.scala index e51a52832a4a..4fe53ad59856 100644 --- a/src/main/scala/is/hail/methods/IBD.scala +++ b/src/main/scala/is/hail/methods/IBD.scala @@ -5,6 +5,8 @@ import is.hail.expr.EvalContext import is.hail.table.Table import is.hail.annotations._ import is.hail.expr.types._ +import is.hail.rvd.RVDContext +import is.hail.sparkextras.ContextRDD import is.hail.variant.{Call, Genotype, HardCallView, MatrixTable} import is.hail.stats.RegressionUtils import org.apache.spark.rdd.RDD @@ -207,7 +209,7 @@ object IBD { min: Option[Double], max: Option[Double], sampleIds: IndexedSeq[String], - bounded: Boolean): RDD[RegionValue] = { + bounded: Boolean): ContextRDD[RVDContext, RegionValue] = { val nSamples = vds.numCols @@ -253,7 +255,7 @@ object IBD { }) .map { case ((s, v), gs) => (v, (s, IBSFFI.pack(chunkSize, chunkSize, gs))) } - chunkedGenotypeMatrix.join(chunkedGenotypeMatrix) + val joined = ContextRDD.weaken[RVDContext](chunkedGenotypeMatrix.join(chunkedGenotypeMatrix) // optimization: Ignore chunks below the diagonal .filter { case (_, ((i, _), (j, _))) => j >= i } .map { case (_, ((s1, gs1), (s2, gs2))) => @@ -266,9 +268,11 @@ object IBD { i += 1 } a - } - .mapPartitions { it => - val region = Region() + }) + + joined + .cmapPartitions { (ctx, it) => + val region = ctx.region val rv = RegionValue(region) val rvb = new RegionValueBuilder(region) for { @@ -282,7 +286,6 @@ object IBD { eibd = calculateIBDInfo(ibses(idx * 3), ibses(idx * 3 + 1), ibses(idx * 3 + 2), ibse, bounded) if min.forall(eibd.ibd.PI_HAT >= _) && max.forall(eibd.ibd.PI_HAT <= _) } yield { - region.clear() rvb.start(ibdSignature) rvb.startStruct() rvb.addString(sampleIds(i)) @@ -292,7 +295,7 @@ object IBD { rv.setOffset(rvb.end()) rv } - } + } } def apply(vds: MatrixTable, diff --git a/src/main/scala/is/hail/methods/LinearRegression.scala b/src/main/scala/is/hail/methods/LinearRegression.scala index f51230dc3623..fdc54acc1d54 100644 --- a/src/main/scala/is/hail/methods/LinearRegression.scala +++ b/src/main/scala/is/hail/methods/LinearRegression.scala @@ -60,10 +60,10 @@ object LinearRegression { val newMatrixType = vsm.matrixType.copy(rvRowType = newRVType) - val newRDD2 = vsm.rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType) { it => - - val region2 = Region() - val rvb = new RegionValueBuilder(region2) + val newRDD2 = vsm.rvd.boundary.mapPartitionsPreservesPartitioning( + newMatrixType.orvdType, { (ctx, it) => + val region2 = ctx.region + val rvb = ctx.rvb val rv2 = RegionValue(region2) val missingCompleteCols = new ArrayBuilder[Int] @@ -136,7 +136,7 @@ object LinearRegression { rv2 } } - } + }) vsm.copyMT(matrixType = newMatrixType, rvd = newRDD2) diff --git a/src/main/scala/is/hail/methods/LocalLDPrune.scala b/src/main/scala/is/hail/methods/LocalLDPrune.scala index eb486b071897..85b659b68226 100644 --- a/src/main/scala/is/hail/methods/LocalLDPrune.scala +++ b/src/main/scala/is/hail/methods/LocalLDPrune.scala @@ -319,13 +319,13 @@ object LocalLDPrune { } }) - val rddLP = pruneLocal(standardizedRDD, r2Threshold, windowSize, Some(maxQueueSize)) + val rvdLP = pruneLocal(standardizedRDD, r2Threshold, windowSize, Some(maxQueueSize)) val tableType = TableType( rowType = mt.rowKeyStruct ++ TStruct("mean" -> TFloat64Required, "centered_length_sd_reciprocal" -> TFloat64Required), key = Some(mt.rowKey), globalType = TStruct.empty()) - val sitesOnlyRDD = rddLP.mapPartitionsPreservesPartitioning( + val sitesOnly = rvdLP.mapPartitionsPreservesPartitioning( new OrderedRVDType(typ.partitionKey, typ.key, tableType.rowType))({ it => val region = Region() @@ -345,7 +345,7 @@ object LocalLDPrune { } }) - new Table(hc = mt.hc, rdd = sitesOnlyRDD.rdd, signature = tableType.rowType, key = tableType.key) + new Table(hc = mt.hc, crdd = sitesOnly.crdd, signature = tableType.rowType, key = tableType.key) } } diff --git a/src/main/scala/is/hail/methods/Nirvana.scala b/src/main/scala/is/hail/methods/Nirvana.scala index 27f9a1a62e25..b2c22fdb10ce 100644 --- a/src/main/scala/is/hail/methods/Nirvana.scala +++ b/src/main/scala/is/hail/methods/Nirvana.scala @@ -6,7 +6,8 @@ import java.util.Properties import is.hail.annotations._ import is.hail.expr.types._ import is.hail.expr.{JSONAnnotationImpex, Parser} -import is.hail.rvd.{OrderedRVD, OrderedRVDType} +import is.hail.rvd.{OrderedRVD, OrderedRVDType, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant.{Locus, MatrixTable, RegionValueVariant} import org.apache.spark.sql.Row @@ -327,8 +328,8 @@ object Nirvana { val nirvanaRVD: OrderedRVD = OrderedRVD( nirvanaORVDType, vds.rvd.partitioner, - annotations.mapPartitions { it => - val region = Region() + ContextRDD.weaken[RVDContext](annotations).cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) diff --git a/src/main/scala/is/hail/methods/PCA.scala b/src/main/scala/is/hail/methods/PCA.scala index 57ce218dec77..d175a2f6a5ab 100644 --- a/src/main/scala/is/hail/methods/PCA.scala +++ b/src/main/scala/is/hail/methods/PCA.scala @@ -3,6 +3,8 @@ package is.hail.methods import breeze.linalg.{*, DenseMatrix, DenseVector} import is.hail.annotations._ import is.hail.expr.types._ +import is.hail.rvd.RVDContext +import is.hail.sparkextras.ContextRDD import is.hail.table.Table import is.hail.utils._ import is.hail.variant.MatrixTable @@ -23,14 +25,13 @@ object PCA { val scoresBc = sc.broadcast(scores) val localSSignature = vsm.colKeyTypes - val scoresRDD = sc.parallelize(vsm.colKeys.zipWithIndex).mapPartitions[RegionValue] { it => - val region = Region() + val scoresRDD = ContextRDD.weaken[RVDContext](sc.parallelize(vsm.colKeys.zipWithIndex)).cmapPartitions { (ctx, it) => + val region = ctx.region val rv = RegionValue(region) val rvb = new RegionValueBuilder(region) val localRowType = rowTypeBc.value it.map { case (s, i) => - region.clear() rvb.start(localRowType) rvb.startStruct() var j = 0 @@ -87,18 +88,17 @@ object PCA { }.collect() } - val optionLoadings = someIf(computeLoadings, { + val optionLoadings = if (computeLoadings) { val rowType = TStruct(vsm.rowKey.zip(vsm.rowKeyTypes): _*) ++ TStruct("loadings" -> TArray(TFloat64())) val rowTypeBc = vsm.sparkContext.broadcast(rowType) val rowKeysBc = vsm.sparkContext.broadcast(collectRowKeys()) val localRowKeySignature = vsm.rowKeyTypes - val rdd = svd.U.rows.mapPartitions[RegionValue] { it => - val region = Region() + val rdd = ContextRDD.weaken[RVDContext](svd.U.rows).cmapPartitions { (ctx, it) => + val region = ctx.region val rv = RegionValue(region) val rvb = new RegionValueBuilder(region) it.map { ir => - region.clear() rvb.start(rowTypeBc.value) rvb.startStruct() @@ -121,8 +121,10 @@ object PCA { rv } } - new Table(vsm.hc, rdd, rowType, Some(vsm.rowKey)) - }) + Some(new Table(vsm.hc, rdd, rowType, Some(vsm.rowKey))) + } else { + None + } val data = if (!svd.V.isTransposed) diff --git a/src/main/scala/is/hail/methods/PCRelate.scala b/src/main/scala/is/hail/methods/PCRelate.scala index 0eaa85e6212c..f28898adc44a 100644 --- a/src/main/scala/is/hail/methods/PCRelate.scala +++ b/src/main/scala/is/hail/methods/PCRelate.scala @@ -168,7 +168,7 @@ object PCRelate { val localRowType = vds.rvRowType val partStarts = vds.partitionStarts() val partStartsBc = vds.sparkContext.broadcast(partStarts) - val rdd = vds.rvd.mapPartitionsWithIndex { case (partIdx, it) => + val rdd = vds.rvd.mapPartitionsWithIndex { (partIdx, it) => val view = HardCallView(localRowType) val missingIndices = new ArrayBuilder[Int]() diff --git a/src/main/scala/is/hail/methods/Skat.scala b/src/main/scala/is/hail/methods/Skat.scala index c8b3b1e3f45a..0007d6037639 100644 --- a/src/main/scala/is/hail/methods/Skat.scala +++ b/src/main/scala/is/hail/methods/Skat.scala @@ -227,8 +227,8 @@ object Skat { val n = completeColIdx.length val completeColIdxBc = sc.broadcast(completeColIdx) - - (vsm.rvd.rdd.flatMap { rv => + + (vsm.rvd.boundary.mapPartitions { it => it.flatMap { rv => val keyIsDefined = fullRowType.isFieldDefined(rv, keyIndex) val weightIsDefined = fullRowType.isFieldDefined(rv, weightIndex) @@ -243,6 +243,7 @@ object Skat { rv, fullRowType, entryArrayType, entryType, entryArrayIdx, fieldIdx) Some(key -> (BDV(data), weight)) } else None + } }.groupByKey(), keyType) } diff --git a/src/main/scala/is/hail/methods/SplitMulti.scala b/src/main/scala/is/hail/methods/SplitMulti.scala index 88b5fdd2b40a..d0707493cbaa 100644 --- a/src/main/scala/is/hail/methods/SplitMulti.scala +++ b/src/main/scala/is/hail/methods/SplitMulti.scala @@ -5,10 +5,10 @@ import is.hail.asm4s.AsmFunction13 import is.hail.expr._ import is.hail.expr.ir._ import is.hail.expr.types._ -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row class ExprAnnotator(val ec: EvalContext, t: TStruct, expr: String, head: Option[String]) extends Serializable { @@ -66,37 +66,26 @@ class SplitMultiRowIR(rowIRs: Array[(String, IR)], entryIRs: Array[(String, IR)] class SplitMultiPartitionContextIR( keepStar: Boolean, nSamples: Int, globalAnnotation: Annotation, matrixType: MatrixType, - rowF: () => AsmFunction13[Region, Long, Boolean, Long, Boolean, Long, Boolean, Long, Boolean, Int, Boolean, Boolean, Boolean, Long], newRVRowType: TStruct) extends - SplitMultiPartitionContext(keepStar, nSamples, globalAnnotation, matrixType, newRVRowType) { - - private var globalsCopied = false - private var globals: Long = 0 - private var globalsEnd: Long = 0 - - private def copyGlobals() { - splitRegion.clear() - rvb.set(splitRegion) - rvb.start(matrixType.globalType) - rvb.addAnnotation(matrixType.globalType, globalAnnotation) - globals = rvb.end() - globalsEnd = splitRegion.size - globalsCopied = true - } + rowF: () => AsmFunction13[Region, Long, Boolean, Long, Boolean, Long, Boolean, Long, Boolean, Int, Boolean, Boolean, Boolean, Long], + newRVRowType: TStruct, + region: Region +) extends + SplitMultiPartitionContext(keepStar, nSamples, globalAnnotation, matrixType, newRVRowType, region) { private val allelesType = matrixType.rowType.fieldByName("alleles").typ private val locusType = matrixType.rowType.fieldByName("locus").typ val f = rowF() def constructSplitRow(splitVariants: Iterator[(Locus, IndexedSeq[String], Int)], rv: RegionValue, wasSplit: Boolean): Iterator[RegionValue] = { - if (!globalsCopied) - copyGlobals() - splitRegion.clear(globalsEnd) - rvb.start(matrixType.rvRowType) - rvb.addRegionValue(matrixType.rvRowType, rv) - val oldRow = rvb.end() - val oldEnd = splitRegion.size splitVariants.map { case (newLocus, newAlleles, aIndex) => - splitRegion.clear(oldEnd) + rvb.set(splitRegion) + rvb.start(matrixType.globalType) + rvb.addAnnotation(matrixType.globalType, globalAnnotation) + val globals = rvb.end() + + rvb.start(matrixType.rvRowType) + rvb.addRegionValue(matrixType.rvRowType, rv) + val oldRow = rvb.end() rvb.start(locusType) rvb.addAnnotation(locusType, newLocus) @@ -115,9 +104,15 @@ class SplitMultiPartitionContextIR( class SplitMultiPartitionContextAST( keepStar: Boolean, - nSamples: Int, globalAnnotation: Annotation, matrixType: MatrixType, - vAnnotator: ExprAnnotator, gAnnotator: ExprAnnotator, newRVRowType: TStruct) extends - SplitMultiPartitionContext(keepStar, nSamples, globalAnnotation, matrixType, newRVRowType) { + nSamples: Int, + globalAnnotation: Annotation, + matrixType: MatrixType, + vAnnotator: ExprAnnotator, + gAnnotator: ExprAnnotator, + newRVRowType: TStruct, + region: Region +) extends + SplitMultiPartitionContext(keepStar, nSamples, globalAnnotation, matrixType, newRVRowType, region) { val (t1, locusInserter) = vAnnotator.newT.insert(matrixType.rowType.fieldByName("locus").typ, "locus") assert(t1 == vAnnotator.newT) @@ -127,7 +122,6 @@ class SplitMultiPartitionContextAST( def constructSplitRow(splitVariants: Iterator[(Locus, IndexedSeq[String], Int)], rv: RegionValue, wasSplit: Boolean): Iterator[RegionValue] = { val gs = fullRow.getAs[IndexedSeq[Any]](matrixType.entriesIdx) splitVariants.map { case (newLocus, newAlleles, i) => - splitRegion.clear() rvb.set(splitRegion) rvb.start(newRVRowType) rvb.startStruct() @@ -165,12 +159,16 @@ class SplitMultiPartitionContextAST( abstract class SplitMultiPartitionContext( keepStar: Boolean, - nSamples: Int, globalAnnotation: Annotation, matrixType: MatrixType, newRVRowType: TStruct) extends Serializable { + nSamples: Int, + globalAnnotation: Annotation, + matrixType: MatrixType, + newRVRowType: TStruct, + val splitRegion: Region +) extends Serializable { var fullRow = new UnsafeRow(matrixType.rvRowType) var prevLocus: Locus = null val rvv = new RegionValueVariant(matrixType.rvRowType) - val splitRegion = Region() val rvb = new RegionValueBuilder() val splitrv = RegionValue() val locusAllelesOrdering = matrixType.rowKeyStruct.ordering @@ -232,10 +230,14 @@ object SplitMulti { splitmulti.split() } - def unionMovedVariants(ordered: OrderedRVD, - moved: RDD[RegionValue]): OrderedRVD = { - val movedRVD = OrderedRVD.adjustBoundsAndShuffle(ordered.typ, - ordered.partitioner, moved) + def unionMovedVariants( + ordered: OrderedRVD, + moved: RVD + ): OrderedRVD = { + val movedRVD = OrderedRVD.adjustBoundsAndShuffle( + ordered.typ, + ordered.partitioner, + moved) ordered.copy(orderedPartitioner = movedRVD.partitioner).partitionSortedUnion(movedRVD) } @@ -284,7 +286,7 @@ class SplitMulti(vsm: MatrixTable, variantExpr: String, genotypeExpr: String, ke (t, true, null) } - def split(sortAlleles: Boolean, removeLeftAligned: Boolean, removeMoving: Boolean, verifyLeftAligned: Boolean): RDD[RegionValue] = { + def split(sortAlleles: Boolean, removeLeftAligned: Boolean, removeMoving: Boolean, verifyLeftAligned: Boolean): RVD = { val localKeepStar = keepStar val globalsBc = vsm.globals.broadcast val localNSamples = vsm.numCols @@ -298,27 +300,22 @@ class SplitMulti(vsm: MatrixTable, variantExpr: String, genotypeExpr: String, ke val locusIndex = localRowType.fieldIdx("locus") - if (useAST) { - vsm.rvd.mapPartitions { it => - val context = new SplitMultiPartitionContextAST(localKeepStar, localNSamples, globalsBc.value, - localMatrixType, localVAnnotator, localGAnnotator, newRowType) - it.flatMap { rv => - val splitit = context.splitRow(rv, sortAlleles, removeLeftAligned, removeMoving, verifyLeftAligned) - context.prevLocus = context.fullRow.getAs[Locus](locusIndex) - splitit - } - } + val makeContext = if (useAST) { + (region: Region) => new SplitMultiPartitionContextAST(localKeepStar, localNSamples, globalsBc.value, + localMatrixType, localVAnnotator, localGAnnotator, newRowType, region) } else { - vsm.rvd.mapPartitions { it => - val context = new SplitMultiPartitionContextIR(localKeepStar, localNSamples, globalsBc.value, - localMatrixType, localSplitRow, newRowType) - it.flatMap { rv => - val splitit = context.splitRow(rv, sortAlleles, removeLeftAligned, removeMoving, verifyLeftAligned) - context.prevLocus = context.fullRow.getAs[Locus](locusIndex) - splitit - } - } + (region: Region) => new SplitMultiPartitionContextIR(localKeepStar, localNSamples, globalsBc.value, + localMatrixType, localSplitRow, newRowType, region) } + + vsm.rvd.boundary.mapPartitions(newRowType, { (ctx, it) => + val splitMultiContext = makeContext(ctx.region) + it.flatMap { rv => + val splitit = splitMultiContext.splitRow(rv, sortAlleles, removeLeftAligned, removeMoving, verifyLeftAligned) + splitMultiContext.prevLocus = splitMultiContext.fullRow.getAs[Locus](locusIndex) + splitit + } + }) } def split(): MatrixTable = { diff --git a/src/main/scala/is/hail/methods/VEP.scala b/src/main/scala/is/hail/methods/VEP.scala index c68b23f2581c..e742acad839e 100644 --- a/src/main/scala/is/hail/methods/VEP.scala +++ b/src/main/scala/is/hail/methods/VEP.scala @@ -6,7 +6,8 @@ import java.util.Properties import is.hail.annotations.{Annotation, Region, RegionValue, RegionValueBuilder} import is.hail.expr._ import is.hail.expr.types._ -import is.hail.rvd.{OrderedRVD, OrderedRVDType} +import is.hail.rvd.{OrderedRVD, OrderedRVDType, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant.{Locus, MatrixTable, RegionValueVariant, VariantMethods} import org.apache.spark.sql.Row @@ -336,9 +337,9 @@ object VEP { val vepRVD: OrderedRVD = OrderedRVD( vepORVDType, vsm.rvd.partitioner, - annotations.mapPartitions { it => - val region = Region() - val rvb = new RegionValueBuilder(region) + ContextRDD.weaken[RVDContext](annotations).cmapPartitions { (ctx, it) => + val region = ctx.region + val rvb = ctx.rvb val rv = RegionValue(region) it.map { case (v, vep) => @@ -351,7 +352,8 @@ object VEP { rv.setOffset(rvb.end()) rv - }}) + } + }) info(s"vep: annotated ${ annotations.count() } variants") diff --git a/src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala b/src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala index 0fdad405407f..4643e9992301 100644 --- a/src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala +++ b/src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala @@ -5,6 +5,8 @@ import is.hail.sparkextras._ import is.hail.utils.fatal import org.apache.spark.rdd.RDD +import scala.collection.generic.Growable + class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) { val typ: OrderedRVDType = rvd.typ val (kType, _) = rvd.rowType.select(key) @@ -22,7 +24,7 @@ class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) { def orderedJoin( right: KeyedOrderedRVD, joinType: String, - joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joiner: (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue], joinedType: OrderedRVDType ): OrderedRVD = { checkJoinCompatability(right) @@ -35,28 +37,34 @@ class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) { this.rvd.constrainToOrderedPartitioner(this.typ, newPartitioner) val repartitionedRight = right.rvd.constrainToOrderedPartitioner(right.typ, newPartitioner) - val compute: (OrderedRVIterator, OrderedRVIterator) => Iterator[JoinedRegionValue] = + val compute: (OrderedRVIterator, OrderedRVIterator, Iterable[RegionValue] with Growable[RegionValue]) => Iterator[JoinedRegionValue] = (joinType: @unchecked) match { - case "inner" => _.innerJoin(_) - case "left" => _.leftJoin(_) - case "right" => _.rightJoin(_) - case "outer" => _.outerJoin(_) + case "inner" => _.innerJoin(_, _) + case "left" => _.leftJoin(_, _) + case "right" => _.rightJoin(_, _) + case "outer" => _.outerJoin(_, _) } - val joinedRDD = - repartitionedLeft.crdd.zipPartitions(repartitionedRight.crdd, true) { - (leftIt, rightIt) => - joiner(compute( - OrderedRVIterator(lTyp, leftIt), - OrderedRVIterator(rTyp, rightIt))) - } - new OrderedRVD(joinedType, newPartitioner, joinedRDD) + repartitionedLeft.zipPartitions( + joinedType, + newPartitioner, + repartitionedRight, + preservesPartitioning = true + ) { (ctx, leftIt, rightIt) => + val sideBuffer = ctx.freshContext.region + joiner( + ctx, + compute( + OrderedRVIterator(lTyp, leftIt), + OrderedRVIterator(rTyp, rightIt), + new RegionValueArrayBuffer(rTyp.rowType, sideBuffer))) + } } def orderedJoinDistinct( right: KeyedOrderedRVD, joinType: String, - joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joiner: (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue], joinedType: OrderedRVDType ): OrderedRVD = { checkJoinCompatability(right) @@ -72,15 +80,19 @@ class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) { case "inner" => _.innerJoinDistinct(_) case "left" => _.leftJoinDistinct(_) } - val joinedRDD = - this.rvd.crdd.zipPartitions(repartitionedRight.crdd, true) { - (leftIt, rightIt) => - joiner(compute( - OrderedRVIterator(rekeyedLTyp, leftIt), - OrderedRVIterator(rekeyedRTyp, rightIt))) - } - new OrderedRVD(joinedType, newPartitioner, joinedRDD) + rvd.zipPartitions( + joinedType, + newPartitioner, + repartitionedRight, + preservesPartitioning = true + ) { (ctx, leftIt, rightIt) => + joiner( + ctx, + compute( + OrderedRVIterator(rekeyedLTyp, leftIt), + OrderedRVIterator(rekeyedRTyp, rightIt))) + } } def orderedZipJoin(right: KeyedOrderedRVD): ContextRDD[RVDContext, JoinedRegionValue] = { @@ -91,7 +103,10 @@ class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) { val leftType = this.typ val rightType = right.typ - repartitionedLeft.crdd.zipPartitions(repartitionedRight.crdd, true){ (leftIt, rightIt) => + repartitionedLeft.zipPartitions( + repartitionedRight, + preservesPartitioning = true + ) { (_, leftIt, rightIt) => OrderedRVIterator(leftType, leftIt).zipJoin(OrderedRVIterator(rightType, rightIt)) } } diff --git a/src/main/scala/is/hail/rvd/OrderedRVD.scala b/src/main/scala/is/hail/rvd/OrderedRVD.scala index eca65d8f6396..1d0b08ad842a 100644 --- a/src/main/scala/is/hail/rvd/OrderedRVD.scala +++ b/src/main/scala/is/hail/rvd/OrderedRVD.scala @@ -1,5 +1,6 @@ package is.hail.rvd +import java.io.ByteArrayInputStream import java.util import is.hail.annotations._ @@ -8,6 +9,7 @@ import is.hail.expr.types._ import is.hail.io.CodecSpec import is.hail.sparkextras._ import is.hail.utils._ +import org.apache.commons.io.output.ByteArrayOutputStream import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.SparkContext @@ -23,38 +25,45 @@ class OrderedRVD( ) extends RVD { self => - def this( - typ: OrderedRVDType, - partitioner: OrderedRVDPartitioner, - rdd: RDD[RegionValue] - ) = this(typ, partitioner, ContextRDD.weaken[RVDContext](rdd)) - - val rdd = crdd.run + def boundary: OrderedRVD = OrderedRVD(typ, partitioner, crddBoundary) def rowType: TStruct = typ.rowType def updateType(newTyp: OrderedRVDType): OrderedRVD = - OrderedRVD(newTyp, partitioner, rdd) + OrderedRVD(newTyp, partitioner, crdd) def mapPreservesPartitioning(newTyp: OrderedRVDType)(f: (RegionValue) => RegionValue): OrderedRVD = OrderedRVD(newTyp, partitioner, - rdd.map(f)) + crdd.map(f)) def mapPartitionsWithIndexPreservesPartitioning(newTyp: OrderedRVDType)(f: (Int, Iterator[RegionValue]) => Iterator[RegionValue]): OrderedRVD = OrderedRVD(newTyp, partitioner, - rdd.mapPartitionsWithIndex(f)) + crdd.mapPartitionsWithIndex(f)) + + def mapPartitionsWithIndexPreservesPartitioning( + newTyp: OrderedRVDType, + f: (Int, RVDContext, Iterator[RegionValue]) => Iterator[RegionValue] + ): OrderedRVD = OrderedRVD( + newTyp, + partitioner, + crdd.cmapPartitionsWithIndex(f)) def mapPartitionsPreservesPartitioning(newTyp: OrderedRVDType)(f: (Iterator[RegionValue]) => Iterator[RegionValue]): OrderedRVD = OrderedRVD(newTyp, partitioner, - rdd.mapPartitions(f)) + crdd.mapPartitions(f)) + + def mapPartitionsPreservesPartitioning( + newTyp: OrderedRVDType, + f: (RVDContext, Iterator[RegionValue]) => Iterator[RegionValue] + ): OrderedRVD = OrderedRVD(newTyp, partitioner, crdd.cmapPartitions(f)) override def filter(p: (RegionValue) => Boolean): OrderedRVD = OrderedRVD(typ, partitioner, - rdd.filter(p)) + crdd.filter(p)) def sample(withReplacement: Boolean, p: Double, seed: Long): OrderedRVD = OrderedRVD(typ, partitioner, crdd.sample(withReplacement, p, seed)) @@ -105,7 +114,7 @@ class OrderedRVD( def orderedJoin( right: OrderedRVD, joinType: String, - joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joiner: (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue], joinedType: OrderedRVDType ): OrderedRVD = keyBy().orderedJoin(right.keyBy(), joinType, joiner, joinedType) @@ -113,7 +122,7 @@ class OrderedRVD( def orderedJoinDistinct( right: OrderedRVD, joinType: String, - joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joiner: (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue], joinedType: OrderedRVDType ): OrderedRVD = keyBy().orderedJoinDistinct(right.keyBy(), joinType, joiner, joinedType) @@ -126,15 +135,16 @@ class OrderedRVD( assert(partitioner == rdd2.partitioner) val localTyp = typ - OrderedRVD(typ, partitioner, - rdd.zipPartitions(rdd2.rdd) { case (it, it2) => - new Iterator[RegionValue] { - private val bit = it.buffered - private val bit2 = it2.buffered + zipPartitions(typ, partitioner, rdd2) { (ctx, it, it2) => + new Iterator[RegionValue] { + private val bit = it.buffered + private val bit2 = it2.buffered + private val rv = RegionValue() - def hasNext: Boolean = bit.hasNext || bit2.hasNext + def hasNext: Boolean = bit.hasNext || bit2.hasNext - def next(): RegionValue = { + def next(): RegionValue = { + val old = if (!bit.hasNext) bit2.next() else if (!bit2.hasNext) @@ -146,21 +156,25 @@ class OrderedRVD( else bit2.next() } - } + ctx.rvb.start(localTyp.rowType) + ctx.rvb.addRegionValue(localTyp.rowType, old) + rv.set(ctx.region, ctx.rvb.end()) + rv } - }) + } + } } def copy(typ: OrderedRVDType = typ, orderedPartitioner: OrderedRVDPartitioner = partitioner, - rdd: RDD[RegionValue] = rdd): OrderedRVD = { + rdd: ContextRDD[RVDContext, RegionValue] = crdd): OrderedRVD = { OrderedRVD(typ, orderedPartitioner, rdd) } def blockCoalesce(partitionEnds: Array[Int]): OrderedRVD = { assert(partitionEnds.last == partitioner.numPartitions - 1 && partitionEnds(0) >= 0) assert(partitionEnds.zip(partitionEnds.tail).forall { case (i, inext) => i < inext }) - OrderedRVD(typ, partitioner.coalesceRangeBounds(partitionEnds), new BlockedRDD(rdd, partitionEnds)) + OrderedRVD(typ, partitioner.coalesceRangeBounds(partitionEnds), crdd.blocked(partitionEnds)) } def naiveCoalesce(maxPartitions: Int): OrderedRVD = { @@ -176,11 +190,11 @@ class OrderedRVD( override def coalesce(maxPartitions: Int, shuffle: Boolean): OrderedRVD = { require(maxPartitions > 0, "cannot coalesce to nPartitions <= 0") - val n = rdd.partitions.length + val n = crdd.partitions.length if (!shuffle && maxPartitions >= n) return this if (shuffle) { - val shuffled = crdd.coalesce(maxPartitions, shuffle = true) + val shuffled = stably(_.shuffleCoalesce(maxPartitions)) val ranges = OrderedRVD.calculateKeyRanges( typ, OrderedRVD.getPartitionKeyInfo(typ, OrderedRVD.getKeys(typ, shuffled)), @@ -191,7 +205,7 @@ class OrderedRVD( shuffled) } else { - val partSize = rdd.context.runJob(rdd, getIteratorSize _) + val partSize = countPerPartition() log.info(s"partSize = ${ partSize.toSeq }") val partCumulativeSize = mapAccumulate[Array, Long](partSize, 0L)((s, acc) => (s + acc, s + acc)) @@ -225,7 +239,7 @@ class OrderedRVD( def filterIntervals(intervals: IntervalTree[_]): OrderedRVD = { val pkOrdering = typ.pkType.ordering - val intervalsBc = rdd.sparkContext.broadcast(intervals) + val intervalsBc = crdd.sparkContext.broadcast(intervals) val rowType = typ.rowType val pkRowFieldIdx = typ.pkRowFieldIdx @@ -258,7 +272,7 @@ class OrderedRVD( OrderedRVD.empty(sparkContext, typ) else { val sub = subsetPartitions(newPartitionIndices) - sub.copy(rdd = sub.rdd.filter(pred)) + sub.copy(rdd = sub.crdd.filter(pred)) } } @@ -268,9 +282,9 @@ class OrderedRVD( if (n == 0) return OrderedRVD.empty(sparkContext, typ) - val newRDD = rdd.head(n) + val newRDD = crdd.head(n) val newNParts = newRDD.getNumPartitions - assert(newNParts > 0) + assert(newNParts >= 0) val newRangeBounds = Array.range(0, newNParts).map(partitioner.rangeBounds) val newPartitioner = new OrderedRVDPartitioner(partitioner.partitionKey, @@ -289,16 +303,21 @@ class OrderedRVD( val localType = typ - val newRDD: RDD[RegionValue] = rdd.mapPartitions { it => - val region = Region() - val rvb = new RegionValueBuilder(region) - val outRV = RegionValue(region) - val buffer = new RegionValueArrayBuffer(localType.valueType) - val stepped: FlipbookIterator[FlipbookIterator[RegionValue]] = - OrderedRVIterator(localType, it).staircase + OrderedRVD(newTyp, partitioner, crdd.cmapPartitionsAndContext { (consumerCtx, useCtxes) => + val consumerRegion = consumerCtx.region + val rvb = consumerCtx.rvb + val outRV = RegionValue(consumerRegion) + + val bufferRegion = consumerCtx.freshContext.region + val buffer = new RegionValueArrayBuffer(localType.valueType, bufferRegion) + + val producerCtx = consumerCtx.freshContext + val producerRegion = producerCtx.region + val it = useCtxes.flatMap(_ (producerCtx)) + + val stepped = OrderedRVIterator(localType, it).staircase stepped.map { stepIt => - region.clear() buffer.clear() rvb.start(newRowType) rvb.startStruct() @@ -307,8 +326,10 @@ class OrderedRVD( rvb.addField(localType.rowType, stepIt.value, localType.kRowFieldIdx(i)) i += 1 } - for (rv <- stepIt) + for (rv <- stepIt) { buffer.appendSelect(localType.rowType, localType.valueFieldIdx, rv) + producerRegion.clear() + } rvb.startArray(buffer.length) for (rv <- buffer) rvb.addRegionValue(localType.valueType, rv) @@ -317,24 +338,21 @@ class OrderedRVD( outRV.setOffset(rvb.end()) outRV } - } - - OrderedRVD(newTyp, partitioner, newRDD) + }) } def distinctByKey(): OrderedRVD = { val localType = typ - val newRVD = rdd.mapPartitions { it => + mapPartitionsPreservesPartitioning(typ)(it => OrderedRVIterator(localType, it) .staircase .map(_.value) - } - OrderedRVD(typ, partitioner, newRVD) + ) } def subsetPartitions(keep: Array[Int]): OrderedRVD = { - require(keep.length <= rdd.partitions.length, "tried to subset to more partitions than exist") - require(keep.isIncreasing && (keep.isEmpty || (keep.head >= 0 && keep.last < rdd.partitions.length)), + require(keep.length <= crdd.partitions.length, "tried to subset to more partitions than exist") + require(keep.isIncreasing && (keep.isEmpty || (keep.head >= 0 && keep.last < crdd.partitions.length)), "values not sorted or not in range [0, number of partitions)") val newRangeBounds = Array.tabulate(keep.length) { i => @@ -350,7 +368,7 @@ class OrderedRVD( partitioner.kType, newRangeBounds) - OrderedRVD(typ, newPartitioner, rdd.subsetPartitions(keep)) + OrderedRVD(typ, newPartitioner, crdd.subsetPartitions(keep)) } override protected def rvdSpec(codecSpec: CodecSpec, partFiles: Array[String]): RVDSpec = @@ -362,25 +380,67 @@ class OrderedRVD( partitioner.rangeBounds, partitioner.rangeBoundsType)) + def zipPartitionsAndContext( + newTyp: OrderedRVDType, + newPartitioner: OrderedRVDPartitioner, + that: OrderedRVD, + preservesPartitioning: Boolean = false + )(zipper: (RVDContext, RVDContext => Iterator[RegionValue], RVDContext => Iterator[RegionValue]) => Iterator[RegionValue] + ): OrderedRVD = OrderedRVD( + newTyp, + newPartitioner, + crdd.czipPartitionsAndContext(that.crdd, preservesPartitioning) { (ctx, lit, rit) => + zipper(ctx, ctx => lit.flatMap(_(ctx)), ctx => rit.flatMap(_(ctx))) + } + ) + def zipPartitionsPreservesPartitioning[T: ClassTag]( newTyp: OrderedRVDType, - that: RDD[T] + that: ContextRDD[RVDContext, T] )(zipper: (Iterator[RegionValue], Iterator[T]) => Iterator[RegionValue] - ): OrderedRVD = - OrderedRVD( - newTyp, - partitioner, - this.rdd.zipPartitions(that, preservesPartitioning = true)(zipper)) + ): OrderedRVD = OrderedRVD( + newTyp, + partitioner, + crdd.zipPartitions(that)(zipper)) + + def zipPartitions( + newTyp: OrderedRVDType, + newPartitioner: OrderedRVDPartitioner, + that: OrderedRVD + )(zipper: (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[RegionValue] + ): OrderedRVD = zipPartitions(newTyp, newPartitioner, that, false)(zipper) - def zipPartitionsPreservesPartitioning( + def zipPartitions( + newTyp: OrderedRVDType, + newPartitioner: OrderedRVDPartitioner, + that: OrderedRVD, + preservesPartitioning: Boolean + )(zipper: (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[RegionValue] + ): OrderedRVD = OrderedRVD( + newTyp, + newPartitioner, + boundary.crdd.czipPartitions(that.boundary.crdd, preservesPartitioning)(zipper)) + + def zipPartitions[T: ClassTag]( + that: OrderedRVD + )(zipper: (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[T] + ): ContextRDD[RVDContext, T] = zipPartitions(that, false)(zipper) + + def zipPartitions[T: ClassTag]( + that: OrderedRVD, + preservesPartitioning: Boolean + )(zipper: (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[T] + ): ContextRDD[RVDContext, T] = + boundary.crdd.czipPartitions(that.boundary.crdd, preservesPartitioning)(zipper) + + def zip( newTyp: OrderedRVDType, that: RVD - )(zipper: (Iterator[RegionValue], Iterator[RegionValue]) => Iterator[RegionValue] - ): OrderedRVD = - OrderedRVD( - newTyp, - partitioner, - this.rdd.zipPartitions(that.rdd, preservesPartitioning = true)(zipper)) + )(zipper: (RVDContext, RegionValue, RegionValue) => RegionValue + ): OrderedRVD = OrderedRVD( + newTyp, + partitioner, + this.crdd.czip(that.crdd, preservesPartitioning = true)(zipper)) def writeRowsSplit( path: String, @@ -396,21 +456,29 @@ object OrderedRVD { def empty(sc: SparkContext, typ: OrderedRVDType): OrderedRVD = { OrderedRVD(typ, OrderedRVDPartitioner.empty(typ), - sc.emptyRDD[RegionValue]) + ContextRDD.empty[RVDContext, RegionValue](sc)) } /** * Precondition: the iterator it is PK-sorted. We lazily K-sort each block * of PK-equivalent elements. */ - def localKeySort(typ: OrderedRVDType, + def localKeySort( + consumerRegion: Region, + producerRegion: Region, + typ: OrderedRVDType, // it: Iterator[RegionValue[rowType]] - it: Iterator[RegionValue]): Iterator[RegionValue] = { + it: Iterator[RegionValue] + ): Iterator[RegionValue] = new Iterator[RegionValue] { private val bit = it.buffered private val q = new mutable.PriorityQueue[RegionValue]()(typ.kInRowOrd.reverse) + private val rvb = new RegionValueBuilder(consumerRegion) + + private val rv = RegionValue() + def hasNext: Boolean = bit.hasNext || q.nonEmpty def next(): RegionValue = { @@ -418,17 +486,17 @@ object OrderedRVD { do { val rv = bit.next() // FIXME ugh, no good answer here - q.enqueue(RegionValue( - rv.region.copy(), - rv.offset)) + q.enqueue(rv.copy()) + producerRegion.clear() } while (bit.hasNext && typ.pkInRowOrd.compare(q.head, bit.head) == 0) } - val rv = q.dequeue() + rvb.start(typ.rowType) + rvb.addRegionValue(typ.rowType, q.dequeue()) + rv.set(consumerRegion, rvb.end()) rv } } - } // getKeys: RDD[RegionValue[kType]] def getKeys( @@ -446,14 +514,6 @@ object OrderedRVD { } } - // FIXME: delete when I've removed all need for RDDs - def getKeys( - typ: OrderedRVDType, - rdd: RDD[RegionValue] - ): RDD[RegionValue] = getKeys( - typ, - ContextRDD.weaken[RVDContext](rdd)).run - def getPartitionKeyInfo( typ: OrderedRVDType, // keys: RDD[kType] @@ -471,25 +531,18 @@ object OrderedRVD { val localType = typ - val pkis = keys.mapPartitionsWithIndex { case (i, it) => - if (it.hasNext) + val pkis = keys.cmapPartitionsWithIndex { (i, ctx, it) => + val out = if (it.hasNext) Iterator(OrderedRVPartitionInfo(localType, samplesPerPartition, i, it, partitionSeed(i))) else Iterator() + ctx.region.clear() + out }.collect() pkis.sortBy(_.min)(typ.pkType.ordering.toOrdering) } - // FIXME: delete when I've removed all need for RDDs - def getPartitionKeyInfo[C]( - typ: OrderedRVDType, - // keys: RDD[kType] - keys: RDD[RegionValue] - ): Array[OrderedRVPartitionInfo] = getPartitionKeyInfo( - typ, - ContextRDD.weaken[RVDContext](keys)) - def coerce( typ: OrderedRVDType, rvd: RVD @@ -498,7 +551,7 @@ object OrderedRVD { def coerce( typ: OrderedRVDType, rvd: RVD, - fastKeys: RDD[RegionValue] + fastKeys: ContextRDD[RVDContext, RegionValue] ): OrderedRVD = coerce(typ, rvd, Some(fastKeys), None) def coerce( @@ -510,9 +563,9 @@ object OrderedRVD { def coerce( typ: OrderedRVDType, rvd: RVD, - fastKeys: Option[RDD[RegionValue]], + fastKeys: Option[ContextRDD[RVDContext, RegionValue]], hintPartitioner: Option[OrderedRVDPartitioner] - ): OrderedRVD = coerce(typ, rvd.rdd, fastKeys, hintPartitioner) + ): OrderedRVD = coerce(typ, rvd.crdd, fastKeys, hintPartitioner) def coerce( typ: OrderedRVDType, @@ -536,23 +589,47 @@ object OrderedRVD { rdd: RDD[RegionValue], fastKeys: RDD[RegionValue], hintPartitioner: OrderedRVDPartitioner - ): OrderedRVD = coerce(typ, rdd, Some(fastKeys), Some(hintPartitioner)) + ): OrderedRVD = coerce( + typ, + rdd, + Some(fastKeys), + Some(hintPartitioner)) def coerce( typ: OrderedRVDType, - // rdd: RDD[RegionValue[rowType]] rdd: RDD[RegionValue], - // fastKeys: Option[RDD[RegionValue[kType]]] fastKeys: Option[RDD[RegionValue]], hintPartitioner: Option[OrderedRVDPartitioner] - ): OrderedRVD = { - val sc = rdd.sparkContext + ): OrderedRVD = coerce( + typ, + ContextRDD.weaken[RVDContext](rdd), + fastKeys.map(ContextRDD.weaken[RVDContext](_)), + hintPartitioner) - if (rdd.partitions.isEmpty) - return empty(sc, typ) + def coerce( + typ: OrderedRVDType, + crdd: ContextRDD[RVDContext, RegionValue] + ): OrderedRVD = coerce(typ, crdd, None, None) + + def coerce( + typ: OrderedRVDType, + crdd: ContextRDD[RVDContext, RegionValue], + fastKeys: ContextRDD[RVDContext, RegionValue] + ): OrderedRVD = coerce(typ, crdd, Some(fastKeys), None) + + def coerce( + typ: OrderedRVDType, + crdd: ContextRDD[RVDContext, RegionValue], + fastKeys: Option[ContextRDD[RVDContext, RegionValue]], + hintPartitioner: Option[OrderedRVDPartitioner] + ): OrderedRVD = { + val sc = crdd.sparkContext + + if (crdd.partitions.isEmpty) + return empty(sc, typ) // keys: RDD[RegionValue[kType]] - val keys = fastKeys.getOrElse(getKeys(typ, rdd)) + val keys = fastKeys.getOrElse(getKeys(typ, crdd)) val pkis = getPartitionKeyInfo(typ, keys) @@ -574,8 +651,9 @@ object OrderedRVD { typ.kType, rangeBounds) - val reorderedPartitionsRDD = rdd.reorderPartitions(pkis.map(_.partitionIndex)) - val adjustedRDD = new AdjustedPartitionsRDD(reorderedPartitionsRDD, adjustedPartitions) + val adjustedRDD = crdd + .reorderPartitions(pkis.map(_.partitionIndex)) + .adjustPartitions(adjustedPartitions) (adjSortedness: @unchecked) match { case OrderedRVPartitionInfo.KSORTED => info("Coerced sorted dataset") @@ -587,19 +665,20 @@ object OrderedRVD { info("Coerced almost-sorted dataset") OrderedRVD(typ, partitioner, - adjustedRDD.mapPartitions { it => - localKeySort(typ, it) + adjustedRDD.cmapPartitionsAndContext { (consumerCtx, it) => + val producerCtx = consumerCtx.freshContext + localKeySort(consumerCtx.region, producerCtx.region, typ, it.flatMap(_(producerCtx))) }) } } else { info("Ordering unsorted dataset with network shuffle") hintPartitioner - .filter(_.numPartitions >= rdd.partitions.length) - .map(adjustBoundsAndShuffle(typ, _, rdd)) + .filter(_.numPartitions >= crdd.partitions.length) + .map(adjustBoundsAndShuffle(typ, _, crdd)) .getOrElse { - val ranges = calculateKeyRanges(typ, pkis, rdd.getNumPartitions) + val ranges = calculateKeyRanges(typ, pkis, crdd.getNumPartitions) val p = new OrderedRVDPartitioner(typ.partitionKey, typ.kType, ranges) - shuffle(typ, p, rdd) + shuffle(typ, p, crdd) } } } @@ -641,21 +720,31 @@ object OrderedRVD { OrderedRVDPartitioner.makeRangeBoundIntervals(typ.pkType, partitionEdges) } - def adjustBoundsAndShuffle(typ: OrderedRVDType, + def adjustBoundsAndShuffle( + typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, - rdd: RDD[RegionValue]): OrderedRVD = { + rvd: RVD + ): OrderedRVD = { + assert(typ.rowType == rvd.rowType) + adjustBoundsAndShuffle(typ, partitioner, rvd.crdd) + } + private[this] def adjustBoundsAndShuffle( + typ: OrderedRVDType, + partitioner: OrderedRVDPartitioner, + crdd: ContextRDD[RVDContext, RegionValue] + ): OrderedRVD = { val pkType = partitioner.pkType val pkOrd = pkType.ordering.toOrdering - val pkis = getPartitionKeyInfo(typ, OrderedRVD.getKeys(typ, rdd)) + val pkis = getPartitionKeyInfo(typ, getKeys(typ, crdd)) if (pkis.isEmpty) - return OrderedRVD(typ, partitioner, rdd) + return OrderedRVD(typ, partitioner, crdd) val min = pkis.map(_.min).min(pkOrd) val max = pkis.map(_.max).max(pkOrd) - shuffle(typ, partitioner.enlargeToRange(Interval(min, max, true, true)), rdd) + shuffle(typ, partitioner.enlargeToRange(Interval(min, max, true, true)), crdd) } def shuffle( @@ -664,12 +753,6 @@ object OrderedRVD { rvd: RVD ): OrderedRVD = shuffle(typ, partitioner, rvd.crdd) - def shuffle(typ: OrderedRVDType, - partitioner: OrderedRVDPartitioner, - rdd: RDD[RegionValue] - ): OrderedRVD = - shuffle(typ, partitioner, ContextRDD.weaken[RVDContext](rdd)) - def shuffle( typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, @@ -679,19 +762,23 @@ object OrderedRVD { val partBc = partitioner.broadcast(crdd.sparkContext) OrderedRVD(typ, partitioner, - crdd.mapPartitions { it => - val wrv = WritableRegionValue(typ.rowType) + crdd.cmapPartitions { (ctx, it) => + val enc = RVD.wireCodec.buildEncoder(localType.rowType) + it.map { rv => val wkrv = WritableRegionValue(typ.kType) - it.map { rv => - wrv.set(rv) - wkrv.setSelect(localType.rowType, localType.kRowFieldIdx, rv) - (wkrv.value, wrv.value) - } + wkrv.setSelect(localType.rowType, localType.kRowFieldIdx, rv) + val bytes = + RVD.regionValueToBytes(enc, ctx)(rv) + (wkrv.value, bytes) + } }.shuffle(partitioner.sparkPartitioner(crdd.sparkContext), typ.kOrd) - .mapPartitionsWithIndex { case (i, it) => - it.map { case (k, v) => + .cmapPartitionsWithIndex { case (i, ctx, it) => + val dec = RVD.wireCodec.buildDecoder(localType.rowType) + val region = ctx.region + val rv = RegionValue(region) + it.map { case (k, bytes) => assert(partBc.value.getPartition(k) == i) - v + RVD.bytesToRegionValue(dec, region, rv)(bytes) } }) } @@ -784,13 +871,21 @@ object OrderedRVD { typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, rvd: RVD - ): OrderedRVD = apply(typ, partitioner, rvd.rdd) + ): OrderedRVD = apply(typ, partitioner, rvd.crdd) def apply( typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, - rdd: RDD[RegionValue] - ): OrderedRVD = apply(typ, partitioner, ContextRDD.weaken[RVDContext](rdd)) + codec: CodecSpec, + rdd: RDD[Array[Byte]] + ): OrderedRVD = apply( + typ, + partitioner, + ContextRDD.weaken[RVDContext](rdd).cmapPartitions { (ctx, it) => + val dec = codec.buildDecoder(typ.rowType) + val rv = RegionValue() + it.map(RVD.bytesToRegionValue(dec, ctx.region, rv)) + }) def apply( typ: OrderedRVDType, diff --git a/src/main/scala/is/hail/rvd/RVD.scala b/src/main/scala/is/hail/rvd/RVD.scala index e275fae5a370..c7a212da0c0d 100644 --- a/src/main/scala/is/hail/rvd/RVD.scala +++ b/src/main/scala/is/hail/rvd/RVD.scala @@ -1,6 +1,7 @@ package is.hail.rvd -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import is.hail.expr.types.Type +import java.io.{ ByteArrayInputStream, ByteArrayOutputStream, InputStream, OutputStream } import is.hail.HailContext import is.hail.annotations._ @@ -43,9 +44,16 @@ object RVDSpec { val hConf = hc.hadoopConf partFiles.flatMap { p => val f = path + "/parts/" + p - val in = hConf.unsafeReader(f) - HailContext.readRowsPartition(codecSpec.buildDecoder(rowType))(0, in) - .map(rv => SafeRow(rowType, rv.region, rv.offset)) + hConf.readFile(f) { in => + using(RVDContext.default) { ctx => + HailContext.readRowsPartition(codecSpec.buildDecoder(rowType))(ctx, in) + .map { rv => + val r = SafeRow(rowType, rv.region, rv.offset) + ctx.region.clear() + r + } + } + } }.toFastIndexedSeq } } @@ -102,19 +110,19 @@ object RVD { val hConf = hc.hadoopConf hConf.mkDir(path + "/parts") - val os = hConf.unsafeWriter(path + "/parts/part-0") - val part0Count = Region.scoped { region => - val rvb = new RegionValueBuilder(region) - val rv = RegionValue(region) - RichContextRDDRegionValue.writeRowsPartition(codecSpec.buildEncoder(rowType))(0, - rows.iterator.map { a => - region.clear() - rvb.start(rowType) - rvb.addAnnotation(rowType, a) - rv.setOffset(rvb.end()) - rv - }, os) - } + val part0Count = + hConf.writeFile(path + "/parts/part-0") { os => + using(RVDContext.default) { ctx => + val rvb = ctx.rvb + val region = ctx.region + RichContextRDDRegionValue.writeRowsPartition(codecSpec.buildEncoder(rowType))(ctx, + rows.iterator.map { a => + rvb.start(rowType) + rvb.addAnnotation(rowType, a) + RegionValue(region, rvb.end()) + }, os) + } + } val spec = UnpartitionedRVDSpec(rowType, codecSpec, Array("part-0")) spec.write(hConf, path) @@ -129,15 +137,130 @@ object RVD { new UnpartitionedRVD(first.rowType, ContextRDD.union(sc, rvds.map(_.crdd))) } + + val memoryCodec = CodecSpec.defaultUncompressed + + val wireCodec = memoryCodec + + def regionValueToBytes( + makeEnc: OutputStream => Encoder, + ctx: RVDContext + )(rv: RegionValue + ): Array[Byte] = + using(new ByteArrayOutputStream()) { baos => + using(makeEnc(baos)) { enc => + enc.writeRegionValue(rv.region, rv.offset) + enc.flush() + ctx.region.clear() + baos.toByteArray + } + } + + def bytesToRegionValue( + makeDec: InputStream => Decoder, + r: Region, + carrierRv: RegionValue + )(bytes: Array[Byte] + ): RegionValue = + using(new ByteArrayInputStream(bytes)) { bais => + using(makeDec(bais)) { dec => + carrierRv.setOffset(dec.readRegionValue(r)) + carrierRv + } + } } trait RVD { self => + def rowType: TStruct def crdd: ContextRDD[RVDContext, RegionValue] - def rdd: RDD[RegionValue] + private[rvd] def stabilize( + unstable: ContextRDD[RVDContext, RegionValue], + codec: CodecSpec = RVD.memoryCodec + ): ContextRDD[RVDContext, Array[Byte]] = { + val enc = codec.buildEncoder(rowType) + unstable.cmapPartitions { (ctx, it) => + it.map(RVD.regionValueToBytes(enc, ctx)) + } + } + + private[rvd] def destabilize( + stable: ContextRDD[RVDContext, Array[Byte]], + codec: CodecSpec = RVD.memoryCodec + ): ContextRDD[RVDContext, RegionValue] = { + val dec = codec.buildDecoder(rowType) + stable.cmapPartitions { (ctx, it) => + val rv = RegionValue(ctx.region) + it.map(RVD.bytesToRegionValue(dec, ctx.region, rv)) + } + } + + private[rvd] def stably( + f: ContextRDD[RVDContext, Array[Byte]] => ContextRDD[RVDContext, Array[Byte]] + ): ContextRDD[RVDContext, RegionValue] = stably(crdd, f) + + private[rvd] def stably( + unstable: ContextRDD[RVDContext, RegionValue], + f: ContextRDD[RVDContext, Array[Byte]] => ContextRDD[RVDContext, Array[Byte]] + ): ContextRDD[RVDContext, RegionValue] = destabilize(f(stabilize(unstable))) + + private[rvd] def crddBoundary: ContextRDD[RVDContext, RegionValue] = + crdd.cmapPartitionsAndContext { (consumerCtx, part) => + val producerCtx = consumerCtx.freshContext + val it = part.flatMap(_ (producerCtx)) + new Iterator[RegionValue]() { + private[this] var cleared: Boolean = false + + def hasNext = { + if (!cleared) { + cleared = true + producerCtx.region.clear() + } + it.hasNext + } + + def next = { + if (!cleared) { + producerCtx.region.clear() + } + cleared = false + it.next + } + } + } + + def boundary: RVD + + def encodedRDD(codec: CodecSpec): RDD[Array[Byte]] = + stabilize(crdd, codec).run + + def head(n: Long): RVD + + final def takeAsBytes(n: Int, codec: CodecSpec): Array[Array[Byte]] = + head(n).encodedRDD(codec).collect() + + final def take(n: Int, codec: CodecSpec): Array[Row] = { + val dec = codec.buildDecoder(rowType) + val encodedData = takeAsBytes(n, codec) + Region.scoped { region => + encodedData.iterator + .map(RVD.bytesToRegionValue(dec, region, RegionValue(region))) + .map { rv => + val row = SafeRow(rowType, rv) + region.clear() + row + }.toArray + } + } + + def forall(p: RegionValue => Boolean): Boolean = + crdd.map(p).run.forall(x => x) + + def exists(p: RegionValue => Boolean): Boolean = + crdd.map(p).run.exists(x => x) def sparkContext: SparkContext = crdd.sparkContext @@ -154,14 +277,36 @@ trait RVD { it.map { rv => f(c, rv) } }) - def map[T](f: (RegionValue) => T)(implicit tct: ClassTag[T]): RDD[T] = rdd.map(f) - def mapPartitions(newRowType: TStruct)(f: (Iterator[RegionValue]) => Iterator[RegionValue]): RVD = new UnpartitionedRVD(newRowType, crdd.mapPartitions(f)) - def mapPartitionsWithIndex[T](f: (Int, Iterator[RegionValue]) => Iterator[T])(implicit tct: ClassTag[T]): RDD[T] = rdd.mapPartitionsWithIndex(f) + def mapPartitions(newRowType: TStruct, f: (RVDContext, Iterator[RegionValue]) => Iterator[RegionValue]): RVD = + new UnpartitionedRVD(newRowType, crdd.cmapPartitions(f)) + + def find(codec: CodecSpec, p: (RegionValue) => Boolean): Option[Array[Byte]] = + filter(p).takeAsBytes(1, codec).headOption + + def find(region: Region)(p: (RegionValue) => Boolean): Option[RegionValue] = + find(RVD.wireCodec, p).map( + RVD.bytesToRegionValue(RVD.wireCodec.buildDecoder(rowType), region, RegionValue(region))) + + // Only use on CRDD's whose T is not dependent on the context + private[rvd] def clearingRun[T: ClassTag]( + crdd: ContextRDD[RVDContext, T] + ): RDD[T] = crdd.cmap { (ctx, v) => + ctx.region.clear() + v + }.run - def mapPartitions[T](f: (Iterator[RegionValue]) => Iterator[T])(implicit tct: ClassTag[T]): RDD[T] = rdd.mapPartitions(f) + def map[T](f: (RegionValue) => T)(implicit tct: ClassTag[T]): RDD[T] = clearingRun(crdd.map(f)) + + def mapPartitionsWithIndex[T](f: (Int, Iterator[RegionValue]) => Iterator[T])(implicit tct: ClassTag[T]): RDD[T] = clearingRun(crdd.mapPartitionsWithIndex(f)) + + def mapPartitionsWithIndex[T: ClassTag]( + f: (Int, RVDContext, Iterator[RegionValue]) => Iterator[T] + ): RDD[T] = clearingRun(crdd.cmapPartitionsWithIndex(f)) + + def mapPartitions[T](f: (Iterator[RegionValue]) => Iterator[T])(implicit tct: ClassTag[T]): RDD[T] = clearingRun(crdd.mapPartitions(f)) def constrainToOrderedPartitioner( ordType: OrderedRVDType, @@ -171,7 +316,7 @@ trait RVD { def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, RegionValue) => U, combOp: (U, U) => U, - depth: Int = treeAggDepth(HailContext.get, rdd.getNumPartitions) + depth: Int = treeAggDepth(HailContext.get, crdd.getNumPartitions) ): U = crdd.treeAggregate(zeroValue, seqOp, combOp, depth) def aggregate[U: ClassTag]( @@ -180,50 +325,45 @@ trait RVD { combOp: (U, U) => U ): U = crdd.aggregate(zeroValue, seqOp, combOp) - def count(): Long = rdd.count() - - def countPerPartition(): Array[Long] = rdd.countPerPartition() + def count(): Long = + crdd.cmapPartitions { (ctx, it) => + var count = 0L + it.foreach { rv => + count += 1 + ctx.region.clear() + } + Iterator.single(count) + }.run.fold(0L)(_ + _) + + def countPerPartition(): Array[Long] = + crdd.cmapPartitions { (ctx, it) => + var count = 0L + it.foreach { rv => + count += 1 + ctx.region.clear() + } + Iterator.single(count) + }.collect() protected def persistRVRDD(level: StorageLevel): PersistedRVRDD = { val localRowType = rowType - val persistCodec = - new PackCodecSpec( - new BlockingBufferSpec(32 * 1024, - new StreamBlockBufferSpec)) - - val makeEnc = persistCodec.buildEncoder(localRowType) + val makeEnc = RVD.memoryCodec.buildEncoder(localRowType) - val makeDec = persistCodec.buildDecoder(localRowType) + val makeDec = RVD.memoryCodec.buildDecoder(localRowType) // copy, persist region values - val persistedRDD = rdd.mapPartitions { it => - it.map { rv => - using(new ByteArrayOutputStream()) { baos => - using(makeEnc(baos)) { enc => - enc.writeRegionValue(rv.region, rv.offset) - enc.flush() - baos.toByteArray - } - } - } - } + val persistedRDD = crdd.cmapPartitions { (ctx, it) => + it.map(RVD.regionValueToBytes(makeEnc, ctx)) + } .run .persist(level) PersistedRVRDD(persistedRDD, ContextRDD.weaken[RVDContext](persistedRDD) - .mapPartitions { it => - val region = Region() - val rv2 = RegionValue(region) - it.map { bytes => - region.clear() - using(new ByteArrayInputStream(bytes)) { bais => - using(makeDec(bais)) { dec => - rv2.setOffset(dec.readRegionValue(region)) - rv2 - } - } - } + .cmapPartitions { (ctx, it) => + val region = ctx.region + val rv = RegionValue(region) + it.map(RVD.bytesToRegionValue(makeDec, region, rv)) }) } @@ -249,7 +389,7 @@ trait RVD { def toRows: RDD[Row] = { val localRowType = rowType - crdd.map { rv => SafeRow(localRowType, rv.region, rv.offset) }.run + map(rv => SafeRow(localRowType, rv.region, rv.offset)) } def toUnpartitionedRVD: UnpartitionedRVD diff --git a/src/main/scala/is/hail/rvd/RVDContext.scala b/src/main/scala/is/hail/rvd/RVDContext.scala index f26f41ace514..a9003393cc32 100644 --- a/src/main/scala/is/hail/rvd/RVDContext.scala +++ b/src/main/scala/is/hail/rvd/RVDContext.scala @@ -28,10 +28,6 @@ class RVDContext(r: Region) extends AutoCloseable { private[this] val theRvb = new RegionValueBuilder(r) def rvb = theRvb - def reset(): Unit = { - r.clear() - } - // frees the memory associated with this context def close(): Unit = { var e: Exception = null diff --git a/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala b/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala index 2eb8926d23aa..4c1d3eb25e48 100644 --- a/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala +++ b/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala @@ -17,10 +17,10 @@ object UnpartitionedRVD { class UnpartitionedRVD(val rowType: TStruct, val crdd: ContextRDD[RVDContext, RegionValue]) extends RVD { self => - def this(rowType: TStruct, rdd: RDD[RegionValue]) = - this(rowType, ContextRDD.weaken[RVDContext](rdd)) + def boundary = new UnpartitionedRVD(rowType, crddBoundary) - val rdd = crdd.run + def head(n: Long): UnpartitionedRVD = + new UnpartitionedRVD(rowType, crdd.head(n)) def filter(f: (RegionValue) => Boolean): UnpartitionedRVD = new UnpartitionedRVD(rowType, crdd.filter(f)) @@ -52,7 +52,12 @@ class UnpartitionedRVD(val rowType: TStruct, val crdd: ContextRDD[RVDContext, Re UnpartitionedRVDSpec(rowType, codecSpec, partFiles) def coalesce(maxPartitions: Int, shuffle: Boolean): UnpartitionedRVD = - new UnpartitionedRVD(rowType, crdd.coalesce(maxPartitions, shuffle = shuffle)) + new UnpartitionedRVD( + rowType, + if (shuffle) + stably(_.shuffleCoalesce(maxPartitions)) + else + crdd.noShuffleCoalesce(maxPartitions)) def constrainToOrderedPartitioner( ordType: OrderedRVDType, diff --git a/src/main/scala/is/hail/sparkextras/BlockedRDD.scala b/src/main/scala/is/hail/sparkextras/BlockedRDD.scala index 2c1f2408dd5a..5405c3a1d6f1 100644 --- a/src/main/scala/is/hail/sparkextras/BlockedRDD.scala +++ b/src/main/scala/is/hail/sparkextras/BlockedRDD.scala @@ -66,4 +66,4 @@ class BlockedRDD[T](@transient var prev: RDD[T], .keys .toFastSeq } -} \ No newline at end of file +} diff --git a/src/main/scala/is/hail/sparkextras/ContextRDD.scala b/src/main/scala/is/hail/sparkextras/ContextRDD.scala index 6e9c84a760ce..d7ea9e3fe06e 100644 --- a/src/main/scala/is/hail/sparkextras/ContextRDD.scala +++ b/src/main/scala/is/hail/sparkextras/ContextRDD.scala @@ -151,8 +151,10 @@ class ContextRDD[C <: AutoCloseable, T: ClassTag]( cmapPartitions((_, part) => f(part), preservesPartitioning) def mapPartitionsWithIndex[U: ClassTag]( - f: (Int, Iterator[T]) => Iterator[U] - ): ContextRDD[C, U] = cmapPartitionsWithIndex((i, _, part) => f(i, part)) + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false + ): ContextRDD[C, U] = + cmapPartitionsWithIndex((i, _, part) => f(i, part), preservesPartitioning) // FIXME: delete when region values are non-serializable def aggregate[U: ClassTag]( @@ -373,13 +375,11 @@ class ContextRDD[C <: AutoCloseable, T: ClassTag]( onRDD(rdd => new AdjustedPartitionsRDD(rdd, contextIgnorantAdjustments)) } - def coalesce(numPartitions: Int, shuffle: Boolean = false): ContextRDD[C, T] = - // NB: the run marks the end of a context lifetime, the next one starts - // after the shuffle - if (shuffle) - ContextRDD.weaken(run.coalesce(numPartitions, shuffle), mkc) - else - onRDD(_.coalesce(numPartitions, shuffle)) + def noShuffleCoalesce(numPartitions: Int): ContextRDD[C, T] = + onRDD(_.coalesce(numPartitions, false)) + + def shuffleCoalesce(numPartitions: Int): ContextRDD[C, T] = + ContextRDD.weaken(run.coalesce(numPartitions, true), mkc) def sample( withReplacement: Boolean, @@ -401,15 +401,73 @@ class ContextRDD[C <: AutoCloseable, T: ClassTag]( }, preservesPartitioning = true) } - def partitionSizes: Array[Long] = - // safe because we don't actually touch the offsets, we just count how many - // there are - sparkContext.runJob(run, getIteratorSize _) + def head(n: Long): ContextRDD[C, T] = { + require(n >= 0) + + val sc = sparkContext + val nPartitions = getNumPartitions + + var partScanned = 0 + var nLeft = n + var idxLast = -1 + var nLast = 0L + var numPartsToTry = 1L + + while (nLeft > 0 && partScanned < nPartitions) { + val nSeen = n - nLeft + + if (partScanned > 0) { + // If we didn't find any rows after the previous iteration, quadruple and retry. + // Otherwise, interpolate the number of partitions we need to try, but overestimate + // it by 50%. We also cap the estimation in the end. + if (nSeen == 0) { + numPartsToTry = partScanned * 4 + } else { + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max((1.5 * n * partScanned / nSeen).toInt - partScanned, 1) + numPartsToTry = Math.min(numPartsToTry, partScanned * 4) + } + } + + val p = partScanned.until(math.min(partScanned + numPartsToTry, nPartitions).toInt) + val counts = runJob(getIteratorSizeWithMaxN(nLeft), p) + + p.zip(counts).foreach { case (idx, c) => + if (nLeft > 0) { + idxLast = idx + nLast = if (c < nLeft) c else nLeft + nLeft -= nLast + } + } + + partScanned += p.size + } + + mapPartitionsWithIndex({ case (i, it) => + if (i == idxLast) + it.take(nLast.toInt) + else + it + }, preservesPartitioning = true) + .subsetPartitions((0 to idxLast).toArray) + } + + def runJob[U: ClassTag](f: Iterator[T] => U, partitions: Seq[Int]): Array[U] = + sparkContext.runJob( + rdd, + { (it: Iterator[ElementType]) => using(mkc())(c => f(it.flatMap(_(c)))) }, + partitions) + + def blocked(partitionEnds: Array[Int]): ContextRDD[C, T] = + new ContextRDD(new BlockedRDD(rdd, partitionEnds), mkc) def sparkContext: SparkContext = rdd.sparkContext def getNumPartitions: Int = rdd.getNumPartitions + def preferredLocations(partition: Partition): Seq[String] = + rdd.preferredLocations(partition) + private[this] def clean[T <: AnyRef](value: T): T = ExposedUtils.clean(sparkContext, value) diff --git a/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD.scala b/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD.scala index b6eb9b9f8ec8..f8d504ea2b47 100644 --- a/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD.scala +++ b/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD.scala @@ -19,8 +19,8 @@ object RepartitionedOrderedRDD { new RepartitionedOrderedRDD( prev.crdd, prev.typ, - prev.partitioner.broadcast(prev.rdd.sparkContext), - newPartitioner.broadcast(prev.rdd.sparkContext)) + prev.partitioner.broadcast(prev.crdd.sparkContext), + newPartitioner.broadcast(prev.crdd.sparkContext)) } } diff --git a/src/main/scala/is/hail/stats/BaldingNicholsModel.scala b/src/main/scala/is/hail/stats/BaldingNicholsModel.scala index 31454c08e40b..c6cfc7f45f15 100644 --- a/src/main/scala/is/hail/stats/BaldingNicholsModel.scala +++ b/src/main/scala/is/hail/stats/BaldingNicholsModel.scala @@ -5,7 +5,8 @@ import breeze.stats.distributions._ import is.hail.HailContext import is.hail.annotations._ import is.hail.expr.types._ -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant.{Call2, MatrixTable, ReferenceGenome} import org.apache.commons.math3.random.JDKRandomGenerator @@ -132,10 +133,9 @@ object BaldingNicholsModel { val rvType = matrixType.rvRowType - val rdd = sc.parallelize((0 until M).view.map(m => (m, Rand.randInt.draw())), nPartitions) - .mapPartitions { it => - - val region = Region() + val rdd = ContextRDD.weaken[RVDContext](sc.parallelize((0 until M).view.map(m => (m, Rand.randInt.draw())), nPartitions)) + .cmapPartitions { (ctx, it) => + val region = ctx.region val rv = RegionValue(region) val rvb = new RegionValueBuilder(region) @@ -152,7 +152,6 @@ object BaldingNicholsModel { .draw() }) - region.clear() rvb.start(rvType) rvb.startStruct() diff --git a/src/main/scala/is/hail/table/Table.scala b/src/main/scala/is/hail/table/Table.scala index 2ee38fb502b1..bb488e03e6e1 100644 --- a/src/main/scala/is/hail/table/Table.scala +++ b/src/main/scala/is/hail/table/Table.scala @@ -176,7 +176,7 @@ object Table { globals: Annotation, sort: Boolean ): Table = { - val crdd2 = crdd.mapPartitions(_.toRegionValueIterator(signature)) + val crdd2 = crdd.cmapPartitions((ctx, it) => it.toRegionValueIterator(ctx.region, signature)) new Table(hc, TableLiteral( TableValue( TableType(signature, None, globalSignature), @@ -211,9 +211,9 @@ class Table(val hc: HailContext, val tir: TableIR) { hc: HailContext, crdd: ContextRDD[RVDContext, RegionValue], signature: TStruct, - key: Option[IndexedSeq[String]], - globalSignature: TStruct, - globals: Row + key: Option[IndexedSeq[String]] = None, + globalSignature: TStruct = TStruct.empty(), + globals: Row = Row.empty ) = this(hc, TableLiteral( TableValue( @@ -221,20 +221,6 @@ class Table(val hc: HailContext, val tir: TableIR) { BroadcastRow(globals, globalSignature, hc.sc), new UnpartitionedRVD(signature, crdd)))) - def this(hc: HailContext, - rdd: RDD[RegionValue], - signature: TStruct, - key: Option[IndexedSeq[String]], - globalSignature: TStruct, - globals: Row - ) = this( - hc, - ContextRDD.weaken[RVDContext](rdd), - signature, - key, - globalSignature, - globals) - def typ: TableType = tir.typ private def useIR(ast: AST): Boolean = { @@ -251,19 +237,6 @@ class Table(val hc: HailContext, val tir: TableIR) { opt.execute(hc) } - def this( - hc: HailContext, - rdd: RDD[RegionValue], - signature: TStruct, - key: Option[IndexedSeq[String]] - ) = this(hc, rdd, signature, key, TStruct.empty(), Row.empty) - - def this( - hc: HailContext, - rdd: RDD[RegionValue], - signature: TStruct - ) = this(hc, rdd, signature, None) - lazy val TableValue(ktType, globals, rvd) = value val TableType(signature, key, globalSignature) = tir.typ @@ -545,7 +518,7 @@ class Table(val hc: HailContext, val tir: TableIR) { def head(n: Long): Table = { if (n < 0) fatal(s"n must be non-negative! Found `$n'.") - copy(rdd = rdd.head(n)) + copy2(rvd = rvd.head(n)) } def keyBy(key: String*): Table = keyBy(key.toArray, key.toArray) @@ -587,9 +560,17 @@ class Table(val hc: HailContext, val tir: TableIR) { else if (ordered.typ.key.length <= keys.length) { val localSortType = new OrderedRVDType(ordered.typ.key, keys, signature) val newType = new OrderedRVDType(ordered.typ.partitionKey, keys, signature) - ordered.mapPartitionsPreservesPartitioning(newType) { it => - OrderedRVD.localKeySort(localSortType, it) - } + OrderedRVD( + newType, + ordered.partitioner, + ordered.crdd.cmapPartitionsAndContext { (consumerCtx, it) => + val producerCtx = consumerCtx.freshContext + OrderedRVD.localKeySort( + consumerCtx.region, + producerCtx.region, + localSortType, + it.flatMap(_(producerCtx))) + }) } else resort } else resort case _: UnpartitionedRVD => @@ -785,16 +766,15 @@ class Table(val hc: HailContext, val tir: TableIR) { val newRVType = matrixType.rvRowType val orderedRKStruct = matrixType.rowKeyStruct - val newRVD = ordered.mapPartitionsPreservesPartitioning(matrixType.orvdType) { it => - val region = Region() - val rvb = new RegionValueBuilder(region) + val newRVD = ordered.boundary.mapPartitionsPreservesPartitioning(matrixType.orvdType, { (ctx, it) => + val region = ctx.region + val rvb = ctx.rvb val outRV = RegionValue(region) OrderedRVIterator( new OrderedRVDType(partitionKeys, rowKeys, rowEntryStruct), it ).staircase.map { rowIt => - region.clear() rvb.start(newRVType) rvb.startStruct() var i = 0 @@ -828,7 +808,7 @@ class Table(val hc: HailContext, val tir: TableIR) { outRV.setOffset(rvb.end()) outRV } - } + }) new MatrixTable(hc, matrixType, globals, @@ -997,6 +977,7 @@ class Table(val hc: HailContext, val tir: TableIR) { } val act = implicitly[ClassTag[Annotation]] + // FIXME: need to add sortBy on rvd? copy(rdd = rdd.sortBy(identity[Annotation], ascending = true)(ord, act)) } @@ -1006,7 +987,7 @@ class Table(val hc: HailContext, val tir: TableIR) { def union(kts: Table*): Table = new Table(hc, TableUnion((tir +: kts.map(_.tir)).toFastIndexedSeq)) - def take(n: Int): Array[Row] = rdd.take(n) + def take(n: Int): Array[Row] = rvd.take(n, RVD.wireCodec) def takeJSON(n: Int): String = { val r = JSONAnnotationImpex.exportAnnotation(take(n).toFastIndexedSeq, TArray(signature)) @@ -1024,6 +1005,7 @@ class Table(val hc: HailContext, val tir: TableIR) { val (newSignature, ins) = signature.insert(TInt64(), name) + // FIXME: should use RVD, need zipWithIndex val newRDD = rdd.zipWithIndex().map { case (r, ind) => ins(r, ind).asInstanceOf[Row] } copy(signature = newSignature.asInstanceOf[TStruct], rdd = newRDD) diff --git a/src/main/scala/is/hail/utils/FlipbookIterator.scala b/src/main/scala/is/hail/utils/FlipbookIterator.scala index 74e2ae94decd..299333d7fe91 100644 --- a/src/main/scala/is/hail/utils/FlipbookIterator.scala +++ b/src/main/scala/is/hail/utils/FlipbookIterator.scala @@ -162,13 +162,16 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => def staircased(ord: OrderingView[A]): StagingIterator[FlipbookIterator[A]] = { ord.setBottom() - val stepIterator: FlipbookIterator[A] = FlipbookIterator( - new StateMachine[A] { - def value: A = self.value - def isValid: Boolean = self.isValid && ord.isEquivalent(value) - def advance() = { self.advance() } + val stepSM = new StateMachine[A] { + def value: A = self.value + var _isValid: Boolean = self.isValid && ord.isEquivalent(self.value) + def isValid = _isValid + def advance() = { + self.advance() + _isValid = self.isValid && ord.isEquivalent(self.value) } - ) + } + val stepIterator: FlipbookIterator[A] = FlipbookIterator(stepSM) val sm = new StateMachine[FlipbookIterator[A]] { var isValid: Boolean = true val value: FlipbookIterator[A] = stepIterator @@ -176,8 +179,10 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => stepIterator.exhaust() if (self.isValid) { ord.setValue(self.value) + stepSM._isValid = self.isValid && ord.isEquivalent(self.value) } else { ord.setBottom() + stepSM._isValid = false isValid = false } } diff --git a/src/main/scala/is/hail/utils/richUtils/Implicits.scala b/src/main/scala/is/hail/utils/richUtils/Implicits.scala index b6630fd8ca6d..fa5a0d19b833 100644 --- a/src/main/scala/is/hail/utils/richUtils/Implicits.scala +++ b/src/main/scala/is/hail/utils/richUtils/Implicits.scala @@ -5,6 +5,8 @@ import java.io.InputStream import breeze.linalg.DenseMatrix import is.hail.annotations.{JoinedRegionValue, Region, RegionValue} import is.hail.asm4s.Code +import is.hail.io.RichContextRDDRegionValue +import is.hail.rvd.RVDContext import is.hail.io.{InputBuffer, RichContextRDDRegionValue} import is.hail.sparkextras._ import is.hail.utils.{ArrayBuilder, HailIterator, JSONWriter, MultiArray2, Truncatable, WithContext} @@ -73,7 +75,7 @@ trait Implicits { implicit def toRichRDD[T](r: RDD[T])(implicit tct: ClassTag[T]): RichRDD[T] = new RichRDD(r) - implicit def toRichContextRDDRegionValue[C <: AutoCloseable](r: ContextRDD[C, RegionValue]): RichContextRDDRegionValue[C] = new RichContextRDDRegionValue(r) + implicit def toRichContextRDDRegionValue(r: ContextRDD[RVDContext, RegionValue]): RichContextRDDRegionValue = new RichContextRDDRegionValue(r) implicit def toRichRDDByteArray(r: RDD[Array[Byte]]): RichRDDByteArray = new RichRDDByteArray(r) @@ -111,7 +113,7 @@ trait Implicits { implicit def toContextPairRDDFunctions[C <: AutoCloseable, K: ClassTag, V: ClassTag](x: ContextRDD[C, (K, V)]): ContextPairRDDFunctions[C, K, V] = new ContextPairRDDFunctions(x) - implicit def toRichContextRDD[C <: AutoCloseable, T: ClassTag](x: ContextRDD[C, T]): RichContextRDD[C, T] = new RichContextRDD(x) + implicit def toRichContextRDD[T: ClassTag](x: ContextRDD[RVDContext, T]): RichContextRDD[T] = new RichContextRDD(x) implicit def toRichCodeInputBuffer(in: Code[InputBuffer]): RichCodeInputBuffer = new RichCodeInputBuffer(in) } diff --git a/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala b/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala index dc408531508a..fe0d597a30db 100644 --- a/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala +++ b/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala @@ -2,6 +2,7 @@ package is.hail.utils.richUtils import java.io._ +import is.hail.rvd.RVDContext import org.apache.commons.lang3.StringUtils import org.apache.spark.TaskContext import is.hail.utils._ @@ -9,9 +10,9 @@ import is.hail.sparkextras._ import scala.reflect.ClassTag -class RichContextRDD[C <: AutoCloseable, T: ClassTag](crdd: ContextRDD[C, T]) { +class RichContextRDD[T: ClassTag](crdd: ContextRDD[RVDContext, T]) { def writePartitions(path: String, - write: (Int, Iterator[T], OutputStream) => Long, + write: (RVDContext, Iterator[T], OutputStream) => Long, remapPartitions: Option[(Array[Int], Int)] = None): (Array[String], Array[Long]) = { val sc = crdd.sparkContext val hadoopConf = sc.hadoopConfiguration @@ -31,12 +32,14 @@ class RichContextRDD[C <: AutoCloseable, T: ClassTag](crdd: ContextRDD[C, T]) { val remapBc = sc.broadcast(remap) - val (partFiles, partitionCounts) = crdd.mapPartitionsWithIndex { case (index, it) => + val (partFiles, partitionCounts) = crdd.cmapPartitionsWithIndex { (index, ctx, it) => val i = remapBc.value(index) val f = partFile(d, i, TaskContext.get) val filename = path + "/parts/" + f val os = sHadoopConfBc.value.value.unsafeWriter(filename) - Iterator.single(f -> write(i, it, os)) + val out = Iterator.single(f -> write(ctx, it, os)) + ctx.region.clear() + out } .collect() .unzip diff --git a/src/main/scala/is/hail/utils/richUtils/RichIterator.scala b/src/main/scala/is/hail/utils/richUtils/RichIterator.scala index 47b962a4a061..5306820296d4 100644 --- a/src/main/scala/is/hail/utils/richUtils/RichIterator.scala +++ b/src/main/scala/is/hail/utils/richUtils/RichIterator.scala @@ -105,12 +105,10 @@ class RichIterator[T](val it: Iterator[T]) extends AnyVal { } class RichRowIterator(val it: Iterator[Row]) extends AnyVal { - def toRegionValueIterator(rowTyp: TStruct): Iterator[RegionValue] = { - val region = Region() + def toRegionValueIterator(region: Region, rowTyp: TStruct): Iterator[RegionValue] = { val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) it.map { row => - region.clear() rvb.start(rowTyp) rvb.addAnnotation(rowTyp, row) rv.setOffset(rvb.end()) diff --git a/src/main/scala/is/hail/utils/richUtils/RichRDD.scala b/src/main/scala/is/hail/utils/richUtils/RichRDD.scala index 302c0484a3c6..650efd3ec331 100644 --- a/src/main/scala/is/hail/utils/richUtils/RichRDD.scala +++ b/src/main/scala/is/hail/utils/richUtils/RichRDD.scala @@ -2,6 +2,7 @@ package is.hail.utils.richUtils import java.io.OutputStream +import is.hail.rvd.RVDContext import is.hail.sparkextras._ import is.hail.utils._ import org.apache.commons.lang3.StringUtils @@ -99,7 +100,8 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { } def subsetPartitions(keep: Array[Int])(implicit ct: ClassTag[T]): RDD[T] = { - require(keep.length <= r.partitions.length, "tried to subset to more partitions than exist") + require(keep.length <= r.partitions.length, + s"tried to subset to more partitions than exist ${keep.toSeq} ${r.partitions.toSeq}") require(keep.isIncreasing && (keep.isEmpty || (keep.head >= 0 && keep.last < r.partitions.length)), "values not sorted or not in range [0, number of partitions)") val parentPartitions = r.partitions @@ -139,7 +141,7 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { var partScanned = 0 var nLeft = n - var idxLast = 0 + var idxLast = -1 var nLast = 0L var numPartsToTry = 1L @@ -183,9 +185,12 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { } def writePartitions(path: String, - write: (Int, Iterator[T], OutputStream) => Long, + write: (Iterator[T], OutputStream) => Long, remapPartitions: Option[(Array[Int], Int)] = None )(implicit tct: ClassTag[T] ): (Array[String], Array[Long]) = - ContextRDD.weaken[TrivialContext](r).writePartitions(path, write, remapPartitions) + ContextRDD.weaken[RVDContext](r).writePartitions( + path, + (_, it, os) => write(it, os), + remapPartitions) } diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index 5dbd0b1cd8c3..e81002412c72 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -13,6 +13,8 @@ import is.hail.methods.Aggregators.ColFunctions import is.hail.utils._ import is.hail.{HailContext, utils} import is.hail.expr.types._ +import is.hail.io.CodecSpec +import is.hail.sparkextras.ContextRDD import org.apache.hadoop import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -143,8 +145,8 @@ object MatrixTable { BroadcastRow(globals.asInstanceOf[Row], matrixType.globalType, hc.sc), BroadcastIndexedSeq(colValues, TArray(matrixType.colType), hc.sc), OrderedRVD.coerce(matrixType.orvdType, - rdd.mapPartitions { it => - val region = Region() + ContextRDD.weaken[RVDContext](rdd).cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) @@ -152,7 +154,6 @@ object MatrixTable { val vaRow = va.asInstanceOf[Row] assert(matrixType.rowType.typeCheck(vaRow), s"${ matrixType.rowType }, $vaRow") - region.clear() rvb.start(localRVRowType) rvb.startStruct() var i = 0 @@ -266,7 +267,7 @@ object MatrixTable { val oldRowType = kt.signature - val rdd = kt.rvd.mapPartitions { it => + val rdd = kt.rvd.mapPartitions(matrixType.rvRowType) { it => val rvb = new RegionValueBuilder() val rv2 = RegionValue() @@ -286,7 +287,7 @@ object MatrixTable { new MatrixTable(kt.hc, matrixType, BroadcastRow(Row(), matrixType.globalType, kt.hc.sc), BroadcastIndexedSeq(Array.empty[Annotation], TArray(matrixType.colType), kt.hc.sc), - OrderedRVD.coerce(matrixType.orvdType, rdd, None, None)) + OrderedRVD.coerce(matrixType.orvdType, rdd)) } } @@ -655,13 +656,13 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val selectIdx = matrixType.orvdType.kRowFieldIdx val keyOrd = matrixType.orvdType.kRowOrd - val newRVD = rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType) { it => + val newRVD = rvd.boundary.mapPartitionsPreservesPartitioning(newMatrixType.orvdType, { (ctx, it) => new Iterator[RegionValue] { var isEnd = false var current: RegionValue = null val rvRowKey: WritableRegionValue = WritableRegionValue(newRowType) - val region = Region() - val rvb = new RegionValueBuilder(region) + val region = ctx.region + val rvb = ctx.rvb val newRV = RegionValue(region) def hasNext: Boolean = { @@ -683,7 +684,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { aggs = seqOp(aggs, current) current = null } - region.clear() rvb.start(newRVType) rvb.startStruct() var i = 0 @@ -697,7 +697,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { newRV } } - } + }) copyMT(rvd = newRVD, matrixType = newMatrixType) } @@ -780,18 +780,17 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val (newRVType, ins) = rvRowType.unsafeStructInsert(valueType, List(root)) val rightRowType = rightRVD.rowType + val leftRowType = rvRowType val rightValueIndices = rightRVD.typ.valueIndices assert(!product || rightValueIndices.length == 1) - val joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue] = { it => - val rvb = new RegionValueBuilder() + val joiner = { (ctx: RVDContext, it: Iterator[JoinedRegionValue]) => + val rvb = ctx.rvb val rv = RegionValue() it.map { jrv => val lrv = jrv.rvLeft - - rvb.set(lrv.region) rvb.start(newRVType) ins(lrv.region, lrv.offset, rvb, () => { @@ -815,7 +814,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } } }) - rv.set(lrv.region, rvb.end()) + rv.set(ctx.region, rvb.end()) rv } } @@ -847,15 +846,21 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val ktSignature = kt.signature val ktKeyFieldIdx = kt.keyFieldIdx.get(0) val ktValueFieldIdx = kt.valueFieldIdx - val partitionKeyedIntervals = kt.rvd.rdd - .flatMap { rv => + val partitionKeyedIntervals = kt.rvd.boundary.crdd + .cflatMap { (ctx, rv) => + val region = ctx.region + val rv2 = RegionValue(region) + val rvb = ctx.rvb + rvb.start(ktSignature) + rvb.addRegionValue(ktSignature, rv) + rv2.setOffset(rvb.end()) val ur = new UnsafeRow(ktSignature, rv) val interval = ur.getAs[Interval](ktKeyFieldIdx) if (interval != null) { val rangeTree = partBc.value.rangeTree val pkOrd = partBc.value.pkType.ordering val wrappedInterval = interval.copy(start = Row(interval.start), end = Row(interval.end)) - rangeTree.queryOverlappingValues(pkOrd, wrappedInterval).map(i => (i, rv)) + rangeTree.queryOverlappingValues(pkOrd, wrappedInterval).map(i => (i, rv2)) } else Iterator() } @@ -876,9 +881,9 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { ) { case (it, intervals) => val intervalAnnotations: Array[(Interval, Any)] = intervals.map { rv => - val ur = new UnsafeRow(ktSignature, rv) - val interval = ur.getAs[Interval](ktKeyFieldIdx) - (interval, Row.fromSeq(ktValueFieldIdx.map(ur.get))) + val r = SafeRow(ktSignature, rv) + val interval = r.getAs[Interval](ktKeyFieldIdx) + (interval, Row.fromSeq(ktValueFieldIdx.map(r.get))) }.toArray val iTree = IntervalTree.annotationTree(typOrdering, intervalAnnotations) @@ -1083,7 +1088,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } if (touchesKeys) { warn("modified row key, rescanning to compute ordering...") - val newRDD = rvd.mapPartitions(mapPartitionsF) + val newRDD = rvd.mapPartitions(newMatrixType.rvRowType)(mapPartitionsF) copyMT(matrixType = newMatrixType, rvd = OrderedRVD.coerce(newMatrixType.orvdType, newRDD, None, None)) } else copyMT(matrixType = newMatrixType, @@ -1145,7 +1150,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { def forceCountRows(): Long = rvd.count() def deduplicate(): MatrixTable = - copy2(rvd = rvd.mapPartitionsPreservesPartitioning(rvd.typ)( + copy2(rvd = rvd.boundary.mapPartitionsPreservesPartitioning(rvd.typ)( SortedDistinctRowIterator.transformer(rvd.typ))) def deleteVA(args: String*): (Type, Deleter) = deleteVA(args.toList) @@ -1172,10 +1177,10 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val localEntriesIndex = entriesIndex - val explodedRDD = rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType) { it => - val region2 = Region() + val explodedRDD = rvd.boundary.mapPartitionsPreservesPartitioning(newMatrixType.orvdType, { (ctx, it) => + val region2 = ctx.region val rv2 = RegionValue(region2) - val rv2b = new RegionValueBuilder(region2) + val rv2b = ctx.rvb val ur = new UnsafeRow(oldRVType) it.flatMap { rv => ur.set(rv) @@ -1184,7 +1189,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { None else keys.iterator.map { explodedElement => - region2.clear() rv2b.start(newRVType) inserter(rv.region, rv.offset, rv2b, () => rv2b.addAnnotation(keyType, explodedElement)) @@ -1192,7 +1196,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { rv2 } } - } + }) copyMT(matrixType = newMatrixType, rvd = explodedRDD) } @@ -1433,15 +1437,14 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val localEntriesType = matrixType.entryArrayType assert(right.matrixType.entryArrayType == localEntriesType) - val joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue] = { it => - val rvb = new RegionValueBuilder() + val joiner = { (ctx: RVDContext, it: Iterator[JoinedRegionValue]) => + val rvb = ctx.rvb val rv2 = RegionValue() it.map { jrv => val lrv = jrv.rvLeft val rrv = jrv.rvRight - rvb.set(lrv.region) rvb.start(leftRVType) rvb.startStruct() var i = 0 @@ -1474,7 +1477,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { rvb.endArray() rvb.endStruct() - rv2.set(lrv.region, rvb.end()) + rv2.set(ctx.region, rvb.end()) rv2 } } @@ -1889,55 +1892,47 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val localColKeys = colKeys metadataSame && - rvd.rdd.zipPartitions( + rvd.crdd.czip( OrderedRVD.adjustBoundsAndShuffle( that.rvd.typ, rvd.partitioner.withKType(that.rvd.typ.partitionKey, that.rvd.typ.kType), - that.rvd.rdd) - .rdd) { (it1, it2) => + that.rvd) + .crdd) { (ctx, rv1, rv2) => + var partSame = true + val fullRow1 = new UnsafeRow(leftRVType) val fullRow2 = new UnsafeRow(rightRVType) - var partSame = true - while (it1.hasNext && it2.hasNext) { - val rv1 = it1.next() - val rv2 = it2.next() - fullRow1.set(rv1) - fullRow2.set(rv2) - val row1 = fullRow1.deleteField(localLeftEntriesIndex) - val row2 = fullRow2.deleteField(localRightEntriesIndex) + fullRow1.set(rv1) + fullRow2.set(rv2) + val row1 = fullRow1.deleteField(localLeftEntriesIndex) + val row2 = fullRow2.deleteField(localRightEntriesIndex) - if (!localRowType.valuesSimilar(row1, row2, tolerance, absolute)) { - println( - s"""row fields not the same: + if (!localRowType.valuesSimilar(row1, row2, tolerance, absolute)) { + println( + s"""row fields not the same: | $row1 | $row2""".stripMargin) - partSame = false - } + partSame = false + } - val gs1 = fullRow1.getAs[IndexedSeq[Annotation]](localLeftEntriesIndex) - val gs2 = fullRow2.getAs[IndexedSeq[Annotation]](localRightEntriesIndex) + val gs1 = fullRow1.getAs[IndexedSeq[Annotation]](localLeftEntriesIndex) + val gs2 = fullRow2.getAs[IndexedSeq[Annotation]](localRightEntriesIndex) - var i = 0 - while (partSame && i < gs1.length) { - if (!localEntryType.valuesSimilar(gs1(i), gs2(i), tolerance, absolute)) { - partSame = false - println( - s"""different entry at row ${ localRKF(row1) }, col ${ localColKeys(i) } + var i = 0 + while (partSame && i < gs1.length) { + if (!localEntryType.valuesSimilar(gs1(i), gs2(i), tolerance, absolute)) { + partSame = false + println( + s"""different entry at row ${ localRKF(row1) }, col ${ localColKeys(i) } | ${ gs1(i) } | ${ gs2(i) }""".stripMargin) - } - i += 1 } + i += 1 } - - if ((it1.hasNext || it2.hasNext) && partSame) { - println("partition has different number of rows") - partSame = false - } - - Iterator(partSame) - }.forall(t => t) + ctx.region.clear() + partSame + }.run.forall(t => t) } def colEC: EvalContext = { @@ -2025,17 +2020,22 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } val localRVRowType = rvRowType - rvd.map { rv => - new UnsafeRow(localRVRowType, rv) - }.find(ur => !localRVRowType.typeCheck(ur)) - .foreach { ur => + val predicate = { (rv: RegionValue) => + val ur = new UnsafeRow(localRVRowType, rv) + !localRVRowType.typeCheck(ur) + } + + Region.scoped { region => + rvd.find(region)(predicate).foreach { rv => + val ur = new UnsafeRow(localRVRowType, rv) foundError = true warn( s"""found violation in row - |Schema: $localRVRowType - |Annotation: ${ Annotation.printAnnotation(ur) }""".stripMargin) + |Schema: $localRVRowType + |Annotation: ${ Annotation.printAnnotation(ur) }""".stripMargin) } + } if (foundError) fatal("found one or more type check errors") @@ -2194,12 +2194,13 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val locusIndex = rvRowType.fieldIdx("locus") val allelesIndex = rvRowType.fieldIdx("alleles") - def minRep1(removeLeftAligned: Boolean, removeMoving: Boolean, verifyLeftAligned: Boolean): RDD[RegionValue] = { - rvd.mapPartitions { it => + def minRep1(removeLeftAligned: Boolean, removeMoving: Boolean, verifyLeftAligned: Boolean): RVD = { + rvd.mapPartitions(localRVRowType) { it => var prevLocus: Locus = null val rvb = new RegionValueBuilder() val rv2 = RegionValue() + // FIXME: how is this not broken? it.flatMap { rv => val ur = new UnsafeRow(localRVRowType, rv.region, rv.offset) @@ -2378,7 +2379,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val fieldIdx = entryType.fieldIdx(entryField) val numColsLocal = numCols - val rows = rvd.mapPartitionsWithIndex { case (pi, it) => + val rows = rvd.mapPartitionsWithIndex { (pi, it) => var i = partStartsBc.value(pi) it.map { rv => val region = rv.region @@ -2420,7 +2421,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { hadoop.mkDir(dirname + "/parts") val gp = GridPartitioner(blockSize, nRows, localNCols) val blockPartFiles = - new WriteBlocksRDD(dirname, rvd.rdd, sparkContext, matrixType, partStarts, entryField, gp) + new WriteBlocksRDD(dirname, rvd.crdd, sparkContext, matrixType, partStarts, entryField, gp) .collect() val blockCount = blockPartFiles.length @@ -2446,15 +2447,14 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val partStarts = partitionStarts() val newMatrixType = matrixType.copy(rvRowType = newRVType) - val indexedRVD = rvd.mapPartitionsWithIndexPreservesPartitioning(newMatrixType.orvdType) { case (i, it) => - val region2 = Region() + val indexedRVD = rvd.boundary.mapPartitionsWithIndexPreservesPartitioning(newMatrixType.orvdType, { (i, ctx, it) => + val region2 = ctx.region val rv2 = RegionValue(region2) - val rv2b = new RegionValueBuilder(region2) + val rv2b = ctx.rvb var idx = partStarts(i) it.map { rv => - region2.clear() rv2b.start(newRVType) inserter(rv.region, rv.offset, rv2b, @@ -2464,7 +2464,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { rv2.setOffset(rv2b.end()) rv2 } - } + }) copyMT(matrixType = newMatrixType, rvd = indexedRVD) } diff --git a/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala b/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala index 9d2860652f9b..f9724f7bb4da 100644 --- a/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala +++ b/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala @@ -342,30 +342,31 @@ class LocalLDPruneSuite extends SparkSuite { } @Test def bitPackedVectorCorrectWhenOffsetNotZero() { - val r = Region() - val rvb = new RegionValueBuilder(r) - val t = BitPackedVectorView.rvRowType( - +TLocus(ReferenceGenome.GRCh37), - +TArray(+TString())) - val bpv = new BitPackedVectorView(t) - r.appendInt(0xbeef) - rvb.start(t) - rvb.startStruct() - rvb.startStruct() - rvb.addString("X") - rvb.addInt(42) - rvb.endStruct() - rvb.startArray(0) - rvb.endArray() - rvb.startArray(0) - rvb.endArray() - rvb.addInt(0) - rvb.addDouble(0.0) - rvb.addDouble(0.0) - rvb.endStruct() - bpv.setRegion(r, rvb.end()) - assert(bpv.getContig == "X") - assert(bpv.getStart == 42) + Region.scoped { r => + val rvb = new RegionValueBuilder(r) + val t = BitPackedVectorView.rvRowType( + +TLocus(ReferenceGenome.GRCh37), + +TArray(+TString())) + val bpv = new BitPackedVectorView(t) + r.appendInt(0xbeef) + rvb.start(t) + rvb.startStruct() + rvb.startStruct() + rvb.addString("X") + rvb.addInt(42) + rvb.endStruct() + rvb.startArray(0) + rvb.endArray() + rvb.startArray(0) + rvb.endArray() + rvb.addInt(0) + rvb.addDouble(0.0) + rvb.addDouble(0.0) + rvb.endStruct() + bpv.setRegion(r, rvb.end()) + assert(bpv.getContig == "X") + assert(bpv.getStart == 42) + } } @Test def testIsLocallyUncorrelated() { diff --git a/src/test/scala/is/hail/testUtils/RichMatrixTable.scala b/src/test/scala/is/hail/testUtils/RichMatrixTable.scala index 2c27d1b5b36e..d6e633c0fd41 100644 --- a/src/test/scala/is/hail/testUtils/RichMatrixTable.scala +++ b/src/test/scala/is/hail/testUtils/RichMatrixTable.scala @@ -87,7 +87,8 @@ class RichMatrixTable(vsm: MatrixTable) { val localEntriesIndex = vsm.entriesIndex val localRowType = vsm.rowType val rowKeyF = vsm.rowKeysF - vsm.rvd.rdd.map { rv => + vsm.rvd.map { rv => + val unsafeFullRow = new UnsafeRow(fullRowType, rv) val fullRow = SafeRow(fullRowType, rv.region, rv.offset) val row = fullRow.deleteField(localEntriesIndex) (rowKeyF(fullRow), (row, fullRow.getAs[IndexedSeq[Any]](localEntriesIndex))) diff --git a/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala b/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala index 23bdbe3a23fd..2ff33a3cdff8 100644 --- a/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala +++ b/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala @@ -2,6 +2,7 @@ package is.hail.utils import is.hail.SparkSuite import org.testng.annotations.Test + import scala.collection.generic.Growable import scala.collection.mutable.ArrayBuffer diff --git a/src/test/scala/is/hail/utils/RichRDDSuite.scala b/src/test/scala/is/hail/utils/RichRDDSuite.scala index 21f8d231cfcf..6af6a8503c82 100644 --- a/src/test/scala/is/hail/utils/RichRDDSuite.scala +++ b/src/test/scala/is/hail/utils/RichRDDSuite.scala @@ -11,10 +11,10 @@ class RichRDDSuite extends SparkSuite { @Test def testHead() { val r = sc.parallelize(0 until 1024, numSlices = 20) - val partitionRanges = r.countPerPartition().scanLeft(Range(0, 0)) { case (x, c) => Range(x.end, x.end + c.toInt) } + val partitionRanges = r.countPerPartition().scanLeft(Range(0, 1)) { case (x, c) => Range(x.end, x.end + c.toInt + 1) } def getExpectedNumPartitions(n: Int): Int = - partitionRanges.indexWhere(_.contains(math.max(0, n - 1))) + partitionRanges.indexWhere(_.contains(n)) for (n <- Array(0, 15, 200, 562, 1024, 2000)) { val t = r.head(n) diff --git a/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala b/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala index 4b1655158e7a..eaa2f26e1756 100644 --- a/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala +++ b/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala @@ -75,8 +75,8 @@ class PartitioningSuite extends SparkSuite { val mt = MatrixTable.fromRowsTable(Table.range(hc, 100, nPartitions=Some(6))) val orvdType = mt.matrixType.orvdType - mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "left", _.map(_._1), orvdType).count() - mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "inner", _.map(_._1), orvdType).count() + mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "left", (_, it) => it.map(_._1), orvdType).count() + mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "inner", (_, it) => it.map(_._1), orvdType).count() } @Test def testEmptyRDDOrderedJoin() { @@ -86,9 +86,9 @@ class PartitioningSuite extends SparkSuite { val nonEmptyRVD = mt.rvd val emptyRVD = OrderedRVD.empty(hc.sc, orvdType) - emptyRVD.orderedJoin(nonEmptyRVD, "left", _.map(_._1), orvdType).count() - emptyRVD.orderedJoin(nonEmptyRVD, "inner", _.map(_._1), orvdType).count() - nonEmptyRVD.orderedJoin(emptyRVD, "left", _.map(_._1), orvdType).count() - nonEmptyRVD.orderedJoin(emptyRVD, "inner", _.map(_._1), orvdType).count() + emptyRVD.orderedJoin(nonEmptyRVD, "left", (_, it) => it.map(_._1), orvdType).count() + emptyRVD.orderedJoin(nonEmptyRVD, "inner", (_, it) => it.map(_._1), orvdType).count() + nonEmptyRVD.orderedJoin(emptyRVD, "left", (_, it) => it.map(_._1), orvdType).count() + nonEmptyRVD.orderedJoin(emptyRVD, "inner", (_, it) => it.map(_._1), orvdType).count() } } diff --git a/src/test/scala/is/hail/vds/JoinSuite.scala b/src/test/scala/is/hail/vds/JoinSuite.scala index 1020844eb98c..863c1c01220e 100644 --- a/src/test/scala/is/hail/vds/JoinSuite.scala +++ b/src/test/scala/is/hail/vds/JoinSuite.scala @@ -1,7 +1,7 @@ package is.hail.vds import is.hail.SparkSuite -import is.hail.annotations.UnsafeRow +import is.hail.annotations._ import is.hail.expr.types.{TLocus, TStruct} import is.hail.table.Table import is.hail.variant.{Locus, MatrixTable, ReferenceGenome} @@ -67,24 +67,24 @@ class JoinSuite extends SparkSuite { val localRowType = left.rvRowType // Inner distinct ordered join - val jInner = left.rvd.orderedJoinDistinct(right.rvd, "inner", _.map(_._1), left.rvd.typ) + val jInner = left.rvd.orderedJoinDistinct(right.rvd, "inner", (_, it) => it.map(_._1), left.rvd.typ) val jInnerOrdRDD1 = left.rdd.join(right.rdd.distinct) assert(jInner.count() == jInnerOrdRDD1.count()) assert(jInner.map { rv => - val ur = new UnsafeRow(localRowType, rv) - ur.getAs[Locus](0) + val r = SafeRow(localRowType, rv) + r.getAs[Locus](0) }.collect() sameElements jInnerOrdRDD1.map(_._1.asInstanceOf[Row].get(0)).collect().sorted(vType.ordering.toOrdering)) // Left distinct ordered join - val jLeft = left.rvd.orderedJoinDistinct(right.rvd, "left", _.map(_._1), left.rvd.typ) + val jLeft = left.rvd.orderedJoinDistinct(right.rvd, "left", (_, it) => it.map(_._1), left.rvd.typ) val jLeftOrdRDD1 = left.rdd.leftOuterJoin(right.rdd.distinct) assert(jLeft.count() == jLeftOrdRDD1.count()) - assert(jLeft.rdd.forall(rv => rv != null)) + assert(jLeft.forall(rv => rv != null)) assert(jLeft.map { rv => - val ur = new UnsafeRow(localRowType, rv) - ur.getAs[Locus](0) + val r = SafeRow(localRowType, rv) + r.getAs[Locus](0) }.collect() sameElements jLeftOrdRDD1.map(_._1.asInstanceOf[Row].get(0)).collect().sorted(vType.ordering.toOrdering)) } }