Skip to content

Commit

Permalink
Merge pull request #83 from broadinstitute/tp_gc
Browse files Browse the repository at this point in the history
Tp gc
  • Loading branch information
cseed committed Dec 2, 2015
2 parents 02f4fec + 1031cf8 commit 4ed94ca
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 266 deletions.
5 changes: 1 addition & 4 deletions src/main/scala/org/broadinstitute/hail/driver/Import.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ object Import extends Command {
@Args4jOption(required = true, name = "-i", aliases = Array("--input"), usage = "Input file")
var input: String = _

@Args4jOption(required = false, name = "-m", aliases = Array("--vsm-type"), usage = "Select VariantSampleMatrix implementation")
var vsmtype: String = "sparky"

@Args4jOption(required = false, name = "-d", aliases = Array("--no-compress"), usage = "Don't compress in-memory representation")
var noCompress: Boolean = false

Expand All @@ -42,7 +39,7 @@ object Import extends Command {
fatal(".gz cannot be loaded in parallel, use .bgz or -f override")
}

LoadVCF(state.sc, input, options.vsmtype, !options.noCompress,
LoadVCF(state.sc, input, !options.noCompress,
if (options.nPartitions != 0)
Some(options.nPartitions)
else
Expand Down
5 changes: 2 additions & 3 deletions src/main/scala/org/broadinstitute/hail/methods/LoadVCF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ object LoadVCF {
// FIXME move to VariantDataset
def apply(sc: SparkContext,
file: String,
vsmtype: String = "sparky",
compress: Boolean = true,
nPartitions: Option[Int] = None): VariantDataset = {

Expand Down Expand Up @@ -43,11 +42,11 @@ object LoadVCF {
val b = new GenotypeStreamBuilder(v, compress)
for (g <- gs)
b += g
(v, b.result())
(v, b.result(): Iterable[Genotype])
}
}

// FIXME null should be contig lengths
VariantSampleMatrix(vsmtype, VariantMetadata(null, sampleIds, headerLines), genotypes)
VariantSampleMatrix(VariantMetadata(null, sampleIds, headerLines), genotypes)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class GenotypeStreamBuilder(variant: Variant, compress: Boolean = true)
this
}

def ++=(i: Iterator[Genotype]): GenotypeStreamBuilder.this.type = {
i.foreach(_.write(b))
this
}

override def clear() {
b.clear()
}
Expand Down
213 changes: 174 additions & 39 deletions src/main/scala/org/broadinstitute/hail/variant/VariantSampleMatrix.scala
Original file line number Diff line number Diff line change
@@ -1,94 +1,229 @@
package org.broadinstitute.hail.variant

import org.apache.spark.SparkContext
import java.nio.ByteBuffer

import org.apache.spark.{SparkEnv, SparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.broadinstitute.hail.Utils._
import org.broadinstitute.hail.variant.vsm.SparkyVSM
import scala.language.implicitConversions

import scala.reflect.ClassTag
import scala.reflect.runtime.universe._


object VariantSampleMatrix {
def apply(vsmtype: String,
metadata: VariantMetadata,
rdd: RDD[(Variant, GenotypeStream)]): VariantSampleMatrix[Genotype] = {
vsmtype match {
case "sparky" => new SparkyVSM(metadata, rdd)
}
def apply(metadata: VariantMetadata,
rdd: RDD[(Variant, Iterable[Genotype])]): VariantDataset = {
new VariantSampleMatrix(metadata, rdd)
}

def read(sqlContext: SQLContext, dirname: String) = {
val (vsmType, metadata) = readObjectFile(dirname + "/metadata.ser", sqlContext.sparkContext.hadoopConfiguration)(
_.readObject().asInstanceOf[(String, VariantMetadata)])
def read(sqlContext: SQLContext, dirname: String): VariantDataset = {
require(dirname.endsWith(".vds"))
import RichRow._

val metadata = readObjectFile(dirname + "/metadata.ser", sqlContext.sparkContext.hadoopConfiguration)(
_.readObject().asInstanceOf[VariantMetadata])

vsmType match {
case "sparky" => SparkyVSM.read(sqlContext, dirname, metadata)
}
// val df = sqlContext.read.parquet(dirname + "/rdd.parquet")
val df = sqlContext.parquetFile(dirname + "/rdd.parquet")
new VariantSampleMatrix[Genotype](metadata, df.rdd.map(r => (r.getVariant(0), r.getGenotypeStream(1))))
}
}

// FIXME all maps should become RDDs
abstract class VariantSampleMatrix[T](val metadata: VariantMetadata,
val localSamples: Array[Int]) {
class VariantSampleMatrix[T](val metadata: VariantMetadata,
val localSamples: Array[Int],
val rdd: RDD[(Variant, Iterable[T])])
(implicit ttt: TypeTag[T], tct: ClassTag[T],
vct: ClassTag[Variant]) {

def this(metadata: VariantMetadata, rdd: RDD[(Variant, Iterable[T])])
(implicit ttt: TypeTag[T], tct: ClassTag[T]) =
this(metadata, Array.range(0, metadata.nSamples), rdd)

def sampleIds: Array[String] = metadata.sampleIds

def nSamples: Int = metadata.sampleIds.length

def nLocalSamples: Int = localSamples.length

def sparkContext: SparkContext
def copy[U](metadata: VariantMetadata = this.metadata,
localSamples: Array[Int] = this.localSamples,
rdd: RDD[(Variant, Iterable[U])] = this.rdd)
(implicit ttt: TypeTag[U], tct: ClassTag[U]): VariantSampleMatrix[U] =
new VariantSampleMatrix(metadata, localSamples, rdd)

// underlying RDD
def nPartitions: Int
def cache(): VariantSampleMatrix[T]
def repartition(nPartitions: Int): VariantSampleMatrix[T]
def sparkContext: SparkContext = rdd.sparkContext

def variants: RDD[Variant]
def nVariants: Long = variants.count()
def cache(): VariantSampleMatrix[T] = copy[T](rdd = rdd.cache())

def repartition(nPartitions: Int) = copy[T](rdd = rdd.repartition(nPartitions))

def nPartitions: Int = rdd.partitions.length

def expand(): RDD[(Variant, Int, T)]
def variants: RDD[Variant] = rdd.keys

def nVariants: Long = variants.count()

def write(sqlContext: SQLContext, dirname: String)
def expand(): RDD[(Variant, Int, T)] =
mapWithKeys[(Variant, Int, T)]((v, s, g) => (v, s, g))

def mapValuesWithKeys[U](f: (Variant, Int, T) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): VariantSampleMatrix[U]

def mapValues[U](f: (T) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): VariantSampleMatrix[U] = {
mapValuesWithKeys((v, s, g) => f(g))
}

def mapWithKeys[U](f: (Variant, Int, T) => U)(implicit uct: ClassTag[U]): RDD[U]
def mapValuesWithKeys[U](f: (Variant, Int, T) => U)
(implicit utt: TypeTag[U], uct: ClassTag[U]): VariantSampleMatrix[U] = {
val localSamplesBc = sparkContext.broadcast(localSamples)
copy(rdd = rdd.map { case (v, gs) =>
(v, localSamplesBc.value.view.zip(gs.view)
.map { case (s, t) => f(v, s, t) })
})
}

def map[U](f: T => U)(implicit uct: ClassTag[U]): RDD[U] =
mapWithKeys((v, s, g) => f(g))

def flatMapWithKeys[U](f: (Variant, Int, T) => TraversableOnce[U])(implicit uct: ClassTag[U]): RDD[U]
def mapWithKeys[U](f: (Variant, Int, T) => U)(implicit uct: ClassTag[U]): RDD[U] = {
val localSamplesBc = sparkContext.broadcast(localSamples)
rdd
.flatMap { case (v, gs) => localSamplesBc.value.view.zip(gs.view)
.map { case (s, g) => f(v, s, g) }
}
}

def flatMap[U](f: T => TraversableOnce[U])(implicit uct: ClassTag[U]): RDD[U] =
flatMapWithKeys((v, s, g) => f(g))

def filterVariants(p: (Variant) => Boolean): VariantSampleMatrix[T]
def flatMapWithKeys[U](f: (Variant, Int, T) => TraversableOnce[U])(implicit uct: ClassTag[U]): RDD[U] = {
val localSamplesBc = sparkContext.broadcast(localSamples)
rdd
.flatMap { case (v, gs) => localSamplesBc.value.view.zip(gs.view)
.flatMap { case (s, g) => f(v, s, g) }
}
}

def filterVariants(ilist: IntervalList): VariantSampleMatrix[T] =
filterVariants(v => ilist.contains(v.contig, v.start))

def filterSamples(p: (Int) => Boolean): VariantSampleMatrix[T]

def aggregateBySampleWithKeys[U](zeroValue: U)(
seqOp: (U, Variant, Int, T) => U,
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Int, U)]
def filterVariants(p: (Variant) => Boolean): VariantSampleMatrix[T] =
copy(rdd = rdd.filter { case (v, _) => p(v) })

def filterSamples(p: (Int) => Boolean) = {
val localSamplesBc = sparkContext.broadcast(localSamples)
copy[T](localSamples = localSamples.filter(p),
rdd = rdd.map { case (v, gs) =>
(v, localSamplesBc.value.view.zip(gs.view)
.filter { case (s, _) => p(s) }
.map(_._2))
})
}

def aggregateBySample[U](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Int, U)] =
aggregateBySampleWithKeys(zeroValue)((e, v, s, g) => seqOp(e, g), combOp)

def aggregateByVariantWithKeys[U](zeroValue: U)(
def aggregateBySampleWithKeys[U](zeroValue: U)(
seqOp: (U, Variant, Int, T) => U,
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Variant, U)]
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Int, U)] = {

val localSamplesBc = sparkContext.broadcast(localSamples)

val serializer = SparkEnv.get.serializer.newInstance()
val zeroBuffer = serializer.serialize(zeroValue)
val zeroArray = new Array[Byte](zeroBuffer.limit)
zeroBuffer.get(zeroArray)

rdd
.mapPartitions { (it: Iterator[(Variant, Iterable[T])]) =>
val serializer = SparkEnv.get.serializer.newInstance()
def copyZeroValue() = serializer.deserialize[U](ByteBuffer.wrap(zeroArray))
val arrayZeroValue = Array.fill[U](localSamplesBc.value.length)(copyZeroValue())

localSamplesBc.value.iterator
.zip(it.foldLeft(arrayZeroValue) { case (acc, (v, gs)) =>
for ((g, i) <- gs.zipWithIndex)
acc(i) = seqOp(acc(i), v, localSamplesBc.value(i), g)
acc
}.iterator)
}.foldByKey(zeroValue)(combOp)
}

def aggregateByVariant[U](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Variant, U)] =
aggregateByVariantWithKeys(zeroValue)((e, v, s, g) => seqOp(e, g), combOp)

def foldBySample(zeroValue: T)(combOp: (T, T) => T): RDD[(Int, T)]
def aggregateByVariantWithKeys[U](zeroValue: U)(
seqOp: (U, Variant, Int, T) => U,
combOp: (U, U) => U)(implicit utt: TypeTag[U], uct: ClassTag[U]): RDD[(Variant, U)] = {

val localSamplesBc = sparkContext.broadcast(localSamples)

// Serialize the zero value to a byte array so that we can get a new clone of it on each key
val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue)
val zeroArray = new Array[Byte](zeroBuffer.limit)
zeroBuffer.get(zeroArray)

rdd
.map { case (v, gs) =>
val serializer = SparkEnv.get.serializer.newInstance()
val zeroValue = serializer.deserialize[U](ByteBuffer.wrap(zeroArray))

(v, gs.zipWithIndex.foldLeft(zeroValue) { case (acc, (g, i)) =>
seqOp(acc, v, localSamplesBc.value(i), g)
})
}
}

def foldBySample(zeroValue: T)(combOp: (T, T) => T): RDD[(Int, T)] = {

val localSamplesBc = sparkContext.broadcast(localSamples)
val localtct = tct

val serializer = SparkEnv.get.serializer.newInstance()
val zeroBuffer = serializer.serialize(zeroValue)
val zeroArray = new Array[Byte](zeroBuffer.limit)
zeroBuffer.get(zeroArray)

rdd
.mapPartitions { (it: Iterator[(Variant, Iterable[T])]) =>
val serializer = SparkEnv.get.serializer.newInstance()
def copyZeroValue() = serializer.deserialize[T](ByteBuffer.wrap(zeroArray))(localtct)
val arrayZeroValue = Array.fill[T](localSamplesBc.value.length)(copyZeroValue())
localSamplesBc.value.iterator
.zip(it.foldLeft(arrayZeroValue) { case (acc, (v, gs)) =>
for ((g, i) <- gs.zipWithIndex)
acc(i) = combOp(acc(i), g)
acc
}.iterator)
}.foldByKey(zeroValue)(combOp)
}

def foldByVariant(zeroValue: T)(combOp: (T, T) => T): RDD[(Variant, T)] =
rdd.mapValues(_.foldLeft(zeroValue)((acc, g) => combOp(acc, g)))

}

// FIXME AnyVal Scala 2.11
class RichVDS(vds: VariantDataset) {

def write(sqlContext: SQLContext, dirname: String, compress: Boolean = true) {
import sqlContext.implicits._

require(dirname.endsWith(".vds"))

def foldByVariant(zeroValue: T)(combOp: (T, T) => T): RDD[(Variant, T)]
val hConf = vds.sparkContext.hadoopConfiguration
hadoopMkdir(dirname, hConf)
writeObjectFile(dirname + "/metadata.ser", hConf)(
_.writeObject(vds.metadata))

// rdd.toDF().write.parquet(dirname + "/rdd.parquet")
vds.rdd
.map { case (v, gs) => (v, gs.toGenotypeStream(v, compress)) }
.toDF()
.saveAsParquetFile(dirname + "/rdd.parquet")
}
}
15 changes: 15 additions & 0 deletions src/main/scala/org/broadinstitute/hail/variant/package.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
package org.broadinstitute.hail

import org.apache.spark.rdd.RDD
import scala.language.implicitConversions
import org.broadinstitute.hail.variant.{GenotypeStream, Variant}

package object variant {
type VariantDataset = VariantSampleMatrix[Genotype]

class RichIterableGenotype(val it: Iterable[Genotype]) extends AnyVal {
def toGenotypeStream(v: Variant, compress: Boolean): GenotypeStream =
it match {
case gs: GenotypeStream => gs
case _ =>
val b: GenotypeStreamBuilder = new GenotypeStreamBuilder(v, compress = compress)
b ++= it
b.result()
}
}

implicit def toRichIterableGenotype(it: Iterable[Genotype]): RichIterableGenotype = new RichIterableGenotype(it)
implicit def toRichVDS(vsm: VariantDataset): RichVDS = new RichVDS(vsm)

// type VariantSampleMatrix[T, S] = managed.ManagedVSM[T, S]
// type VariantSampleMatrix[T, S <: Iterable[(Int, T)]] = sparky.SparkyVSM[T, S]
// def importToVSM(rdd: RDD[(Variant, GenotypeStream)]) = rdd
Expand Down
Loading

0 comments on commit 4ed94ca

Please sign in to comment.