From 8b39e6a0e4eff523ac756a59b00acc148de8450a Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 23 Mar 2018 18:48:05 -0400 Subject: [PATCH] Eliminate some but not all uses of RVD.rdd (#3186) * Eliminate some but not all uses of RVD.rdd This change anticipates the ContextRDD change wherein `RVD.rdd` will not be an RDD. Moreover, enforcing an abstraction barrier at the level of `RVD` will ease changes to the implementation of `RVD`. There are two remaining types of calls that I cannot eliminate: - uses in BlockMatrix and OrderedRDD2: these two classes are building new RDDs based on the RVD's rdd, these classes should be considered within the implementation of the RVD abstraction. Because these two classes are outside of `is.hail.rvd`, I cannot enforce an access modifier on `RVD.rdd`. - uses by methods: - LDPrune: it seems we need a "GeneralRVD" - Skat: it seems like some of this could be moved to python actually; but there is some matrix math that cannot be moved until the expr lang has efficient small-matrix ops - MatrixTable.same: I could probably move this if I re-implemented forall in terms of RVD.aggregate? - MatrixTable.annotateRowsIntervalTable: really not sure about this one, this seems like a performance optimization that purposely reaches through the abstraction to do Smart Things * clean up * formatting * more formatting * use assertOrdered instead of old apply * fixes * improve use of assertions * rename toUnsafeRows to toRows * rename unsafeChangeType to updateType * wip zip not sure what to do * finish renames * fix invalid assertions * remove coerceOrdered, remove OrderedRVD.apply * fixes and eliminate coerceOrdered * actually remove coerceOrdered * fix * clean up zipPartitions definitions and uses * name error * fix name * Update OrderedRVD.scala * Update OrderedRVD.scala * fix filteralleles shuffle and friends * formatting * rebase errors * harmonize formatting * remove rebase cruft --- src/main/scala/is/hail/expr/Relational.scala | 77 +++++----- src/main/scala/is/hail/io/LoadMatrix.scala | 2 +- src/main/scala/is/hail/io/bgen/LoadBgen.scala | 2 +- .../scala/is/hail/io/plink/LoadPlink.scala | 2 +- src/main/scala/is/hail/io/vcf/LoadGDB.scala | 2 +- src/main/scala/is/hail/io/vcf/LoadVCF.scala | 2 +- src/main/scala/is/hail/rvd/OrderedRVD.scala | 143 +++++++++++++----- .../scala/is/hail/rvd/OrderedRVDType.scala | 6 + src/main/scala/is/hail/rvd/RVD.scala | 6 + .../is/hail/stats/BaldingNicholsModel.scala | 2 +- src/main/scala/is/hail/table/Table.scala | 4 +- .../scala/is/hail/variant/MatrixTable.scala | 63 ++++---- 12 files changed, 188 insertions(+), 123 deletions(-) diff --git a/src/main/scala/is/hail/expr/Relational.scala b/src/main/scala/is/hail/expr/Relational.scala index a26b82b5c14c..2fdd56832790 100644 --- a/src/main/scala/is/hail/expr/Relational.scala +++ b/src/main/scala/is/hail/expr/Relational.scala @@ -85,6 +85,8 @@ case class MatrixValue( colValues: IndexedSeq[Annotation], rvd: OrderedRVD) { + assert(rvd.typ == typ.orvdType) + def sparkContext: SparkContext = rvd.sparkContext def nPartitions: Int = rvd.partitions.length @@ -252,42 +254,43 @@ case class MatrixRead( } else { val entriesRVD = spec.entriesComponent.read(hc, path) val entriesRowType = entriesRVD.rowType - OrderedRVD(typ.orvdType, - rowsRVD.partitioner, - rowsRVD.rdd.zipPartitions(entriesRVD.rdd) { case (it1, it2) => - val rvb = new RegionValueBuilder() - - new Iterator[RegionValue] { - def hasNext: Boolean = { - val hn = it1.hasNext - assert(hn == it2.hasNext) - hn - } + rowsRVD.zipPartitionsPreservesPartitioning( + typ.orvdType, + entriesRVD + ) { case (it1, it2) => + val rvb = new RegionValueBuilder() + + new Iterator[RegionValue] { + def hasNext: Boolean = { + val hn = it1.hasNext + assert(hn == it2.hasNext) + hn + } - def next(): RegionValue = { - val rv1 = it1.next() - val rv2 = it2.next() - val region = rv2.region - rvb.set(region) - rvb.start(fullRowType) - rvb.startStruct() - var i = 0 - while (i < localEntriesIndex) { - rvb.addField(fullRowType, rv1, i) - i += 1 - } - rvb.addField(entriesRowType, rv2, 0) + def next(): RegionValue = { + val rv1 = it1.next() + val rv2 = it2.next() + val region = rv2.region + rvb.set(region) + rvb.start(fullRowType) + rvb.startStruct() + var i = 0 + while (i < localEntriesIndex) { + rvb.addField(fullRowType, rv1, i) i += 1 - while (i < fullRowType.size) { - rvb.addField(fullRowType, rv1, i - 1) - i += 1 - } - rvb.endStruct() - rv2.set(region, rvb.end()) - rv2 } + rvb.addField(entriesRowType, rv2, 0) + i += 1 + while (i < fullRowType.size) { + rvb.addField(fullRowType, rv1, i - 1) + i += 1 + } + rvb.endStruct() + rv2.set(region, rvb.end()) + rv2 } - }) + } + } } } @@ -450,10 +453,8 @@ case class MapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR { } case class TableValue(typ: TableType, globals: BroadcastValue, rvd: RVD) { - def rdd: RDD[Row] = { - val localRowType = typ.rowType - rvd.rdd.map { rv => new UnsafeRow(localRowType, rv.region.copy(), rv.offset) } - } + def rdd: RDD[Row] = + rvd.toRows def filter(p: (RegionValue, RegionValue) => Boolean): TableValue = { val globalType = typ.globalType @@ -642,7 +643,7 @@ case class TableJoin(left: TableIR, right: TableIR, joinType: String) extends Ta val leftORVD = leftTV.rvd match { case ordered: OrderedRVD => ordered case unordered => - OrderedRVD( + OrderedRVD.coerce( new OrderedRVDType(left.typ.key.toArray, left.typ.key.toArray, leftRowType), unordered.rdd, None, @@ -656,7 +657,7 @@ case class TableJoin(left: TableIR, right: TableIR, joinType: String) extends Ta if (joinType == "left" || joinType == "inner") unordered.constrainToOrderedPartitioner(ordType, leftORVD.partitioner) else - OrderedRVD(ordType, unordered.rdd, None, Some(leftORVD.partitioner)) + OrderedRVD.coerce(ordType, unordered.rdd, None, Some(leftORVD.partitioner)) } val joinedRVD = leftORVD.orderedJoin( rightORVD, diff --git a/src/main/scala/is/hail/io/LoadMatrix.scala b/src/main/scala/is/hail/io/LoadMatrix.scala index 4abb24941080..8ad472e613cf 100644 --- a/src/main/scala/is/hail/io/LoadMatrix.scala +++ b/src/main/scala/is/hail/io/LoadMatrix.scala @@ -392,7 +392,7 @@ object LoadMatrix { val (partitioner, keepPartitions) = makePartitionerFromCounts(partitionCounts, matrixType.orvdType.pkType) OrderedRVD(matrixType.orvdType, partitioner, rdd.subsetPartitions(keepPartitions)) } else - OrderedRVD(matrixType.orvdType, rdd, None, None) + OrderedRVD.coerce(matrixType.orvdType, rdd, None, None) new MatrixTable(hc, matrixType, diff --git a/src/main/scala/is/hail/io/bgen/LoadBgen.scala b/src/main/scala/is/hail/io/bgen/LoadBgen.scala index 41764d4dae7b..17fbf0bcd828 100644 --- a/src/main/scala/is/hail/io/bgen/LoadBgen.scala +++ b/src/main/scala/is/hail/io/bgen/LoadBgen.scala @@ -169,7 +169,7 @@ object LoadBgen { new MatrixTable(hc, matrixType, BroadcastValue(Annotation.empty, matrixType.globalType, sc), sampleIds.map(x => Annotation(x)), - OrderedRVD(matrixType.orvdType, rdd2, Some(fastKeys), None)) + OrderedRVD.coerce(matrixType.orvdType, rdd2, Some(fastKeys), None)) } def index(hConf: org.apache.hadoop.conf.Configuration, file: String) { diff --git a/src/main/scala/is/hail/io/plink/LoadPlink.scala b/src/main/scala/is/hail/io/plink/LoadPlink.scala index 75284cce2ac7..970c6616f577 100644 --- a/src/main/scala/is/hail/io/plink/LoadPlink.scala +++ b/src/main/scala/is/hail/io/plink/LoadPlink.scala @@ -189,7 +189,7 @@ object LoadPlink { new MatrixTable(hc, matrixType, BroadcastValue(Annotation.empty, matrixType.globalType, sc), sampleAnnotations, - OrderedRVD(matrixType.orvdType, rdd2, Some(fastKeys), None)) + OrderedRVD.coerce(matrixType.orvdType, rdd2, Some(fastKeys), None)) } def apply(hc: HailContext, bedPath: String, bimPath: String, famPath: String, ffConfig: FamFileConfig, diff --git a/src/main/scala/is/hail/io/vcf/LoadGDB.scala b/src/main/scala/is/hail/io/vcf/LoadGDB.scala index 9e8784e47e08..ad7a76302dbe 100644 --- a/src/main/scala/is/hail/io/vcf/LoadGDB.scala +++ b/src/main/scala/is/hail/io/vcf/LoadGDB.scala @@ -190,6 +190,6 @@ object LoadGDB { new MatrixTable(hc, matrixType, BroadcastValue(Annotation.empty, matrixType.globalType, sc), sampleIds.map(x => Annotation(x)), - OrderedRVD(matrixType.orvdType, hc.sc.parallelize(records), None, None)) + OrderedRVD.coerce(matrixType.orvdType, hc.sc.parallelize(records), None, None)) } } diff --git a/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/src/main/scala/is/hail/io/vcf/LoadVCF.scala index 43e9e0b44f1d..3dc65e4af7f4 100644 --- a/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -841,7 +841,7 @@ object LoadVCF { // nothing after the key val justVariants = parseLines(() => ())((c, l, rvb) => ())(lines, kType, rg, contigRecoding) - val rdd = OrderedRVD( + val rdd = OrderedRVD.coerce( matrixType.orvdType, parseLines { () => new ParseLineContext(genotypeSignature, new BufferedLineIterator(headerLinesBc.value.iterator.buffered)) diff --git a/src/main/scala/is/hail/rvd/OrderedRVD.scala b/src/main/scala/is/hail/rvd/OrderedRVD.scala index 3a65aa9a9b85..9755b910246d 100644 --- a/src/main/scala/is/hail/rvd/OrderedRVD.scala +++ b/src/main/scala/is/hail/rvd/OrderedRVD.scala @@ -23,6 +23,9 @@ class OrderedRVD( self => def rowType: TStruct = typ.rowType + def updateType(newTyp: OrderedRVDType): OrderedRVD = + OrderedRVD(newTyp, partitioner, rdd) + def mapPreservesPartitioning(newTyp: OrderedRVDType)(f: (RegionValue) => RegionValue): OrderedRVD = OrderedRVD(newTyp, partitioner, @@ -38,13 +41,6 @@ class OrderedRVD( partitioner, rdd.mapPartitions(f)) - def zipPartitionsPreservesPartitioning[T](newTyp: OrderedRVDType, rdd2: RDD[T])(f: (Iterator[RegionValue], Iterator[T]) => Iterator[RegionValue])(implicit tct: ClassTag[T]): OrderedRVD = - OrderedRVD(newTyp, - partitioner, - rdd.zipPartitions(rdd2) { case (it, it2) => - f(it, it2) - }) - override def filter(p: (RegionValue) => Boolean): OrderedRVD = OrderedRVD(typ, partitioner, @@ -356,28 +352,41 @@ class OrderedRVD( spec.write(sparkContext.hadoopConfiguration, path) partitionCounts } -} -object OrderedRVD { - type CoercionMethod = Int + def zipPartitionsPreservesPartitioning[T: ClassTag]( + newTyp: OrderedRVDType, + that: RDD[T] + )(zipper: (Iterator[RegionValue], Iterator[T]) => Iterator[RegionValue] + ): OrderedRVD = + OrderedRVD( + newTyp, + partitioner, + this.rdd.zipPartitions(that, preservesPartitioning = true)(zipper)) + + def zipPartitionsPreservesPartitioning( + newTyp: OrderedRVDType, + that: RVD + )(zipper: (Iterator[RegionValue], Iterator[RegionValue]) => Iterator[RegionValue] + ): OrderedRVD = + OrderedRVD( + newTyp, + partitioner, + this.rdd.zipPartitions(that.rdd, preservesPartitioning = true)(zipper)) - final val ORDERED_PARTITIONER: CoercionMethod = 0 - final val AS_IS: CoercionMethod = 1 - final val LOCAL_SORT: CoercionMethod = 2 - final val SHUFFLE: CoercionMethod = 3 + def writeRowsSplit( + path: String, + t: MatrixType, + codecSpec: CodecSpec + ): Array[Long] = rdd.writeRowsSplit(path, t, codecSpec, partitioner) +} +object OrderedRVD { def empty(sc: SparkContext, typ: OrderedRVDType): OrderedRVD = { OrderedRVD(typ, OrderedRVDPartitioner.empty(typ), sc.emptyRDD[RegionValue]) } - def apply(typ: OrderedRVDType, - rdd: RDD[RegionValue], fastKeys: Option[RDD[RegionValue]], hintPartitioner: Option[OrderedRVDPartitioner]): OrderedRVD = { - val (_, orderedRVD) = coerce(typ, rdd, fastKeys, hintPartitioner) - orderedRVD - } - /** * Precondition: the iterator it is PK-sorted. We lazily K-sort each block * of PK-equivalent elements. @@ -448,16 +457,54 @@ object OrderedRVD { pkis.sortBy(_.min)(typ.pkOrd) } - def coerce(typ: OrderedRVDType, + def coerce( + typ: OrderedRVDType, + rvd: RVD + ): OrderedRVD = coerce(typ, rvd, None, None) + + def coerce( + typ: OrderedRVDType, + rvd: RVD, + fastKeys: Option[RDD[RegionValue]], + hintPartitioner: Option[OrderedRVDPartitioner] + ): OrderedRVD = coerce(typ, rvd.rdd, fastKeys, hintPartitioner) + + def coerce( + typ: OrderedRVDType, + rdd: RDD[RegionValue] + ): OrderedRVD = coerce(typ, rdd, None, None) + + def coerce( + typ: OrderedRVDType, + rdd: RDD[RegionValue], + fastKeys: RDD[RegionValue] + ): OrderedRVD = coerce(typ, rdd, Some(fastKeys), None) + + def coerce( + typ: OrderedRVDType, + rdd: RDD[RegionValue], + hintPartitioner: OrderedRVDPartitioner + ): OrderedRVD = coerce(typ, rdd, None, Some(hintPartitioner)) + + def coerce( + typ: OrderedRVDType, + rdd: RDD[RegionValue], + fastKeys: RDD[RegionValue], + hintPartitioner: OrderedRVDPartitioner + ): OrderedRVD = coerce(typ, rdd, Some(fastKeys), Some(hintPartitioner)) + + def coerce( + typ: OrderedRVDType, // rdd: RDD[RegionValue[rowType]] rdd: RDD[RegionValue], // fastKeys: Option[RDD[RegionValue[kType]]] - fastKeys: Option[RDD[RegionValue]] = None, - hintPartitioner: Option[OrderedRVDPartitioner] = None): (CoercionMethod, OrderedRVD) = { + fastKeys: Option[RDD[RegionValue]], + hintPartitioner: Option[OrderedRVDPartitioner] + ): OrderedRVD = { val sc = rdd.sparkContext if (rdd.partitions.isEmpty) - return (ORDERED_PARTITIONER, empty(sc, typ)) + return empty(sc, typ) // keys: RDD[RegionValue[kType]] val keys = fastKeys.getOrElse(getKeys(typ, rdd)) @@ -465,7 +512,7 @@ object OrderedRVD { val pkis = getPartitionKeyInfo(typ, keys) if (pkis.isEmpty) - return (AS_IS, empty(sc, typ)) + return empty(sc, typ) val partitionsSorted = (pkis, pkis.tail).zipped.forall { case (p, pnext) => val r = typ.pkOrd.lteq(p.max, pnext.min) @@ -487,29 +534,28 @@ object OrderedRVD { (adjSortedness: @unchecked) match { case OrderedRVPartitionInfo.KSORTED => info("Coerced sorted dataset") - (AS_IS, OrderedRVD(typ, + OrderedRVD(typ, partitioner, - adjustedRDD)) + adjustedRDD) case OrderedRVPartitionInfo.TSORTED => info("Coerced almost-sorted dataset") - (LOCAL_SORT, OrderedRVD(typ, + OrderedRVD(typ, partitioner, adjustedRDD.mapPartitions { it => localKeySort(typ, it) - })) + }) } } else { info("Ordering unsorted dataset with network shuffle") - val orvd = hintPartitioner + hintPartitioner .filter(_.numPartitions >= rdd.partitions.length) .map(adjustBoundsAndShuffle(typ, _, rdd)) .getOrElse { - val ranges = calculateKeyRanges(typ, pkis, rdd.getNumPartitions) - val p = new OrderedRVDPartitioner(typ.partitionKey, typ.kType, ranges) - shuffle(typ, p, rdd) - } - (SHUFFLE, orvd) + val ranges = calculateKeyRanges(typ, pkis, rdd.getNumPartitions) + val p = new OrderedRVDPartitioner(typ.partitionKey, typ.kType, ranges) + shuffle(typ, p, rdd) + } } } @@ -567,6 +613,12 @@ object OrderedRVD { shuffle(typ, partitioner.enlargeToRange(Interval(min, max, true, true)), rdd) } + def shuffle( + typ: OrderedRVDType, + partitioner: OrderedRVDPartitioner, + rvd: RVD + ): OrderedRVD = shuffle(typ, partitioner, rvd.rdd) + def shuffle(typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, rdd: RDD[RegionValue]): OrderedRVD = { @@ -594,9 +646,6 @@ object OrderedRVD { }) } - def shuffle(typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, rvd: RVD): OrderedRVD = - shuffle(typ, partitioner, rvd.rdd) - def rangesAndAdjustments(typ: OrderedRVDType, sortedKeyInfo: Array[OrderedRVPartitionInfo], sortedness: Int): (IndexedSeq[Array[Adjustment[RegionValue]]], UnsafeIndexedSeq, Int) = { @@ -662,6 +711,12 @@ object OrderedRVD { (adjustmentsBuffer, rangeBounds, adjSortedness) } + def apply( + typ: OrderedRVDType, + partitioner: OrderedRVDPartitioner, + rvd: RVD + ): OrderedRVD = apply(typ, partitioner, rvd.rdd) + def apply(typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, rdd: RDD[RegionValue]): OrderedRVD = { @@ -703,8 +758,14 @@ object OrderedRVD { }) } - def apply(typ: OrderedRVDType, partitioner: OrderedRVDPartitioner, rvd: RVD): OrderedRVD = { - assert(typ.rowType == rvd.rowType) - apply(typ, partitioner, rvd.rdd) + def union(rvds: Array[OrderedRVD]): OrderedRVD = { + require(rvds.length > 1) + val first = rvds(0) + val sc = first.sparkContext + OrderedRVD.coerce( + first.typ, + sc.union(rvds.map(_.rdd)), + None, + None) } } diff --git a/src/main/scala/is/hail/rvd/OrderedRVDType.scala b/src/main/scala/is/hail/rvd/OrderedRVDType.scala index c3b359283bfc..615b539eea8d 100644 --- a/src/main/scala/is/hail/rvd/OrderedRVDType.scala +++ b/src/main/scala/is/hail/rvd/OrderedRVDType.scala @@ -113,6 +113,12 @@ class OrderedRVDType( sb += '}' sb.result() } + + def copy( + partitionKey: Array[String] = partitionKey, + key: Array[String] = key, + rowType: TStruct = rowType + ): OrderedRVDType = new OrderedRVDType(partitionKey, key, rowType) } object OrderedRVDType { diff --git a/src/main/scala/is/hail/rvd/RVD.scala b/src/main/scala/is/hail/rvd/RVD.scala index 1d3e1d5027a8..3951968e9727 100644 --- a/src/main/scala/is/hail/rvd/RVD.scala +++ b/src/main/scala/is/hail/rvd/RVD.scala @@ -4,6 +4,7 @@ import is.hail.HailContext import is.hail.annotations._ import is.hail.expr.{JSONAnnotationImpex, Parser} import is.hail.expr.types.{TArray, TInterval, TStruct, TStructSerializer} +import is.hail.sparkextras._ import is.hail.io._ import is.hail.utils._ import org.apache.hadoop @@ -209,4 +210,9 @@ trait RVD { def sample(withReplacement: Boolean, p: Double, seed: Long): RVD def write(path: String, codecSpec: CodecSpec): Array[Long] + + def toRows: RDD[Row] = { + val localRowType = rowType + rdd.map { rv => new UnsafeRow(localRowType, rv.region.copy(), rv.offset) } + } } diff --git a/src/main/scala/is/hail/stats/BaldingNicholsModel.scala b/src/main/scala/is/hail/stats/BaldingNicholsModel.scala index 18c0a1cdfd06..020e09182c25 100644 --- a/src/main/scala/is/hail/stats/BaldingNicholsModel.scala +++ b/src/main/scala/is/hail/stats/BaldingNicholsModel.scala @@ -216,7 +216,7 @@ object BaldingNicholsModel { Array.tabulate(N)(i => Annotation(i, popOfSample_n(0, i).toInt)) // FIXME: should use fast keys - val ordrdd = OrderedRVD(matrixType.orvdType, rdd, None, None) + val ordrdd = OrderedRVD.coerce(matrixType.orvdType, rdd, None, None) new MatrixTable(hc, matrixType, diff --git a/src/main/scala/is/hail/table/Table.scala b/src/main/scala/is/hail/table/Table.scala index 160ab2bbfa5f..08919be84a87 100644 --- a/src/main/scala/is/hail/table/Table.scala +++ b/src/main/scala/is/hail/table/Table.scala @@ -602,7 +602,7 @@ class Table(val hc: HailContext, val tir: TableIR) { } val ordType = new OrderedRVDType(partitionKeys, rowKeys ++ Array(INDEX_UID), rowEntryStruct) - val ordered = OrderedRVD(ordType, rowEntryRVD.rdd, None, None) + val ordered = OrderedRVD.coerce(ordType, rowEntryRVD) val matrixType: MatrixType = MatrixType.fromParts( globalSignature, @@ -1144,6 +1144,6 @@ class Table(val hc: HailContext, val tir: TableIR) { def toOrderedRVD(hintPartitioner: Option[OrderedRVDPartitioner], partitionKeys: Int): OrderedRVD = { val orderedKTType = new OrderedRVDType(key.take(partitionKeys).toArray, key.toArray, signature) assert(hintPartitioner.forall(p => p.pkType.types.sameElements(orderedKTType.pkType.types))) - OrderedRVD(orderedKTType, rvd.rdd, None, hintPartitioner) + OrderedRVD.coerce(orderedKTType, rvd, None, hintPartitioner) } } diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index 05bfefc2e444..c2a243169d37 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -143,7 +143,7 @@ object MatrixTable { val localNCols = colValues.length var ds = new MatrixTable(hc, matrixType, BroadcastValue(globals, matrixType.globalType, hc.sc), colValues, - OrderedRVD(matrixType.orvdType, + OrderedRVD.coerce(matrixType.orvdType, rdd.mapPartitions { it => val region = Region() val rvb = new RegionValueBuilder(region) @@ -212,7 +212,7 @@ object MatrixTable { rv } } - val rvd = OrderedRVD(mt.orvdType, rdd, None, None) + val rvd = OrderedRVD.coerce(mt.orvdType, rdd, None, None) new MatrixTable(hc, mt, BroadcastValue(Annotation.empty, mt.globalType, hc.sc), Array.tabulate(nCols)(Annotation(_)), rvd) } @@ -303,15 +303,8 @@ object MatrixTable { def unionRows(datasets: Array[MatrixTable]): MatrixTable = { require(datasets.length >= 2) val first = datasets(0) - val sc = first.sparkContext - checkDatasetSchemasCompatible(datasets) - - first.copyMT( - rvd = OrderedRVD( - first.rvd.typ, - sc.union(datasets.map(_.rvd.rdd)), - None, None)) + first.copyMT(rvd = OrderedRVD.union(datasets.map(_.rvd))) } def fromRowsTable(kt: Table, partitionKey: java.util.ArrayList[String] = null): MatrixTable = { @@ -347,7 +340,7 @@ object MatrixTable { new MatrixTable(kt.hc, matrixType, BroadcastValue(Annotation.empty, matrixType.globalType, kt.hc.sc), Array.empty[Annotation], - OrderedRVD(matrixType.orvdType, rdd, None, None)) + OrderedRVD.coerce(matrixType.orvdType, rdd, None, None)) } } @@ -576,7 +569,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { rowPartitionKey = partitionKeys) copyMT(matrixType = newMatrixType, - rvd = OrderedRVD(newMatrixType.orvdType, rvd.rdd, None, None)) + rvd = OrderedRVD.coerce(newMatrixType.orvdType, rvd)) } def keyColsBy(keys: java.util.ArrayList[String]): MatrixTable = keyColsBy(keys.asScala: _*) @@ -970,7 +963,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { warn("modified row key, rescanning to compute ordering...") val newRDD = rvd.mapPartitions(mapPartitionsF) copyMT(matrixType = newMatrixType, - rvd = OrderedRVD(newMatrixType.orvdType, newRDD, None, None)) + rvd = OrderedRVD.coerce(newMatrixType.orvdType, newRDD, None, None)) } else copyMT(matrixType = newMatrixType, rvd = rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType)(mapPartitionsF)) } @@ -1080,7 +1073,11 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val localRVRowType = rvRowType val pkIndex = rvRowType.fieldIdx(rowPartitionKey(0)) - val newRDD = rvd.rdd.zipPartitions(zipRDD, preservesPartitioning = true) { case (it, intervals) => + val newMatrixType = matrixType.copy(rvRowType = newRVType) + val newRVD = rvd.zipPartitionsPreservesPartitioning( + newMatrixType.orvdType, + zipRDD + ) { case (it, intervals) => val intervalAnnotations: Array[(Interval, Any)] = intervals.map { rv => val ur = new UnsafeRow(ktSignature, rv) @@ -1118,13 +1115,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } } - val newMatrixType = matrixType.copy(rvRowType = newRVType) - - val newRVD = OrderedRVD( - newMatrixType.orvdType, - rvd.partitioner, - newRDD) - copyMT(rvd = newRVD, matrixType = newMatrixType) } @@ -1273,7 +1263,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { warn("modified row key, rescanning to compute ordering...") val newRDD = rvd.mapPartitions(mapPartitionsF) copyMT(matrixType = newMatrixType, - rvd = OrderedRVD(newMatrixType.orvdType, newRDD, None, None)) + rvd = OrderedRVD.coerce(newMatrixType.orvdType, newRDD, None, None)) } else copyMT(matrixType = newMatrixType, rvd = rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType)(mapPartitionsF)) } @@ -1327,7 +1317,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { warn("modified row key, rescanning to compute ordering...") val newRDD = rvd.mapPartitions(mapPartitionsF) copyMT(matrixType = newMatrixType, - rvd = OrderedRVD(newMatrixType.orvdType, newRDD, None, None)) + rvd = OrderedRVD.coerce(newMatrixType.orvdType, newRDD, None, None)) } else copyMT(matrixType = newMatrixType, rvd = rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType)(mapPartitionsF)) } @@ -1376,13 +1366,8 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { def nPartitions: Int = rvd.partitions.length - def annotateRowsVDS(right: MatrixTable, root: String): MatrixTable = { - // need to strip entries! - // FIXME: HACK - val rTyp = new OrderedRVDType(right.rowPartitionKey.toArray, right.rowKey.toArray, right.rowType) - val rightRVD = OrderedRVD(rTyp, right.rvd.partitioner, right.rowsTable().rvd) - orderedRVDLeftJoinDistinctAndInsert(rightRVD, root, product = false) - } + def annotateRowsVDS(right: MatrixTable, root: String): MatrixTable = + orderedRVDLeftJoinDistinctAndInsert(right.rowFieldsRVD, root, product = false) def count(): (Long, Long) = (countRows(), numCols) @@ -2052,7 +2037,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val newRVD = if (fieldMapRows.isEmpty) rvd else { val newType = newMatrixType.orvdType val newPartitioner = rvd.partitioner.withKType(pk.toArray, newType.kType) - OrderedRVD(newType, newPartitioner, rvd.rdd) + rvd.updateType(newType) } new MatrixTable(hc, newMatrixType, globals, colValues, newRVD) @@ -2319,12 +2304,13 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { Array.empty[String]) } - def rowsTable(): Table = { + private def rowFieldsRVD: OrderedRVD = { val localRowType = rowType val fullRowType = rvRowType val localEntriesIndex = entriesIndex - val tableType = TableType(rowType, rowKey, globalType) - new Table(hc, TableLiteral(TableValue(tableType, globals, rvd.mapPartitions(rowType) { it => + rvd.mapPartitionsPreservesPartitioning( + new OrderedRVDType(rowPartitionKey.toArray, rowKey.toArray, rowType) + ) { it => val rv2b = new RegionValueBuilder() val rv2 = RegionValue() it.map { rv => @@ -2341,7 +2327,12 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { rv2.set(rv.region, rv2b.end()) rv2 } - }))) + } + } + + def rowsTable(): Table = { + val tableType = TableType(rowType, rowKey, globalType) + new Table(hc, TableLiteral(TableValue(tableType, globals, rowFieldsRVD))) } def entriesTable(): Table = { @@ -2558,7 +2549,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { hc.hadoopConf.mkDir(path) - val partitionCounts = rvd.rdd.writeRowsSplit(path, matrixType, codecSpec, rvd.partitioner) + val partitionCounts = rvd.writeRowsSplit(path, matrixType, codecSpec) val globalsPath = path + "/globals" hadoopConf.mkDir(globalsPath)