Skip to content

Commit

Permalink
Merge pull request #79 from broadinstitute/jg_multiarray2
Browse files Browse the repository at this point in the history
added multiarray2
  • Loading branch information
cseed committed Dec 3, 2015
2 parents 4ed94ca + dc46651 commit f65627e
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/main/scala/org/broadinstitute/hail/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -496,4 +496,4 @@ object Utils {
toolbox.typeCheck(ast)
toolbox.eval(ast).asInstanceOf[T]
}
}
}
41 changes: 20 additions & 21 deletions src/main/scala/org/broadinstitute/hail/methods/MendelErrors.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.broadinstitute.hail.methods

import org.apache.spark.rdd.RDD
import org.broadinstitute.hail.utils.MultiArray2
import org.broadinstitute.hail.Utils._
import org.broadinstitute.hail.variant._
import org.broadinstitute.hail.variant.GenotypeType._
Expand Down Expand Up @@ -35,7 +36,7 @@ case class MendelError(variant: Variant, trio: CompleteTrio, code: Int,

object MendelErrors {

def getCode(gts: Array[GenotypeType], isHemizygous: Boolean): Int = {
def getCode(gts: IndexedSeq[GenotypeType], isHemizygous: Boolean): Int = {
(gts(1), gts(2), gts(0), isHemizygous) match { // gtDad, gtMom, gtKid, isHemizygous
case (HomRef, HomRef, Het, false) => 2 // Kid is het and not hemizygous
case (HomVar, HomVar, Het, false) => 1
Expand Down Expand Up @@ -67,36 +68,34 @@ object MendelErrors {
// all trios have defined sex, see require above
val trioSexBc = sc.broadcast(trios.map(_.sex.get))

val zeroVal: Array[Array[GenotypeType]] = // FIXME: change to MultiArray2 once available
Array.fill[Array[GenotypeType]](trios.size)(Array.fill[GenotypeType](3)(NoCall))
val zeroVal: MultiArray2[GenotypeType] = MultiArray2.fill(trios.length,3)(NoCall)

def seqOp(a: Array[Array[GenotypeType]], s: Int, g: Genotype): Array[Array[GenotypeType]] = {
sampleTrioRolesBc.value(s).foreach{ case (ti, ri) => a(ti)(ri) = g.gtType }
def seqOp(a: MultiArray2[GenotypeType], s: Int, g: Genotype): MultiArray2[GenotypeType] = {
sampleTrioRolesBc.value(s).foreach{ case (ti, ri) => a.update(ti,ri,g.gtType) }
a
}

def mergeOp(a: Array[Array[GenotypeType]], b: Array[Array[GenotypeType]]): Array[Array[GenotypeType]] = {
for (ti <- a.indices; ri <- 0 until 3)
if (b(ti)(ri) != NoCall)
a(ti)(ri) = b(ti)(ri)
def mergeOp(a: MultiArray2[GenotypeType], b: MultiArray2[GenotypeType]): MultiArray2[GenotypeType] = {
for ((i,j) <- a.indices)
if (b(i,j) != NoCall)
a(i,j) = b(i,j)
a
}

new MendelErrors(trios, vds.sampleIds,
vds
.aggregateByVariantWithKeys(zeroVal)(
(a, v, s, g) => seqOp(a, s, g),
mergeOp)
.flatMap{ case (v, a) =>
a.zipWithIndex.flatMap{ case (ati, ti) =>
val code = getCode(ati, v.isHemizygous(trioSexBc.value(ti)))
if (code != 0)
Some(new MendelError(v, triosBc.value(ti), code, ati(0), ati(1), ati(2)))
else
None
.aggregateByVariantWithKeys(zeroVal)(
(a, v, s, g) => seqOp(a, s, g),
mergeOp)
.flatMap { case (v, a) =>
a.rows.flatMap { case (row) => val code = getCode(row, v.isHemizygous(trioSexBc.value(row.i)))
if (code != 0)
Some(new MendelError(v, triosBc.value(row.i), code, row(0), row(1), row(2)))
else
None
}
}
}
.cache()
.cache()
)
}
}
Expand Down
74 changes: 74 additions & 0 deletions src/main/scala/org/broadinstitute/hail/utils/MultiArray2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.broadinstitute.hail.utils

import java.io.Serializable
import scala.reflect.ClassTag
import scala.collection.immutable.IndexedSeq


class MultiArray2[T](val n1: Int,
val n2: Int,
val a: Array[T]) extends Serializable with Iterable[T] {

require(n1 >= 0 && n2 >= 0)
require(a.length == n1*n2)

class Row(val i:Int) extends IndexedSeq[T] {
require(i >= 0 && i < n1)
def apply(j:Int): T = {
if (j < 0 || j >= length) throw new ArrayIndexOutOfBoundsException
a(i*n2 + j)
}
def length: Int = n2
}

class Column(val j:Int) extends IndexedSeq[T] {
require(j >= 0 && j < n2)
def apply(i:Int): T = {
if (i < 0 || i >= length) throw new ArrayIndexOutOfBoundsException
a(i*n2 + j)
}
def length: Int = n1
}

def row(i:Int) = new Row(i)
def column(j:Int) = new Column(j)

def rows: Iterable[Row] = for (i <- rowIndices) yield row(i)
def columns: Iterable[Column] = for (j <- columnIndices) yield column(j)

def indices: Iterable[(Int,Int)] = for (i <- 0 until n1; j <- 0 until n2) yield (i, j)

def rowIndices: Iterable[Int] = for (i <- 0 until n1) yield i

def columnIndices: Iterable[Int] = for (j <- 0 until n2) yield j

def apply(i: Int, j: Int): T = {
require(i >= 0 && i < n1 && j >= 0 && j < n2)
a(i*n2 + j)
}

def update(i: Int, j: Int, x:T): Unit = {
require(i >= 0 && i < n1 && j >= 0 && j < n2)
a.update(i*n2 + j,x)
}

def update(t: (Int,Int), x:T): Unit = {
require(t._1 >= 0 && t._1 < n1 && t._2 >= 0 && t._2 < n2)
update(t._1,t._2,x)
}

def toArray: Array[T] = a

def zip[S](other: MultiArray2[S]): MultiArray2[(T,S)] = {
require(n1 == other.n1 && n2 == other.n2)
new MultiArray2(n1,n2,a.zip(other.a))
}

def iterator: Iterator[T] = a.iterator
}

object MultiArray2 {
def fill[T](n1: Int, n2: Int)(elem: => T)(implicit tct: ClassTag[T]): MultiArray2[T] =
new MultiArray2[T](n1, n2, Array.fill[T](n1 * n2)(elem))
}

104 changes: 104 additions & 0 deletions src/test/scala/org/broadinstitute/hail/methods/MultiArray2Suite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package org.broadinstitute.hail.methods

import org.broadinstitute.hail.SparkSuite
import org.broadinstitute.hail.utils.MultiArray2
import org.testng.annotations.Test

class MultiArray2Suite extends SparkSuite{
@Test def test() = {

// test multiarray of size 0 will be created
val ma0 = MultiArray2.fill[Int](0, 0)(0)

// test multiarray of size 0 that get nothing out
intercept[IllegalArgumentException] {
val ma0 = MultiArray2.fill[Int](0, 0)(0)
ma0(0,0)
}

// test array index out of bounds on row slice
intercept[ArrayIndexOutOfBoundsException] {
val foo = MultiArray2.fill[Int](5, 5)(0)
foo.row(0)(5)
}

// bad multiarray initiation -- negative number
intercept[IllegalArgumentException] {
val a = MultiArray2.fill[Int](-5,5)(0)
}

// bad multiarray initiation -- negative number
intercept[IllegalArgumentException] {
val a = MultiArray2.fill[Int](5,-5)(0)
}

val ma1 = MultiArray2.fill[Int](10,3)(0)
for ((i,j) <- ma1.indices) {
ma1.update(i,j,i*j)
}
assert(ma1(2,2) == 4)
assert(ma1(6,1) == 6)

// Catch exception if try to get value that is not in indices of multiarray
intercept[IllegalArgumentException] {
val foo = ma1(100,100)
}

val ma2 = MultiArray2.fill[Int](10,3)(0)
for ((i,j) <- ma2.indices) {
ma2.update(i,j,i+j)
}


assert(ma2(2,2) == 4)
assert(ma2(6,1) == 7)

// Test zip with two ints
val ma3 = ma1.zip(ma2)
assert(ma3(2,2) == (4,4))
assert(ma3(6,1) == (6,7))

// Test zip with multi-arrays of different types
val ma4 = MultiArray2.fill[String](10,3)("foo")
val ma5 = ma1.zip(ma4)
assert(ma5(2,2) == (4,"foo"))
assert(ma5(0,0) == (0,"foo"))

// Test row slice
for (row <- ma5.rows; idx <- 0 until row.length) {
assert(row(idx) == (row.i*idx,"foo"))
}

intercept[IllegalArgumentException] {
val x = ma5.row(100)
}

intercept[ArrayIndexOutOfBoundsException] {
val x = ma5.row(0)
val y = x(100)
}

intercept[IllegalArgumentException] {
val x = ma5.row(-5)
}

intercept[IllegalArgumentException] {
val x = ma5.column(100)
}

intercept[IllegalArgumentException] {
val x = ma5.column(-5)
}

intercept[ArrayIndexOutOfBoundsException] {
val x = ma5.column(0)
val y = x(100)
}

// Test column slice
for (column <- ma5.columns; idx <- 0 until column.length) {
assert(column(idx) == (column.j*idx,"foo"))
}

}
}

0 comments on commit f65627e

Please sign in to comment.