Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added multiarray2 #79

Merged
merged 10 commits into from
Dec 3, 2015
2 changes: 1 addition & 1 deletion src/main/scala/org/broadinstitute/hail/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -496,4 +496,4 @@ object Utils {
toolbox.typeCheck(ast)
toolbox.eval(ast).asInstanceOf[T]
}
}
}
41 changes: 20 additions & 21 deletions src/main/scala/org/broadinstitute/hail/methods/MendelErrors.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.broadinstitute.hail.methods

import org.apache.spark.rdd.RDD
import org.broadinstitute.hail.utils.MultiArray2
import org.broadinstitute.hail.Utils._
import org.broadinstitute.hail.variant._
import org.broadinstitute.hail.variant.GenotypeType._
Expand Down Expand Up @@ -35,7 +36,7 @@ case class MendelError(variant: Variant, trio: CompleteTrio, code: Int,

object MendelErrors {

def getCode(gts: Array[GenotypeType], isHemizygous: Boolean): Int = {
def getCode(gts: IndexedSeq[GenotypeType], isHemizygous: Boolean): Int = {
(gts(1), gts(2), gts(0), isHemizygous) match { // gtDad, gtMom, gtKid, isHemizygous
case (HomRef, HomRef, Het, false) => 2 // Kid is het and not hemizygous
case (HomVar, HomVar, Het, false) => 1
Expand Down Expand Up @@ -67,36 +68,34 @@ object MendelErrors {
// all trios have defined sex, see require above
val trioSexBc = sc.broadcast(trios.map(_.sex.get))

val zeroVal: Array[Array[GenotypeType]] = // FIXME: change to MultiArray2 once available
Array.fill[Array[GenotypeType]](trios.size)(Array.fill[GenotypeType](3)(NoCall))
val zeroVal: MultiArray2[GenotypeType] = MultiArray2.fill(trios.length,3)(NoCall)

def seqOp(a: Array[Array[GenotypeType]], s: Int, g: Genotype): Array[Array[GenotypeType]] = {
sampleTrioRolesBc.value(s).foreach{ case (ti, ri) => a(ti)(ri) = g.gtType }
def seqOp(a: MultiArray2[GenotypeType], s: Int, g: Genotype): MultiArray2[GenotypeType] = {
sampleTrioRolesBc.value(s).foreach{ case (ti, ri) => a.update(ti,ri,g.gtType) }
a
}

def mergeOp(a: Array[Array[GenotypeType]], b: Array[Array[GenotypeType]]): Array[Array[GenotypeType]] = {
for (ti <- a.indices; ri <- 0 until 3)
if (b(ti)(ri) != NoCall)
a(ti)(ri) = b(ti)(ri)
def mergeOp(a: MultiArray2[GenotypeType], b: MultiArray2[GenotypeType]): MultiArray2[GenotypeType] = {
for ((i,j) <- a.indices)
if (b(i,j) != NoCall)
a(i,j) = b(i,j)
a
}

new MendelErrors(trios, vds.sampleIds,
vds
.aggregateByVariantWithKeys(zeroVal)(
(a, v, s, g) => seqOp(a, s, g),
mergeOp)
.flatMap{ case (v, a) =>
a.zipWithIndex.flatMap{ case (ati, ti) =>
val code = getCode(ati, v.isHemizygous(trioSexBc.value(ti)))
if (code != 0)
Some(new MendelError(v, triosBc.value(ti), code, ati(0), ati(1), ati(2)))
else
None
.aggregateByVariantWithKeys(zeroVal)(
(a, v, s, g) => seqOp(a, s, g),
mergeOp)
.flatMap { case (v, a) =>
a.rows.flatMap { case (row) => val code = getCode(row, v.isHemizygous(trioSexBc.value(row.i)))
if (code != 0)
Some(new MendelError(v, triosBc.value(row.i), code, row(0), row(1), row(2)))
else
None
}
}
}
.cache()
.cache()
)
}
}
Expand Down
74 changes: 74 additions & 0 deletions src/main/scala/org/broadinstitute/hail/utils/MultiArray2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.broadinstitute.hail.utils

import java.io.Serializable
import scala.reflect.ClassTag
import scala.collection.immutable.IndexedSeq


class MultiArray2[T](val n1: Int,
val n2: Int,
val a: Array[T]) extends Serializable with Iterable[T] {

require(n1 >= 0 && n2 >= 0)
require(a.length == n1*n2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pervasive use of require is fantastic! Please keep it up.


class Row(val i:Int) extends IndexedSeq[T] {
require(i >= 0 && i < n1)
def apply(j:Int): T = {
if (j < 0 || j >= length) throw new ArrayIndexOutOfBoundsException
a(i*n2 + j)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename idx to j for consistency, although this is debatable. In the apply(), rather than require, you should throw IndexOutOfBoundsException. According to the IndexedSeq.apply docs: Exceptions thrown: IndexOutOfBoundsException if idx does not satisfy 0 <= idx < length. See:

http://www.scala-lang.org/api/2.11.7/#scala.collection.IndexedSeq

and click apply to expand.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still need a check here. If the size is (5, 5), rowSlice(0)(5) will return element (1, 0), but it should throw an error. Basically I'm thinking:

  def apply(j: Int): T = {
    if (j < 0 || j >= n2)
      throw new IndexOutOfBoundsException
      a(...
  }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So as we discussed def apply(...) { if (opposite condition) throw new ArrayIndex...; a(i*n2 + j) }. And same below.

def length: Int = n2
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, here's some thoughts on rowSlice:

  • You did something clever, making a class look like a function. Except to call it, you need to do new a.rowSlice(i) which isn't quite what we want. You could make it a case class or define a companion object with apply.
  • However, I think it is clearer as RowSlice: classes are Capitalized, functions and variables are lowerCase.
  • so you need a def rowSlice(i: Int) = ....
  • RowSlice should be a mutable.IndexedSeq. This will give it an array-like interface. Now would be a good time to read the Scala collections chapter of Introduction to Scala. Make sure you understand the difference between Traversable, Iterable, Iterator and IndexedSeq.


class Column(val j:Int) extends IndexedSeq[T] {
require(j >= 0 && j < n2)
def apply(i:Int): T = {
if (i < 0 || i >= length) throw new ArrayIndexOutOfBoundsException
a(i*n2 + j)
}
def length: Int = n1
}

def row(i:Int) = new Row(i)
def column(j:Int) = new Column(j)

def rows: Iterable[Row] = for (i <- rowIndices) yield row(i)
def columns: Iterable[Column] = for (j <- columnIndices) yield column(j)

def indices: Iterable[(Int,Int)] = for (i <- 0 until n1; j <- 0 until n2) yield (i, j)

def rowIndices: Iterable[Int] = for (i <- 0 until n1) yield i

def columnIndices: Iterable[Int] = for (j <- 0 until n2) yield j

def apply(i: Int, j: Int): T = {
require(i >= 0 && i < n1 && j >= 0 && j < n2)
a(i*n2 + j)
}

def update(i: Int, j: Int, x:T): Unit = {
require(i >= 0 && i < n1 && j >= 0 && j < n2)
a.update(i*n2 + j,x)
}

def update(t: (Int,Int), x:T): Unit = {
require(t._1 >= 0 && t._1 < n1 && t._2 >= 0 && t._2 < n2)
update(t._1,t._2,x)
}

def toArray: Array[T] = a

def zip[S](other: MultiArray2[S]): MultiArray2[(T,S)] = {
require(n1 == other.n1 && n2 == other.n2)
new MultiArray2(n1,n2,a.zip(other.a))
}

def iterator: Iterator[T] = a.iterator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you define iterator you should extend from Iterable. See my comment above about the Scala collections.

}

object MultiArray2 {
def fill[T](n1: Int, n2: Int)(elem: => T)(implicit tct: ClassTag[T]): MultiArray2[T] =
new MultiArray2[T](n1, n2, Array.fill[T](n1 * n2)(elem))
}

104 changes: 104 additions & 0 deletions src/test/scala/org/broadinstitute/hail/methods/MultiArray2Suite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package org.broadinstitute.hail.methods

import org.broadinstitute.hail.SparkSuite
import org.broadinstitute.hail.utils.MultiArray2
import org.testng.annotations.Test

class MultiArray2Suite extends SparkSuite{
@Test def test() = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great work. After yesterday's meeting, my one comment is that we should try to use ScalaCheck here. However, let's get this branch merged first and then do that as a separate pull request.


// test multiarray of size 0 will be created
val ma0 = MultiArray2.fill[Int](0, 0)(0)

// test multiarray of size 0 that get nothing out
intercept[IllegalArgumentException] {
val ma0 = MultiArray2.fill[Int](0, 0)(0)
ma0(0,0)
}

// test array index out of bounds on row slice
intercept[ArrayIndexOutOfBoundsException] {
val foo = MultiArray2.fill[Int](5, 5)(0)
foo.row(0)(5)
}

// bad multiarray initiation -- negative number
intercept[IllegalArgumentException] {
val a = MultiArray2.fill[Int](-5,5)(0)
}

// bad multiarray initiation -- negative number
intercept[IllegalArgumentException] {
val a = MultiArray2.fill[Int](5,-5)(0)
}

val ma1 = MultiArray2.fill[Int](10,3)(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! I want up the testing ante by using ScalaCheck (property-based testing) more extensively. I'll set up tutorial for the week after Thanksgiving and then we can all start using it.

for ((i,j) <- ma1.indices) {
ma1.update(i,j,i*j)
}
assert(ma1(2,2) == 4)
assert(ma1(6,1) == 6)

// Catch exception if try to get value that is not in indices of multiarray
intercept[IllegalArgumentException] {
val foo = ma1(100,100)
}

val ma2 = MultiArray2.fill[Int](10,3)(0)
for ((i,j) <- ma2.indices) {
ma2.update(i,j,i+j)
}


assert(ma2(2,2) == 4)
assert(ma2(6,1) == 7)

// Test zip with two ints
val ma3 = ma1.zip(ma2)
assert(ma3(2,2) == (4,4))
assert(ma3(6,1) == (6,7))

// Test zip with multi-arrays of different types
val ma4 = MultiArray2.fill[String](10,3)("foo")
val ma5 = ma1.zip(ma4)
assert(ma5(2,2) == (4,"foo"))
assert(ma5(0,0) == (0,"foo"))

// Test row slice
for (row <- ma5.rows; idx <- 0 until row.length) {
assert(row(idx) == (row.i*idx,"foo"))
}

intercept[IllegalArgumentException] {
val x = ma5.row(100)
}

intercept[ArrayIndexOutOfBoundsException] {
val x = ma5.row(0)
val y = x(100)
}

intercept[IllegalArgumentException] {
val x = ma5.row(-5)
}

intercept[IllegalArgumentException] {
val x = ma5.column(100)
}

intercept[IllegalArgumentException] {
val x = ma5.column(-5)
}

intercept[ArrayIndexOutOfBoundsException] {
val x = ma5.column(0)
val y = x(100)
}

// Test column slice
for (column <- ma5.columns; idx <- 0 until column.length) {
assert(column(idx) == (column.j*idx,"foo"))
}

}
}