diff --git a/build.sbt b/build.sbt index 7bddfa5e680f..d85ac4efcc48 100644 --- a/build.sbt +++ b/build.sbt @@ -42,6 +42,7 @@ lazy val root = (project in file(".")). , "org.json4s" %% "json4s-core" % "3.2.10" , "org.json4s" %% "json4s-jackson" % "3.2.10" , "org.json4s" %% "json4s-ast" % "3.2.10" + , "org.elasticsearch" % "elasticsearch-spark-20_2.11" % "6.2.4" , "org.apache.solr" % "solr-solrj" % "6.2.0" , "com.datastax.cassandra" % "cassandra-driver-core" % "3.0.0" , "com.jayway.restassured" % "rest-assured" % "2.8.0" diff --git a/src/main/scala/is/hail/HailContext.scala b/src/main/scala/is/hail/HailContext.scala index 7f889d24fc72..dddf9c5ea4a6 100644 --- a/src/main/scala/is/hail/HailContext.scala +++ b/src/main/scala/is/hail/HailContext.scala @@ -11,7 +11,9 @@ import is.hail.io.bgen.LoadBgen import is.hail.io.gen.LoadGen import is.hail.io.plink.{FamFileConfig, LoadPlink} import is.hail.io.vcf._ +import is.hail.rvd.RVDContext import is.hail.table.Table +import is.hail.sparkextras.ContextRDD import is.hail.stats.{BaldingNicholsModel, Distribution, UniformDist} import is.hail.utils.{log, _} import is.hail.variant.{MatrixTable, ReferenceGenome, VSMSubgen} @@ -208,9 +210,14 @@ object HailContext { ProgressBarBuilder.build(sc) } - def readRowsPartition(makeDec: (InputStream) => Decoder)(i: Int, in: InputStream, metrics: InputMetrics = null): Iterator[RegionValue] = { + def readRowsPartition( + makeDec: (InputStream) => Decoder + )(ctx: RVDContext, + in: InputStream, + metrics: InputMetrics = null + ): Iterator[RegionValue] = new Iterator[RegionValue] { - private val region = Region() + private val region = ctx.region private val rv = RegionValue(region) private val trackedIn = new ByteTrackingInputStream(in) @@ -236,7 +243,6 @@ object HailContext { throw new NoSuchElementException("next on empty iterator") try { - region.clear() rv.setOffset(dec.readRegionValue(region)) if (metrics != null) { ExposedMetrics.incrementRecord(metrics) @@ -259,7 +265,6 @@ object HailContext { dec.close() } } - } } class HailContext private(val sc: SparkContext, @@ -531,8 +536,19 @@ class HailContext private(val sc: SparkContext, } } - def readRows(path: String, t: TStruct, codecSpec: CodecSpec, partFiles: Array[String]): RDD[RegionValue] = - readPartitions(path, partFiles, HailContext.readRowsPartition(codecSpec.buildDecoder(t))) + def readRows( + path: String, + t: TStruct, + codecSpec: CodecSpec, + partFiles: Array[String] + ): ContextRDD[RVDContext, RegionValue] = + ContextRDD.weaken[RVDContext](readPartitions(path, partFiles, (_, is, m) => Iterator.single(is -> m))) + .cmapPartitions { (ctx, it) => + assert(it.hasNext) + val (is, m) = it.next + assert(!it.hasNext) + HailContext.readRowsPartition(codecSpec.buildDecoder(t))(ctx, is, m) + } def parseVCFMetadata(file: String): Map[String, Map[String, Map[String, String]]] = { val reader = new HtsjdkRecordReader(Set.empty) @@ -663,30 +679,30 @@ class HailContext private(val sc: SparkContext, val ast = Parser.parseToAST(expr, ec) ast.toIR() match { case Some(body) => - val region = Region() - val t = ast.`type` - t match { - case _: TBoolean => - val (_, f) = ir.Compile[Boolean](body) - (f()(region), t) - case _: TInt32 => - val (_, f) = ir.Compile[Int](body) - (f()(region), t) - case _: TInt64 => - val (_, f) = ir.Compile[Long](body) - (f()(region), t) - case _: TFloat32 => - val (_, f) = ir.Compile[Float](body) - (f()(region), t) - case _: TFloat64 => - val (_, f) = ir.Compile[Double](body) - (f()(region), t) - case _ => - val (_, f) = ir.Compile[Long](body) - val off = f()(region) - val v2 = UnsafeRow.read(t, region, off) - val v3 = Annotation.copy(t, v2) - (v3, t) + Region.scoped { region => + val t = ast.`type` + t match { + case _: TBoolean => + val (_, f) = ir.Compile[Boolean](body) + (f()(region), t) + case _: TInt32 => + val (_, f) = ir.Compile[Int](body) + (f()(region), t) + case _: TInt64 => + val (_, f) = ir.Compile[Long](body) + (f()(region), t) + case _: TFloat32 => + val (_, f) = ir.Compile[Float](body) + (f()(region), t) + case _: TFloat64 => + val (_, f) = ir.Compile[Double](body) + (f()(region), t) + case _ => + val (_, f) = ir.Compile[Long](body) + val off = f()(region) + val v2 = SafeRow.read(t, region, off) + (v2, t) + } } case None => val (t, f) = Parser.eval(ast, ec) diff --git a/src/main/scala/is/hail/annotations/OrderedRVIterator.scala b/src/main/scala/is/hail/annotations/OrderedRVIterator.scala index 3608ac157c97..4bde68498f57 100644 --- a/src/main/scala/is/hail/annotations/OrderedRVIterator.scala +++ b/src/main/scala/is/hail/annotations/OrderedRVIterator.scala @@ -3,6 +3,8 @@ package is.hail.annotations import is.hail.rvd.OrderedRVDType import is.hail.utils._ +import scala.collection.generic.Growable + case class OrderedRVIterator(t: OrderedRVDType, iterator: Iterator[RegionValue]) { def restrictToPKInterval(interval: Interval): Iterator[RegionValue] = { @@ -56,50 +58,62 @@ case class OrderedRVIterator(t: OrderedRVDType, iterator: Iterator[RegionValue]) this.t.kComp(other.t).compare ) - def innerJoin(other: OrderedRVIterator): Iterator[JoinedRegionValue] = { + def innerJoin( + other: OrderedRVIterator, + rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.innerJoin( other.iterator.toFlipbookIterator, this.t.kRowOrdView, other.t.kRowOrdView, null, null, - new RegionValueArrayBuffer(other.t.rowType), + rightBuffer, this.t.kComp(other.t).compare ) } - def leftJoin(other: OrderedRVIterator): Iterator[JoinedRegionValue] = { + def leftJoin( + other: OrderedRVIterator, + rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.leftJoin( other.iterator.toFlipbookIterator, this.t.kRowOrdView, other.t.kRowOrdView, null, null, - new RegionValueArrayBuffer(other.t.rowType), + rightBuffer, this.t.kComp(other.t).compare ) } - def rightJoin(other: OrderedRVIterator): Iterator[JoinedRegionValue] = { + def rightJoin( + other: OrderedRVIterator, + rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.rightJoin( other.iterator.toFlipbookIterator, this.t.kRowOrdView, other.t.kRowOrdView, null, null, - new RegionValueArrayBuffer(other.t.rowType), + rightBuffer, this.t.kComp(other.t).compare ) } - def outerJoin(other: OrderedRVIterator): Iterator[JoinedRegionValue] = { + def outerJoin( + other: OrderedRVIterator, + rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.outerJoin( other.iterator.toFlipbookIterator, this.t.kRowOrdView, other.t.kRowOrdView, null, null, - new RegionValueArrayBuffer(other.t.rowType), + rightBuffer, this.t.kComp(other.t).compare ) } diff --git a/src/main/scala/is/hail/annotations/Region.scala b/src/main/scala/is/hail/annotations/Region.scala index 059d86d3c98d..55ddea66a198 100644 --- a/src/main/scala/is/hail/annotations/Region.scala +++ b/src/main/scala/is/hail/annotations/Region.scala @@ -255,11 +255,6 @@ final class Region( off } - def clear(newEnd: Long) { - assert(newEnd <= end) - end = newEnd - } - def clear() { end = 0 } diff --git a/src/main/scala/is/hail/annotations/WritableRegionValue.scala b/src/main/scala/is/hail/annotations/WritableRegionValue.scala index 8bfc06123619..11c8d33af20f 100644 --- a/src/main/scala/is/hail/annotations/WritableRegionValue.scala +++ b/src/main/scala/is/hail/annotations/WritableRegionValue.scala @@ -55,10 +55,9 @@ class WritableRegionValue private (val t: Type) { def pretty: String = value.pretty(t) } -class RegionValueArrayBuffer(val t: Type) +class RegionValueArrayBuffer(val t: Type, region: Region) extends Iterable[RegionValue] with Growable[RegionValue] { - val region = Region() val value = RegionValue(region, 0) private val rvb = new RegionValueBuilder(region) @@ -94,7 +93,7 @@ class RegionValueArrayBuffer(val t: Type) def clear() { region.clear() idx.clear() - rvb.clear() + rvb.clear() // remove } private var itIdx = 0 diff --git a/src/main/scala/is/hail/expr/Relational.scala b/src/main/scala/is/hail/expr/Relational.scala index 1d44a316bebb..f21d99f3d396 100644 --- a/src/main/scala/is/hail/expr/Relational.scala +++ b/src/main/scala/is/hail/expr/Relational.scala @@ -2,12 +2,14 @@ package is.hail.expr import is.hail.HailContext import is.hail.annotations._ +import is.hail.annotations.Annotation._ import is.hail.annotations.aggregators.RegionValueAggregator import is.hail.expr.ir._ import is.hail.expr.types._ import is.hail.io._ import is.hail.methods.Aggregators import is.hail.rvd._ +import is.hail.sparkextras.ContextRDD import is.hail.table.TableSpec import is.hail.variant._ import org.apache.spark.rdd.RDD @@ -201,8 +203,10 @@ case class MatrixValue( val hc = HailContext.get val signature = typ.colType - new UnpartitionedRVD(signature, hc.sc.parallelize(colValues.value.toArray.map(_.asInstanceOf[Row])) - .mapPartitions(_.toRegionValueIterator(signature))) + new UnpartitionedRVD( + signature, + ContextRDD.parallelize(hc.sc, colValues.value.toArray.map(_.asInstanceOf[Row])) + .cmapPartitions { (ctx, it) => it.toRegionValueIterator(ctx.region, signature) }) } def entriesRVD(): RVD = { @@ -216,21 +220,17 @@ case class MatrixValue( val localColValues = colValues.broadcast.value - rvd.mapPartitions(resultStruct) { it => - - val rv2b = new RegionValueBuilder() - val rv2 = RegionValue() + rvd.boundary.mapPartitions(resultStruct, { (ctx, it) => + val rv2b = ctx.rvb + val rv2 = RegionValue(ctx.region) it.flatMap { rv => - val rvEnd = rv.region.size - rv2b.set(rv.region) val gsOffset = fullRowType.loadField(rv, localEntriesIndex) (0 until localNCols).iterator .filter { i => localEntriesType.isElementDefined(rv.region, gsOffset, i) } .map { i => - rv.region.clear(rvEnd) rv2b.clear() rv2b.start(resultStruct) rv2b.startStruct() @@ -245,11 +245,11 @@ case class MatrixValue( rv2b.addInlineRow(localColType, localColValues(i).asInstanceOf[Row]) rv2b.addAllFields(localEntryType, rv.region, localEntriesType.elementOffsetInRegion(rv.region, gsOffset, i)) rv2b.endStruct() - rv2.set(rv.region, rv2b.end()) + rv2.setOffset(rv2b.end()) rv2 } } - } + }) } } @@ -471,42 +471,25 @@ case class MatrixRead( } else { val entriesRVD = spec.entriesComponent.read(hc, path) val entriesRowType = entriesRVD.rowType - rowsRVD.zipPartitionsPreservesPartitioning( - typ.orvdType, - entriesRVD - ) { case (it1, it2) => - val rvb = new RegionValueBuilder() - - new Iterator[RegionValue] { - def hasNext: Boolean = { - val hn = it1.hasNext - assert(hn == it2.hasNext) - hn - } - - def next(): RegionValue = { - val rv1 = it1.next() - val rv2 = it2.next() - val region = rv2.region - rvb.set(region) - rvb.start(fullRowType) - rvb.startStruct() - var i = 0 - while (i < localEntriesIndex) { - rvb.addField(rowType, rv1, i) - i += 1 - } - rvb.addField(entriesRowType, rv2, 0) - i += 1 - while (i < fullRowType.size) { - rvb.addField(rowType, rv1, i - 1) - i += 1 - } - rvb.endStruct() - rv2.set(region, rvb.end()) - rv2 - } + rowsRVD.zip(typ.orvdType, entriesRVD) { (ctx, rv1, rv2) => + val rvb = ctx.rvb + val region = ctx.region + rvb.start(fullRowType) + rvb.startStruct() + var i = 0 + while (i < localEntriesIndex) { + rvb.addField(rowType, rv1, i) + i += 1 + } + rvb.addField(entriesRowType, rv2, 0) + i += 1 + while (i < fullRowType.size) { + rvb.addField(rowType, rv1, i - 1) + i += 1 } + rvb.endStruct() + rv2.set(region, rvb.end()) + rv2 } } } @@ -556,16 +539,15 @@ case class MatrixRange(nRows: Int, nCols: Int, nPartitions: Int) extends MatrixI val start = partStarts(i) Interval(Row(start), Row(start + localPartCounts(i)), includesStart = true, includesEnd = false) }), - hc.sc.parallelize(Range(0, nPartitionsAdj), nPartitionsAdj) - .mapPartitionsWithIndex { case (i, _) => - val region = Region() - val rvb = new RegionValueBuilder(region) + ContextRDD.parallelize[RVDContext](hc.sc, Range(0, nPartitionsAdj), nPartitionsAdj) + .cmapPartitionsWithIndex { (i, ctx, _) => + val region = ctx.region + val rvb = ctx.rvb val rv = RegionValue(region) val start = partStarts(i) Iterator.range(start, start + localPartCounts(i)) .map { j => - region.clear() rvb.start(localRVType) rvb.startStruct() @@ -963,12 +945,12 @@ case class MatrixMapRows(child: MatrixIR, newRow: IR) extends MatrixIR { }) assert(rTyp == typ.rvRowType, s"$rTyp, ${ typ.rvRowType }") - val mapPartitionF = { it: Iterator[RegionValue] => + val mapPartitionF = { (ctx: RVDContext, it: Iterator[RegionValue]) => val rvb = new RegionValueBuilder() val newRV = RegionValue() val rowF = f() - val partRegion = Region() + val partRegion = ctx.freshContext.region rvb.set(partRegion) rvb.start(localGlobalsType) @@ -1020,11 +1002,12 @@ case class MatrixMapRows(child: MatrixIR, newRow: IR) extends MatrixIR { } if (touchesKeys) { - val newRDD = prev.rvd.mapPartitions(mapPartitionF) prev.copy(typ = typ, - rvd = OrderedRVD.coerce(typ.orvdType, newRDD, None, None)) + rvd = OrderedRVD.coerce( + typ.orvdType, + prev.rvd.mapPartitions(typ.rvRowType, mapPartitionF))) } else { - val newRVD = prev.rvd.mapPartitionsPreservesPartitioning(typ.orvdType)(mapPartitionF) + val newRVD = prev.rvd.mapPartitionsPreservesPartitioning(typ.orvdType, mapPartitionF) prev.copy(typ = typ, rvd = newRVD) } } @@ -1210,14 +1193,15 @@ case class MatrixMapGlobals(child: MatrixIR, newRow: IR, value: BroadcastRow) ex newRow) assert(rTyp == typ.globalType) - val globalRegion = Region() - val globalOff = prev.globals.toRegion(globalRegion) - val valueOff = value.toRegion(globalRegion) - val newOff = f()(globalRegion, globalOff, false, valueOff, false) + val newGlobals = Region.scoped { globalRegion => + val globalOff = prev.globals.toRegion(globalRegion) + val valueOff = value.toRegion(globalRegion) + val newOff = f()(globalRegion, globalOff, false, valueOff, false) - val newGlobals = prev.globals.copy( - value = SafeRow(rTyp.asInstanceOf[TStruct], globalRegion, newOff), - t = rTyp.asInstanceOf[TStruct]) + prev.globals.copy( + value = SafeRow(rTyp.asInstanceOf[TStruct], globalRegion, newOff), + t = rTyp.asInstanceOf[TStruct]) + } prev.copy(typ = typ, globals = newGlobals) } @@ -1372,8 +1356,8 @@ case class TableParallelize(typ: TableType, rows: IndexedSeq[Row], nPartitions: def execute(hc: HailContext): TableValue = { val rowTyp = typ.rowType - val rvd = hc.sc.parallelize(rows, nPartitions.getOrElse(hc.sc.defaultParallelism)) - .mapPartitions(_.toRegionValueIterator(rowTyp)) + val rvd = ContextRDD.parallelize[RVDContext](hc.sc, rows, nPartitions) + .cmapPartitions((ctx, it) => it.toRegionValueIterator(ctx.region, rowTyp)) TableValue(typ, BroadcastRow(Row(), typ.globalType, hc.sc), new UnpartitionedRVD(rowTyp, rvd)) } } @@ -1397,20 +1381,18 @@ case class TableImport(paths: Array[String], typ: TableType, readerOpts: TableRe val useColIndices = readerOpts.useColIndices - val rvd = hc.sc.textFilesLines(paths, readerOpts.nPartitions) + val rvd = ContextRDD.textFilesLines[RVDContext](hc.sc, paths, readerOpts.nPartitions) .filter { line => !readerOpts.isComment(line.value) && (readerOpts.noHeader || readerOpts.header != line.value) && !(readerOpts.skipBlankLines && line.value.isEmpty) - }.mapPartitions { it => - val region = Region() - val rvb = new RegionValueBuilder(region) + }.cmapPartitions { (ctx, it) => + val region = ctx.region + val rvb = ctx.rvb val rv = RegionValue(region) it.map { _.map { line => - region.clear() - val sp = TextTableReader.splitLine(line, readerOpts.separator, readerOpts.quote) if (sp.length != nFieldOrig) fatal(s"expected $nFieldOrig fields, but found ${ sp.length } fields") @@ -1483,16 +1465,15 @@ case class TableRange(n: Int, nPartitions: Int) extends TableIR { val end = partStarts(i + 1) Interval(Row(start), Row(end), includesStart = true, includesEnd = false) }), - hc.sc.parallelize(Range(0, nPartitionsAdj), nPartitionsAdj) - .mapPartitionsWithIndex { case (i, _) => - val region = Region() - val rvb = new RegionValueBuilder(region) + ContextRDD.parallelize(hc.sc, Range(0, nPartitionsAdj), nPartitionsAdj) + .cmapPartitionsWithIndex { case (i, ctx, _) => + val region = ctx.region + val rvb = ctx.rvb val rv = RegionValue(region) val start = partStarts(i) Iterator.range(start, start + localPartCounts(i)) .map { j => - region.clear() rvb.start(localRowType) rvb.startStruct() rvb.addInt(j) @@ -1562,7 +1543,7 @@ case class TableJoin(left: TableIR, right: TableIR, joinType: String) extends Ta val leftValueFieldIdx = left.typ.valueFieldIdx val rightValueFieldIdx = right.typ.valueFieldIdx val localNewRowType = newRowType - val rvMerger: Iterator[JoinedRegionValue] => Iterator[RegionValue] = { it => + val rvMerger = { (ctx: RVDContext, it: Iterator[JoinedRegionValue]) => val rvb = new RegionValueBuilder() val rv = RegionValue() it.map { joined => @@ -1751,26 +1732,26 @@ case class TableExplode(child: TableIR, column: String) extends TableIR { assert(resultType == typ.rowType) TableValue(typ, prev.globals, - prev.rvd.mapPartitions(typ.rowType) { it => + prev.rvd.boundary.mapPartitions(typ.rowType, { (ctx, it) => val rv2 = RegionValue() - it.flatMap { rv => val isMissing = isMissingF()(rv.region, rv.offset, false) if (isMissing) Iterator.empty else { - val end = rv.region.size val n = lengthF()(rv.region, rv.offset, false) Iterator.range(0, n) .map { i => - rv.region.clear(end) - val off = explodeF()(rv.region, rv.offset, false, i, false) - rv2.set(rv.region, off) + ctx.rvb.start(childRowType) + ctx.rvb.addRegionValue(childRowType, rv) + val incomingRow = ctx.rvb.end() + val off = explodeF()(ctx.region, incomingRow, false, i, false) + rv2.set(ctx.region, off) rv2 } } } - }) + })) } } diff --git a/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala b/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala index fe862a0cdbab..c0b624223b1b 100644 --- a/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala +++ b/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala @@ -68,6 +68,6 @@ object ExtractAggregators { Code(codeArgs.map(_.setup): _*), AggOp.get(op, x.inputType, args.map(_.typ)) .stagedNew(codeArgs.map(_.v).toArray, codeArgs.map(_.m).toArray))) - constfb.result()()(Region()) + Region.scoped(constfb.result()()(_)) } } diff --git a/src/main/scala/is/hail/io/LoadMatrix.scala b/src/main/scala/is/hail/io/LoadMatrix.scala index 5cf481ae3641..d93a71eeb433 100644 --- a/src/main/scala/is/hail/io/LoadMatrix.scala +++ b/src/main/scala/is/hail/io/LoadMatrix.scala @@ -3,7 +3,8 @@ package is.hail.io import is.hail.HailContext import is.hail.annotations._ import is.hail.expr.types._ -import is.hail.rvd.{OrderedRVD, OrderedRVDPartitioner} +import is.hail.rvd.{OrderedRVD, OrderedRVDPartitioner, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant._ import org.apache.hadoop.conf.Configuration @@ -359,9 +360,9 @@ object LoadMatrix { rowPartitionKey = rowKey.toFastIndexedSeq, entryType = cellType) - val rdd = lines.filter(l => l.value.nonEmpty) - .mapPartitionsWithIndex { (i, it) => - val region = Region() + val rdd = ContextRDD.weaken[RVDContext](lines.filter(l => l.value.nonEmpty)) + .cmapPartitionsWithIndex { (i, ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) @@ -374,7 +375,6 @@ object LoadMatrix { val fileRowNum = partitionStartInFile + row val line = v.value - region.clear() rvb.start(matrixType.rvRowType) rvb.startStruct() if (useIndex) { diff --git a/src/main/scala/is/hail/io/RowStore.scala b/src/main/scala/is/hail/io/RowStore.scala index 56b0dae51e82..44e6aa41702f 100644 --- a/src/main/scala/is/hail/io/RowStore.scala +++ b/src/main/scala/is/hail/io/RowStore.scala @@ -4,8 +4,8 @@ import is.hail.annotations._ import is.hail.expr.JSONAnnotationImpex import is.hail.expr.types._ import is.hail.io.compress.LZ4Utils -import is.hail.rvd.{OrderedRVDPartitioner, OrderedRVDSpec, RVDSpec, UnpartitionedRVDSpec} -import is.hail.sparkextras.ContextRDD +import is.hail.rvd.{OrderedRVDPartitioner, OrderedRVDSpec, RVDContext, RVDSpec, UnpartitionedRVDSpec} +import is.hail.sparkextras._ import is.hail.utils._ import org.apache.spark.rdd.RDD import org.json4s.{Extraction, JValue} @@ -62,6 +62,10 @@ object CodecSpec { new LZ4BlockBufferSpec(32 * 1024, new StreamBlockBufferSpec)))) + val defaultUncompressed = new PackCodecSpec( + new BlockingBufferSpec(32 * 1024, + new StreamBlockBufferSpec)) + val blockSpecs: Array[BufferSpec] = Array( new BlockingBufferSpec(64 * 1024, new StreamBlockBufferSpec), @@ -850,13 +854,14 @@ final class PackEncoder(rowType: Type, out: OutputBuffer) extends Encoder { } object RichContextRDDRegionValue { - def writeRowsPartition(makeEnc: (OutputStream) => Encoder)(i: Int, it: Iterator[RegionValue], os: OutputStream): Long = { + def writeRowsPartition(makeEnc: (OutputStream) => Encoder)(ctx: RVDContext, it: Iterator[RegionValue], os: OutputStream): Long = { val en = makeEnc(os) var rowCount = 0L it.foreach { rv => en.writeByte(1) en.writeRegionValue(rv.region, rv.offset) + ctx.region.clear() rowCount += 1 } @@ -868,7 +873,7 @@ object RichContextRDDRegionValue { } } -class RichContextRDDRegionValue[C <: AutoCloseable](val crdd: ContextRDD[C, RegionValue]) extends AnyVal { +class RichContextRDDRegionValue(val crdd: ContextRDD[RVDContext, RegionValue]) extends AnyVal { def writeRows(path: String, t: TStruct, codecSpec: CodecSpec): (Array[String], Array[Long]) = { crdd.writePartitions(path, RichContextRDDRegionValue.writeRowsPartition(codecSpec.buildEncoder(t))) } @@ -892,11 +897,11 @@ class RichContextRDDRegionValue[C <: AutoCloseable](val crdd: ContextRDD[C, Regi val entriesRVType = TStruct( MatrixType.entriesIdentifier -> TArray(t.entryType)) - val makeRowsEnc = codecSpec.buildEncoder(rowsRVType)(_) + val makeRowsEnc = codecSpec.buildEncoder(rowsRVType) - val makeEntriesEnc = codecSpec.buildEncoder(entriesRVType)(_) + val makeEntriesEnc = codecSpec.buildEncoder(entriesRVType) - val partFilePartitionCounts = crdd.mapPartitionsWithIndex { case (i, it) => + val partFilePartitionCounts = crdd.cmapPartitionsWithIndex { (i, ctx, it) => val hConf = sHConfBc.value.value val f = partFile(d, i, TaskContext.get) @@ -940,6 +945,8 @@ class RichContextRDDRegionValue[C <: AutoCloseable](val crdd: ContextRDD[C, Regi rowsEN.writeByte(0) // end entriesEN.writeByte(0) + ctx.region.clear() + Iterator.single(f -> rowCount) } } diff --git a/src/main/scala/is/hail/io/bgen/LoadBgen.scala b/src/main/scala/is/hail/io/bgen/LoadBgen.scala index 7cad671d81bd..0035318f315e 100644 --- a/src/main/scala/is/hail/io/bgen/LoadBgen.scala +++ b/src/main/scala/is/hail/io/bgen/LoadBgen.scala @@ -5,7 +5,8 @@ import is.hail.annotations._ import is.hail.expr.types._ import is.hail.io.vcf.LoadVCF import is.hail.io.{HadoopFSDataBinaryReader, IndexBTree} -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant._ import org.apache.hadoop.io.LongWritable @@ -103,8 +104,10 @@ object LoadBgen { val kType = matrixType.orvdType.kType val rowType = matrixType.rvRowType - val fastKeys = sc.union(results.map(_.rdd.mapPartitions { it => - val region = Region() + val crdds = results.map(x => ContextRDD.weaken[RVDContext](x.rdd)) + + val fastKeys = ContextRDD.union(sc, crdds.map(_.cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) @@ -112,7 +115,6 @@ object LoadBgen { val (contig, pos, alleles) = record.getKey val contigRecoded = contigRecoding.getOrElse(contig, contig) - region.clear() rvb.start(kType) rvb.startStruct() rvb.addAnnotation(kType.types(0), Locus.annotation(contigRecoded, pos, rg)) @@ -134,8 +136,8 @@ object LoadBgen { val loadEntries = entryFields.length > 0 - val rdd2 = sc.union(results.map(_.rdd.mapPartitions { it => - val region = Region() + val rdd2 = ContextRDD.union(sc, crdds.map(_.cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) @@ -145,7 +147,6 @@ object LoadBgen { val contigRecoded = contigRecoding.getOrElse(contig, contig) - region.clear() rvb.start(rowType) rvb.startStruct() rvb.addAnnotation(kType.types(0), Locus.annotation(contigRecoded, pos, rg)) diff --git a/src/main/scala/is/hail/io/plink/ExportPlink.scala b/src/main/scala/is/hail/io/plink/ExportPlink.scala index 103fa5b3aa77..774c8da86b87 100644 --- a/src/main/scala/is/hail/io/plink/ExportPlink.scala +++ b/src/main/scala/is/hail/io/plink/ExportPlink.scala @@ -76,7 +76,7 @@ object ExportPlink { val nSamples = mv.colValues.value.length val fullRowType = mv.typ.rvRowType - val nRecordsWritten = mv.rvd.mapPartitionsWithIndex { case (i, it) => + val nRecordsWritten = mv.rvd.mapPartitionsWithIndex { (i, ctx, it) => val hConf = sHConfBc.value.value val f = partFile(d, i, TaskContext.get) val bedPartPath = tmpBedDir + "/" + f @@ -97,7 +97,7 @@ object ExportPlink { hcv.setRegion(rv) ExportPlink.writeBedRow(hcv, bp, nSamples) - + ctx.region.clear() rowCount += 1 } } diff --git a/src/main/scala/is/hail/io/plink/LoadPlink.scala b/src/main/scala/is/hail/io/plink/LoadPlink.scala index 9e1af3de8e5f..119e1f18977b 100644 --- a/src/main/scala/is/hail/io/plink/LoadPlink.scala +++ b/src/main/scala/is/hail/io/plink/LoadPlink.scala @@ -4,7 +4,8 @@ import is.hail.HailContext import is.hail.annotations._ import is.hail.expr.types._ import is.hail.io.vcf.LoadVCF -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils.StringEscapeUtils._ import is.hail.utils._ import is.hail.variant.{Locus, _} @@ -123,8 +124,13 @@ object LoadPlink { sc.hadoopConfiguration.setInt("nSamples", nSamples) sc.hadoopConfiguration.setBoolean("a2Reference", a2Reference) - val rdd = sc.hadoopFile(bedPath, classOf[PlinkInputFormat], classOf[LongWritable], classOf[PlinkRecord], - nPartitions.getOrElse(sc.defaultMinPartitions)) + val crdd = ContextRDD.weaken[RVDContext]( + sc.hadoopFile( + bedPath, + classOf[PlinkInputFormat], + classOf[LongWritable], + classOf[PlinkRecord], + nPartitions.getOrElse(sc.defaultMinPartitions))) val matrixType = MatrixType.fromParts( globalType = TStruct.empty(), @@ -138,15 +144,14 @@ object LoadPlink { val kType = matrixType.orvdType.kType val rvRowType = matrixType.rvRowType - val fastKeys = rdd.mapPartitions { it => - val region = Region() + val fastKeys = crdd.cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) it.map { case (_, record) => val (contig, pos, posMorgan, ref, alt, rsid) = variantsBc.value(record.getKey) - region.clear() rvb.start(kType) rvb.startStruct() rvb.addAnnotation(kType.types(0), Locus.annotation(contig, pos, rg)) @@ -161,15 +166,14 @@ object LoadPlink { } } - val rdd2 = rdd.mapPartitions { it => - val region = Region() + val rdd2 = crdd.cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) it.map { case (_, record) => val (contig, pos, posMorgan, ref, alt, rsid) = variantsBc.value(record.getKey) - region.clear() rvb.start(rvRowType) rvb.startStruct() rvb.addAnnotation(kType.types(0), Locus.annotation(contig, pos, rg)) diff --git a/src/main/scala/is/hail/io/vcf/LoadGDB.scala b/src/main/scala/is/hail/io/vcf/LoadGDB.scala index 861f7ae578f6..488b243addba 100644 --- a/src/main/scala/is/hail/io/vcf/LoadGDB.scala +++ b/src/main/scala/is/hail/io/vcf/LoadGDB.scala @@ -5,17 +5,19 @@ import htsjdk.variant.vcf.{VCFCompoundHeaderLine, VCFHeader} import is.hail.HailContext import is.hail.annotations._ import is.hail.utils._ +import is.hail.io._ import is.hail.variant.{MatrixTable, ReferenceGenome} import org.json4s._ import scala.collection.JavaConversions._ import scala.collection.JavaConverters.asScalaIteratorConverter -import java.io.{File, FileWriter} +import java.io._ import is.hail.expr.types._ import is.hail.io.VCFAttributes import is.hail.io.vcf.LoadVCF.headerSignature -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVDContext} +import is.hail.sparkextras.ContextRDD import org.apache.spark.sql.Row import scala.collection.mutable @@ -170,27 +172,42 @@ object LoadGDB { entryType = genotypeSignature) val localRowType = matrixType.rvRowType - val region = Region() - val rvb = new RegionValueBuilder(region) + val rvCodec = CodecSpec.defaultUncompressed - val records = gdbReader - .iterator - .asScala - .map { vc => + val records = Region.scoped { region => + val baos = new ByteArrayOutputStream() + val enc = rvCodec.buildEncoder(localRowType)(baos) + val rvb = new RegionValueBuilder(region) + gdbReader + .iterator + .asScala + .map { vc => rvb.clear() region.clear() rvb.start(localRowType) reader.readRecord(vc, rvb, infoSignature, genotypeSignature, dropSamples, canonicalFlags, infoFlagFieldNames) - rvb.result().copy() + enc.writeRegionValue(region, rvb.end()) + baos.toByteArray() }.toArray + } - val recordRDD = sc.parallelize(records, nPartitions.getOrElse(sc.defaultMinPartitions)) + val recordCRDD = ContextRDD.parallelize[RVDContext](sc, records, nPartitions) + .cmapPartitions { (ctx, it) => + val region = ctx.region + val rv = RegionValue(region) + it.map { bytes => + val bais = new ByteArrayInputStream(bytes) + val dec = rvCodec.buildDecoder(localRowType)(bais) + rv.setOffset(dec.readRegionValue(region)) + rv + } + } queryFile.delete() new MatrixTable(hc, matrixType, BroadcastRow(Row.empty, matrixType.globalType, sc), BroadcastIndexedSeq(sampleIds.map(x => Annotation(x)), TArray(matrixType.colType), sc), - OrderedRVD.coerce(matrixType.orvdType, hc.sc.parallelize(records), None, None)) + OrderedRVD.coerce(matrixType.orvdType, recordCRDD)) } } diff --git a/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/src/main/scala/is/hail/io/vcf/LoadVCF.scala index 35586aba31a9..fa70d03b0610 100644 --- a/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -5,7 +5,8 @@ import is.hail.HailContext import is.hail.annotations._ import is.hail.expr.types._ import is.hail.io.{VCFAttributes, VCFMetadata} -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant._ import org.apache.hadoop @@ -711,12 +712,18 @@ object LoadVCF { } // parses the Variant (key), leaves the rest to f - def parseLines[C](makeContext: () => C)(f: (C, VCFLine, RegionValueBuilder) => Unit)( - lines: RDD[WithContext[String]], t: Type, rg: Option[ReferenceGenome], contigRecoding: Map[String, String]): RDD[RegionValue] = { - lines.mapPartitions { it => + def parseLines[C]( + makeContext: () => C + )(f: (C, VCFLine, RegionValueBuilder) => Unit + )(lines: ContextRDD[RVDContext, WithContext[String]], + t: Type, + rg: Option[ReferenceGenome], + contigRecoding: Map[String, String] + ): ContextRDD[RVDContext, RegionValue] = { + lines.cmapPartitions { (ctx, it) => new Iterator[RegionValue] { - val region = Region() - val rvb = new RegionValueBuilder(region) + val region = ctx.region + val rvb = ctx.rvb val rv = RegionValue(region) val context: C = makeContext() @@ -729,7 +736,6 @@ object LoadVCF { val line = lwc.value try { val vcfLine = new VCFLine(line) - region.clear() rvb.start(t) rvb.startStruct() present = vcfLine.parseAddVariant(rvb, rg, contigRecoding) @@ -845,7 +851,7 @@ object LoadVCF { val headerLinesBc = sc.broadcast(headerLines1) - val lines = sc.textFilesLines(files, nPartitions.getOrElse(sc.defaultMinPartitions)) + val lines = ContextRDD.textFilesLines[RVDContext](sc, files, nPartitions) val matrixType: MatrixType = MatrixType.fromParts( TStruct.empty(true), @@ -860,7 +866,13 @@ object LoadVCF { val rowType = matrixType.rvRowType // nothing after the key - val justVariants = parseLines(() => ())((c, l, rvb) => ())(lines, kType, rg, contigRecoding) + val justVariants = parseLines( + () => () + )((c, l, rvb) => () + )(ContextRDD.textFilesLines[RVDContext](sc, files, nPartitions), + kType, + rg, + contigRecoding) val rdd = OrderedRVD.coerce( matrixType.orvdType, @@ -895,7 +907,7 @@ object LoadVCF { } rvb.endArray() }(lines, rowType, rg, contigRecoding), - Some(justVariants), None) + justVariants) new MatrixTable(hc, matrixType, diff --git a/src/main/scala/is/hail/linalg/BlockMatrix.scala b/src/main/scala/is/hail/linalg/BlockMatrix.scala index e39ff2f00ea5..e3da2cc5c7cc 100644 --- a/src/main/scala/is/hail/linalg/BlockMatrix.scala +++ b/src/main/scala/is/hail/linalg/BlockMatrix.scala @@ -11,7 +11,8 @@ import is.hail.table.Table import is.hail.expr.types._ import is.hail.io.{BlockingBufferSpec, BufferSpec, LZ4BlockBufferSpec, StreamBlockBufferSpec} import is.hail.methods.UpperIndexBounds -import is.hail.rvd.RVD +import is.hail.rvd.{RVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.utils.richUtils.RichDenseMatrixDouble import org.apache.commons.math3.random.MersenneTwister @@ -247,7 +248,7 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], val hadoop = blocks.sparkContext.hadoopConfiguration hadoop.mkDir(uri) - def writeBlock(i: Int, it: Iterator[((Int, Int), BDM[Double])], os: OutputStream): Int = { + def writeBlock(it: Iterator[((Int, Int), BDM[Double])], os: OutputStream): Int = { assert(it.hasNext) val bdm = it.next()._2 assert(!it.hasNext) @@ -621,17 +622,16 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], case None => blocks } - val entriesRDD = rdd.flatMap { case ((blockRow, blockCol), block) => + val entriesRDD = ContextRDD.weaken[RVDContext](rdd).cflatMap { case (ctx, ((blockRow, blockCol), block)) => val rowOffset = blockRow * blockSize.toLong val colOffset = blockCol * blockSize.toLong - val region = Region() + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) block.activeIterator .map { case ((i, j), entry) => - region.clear() rvb.start(rvRowType) rvb.startStruct() rvb.addLong(rowOffset + i) @@ -1008,7 +1008,7 @@ case class WriteBlocksRDDPartition(index: Int, start: Int, skip: Int, end: Int) } class WriteBlocksRDD(path: String, - rdd: RDD[RegionValue], + crdd: ContextRDD[RVDContext, RegionValue], sc: SparkContext, matrixType: MatrixType, parentPartStarts: Array[Long], @@ -1017,7 +1017,7 @@ class WriteBlocksRDD(path: String, require(gp.nRows == parentPartStarts.last) - private val parentParts = rdd.partitions + private val parentParts = crdd.partitions private val blockSize = gp.blockSize private val d = digitsNeeded(gp.numPartitions) @@ -1025,7 +1025,7 @@ class WriteBlocksRDD(path: String, override def getDependencies: Seq[Dependency[_]] = Array[Dependency[_]]( - new NarrowDependency(rdd) { + new NarrowDependency(crdd.rdd) { def getParents(partitionId: Int): Seq[Int] = partitions(partitionId).asInstanceOf[WriteBlocksRDDPartition].range } @@ -1100,8 +1100,8 @@ class WriteBlocksRDD(path: String, val writeBlocksPart = split.asInstanceOf[WriteBlocksRDDPartition] val start = writeBlocksPart.start - writeBlocksPart.range.foreach { pi => - val it = rdd.iterator(parentParts(pi), context) + writeBlocksPart.range.foreach { pi => using(crdd.mkc()) { ctx => + val it = crdd.iterator(parentParts(pi), context, ctx) if (pi == start) { var j = 0 @@ -1147,7 +1147,7 @@ class WriteBlocksRDD(path: String, } i += 1 } - } + } } outPerBlockCol.foreach(_.close()) diff --git a/src/main/scala/is/hail/methods/FilterAlleles.scala b/src/main/scala/is/hail/methods/FilterAlleles.scala index 0407f589355d..161bea782500 100644 --- a/src/main/scala/is/hail/methods/FilterAlleles.scala +++ b/src/main/scala/is/hail/methods/FilterAlleles.scala @@ -104,15 +104,14 @@ object FilterAlleles { val localSampleAnnotationsBc = vsm.colValues.broadcast - rdd.mapPartitions(newRVType) { it => + rdd.boundary.mapPartitions(newRVType, { (ctx, it) => var prevLocus: Locus = null val fullRow = new UnsafeRow(fullRowType) val rvv = new RegionValueVariant(fullRowType) + val rvb = ctx.rvb + val rv2 = RegionValue(ctx.region) it.flatMap { rv => - val rvb = new RegionValueBuilder() - val rv2 = RegionValue() - rvv.setRegion(rv) fullRow.set(rv) @@ -132,7 +131,6 @@ object FilterAlleles { || (!isLeftAligned && removeMoving)) None else { - rvb.set(rv.region) rvb.start(newRVType) rvb.startStruct() @@ -162,13 +160,13 @@ object FilterAlleles { } rvb.endArray() rvb.endStruct() - rv2.set(rv.region, rvb.end()) + rv2.setOffset(rvb.end()) Some(rv2) } } } - } + }) } val newRDD2: OrderedRVD = diff --git a/src/main/scala/is/hail/methods/FilterIntervals.scala b/src/main/scala/is/hail/methods/FilterIntervals.scala index 84fdefea1cbf..971cfcc3ecc2 100644 --- a/src/main/scala/is/hail/methods/FilterIntervals.scala +++ b/src/main/scala/is/hail/methods/FilterIntervals.scala @@ -22,15 +22,15 @@ object FilterIntervals { val pkRowFieldIdx = vsm.rvd.typ.pkRowFieldIdx val rowType = vsm.rvd.typ.rowType - vsm.copy2(rvd = vsm.rvd.mapPartitionsPreservesPartitioning(vsm.rvd.typ) { it => - val pk = WritableRegionValue(pkType) + vsm.copy2(rvd = vsm.rvd.mapPartitionsPreservesPartitioning(vsm.rvd.typ, { (ctx, it) => val pkUR = new UnsafeRow(pkType) it.filter { rv => - pk.setSelect(rowType, pkRowFieldIdx, rv) - pkUR.set(pk.value) + ctx.rvb.start(pkType) + ctx.rvb.selectRegionValue(rowType, pkRowFieldIdx, rv) + pkUR.set(ctx.region, ctx.rvb.end()) !intervalsBc.value.contains(pkType.ordering, pkUR) } - }) + })) } } } diff --git a/src/main/scala/is/hail/methods/IBD.scala b/src/main/scala/is/hail/methods/IBD.scala index e51a52832a4a..4fe53ad59856 100644 --- a/src/main/scala/is/hail/methods/IBD.scala +++ b/src/main/scala/is/hail/methods/IBD.scala @@ -5,6 +5,8 @@ import is.hail.expr.EvalContext import is.hail.table.Table import is.hail.annotations._ import is.hail.expr.types._ +import is.hail.rvd.RVDContext +import is.hail.sparkextras.ContextRDD import is.hail.variant.{Call, Genotype, HardCallView, MatrixTable} import is.hail.stats.RegressionUtils import org.apache.spark.rdd.RDD @@ -207,7 +209,7 @@ object IBD { min: Option[Double], max: Option[Double], sampleIds: IndexedSeq[String], - bounded: Boolean): RDD[RegionValue] = { + bounded: Boolean): ContextRDD[RVDContext, RegionValue] = { val nSamples = vds.numCols @@ -253,7 +255,7 @@ object IBD { }) .map { case ((s, v), gs) => (v, (s, IBSFFI.pack(chunkSize, chunkSize, gs))) } - chunkedGenotypeMatrix.join(chunkedGenotypeMatrix) + val joined = ContextRDD.weaken[RVDContext](chunkedGenotypeMatrix.join(chunkedGenotypeMatrix) // optimization: Ignore chunks below the diagonal .filter { case (_, ((i, _), (j, _))) => j >= i } .map { case (_, ((s1, gs1), (s2, gs2))) => @@ -266,9 +268,11 @@ object IBD { i += 1 } a - } - .mapPartitions { it => - val region = Region() + }) + + joined + .cmapPartitions { (ctx, it) => + val region = ctx.region val rv = RegionValue(region) val rvb = new RegionValueBuilder(region) for { @@ -282,7 +286,6 @@ object IBD { eibd = calculateIBDInfo(ibses(idx * 3), ibses(idx * 3 + 1), ibses(idx * 3 + 2), ibse, bounded) if min.forall(eibd.ibd.PI_HAT >= _) && max.forall(eibd.ibd.PI_HAT <= _) } yield { - region.clear() rvb.start(ibdSignature) rvb.startStruct() rvb.addString(sampleIds(i)) @@ -292,7 +295,7 @@ object IBD { rv.setOffset(rvb.end()) rv } - } + } } def apply(vds: MatrixTable, diff --git a/src/main/scala/is/hail/methods/LinearRegression.scala b/src/main/scala/is/hail/methods/LinearRegression.scala index f51230dc3623..fdc54acc1d54 100644 --- a/src/main/scala/is/hail/methods/LinearRegression.scala +++ b/src/main/scala/is/hail/methods/LinearRegression.scala @@ -60,10 +60,10 @@ object LinearRegression { val newMatrixType = vsm.matrixType.copy(rvRowType = newRVType) - val newRDD2 = vsm.rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType) { it => - - val region2 = Region() - val rvb = new RegionValueBuilder(region2) + val newRDD2 = vsm.rvd.boundary.mapPartitionsPreservesPartitioning( + newMatrixType.orvdType, { (ctx, it) => + val region2 = ctx.region + val rvb = ctx.rvb val rv2 = RegionValue(region2) val missingCompleteCols = new ArrayBuilder[Int] @@ -136,7 +136,7 @@ object LinearRegression { rv2 } } - } + }) vsm.copyMT(matrixType = newMatrixType, rvd = newRDD2) diff --git a/src/main/scala/is/hail/methods/LocalLDPrune.scala b/src/main/scala/is/hail/methods/LocalLDPrune.scala index eb486b071897..85b659b68226 100644 --- a/src/main/scala/is/hail/methods/LocalLDPrune.scala +++ b/src/main/scala/is/hail/methods/LocalLDPrune.scala @@ -319,13 +319,13 @@ object LocalLDPrune { } }) - val rddLP = pruneLocal(standardizedRDD, r2Threshold, windowSize, Some(maxQueueSize)) + val rvdLP = pruneLocal(standardizedRDD, r2Threshold, windowSize, Some(maxQueueSize)) val tableType = TableType( rowType = mt.rowKeyStruct ++ TStruct("mean" -> TFloat64Required, "centered_length_sd_reciprocal" -> TFloat64Required), key = Some(mt.rowKey), globalType = TStruct.empty()) - val sitesOnlyRDD = rddLP.mapPartitionsPreservesPartitioning( + val sitesOnly = rvdLP.mapPartitionsPreservesPartitioning( new OrderedRVDType(typ.partitionKey, typ.key, tableType.rowType))({ it => val region = Region() @@ -345,7 +345,7 @@ object LocalLDPrune { } }) - new Table(hc = mt.hc, rdd = sitesOnlyRDD.rdd, signature = tableType.rowType, key = tableType.key) + new Table(hc = mt.hc, crdd = sitesOnly.crdd, signature = tableType.rowType, key = tableType.key) } } diff --git a/src/main/scala/is/hail/methods/Nirvana.scala b/src/main/scala/is/hail/methods/Nirvana.scala index 27f9a1a62e25..b2c22fdb10ce 100644 --- a/src/main/scala/is/hail/methods/Nirvana.scala +++ b/src/main/scala/is/hail/methods/Nirvana.scala @@ -6,7 +6,8 @@ import java.util.Properties import is.hail.annotations._ import is.hail.expr.types._ import is.hail.expr.{JSONAnnotationImpex, Parser} -import is.hail.rvd.{OrderedRVD, OrderedRVDType} +import is.hail.rvd.{OrderedRVD, OrderedRVDType, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant.{Locus, MatrixTable, RegionValueVariant} import org.apache.spark.sql.Row @@ -327,8 +328,8 @@ object Nirvana { val nirvanaRVD: OrderedRVD = OrderedRVD( nirvanaORVDType, vds.rvd.partitioner, - annotations.mapPartitions { it => - val region = Region() + ContextRDD.weaken[RVDContext](annotations).cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) diff --git a/src/main/scala/is/hail/methods/PCA.scala b/src/main/scala/is/hail/methods/PCA.scala index 57ce218dec77..d175a2f6a5ab 100644 --- a/src/main/scala/is/hail/methods/PCA.scala +++ b/src/main/scala/is/hail/methods/PCA.scala @@ -3,6 +3,8 @@ package is.hail.methods import breeze.linalg.{*, DenseMatrix, DenseVector} import is.hail.annotations._ import is.hail.expr.types._ +import is.hail.rvd.RVDContext +import is.hail.sparkextras.ContextRDD import is.hail.table.Table import is.hail.utils._ import is.hail.variant.MatrixTable @@ -23,14 +25,13 @@ object PCA { val scoresBc = sc.broadcast(scores) val localSSignature = vsm.colKeyTypes - val scoresRDD = sc.parallelize(vsm.colKeys.zipWithIndex).mapPartitions[RegionValue] { it => - val region = Region() + val scoresRDD = ContextRDD.weaken[RVDContext](sc.parallelize(vsm.colKeys.zipWithIndex)).cmapPartitions { (ctx, it) => + val region = ctx.region val rv = RegionValue(region) val rvb = new RegionValueBuilder(region) val localRowType = rowTypeBc.value it.map { case (s, i) => - region.clear() rvb.start(localRowType) rvb.startStruct() var j = 0 @@ -87,18 +88,17 @@ object PCA { }.collect() } - val optionLoadings = someIf(computeLoadings, { + val optionLoadings = if (computeLoadings) { val rowType = TStruct(vsm.rowKey.zip(vsm.rowKeyTypes): _*) ++ TStruct("loadings" -> TArray(TFloat64())) val rowTypeBc = vsm.sparkContext.broadcast(rowType) val rowKeysBc = vsm.sparkContext.broadcast(collectRowKeys()) val localRowKeySignature = vsm.rowKeyTypes - val rdd = svd.U.rows.mapPartitions[RegionValue] { it => - val region = Region() + val rdd = ContextRDD.weaken[RVDContext](svd.U.rows).cmapPartitions { (ctx, it) => + val region = ctx.region val rv = RegionValue(region) val rvb = new RegionValueBuilder(region) it.map { ir => - region.clear() rvb.start(rowTypeBc.value) rvb.startStruct() @@ -121,8 +121,10 @@ object PCA { rv } } - new Table(vsm.hc, rdd, rowType, Some(vsm.rowKey)) - }) + Some(new Table(vsm.hc, rdd, rowType, Some(vsm.rowKey))) + } else { + None + } val data = if (!svd.V.isTransposed) diff --git a/src/main/scala/is/hail/methods/PCRelate.scala b/src/main/scala/is/hail/methods/PCRelate.scala index 0eaa85e6212c..f28898adc44a 100644 --- a/src/main/scala/is/hail/methods/PCRelate.scala +++ b/src/main/scala/is/hail/methods/PCRelate.scala @@ -168,7 +168,7 @@ object PCRelate { val localRowType = vds.rvRowType val partStarts = vds.partitionStarts() val partStartsBc = vds.sparkContext.broadcast(partStarts) - val rdd = vds.rvd.mapPartitionsWithIndex { case (partIdx, it) => + val rdd = vds.rvd.mapPartitionsWithIndex { (partIdx, it) => val view = HardCallView(localRowType) val missingIndices = new ArrayBuilder[Int]() diff --git a/src/main/scala/is/hail/methods/Skat.scala b/src/main/scala/is/hail/methods/Skat.scala index c8b3b1e3f45a..0007d6037639 100644 --- a/src/main/scala/is/hail/methods/Skat.scala +++ b/src/main/scala/is/hail/methods/Skat.scala @@ -227,8 +227,8 @@ object Skat { val n = completeColIdx.length val completeColIdxBc = sc.broadcast(completeColIdx) - - (vsm.rvd.rdd.flatMap { rv => + + (vsm.rvd.boundary.mapPartitions { it => it.flatMap { rv => val keyIsDefined = fullRowType.isFieldDefined(rv, keyIndex) val weightIsDefined = fullRowType.isFieldDefined(rv, weightIndex) @@ -243,6 +243,7 @@ object Skat { rv, fullRowType, entryArrayType, entryType, entryArrayIdx, fieldIdx) Some(key -> (BDV(data), weight)) } else None + } }.groupByKey(), keyType) } diff --git a/src/main/scala/is/hail/methods/SplitMulti.scala b/src/main/scala/is/hail/methods/SplitMulti.scala index 88b5fdd2b40a..d0707493cbaa 100644 --- a/src/main/scala/is/hail/methods/SplitMulti.scala +++ b/src/main/scala/is/hail/methods/SplitMulti.scala @@ -5,10 +5,10 @@ import is.hail.asm4s.AsmFunction13 import is.hail.expr._ import is.hail.expr.ir._ import is.hail.expr.types._ -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row class ExprAnnotator(val ec: EvalContext, t: TStruct, expr: String, head: Option[String]) extends Serializable { @@ -66,37 +66,26 @@ class SplitMultiRowIR(rowIRs: Array[(String, IR)], entryIRs: Array[(String, IR)] class SplitMultiPartitionContextIR( keepStar: Boolean, nSamples: Int, globalAnnotation: Annotation, matrixType: MatrixType, - rowF: () => AsmFunction13[Region, Long, Boolean, Long, Boolean, Long, Boolean, Long, Boolean, Int, Boolean, Boolean, Boolean, Long], newRVRowType: TStruct) extends - SplitMultiPartitionContext(keepStar, nSamples, globalAnnotation, matrixType, newRVRowType) { - - private var globalsCopied = false - private var globals: Long = 0 - private var globalsEnd: Long = 0 - - private def copyGlobals() { - splitRegion.clear() - rvb.set(splitRegion) - rvb.start(matrixType.globalType) - rvb.addAnnotation(matrixType.globalType, globalAnnotation) - globals = rvb.end() - globalsEnd = splitRegion.size - globalsCopied = true - } + rowF: () => AsmFunction13[Region, Long, Boolean, Long, Boolean, Long, Boolean, Long, Boolean, Int, Boolean, Boolean, Boolean, Long], + newRVRowType: TStruct, + region: Region +) extends + SplitMultiPartitionContext(keepStar, nSamples, globalAnnotation, matrixType, newRVRowType, region) { private val allelesType = matrixType.rowType.fieldByName("alleles").typ private val locusType = matrixType.rowType.fieldByName("locus").typ val f = rowF() def constructSplitRow(splitVariants: Iterator[(Locus, IndexedSeq[String], Int)], rv: RegionValue, wasSplit: Boolean): Iterator[RegionValue] = { - if (!globalsCopied) - copyGlobals() - splitRegion.clear(globalsEnd) - rvb.start(matrixType.rvRowType) - rvb.addRegionValue(matrixType.rvRowType, rv) - val oldRow = rvb.end() - val oldEnd = splitRegion.size splitVariants.map { case (newLocus, newAlleles, aIndex) => - splitRegion.clear(oldEnd) + rvb.set(splitRegion) + rvb.start(matrixType.globalType) + rvb.addAnnotation(matrixType.globalType, globalAnnotation) + val globals = rvb.end() + + rvb.start(matrixType.rvRowType) + rvb.addRegionValue(matrixType.rvRowType, rv) + val oldRow = rvb.end() rvb.start(locusType) rvb.addAnnotation(locusType, newLocus) @@ -115,9 +104,15 @@ class SplitMultiPartitionContextIR( class SplitMultiPartitionContextAST( keepStar: Boolean, - nSamples: Int, globalAnnotation: Annotation, matrixType: MatrixType, - vAnnotator: ExprAnnotator, gAnnotator: ExprAnnotator, newRVRowType: TStruct) extends - SplitMultiPartitionContext(keepStar, nSamples, globalAnnotation, matrixType, newRVRowType) { + nSamples: Int, + globalAnnotation: Annotation, + matrixType: MatrixType, + vAnnotator: ExprAnnotator, + gAnnotator: ExprAnnotator, + newRVRowType: TStruct, + region: Region +) extends + SplitMultiPartitionContext(keepStar, nSamples, globalAnnotation, matrixType, newRVRowType, region) { val (t1, locusInserter) = vAnnotator.newT.insert(matrixType.rowType.fieldByName("locus").typ, "locus") assert(t1 == vAnnotator.newT) @@ -127,7 +122,6 @@ class SplitMultiPartitionContextAST( def constructSplitRow(splitVariants: Iterator[(Locus, IndexedSeq[String], Int)], rv: RegionValue, wasSplit: Boolean): Iterator[RegionValue] = { val gs = fullRow.getAs[IndexedSeq[Any]](matrixType.entriesIdx) splitVariants.map { case (newLocus, newAlleles, i) => - splitRegion.clear() rvb.set(splitRegion) rvb.start(newRVRowType) rvb.startStruct() @@ -165,12 +159,16 @@ class SplitMultiPartitionContextAST( abstract class SplitMultiPartitionContext( keepStar: Boolean, - nSamples: Int, globalAnnotation: Annotation, matrixType: MatrixType, newRVRowType: TStruct) extends Serializable { + nSamples: Int, + globalAnnotation: Annotation, + matrixType: MatrixType, + newRVRowType: TStruct, + val splitRegion: Region +) extends Serializable { var fullRow = new UnsafeRow(matrixType.rvRowType) var prevLocus: Locus = null val rvv = new RegionValueVariant(matrixType.rvRowType) - val splitRegion = Region() val rvb = new RegionValueBuilder() val splitrv = RegionValue() val locusAllelesOrdering = matrixType.rowKeyStruct.ordering @@ -232,10 +230,14 @@ object SplitMulti { splitmulti.split() } - def unionMovedVariants(ordered: OrderedRVD, - moved: RDD[RegionValue]): OrderedRVD = { - val movedRVD = OrderedRVD.adjustBoundsAndShuffle(ordered.typ, - ordered.partitioner, moved) + def unionMovedVariants( + ordered: OrderedRVD, + moved: RVD + ): OrderedRVD = { + val movedRVD = OrderedRVD.adjustBoundsAndShuffle( + ordered.typ, + ordered.partitioner, + moved) ordered.copy(orderedPartitioner = movedRVD.partitioner).partitionSortedUnion(movedRVD) } @@ -284,7 +286,7 @@ class SplitMulti(vsm: MatrixTable, variantExpr: String, genotypeExpr: String, ke (t, true, null) } - def split(sortAlleles: Boolean, removeLeftAligned: Boolean, removeMoving: Boolean, verifyLeftAligned: Boolean): RDD[RegionValue] = { + def split(sortAlleles: Boolean, removeLeftAligned: Boolean, removeMoving: Boolean, verifyLeftAligned: Boolean): RVD = { val localKeepStar = keepStar val globalsBc = vsm.globals.broadcast val localNSamples = vsm.numCols @@ -298,27 +300,22 @@ class SplitMulti(vsm: MatrixTable, variantExpr: String, genotypeExpr: String, ke val locusIndex = localRowType.fieldIdx("locus") - if (useAST) { - vsm.rvd.mapPartitions { it => - val context = new SplitMultiPartitionContextAST(localKeepStar, localNSamples, globalsBc.value, - localMatrixType, localVAnnotator, localGAnnotator, newRowType) - it.flatMap { rv => - val splitit = context.splitRow(rv, sortAlleles, removeLeftAligned, removeMoving, verifyLeftAligned) - context.prevLocus = context.fullRow.getAs[Locus](locusIndex) - splitit - } - } + val makeContext = if (useAST) { + (region: Region) => new SplitMultiPartitionContextAST(localKeepStar, localNSamples, globalsBc.value, + localMatrixType, localVAnnotator, localGAnnotator, newRowType, region) } else { - vsm.rvd.mapPartitions { it => - val context = new SplitMultiPartitionContextIR(localKeepStar, localNSamples, globalsBc.value, - localMatrixType, localSplitRow, newRowType) - it.flatMap { rv => - val splitit = context.splitRow(rv, sortAlleles, removeLeftAligned, removeMoving, verifyLeftAligned) - context.prevLocus = context.fullRow.getAs[Locus](locusIndex) - splitit - } - } + (region: Region) => new SplitMultiPartitionContextIR(localKeepStar, localNSamples, globalsBc.value, + localMatrixType, localSplitRow, newRowType, region) } + + vsm.rvd.boundary.mapPartitions(newRowType, { (ctx, it) => + val splitMultiContext = makeContext(ctx.region) + it.flatMap { rv => + val splitit = splitMultiContext.splitRow(rv, sortAlleles, removeLeftAligned, removeMoving, verifyLeftAligned) + splitMultiContext.prevLocus = splitMultiContext.fullRow.getAs[Locus](locusIndex) + splitit + } + }) } def split(): MatrixTable = { diff --git a/src/main/scala/is/hail/methods/VEP.scala b/src/main/scala/is/hail/methods/VEP.scala index c68b23f2581c..e742acad839e 100644 --- a/src/main/scala/is/hail/methods/VEP.scala +++ b/src/main/scala/is/hail/methods/VEP.scala @@ -6,7 +6,8 @@ import java.util.Properties import is.hail.annotations.{Annotation, Region, RegionValue, RegionValueBuilder} import is.hail.expr._ import is.hail.expr.types._ -import is.hail.rvd.{OrderedRVD, OrderedRVDType} +import is.hail.rvd.{OrderedRVD, OrderedRVDType, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant.{Locus, MatrixTable, RegionValueVariant, VariantMethods} import org.apache.spark.sql.Row @@ -336,9 +337,9 @@ object VEP { val vepRVD: OrderedRVD = OrderedRVD( vepORVDType, vsm.rvd.partitioner, - annotations.mapPartitions { it => - val region = Region() - val rvb = new RegionValueBuilder(region) + ContextRDD.weaken[RVDContext](annotations).cmapPartitions { (ctx, it) => + val region = ctx.region + val rvb = ctx.rvb val rv = RegionValue(region) it.map { case (v, vep) => @@ -351,7 +352,8 @@ object VEP { rv.setOffset(rvb.end()) rv - }}) + } + }) info(s"vep: annotated ${ annotations.count() } variants") diff --git a/src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala b/src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala index 0fdad405407f..4643e9992301 100644 --- a/src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala +++ b/src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala @@ -5,6 +5,8 @@ import is.hail.sparkextras._ import is.hail.utils.fatal import org.apache.spark.rdd.RDD +import scala.collection.generic.Growable + class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) { val typ: OrderedRVDType = rvd.typ val (kType, _) = rvd.rowType.select(key) @@ -22,7 +24,7 @@ class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) { def orderedJoin( right: KeyedOrderedRVD, joinType: String, - joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joiner: (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue], joinedType: OrderedRVDType ): OrderedRVD = { checkJoinCompatability(right) @@ -35,28 +37,34 @@ class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) { this.rvd.constrainToOrderedPartitioner(this.typ, newPartitioner) val repartitionedRight = right.rvd.constrainToOrderedPartitioner(right.typ, newPartitioner) - val compute: (OrderedRVIterator, OrderedRVIterator) => Iterator[JoinedRegionValue] = + val compute: (OrderedRVIterator, OrderedRVIterator, Iterable[RegionValue] with Growable[RegionValue]) => Iterator[JoinedRegionValue] = (joinType: @unchecked) match { - case "inner" => _.innerJoin(_) - case "left" => _.leftJoin(_) - case "right" => _.rightJoin(_) - case "outer" => _.outerJoin(_) + case "inner" => _.innerJoin(_, _) + case "left" => _.leftJoin(_, _) + case "right" => _.rightJoin(_, _) + case "outer" => _.outerJoin(_, _) } - val joinedRDD = - repartitionedLeft.crdd.zipPartitions(repartitionedRight.crdd, true) { - (leftIt, rightIt) => - joiner(compute( - OrderedRVIterator(lTyp, leftIt), - OrderedRVIterator(rTyp, rightIt))) - } - new OrderedRVD(joinedType, newPartitioner, joinedRDD) + repartitionedLeft.zipPartitions( + joinedType, + newPartitioner, + repartitionedRight, + preservesPartitioning = true + ) { (ctx, leftIt, rightIt) => + val sideBuffer = ctx.freshContext.region + joiner( + ctx, + compute( + OrderedRVIterator(lTyp, leftIt), + OrderedRVIterator(rTyp, rightIt), + new RegionValueArrayBuffer(rTyp.rowType, sideBuffer))) + } } def orderedJoinDistinct( right: KeyedOrderedRVD, joinType: String, - joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joiner: (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue], joinedType: OrderedRVDType ): OrderedRVD = { checkJoinCompatability(right) @@ -72,15 +80,19 @@ class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) { case "inner" => _.innerJoinDistinct(_) case "left" => _.leftJoinDistinct(_) } - val joinedRDD = - this.rvd.crdd.zipPartitions(repartitionedRight.crdd, true) { - (leftIt, rightIt) => - joiner(compute( - OrderedRVIterator(rekeyedLTyp, leftIt), - OrderedRVIterator(rekeyedRTyp, rightIt))) - } - new OrderedRVD(joinedType, newPartitioner, joinedRDD) + rvd.zipPartitions( + joinedType, + newPartitioner, + repartitionedRight, + preservesPartitioning = true + ) { (ctx, leftIt, rightIt) => + joiner( + ctx, + compute( + OrderedRVIterator(rekeyedLTyp, leftIt), + OrderedRVIterator(rekeyedRTyp, rightIt))) + } } def orderedZipJoin(right: KeyedOrderedRVD): ContextRDD[RVDContext, JoinedRegionValue] = { @@ -91,7 +103,10 @@ class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) { val leftType = this.typ val rightType = right.typ - repartitionedLeft.crdd.zipPartitions(repartitionedRight.crdd, true){ (leftIt, rightIt) => + repartitionedLeft.zipPartitions( + repartitionedRight, + preservesPartitioning = true + ) { (_, leftIt, rightIt) => OrderedRVIterator(leftType, leftIt).zipJoin(OrderedRVIterator(rightType, rightIt)) } } diff --git a/src/main/scala/is/hail/rvd/OrderedRVD.scala b/src/main/scala/is/hail/rvd/OrderedRVD.scala index eca65d8f6396..1d0b08ad842a 100644 --- a/src/main/scala/is/hail/rvd/OrderedRVD.scala +++ b/src/main/scala/is/hail/rvd/OrderedRVD.scala @@ -1,5 +1,6 @@ package is.hail.rvd +import java.io.ByteArrayInputStream import java.util import is.hail.annotations._ @@ -8,6 +9,7 @@ import is.hail.expr.types._ import is.hail.io.CodecSpec import is.hail.sparkextras._ import is.hail.utils._ +import org.apache.commons.io.output.ByteArrayOutputStream import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.SparkContext @@ -23,38 +25,45 @@ class OrderedRVD( ) extends RVD { self => - def this( - typ: OrderedRVDType, - partitioner: OrderedRVDPartitioner, - rdd: RDD[RegionValue] - ) = this(typ, partitioner, ContextRDD.weaken[RVDContext](rdd)) - - val rdd = crdd.run + def boundary: OrderedRVD = OrderedRVD(typ, partitioner, crddBoundary) def rowType: TStruct = typ.rowType def updateType(newTyp: OrderedRVDType): OrderedRVD = - OrderedRVD(newTyp, partitioner, rdd) + OrderedRVD(newTyp, partitioner, crdd) def mapPreservesPartitioning(newTyp: OrderedRVDType)(f: (RegionValue) => RegionValue): OrderedRVD = OrderedRVD(newTyp, partitioner, - rdd.map(f)) + crdd.map(f)) def mapPartitionsWithIndexPreservesPartitioning(newTyp: OrderedRVDType)(f: (Int, Iterator[RegionValue]) => Iterator[RegionValue]): OrderedRVD = OrderedRVD(newTyp, partitioner, - rdd.mapPartitionsWithIndex(f)) + crdd.mapPartitionsWithIndex(f)) + + def mapPartitionsWithIndexPreservesPartitioning( + newTyp: OrderedRVDType, + f: (Int, RVDContext, Iterator[RegionValue]) => Iterator[RegionValue] + ): OrderedRVD = OrderedRVD( + newTyp, + partitioner, + crdd.cmapPartitionsWithIndex(f)) def mapPartitionsPreservesPartitioning(newTyp: OrderedRVDType)(f: (Iterator[RegionValue]) => Iterator[RegionValue]): OrderedRVD = OrderedRVD(newTyp, partitioner, - rdd.mapPartitions(f)) + crdd.mapPartitions(f)) + + def mapPartitionsPreservesPartitioning( + newTyp: OrderedRVDType, + f: (RVDContext, Iterator[RegionValue]) => Iterator[RegionValue] + ): OrderedRVD = OrderedRVD(newTyp, partitioner, crdd.cmapPartitions(f)) override def filter(p: (RegionValue) => Boolean): OrderedRVD = OrderedRVD(typ, partitioner, - rdd.filter(p)) + crdd.filter(p)) def sample(withReplacement: Boolean, p: Double, seed: Long): OrderedRVD = OrderedRVD(typ, partitioner, crdd.sample(withReplacement, p, seed)) @@ -105,7 +114,7 @@ class OrderedRVD( def orderedJoin( right: OrderedRVD, joinType: String, - joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joiner: (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue], joinedType: OrderedRVDType ): OrderedRVD = keyBy().orderedJoin(right.keyBy(), joinType, joiner, joinedType) @@ -113,7 +122,7 @@ class OrderedRVD( def orderedJoinDistinct( right: OrderedRVD, joinType: String, - joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue], + joiner: (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue], joinedType: OrderedRVDType ): OrderedRVD = keyBy().orderedJoinDistinct(right.keyBy(), joinType, joiner, joinedType) @@ -126,15 +135,16 @@ class OrderedRVD( assert(partitioner == rdd2.partitioner) val localTyp = typ - OrderedRVD(typ, partitioner, - rdd.zipPartitions(rdd2.rdd) { case (it, it2) => - new Iterator[RegionValue] { - private val bit = it.buffered - private val bit2 = it2.buffered + zipPartitions(typ, partitioner, rdd2) { (ctx, it, it2) => + new Iterator[RegionValue] { + private val bit = it.buffered + private val bit2 = it2.buffered + private val rv = RegionValue() - def hasNext: Boolean = bit.hasNext || bit2.hasNext + def hasNext: Boolean = bit.hasNext || bit2.hasNext - def next(): RegionValue = { + def next(): RegionValue = { + val old = if (!bit.hasNext) bit2.next() else if (!bit2.hasNext) @@ -146,21 +156,25 @@ class OrderedRVD( else bit2.next() } - } + ctx.rvb.start(localTyp.rowType) + ctx.rvb.addRegionValue(localTyp.rowType, old) + rv.set(ctx.region, ctx.rvb.end()) + rv } - }) + } + } } def copy(typ: OrderedRVDType = typ, orderedPartitioner: OrderedRVDPartitioner = partitioner, - rdd: RDD[RegionValue] = rdd): OrderedRVD = { + rdd: ContextRDD[RVDContext, RegionValue] = crdd): OrderedRVD = { OrderedRVD(typ, orderedPartitioner, rdd) } def blockCoalesce(partitionEnds: Array[Int]): OrderedRVD = { assert(partitionEnds.last == partitioner.numPartitions - 1 && partitionEnds(0) >= 0) assert(partitionEnds.zip(partitionEnds.tail).forall { case (i, inext) => i < inext }) - OrderedRVD(typ, partitioner.coalesceRangeBounds(partitionEnds), new BlockedRDD(rdd, partitionEnds)) + OrderedRVD(typ, partitioner.coalesceRangeBounds(partitionEnds), crdd.blocked(partitionEnds)) } def naiveCoalesce(maxPartitions: Int): OrderedRVD = { @@ -176,11 +190,11 @@ class OrderedRVD( override def coalesce(maxPartitions: Int, shuffle: Boolean): OrderedRVD = { require(maxPartitions > 0, "cannot coalesce to nPartitions <= 0") - val n = rdd.partitions.length + val n = crdd.partitions.length if (!shuffle && maxPartitions >= n) return this if (shuffle) { - val shuffled = crdd.coalesce(maxPartitions, shuffle = true) + val shuffled = stably(_.shuffleCoalesce(maxPartitions)) val ranges = OrderedRVD.calculateKeyRanges( typ, OrderedRVD.getPartitionKeyInfo(typ, OrderedRVD.getKeys(typ, shuffled)), @@ -191,7 +205,7 @@ class OrderedRVD( shuffled) } else { - val partSize = rdd.context.runJob(rdd, getIteratorSize _) + val partSize = countPerPartition() log.info(s"partSize = ${ partSize.toSeq }") val partCumulativeSize = mapAccumulate[Array, Long](partSize, 0L)((s, acc) => (s + acc, s + acc)) @@ -225,7 +239,7 @@ class OrderedRVD( def filterIntervals(intervals: IntervalTree[_]): OrderedRVD = { val pkOrdering = typ.pkType.ordering - val intervalsBc = rdd.sparkContext.broadcast(intervals) + val intervalsBc = crdd.sparkContext.broadcast(intervals) val rowType = typ.rowType val pkRowFieldIdx = typ.pkRowFieldIdx @@ -258,7 +272,7 @@ class OrderedRVD( OrderedRVD.empty(sparkContext, typ) else { val sub = subsetPartitions(newPartitionIndices) - sub.copy(rdd = sub.rdd.filter(pred)) + sub.copy(rdd = sub.crdd.filter(pred)) } } @@ -268,9 +282,9 @@ class OrderedRVD( if (n == 0) return OrderedRVD.empty(sparkContext, typ) - val newRDD = rdd.head(n) + val newRDD = crdd.head(n) val newNParts = newRDD.getNumPartitions - assert(newNParts > 0) + assert(newNParts >= 0) val newRangeBounds = Array.range(0, newNParts).map(partitioner.rangeBounds) val newPartitioner = new OrderedRVDPartitioner(partitioner.partitionKey, @@ -289,16 +303,21 @@ class OrderedRVD( val localType = typ - val newRDD: RDD[RegionValue] = rdd.mapPartitions { it => - val region = Region() - val rvb = new RegionValueBuilder(region) - val outRV = RegionValue(region) - val buffer = new RegionValueArrayBuffer(localType.valueType) - val stepped: FlipbookIterator[FlipbookIterator[RegionValue]] = - OrderedRVIterator(localType, it).staircase + OrderedRVD(newTyp, partitioner, crdd.cmapPartitionsAndContext { (consumerCtx, useCtxes) => + val consumerRegion = consumerCtx.region + val rvb = consumerCtx.rvb + val outRV = RegionValue(consumerRegion) + + val bufferRegion = consumerCtx.freshContext.region + val buffer = new RegionValueArrayBuffer(localType.valueType, bufferRegion) + + val producerCtx = consumerCtx.freshContext + val producerRegion = producerCtx.region + val it = useCtxes.flatMap(_ (producerCtx)) + + val stepped = OrderedRVIterator(localType, it).staircase stepped.map { stepIt => - region.clear() buffer.clear() rvb.start(newRowType) rvb.startStruct() @@ -307,8 +326,10 @@ class OrderedRVD( rvb.addField(localType.rowType, stepIt.value, localType.kRowFieldIdx(i)) i += 1 } - for (rv <- stepIt) + for (rv <- stepIt) { buffer.appendSelect(localType.rowType, localType.valueFieldIdx, rv) + producerRegion.clear() + } rvb.startArray(buffer.length) for (rv <- buffer) rvb.addRegionValue(localType.valueType, rv) @@ -317,24 +338,21 @@ class OrderedRVD( outRV.setOffset(rvb.end()) outRV } - } - - OrderedRVD(newTyp, partitioner, newRDD) + }) } def distinctByKey(): OrderedRVD = { val localType = typ - val newRVD = rdd.mapPartitions { it => + mapPartitionsPreservesPartitioning(typ)(it => OrderedRVIterator(localType, it) .staircase .map(_.value) - } - OrderedRVD(typ, partitioner, newRVD) + ) } def subsetPartitions(keep: Array[Int]): OrderedRVD = { - require(keep.length <= rdd.partitions.length, "tried to subset to more partitions than exist") - require(keep.isIncreasing && (keep.isEmpty || (keep.head >= 0 && keep.last < rdd.partitions.length)), + require(keep.length <= crdd.partitions.length, "tried to subset to more partitions than exist") + require(keep.isIncreasing && (keep.isEmpty || (keep.head >= 0 && keep.last < crdd.partitions.length)), "values not sorted or not in range [0, number of partitions)") val newRangeBounds = Array.tabulate(keep.length) { i => @@ -350,7 +368,7 @@ class OrderedRVD( partitioner.kType, newRangeBounds) - OrderedRVD(typ, newPartitioner, rdd.subsetPartitions(keep)) + OrderedRVD(typ, newPartitioner, crdd.subsetPartitions(keep)) } override protected def rvdSpec(codecSpec: CodecSpec, partFiles: Array[String]): RVDSpec = @@ -362,25 +380,67 @@ class OrderedRVD( partitioner.rangeBounds, partitioner.rangeBoundsType)) + def zipPartitionsAndContext( + newTyp: OrderedRVDType, + newPartitioner: OrderedRVDPartitioner, + that: OrderedRVD, + preservesPartitioning: Boolean = false + )(zipper: (RVDContext, RVDContext => Iterator[RegionValue], RVDContext => Iterator[RegionValue]) => Iterator[RegionValue] + ): OrderedRVD = OrderedRVD( + newTyp, + newPartitioner, + crdd.czipPartitionsAndContext(that.crdd, preservesPartitioning) { (ctx, lit, rit) => + zipper(ctx, ctx => lit.flatMap(_(ctx)), ctx => rit.flatMap(_(ctx))) + } + ) + def zipPartitionsPreservesPartitioning[T: ClassTag]( newTyp: OrderedRVDType, - that: RDD[T] + that: ContextRDD[RVDContext, T] )(zipper: (Iterator[RegionValue], Iterator[T]) => Iterator[RegionValue] - ): OrderedRVD = - OrderedRVD( - newTyp, - partitioner, - this.rdd.zipPartitions(that, preservesPartitioning = true)(zipper)) + ): OrderedRVD = OrderedRVD( + newTyp, + partitioner, + crdd.zipPartitions(that)(zipper)) + + def zipPartitions( + newTyp: OrderedRVDType, + newPartitioner: OrderedRVDPartitioner, + that: OrderedRVD + )(zipper: (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[RegionValue] + ): OrderedRVD = zipPartitions(newTyp, newPartitioner, that, false)(zipper) - def zipPartitionsPreservesPartitioning( + def zipPartitions( + newTyp: OrderedRVDType, + newPartitioner: OrderedRVDPartitioner, + that: OrderedRVD, + preservesPartitioning: Boolean + )(zipper: (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[RegionValue] + ): OrderedRVD = OrderedRVD( + newTyp, + newPartitioner, + boundary.crdd.czipPartitions(that.boundary.crdd, preservesPartitioning)(zipper)) + + def zipPartitions[T: ClassTag]( + that: OrderedRVD + )(zipper: (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[T] + ): ContextRDD[RVDContext, T] = zipPartitions(that, false)(zipper) + + def zipPartitions[T: ClassTag]( + that: OrderedRVD, + preservesPartitioning: Boolean + )(zipper: (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[T] + ): ContextRDD[RVDContext, T] = + boundary.crdd.czipPartitions(that.boundary.crdd, preservesPartitioning)(zipper) + + def zip( newTyp: OrderedRVDType, that: RVD - )(zipper: (Iterator[RegionValue], Iterator[RegionValue]) => Iterator[RegionValue] - ): OrderedRVD = - OrderedRVD( - newTyp, - partitioner, - this.rdd.zipPartitions(that.rdd, preservesPartitioning = true)(zipper)) + )(zipper: (RVDContext, RegionValue, RegionValue) => RegionValue + ): OrderedRVD = OrderedRVD( + newTyp, + partitioner, + this.crdd.czip(that.crdd, preservesPartitioning = true)(zipper)) def writeRowsSplit( path: String, @@ -396,21 +456,29 @@ object OrderedRVD { def empty(sc: SparkContext, typ: OrderedRVDType): OrderedRVD = { OrderedRVD(typ, OrderedRVDPartitioner.empty(typ), - sc.emptyRDD[RegionValue]) + ContextRDD.empty[RVDContext, RegionValue](sc)) } /** * Precondition: the iterator it is PK-sorted. We lazily K-sort each block * of PK-equivalent elements. */ - def localKeySort(typ: OrderedRVDType, + def localKeySort( + consumerRegion: Region, + producerRegion: Region, + typ: OrderedRVDType, // it: Iterator[RegionValue[rowType]] - it: Iterator[RegionValue]): Iterator[RegionValue] = { + it: Iterator[RegionValue] + ): Iterator[RegionValue] = new Iterator[RegionValue] { private val bit = it.buffered private val q = new mutable.PriorityQueue[RegionValue]()(typ.kInRowOrd.reverse) + private val rvb = new RegionValueBuilder(consumerRegion) + + private val rv = RegionValue() + def hasNext: Boolean = bit.hasNext || q.nonEmpty def next(): RegionValue = { @@ -418,17 +486,17 @@ object OrderedRVD { do { val rv = bit.next() // FIXME ugh, no good answer here - q.enqueue(RegionValue( - rv.region.copy(), - rv.offset)) + q.enqueue(rv.copy()) + producerRegion.clear() } while (bit.hasNext && typ.pkInRowOrd.compare(q.head, bit.head) == 0) } - val rv = q.dequeue() + rvb.start(typ.rowType) + rvb.addRegionValue(typ.rowType, q.dequeue()) + rv.set(consumerRegion, rvb.end()) rv } } - } // getKeys: RDD[RegionValue[kType]] def getKeys( @@ -446,14 +514,6 @@ object OrderedRVD { } } - // FIXME: delete when I've removed all need for RDDs - def getKeys( - typ: OrderedRVDType, - rdd: RDD[RegionValue] - ): RDD[RegionValue] = getKeys( - typ, - ContextRDD.weaken[RVDContext](rdd)).run - def getPartitionKeyInfo( typ: OrderedRVDType, // keys: RDD[kType] @@ -471,25 +531,18 @@ object OrderedRVD { val localType = typ - val pkis = keys.mapPartitionsWithIndex { case (i, it) => - if (it.hasNext) + val pkis = keys.cmapPartitionsWithIndex { (i, ctx, it) => + val out = if (it.hasNext) Iterator(OrderedRVPartitionInfo(localType, samplesPerPartition, i, it, partitionSeed(i))) else Iterator() + ctx.region.clear() + out }.collect() pkis.sortBy(_.min)(typ.pkType.ordering.toOrdering) } - // FIXME: delete when I've removed all need for RDDs - def getPartitionKeyInfo[C]( - typ: OrderedRVDType, - // keys: RDD[kType] - keys: RDD[RegionValue] - ): Array[OrderedRVPartitionInfo] = getPartitionKeyInfo( - typ, - ContextRDD.weaken[RVDContext](keys)) - def coerce( typ: OrderedRVDType, rvd: RVD @@ -498,7 +551,7 @@ object OrderedRVD { def coerce( typ: OrderedRVDType, rvd: RVD, - fastKeys: RDD[RegionValue] + fastKeys: ContextRDD[RVDContext, RegionValue] ): OrderedRVD = coerce(typ, rvd, Some(fastKeys), None) def coerce( @@ -510,9 +563,9 @@ object OrderedRVD { def coerce( typ: OrderedRVDType, rvd: RVD, - fastKeys: Option[RDD[RegionValue]], + fastKeys: Option[ContextRDD[RVDContext, RegionValue]], hintPartitioner: Option[OrderedRVDPartitioner] - ): OrderedRVD = coerce(typ, rvd.rdd, fastKeys, hintPartitioner) + ): OrderedRVD = coerce(typ, rvd.crdd, fastKeys, hintPartitioner) def coerce( typ: OrderedRVDType, @@ -536,23 +589,47 @@ object OrderedRVD { rdd: RDD[RegionValue], fastKeys: RDD[RegionValue], hintPartitioner: OrderedRVDPartitioner - ): OrderedRVD = coerce(typ, rdd, Some(fastKeys), Some(hintPartitioner)) + ): OrderedRVD = coerce( + typ, + rdd, + Some(fastKeys), + Some(hintPartitioner)) def coerce( typ: OrderedRVDType, - // rdd: RDD[RegionValue[rowType]] rdd: RDD[RegionValue], - // fastKeys: Option[RDD[RegionValue[kType]]] fastKeys: Option[RDD[RegionValue]], hintPartitioner: Option[OrderedRVDPartitioner] - ): OrderedRVD = { - val sc = rdd.sparkContext + ): OrderedRVD = coerce( + typ, + ContextRDD.weaken[RVDContext](rdd), + fastKeys.map(ContextRDD.weaken[RVDContext](_)), + hintPartitioner) - if (rdd.partitions.isEmpty) - return empty(sc, typ) + def coerce( + typ: OrderedRVDType, + crdd: ContextRDD[RVDContext, RegionValue] + ): OrderedRVD = coerce(typ, crdd, None, None) + + def coerce( + typ: OrderedRVDType, + crdd: ContextRDD[RVDContext, RegionValue], + fastKeys: ContextRDD[RVDContext, RegionValue] + ): OrderedRVD = coerce(typ, crdd, Some(fastKeys), None) + + def coerce( + typ: OrderedRVDType, + crdd: ContextRDD[RVDContext, RegionValue], + fastKeys: Option[ContextRDD[RVDContext, RegionValue]], + hintPartitioner: Option[OrderedRVDPartitioner] + ): OrderedRVD = { + val sc = crdd.sparkContext + + if (crdd.partitions.isEmpty) + return empty(sc, typ) // keys: RDD[RegionValue[kType]] - val keys = fastKeys.getOrElse(getKeys(typ, rdd)) + val keys = fastKeys.getOrElse(getKeys(typ, crdd)) val pkis = getPartitionKeyInfo(typ, keys) @@ -574,8 +651,9 @@ object OrderedRVD { typ.kType, rangeBounds) - val reorderedPartitionsRDD = rdd.reorderPartitions(pkis.map(_.partitionIndex)) - val adjustedRDD = new AdjustedPartitionsRDD(reorderedPartitionsRDD, adjustedPartitions) + val adjustedRDD = crdd + .reorderPartitions(pkis.map(_.partitionIndex)) + .adjustPartitions(adjustedPartitions) (adjSortedness: @unchecked) match { case OrderedRVPartitionInfo.KSORTED => info("Coerced sorted dataset") @@ -587,19 +665,20 @@ object OrderedRVD { info("Coerced almost-sorted dataset") OrderedRVD(typ, partitioner, - adjustedRDD.mapPartitions { it => - localKeySort(typ, it) + adjustedRDD.cmapPartitionsAndContext { (consumerCtx, it) => + val producerCtx = consumerCtx.freshContext + localKeySort(consumerCtx.region, producerCtx.region, typ, it.flatMap(_(producerCtx))) }) } } else { info("Ordering unsorted dataset with network shuffle") hintPartitioner - .filter(_.numPartitions >= rdd.partitions.length) - .map(adjustBoundsAndShuffle(typ, _, rdd)) + .filter(_.numPartitions >= crdd.partitions.length) + .map(adjustBoundsAndShuffle(typ, _, crdd)) .getOrElse { - val ranges = calculateKeyRanges(typ, pkis, rdd.getNumPartitions) + val ranges = calculateKeyRanges(typ, pkis, crdd.getNumPartitions) val p = new OrderedRVDPartitioner(typ.partitionKey, typ.kType, ranges) - shuffle(typ, p, rdd) + shuffle(typ, p, crdd) } } } @@ -641,21 +720,31 @@ object OrderedRVD { OrderedRVDPartitioner.makeRangeBoundIntervals(typ.pkType, partitionEdges) } - def adjustBoundsAndShuffle(typ: OrderedRVDType, + def adjustBoundsAndShuffle( + typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, - rdd: RDD[RegionValue]): OrderedRVD = { + rvd: RVD + ): OrderedRVD = { + assert(typ.rowType == rvd.rowType) + adjustBoundsAndShuffle(typ, partitioner, rvd.crdd) + } + private[this] def adjustBoundsAndShuffle( + typ: OrderedRVDType, + partitioner: OrderedRVDPartitioner, + crdd: ContextRDD[RVDContext, RegionValue] + ): OrderedRVD = { val pkType = partitioner.pkType val pkOrd = pkType.ordering.toOrdering - val pkis = getPartitionKeyInfo(typ, OrderedRVD.getKeys(typ, rdd)) + val pkis = getPartitionKeyInfo(typ, getKeys(typ, crdd)) if (pkis.isEmpty) - return OrderedRVD(typ, partitioner, rdd) + return OrderedRVD(typ, partitioner, crdd) val min = pkis.map(_.min).min(pkOrd) val max = pkis.map(_.max).max(pkOrd) - shuffle(typ, partitioner.enlargeToRange(Interval(min, max, true, true)), rdd) + shuffle(typ, partitioner.enlargeToRange(Interval(min, max, true, true)), crdd) } def shuffle( @@ -664,12 +753,6 @@ object OrderedRVD { rvd: RVD ): OrderedRVD = shuffle(typ, partitioner, rvd.crdd) - def shuffle(typ: OrderedRVDType, - partitioner: OrderedRVDPartitioner, - rdd: RDD[RegionValue] - ): OrderedRVD = - shuffle(typ, partitioner, ContextRDD.weaken[RVDContext](rdd)) - def shuffle( typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, @@ -679,19 +762,23 @@ object OrderedRVD { val partBc = partitioner.broadcast(crdd.sparkContext) OrderedRVD(typ, partitioner, - crdd.mapPartitions { it => - val wrv = WritableRegionValue(typ.rowType) + crdd.cmapPartitions { (ctx, it) => + val enc = RVD.wireCodec.buildEncoder(localType.rowType) + it.map { rv => val wkrv = WritableRegionValue(typ.kType) - it.map { rv => - wrv.set(rv) - wkrv.setSelect(localType.rowType, localType.kRowFieldIdx, rv) - (wkrv.value, wrv.value) - } + wkrv.setSelect(localType.rowType, localType.kRowFieldIdx, rv) + val bytes = + RVD.regionValueToBytes(enc, ctx)(rv) + (wkrv.value, bytes) + } }.shuffle(partitioner.sparkPartitioner(crdd.sparkContext), typ.kOrd) - .mapPartitionsWithIndex { case (i, it) => - it.map { case (k, v) => + .cmapPartitionsWithIndex { case (i, ctx, it) => + val dec = RVD.wireCodec.buildDecoder(localType.rowType) + val region = ctx.region + val rv = RegionValue(region) + it.map { case (k, bytes) => assert(partBc.value.getPartition(k) == i) - v + RVD.bytesToRegionValue(dec, region, rv)(bytes) } }) } @@ -784,13 +871,21 @@ object OrderedRVD { typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, rvd: RVD - ): OrderedRVD = apply(typ, partitioner, rvd.rdd) + ): OrderedRVD = apply(typ, partitioner, rvd.crdd) def apply( typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, - rdd: RDD[RegionValue] - ): OrderedRVD = apply(typ, partitioner, ContextRDD.weaken[RVDContext](rdd)) + codec: CodecSpec, + rdd: RDD[Array[Byte]] + ): OrderedRVD = apply( + typ, + partitioner, + ContextRDD.weaken[RVDContext](rdd).cmapPartitions { (ctx, it) => + val dec = codec.buildDecoder(typ.rowType) + val rv = RegionValue() + it.map(RVD.bytesToRegionValue(dec, ctx.region, rv)) + }) def apply( typ: OrderedRVDType, diff --git a/src/main/scala/is/hail/rvd/RVD.scala b/src/main/scala/is/hail/rvd/RVD.scala index e275fae5a370..c7a212da0c0d 100644 --- a/src/main/scala/is/hail/rvd/RVD.scala +++ b/src/main/scala/is/hail/rvd/RVD.scala @@ -1,6 +1,7 @@ package is.hail.rvd -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import is.hail.expr.types.Type +import java.io.{ ByteArrayInputStream, ByteArrayOutputStream, InputStream, OutputStream } import is.hail.HailContext import is.hail.annotations._ @@ -43,9 +44,16 @@ object RVDSpec { val hConf = hc.hadoopConf partFiles.flatMap { p => val f = path + "/parts/" + p - val in = hConf.unsafeReader(f) - HailContext.readRowsPartition(codecSpec.buildDecoder(rowType))(0, in) - .map(rv => SafeRow(rowType, rv.region, rv.offset)) + hConf.readFile(f) { in => + using(RVDContext.default) { ctx => + HailContext.readRowsPartition(codecSpec.buildDecoder(rowType))(ctx, in) + .map { rv => + val r = SafeRow(rowType, rv.region, rv.offset) + ctx.region.clear() + r + } + } + } }.toFastIndexedSeq } } @@ -102,19 +110,19 @@ object RVD { val hConf = hc.hadoopConf hConf.mkDir(path + "/parts") - val os = hConf.unsafeWriter(path + "/parts/part-0") - val part0Count = Region.scoped { region => - val rvb = new RegionValueBuilder(region) - val rv = RegionValue(region) - RichContextRDDRegionValue.writeRowsPartition(codecSpec.buildEncoder(rowType))(0, - rows.iterator.map { a => - region.clear() - rvb.start(rowType) - rvb.addAnnotation(rowType, a) - rv.setOffset(rvb.end()) - rv - }, os) - } + val part0Count = + hConf.writeFile(path + "/parts/part-0") { os => + using(RVDContext.default) { ctx => + val rvb = ctx.rvb + val region = ctx.region + RichContextRDDRegionValue.writeRowsPartition(codecSpec.buildEncoder(rowType))(ctx, + rows.iterator.map { a => + rvb.start(rowType) + rvb.addAnnotation(rowType, a) + RegionValue(region, rvb.end()) + }, os) + } + } val spec = UnpartitionedRVDSpec(rowType, codecSpec, Array("part-0")) spec.write(hConf, path) @@ -129,15 +137,130 @@ object RVD { new UnpartitionedRVD(first.rowType, ContextRDD.union(sc, rvds.map(_.crdd))) } + + val memoryCodec = CodecSpec.defaultUncompressed + + val wireCodec = memoryCodec + + def regionValueToBytes( + makeEnc: OutputStream => Encoder, + ctx: RVDContext + )(rv: RegionValue + ): Array[Byte] = + using(new ByteArrayOutputStream()) { baos => + using(makeEnc(baos)) { enc => + enc.writeRegionValue(rv.region, rv.offset) + enc.flush() + ctx.region.clear() + baos.toByteArray + } + } + + def bytesToRegionValue( + makeDec: InputStream => Decoder, + r: Region, + carrierRv: RegionValue + )(bytes: Array[Byte] + ): RegionValue = + using(new ByteArrayInputStream(bytes)) { bais => + using(makeDec(bais)) { dec => + carrierRv.setOffset(dec.readRegionValue(r)) + carrierRv + } + } } trait RVD { self => + def rowType: TStruct def crdd: ContextRDD[RVDContext, RegionValue] - def rdd: RDD[RegionValue] + private[rvd] def stabilize( + unstable: ContextRDD[RVDContext, RegionValue], + codec: CodecSpec = RVD.memoryCodec + ): ContextRDD[RVDContext, Array[Byte]] = { + val enc = codec.buildEncoder(rowType) + unstable.cmapPartitions { (ctx, it) => + it.map(RVD.regionValueToBytes(enc, ctx)) + } + } + + private[rvd] def destabilize( + stable: ContextRDD[RVDContext, Array[Byte]], + codec: CodecSpec = RVD.memoryCodec + ): ContextRDD[RVDContext, RegionValue] = { + val dec = codec.buildDecoder(rowType) + stable.cmapPartitions { (ctx, it) => + val rv = RegionValue(ctx.region) + it.map(RVD.bytesToRegionValue(dec, ctx.region, rv)) + } + } + + private[rvd] def stably( + f: ContextRDD[RVDContext, Array[Byte]] => ContextRDD[RVDContext, Array[Byte]] + ): ContextRDD[RVDContext, RegionValue] = stably(crdd, f) + + private[rvd] def stably( + unstable: ContextRDD[RVDContext, RegionValue], + f: ContextRDD[RVDContext, Array[Byte]] => ContextRDD[RVDContext, Array[Byte]] + ): ContextRDD[RVDContext, RegionValue] = destabilize(f(stabilize(unstable))) + + private[rvd] def crddBoundary: ContextRDD[RVDContext, RegionValue] = + crdd.cmapPartitionsAndContext { (consumerCtx, part) => + val producerCtx = consumerCtx.freshContext + val it = part.flatMap(_ (producerCtx)) + new Iterator[RegionValue]() { + private[this] var cleared: Boolean = false + + def hasNext = { + if (!cleared) { + cleared = true + producerCtx.region.clear() + } + it.hasNext + } + + def next = { + if (!cleared) { + producerCtx.region.clear() + } + cleared = false + it.next + } + } + } + + def boundary: RVD + + def encodedRDD(codec: CodecSpec): RDD[Array[Byte]] = + stabilize(crdd, codec).run + + def head(n: Long): RVD + + final def takeAsBytes(n: Int, codec: CodecSpec): Array[Array[Byte]] = + head(n).encodedRDD(codec).collect() + + final def take(n: Int, codec: CodecSpec): Array[Row] = { + val dec = codec.buildDecoder(rowType) + val encodedData = takeAsBytes(n, codec) + Region.scoped { region => + encodedData.iterator + .map(RVD.bytesToRegionValue(dec, region, RegionValue(region))) + .map { rv => + val row = SafeRow(rowType, rv) + region.clear() + row + }.toArray + } + } + + def forall(p: RegionValue => Boolean): Boolean = + crdd.map(p).run.forall(x => x) + + def exists(p: RegionValue => Boolean): Boolean = + crdd.map(p).run.exists(x => x) def sparkContext: SparkContext = crdd.sparkContext @@ -154,14 +277,36 @@ trait RVD { it.map { rv => f(c, rv) } }) - def map[T](f: (RegionValue) => T)(implicit tct: ClassTag[T]): RDD[T] = rdd.map(f) - def mapPartitions(newRowType: TStruct)(f: (Iterator[RegionValue]) => Iterator[RegionValue]): RVD = new UnpartitionedRVD(newRowType, crdd.mapPartitions(f)) - def mapPartitionsWithIndex[T](f: (Int, Iterator[RegionValue]) => Iterator[T])(implicit tct: ClassTag[T]): RDD[T] = rdd.mapPartitionsWithIndex(f) + def mapPartitions(newRowType: TStruct, f: (RVDContext, Iterator[RegionValue]) => Iterator[RegionValue]): RVD = + new UnpartitionedRVD(newRowType, crdd.cmapPartitions(f)) + + def find(codec: CodecSpec, p: (RegionValue) => Boolean): Option[Array[Byte]] = + filter(p).takeAsBytes(1, codec).headOption + + def find(region: Region)(p: (RegionValue) => Boolean): Option[RegionValue] = + find(RVD.wireCodec, p).map( + RVD.bytesToRegionValue(RVD.wireCodec.buildDecoder(rowType), region, RegionValue(region))) + + // Only use on CRDD's whose T is not dependent on the context + private[rvd] def clearingRun[T: ClassTag]( + crdd: ContextRDD[RVDContext, T] + ): RDD[T] = crdd.cmap { (ctx, v) => + ctx.region.clear() + v + }.run - def mapPartitions[T](f: (Iterator[RegionValue]) => Iterator[T])(implicit tct: ClassTag[T]): RDD[T] = rdd.mapPartitions(f) + def map[T](f: (RegionValue) => T)(implicit tct: ClassTag[T]): RDD[T] = clearingRun(crdd.map(f)) + + def mapPartitionsWithIndex[T](f: (Int, Iterator[RegionValue]) => Iterator[T])(implicit tct: ClassTag[T]): RDD[T] = clearingRun(crdd.mapPartitionsWithIndex(f)) + + def mapPartitionsWithIndex[T: ClassTag]( + f: (Int, RVDContext, Iterator[RegionValue]) => Iterator[T] + ): RDD[T] = clearingRun(crdd.cmapPartitionsWithIndex(f)) + + def mapPartitions[T](f: (Iterator[RegionValue]) => Iterator[T])(implicit tct: ClassTag[T]): RDD[T] = clearingRun(crdd.mapPartitions(f)) def constrainToOrderedPartitioner( ordType: OrderedRVDType, @@ -171,7 +316,7 @@ trait RVD { def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, RegionValue) => U, combOp: (U, U) => U, - depth: Int = treeAggDepth(HailContext.get, rdd.getNumPartitions) + depth: Int = treeAggDepth(HailContext.get, crdd.getNumPartitions) ): U = crdd.treeAggregate(zeroValue, seqOp, combOp, depth) def aggregate[U: ClassTag]( @@ -180,50 +325,45 @@ trait RVD { combOp: (U, U) => U ): U = crdd.aggregate(zeroValue, seqOp, combOp) - def count(): Long = rdd.count() - - def countPerPartition(): Array[Long] = rdd.countPerPartition() + def count(): Long = + crdd.cmapPartitions { (ctx, it) => + var count = 0L + it.foreach { rv => + count += 1 + ctx.region.clear() + } + Iterator.single(count) + }.run.fold(0L)(_ + _) + + def countPerPartition(): Array[Long] = + crdd.cmapPartitions { (ctx, it) => + var count = 0L + it.foreach { rv => + count += 1 + ctx.region.clear() + } + Iterator.single(count) + }.collect() protected def persistRVRDD(level: StorageLevel): PersistedRVRDD = { val localRowType = rowType - val persistCodec = - new PackCodecSpec( - new BlockingBufferSpec(32 * 1024, - new StreamBlockBufferSpec)) - - val makeEnc = persistCodec.buildEncoder(localRowType) + val makeEnc = RVD.memoryCodec.buildEncoder(localRowType) - val makeDec = persistCodec.buildDecoder(localRowType) + val makeDec = RVD.memoryCodec.buildDecoder(localRowType) // copy, persist region values - val persistedRDD = rdd.mapPartitions { it => - it.map { rv => - using(new ByteArrayOutputStream()) { baos => - using(makeEnc(baos)) { enc => - enc.writeRegionValue(rv.region, rv.offset) - enc.flush() - baos.toByteArray - } - } - } - } + val persistedRDD = crdd.cmapPartitions { (ctx, it) => + it.map(RVD.regionValueToBytes(makeEnc, ctx)) + } .run .persist(level) PersistedRVRDD(persistedRDD, ContextRDD.weaken[RVDContext](persistedRDD) - .mapPartitions { it => - val region = Region() - val rv2 = RegionValue(region) - it.map { bytes => - region.clear() - using(new ByteArrayInputStream(bytes)) { bais => - using(makeDec(bais)) { dec => - rv2.setOffset(dec.readRegionValue(region)) - rv2 - } - } - } + .cmapPartitions { (ctx, it) => + val region = ctx.region + val rv = RegionValue(region) + it.map(RVD.bytesToRegionValue(makeDec, region, rv)) }) } @@ -249,7 +389,7 @@ trait RVD { def toRows: RDD[Row] = { val localRowType = rowType - crdd.map { rv => SafeRow(localRowType, rv.region, rv.offset) }.run + map(rv => SafeRow(localRowType, rv.region, rv.offset)) } def toUnpartitionedRVD: UnpartitionedRVD diff --git a/src/main/scala/is/hail/rvd/RVDContext.scala b/src/main/scala/is/hail/rvd/RVDContext.scala index f26f41ace514..a9003393cc32 100644 --- a/src/main/scala/is/hail/rvd/RVDContext.scala +++ b/src/main/scala/is/hail/rvd/RVDContext.scala @@ -28,10 +28,6 @@ class RVDContext(r: Region) extends AutoCloseable { private[this] val theRvb = new RegionValueBuilder(r) def rvb = theRvb - def reset(): Unit = { - r.clear() - } - // frees the memory associated with this context def close(): Unit = { var e: Exception = null diff --git a/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala b/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala index 2eb8926d23aa..4c1d3eb25e48 100644 --- a/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala +++ b/src/main/scala/is/hail/rvd/UnpartitionedRVD.scala @@ -17,10 +17,10 @@ object UnpartitionedRVD { class UnpartitionedRVD(val rowType: TStruct, val crdd: ContextRDD[RVDContext, RegionValue]) extends RVD { self => - def this(rowType: TStruct, rdd: RDD[RegionValue]) = - this(rowType, ContextRDD.weaken[RVDContext](rdd)) + def boundary = new UnpartitionedRVD(rowType, crddBoundary) - val rdd = crdd.run + def head(n: Long): UnpartitionedRVD = + new UnpartitionedRVD(rowType, crdd.head(n)) def filter(f: (RegionValue) => Boolean): UnpartitionedRVD = new UnpartitionedRVD(rowType, crdd.filter(f)) @@ -52,7 +52,12 @@ class UnpartitionedRVD(val rowType: TStruct, val crdd: ContextRDD[RVDContext, Re UnpartitionedRVDSpec(rowType, codecSpec, partFiles) def coalesce(maxPartitions: Int, shuffle: Boolean): UnpartitionedRVD = - new UnpartitionedRVD(rowType, crdd.coalesce(maxPartitions, shuffle = shuffle)) + new UnpartitionedRVD( + rowType, + if (shuffle) + stably(_.shuffleCoalesce(maxPartitions)) + else + crdd.noShuffleCoalesce(maxPartitions)) def constrainToOrderedPartitioner( ordType: OrderedRVDType, diff --git a/src/main/scala/is/hail/sparkextras/BlockedRDD.scala b/src/main/scala/is/hail/sparkextras/BlockedRDD.scala index 2c1f2408dd5a..5405c3a1d6f1 100644 --- a/src/main/scala/is/hail/sparkextras/BlockedRDD.scala +++ b/src/main/scala/is/hail/sparkextras/BlockedRDD.scala @@ -66,4 +66,4 @@ class BlockedRDD[T](@transient var prev: RDD[T], .keys .toFastSeq } -} \ No newline at end of file +} diff --git a/src/main/scala/is/hail/sparkextras/ContextRDD.scala b/src/main/scala/is/hail/sparkextras/ContextRDD.scala index 6e9c84a760ce..d7ea9e3fe06e 100644 --- a/src/main/scala/is/hail/sparkextras/ContextRDD.scala +++ b/src/main/scala/is/hail/sparkextras/ContextRDD.scala @@ -151,8 +151,10 @@ class ContextRDD[C <: AutoCloseable, T: ClassTag]( cmapPartitions((_, part) => f(part), preservesPartitioning) def mapPartitionsWithIndex[U: ClassTag]( - f: (Int, Iterator[T]) => Iterator[U] - ): ContextRDD[C, U] = cmapPartitionsWithIndex((i, _, part) => f(i, part)) + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false + ): ContextRDD[C, U] = + cmapPartitionsWithIndex((i, _, part) => f(i, part), preservesPartitioning) // FIXME: delete when region values are non-serializable def aggregate[U: ClassTag]( @@ -373,13 +375,11 @@ class ContextRDD[C <: AutoCloseable, T: ClassTag]( onRDD(rdd => new AdjustedPartitionsRDD(rdd, contextIgnorantAdjustments)) } - def coalesce(numPartitions: Int, shuffle: Boolean = false): ContextRDD[C, T] = - // NB: the run marks the end of a context lifetime, the next one starts - // after the shuffle - if (shuffle) - ContextRDD.weaken(run.coalesce(numPartitions, shuffle), mkc) - else - onRDD(_.coalesce(numPartitions, shuffle)) + def noShuffleCoalesce(numPartitions: Int): ContextRDD[C, T] = + onRDD(_.coalesce(numPartitions, false)) + + def shuffleCoalesce(numPartitions: Int): ContextRDD[C, T] = + ContextRDD.weaken(run.coalesce(numPartitions, true), mkc) def sample( withReplacement: Boolean, @@ -401,15 +401,73 @@ class ContextRDD[C <: AutoCloseable, T: ClassTag]( }, preservesPartitioning = true) } - def partitionSizes: Array[Long] = - // safe because we don't actually touch the offsets, we just count how many - // there are - sparkContext.runJob(run, getIteratorSize _) + def head(n: Long): ContextRDD[C, T] = { + require(n >= 0) + + val sc = sparkContext + val nPartitions = getNumPartitions + + var partScanned = 0 + var nLeft = n + var idxLast = -1 + var nLast = 0L + var numPartsToTry = 1L + + while (nLeft > 0 && partScanned < nPartitions) { + val nSeen = n - nLeft + + if (partScanned > 0) { + // If we didn't find any rows after the previous iteration, quadruple and retry. + // Otherwise, interpolate the number of partitions we need to try, but overestimate + // it by 50%. We also cap the estimation in the end. + if (nSeen == 0) { + numPartsToTry = partScanned * 4 + } else { + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max((1.5 * n * partScanned / nSeen).toInt - partScanned, 1) + numPartsToTry = Math.min(numPartsToTry, partScanned * 4) + } + } + + val p = partScanned.until(math.min(partScanned + numPartsToTry, nPartitions).toInt) + val counts = runJob(getIteratorSizeWithMaxN(nLeft), p) + + p.zip(counts).foreach { case (idx, c) => + if (nLeft > 0) { + idxLast = idx + nLast = if (c < nLeft) c else nLeft + nLeft -= nLast + } + } + + partScanned += p.size + } + + mapPartitionsWithIndex({ case (i, it) => + if (i == idxLast) + it.take(nLast.toInt) + else + it + }, preservesPartitioning = true) + .subsetPartitions((0 to idxLast).toArray) + } + + def runJob[U: ClassTag](f: Iterator[T] => U, partitions: Seq[Int]): Array[U] = + sparkContext.runJob( + rdd, + { (it: Iterator[ElementType]) => using(mkc())(c => f(it.flatMap(_(c)))) }, + partitions) + + def blocked(partitionEnds: Array[Int]): ContextRDD[C, T] = + new ContextRDD(new BlockedRDD(rdd, partitionEnds), mkc) def sparkContext: SparkContext = rdd.sparkContext def getNumPartitions: Int = rdd.getNumPartitions + def preferredLocations(partition: Partition): Seq[String] = + rdd.preferredLocations(partition) + private[this] def clean[T <: AnyRef](value: T): T = ExposedUtils.clean(sparkContext, value) diff --git a/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD.scala b/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD.scala index b6eb9b9f8ec8..f8d504ea2b47 100644 --- a/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD.scala +++ b/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD.scala @@ -19,8 +19,8 @@ object RepartitionedOrderedRDD { new RepartitionedOrderedRDD( prev.crdd, prev.typ, - prev.partitioner.broadcast(prev.rdd.sparkContext), - newPartitioner.broadcast(prev.rdd.sparkContext)) + prev.partitioner.broadcast(prev.crdd.sparkContext), + newPartitioner.broadcast(prev.crdd.sparkContext)) } } diff --git a/src/main/scala/is/hail/stats/BaldingNicholsModel.scala b/src/main/scala/is/hail/stats/BaldingNicholsModel.scala index 31454c08e40b..c6cfc7f45f15 100644 --- a/src/main/scala/is/hail/stats/BaldingNicholsModel.scala +++ b/src/main/scala/is/hail/stats/BaldingNicholsModel.scala @@ -5,7 +5,8 @@ import breeze.stats.distributions._ import is.hail.HailContext import is.hail.annotations._ import is.hail.expr.types._ -import is.hail.rvd.OrderedRVD +import is.hail.rvd.{OrderedRVD, RVDContext} +import is.hail.sparkextras.ContextRDD import is.hail.utils._ import is.hail.variant.{Call2, MatrixTable, ReferenceGenome} import org.apache.commons.math3.random.JDKRandomGenerator @@ -132,10 +133,9 @@ object BaldingNicholsModel { val rvType = matrixType.rvRowType - val rdd = sc.parallelize((0 until M).view.map(m => (m, Rand.randInt.draw())), nPartitions) - .mapPartitions { it => - - val region = Region() + val rdd = ContextRDD.weaken[RVDContext](sc.parallelize((0 until M).view.map(m => (m, Rand.randInt.draw())), nPartitions)) + .cmapPartitions { (ctx, it) => + val region = ctx.region val rv = RegionValue(region) val rvb = new RegionValueBuilder(region) @@ -152,7 +152,6 @@ object BaldingNicholsModel { .draw() }) - region.clear() rvb.start(rvType) rvb.startStruct() diff --git a/src/main/scala/is/hail/table/Table.scala b/src/main/scala/is/hail/table/Table.scala index 2ee38fb502b1..bb488e03e6e1 100644 --- a/src/main/scala/is/hail/table/Table.scala +++ b/src/main/scala/is/hail/table/Table.scala @@ -176,7 +176,7 @@ object Table { globals: Annotation, sort: Boolean ): Table = { - val crdd2 = crdd.mapPartitions(_.toRegionValueIterator(signature)) + val crdd2 = crdd.cmapPartitions((ctx, it) => it.toRegionValueIterator(ctx.region, signature)) new Table(hc, TableLiteral( TableValue( TableType(signature, None, globalSignature), @@ -211,9 +211,9 @@ class Table(val hc: HailContext, val tir: TableIR) { hc: HailContext, crdd: ContextRDD[RVDContext, RegionValue], signature: TStruct, - key: Option[IndexedSeq[String]], - globalSignature: TStruct, - globals: Row + key: Option[IndexedSeq[String]] = None, + globalSignature: TStruct = TStruct.empty(), + globals: Row = Row.empty ) = this(hc, TableLiteral( TableValue( @@ -221,20 +221,6 @@ class Table(val hc: HailContext, val tir: TableIR) { BroadcastRow(globals, globalSignature, hc.sc), new UnpartitionedRVD(signature, crdd)))) - def this(hc: HailContext, - rdd: RDD[RegionValue], - signature: TStruct, - key: Option[IndexedSeq[String]], - globalSignature: TStruct, - globals: Row - ) = this( - hc, - ContextRDD.weaken[RVDContext](rdd), - signature, - key, - globalSignature, - globals) - def typ: TableType = tir.typ private def useIR(ast: AST): Boolean = { @@ -251,19 +237,6 @@ class Table(val hc: HailContext, val tir: TableIR) { opt.execute(hc) } - def this( - hc: HailContext, - rdd: RDD[RegionValue], - signature: TStruct, - key: Option[IndexedSeq[String]] - ) = this(hc, rdd, signature, key, TStruct.empty(), Row.empty) - - def this( - hc: HailContext, - rdd: RDD[RegionValue], - signature: TStruct - ) = this(hc, rdd, signature, None) - lazy val TableValue(ktType, globals, rvd) = value val TableType(signature, key, globalSignature) = tir.typ @@ -545,7 +518,7 @@ class Table(val hc: HailContext, val tir: TableIR) { def head(n: Long): Table = { if (n < 0) fatal(s"n must be non-negative! Found `$n'.") - copy(rdd = rdd.head(n)) + copy2(rvd = rvd.head(n)) } def keyBy(key: String*): Table = keyBy(key.toArray, key.toArray) @@ -587,9 +560,17 @@ class Table(val hc: HailContext, val tir: TableIR) { else if (ordered.typ.key.length <= keys.length) { val localSortType = new OrderedRVDType(ordered.typ.key, keys, signature) val newType = new OrderedRVDType(ordered.typ.partitionKey, keys, signature) - ordered.mapPartitionsPreservesPartitioning(newType) { it => - OrderedRVD.localKeySort(localSortType, it) - } + OrderedRVD( + newType, + ordered.partitioner, + ordered.crdd.cmapPartitionsAndContext { (consumerCtx, it) => + val producerCtx = consumerCtx.freshContext + OrderedRVD.localKeySort( + consumerCtx.region, + producerCtx.region, + localSortType, + it.flatMap(_(producerCtx))) + }) } else resort } else resort case _: UnpartitionedRVD => @@ -785,16 +766,15 @@ class Table(val hc: HailContext, val tir: TableIR) { val newRVType = matrixType.rvRowType val orderedRKStruct = matrixType.rowKeyStruct - val newRVD = ordered.mapPartitionsPreservesPartitioning(matrixType.orvdType) { it => - val region = Region() - val rvb = new RegionValueBuilder(region) + val newRVD = ordered.boundary.mapPartitionsPreservesPartitioning(matrixType.orvdType, { (ctx, it) => + val region = ctx.region + val rvb = ctx.rvb val outRV = RegionValue(region) OrderedRVIterator( new OrderedRVDType(partitionKeys, rowKeys, rowEntryStruct), it ).staircase.map { rowIt => - region.clear() rvb.start(newRVType) rvb.startStruct() var i = 0 @@ -828,7 +808,7 @@ class Table(val hc: HailContext, val tir: TableIR) { outRV.setOffset(rvb.end()) outRV } - } + }) new MatrixTable(hc, matrixType, globals, @@ -997,6 +977,7 @@ class Table(val hc: HailContext, val tir: TableIR) { } val act = implicitly[ClassTag[Annotation]] + // FIXME: need to add sortBy on rvd? copy(rdd = rdd.sortBy(identity[Annotation], ascending = true)(ord, act)) } @@ -1006,7 +987,7 @@ class Table(val hc: HailContext, val tir: TableIR) { def union(kts: Table*): Table = new Table(hc, TableUnion((tir +: kts.map(_.tir)).toFastIndexedSeq)) - def take(n: Int): Array[Row] = rdd.take(n) + def take(n: Int): Array[Row] = rvd.take(n, RVD.wireCodec) def takeJSON(n: Int): String = { val r = JSONAnnotationImpex.exportAnnotation(take(n).toFastIndexedSeq, TArray(signature)) @@ -1024,6 +1005,7 @@ class Table(val hc: HailContext, val tir: TableIR) { val (newSignature, ins) = signature.insert(TInt64(), name) + // FIXME: should use RVD, need zipWithIndex val newRDD = rdd.zipWithIndex().map { case (r, ind) => ins(r, ind).asInstanceOf[Row] } copy(signature = newSignature.asInstanceOf[TStruct], rdd = newRDD) diff --git a/src/main/scala/is/hail/utils/FlipbookIterator.scala b/src/main/scala/is/hail/utils/FlipbookIterator.scala index 74e2ae94decd..299333d7fe91 100644 --- a/src/main/scala/is/hail/utils/FlipbookIterator.scala +++ b/src/main/scala/is/hail/utils/FlipbookIterator.scala @@ -162,13 +162,16 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => def staircased(ord: OrderingView[A]): StagingIterator[FlipbookIterator[A]] = { ord.setBottom() - val stepIterator: FlipbookIterator[A] = FlipbookIterator( - new StateMachine[A] { - def value: A = self.value - def isValid: Boolean = self.isValid && ord.isEquivalent(value) - def advance() = { self.advance() } + val stepSM = new StateMachine[A] { + def value: A = self.value + var _isValid: Boolean = self.isValid && ord.isEquivalent(self.value) + def isValid = _isValid + def advance() = { + self.advance() + _isValid = self.isValid && ord.isEquivalent(self.value) } - ) + } + val stepIterator: FlipbookIterator[A] = FlipbookIterator(stepSM) val sm = new StateMachine[FlipbookIterator[A]] { var isValid: Boolean = true val value: FlipbookIterator[A] = stepIterator @@ -176,8 +179,10 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => stepIterator.exhaust() if (self.isValid) { ord.setValue(self.value) + stepSM._isValid = self.isValid && ord.isEquivalent(self.value) } else { ord.setBottom() + stepSM._isValid = false isValid = false } } diff --git a/src/main/scala/is/hail/utils/richUtils/Implicits.scala b/src/main/scala/is/hail/utils/richUtils/Implicits.scala index b6630fd8ca6d..fa5a0d19b833 100644 --- a/src/main/scala/is/hail/utils/richUtils/Implicits.scala +++ b/src/main/scala/is/hail/utils/richUtils/Implicits.scala @@ -5,6 +5,8 @@ import java.io.InputStream import breeze.linalg.DenseMatrix import is.hail.annotations.{JoinedRegionValue, Region, RegionValue} import is.hail.asm4s.Code +import is.hail.io.RichContextRDDRegionValue +import is.hail.rvd.RVDContext import is.hail.io.{InputBuffer, RichContextRDDRegionValue} import is.hail.sparkextras._ import is.hail.utils.{ArrayBuilder, HailIterator, JSONWriter, MultiArray2, Truncatable, WithContext} @@ -73,7 +75,7 @@ trait Implicits { implicit def toRichRDD[T](r: RDD[T])(implicit tct: ClassTag[T]): RichRDD[T] = new RichRDD(r) - implicit def toRichContextRDDRegionValue[C <: AutoCloseable](r: ContextRDD[C, RegionValue]): RichContextRDDRegionValue[C] = new RichContextRDDRegionValue(r) + implicit def toRichContextRDDRegionValue(r: ContextRDD[RVDContext, RegionValue]): RichContextRDDRegionValue = new RichContextRDDRegionValue(r) implicit def toRichRDDByteArray(r: RDD[Array[Byte]]): RichRDDByteArray = new RichRDDByteArray(r) @@ -111,7 +113,7 @@ trait Implicits { implicit def toContextPairRDDFunctions[C <: AutoCloseable, K: ClassTag, V: ClassTag](x: ContextRDD[C, (K, V)]): ContextPairRDDFunctions[C, K, V] = new ContextPairRDDFunctions(x) - implicit def toRichContextRDD[C <: AutoCloseable, T: ClassTag](x: ContextRDD[C, T]): RichContextRDD[C, T] = new RichContextRDD(x) + implicit def toRichContextRDD[T: ClassTag](x: ContextRDD[RVDContext, T]): RichContextRDD[T] = new RichContextRDD(x) implicit def toRichCodeInputBuffer(in: Code[InputBuffer]): RichCodeInputBuffer = new RichCodeInputBuffer(in) } diff --git a/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala b/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala index dc408531508a..fe0d597a30db 100644 --- a/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala +++ b/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala @@ -2,6 +2,7 @@ package is.hail.utils.richUtils import java.io._ +import is.hail.rvd.RVDContext import org.apache.commons.lang3.StringUtils import org.apache.spark.TaskContext import is.hail.utils._ @@ -9,9 +10,9 @@ import is.hail.sparkextras._ import scala.reflect.ClassTag -class RichContextRDD[C <: AutoCloseable, T: ClassTag](crdd: ContextRDD[C, T]) { +class RichContextRDD[T: ClassTag](crdd: ContextRDD[RVDContext, T]) { def writePartitions(path: String, - write: (Int, Iterator[T], OutputStream) => Long, + write: (RVDContext, Iterator[T], OutputStream) => Long, remapPartitions: Option[(Array[Int], Int)] = None): (Array[String], Array[Long]) = { val sc = crdd.sparkContext val hadoopConf = sc.hadoopConfiguration @@ -31,12 +32,14 @@ class RichContextRDD[C <: AutoCloseable, T: ClassTag](crdd: ContextRDD[C, T]) { val remapBc = sc.broadcast(remap) - val (partFiles, partitionCounts) = crdd.mapPartitionsWithIndex { case (index, it) => + val (partFiles, partitionCounts) = crdd.cmapPartitionsWithIndex { (index, ctx, it) => val i = remapBc.value(index) val f = partFile(d, i, TaskContext.get) val filename = path + "/parts/" + f val os = sHadoopConfBc.value.value.unsafeWriter(filename) - Iterator.single(f -> write(i, it, os)) + val out = Iterator.single(f -> write(ctx, it, os)) + ctx.region.clear() + out } .collect() .unzip diff --git a/src/main/scala/is/hail/utils/richUtils/RichIterator.scala b/src/main/scala/is/hail/utils/richUtils/RichIterator.scala index 47b962a4a061..5306820296d4 100644 --- a/src/main/scala/is/hail/utils/richUtils/RichIterator.scala +++ b/src/main/scala/is/hail/utils/richUtils/RichIterator.scala @@ -105,12 +105,10 @@ class RichIterator[T](val it: Iterator[T]) extends AnyVal { } class RichRowIterator(val it: Iterator[Row]) extends AnyVal { - def toRegionValueIterator(rowTyp: TStruct): Iterator[RegionValue] = { - val region = Region() + def toRegionValueIterator(region: Region, rowTyp: TStruct): Iterator[RegionValue] = { val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) it.map { row => - region.clear() rvb.start(rowTyp) rvb.addAnnotation(rowTyp, row) rv.setOffset(rvb.end()) diff --git a/src/main/scala/is/hail/utils/richUtils/RichRDD.scala b/src/main/scala/is/hail/utils/richUtils/RichRDD.scala index 302c0484a3c6..650efd3ec331 100644 --- a/src/main/scala/is/hail/utils/richUtils/RichRDD.scala +++ b/src/main/scala/is/hail/utils/richUtils/RichRDD.scala @@ -2,6 +2,7 @@ package is.hail.utils.richUtils import java.io.OutputStream +import is.hail.rvd.RVDContext import is.hail.sparkextras._ import is.hail.utils._ import org.apache.commons.lang3.StringUtils @@ -99,7 +100,8 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { } def subsetPartitions(keep: Array[Int])(implicit ct: ClassTag[T]): RDD[T] = { - require(keep.length <= r.partitions.length, "tried to subset to more partitions than exist") + require(keep.length <= r.partitions.length, + s"tried to subset to more partitions than exist ${keep.toSeq} ${r.partitions.toSeq}") require(keep.isIncreasing && (keep.isEmpty || (keep.head >= 0 && keep.last < r.partitions.length)), "values not sorted or not in range [0, number of partitions)") val parentPartitions = r.partitions @@ -139,7 +141,7 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { var partScanned = 0 var nLeft = n - var idxLast = 0 + var idxLast = -1 var nLast = 0L var numPartsToTry = 1L @@ -183,9 +185,12 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { } def writePartitions(path: String, - write: (Int, Iterator[T], OutputStream) => Long, + write: (Iterator[T], OutputStream) => Long, remapPartitions: Option[(Array[Int], Int)] = None )(implicit tct: ClassTag[T] ): (Array[String], Array[Long]) = - ContextRDD.weaken[TrivialContext](r).writePartitions(path, write, remapPartitions) + ContextRDD.weaken[RVDContext](r).writePartitions( + path, + (_, it, os) => write(it, os), + remapPartitions) } diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index 5dbd0b1cd8c3..e81002412c72 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -13,6 +13,8 @@ import is.hail.methods.Aggregators.ColFunctions import is.hail.utils._ import is.hail.{HailContext, utils} import is.hail.expr.types._ +import is.hail.io.CodecSpec +import is.hail.sparkextras.ContextRDD import org.apache.hadoop import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -143,8 +145,8 @@ object MatrixTable { BroadcastRow(globals.asInstanceOf[Row], matrixType.globalType, hc.sc), BroadcastIndexedSeq(colValues, TArray(matrixType.colType), hc.sc), OrderedRVD.coerce(matrixType.orvdType, - rdd.mapPartitions { it => - val region = Region() + ContextRDD.weaken[RVDContext](rdd).cmapPartitions { (ctx, it) => + val region = ctx.region val rvb = new RegionValueBuilder(region) val rv = RegionValue(region) @@ -152,7 +154,6 @@ object MatrixTable { val vaRow = va.asInstanceOf[Row] assert(matrixType.rowType.typeCheck(vaRow), s"${ matrixType.rowType }, $vaRow") - region.clear() rvb.start(localRVRowType) rvb.startStruct() var i = 0 @@ -266,7 +267,7 @@ object MatrixTable { val oldRowType = kt.signature - val rdd = kt.rvd.mapPartitions { it => + val rdd = kt.rvd.mapPartitions(matrixType.rvRowType) { it => val rvb = new RegionValueBuilder() val rv2 = RegionValue() @@ -286,7 +287,7 @@ object MatrixTable { new MatrixTable(kt.hc, matrixType, BroadcastRow(Row(), matrixType.globalType, kt.hc.sc), BroadcastIndexedSeq(Array.empty[Annotation], TArray(matrixType.colType), kt.hc.sc), - OrderedRVD.coerce(matrixType.orvdType, rdd, None, None)) + OrderedRVD.coerce(matrixType.orvdType, rdd)) } } @@ -655,13 +656,13 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val selectIdx = matrixType.orvdType.kRowFieldIdx val keyOrd = matrixType.orvdType.kRowOrd - val newRVD = rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType) { it => + val newRVD = rvd.boundary.mapPartitionsPreservesPartitioning(newMatrixType.orvdType, { (ctx, it) => new Iterator[RegionValue] { var isEnd = false var current: RegionValue = null val rvRowKey: WritableRegionValue = WritableRegionValue(newRowType) - val region = Region() - val rvb = new RegionValueBuilder(region) + val region = ctx.region + val rvb = ctx.rvb val newRV = RegionValue(region) def hasNext: Boolean = { @@ -683,7 +684,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { aggs = seqOp(aggs, current) current = null } - region.clear() rvb.start(newRVType) rvb.startStruct() var i = 0 @@ -697,7 +697,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { newRV } } - } + }) copyMT(rvd = newRVD, matrixType = newMatrixType) } @@ -780,18 +780,17 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val (newRVType, ins) = rvRowType.unsafeStructInsert(valueType, List(root)) val rightRowType = rightRVD.rowType + val leftRowType = rvRowType val rightValueIndices = rightRVD.typ.valueIndices assert(!product || rightValueIndices.length == 1) - val joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue] = { it => - val rvb = new RegionValueBuilder() + val joiner = { (ctx: RVDContext, it: Iterator[JoinedRegionValue]) => + val rvb = ctx.rvb val rv = RegionValue() it.map { jrv => val lrv = jrv.rvLeft - - rvb.set(lrv.region) rvb.start(newRVType) ins(lrv.region, lrv.offset, rvb, () => { @@ -815,7 +814,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } } }) - rv.set(lrv.region, rvb.end()) + rv.set(ctx.region, rvb.end()) rv } } @@ -847,15 +846,21 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val ktSignature = kt.signature val ktKeyFieldIdx = kt.keyFieldIdx.get(0) val ktValueFieldIdx = kt.valueFieldIdx - val partitionKeyedIntervals = kt.rvd.rdd - .flatMap { rv => + val partitionKeyedIntervals = kt.rvd.boundary.crdd + .cflatMap { (ctx, rv) => + val region = ctx.region + val rv2 = RegionValue(region) + val rvb = ctx.rvb + rvb.start(ktSignature) + rvb.addRegionValue(ktSignature, rv) + rv2.setOffset(rvb.end()) val ur = new UnsafeRow(ktSignature, rv) val interval = ur.getAs[Interval](ktKeyFieldIdx) if (interval != null) { val rangeTree = partBc.value.rangeTree val pkOrd = partBc.value.pkType.ordering val wrappedInterval = interval.copy(start = Row(interval.start), end = Row(interval.end)) - rangeTree.queryOverlappingValues(pkOrd, wrappedInterval).map(i => (i, rv)) + rangeTree.queryOverlappingValues(pkOrd, wrappedInterval).map(i => (i, rv2)) } else Iterator() } @@ -876,9 +881,9 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { ) { case (it, intervals) => val intervalAnnotations: Array[(Interval, Any)] = intervals.map { rv => - val ur = new UnsafeRow(ktSignature, rv) - val interval = ur.getAs[Interval](ktKeyFieldIdx) - (interval, Row.fromSeq(ktValueFieldIdx.map(ur.get))) + val r = SafeRow(ktSignature, rv) + val interval = r.getAs[Interval](ktKeyFieldIdx) + (interval, Row.fromSeq(ktValueFieldIdx.map(r.get))) }.toArray val iTree = IntervalTree.annotationTree(typOrdering, intervalAnnotations) @@ -1083,7 +1088,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } if (touchesKeys) { warn("modified row key, rescanning to compute ordering...") - val newRDD = rvd.mapPartitions(mapPartitionsF) + val newRDD = rvd.mapPartitions(newMatrixType.rvRowType)(mapPartitionsF) copyMT(matrixType = newMatrixType, rvd = OrderedRVD.coerce(newMatrixType.orvdType, newRDD, None, None)) } else copyMT(matrixType = newMatrixType, @@ -1145,7 +1150,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { def forceCountRows(): Long = rvd.count() def deduplicate(): MatrixTable = - copy2(rvd = rvd.mapPartitionsPreservesPartitioning(rvd.typ)( + copy2(rvd = rvd.boundary.mapPartitionsPreservesPartitioning(rvd.typ)( SortedDistinctRowIterator.transformer(rvd.typ))) def deleteVA(args: String*): (Type, Deleter) = deleteVA(args.toList) @@ -1172,10 +1177,10 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val localEntriesIndex = entriesIndex - val explodedRDD = rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType) { it => - val region2 = Region() + val explodedRDD = rvd.boundary.mapPartitionsPreservesPartitioning(newMatrixType.orvdType, { (ctx, it) => + val region2 = ctx.region val rv2 = RegionValue(region2) - val rv2b = new RegionValueBuilder(region2) + val rv2b = ctx.rvb val ur = new UnsafeRow(oldRVType) it.flatMap { rv => ur.set(rv) @@ -1184,7 +1189,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { None else keys.iterator.map { explodedElement => - region2.clear() rv2b.start(newRVType) inserter(rv.region, rv.offset, rv2b, () => rv2b.addAnnotation(keyType, explodedElement)) @@ -1192,7 +1196,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { rv2 } } - } + }) copyMT(matrixType = newMatrixType, rvd = explodedRDD) } @@ -1433,15 +1437,14 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val localEntriesType = matrixType.entryArrayType assert(right.matrixType.entryArrayType == localEntriesType) - val joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue] = { it => - val rvb = new RegionValueBuilder() + val joiner = { (ctx: RVDContext, it: Iterator[JoinedRegionValue]) => + val rvb = ctx.rvb val rv2 = RegionValue() it.map { jrv => val lrv = jrv.rvLeft val rrv = jrv.rvRight - rvb.set(lrv.region) rvb.start(leftRVType) rvb.startStruct() var i = 0 @@ -1474,7 +1477,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { rvb.endArray() rvb.endStruct() - rv2.set(lrv.region, rvb.end()) + rv2.set(ctx.region, rvb.end()) rv2 } } @@ -1889,55 +1892,47 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val localColKeys = colKeys metadataSame && - rvd.rdd.zipPartitions( + rvd.crdd.czip( OrderedRVD.adjustBoundsAndShuffle( that.rvd.typ, rvd.partitioner.withKType(that.rvd.typ.partitionKey, that.rvd.typ.kType), - that.rvd.rdd) - .rdd) { (it1, it2) => + that.rvd) + .crdd) { (ctx, rv1, rv2) => + var partSame = true + val fullRow1 = new UnsafeRow(leftRVType) val fullRow2 = new UnsafeRow(rightRVType) - var partSame = true - while (it1.hasNext && it2.hasNext) { - val rv1 = it1.next() - val rv2 = it2.next() - fullRow1.set(rv1) - fullRow2.set(rv2) - val row1 = fullRow1.deleteField(localLeftEntriesIndex) - val row2 = fullRow2.deleteField(localRightEntriesIndex) + fullRow1.set(rv1) + fullRow2.set(rv2) + val row1 = fullRow1.deleteField(localLeftEntriesIndex) + val row2 = fullRow2.deleteField(localRightEntriesIndex) - if (!localRowType.valuesSimilar(row1, row2, tolerance, absolute)) { - println( - s"""row fields not the same: + if (!localRowType.valuesSimilar(row1, row2, tolerance, absolute)) { + println( + s"""row fields not the same: | $row1 | $row2""".stripMargin) - partSame = false - } + partSame = false + } - val gs1 = fullRow1.getAs[IndexedSeq[Annotation]](localLeftEntriesIndex) - val gs2 = fullRow2.getAs[IndexedSeq[Annotation]](localRightEntriesIndex) + val gs1 = fullRow1.getAs[IndexedSeq[Annotation]](localLeftEntriesIndex) + val gs2 = fullRow2.getAs[IndexedSeq[Annotation]](localRightEntriesIndex) - var i = 0 - while (partSame && i < gs1.length) { - if (!localEntryType.valuesSimilar(gs1(i), gs2(i), tolerance, absolute)) { - partSame = false - println( - s"""different entry at row ${ localRKF(row1) }, col ${ localColKeys(i) } + var i = 0 + while (partSame && i < gs1.length) { + if (!localEntryType.valuesSimilar(gs1(i), gs2(i), tolerance, absolute)) { + partSame = false + println( + s"""different entry at row ${ localRKF(row1) }, col ${ localColKeys(i) } | ${ gs1(i) } | ${ gs2(i) }""".stripMargin) - } - i += 1 } + i += 1 } - - if ((it1.hasNext || it2.hasNext) && partSame) { - println("partition has different number of rows") - partSame = false - } - - Iterator(partSame) - }.forall(t => t) + ctx.region.clear() + partSame + }.run.forall(t => t) } def colEC: EvalContext = { @@ -2025,17 +2020,22 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } val localRVRowType = rvRowType - rvd.map { rv => - new UnsafeRow(localRVRowType, rv) - }.find(ur => !localRVRowType.typeCheck(ur)) - .foreach { ur => + val predicate = { (rv: RegionValue) => + val ur = new UnsafeRow(localRVRowType, rv) + !localRVRowType.typeCheck(ur) + } + + Region.scoped { region => + rvd.find(region)(predicate).foreach { rv => + val ur = new UnsafeRow(localRVRowType, rv) foundError = true warn( s"""found violation in row - |Schema: $localRVRowType - |Annotation: ${ Annotation.printAnnotation(ur) }""".stripMargin) + |Schema: $localRVRowType + |Annotation: ${ Annotation.printAnnotation(ur) }""".stripMargin) } + } if (foundError) fatal("found one or more type check errors") @@ -2194,12 +2194,13 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val locusIndex = rvRowType.fieldIdx("locus") val allelesIndex = rvRowType.fieldIdx("alleles") - def minRep1(removeLeftAligned: Boolean, removeMoving: Boolean, verifyLeftAligned: Boolean): RDD[RegionValue] = { - rvd.mapPartitions { it => + def minRep1(removeLeftAligned: Boolean, removeMoving: Boolean, verifyLeftAligned: Boolean): RVD = { + rvd.mapPartitions(localRVRowType) { it => var prevLocus: Locus = null val rvb = new RegionValueBuilder() val rv2 = RegionValue() + // FIXME: how is this not broken? it.flatMap { rv => val ur = new UnsafeRow(localRVRowType, rv.region, rv.offset) @@ -2378,7 +2379,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val fieldIdx = entryType.fieldIdx(entryField) val numColsLocal = numCols - val rows = rvd.mapPartitionsWithIndex { case (pi, it) => + val rows = rvd.mapPartitionsWithIndex { (pi, it) => var i = partStartsBc.value(pi) it.map { rv => val region = rv.region @@ -2420,7 +2421,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { hadoop.mkDir(dirname + "/parts") val gp = GridPartitioner(blockSize, nRows, localNCols) val blockPartFiles = - new WriteBlocksRDD(dirname, rvd.rdd, sparkContext, matrixType, partStarts, entryField, gp) + new WriteBlocksRDD(dirname, rvd.crdd, sparkContext, matrixType, partStarts, entryField, gp) .collect() val blockCount = blockPartFiles.length @@ -2446,15 +2447,14 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val partStarts = partitionStarts() val newMatrixType = matrixType.copy(rvRowType = newRVType) - val indexedRVD = rvd.mapPartitionsWithIndexPreservesPartitioning(newMatrixType.orvdType) { case (i, it) => - val region2 = Region() + val indexedRVD = rvd.boundary.mapPartitionsWithIndexPreservesPartitioning(newMatrixType.orvdType, { (i, ctx, it) => + val region2 = ctx.region val rv2 = RegionValue(region2) - val rv2b = new RegionValueBuilder(region2) + val rv2b = ctx.rvb var idx = partStarts(i) it.map { rv => - region2.clear() rv2b.start(newRVType) inserter(rv.region, rv.offset, rv2b, @@ -2464,7 +2464,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { rv2.setOffset(rv2b.end()) rv2 } - } + }) copyMT(matrixType = newMatrixType, rvd = indexedRVD) } diff --git a/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala b/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala index 9d2860652f9b..f9724f7bb4da 100644 --- a/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala +++ b/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala @@ -342,30 +342,31 @@ class LocalLDPruneSuite extends SparkSuite { } @Test def bitPackedVectorCorrectWhenOffsetNotZero() { - val r = Region() - val rvb = new RegionValueBuilder(r) - val t = BitPackedVectorView.rvRowType( - +TLocus(ReferenceGenome.GRCh37), - +TArray(+TString())) - val bpv = new BitPackedVectorView(t) - r.appendInt(0xbeef) - rvb.start(t) - rvb.startStruct() - rvb.startStruct() - rvb.addString("X") - rvb.addInt(42) - rvb.endStruct() - rvb.startArray(0) - rvb.endArray() - rvb.startArray(0) - rvb.endArray() - rvb.addInt(0) - rvb.addDouble(0.0) - rvb.addDouble(0.0) - rvb.endStruct() - bpv.setRegion(r, rvb.end()) - assert(bpv.getContig == "X") - assert(bpv.getStart == 42) + Region.scoped { r => + val rvb = new RegionValueBuilder(r) + val t = BitPackedVectorView.rvRowType( + +TLocus(ReferenceGenome.GRCh37), + +TArray(+TString())) + val bpv = new BitPackedVectorView(t) + r.appendInt(0xbeef) + rvb.start(t) + rvb.startStruct() + rvb.startStruct() + rvb.addString("X") + rvb.addInt(42) + rvb.endStruct() + rvb.startArray(0) + rvb.endArray() + rvb.startArray(0) + rvb.endArray() + rvb.addInt(0) + rvb.addDouble(0.0) + rvb.addDouble(0.0) + rvb.endStruct() + bpv.setRegion(r, rvb.end()) + assert(bpv.getContig == "X") + assert(bpv.getStart == 42) + } } @Test def testIsLocallyUncorrelated() { diff --git a/src/test/scala/is/hail/testUtils/RichMatrixTable.scala b/src/test/scala/is/hail/testUtils/RichMatrixTable.scala index 2c27d1b5b36e..d6e633c0fd41 100644 --- a/src/test/scala/is/hail/testUtils/RichMatrixTable.scala +++ b/src/test/scala/is/hail/testUtils/RichMatrixTable.scala @@ -87,7 +87,8 @@ class RichMatrixTable(vsm: MatrixTable) { val localEntriesIndex = vsm.entriesIndex val localRowType = vsm.rowType val rowKeyF = vsm.rowKeysF - vsm.rvd.rdd.map { rv => + vsm.rvd.map { rv => + val unsafeFullRow = new UnsafeRow(fullRowType, rv) val fullRow = SafeRow(fullRowType, rv.region, rv.offset) val row = fullRow.deleteField(localEntriesIndex) (rowKeyF(fullRow), (row, fullRow.getAs[IndexedSeq[Any]](localEntriesIndex))) diff --git a/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala b/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala index 23bdbe3a23fd..2ff33a3cdff8 100644 --- a/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala +++ b/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala @@ -2,6 +2,7 @@ package is.hail.utils import is.hail.SparkSuite import org.testng.annotations.Test + import scala.collection.generic.Growable import scala.collection.mutable.ArrayBuffer diff --git a/src/test/scala/is/hail/utils/RichRDDSuite.scala b/src/test/scala/is/hail/utils/RichRDDSuite.scala index 21f8d231cfcf..6af6a8503c82 100644 --- a/src/test/scala/is/hail/utils/RichRDDSuite.scala +++ b/src/test/scala/is/hail/utils/RichRDDSuite.scala @@ -11,10 +11,10 @@ class RichRDDSuite extends SparkSuite { @Test def testHead() { val r = sc.parallelize(0 until 1024, numSlices = 20) - val partitionRanges = r.countPerPartition().scanLeft(Range(0, 0)) { case (x, c) => Range(x.end, x.end + c.toInt) } + val partitionRanges = r.countPerPartition().scanLeft(Range(0, 1)) { case (x, c) => Range(x.end, x.end + c.toInt + 1) } def getExpectedNumPartitions(n: Int): Int = - partitionRanges.indexWhere(_.contains(math.max(0, n - 1))) + partitionRanges.indexWhere(_.contains(n)) for (n <- Array(0, 15, 200, 562, 1024, 2000)) { val t = r.head(n) diff --git a/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala b/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala index 4b1655158e7a..eaa2f26e1756 100644 --- a/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala +++ b/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala @@ -75,8 +75,8 @@ class PartitioningSuite extends SparkSuite { val mt = MatrixTable.fromRowsTable(Table.range(hc, 100, nPartitions=Some(6))) val orvdType = mt.matrixType.orvdType - mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "left", _.map(_._1), orvdType).count() - mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "inner", _.map(_._1), orvdType).count() + mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "left", (_, it) => it.map(_._1), orvdType).count() + mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "inner", (_, it) => it.map(_._1), orvdType).count() } @Test def testEmptyRDDOrderedJoin() { @@ -86,9 +86,9 @@ class PartitioningSuite extends SparkSuite { val nonEmptyRVD = mt.rvd val emptyRVD = OrderedRVD.empty(hc.sc, orvdType) - emptyRVD.orderedJoin(nonEmptyRVD, "left", _.map(_._1), orvdType).count() - emptyRVD.orderedJoin(nonEmptyRVD, "inner", _.map(_._1), orvdType).count() - nonEmptyRVD.orderedJoin(emptyRVD, "left", _.map(_._1), orvdType).count() - nonEmptyRVD.orderedJoin(emptyRVD, "inner", _.map(_._1), orvdType).count() + emptyRVD.orderedJoin(nonEmptyRVD, "left", (_, it) => it.map(_._1), orvdType).count() + emptyRVD.orderedJoin(nonEmptyRVD, "inner", (_, it) => it.map(_._1), orvdType).count() + nonEmptyRVD.orderedJoin(emptyRVD, "left", (_, it) => it.map(_._1), orvdType).count() + nonEmptyRVD.orderedJoin(emptyRVD, "inner", (_, it) => it.map(_._1), orvdType).count() } } diff --git a/src/test/scala/is/hail/vds/JoinSuite.scala b/src/test/scala/is/hail/vds/JoinSuite.scala index 1020844eb98c..863c1c01220e 100644 --- a/src/test/scala/is/hail/vds/JoinSuite.scala +++ b/src/test/scala/is/hail/vds/JoinSuite.scala @@ -1,7 +1,7 @@ package is.hail.vds import is.hail.SparkSuite -import is.hail.annotations.UnsafeRow +import is.hail.annotations._ import is.hail.expr.types.{TLocus, TStruct} import is.hail.table.Table import is.hail.variant.{Locus, MatrixTable, ReferenceGenome} @@ -67,24 +67,24 @@ class JoinSuite extends SparkSuite { val localRowType = left.rvRowType // Inner distinct ordered join - val jInner = left.rvd.orderedJoinDistinct(right.rvd, "inner", _.map(_._1), left.rvd.typ) + val jInner = left.rvd.orderedJoinDistinct(right.rvd, "inner", (_, it) => it.map(_._1), left.rvd.typ) val jInnerOrdRDD1 = left.rdd.join(right.rdd.distinct) assert(jInner.count() == jInnerOrdRDD1.count()) assert(jInner.map { rv => - val ur = new UnsafeRow(localRowType, rv) - ur.getAs[Locus](0) + val r = SafeRow(localRowType, rv) + r.getAs[Locus](0) }.collect() sameElements jInnerOrdRDD1.map(_._1.asInstanceOf[Row].get(0)).collect().sorted(vType.ordering.toOrdering)) // Left distinct ordered join - val jLeft = left.rvd.orderedJoinDistinct(right.rvd, "left", _.map(_._1), left.rvd.typ) + val jLeft = left.rvd.orderedJoinDistinct(right.rvd, "left", (_, it) => it.map(_._1), left.rvd.typ) val jLeftOrdRDD1 = left.rdd.leftOuterJoin(right.rdd.distinct) assert(jLeft.count() == jLeftOrdRDD1.count()) - assert(jLeft.rdd.forall(rv => rv != null)) + assert(jLeft.forall(rv => rv != null)) assert(jLeft.map { rv => - val ur = new UnsafeRow(localRowType, rv) - ur.getAs[Locus](0) + val r = SafeRow(localRowType, rv) + r.getAs[Locus](0) }.collect() sameElements jLeftOrdRDD1.map(_._1.asInstanceOf[Row].get(0)).collect().sorted(vType.ordering.toOrdering)) } }