Skip to content

Commit

Permalink
Merge pull request #227 from rust-ndarray/least-square-test
Browse files Browse the repository at this point in the history
Revise tests for least-square problems
  • Loading branch information
termoshtt authored Jul 24, 2020
2 parents 9613cfe + f6a9c2a commit 4d0d8c3
Show file tree
Hide file tree
Showing 4 changed files with 356 additions and 194 deletions.
8 changes: 5 additions & 3 deletions lax/src/least_squares.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ macro_rules! impl_least_squares {
}
let k = ::std::cmp::min(m, n);
let nrhs = 1;
let ldb = match a_layout {
MatrixLayout::F { .. } => m.max(n),
MatrixLayout::C { .. } => 1,
};
let rcond: Self::Real = -1.;
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
let mut rank: i32 = 0;
Expand All @@ -54,9 +58,7 @@ macro_rules! impl_least_squares {
a,
a_layout.lda(),
b,
// this is the 'leading dimension of b', in the case where
// b is a single vector, this is 1
nrhs,
ldb,
&mut singular_values,
rcond,
&mut rank,
Expand Down
212 changes: 21 additions & 191 deletions ndarray-linalg/src/least_squares.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
//! // `a` and `b` have been moved, no longer valid
//! ```
use ndarray::{s, Array, Array1, Array2, ArrayBase, Axis, Data, DataMut, Dimension, Ix0, Ix1, Ix2};
use ndarray::*;

use crate::error::*;
use crate::lapack::least_squares::*;
Expand Down Expand Up @@ -352,7 +352,10 @@ where
// we need a new rhs b/c it will be overwritten with the solution
// for which we need `n` entries
let k = rhs.shape()[1];
let mut new_rhs = Array2::<E>::zeros((n, k));
let mut new_rhs = match self.layout()? {
MatrixLayout::C { .. } => Array2::<E>::zeros((n, k)),
MatrixLayout::F { .. } => Array2::<E>::zeros((n, k).f()),
};
new_rhs.slice_mut(s![0..m, ..]).assign(rhs);
compute_least_squares_nrhs(self, &mut new_rhs)
} else {
Expand Down Expand Up @@ -414,117 +417,9 @@ fn compute_residual_array1<E: Scalar, D: Data<Elem = E>>(

#[cfg(test)]
mod tests {
use super::*;
use crate::{error::LinalgError, *};
use approx::AbsDiffEq;
use ndarray::{ArcArray1, ArcArray2, Array1, Array2, CowArray};
use num_complex::Complex;

//
// Test cases taken from the scipy test suite for the scipy lstsq function
// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
//
#[test]
fn scipy_test_simple_exact() {
let a = array![[1., 20.], [-30., 4.]];
let bs = vec![
array![[1., 0.], [0., 1.]],
array![[1.], [0.]],
array![[2., 1.], [-30., 4.]],
];
for b in &bs {
let res = a.least_squares(b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum_axis(Axis(0));
assert!(res
.residual_sum_of_squares
.unwrap()
.abs_diff_eq(&rssq, 1e-12));
assert!(b_hat.abs_diff_eq(&b, 1e-12));
}
}

#[test]
fn scipy_test_simple_overdetermined() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2., 3.];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
assert!(res
.solution
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));
}

#[test]
fn scipy_test_simple_overdetermined_f32() {
let a: Array2<f32> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f32> = array![1., 2., 3.];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-6));
assert!(res
.solution
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-6));
}

fn c(re: f64, im: f64) -> Complex<f64> {
Complex::new(re, im)
}

#[test]
fn scipy_test_simple_overdetermined_complex() {
let a: Array2<c64> = array![
[c(1., 2.), c(2., 0.)],
[c(4., 0.), c(5., 0.)],
[c(3., 0.), c(4., 0.)]
];
let b: Array1<c64> = array![c(1., 0.), c(2., 4.), c(3., 0.)];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
let b_hat = a.dot(&res.solution);
let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum();
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
assert!(res.solution.abs_diff_eq(
&array![
c(-0.4831460674157303, 0.258426966292135),
c(0.921348314606741, 0.292134831460674)
],
1e-12
));
}

#[test]
fn scipy_test_simple_underdetermined() {
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
let b: Array1<f64> = array![1., 2.];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
assert!(res.residual_sum_of_squares.is_none());
let expected = array![-0.055555555555555, 0.111111111111111, 0.277777777777777];
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
}

/// This test case tests the underdetermined case for multiple right hand
/// sides. Adapted from scipy lstsq tests.
#[test]
fn scipy_test_simple_underdetermined_nrhs() {
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
let b: Array2<f64> = array![[1., 1.], [2., 2.]];
let res = a.least_squares(&b).unwrap();
assert_eq!(res.rank, 2);
assert!(res.residual_sum_of_squares.is_none());
let expected = array![
[-0.055555555555555, -0.055555555555555],
[0.111111111111111, 0.111111111111111],
[0.277777777777777, 0.277777777777777]
];
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
}
use ndarray::*;

//
// Test that the different lest squares traits work as intended on the
Expand Down Expand Up @@ -554,23 +449,23 @@ mod tests {
}

