Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose lapack routines for solving least squares problems #197

Merged
merged 11 commits into from
Jun 26, 2020
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ rand = "0.5"

[dependencies.ndarray]
version = "0.13.0"
features = ["blas"]
features = ["blas", "approx"]
default-features = false

[dependencies.blas-src]
Expand All @@ -51,6 +51,7 @@ optional = true
[dev-dependencies]
paste = "0.1.9"
criterion = "0.3.1"
approx = { version = "0.3.2", features = ["num-complex"] }

[[bench]]
name = "truncated_eig"
Expand Down
132 changes: 132 additions & 0 deletions src/lapack/least_squares.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
//! Least squares

use lapacke;
use ndarray::{ErrorKind, ShapeError};
use num_traits::Zero;

use crate::error::*;
use crate::layout::MatrixLayout;
use crate::types::*;

use super::into_result;

/// Result of LeastSquares
pub struct LeastSquaresOutput<A: Scalar> {
/// singular values
pub singular_values: Vec<A::Real>,
/// The rank of the input matrix A
pub rank: i32,
}

/// Wraps `*gelsd`
pub trait LeastSquaresSvdDivideConquer_: Scalar {
unsafe fn least_squares(
a_layout: MatrixLayout,
a: &mut [Self],
b: &mut [Self],
) -> Result<LeastSquaresOutput<Self>>;

unsafe fn least_squares_nrhs(
a_layout: MatrixLayout,
a: &mut [Self],
b_layout: MatrixLayout,
b: &mut [Self],
) -> Result<LeastSquaresOutput<Self>>;
}

macro_rules! impl_least_squares {
($scalar:ty, $gelsd:path) => {
impl LeastSquaresSvdDivideConquer_ for $scalar {
unsafe fn least_squares(
a_layout: MatrixLayout,
a: &mut [Self],
b: &mut [Self],
) -> Result<LeastSquaresOutput<Self>> {
let (m, n) = a_layout.size();
if (m as usize) > b.len() || (n as usize) > b.len() {
return Err(LinalgError::Shape(ShapeError::from_kind(
ErrorKind::IncompatibleShape,
)));
}
let k = ::std::cmp::min(m, n);
let nrhs = 1;
let rcond: Self::Real = -1.;
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
let mut rank: i32 = 0;

let status = $gelsd(
a_layout.lapacke_layout(),
m,
n,
nrhs,
a,
a_layout.lda(),
b,
// this is the 'leading dimension of b', in the case where
// b is a single vector, this is 1
nrhs,
&mut singular_values,
rcond,
&mut rank,
);

into_result(
status,
LeastSquaresOutput {
singular_values,
rank,
},
)
}

unsafe fn least_squares_nrhs(
a_layout: MatrixLayout,
a: &mut [Self],
b_layout: MatrixLayout,
b: &mut [Self],
) -> Result<LeastSquaresOutput<Self>> {
let (m, n) = a_layout.size();
if (m as usize) > b.len()
|| (n as usize) > b.len()
|| a_layout.lapacke_layout() != b_layout.lapacke_layout()
{
return Err(LinalgError::Shape(ShapeError::from_kind(
ErrorKind::IncompatibleShape,
)));
}
let k = ::std::cmp::min(m, n);
let nrhs = b_layout.size().1;
let rcond: Self::Real = -1.;
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
let mut rank: i32 = 0;

let status = $gelsd(
a_layout.lapacke_layout(),
m,
n,
nrhs,
a,
a_layout.lda(),
b,
b_layout.lda(),
&mut singular_values,
rcond,
&mut rank,
);

into_result(
status,
LeastSquaresOutput {
singular_values,
rank,
},
)
}
}
};
}

impl_least_squares!(f64, lapacke::dgelsd);
impl_least_squares!(f32, lapacke::sgelsd);
impl_least_squares!(c64, lapacke::zgelsd);
impl_least_squares!(c32, lapacke::cgelsd);
2 changes: 2 additions & 0 deletions src/lapack/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pub mod cholesky;
pub mod eig;
pub mod eigh;
pub mod least_squares;
pub mod opnorm;
pub mod qr;
pub mod solve;
Expand All @@ -14,6 +15,7 @@ pub mod triangular;
pub use self::cholesky::*;
pub use self::eig::*;
pub use self::eigh::*;
pub use self::least_squares::*;
pub use self::opnorm::*;
pub use self::qr::*;
pub use self::solve::*;
Expand Down
Loading