Skip to content

Commit

Permalink
Impl solve using LAPACK
Browse files Browse the repository at this point in the history
  • Loading branch information
termoshtt committed Jul 6, 2020
1 parent f7d93f4 commit a7dc42c
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 73 deletions.
3 changes: 3 additions & 0 deletions lax/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pub mod layout;
pub mod least_squares;
pub mod opnorm;
pub mod qr;
pub mod rcond;
pub mod solve;
pub mod solveh;
pub mod svd;
Expand All @@ -83,6 +84,7 @@ pub use self::eigh::*;
pub use self::least_squares::*;
pub use self::opnorm::*;
pub use self::qr::*;
pub use self::rcond::*;
pub use self::solve::*;
pub use self::solveh::*;
pub use self::svd::*;
Expand All @@ -107,6 +109,7 @@ pub trait Lapack:
+ Eigh_
+ Triangular_
+ Tridiagonal_
+ Rcond_
{
}

Expand Down
78 changes: 78 additions & 0 deletions lax/src/rcond.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use super::*;
use crate::{error::*, layout::MatrixLayout};
use cauchy::*;
use num_traits::Zero;

pub trait Rcond_: Scalar + Sized {
/// Estimates the the reciprocal of the condition number of the matrix in 1-norm.
///
/// `anorm` should be the 1-norm of the matrix `a`.
fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real>;
}

macro_rules! impl_rcond_real {
($scalar:ty, $gecon:path) => {
impl Rcond_ for $scalar {
fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> {
let (n, _) = l.size();
let mut rcond = Self::Real::zero();
let mut info = 0;

let mut work = vec![Self::zero(); 4 * n as usize];
let mut iwork = vec![0; n as usize];
unsafe {
$gecon(
NormType::One as u8,
n,
a,
l.lda(),
anorm,
&mut rcond,
&mut work,
&mut iwork,
&mut info,
)
};
info.as_lapack_result()?;

Ok(rcond)
}
}
};
}

impl_rcond_real!(f32, lapack::sgecon);
impl_rcond_real!(f64, lapack::dgecon);

macro_rules! impl_rcond_complex {
($scalar:ty, $gecon:path) => {
impl Rcond_ for $scalar {
fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> {
let (n, _) = l.size();
let mut rcond = Self::Real::zero();
let mut info = 0;
let mut work = vec![Self::zero(); 2 * n as usize];
let mut rwork = vec![Self::Real::zero(); 2 * n as usize];
unsafe {
$gecon(
NormType::One as u8,
n,
a,
l.lda(),
anorm,
&mut rcond,
&mut work,
&mut rwork,
&mut info,
)
};
info.as_lapack_result()?;

Ok(rcond)
}
}
};
}

impl_rcond_complex!(c32, lapack::cgecon);
impl_rcond_complex!(c64, lapack::zgecon);
114 changes: 41 additions & 73 deletions lax/src/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use super::*;
use crate::{error::*, layout::MatrixLayout};
use cauchy::*;
use num_traits::Zero;
use num_traits::{ToPrimitive, Zero};

pub trait Solve_: Scalar + Sized {
/// Computes the LU factorization of a general `m x n` matrix `a` using
Expand All @@ -14,59 +14,55 @@ pub trait Solve_: Scalar + Sized {
/// Error
/// ------
/// - `LapackComputationalFailure { return_code }` when the matrix is singular
/// - `U[(return_code-1, return_code-1)]` is exactly zero.
/// - Division by zero will occur if it is used to solve a system of equations.
/// - Division by zero will occur if it is used to solve a system of equations
/// because `U[(return_code-1, return_code-1)]` is exactly zero.
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;

fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>;

/// Estimates the the reciprocal of the condition number of the matrix in 1-norm.
///
/// `anorm` should be the 1-norm of the matrix `a`.
fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real>;

fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
}

macro_rules! impl_solve {
($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => {
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
impl Solve_ for $scalar {
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
let (row, col) = l.size();
assert_eq!(a.len() as i32, row * col);
let k = ::std::cmp::min(row, col);
let mut ipiv = vec![0; k as usize];
unsafe {
$getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv)
.as_lapack_result()?;
}
let mut info = 0;
unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) };
info.as_lapack_result()?;
Ok(ipiv)
}

fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
let (n, _) = l.size();
unsafe {
$getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?;
}
Ok(())
}

fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> {
let (n, _) = l.size();
let mut rcond = Self::Real::zero();
// calc work size
let mut info = 0;
let mut work_size = [Self::zero()];
unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) };
info.as_lapack_result()?;

// actual
let lwork = work_size[0].to_usize().unwrap();
let mut work = vec![Self::zero(); lwork];
unsafe {
$gecon(
l.lapacke_layout(),
NormType::One as u8,
n,
$getri(
l.len(),
a,
l.lda(),
anorm,
&mut rcond,
ipiv,
&mut work,
lwork as i32,
&mut info,
)
}
.as_lapack_result()?;
};
info.as_lapack_result()?;

Ok(rcond)
Ok(())
}

fn solve(
Expand All @@ -76,54 +72,26 @@ macro_rules! impl_solve {
ipiv: &Pivot,
b: &mut [Self],
) -> Result<()> {
let t = match l {
MatrixLayout::C { .. } => match t {
Transpose::No => Transpose::Transpose,
Transpose::Transpose | Transpose::Hermite => Transpose::No,
},
_ => t,
};
let (n, _) = l.size();
let nrhs = 1;
let ldb = 1;
unsafe {
$getrs(
l.lapacke_layout(),
t as u8,
n,
nrhs,
a,
l.lda(),
ipiv,
b,
ldb,
)
.as_lapack_result()?;
}
let ldb = l.lda();
let mut info = 0;
unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) };
info.as_lapack_result()?;
Ok(())
}
}
};
} // impl_solve!

impl_solve!(
f64,
lapacke::dgetrf,
lapacke::dgetri,
lapacke::dgecon,
lapacke::dgetrs
);
impl_solve!(
f32,
lapacke::sgetrf,
lapacke::sgetri,
lapacke::sgecon,
lapacke::sgetrs
);
impl_solve!(
c64,
lapacke::zgetrf,
lapacke::zgetri,
lapacke::zgecon,
lapacke::zgetrs
);
impl_solve!(
c32,
lapacke::cgetrf,
lapacke::cgetri,
lapacke::cgecon,
lapacke::cgetrs
);
impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs);
impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs);
impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs);
impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs);

0 comments on commit a7dc42c

Please sign in to comment.