Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added CommonBlas class #95

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/main/java/org/jblas/CommonBlas.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package org.jblas;

import org.jblas.exceptions.SizeException;

/**
* This class provides some common function interfaces of m*n matrices.
* This class is distinguished from the <class>SimpleBlas</class> which
* provides an interface for the computation of one row or column matrix dot product.
* For instance,it provides an interface of the m*n and n*p matrix inner product.
* <p/>
* For example, you can do any legitimate m*n matrix operation
* <p/>
* Currently, only implements inner,scalar matrix
*/
public class CommonBlas {

/**
* inner product of x with y <- x.y
*
* @return result of compution
* @throws SizeException x colums is inconsistent with y rows
*/
public static FloatMatrix inner(FloatMatrix x,FloatMatrix y) throws SizeException {
FloatMatrix x_t = x.transpose();
if (x_t.rows == y.rows) {
FloatMatrix result = new FloatMatrix(x_t.columns,y.columns);
for (int i = 0; i < x_t.columns; i++) {
for (int j = 0; j < y.columns; j++) {
float value_ij = SimpleBlas.dot(x_t.getColumn(i), y.getColumn(j));
result.put(i, j, value_ij);
}
}
return result;
}else
throw new SizeException("x colums is inconsistent with y rows");
}

/**
* inner product of x with y <- x.y
*
* @return result of compution
* @throws SizeException x colums is inconsistent with y rows
*/
public static DoubleMatrix inner(DoubleMatrix x,DoubleMatrix y) throws SizeException {
DoubleMatrix x_t = x.transpose();
if (x_t.rows == y.rows) {
DoubleMatrix result = new DoubleMatrix(x_t.columns,y.columns);
for (int i = 0; i < x_t.columns; i++) {
for (int j = 0; j < y.columns; j++) {
double value_ij = SimpleBlas.dot(x_t.getColumn(i), y.getColumn(j));
result.put(i, j, value_ij);
}
}
return result;
}else
throw new SizeException("x colums is inconsistent with y rows");
}

/**
* provide a instance of n*n float scalar matrix which value is n
* @param rows rows=column
* @param n value
* @return result
*/
public static FloatMatrix floatScalarMatrixInstance(int rows,float n){
FloatMatrix result = new FloatMatrix(rows,rows);
for(int i=0;i<rows;i++){
result.put(i,i,n);
}
return result;
}

/**
* provide a instance of n*n double scalar matrix which value is n
* @param rows rows=column
* @param n value
* @return result
*/
public static DoubleMatrix doubleScalarMatrixInstance(int rows,double n){
DoubleMatrix result = new DoubleMatrix(rows,rows);
for(int i=0;i<rows;i++){
result.put(i,i,n);
}
return result;
}
}
7 changes: 7 additions & 0 deletions src/main/java/org/jblas/DoubleMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -2226,6 +2226,13 @@ public double dot(DoubleMatrix other) {
return SimpleBlas.dot(this, other);
}

/**
* The inner product of this with other.
*/
public DoubleMatrix inner(DoubleMatrix other) {
return CommonBlas.inner(this, other);
}

/**
* Computes the projection coefficient of other on this.
*
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/org/jblas/FloatMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -2226,6 +2226,13 @@ public float dot(FloatMatrix other) {
return SimpleBlas.dot(this, other);
}

/**
* The inner product of this with other.
*/
public FloatMatrix inner(FloatMatrix other) {
return CommonBlas.inner(this, other);
}

/**
* Computes the projection coefficient of other on this.
*
Expand Down
35 changes: 35 additions & 0 deletions src/test/java/org/jblas/CommonBlasTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.jblas;

import org.junit.Test;

import static org.junit.Assert.assertTrue;

/**
* Some test for class CommonBlas
*
* @author Jason Chen
*/
public class CommonBlasTest {
@Test
public void testInner() {
DoubleMatrix A = new DoubleMatrix(2, 2, 3.0, -3.0, 1.0, 1.0);
DoubleMatrix B = new DoubleMatrix(2,2,1.0,-1.0,2.0,4.0);
DoubleMatrix C = new DoubleMatrix(2,2,2.0,-4.0,10.0,-2.0);

FloatMatrix D = new FloatMatrix(2, 2, 3.0f, -3.0f, 1.0f, 1.0f);
FloatMatrix E = new FloatMatrix(2,2,1.0f,-1.0f,2.0f,4.0f);
FloatMatrix F = new FloatMatrix(2,2,2.0f,-4.0f,10.0f,-2.0f);

assertTrue(CommonBlas.inner(A,B).equals(C));
assertTrue(CommonBlas.inner(D,E).equals(F));
}

@Test
public void testScalar(){
DoubleMatrix scalarDoubleMatrix = CommonBlas.doubleScalarMatrixInstance(2,-9.0);
FloatMatrix scalarFloatMatrix = CommonBlas.floatScalarMatrixInstance(2,-8.0f);

assertTrue(new DoubleMatrix(2,2,-9.0,0.0,0.0,-9.0).equals(scalarDoubleMatrix));
assertTrue(new FloatMatrix(2,2,-8.0f,0,0,-8.0f).equals(scalarFloatMatrix));
}
}