-
Notifications
You must be signed in to change notification settings - Fork 244
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #79 from broadinstitute/jg_multiarray2
added multiarray2
- Loading branch information
Showing
4 changed files
with
199 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -496,4 +496,4 @@ object Utils { | |
toolbox.typeCheck(ast) | ||
toolbox.eval(ast).asInstanceOf[T] | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
src/main/scala/org/broadinstitute/hail/utils/MultiArray2.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package org.broadinstitute.hail.utils | ||
|
||
import java.io.Serializable | ||
import scala.reflect.ClassTag | ||
import scala.collection.immutable.IndexedSeq | ||
|
||
|
||
class MultiArray2[T](val n1: Int, | ||
val n2: Int, | ||
val a: Array[T]) extends Serializable with Iterable[T] { | ||
|
||
require(n1 >= 0 && n2 >= 0) | ||
require(a.length == n1*n2) | ||
|
||
class Row(val i:Int) extends IndexedSeq[T] { | ||
require(i >= 0 && i < n1) | ||
def apply(j:Int): T = { | ||
if (j < 0 || j >= length) throw new ArrayIndexOutOfBoundsException | ||
a(i*n2 + j) | ||
} | ||
def length: Int = n2 | ||
} | ||
|
||
class Column(val j:Int) extends IndexedSeq[T] { | ||
require(j >= 0 && j < n2) | ||
def apply(i:Int): T = { | ||
if (i < 0 || i >= length) throw new ArrayIndexOutOfBoundsException | ||
a(i*n2 + j) | ||
} | ||
def length: Int = n1 | ||
} | ||
|
||
def row(i:Int) = new Row(i) | ||
def column(j:Int) = new Column(j) | ||
|
||
def rows: Iterable[Row] = for (i <- rowIndices) yield row(i) | ||
def columns: Iterable[Column] = for (j <- columnIndices) yield column(j) | ||
|
||
def indices: Iterable[(Int,Int)] = for (i <- 0 until n1; j <- 0 until n2) yield (i, j) | ||
|
||
def rowIndices: Iterable[Int] = for (i <- 0 until n1) yield i | ||
|
||
def columnIndices: Iterable[Int] = for (j <- 0 until n2) yield j | ||
|
||
def apply(i: Int, j: Int): T = { | ||
require(i >= 0 && i < n1 && j >= 0 && j < n2) | ||
a(i*n2 + j) | ||
} | ||
|
||
def update(i: Int, j: Int, x:T): Unit = { | ||
require(i >= 0 && i < n1 && j >= 0 && j < n2) | ||
a.update(i*n2 + j,x) | ||
} | ||
|
||
def update(t: (Int,Int), x:T): Unit = { | ||
require(t._1 >= 0 && t._1 < n1 && t._2 >= 0 && t._2 < n2) | ||
update(t._1,t._2,x) | ||
} | ||
|
||
def toArray: Array[T] = a | ||
|
||
def zip[S](other: MultiArray2[S]): MultiArray2[(T,S)] = { | ||
require(n1 == other.n1 && n2 == other.n2) | ||
new MultiArray2(n1,n2,a.zip(other.a)) | ||
} | ||
|
||
def iterator: Iterator[T] = a.iterator | ||
} | ||
|
||
object MultiArray2 { | ||
def fill[T](n1: Int, n2: Int)(elem: => T)(implicit tct: ClassTag[T]): MultiArray2[T] = | ||
new MultiArray2[T](n1, n2, Array.fill[T](n1 * n2)(elem)) | ||
} | ||
|
104 changes: 104 additions & 0 deletions
104
src/test/scala/org/broadinstitute/hail/methods/MultiArray2Suite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
package org.broadinstitute.hail.methods | ||
|
||
import org.broadinstitute.hail.SparkSuite | ||
import org.broadinstitute.hail.utils.MultiArray2 | ||
import org.testng.annotations.Test | ||
|
||
class MultiArray2Suite extends SparkSuite{ | ||
@Test def test() = { | ||
|
||
// test multiarray of size 0 will be created | ||
val ma0 = MultiArray2.fill[Int](0, 0)(0) | ||
|
||
// test multiarray of size 0 that get nothing out | ||
intercept[IllegalArgumentException] { | ||
val ma0 = MultiArray2.fill[Int](0, 0)(0) | ||
ma0(0,0) | ||
} | ||
|
||
// test array index out of bounds on row slice | ||
intercept[ArrayIndexOutOfBoundsException] { | ||
val foo = MultiArray2.fill[Int](5, 5)(0) | ||
foo.row(0)(5) | ||
} | ||
|
||
// bad multiarray initiation -- negative number | ||
intercept[IllegalArgumentException] { | ||
val a = MultiArray2.fill[Int](-5,5)(0) | ||
} | ||
|
||
// bad multiarray initiation -- negative number | ||
intercept[IllegalArgumentException] { | ||
val a = MultiArray2.fill[Int](5,-5)(0) | ||
} | ||
|
||
val ma1 = MultiArray2.fill[Int](10,3)(0) | ||
for ((i,j) <- ma1.indices) { | ||
ma1.update(i,j,i*j) | ||
} | ||
assert(ma1(2,2) == 4) | ||
assert(ma1(6,1) == 6) | ||
|
||
// Catch exception if try to get value that is not in indices of multiarray | ||
intercept[IllegalArgumentException] { | ||
val foo = ma1(100,100) | ||
} | ||
|
||
val ma2 = MultiArray2.fill[Int](10,3)(0) | ||
for ((i,j) <- ma2.indices) { | ||
ma2.update(i,j,i+j) | ||
} | ||
|
||
|
||
assert(ma2(2,2) == 4) | ||
assert(ma2(6,1) == 7) | ||
|
||
// Test zip with two ints | ||
val ma3 = ma1.zip(ma2) | ||
assert(ma3(2,2) == (4,4)) | ||
assert(ma3(6,1) == (6,7)) | ||
|
||
// Test zip with multi-arrays of different types | ||
val ma4 = MultiArray2.fill[String](10,3)("foo") | ||
val ma5 = ma1.zip(ma4) | ||
assert(ma5(2,2) == (4,"foo")) | ||
assert(ma5(0,0) == (0,"foo")) | ||
|
||
// Test row slice | ||
for (row <- ma5.rows; idx <- 0 until row.length) { | ||
assert(row(idx) == (row.i*idx,"foo")) | ||
} | ||
|
||
intercept[IllegalArgumentException] { | ||
val x = ma5.row(100) | ||
} | ||
|
||
intercept[ArrayIndexOutOfBoundsException] { | ||
val x = ma5.row(0) | ||
val y = x(100) | ||
} | ||
|
||
intercept[IllegalArgumentException] { | ||
val x = ma5.row(-5) | ||
} | ||
|
||
intercept[IllegalArgumentException] { | ||
val x = ma5.column(100) | ||
} | ||
|
||
intercept[IllegalArgumentException] { | ||
val x = ma5.column(-5) | ||
} | ||
|
||
intercept[ArrayIndexOutOfBoundsException] { | ||
val x = ma5.column(0) | ||
val y = x(100) | ||
} | ||
|
||
// Test column slice | ||
for (column <- ma5.columns; idx <- 0 until column.length) { | ||
assert(column(idx) == (column.j*idx,"foo")) | ||
} | ||
|
||
} | ||
} |