Skip to content

Commit

Permalink
Clean-up and added unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
StaticBeagle committed Dec 21, 2024
1 parent 0e52f49 commit ab59aef
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ public ComplexMatrixDense(int rows, int cols, double val) {
Arrays.fill(data, val);
}

public static ComplexMatrixDense from2DArray(Complex[][] A) {
return new ComplexMatrixDense(A);
}

/***
* Deep copy
* @return A newly created {@link ComplexMatrixDense} with the values of this current one.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,59 +14,6 @@ public class ComplexQRDecompositionDense extends ComplexQRDecomposition<ComplexM
private int m = 0;
private int n = 0;

public static void main(String[] args) {
// Define a 2x2 complex matrix
Complex[][] A = {
{new Complex(1, 1), new Complex(2, 2)},
{new Complex(3, 3), new Complex(4, 4)}
};

// Perform Householder QR decomposition
ComplexQRDecompositionDense result = new ComplexQRDecompositionDense(new ComplexMatrixDense(A));

// Print results
System.out.println("Matrix Q:");
System.out.println(result.getQ());

// System.out.println("Matrix Q thin:");
// System.out.println(result.getQThin());

System.out.println("Matrix R:");
System.out.println(result.getR());

System.out.println("Matrix H:");
System.out.println(result.getH());

System.out.println("Matrix Q * R:");
System.out.println(result.getQ().multiply(result.getR()));

System.out.println("Check that Q is orthogonal");
System.out.println(result.getQ().multiply(result.getQT()));

Complex[] x = {new Complex(1, 1), new Complex(2, 2)};
Complex[] u = {new Complex(3, 3), new Complex(4, 4)};

Complex[][] y = {x, u};
//houseQR(new ComplexMatrixDense(y));

Complex[] g = {new Complex(60, 70)};
Complex[] f = {new Complex(80, 90)};

ComplexMatrixDense sol = new ComplexQRDecompositionDense(new ComplexMatrixDense(y))
.solve(new ComplexMatrixDense(new Complex[][]{g, f}));
System.out.println("Solution to system");
System.out.println(sol);

ComplexMatrixDense AA = new ComplexMatrixDense(new Complex[][] {
{new Complex(1, 1), new Complex(2, 0)},
{new Complex(2, -1), new Complex(3, 0)}
});

ComplexQRDecompositionDense qrDecompositionDense = new ComplexQRDecompositionDense(AA);
System.out.println(qrDecompositionDense.getQ());
System.out.println(qrDecompositionDense.getR());
}

private static Complex sig(Complex u) {
return u.sign().add(u.equals(new Complex()) ? 1 : 0);
}
Expand Down Expand Up @@ -138,7 +85,6 @@ private static void setColumn(ComplexMatrixDense A, Complex[] values, int col, i
A.set(i + startingRow, col, values[i]);
}
}

private static void setColumn(ComplexMatrixDense A, Complex[] values, int col) {
setColumn(A, values, col, 0);
}
Expand All @@ -149,15 +95,6 @@ private static void setColumn(ComplexMatrixDense A, Complex[] values, int row0,
}
}

// private static ComplexMatrixDense getSubMatrix(ComplexMatrixDense A, int row0, int row1, int col0, int col1) {
// ComplexMatrixDense result = new ComplexMatrixDense(row1 - row0, col1 - col0);
// for(int i = row0; i < row1; i++) {
// for(int j = col0; j < col1; j++) {
// A.set
// }
// }
// }

private static Complex[][] getSubMatrix(ComplexMatrixDense A, int row0, int row1, int col0, int col1) {
Complex[][] result = new Complex[row1 - row0][col1 - col0];
for (int i = row0; i < row1; i++) {
Expand All @@ -169,15 +106,6 @@ private static Complex[][] getSubMatrix(ComplexMatrixDense A, int row0, int row1
}

private static void setSubMatrix(ComplexMatrixDense A, Complex[][] values, int row0, int row1, int col0, int col1) {
// if(col0 == col1) {
// setColumn(A, values[0], row0, row1, col0);
// } else {
// for(int i = row0; i < row1; i++) {
// for(int j = col0; j < col1; j++) {
// A.set(i, j, values[i - row0][j - col0]);
// }
// }
// }
for (int i = row0; i < row1; i++) {
for (int j = col0; j < col1; j++) {
A.set(i, j, values[i - row0][j - col0]);
Expand Down Expand Up @@ -245,30 +173,6 @@ public static ComplexMatrixDense backSubstitutionSolve(ComplexMatrixDense R, Com
return new ComplexMatrixDense(X);
}

// public static ComplexMatrixDense solve(ComplexMatrixDense A, ComplexMatrixDense B) {
// ComplexMatrixDense X = new ComplexMatrixDense(zeros(B.getRowCount(), B.getColumnCount()));
//
// Tuples.Tuple2<ComplexMatrixDense, ComplexMatrixDense> UR = houseQR(A);
// ComplexMatrixDense U = UR.getItem1();
// ComplexMatrixDense R = UR.getItem2();
//
// // Compute Y = transpose(Q) * B
// ComplexMatrixDense Y = houseApplyTranspose(U, B);
//
// // Solve R * X = Y;
// // Back Substitution
//
//
//// if (B.getRowCount() != rows) {
//// throw new IllegalArgumentException("Matrix row dimensions must agree.");
//// }
////// if (!this.isFullRank()) {
////// throw new RuntimeException("Matrix is rank deficient.");
////// }
////
// return backSubstitutionSolve(R, Y);
// }

private ComplexMatrixDense U;
private ComplexMatrixDense R;

Expand Down Expand Up @@ -333,26 +237,6 @@ public ComplexMatrixDense getH() {
public ComplexMatrixDense getR() {
return R;
}
//
// /**
// * Return the upper triangular factor
// *
// * @return R
// */
//
// public MatrixDense getRT() {