#[test]
fn test_least_squares_on_arc() {
fn on_arc() {
let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
let res = a.least_squares(&b).unwrap();
assert_result(&a, &b, &res);
}

#[test]
fn test_least_squares_on_cow() {
fn on_cow() {
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
let b = CowArray::from(array![1., 2., 3.]);
let res = a.least_squares(&b).unwrap();
assert_result(&a, &b, &res);
}

#[test]
fn test_least_squares_on_view() {
fn on_view() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2., 3.];
let av = a.view();
Expand All @@ -580,7 +475,7 @@ mod tests {
}

#[test]
fn test_least_squares_on_view_mut() {
fn on_view_mut() {
let mut a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let mut b: Array1<f64> = array![1., 2., 3.];
let av = a.view_mut();
Expand All @@ -590,7 +485,7 @@ mod tests {
}

#[test]
fn test_least_squares_into_on_owned() {
fn into_on_owned() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2., 3.];
let ac = a.clone();
Expand All @@ -600,7 +495,7 @@ mod tests {
}

#[test]
fn test_least_squares_into_on_arc() {
fn into_on_arc() {
let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
let a2 = a.clone();
Expand All @@ -610,7 +505,7 @@ mod tests {
}

#[test]
fn test_least_squares_into_on_cow() {
fn into_on_cow() {
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
let b = CowArray::from(array![1., 2., 3.]);
let a2 = a.clone();
Expand All @@ -620,7 +515,7 @@ mod tests {
}

#[test]
fn test_least_squares_in_place_on_owned() {
fn in_place_on_owned() {
let a = array![[1., 2.], [4., 5.], [3., 4.]];
let b = array![1., 2., 3.];
let mut a2 = a.clone();
Expand All @@ -630,7 +525,7 @@ mod tests {
}

#[test]
fn test_least_squares_in_place_on_cow() {
fn in_place_on_cow() {
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
let b = CowArray::from(array![1., 2., 3.]);
let mut a2 = a.clone();
Expand All @@ -640,7 +535,7 @@ mod tests {
}

#[test]
fn test_least_squares_in_place_on_mut_view() {
fn in_place_on_mut_view() {
let a = array![[1., 2.], [4., 5.], [3., 4.]];
let b = array![1., 2., 3.];
let mut a2 = a.clone();
Expand All @@ -651,95 +546,30 @@ mod tests {
assert_result(&a, &b, &res);
}

//
// Test cases taken from the netlib documentation at
// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
//
#[test]
fn netlib_lapack_example_for_dgels_1() {
let a: Array2<f64> = array![
[1., 1., 1.],
[2., 3., 4.],
[3., 5., 2.],
[4., 2., 5.],
[5., 4., 3.]
];
let b: Array1<f64> = array![-10., 12., 14., 16., 18.];
let expected: Array1<f64> = array![2., 1., 1.];
let result = a.least_squares(&b).unwrap();
assert!(result.solution.abs_diff_eq(&expected, 1e-12));

let residual = b - a.dot(&result.solution);
let resid_ssq = result.residual_sum_of_squares.unwrap();
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
}

#[test]
fn netlib_lapack_example_for_dgels_2() {
let a: Array2<f64> = array![
[1., 1., 1.],
[2., 3., 4.],
[3., 5., 2.],
[4., 2., 5.],
[5., 4., 3.]
];
let b: Array1<f64> = array![-3., 14., 12., 16., 16.];
let expected: Array1<f64> = array![1., 1., 2.];
let result = a.least_squares(&b).unwrap();
assert!(result.solution.abs_diff_eq(&expected, 1e-12));

let residual = b - a.dot(&result.solution);
let resid_ssq = result.residual_sum_of_squares.unwrap();
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
}

#[test]
fn netlib_lapack_example_for_dgels_nrhs() {
let a: Array2<f64> = array![
[1., 1., 1.],
[2., 3., 4.],
[3., 5., 2.],
[4., 2., 5.],
[5., 4., 3.]
];
let b: Array2<f64> = array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]];
let expected: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]];
let result = a.least_squares(&b).unwrap();
assert!(result.solution.abs_diff_eq(&expected, 1e-12));

let residual = &b - &a.dot(&result.solution);
let residual_ssq = residual.mapv(|x| x.powi(2)).sum_axis(Axis(0));
assert!(result
.residual_sum_of_squares
.unwrap()
.abs_diff_eq(&residual_ssq, 1e-12));
}

//
// Testing error cases
//
use crate::layout::MatrixLayout;

#[test]
fn test_incompatible_shape_error_on_mismatching_num_rows() {
fn incompatible_shape_error_on_mismatching_num_rows() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b: Array1<f64> = array![1., 2.];
let res = a.least_squares(&b);
match res {
Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {}
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
_ => panic!("Expected Err()"),
}
}

#[test]
fn test_incompatible_shape_error_on_mismatching_layout() {
fn incompatible_shape_error_on_mismatching_layout() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b = array![[1.], [2.]].t().to_owned();
assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 });

let res = a.least_squares(&b);
match res {
Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {}
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
_ => panic!("Expected Err()"),
}
}
Expand Down
Loading

0 comments on commit 4d0d8c3

Please sign in to comment.