Skip to content

Commit

Permalink
whoops
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Mar 20, 2018
1 parent 1a06b9e commit ae2272b
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/methods/LDPrune.scala
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,6 @@ object LDPrune {
val ((globalPrunedRDD, nVariantsFinal), globalDuration) = time(pruneGlobal(rddLP2, r2Threshold, windowSize))
info(s"LD prune step 3 of 3: nVariantsKept=$nVariantsFinal, time=${ formatTime(globalDuration) }")

vsm.copy2(rvd = vsm.rvd.copy(rdd = vsm.rvd.orderedJoinDistinct(globalPrunedRDD, "inner").map(_.rvLeft)))
vsm.copy2(rvd = vsm.rvd.orderedJoinDistinct(globalPrunedRDD, "inner", _.map(_.rvLeft), vsm.rvd.typ))
}
}
18 changes: 1 addition & 17 deletions src/main/scala/is/hail/rvd/KeyedOrderedRVD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) extends Seria
new OrderedRVD(joinedType, newPartitioner, joinedRDD)
}

def newOrderedJoinDistinct(
def orderedJoinDistinct(
right: KeyedOrderedRVD,
joinType: String,
joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue],
Expand Down Expand Up @@ -81,22 +81,6 @@ class KeyedOrderedRVD(val rvd: OrderedRVD, val key: Array[String]) extends Seria
new OrderedRVD(joinedType, newPartitioner, joinedRDD)
}

def orderedJoinDistinct(right: KeyedOrderedRVD, joinType: String): RDD[JoinedRegionValue] = {
checkJoinCompatability(right)
val rekeyedLTyp = new OrderedRVDType(typ.partitionKey, key, typ.rowType)
val rekeyedRTyp = new OrderedRVDType(right.typ.partitionKey, right.key, right.typ.rowType)

val repartitionedRight = right.rvd.constrainToOrderedPartitioner(right.typ, this.rvd.partitioner)
val compute: (OrderedRVIterator, OrderedRVIterator) => Iterator[JoinedRegionValue] =
(joinType: @unchecked) match {
case "inner" => _.innerJoinDistinct(_)
case "left" => _.leftJoinDistinct(_)
}
this.rvd.rdd.zipPartitions(repartitionedRight.rdd, true){ (leftIt, rightIt) =>
compute(OrderedRVIterator(rekeyedLTyp, leftIt), OrderedRVIterator(rekeyedRTyp, rightIt))
}
}

def orderedZipJoin(right: KeyedOrderedRVD): RDD[JoinedRegionValue] = {
val newPartitioner = rvd.partitioner.enlargeToRange(right.rvd.partitioner.range)

Expand Down
9 changes: 7 additions & 2 deletions src/main/scala/is/hail/rvd/OrderedRVD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,13 @@ class OrderedRVD(
): OrderedRVD =
keyBy().orderedJoin(right.keyBy(), joinType, joiner, joinedType)

def orderedJoinDistinct(right: OrderedRVD, joinType: String): RDD[JoinedRegionValue] =
keyBy().orderedJoinDistinct(right.keyBy(), joinType)
def orderedJoinDistinct(
right: OrderedRVD,
joinType: String,
joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue],
joinedType: OrderedRVDType
): OrderedRVD =
keyBy().orderedJoinDistinct(right.keyBy(), joinType, joiner, joinedType)

def orderedZipJoin(right: OrderedRVD): RDD[JoinedRegionValue] =
keyBy().orderedZipJoin(right.keyBy())
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/is/hail/variant/MatrixTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) {
}
}

val joinedRVD = this.rvd.keyBy(rowKey.take(right.typ.key.length).toArray).newOrderedJoinDistinct(
val joinedRVD = this.rvd.keyBy(rowKey.take(right.typ.key.length).toArray).orderedJoinDistinct(
right.keyBy(),
"left",
joiner,
Expand Down Expand Up @@ -1756,7 +1756,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) {
val localEntriesType = matrixType.entryArrayType
assert(right.matrixType.entryArrayType == localEntriesType)

val joined = rvd.orderedJoinDistinct(right.rvd, "inner").mapPartitions({ it =>
val joiner: Iterator[JoinedRegionValue] => Iterator[RegionValue] = { it =>
val rvb = new RegionValueBuilder()
val rv2 = RegionValue()

Expand Down Expand Up @@ -1800,13 +1800,13 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) {
rv2.set(lrv.region, rvb.end())
rv2
}
}, preservesPartitioning = true)
}

val newMatrixType = matrixType.copyParts() // move entries to the end

copyMT(matrixType = newMatrixType,
colValues = colValues ++ right.colValues,
rvd = OrderedRVD(rvd.typ, rvd.partitioner, joined))
rvd = rvd.orderedJoinDistinct(right.rvd, "inner", joiner, rvd.typ))
}

def makeKT(rowExpr: String, entryExpr: String, keyNames: Array[String] = Array.empty, seperator: String = "."): Table = {
Expand Down
4 changes: 2 additions & 2 deletions src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class PartitioningSuite extends SparkSuite {
val mt = MatrixTable.fromRowsTable(Table.range(hc, 100, "idx", partitions=Some(6)))
val orvdType = mt.matrixType.orvdType

mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "left").count()
mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "inner").count()
mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "left", _.map(_._1), orvdType).count()
mt.rvd.orderedJoinDistinct(OrderedRVD.empty(hc.sc, orvdType), "inner", _.map(_._1), orvdType).count()
}
}
15 changes: 7 additions & 8 deletions src/test/scala/is/hail/vds/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,23 @@ class JoinSuite extends SparkSuite {
val localRowType = left.rvRowType

// Inner distinct ordered join
val jInner = left.rvd.orderedJoinDistinct(right.rvd, "inner")
val jInner = left.rvd.orderedJoinDistinct(right.rvd, "inner", _.map(_._1), left.rvd.typ)
val jInnerOrdRDD1 = left.rdd.join(right.rdd.distinct)

assert(jInner.count() == jInnerOrdRDD1.count())
assert(jInner.forall(jrv => jrv.rvLeft != null && jrv.rvRight != null))
assert(jInner.map { jrv =>
val ur = new UnsafeRow(localRowType, jrv.rvLeft)
assert(jInner.map { rv =>
val ur = new UnsafeRow(localRowType, rv)
ur.getAs[Locus](0)
}.collect() sameElements jInnerOrdRDD1.map(_._1.asInstanceOf[Row].get(0)).collect().sorted(vType.ordering.toOrdering))

// Left distinct ordered join
val jLeft = left.rvd.orderedJoinDistinct(right.rvd, "left")
val jLeft = left.rvd.orderedJoinDistinct(right.rvd, "left", _.map(_._1), left.rvd.typ)
val jLeftOrdRDD1 = left.rdd.leftOuterJoin(right.rdd.distinct)

assert(jLeft.count() == jLeftOrdRDD1.count())
assert(jLeft.forall(jrv => jrv.rvLeft != null))
assert(jLeft.map { jrv =>
val ur = new UnsafeRow(localRowType, jrv.rvLeft)
assert(jLeft.rdd.forall(rv => rv != null))
assert(jLeft.map { rv =>
val ur = new UnsafeRow(localRowType, rv)
ur.getAs[Locus](0)
}.collect() sameElements jLeftOrdRDD1.map(_._1.asInstanceOf[Row].get(0)).collect().sorted(vType.ordering.toOrdering))
}
Expand Down

0 comments on commit ae2272b

Please sign in to comment.