// }
//
// /**
// * Generate and return the (economy-sized) orthogonal factor
// *
// * @return Q
// */
//
// public ComplexMatrixDense getQThin() {
// }


/**
* Generate and return the unitary orthogonal factor
Expand All @@ -369,11 +253,11 @@ public ComplexMatrixDense QmultiplyX(ComplexMatrixDense X) {
}

/**
* Generate and return the transpose of the orthogonal factor
* Generate and return the conjugate transpose of the orthogonal factor
*
* @return transpose(Q)
*/
public ComplexMatrixDense getQT() {
public ComplexMatrixDense getQH() {
ComplexMatrixDense I = ComplexMatrixDense.Factory.identity(U.getRowCount(), U.getColumnCount());
return houseApplyTranspose(U, I);
}
Expand Down Expand Up @@ -401,59 +285,16 @@ public ComplexMatrixDense solve(ComplexMatrixDense B) {
// Back Substitution
return backSubstitutionSolve(R, Y);
}
//
//// public Matrix solveTranspose(Matrix B) {
////// if (B.getRowCount() != rows) {
////// throw new IllegalArgumentException("Matrix row dimensions must agree.");
////// }
//// if (!this.isFullRank()) {
//// throw new RuntimeException("Matrix is rank deficient.");
//// }
////
//// // Copy right hand side
//// int nx = B.getColumnCount();
//// double[] X = B.getArrayCopy();
////
//// // Solve RT*X = Y;
//// for (int k = cols - 1; k >= 0; k--) {
//// for (int j = 0; j < nx; j++) {
//// X[k * nx + j] /= _rdiag[k];
//// }
//// for (int i = 0; i < k; i++) {
//// for (int j = 0; j < nx; j++) {
//// X[i * nx + j] -= X[k * nx + j] * _data[i * cols + k];
//// }
//// }
//// }
////
////
//// int mr = Math.min(rows, cols);
//// for (int k = mr - 1; k >= 0; --k) {
//// for (int j = nx - 1; j >= 0; --j) {
//// double s = 0.0;
//// for (int i = k; i < rows; i++) {
//// s += _data[i * cols + k] * X[i * nx + j];
//// }
//// s = -s / _data[k * cols + k];
//// for (int i = k; i < rows; i++) {
//// X[i * nx + j] += s * _data[i * cols + k];
//// }
//// }
//// }
//// return new Matrix(X, B.getRowCount(), B.getColumnCount());
////
//// //return (new Matrix(X, cols, nx).subMatrix(0, cols - 1, 0, nx - 1));
//// }
//
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < rows * cols; ++i) {
if (i > 0 && i % cols == 0) {
sb.append(System.lineSeparator());
}
sb.append(String.format("%.4f", _data[i])).append(" ");

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < rows * cols; ++i) {
if (i > 0 && i % cols == 0) {
sb.append(System.lineSeparator());
}
return sb.toString();
sb.append(String.format("%.4f", _data[i])).append(" ");
}
return sb.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package com.wildbitsfoundry.etk4j.math.linearalgebra;

import com.wildbitsfoundry.etk4j.math.complex.Complex;
import org.junit.Test;

import static org.junit.Assert.assertArrayEquals;

