-
Notifications
You must be signed in to change notification settings - Fork 244
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #83 from broadinstitute/tp_gc
Tp gc
- Loading branch information
Showing
9 changed files
with
211 additions
and
266 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
213 changes: 174 additions & 39 deletions
213
src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,94 +1,229 @@ | ||
package org.broadinstitute.hail.variant | ||
|
||
import org.apache.spark.SparkContext | ||
import java.nio.ByteBuffer | ||
|
||
import org.apache.spark.{SparkEnv, SparkContext} | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.SQLContext | ||
import org.broadinstitute.hail.Utils._ | ||
import org.broadinstitute.hail.variant.vsm.SparkyVSM | ||
import scala.language.implicitConversions | ||
|
||
import scala.reflect.ClassTag | ||
import scala.reflect.runtime.universe._ | ||
|
||
|
||
object VariantSampleMatrix { | ||
def apply(vsmtype: String, | ||
metadata: VariantMetadata, | ||
rdd: RDD[(Variant, GenotypeStream)]): VariantSampleMatrix[Genotype] = { | ||
vsmtype match { | ||
case "sparky" => new SparkyVSM(metadata, rdd) | ||
} | ||
def apply(metadata: VariantMetadata, | ||
rdd: RDD[(Variant, Iterable[Genotype])]): VariantDataset = { | ||
new VariantSampleMatrix(metadata, rdd) | ||
} | ||
|
||
def read(sqlContext: SQLContext, dirname: String) = { | ||
val (vsmType, metadata) = readObjectFile(dirname + "/metadata.ser", sqlContext.sparkContext.hadoopConfiguration)( | ||
_.readObject().asInstanceOf[(String, VariantMetadata)]) | ||
def read(sqlContext: SQLContext, dirname: String): VariantDataset = { | ||
require(dirname.endsWith(".vds")) | ||
import RichRow._ | ||
|
||
val metadata = readObjectFile(dirname + "/metadata.ser", sqlContext.sparkContext.hadoopConfiguration)( | ||
_.readObject().asInstanceOf[VariantMetadata]) | ||
|
||
vsmType match { | ||
case "sparky" => SparkyVSM.read(sqlContext, dirname, metadata) | ||
} | ||
// val df = sqlContext.read.parquet(dirname + "/rdd.parquet") | ||
val df = sqlContext.parquetFile(dirname + "/rdd.parquet") | ||
new VariantSampleMatrix[Genotype](metadata, df.rdd.map(r => (r.getVariant(0), r.getGenotypeStream(1)))) | ||
} | ||
} | ||
|
||
// FIXME all maps should become RDDs | ||
abstract class VariantSampleMatrix[T](val metadata: VariantMetadata, | ||
val localSamples: Array[Int]) { | ||
class VariantSampleMatrix[T](val metadata: VariantMetadata, | ||
val localSamples: Array[Int], | ||
val rdd: RDD[(Variant, Iterable[T])]) | ||
(implicit ttt: TypeTag[T], tct: ClassTag[T], | ||
vct: ClassTag[Variant]) { | ||
|
||
def this(metadata: VariantMetadata, rdd: RDD[(Variant, Iterable[T])]) | ||
(implicit ttt: TypeTag[T], tct: ClassTag[T]) = | ||
this(metadata, Array.range(0, metadata.nSamples), rdd) | ||
|
||
def sampleIds: Array[String] = metadata.sampleIds | ||
|
||
def nSamples: Int = metadata.sampleIds.length | ||
|
||
def nLocalSamples: Int = localSamples.length | ||
|
||
def sparkContext: SparkContext | ||
def copy[U](metadata: VariantMetadata = this.metadata, | ||
localSamples: Array[Int] = this.localSamples, | ||
rdd: RDD[(Variant, Iterable[U])] = this.rdd) | ||
(implicit ttt: TypeTag[U], tct: ClassTag[U]): VariantSampleMatrix[U] = | ||
new VariantSampleMatrix(metadata, localSamples, rdd) | ||
|
||
// underlying RDD | ||
def nPartitions: Int | ||
def cache(): VariantSampleMatrix[T] | ||
def repartition(nPartitions: Int): VariantSampleMatrix[T] | ||
def sparkContext: SparkContext = rdd.sparkContext | ||
|
||
def variants: RDD[Variant] | ||
def nVariants: Long = variants.count() | ||
def cache(): VariantSampleMatrix[T] = copy[T](rdd = rdd.cache()) | ||
|
||
def repartition(nPartitions: Int) = copy[T](rdd = rdd.repartition(nPartitions)) | ||
|
||
def nPartitions: Int = rdd.partitions.length | ||
|
||
def expand(): RDD[(Variant, Int, T)] | ||
def variants: RDD[Variant] = rdd.keys | ||
|
||
def nVariants: Long = variants.count() | ||
|
||
def write(sqlContext: SQLContext, dirname: String) | ||
def expand(): RDD[(Variant, Int, T)] = | ||
mapWithKeys[(Variant, Int, T)]((v, s, g) => (v, s, g)) | ||
|
||
def mapValuesWithKeys[U](f: (Variant, Int, T) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): VariantSampleMatrix[U] | ||
|
||
def mapValues[U](f: (T) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): VariantSampleMatrix[U] = { | ||
mapValuesWithKeys((v, s, g) => f(g)) | ||
} | ||
|
||
def mapWithKeys[U](f: (Variant, Int, T) => U)(implicit uct: ClassTag[U]): RDD[U] | ||
def mapValuesWithKeys[U](f: (Variant, Int, T) => U) | ||
(implicit utt: TypeTag[U], uct: ClassTag[U]): VariantSampleMatrix[U] = { | ||
val localSamplesBc = sparkContext.broadcast(localSamples) | ||
copy(rdd = rdd.map { case (v, gs) => | ||
(v, localSamplesBc.value.view.zip(gs.view) | ||
.map { case (s, t) => f(v, s, t) }) | ||
}) | ||
} | ||
|
||
def map[U](f: T => U)(implicit uct: ClassTag[U]): RDD[U] = | ||
mapWithKeys((v, s, g) => f(g)) | ||
|
||
def flatMapWithKeys[U](f: (Variant, Int, T) => TraversableOnce[U])(implicit uct: ClassTag[U]): RDD[U] | ||
def mapWithKeys[U](f: (Variant, Int, T) => U)(implicit uct: ClassTag[U]): RDD[U] = { | ||
val localSamplesBc = sparkContext.broadcast(localSamples) | ||
rdd | ||
.flatMap { case (v, gs) => localSamplesBc.value.view.zip(gs.view) | ||
.map { case (s, g) => f(v, s, g) } | ||
} | ||
} | ||
|
||
def flatMap[U](f: T => TraversableOnce[U])(implicit uct: ClassTag[U]): RDD[U] = | ||
flatMapWithKeys((v, s, g) => f(g)) | ||
|
||
def filterVariants(p: (Variant) => Boolean): VariantSampleMatrix[T] | ||
def flatMapWithKeys[U](f: (Variant, Int, T) => TraversableOnce[U])(implicit uct: ClassTag[U]): RDD[U] = { | ||
val localSamplesBc = sparkContext.broadcast(localSamples) | ||
rdd | ||
.flatMap { case (v, gs) => localSamplesBc.value.view.zip(gs.view) | ||
.flatMap { case (s, g) => f(v, s, g) } | ||
} | ||
} | ||
|
||
def filterVariants(ilist: IntervalList): VariantSampleMatrix[T] = | ||
filterVariants(v => ilist.contains(v.contig, v.start)) | ||
|
||
def filterSamples(p: (Int) => Boolean): VariantSampleMatrix[T] | ||
|
||
def aggregateBySampleWithKeys[U](zeroValue: U)( | ||
seqOp: (U, Variant, Int, T) => U, | ||
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Int, U)] | ||
def filterVariants(p: (Variant) => Boolean): VariantSampleMatrix[T] = | ||
copy(rdd = rdd.filter { case (v, _) => p(v) }) | ||
|
||
def filterSamples(p: (Int) => Boolean) = { | ||
val localSamplesBc = sparkContext.broadcast(localSamples) | ||
copy[T](localSamples = localSamples.filter(p), | ||
rdd = rdd.map { case (v, gs) => | ||
(v, localSamplesBc.value.view.zip(gs.view) | ||
.filter { case (s, _) => p(s) } | ||
.map(_._2)) | ||
}) | ||
} | ||
|
||
def aggregateBySample[U](zeroValue: U)( | ||
seqOp: (U, T) => U, | ||
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Int, U)] = | ||
aggregateBySampleWithKeys(zeroValue)((e, v, s, g) => seqOp(e, g), combOp) | ||
|
||
def aggregateByVariantWithKeys[U](zeroValue: U)( | ||
def aggregateBySampleWithKeys[U](zeroValue: U)( | ||
seqOp: (U, Variant, Int, T) => U, | ||
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Variant, U)] | ||
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Int, U)] = { | ||
|
||
val localSamplesBc = sparkContext.broadcast(localSamples) | ||
|
||
val serializer = SparkEnv.get.serializer.newInstance() | ||
val zeroBuffer = serializer.serialize(zeroValue) | ||
val zeroArray = new Array[Byte](zeroBuffer.limit) | ||
zeroBuffer.get(zeroArray) | ||
|
||
rdd | ||
.mapPartitions { (it: Iterator[(Variant, Iterable[T])]) => | ||
val serializer = SparkEnv.get.serializer.newInstance() | ||
def copyZeroValue() = serializer.deserialize[U](ByteBuffer.wrap(zeroArray)) | ||
val arrayZeroValue = Array.fill[U](localSamplesBc.value.length)(copyZeroValue()) | ||
|
||
localSamplesBc.value.iterator | ||
.zip(it.foldLeft(arrayZeroValue) { case (acc, (v, gs)) => | ||
for ((g, i) <- gs.zipWithIndex) | ||
acc(i) = seqOp(acc(i), v, localSamplesBc.value(i), g) | ||
acc | ||
}.iterator) | ||
}.foldByKey(zeroValue)(combOp) | ||
} | ||
|
||
def aggregateByVariant[U](zeroValue: U)( | ||
seqOp: (U, T) => U, | ||
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Variant, U)] = | ||
aggregateByVariantWithKeys(zeroValue)((e, v, s, g) => seqOp(e, g), combOp) | ||
|
||
def foldBySample(zeroValue: T)(combOp: (T, T) => T): RDD[(Int, T)] | ||
def aggregateByVariantWithKeys[U](zeroValue: U)( | ||
seqOp: (U, Variant, Int, T) => U, | ||
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Variant, U)] = { | ||
|
||
val localSamplesBc = sparkContext.broadcast(localSamples) | ||
|
||
// Serialize the zero value to a byte array so that we can get a new clone of it on each key | ||
val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue) | ||
val zeroArray = new Array[Byte](zeroBuffer.limit) | ||
zeroBuffer.get(zeroArray) | ||
|
||
rdd | ||
.map { case (v, gs) => | ||
val serializer = SparkEnv.get.serializer.newInstance() | ||
val zeroValue = serializer.deserialize[U](ByteBuffer.wrap(zeroArray)) | ||
|
||
(v, gs.zipWithIndex.foldLeft(zeroValue) { case (acc, (g, i)) => | ||
seqOp(acc, v, localSamplesBc.value(i), g) | ||
}) | ||
} | ||
} | ||
|
||
def foldBySample(zeroValue: T)(combOp: (T, T) => T): RDD[(Int, T)] = { | ||
|
||
val localSamplesBc = sparkContext.broadcast(localSamples) | ||
val localtct = tct | ||
|
||
val serializer = SparkEnv.get.serializer.newInstance() | ||
val zeroBuffer = serializer.serialize(zeroValue) | ||
val zeroArray = new Array[Byte](zeroBuffer.limit) | ||
zeroBuffer.get(zeroArray) | ||
|
||
rdd | ||
.mapPartitions { (it: Iterator[(Variant, Iterable[T])]) => | ||
val serializer = SparkEnv.get.serializer.newInstance() | ||
def copyZeroValue() = serializer.deserialize[T](ByteBuffer.wrap(zeroArray))(localtct) | ||
val arrayZeroValue = Array.fill[T](localSamplesBc.value.length)(copyZeroValue()) | ||
localSamplesBc.value.iterator | ||
.zip(it.foldLeft(arrayZeroValue) { case (acc, (v, gs)) => | ||
for ((g, i) <- gs.zipWithIndex) | ||
acc(i) = combOp(acc(i), g) | ||
acc | ||
}.iterator) | ||
}.foldByKey(zeroValue)(combOp) | ||
} | ||
|
||
def foldByVariant(zeroValue: T)(combOp: (T, T) => T): RDD[(Variant, T)] = | ||
rdd.mapValues(_.foldLeft(zeroValue)((acc, g) => combOp(acc, g))) | ||
|
||
} | ||
|
||
// FIXME AnyVal Scala 2.11 | ||
class RichVDS(vds: VariantDataset) { | ||
|
||
def write(sqlContext: SQLContext, dirname: String, compress: Boolean = true) { | ||
import sqlContext.implicits._ | ||
|
||
require(dirname.endsWith(".vds")) | ||
|
||
def foldByVariant(zeroValue: T)(combOp: (T, T) => T): RDD[(Variant, T)] | ||
val hConf = vds.sparkContext.hadoopConfiguration | ||
hadoopMkdir(dirname, hConf) | ||
writeObjectFile(dirname + "/metadata.ser", hConf)( | ||
_.writeObject(vds.metadata)) | ||
|
||
// rdd.toDF().write.parquet(dirname + "/rdd.parquet") | ||
vds.rdd | ||
.map { case (v, gs) => (v, gs.toGenotypeStream(v, compress)) } | ||
.toDF() | ||
.saveAsParquetFile(dirname + "/rdd.parquet") | ||
} | ||
} |
15 changes: 15 additions & 0 deletions
15
src/main/scala/org/broadinstitute/hail/variant/package.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.