Skip to content

Commit

Permalink
[SPARK-4409] Added JavaAPI Tests, and fixed a couple of bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Nov 26, 2014
1 parent d662f9d commit c75f3cd
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 25 deletions.
29 changes: 17 additions & 12 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ object SparseMatrix {
var i = 0
var nnz = 0
var lastCol = -1

raw.foreach { v =>
val r = i % numRows
val c = (i - r) / numRows
Expand All @@ -378,7 +377,10 @@ object SparseMatrix {
}
i += 1
}
sCols.append(sparseA.length)
while (numCols > lastCol){
sCols.append(sparseA.length)
lastCol += 1
}
new SparseMatrix(numRows, numCols, sCols.toArray, sRows.toArray, sparseA.toArray)
}

Expand All @@ -399,11 +401,11 @@ object SparseMatrix {
s"0.0 < d < 1.0. Currently, density: $density")
val rand = new XORShiftRandom(seed)
val length = numRows * numCols
val rawA = Array.fill(length)(0.0)
val rawA = new Array[Double](length)
var nnz = 0
for (i <- 0 until length) {
val p = rand.nextDouble()
if (p < density) {
if (p <= density) {
rawA.update(i, rand.nextDouble())
nnz += 1
}
Expand Down Expand Up @@ -439,11 +441,11 @@ object SparseMatrix {
s"0.0 < d < 1.0. Currently, density: $density")
val rand = new XORShiftRandom(seed)
val length = numRows * numCols
val rawA = Array.fill(length)(0.0)
val rawA = new Array[Double](length)
var nnz = 0
for (i <- 0 until length) {
val p = rand.nextDouble()
if (p < density) {
if (p <= density) {
rawA.update(i, rand.nextGaussian())
nnz += 1
}
Expand Down Expand Up @@ -476,21 +478,24 @@ object SparseMatrix {
val values = sVec.values
var i = 0
var lastCol = -1
val colPtrs = new ArrayBuffer[Int](n)
val colPtrs = new ArrayBuffer[Int](n + 1)
rows.foreach { r =>
while (r != lastCol) {
colPtrs.append(i)
lastCol += 1
}
i += 1
}
colPtrs.append(n)
while (n > lastCol) {
colPtrs.append(i)
lastCol += 1
}
new SparseMatrix(n, n, colPtrs.toArray, rows, values)
case dVec: DenseVector =>
val values = dVec.values
var i = 0
var nnz = 0
val sVals = values.filter( v => v != 0.0)
val sVals = values.filter(v => v != 0.0)
var lastCol = -1
val colPtrs = new ArrayBuffer[Int](n + 1)
val sRows = new ArrayBuffer[Int](sVals.length)
Expand Down Expand Up @@ -687,10 +692,10 @@ object Matrices {
* Horizontally concatenate a sequence of matrices. The returned matrix will be in the format
* the matrices are supplied in. Supplying a mix of dense and sparse matrices will result in
* a dense matrix.
* @param matrices sequence of matrices
* @param matrices array of matrices
* @return a single `Matrix` composed of the matrices that were horizontally concatenated
*/
private[mllib] def horzCat(matrices: Seq[Matrix]): Matrix = {
def horzcat(matrices: Array[Matrix]): Matrix = {
if (matrices.size == 1) {
return matrices(0)
}
Expand Down Expand Up @@ -744,7 +749,7 @@ object Matrices {
* @param matrices sequence of matrices
* @return a single `Matrix` composed of the matrices that were horizontally concatenated
*/
private[mllib] def vertCat(matrices: Seq[Matrix]): Matrix = {
def vertcat(matrices: Array[Matrix]): Matrix = {
if (matrices.size == 1) {
return matrices(0)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.linalg;

import static org.junit.Assert.*;
import org.junit.Test;

import java.io.Serializable;

public class JavaMatricesSuite implements Serializable {

@Test
public void randMatrixConstruction() {
Matrix r = Matrices.rand(3, 4, 24);
DenseMatrix dr = DenseMatrix.rand(3, 4, 24);
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);

Matrix rn = Matrices.randn(3, 4, 24);
DenseMatrix drn = DenseMatrix.randn(3, 4, 24);
assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);

Matrix s = Matrices.sprand(3, 4, 0.5, 24);
SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, 24);
assertArrayEquals(s.toArray(), sr.toArray(), 0.0);

Matrix sn = Matrices.sprandn(3, 4, 0.5, 24);
SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, 24);
assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
}

@Test
public void identityMatrixConstruction() {
Matrix r = Matrices.eye(2);
DenseMatrix dr = DenseMatrix.eye(2);
SparseMatrix sr = SparseMatrix.speye(2);
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
}

@Test
public void diagonalMatrixConstruction() {
Vector v = Vectors.dense(1.0, 0.0, 2.0);
Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});

Matrix m = Matrices.diag(v);
Matrix sm = Matrices.diag(sv);
DenseMatrix d = DenseMatrix.diag(v);
DenseMatrix sd = DenseMatrix.diag(sv);
SparseMatrix s = SparseMatrix.diag(v);
SparseMatrix ss = SparseMatrix.diag(sv);

assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
assertArrayEquals(s.values(), ss.values(), 0.0);
assert(s.values().length == 2);
assert(ss.values().length == 2);
assert(s.colPtrs().length == 2);
assert(ss.colPtrs().length == 2);
}

@Test
public void zerosMatrixConstruction() {
Matrix z = Matrices.zeros(2, 2);
Matrix one = Matrices.ones(2, 2);
DenseMatrix dz = DenseMatrix.zeros(2, 2);
DenseMatrix done = DenseMatrix.ones(2, 2);

assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
}

@Test
public void concatenateMatrices() {
int m = 3;
int n = 2;

SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, 42);
DenseMatrix deMat1 = DenseMatrix.rand(m, n, 42);
Matrix deMat2 = Matrices.eye(3);
Matrix spMat2 = Matrices.speye(3);
Matrix deMat3 = Matrices.eye(2);
Matrix spMat3 = Matrices.speye(2);

Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});

assert(deHorz1.numRows() == 3);
assert(deHorz2.numRows() == 3);
assert(deHorz3.numRows() == 3);
assert(spHorz.numRows() == 3);
assert(deHorz1.numCols() == 5);
assert(deHorz2.numCols() == 5);
assert(deHorz3.numCols() == 5);
assert(spHorz.numCols() == 5);

Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});

assert(deVert1.numRows() == 5);
assert(deVert2.numRows() == 5);
assert(deVert3.numRows() == 5);
assert(spVert.numRows() == 5);
assert(deVert1.numCols() == 2);
assert(deVert2.numCols() == 2);
assert(deVert3.numCols() == 2);
assert(spVert.numCols() == 2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class MatricesSuite extends FunSuite {
assert(deMat1.toArray === deMat2.toArray)
}

test("horzCat, vertCat, eye, speye") {
test("horzcat, vertcat, eye, speye") {
val m = 3
val n = 2
val values = Array(1.0, 2.0, 4.0, 5.0)
Expand All @@ -147,10 +147,10 @@ class MatricesSuite extends FunSuite {
val deMat3 = Matrices.eye(2)
val spMat3 = Matrices.speye(2)

val spHorz = Matrices.horzCat(Seq(spMat1, spMat2))
val deHorz1 = Matrices.horzCat(Seq(deMat1, deMat2))
val deHorz2 = Matrices.horzCat(Seq(spMat1, deMat2))
val deHorz3 = Matrices.horzCat(Seq(deMat1, spMat2))
val spHorz = Matrices.horzcat(Array(spMat1, spMat2))
val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2))
val deHorz2 = Matrices.horzcat(Array(spMat1, deMat2))
val deHorz3 = Matrices.horzcat(Array(deMat1, spMat2))

assert(deHorz1.numRows === 3)
assert(deHorz2.numRows === 3)
Expand Down Expand Up @@ -179,17 +179,17 @@ class MatricesSuite extends FunSuite {
assert(deHorz1(1, 4) === 0.0)

intercept[IllegalArgumentException] {
Matrices.horzCat(Seq(spMat1, spMat3))
Matrices.horzcat(Array(spMat1, spMat3))
}

intercept[IllegalArgumentException] {
Matrices.horzCat(Seq(deMat1, spMat3))
Matrices.horzcat(Array(deMat1, spMat3))
}

val spVert = Matrices.vertCat(Seq(spMat1, spMat3))
val deVert1 = Matrices.vertCat(Seq(deMat1, deMat3))
val deVert2 = Matrices.vertCat(Seq(spMat1, deMat3))
val deVert3 = Matrices.vertCat(Seq(deMat1, spMat3))
val spVert = Matrices.vertcat(Array(spMat1, spMat3))
val deVert1 = Matrices.vertcat(Array(deMat1, deMat3))
val deVert2 = Matrices.vertcat(Array(spMat1, deMat3))
val deVert3 = Matrices.vertcat(Array(deMat1, spMat3))

assert(deVert1.numRows === 5)
assert(deVert2.numRows === 5)
Expand All @@ -214,11 +214,11 @@ class MatricesSuite extends FunSuite {
assert(deVert1(4, 1) === 1.0)

intercept[IllegalArgumentException] {
Matrices.vertCat(Seq(spMat1, spMat2))
Matrices.vertcat(Array(spMat1, spMat2))
}

intercept[IllegalArgumentException] {
Matrices.vertCat(Seq(deMat1, spMat2))
Matrices.vertcat(Array(deMat1, spMat2))
}
}
}

0 comments on commit c75f3cd

Please sign in to comment.