public class ComplexQRDenseTest {

@Test
public void testGetQ() {
// Define a 2x2 complex matrix
Complex[][] A = {
{new Complex(1, 1), new Complex(2, 2)},
{new Complex(3, 3), new Complex(4, 4)}
};

// Perform Householder QR decomposition
ComplexQRDecompositionDense complexQRDecompositionDense = new ComplexQRDecompositionDense(ComplexMatrixDense.from2DArray(A));
Complex[] expected = {new Complex(-0.3162277660168378, 0), new Complex(0.9486832980505124, 0),
new Complex(-0.9486832980505134, 0), new Complex(-0.3162277660168382, 0)};
assertArrayEquals(expected, complexQRDecompositionDense.getQ().getArray());
}

@Test
public void testGetR() {
// Define a 2x2 complex matrix
Complex[][] A = {
{new Complex(1, 1), new Complex(2, 2)},
{new Complex(3, 3), new Complex(4, 4)}
};

// Perform Householder QR decomposition
ComplexQRDecompositionDense complexQRDecompositionDense = new ComplexQRDecompositionDense(ComplexMatrixDense.from2DArray(A));
Complex[] expected = {new Complex(-3.1622776601683773, -3.1622776601683773), new Complex(-4.427188724235729, -4.427188724235729),
new Complex(0, 0), new Complex(0.6324555320336729, 0.6324555320336729)};
assertArrayEquals(expected, complexQRDecompositionDense.getR().getArray());
}

@Test
public void testGetH() {
// Define a 2x2 complex matrix
Complex[][] A = {
{new Complex(1, 1), new Complex(2, 2)},
{new Complex(3, 3), new Complex(4, 4)}
};

// Perform Householder QR decomposition
ComplexQRDecompositionDense complexQRDecompositionDense = new ComplexQRDecompositionDense(ComplexMatrixDense.from2DArray(A));
Complex[] expected = {new Complex(0.8112421851755608, 0.8112421851755608), new Complex(0, 0),
new Complex(0.5847102846637647, 0.5847102846637647), new Complex(-0.9999999999999998, -0.9999999999999998)};
assertArrayEquals(expected, complexQRDecompositionDense.getH().getArray());
}

@Test
public void testGetOriginalMatrixBack() {
// Define a 2x2 complex matrix
Complex[][] A = {
{new Complex(1, 1), new Complex(2, 2)},
{new Complex(3, 3), new Complex(4, 4)}
};

ComplexMatrixDense complexMatrixDense = ComplexMatrixDense.from2DArray(A);
Complex[] expected = {new Complex(0.9999999999999989, 0.9999999999999989), new Complex(1.9999999999999951, 1.9999999999999951),
new Complex(2.999999999999997, 2.999999999999997), new Complex(3.9999999999999973, 3.9999999999999973)};

// Perform Householder QR decomposition
ComplexQRDecompositionDense complexQRDecompositionDense = new ComplexQRDecompositionDense(complexMatrixDense);
assertArrayEquals(expected, complexQRDecompositionDense.getQ().multiply(complexQRDecompositionDense.getR()).getArray());
}

@Test
public void testGetQConjugateTranspose() {
ComplexMatrixDense A = new ComplexMatrixDense(new Complex[][]{
{new Complex(1, 1), new Complex(2, 0)},
{new Complex(2, -1), new Complex(3, 0)}
});

// Perform Householder QR decomposition
ComplexQRDecompositionDense complexQRDecompositionDense = new ComplexQRDecompositionDense(A);
Complex[] expected = {new Complex(-0.5345224838248488, 0), new Complex(-0.26726124191242434, -0.8017837257372731),
new Complex(0.26726124191242423, -0.8017837257372731), new Complex(-0.5345224838248488, 5.551115123125783E-17)};
assertArrayEquals(expected, complexQRDecompositionDense.getQH().getArray());
}

@Test
public void testQisOrthogonal() {
// Define a 2x2 complex matrix
Complex[][] A = {
{new Complex(1, 1), new Complex(2, 2)},
{new Complex(3, 3), new Complex(4, 4)}
};

Complex[] expected = {new Complex(0.9999999999999973, 0), new Complex(-1.1102230246251565E-16, 0),
new Complex(-1.6653345369377348E-16, 0), new Complex(0.9999999999999996, 0)};

// Perform Householder QR decomposition
ComplexQRDecompositionDense complexQRDecompositionDense = new ComplexQRDecompositionDense(ComplexMatrixDense.from2DArray(A));
assertArrayEquals(expected, complexQRDecompositionDense.getQ().multiply(complexQRDecompositionDense.getQH()).getArray());
}

@Test
public void testSolve() {
Complex[] x = {new Complex(1, 1), new Complex(2, 2)};
Complex[] u = {new Complex(3, 3), new Complex(4, 4)};

Complex[][] y = {x, u};
Complex[] g = {new Complex(60, 70)};
Complex[] f = {new Complex(80, 90)};

Complex[] expected = {new Complex(-45.00000000000017, -5.000000000000018),
new Complex(55.000000000000114, 5.000000000000011)};

ComplexMatrixDense sol = new ComplexQRDecompositionDense(new ComplexMatrixDense(y))
.solve(new ComplexMatrixDense(new Complex[][]{g, f}));
assertArrayEquals(expected, sol.getArray());
}
}

0 comments on commit ab59aef

Please sign in to comment.