diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index e827d9cb..0abcb1aa 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -55,3 +55,12 @@ jobs: with: command: test args: --features=openblas --no-default-features + + check-format: + runs-on: ubuntu-18.04 + steps: + - uses: actions/checkout@v1 + - uses: actions-rs/cargo@v1 + with: + command: fmt + args: -- --check diff --git a/examples/eig.rs b/examples/eig.rs index f4195dfc..3e41556a 100644 --- a/examples/eig.rs +++ b/examples/eig.rs @@ -9,4 +9,4 @@ fn main() { let a_c: Array2 = a.map(|f| c64::new(*f, 0.0)); let av = a_c.dot(&vecs); println!("AV = \n{:?}", av); -} \ No newline at end of file +} diff --git a/examples/truncated_svd.rs b/examples/truncated_svd.rs index 8cc164eb..1416368c 100644 --- a/examples/truncated_svd.rs +++ b/examples/truncated_svd.rs @@ -8,7 +8,9 @@ fn main() { let a = arr2(&[[3., 2., 2.], [2., 3., -2.]]); // calculate the truncated singular value decomposition for 2 singular values - let result = TruncatedSvd::new(a, TruncatedOrder::Largest).decompose(2).unwrap(); + let result = TruncatedSvd::new(a, TruncatedOrder::Largest) + .decompose(2) + .unwrap(); // acquire singular values, left-singular vectors and right-singular vectors let (u, sigma, v_t) = result.values_vectors(); diff --git a/rustfmt.toml b/rustfmt.toml deleted file mode 100644 index 28ce2ac8..00000000 --- a/rustfmt.toml +++ /dev/null @@ -1,6 +0,0 @@ -max_width = 120 -hard_tabs = false -tab_spaces = 4 -newline_style = "Unix" -merge_derives = true -force_explicit_abi = true diff --git a/src/cholesky.rs b/src/cholesky.rs index efc135bc..7f38e86b 100644 --- a/src/cholesky.rs +++ b/src/cholesky.rs @@ -166,7 +166,10 @@ where A: Scalar + Lapack, S: Data, { - fn solvec_inplace<'a, Sb>(&self, b: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + fn solvec_inplace<'a, Sb>( + &self, + b: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> where Sb: DataMut, { @@ -327,7 +330,10 @@ pub trait SolveC { /// Solves a system of linear equations `A * x = b` with Hermitian (or real /// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is /// the argument, and `x` is the successful result. - fn solvec_into>(&self, mut b: ArrayBase) -> Result> { + fn solvec_into>( + &self, + mut b: ArrayBase, + ) -> Result> { self.solvec_inplace(&mut b)?; Ok(b) } @@ -346,7 +352,10 @@ where A: Scalar + Lapack, S: Data, { - fn solvec_inplace<'a, Sb>(&self, b: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + fn solvec_inplace<'a, Sb>( + &self, + b: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> where Sb: DataMut, { diff --git a/src/convert.rs b/src/convert.rs index 92b0e99e..52ce0dd8 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -93,7 +93,11 @@ where } else { ArrayBase::from_shape_vec(a.dim().f(), a.into_raw_vec()).unwrap() }; - assert_eq!(new.strides(), strides.as_slice(), "Custom stride is not supported"); + assert_eq!( + new.strides(), + strides.as_slice(), + "Custom stride is not supported" + ); new } diff --git a/src/eig.rs b/src/eig.rs index 69917e45..e9f09080 100644 --- a/src/eig.rs +++ b/src/eig.rs @@ -1,9 +1,9 @@ //! Eigenvalue decomposition for non-symmetric square matrices -use ndarray::*; use crate::error::*; use crate::layout::*; use crate::types::*; +use ndarray::*; /// Eigenvalue decomposition of general matrix reference pub trait Eig { @@ -27,7 +27,12 @@ where let layout = a.square_layout()?; let (s, t) = unsafe { A::eig(true, layout, a.as_allocated_mut()?)? }; let (n, _) = layout.size(); - Ok((ArrayBase::from(s), ArrayBase::from(t).into_shape((n as usize, n as usize)).unwrap())) + Ok(( + ArrayBase::from(s), + ArrayBase::from(t) + .into_shape((n as usize, n as usize)) + .unwrap(), + )) } } @@ -49,4 +54,4 @@ where let (s, _) = unsafe { A::eig(true, a.square_layout()?, a.as_allocated_mut()?)? }; Ok(ArrayBase::from(s)) } -} \ No newline at end of file +} diff --git a/src/error.rs b/src/error.rs index 6fea8673..4e487acf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -24,9 +24,15 @@ pub enum LinalgError { impl fmt::Display for LinalgError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - LinalgError::NotSquare { rows, cols } => write!(f, "Not square: rows({}) != cols({})", rows, cols), - LinalgError::Lapack { return_code } => write!(f, "LAPACK: return_code = {}", return_code), - LinalgError::InvalidStride { s0, s1 } => write!(f, "invalid stride: s0={}, s1={}", s0, s1), + LinalgError::NotSquare { rows, cols } => { + write!(f, "Not square: rows({}) != cols({})", rows, cols) + } + LinalgError::Lapack { return_code } => { + write!(f, "LAPACK: return_code = {}", return_code) + } + LinalgError::InvalidStride { s0, s1 } => { + write!(f, "invalid stride: s0={}, s1={}", s0, s1) + } LinalgError::MemoryNotCont => write!(f, "Memory is not contiguous"), LinalgError::Shape(err) => write!(f, "Shape Error: {}", err), } diff --git a/src/inner.rs b/src/inner.rs index 0ce99a43..87ca96cc 100644 --- a/src/inner.rs +++ b/src/inner.rs @@ -24,7 +24,9 @@ where assert_eq!(self.len(), rhs.len()); Zip::from(self) .and(rhs) - .fold_while(A::zero(), |acc, s, r| FoldWhile::Continue(acc + s.conj() * *r)) + .fold_while(A::zero(), |acc, s, r| { + FoldWhile::Continue(acc + s.conj() * *r) + }) .into_inner() } } diff --git a/src/krylov/arnoldi.rs b/src/krylov/arnoldi.rs index 297cafbf..66c29991 100644 --- a/src/krylov/arnoldi.rs +++ b/src/krylov/arnoldi.rs @@ -97,7 +97,11 @@ where } /// Utility to execute Arnoldi iteration with Householder reflection -pub fn arnoldi_householder(a: impl LinearOperator, v: ArrayBase, tol: A::Real) -> (Q, H) +pub fn arnoldi_householder( + a: impl LinearOperator, + v: ArrayBase, + tol: A::Real, +) -> (Q, H) where A: Scalar + Lapack, S: DataMut, @@ -107,7 +111,11 @@ where } /// Utility to execute Arnoldi iteration with modified Gram-Schmit orthogonalizer -pub fn arnoldi_mgs(a: impl LinearOperator, v: ArrayBase, tol: A::Real) -> (Q, H) +pub fn arnoldi_mgs( + a: impl LinearOperator, + v: ArrayBase, + tol: A::Real, +) -> (Q, H) where A: Scalar + Lapack, S: DataMut, diff --git a/src/krylov/householder.rs b/src/krylov/householder.rs index 9515e7b1..951ecdb0 100644 --- a/src/krylov/householder.rs +++ b/src/krylov/householder.rs @@ -71,7 +71,11 @@ impl Householder { S: DataMut, { assert!(k < self.v.len()); - assert_eq!(a.len(), self.dim, "Input array size mismaches to the dimension"); + assert_eq!( + a.len(), + self.dim, + "Input array size mismaches to the dimension" + ); reflect(&self.v[k].slice(s![k..]), &mut a.slice_mut(s![k..])); } diff --git a/src/lapack/cholesky.rs b/src/lapack/cholesky.rs index 94d29a39..494e004c 100644 --- a/src/lapack/cholesky.rs +++ b/src/lapack/cholesky.rs @@ -18,7 +18,8 @@ pub trait Cholesky_: Sized { /// **Warning: Only the portion of `a` corresponding to `UPLO` is written.** unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; /// Wrapper of `*potrs` - unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; + unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) + -> Result<()>; } macro_rules! impl_cholesky { @@ -36,7 +37,12 @@ macro_rules! impl_cholesky { into_result(info, ()) } - unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()> { + unsafe fn solve_cholesky( + l: MatrixLayout, + uplo: UPLO, + a: &[Self], + b: &mut [Self], + ) -> Result<()> { let (n, _) = l.size(); let nrhs = 1; let ldb = 1; diff --git a/src/lapack/eig.rs b/src/lapack/eig.rs index eec83727..e2a50961 100644 --- a/src/lapack/eig.rs +++ b/src/lapack/eig.rs @@ -11,19 +11,39 @@ use super::into_result; /// Wraps `*geev` for real/complex pub trait Eig_: Scalar { - unsafe fn eig(calc_v: bool, l: MatrixLayout, a: &mut [Self]) -> Result<(Vec, Vec)>; + unsafe fn eig( + calc_v: bool, + l: MatrixLayout, + a: &mut [Self], + ) -> Result<(Vec, Vec)>; } macro_rules! impl_eig_complex { ($scalar:ty, $ev:path) => { impl Eig_ for $scalar { - unsafe fn eig(calc_v: bool, l: MatrixLayout, mut a: &mut [Self]) -> Result<(Vec, Vec)> { + unsafe fn eig( + calc_v: bool, + l: MatrixLayout, + mut a: &mut [Self], + ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); let jobvr = if calc_v { b'V' } else { b'N' }; let mut w = vec![Self::Complex::zero(); n as usize]; let mut vl = Vec::new(); let mut vr = vec![Self::Complex::zero(); (n * n) as usize]; - let info = $ev(l.lapacke_layout(), b'N', jobvr, n, &mut a, n, &mut w, &mut vl, n, &mut vr, n); + let info = $ev( + l.lapacke_layout(), + b'N', + jobvr, + n, + &mut a, + n, + &mut w, + &mut vl, + n, + &mut vr, + n, + ); into_result(info, (w, vr)) } } @@ -33,49 +53,75 @@ macro_rules! impl_eig_complex { macro_rules! impl_eig_real { ($scalar:ty, $ev:path) => { impl Eig_ for $scalar { - unsafe fn eig(calc_v: bool, l: MatrixLayout, mut a: &mut [Self]) -> Result<(Vec, Vec)> { + unsafe fn eig( + calc_v: bool, + l: MatrixLayout, + mut a: &mut [Self], + ) -> Result<(Vec, Vec)> { let (n, _) = l.size(); let jobvr = if calc_v { b'V' } else { b'N' }; let mut wr = vec![Self::Real::zero(); n as usize]; let mut wi = vec![Self::Real::zero(); n as usize]; let mut vl = Vec::new(); let mut vr = vec![Self::Real::zero(); (n * n) as usize]; - let info = $ev(l.lapacke_layout(), b'N', jobvr, n, &mut a, n, &mut wr, &mut wi, &mut vl, n, &mut vr, n); - let w: Vec = wr.iter().zip(wi.iter()).map(|(&r, &i)| Self::Complex::new(r, i)).collect(); + let info = $ev( + l.lapacke_layout(), + b'N', + jobvr, + n, + &mut a, + n, + &mut wr, + &mut wi, + &mut vl, + n, + &mut vr, + n, + ); + let w: Vec = wr + .iter() + .zip(wi.iter()) + .map(|(&r, &i)| Self::Complex::new(r, i)) + .collect(); // If the j-th eigenvalue is real, then // eigenvector = [ vr[j], vr[j+n], vr[j+2*n], ... ]. // - // If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, + // If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, // eigenvector(j) = [ vr[j] + i*vr[j+1], vr[j+n] + i*vr[j+n+1], vr[j+2*n] + i*vr[j+2*n+1], ... ] and // eigenvector(j+1) = [ vr[j] - i*vr[j+1], vr[j+n] - i*vr[j+n+1], vr[j+2*n] - i*vr[j+2*n+1], ... ]. - // + // // Therefore, if eigenvector(j) is written as [ v_{j0}, v_{j1}, v_{j2}, ... ], - // you have to make + // you have to make // v = vec![ v_{00}, v_{10}, v_{20}, ..., v_{jk}, v_{(j+1)k}, v_{(j+2)k}, ... ] (v.len() = n*n) // based on wi and vr. // After that, v is converted to Array2 (see ../eig.rs). let n = n as usize; let mut flg = false; - let conj: Vec = wi.iter().map(|&i| { - if flg { - flg = false; - -1 - } else if i != 0.0 { - flg = true; - 1 - } else { - 0 - } - }).collect(); - let v: Vec = (0..n*n).map(|i| { - let j = i % n; - match conj[j] { - 1 => Self::Complex::new(vr[i], vr[i+1]), - -1 => Self::Complex::new(vr[i-1], -vr[i]), - _ => Self::Complex::new(vr[i], 0.0), - } - }).collect(); - + let conj: Vec = wi + .iter() + .map(|&i| { + if flg { + flg = false; + -1 + } else if i != 0.0 { + flg = true; + 1 + } else { + 0 + } + }) + .collect(); + let v: Vec = (0..n * n) + .map(|i| { + let j = i % n; + match conj[j] { + 1 => Self::Complex::new(vr[i], vr[i + 1]), + -1 => Self::Complex::new(vr[i - 1], -vr[i]), + _ => Self::Complex::new(vr[i], 0.0), + } + }) + .collect(); + into_result(info, (w, v)) } } @@ -85,4 +131,4 @@ macro_rules! impl_eig_real { impl_eig_real!(f64, lapacke::dgeev); impl_eig_real!(f32, lapacke::sgeev); impl_eig_complex!(c64, lapacke::zgeev); -impl_eig_complex!(c32, lapacke::cgeev); \ No newline at end of file +impl_eig_complex!(c32, lapacke::cgeev); diff --git a/src/lapack/eigh.rs b/src/lapack/eigh.rs index 11b40bb0..d77b3b2e 100644 --- a/src/lapack/eigh.rs +++ b/src/lapack/eigh.rs @@ -11,7 +11,12 @@ use super::{into_result, UPLO}; /// Wraps `*syev` for real and `*heev` for complex pub trait Eigh_: Scalar { - unsafe fn eigh(calc_eigenvec: bool, l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result>; + unsafe fn eigh( + calc_eigenvec: bool, + l: MatrixLayout, + uplo: UPLO, + a: &mut [Self], + ) -> Result>; unsafe fn eigh_generalized( calc_eigenvec: bool, l: MatrixLayout, @@ -24,7 +29,12 @@ pub trait Eigh_: Scalar { macro_rules! impl_eigh { ($scalar:ty, $ev:path, $evg:path) => { impl Eigh_ for $scalar { - unsafe fn eigh(calc_v: bool, l: MatrixLayout, uplo: UPLO, mut a: &mut [Self]) -> Result> { + unsafe fn eigh( + calc_v: bool, + l: MatrixLayout, + uplo: UPLO, + mut a: &mut [Self], + ) -> Result> { let (n, _) = l.size(); let jobz = if calc_v { b'V' } else { b'N' }; let mut w = vec![Self::Real::zero(); n as usize]; diff --git a/src/lapack/mod.rs b/src/lapack/mod.rs index 42e72073..6a6903fe 100644 --- a/src/lapack/mod.rs +++ b/src/lapack/mod.rs @@ -28,7 +28,10 @@ use super::types::*; pub type Pivot = Vec; /// Trait for primitive types which implements LAPACK subroutines -pub trait Lapack: OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eig_ + Eigh_ + Triangular_ {} +pub trait Lapack: + OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eig_ + Eigh_ + Triangular_ +{ +} impl Lapack for f32 {} impl Lapack for f64 {} diff --git a/src/lapack/opnorm.rs b/src/lapack/opnorm.rs index 0cdc9640..8e3f9d8d 100644 --- a/src/lapack/opnorm.rs +++ b/src/lapack/opnorm.rs @@ -18,7 +18,9 @@ macro_rules! impl_opnorm { unsafe fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real { match l { MatrixLayout::F((col, lda)) => $lange(cm, t as u8, lda, col, a, lda), - MatrixLayout::C((row, lda)) => $lange(cm, t.transpose() as u8, lda, row, a, lda), + MatrixLayout::C((row, lda)) => { + $lange(cm, t.transpose() as u8, lda, row, a, lda) + } } } } diff --git a/src/lapack/solve.rs b/src/lapack/solve.rs index d71836af..58430995 100644 --- a/src/lapack/solve.rs +++ b/src/lapack/solve.rs @@ -26,7 +26,13 @@ pub trait Solve_: Scalar + Sized { /// /// `anorm` should be the 1-norm of the matrix `a`. unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; - unsafe fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; + unsafe fn solve( + l: MatrixLayout, + t: Transpose, + a: &[Self], + p: &Pivot, + b: &mut [Self], + ) -> Result<()>; } macro_rules! impl_solve { @@ -61,18 +67,58 @@ macro_rules! impl_solve { into_result(info, rcond) } - unsafe fn solve(l: MatrixLayout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> { + unsafe fn solve( + l: MatrixLayout, + t: Transpose, + a: &[Self], + ipiv: &Pivot, + b: &mut [Self], + ) -> Result<()> { let (n, _) = l.size(); let nrhs = 1; let ldb = 1; - let info = $getrs(l.lapacke_layout(), t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb); + let info = $getrs( + l.lapacke_layout(), + t as u8, + n, + nrhs, + a, + l.lda(), + ipiv, + b, + ldb, + ); into_result(info, ()) } } }; } // impl_solve! -impl_solve!(f64, lapacke::dgetrf, lapacke::dgetri, lapacke::dgecon, lapacke::dgetrs); -impl_solve!(f32, lapacke::sgetrf, lapacke::sgetri, lapacke::sgecon, lapacke::sgetrs); -impl_solve!(c64, lapacke::zgetrf, lapacke::zgetri, lapacke::zgecon, lapacke::zgetrs); -impl_solve!(c32, lapacke::cgetrf, lapacke::cgetri, lapacke::cgecon, lapacke::cgetrs); +impl_solve!( + f64, + lapacke::dgetrf, + lapacke::dgetri, + lapacke::dgecon, + lapacke::dgetrs +); +impl_solve!( + f32, + lapacke::sgetrf, + lapacke::sgetri, + lapacke::sgecon, + lapacke::sgetrs +); +impl_solve!( + c64, + lapacke::zgetrf, + lapacke::zgetri, + lapacke::zgecon, + lapacke::zgetrs +); +impl_solve!( + c32, + lapacke::cgetrf, + lapacke::cgetri, + lapacke::cgecon, + lapacke::cgetrs +); diff --git a/src/lapack/solveh.rs b/src/lapack/solveh.rs index e6ea9e60..b17b68ac 100644 --- a/src/lapack/solveh.rs +++ b/src/lapack/solveh.rs @@ -16,7 +16,13 @@ pub trait Solveh_: Sized { /// Wrapper of `*sytri` and `*hetri` unsafe fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; /// Wrapper of `*sytrs` and `*hetrs` - unsafe fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; + unsafe fn solveh( + l: MatrixLayout, + uplo: UPLO, + a: &[Self], + ipiv: &Pivot, + b: &mut [Self], + ) -> Result<()>; } macro_rules! impl_solveh { @@ -34,20 +40,41 @@ macro_rules! impl_solveh { } } - unsafe fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + unsafe fn invh( + l: MatrixLayout, + uplo: UPLO, + a: &mut [Self], + ipiv: &Pivot, + ) -> Result<()> { let (n, _) = l.size(); let info = $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda(), ipiv); into_result(info, ()) } - unsafe fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> { + unsafe fn solveh( + l: MatrixLayout, + uplo: UPLO, + a: &[Self], + ipiv: &Pivot, + b: &mut [Self], + ) -> Result<()> { let (n, _) = l.size(); let nrhs = 1; let ldb = match l { MatrixLayout::C(_) => 1, MatrixLayout::F(_) => n, }; - let info = $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), ipiv, b, ldb); + let info = $trs( + l.lapacke_layout(), + uplo as u8, + n, + nrhs, + a, + l.lda(), + ipiv, + b, + ldb, + ); into_result(info, ()) } } diff --git a/src/lapack/svd.rs b/src/lapack/svd.rs index 5d5b44eb..adc50702 100644 --- a/src/lapack/svd.rs +++ b/src/lapack/svd.rs @@ -29,13 +29,23 @@ pub struct SVDOutput { /// Wraps `*gesvd` pub trait SVD_: Scalar { - unsafe fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result>; + unsafe fn svd( + l: MatrixLayout, + calc_u: bool, + calc_vt: bool, + a: &mut [Self], + ) -> Result>; } macro_rules! impl_svd { ($scalar:ty, $gesvd:path) => { impl SVD_ for $scalar { - unsafe fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result> { + unsafe fn svd( + l: MatrixLayout, + calc_u: bool, + calc_vt: bool, + mut a: &mut [Self], + ) -> Result> { let (m, n) = l.size(); let k = ::std::cmp::min(n, m); let lda = l.lda(); diff --git a/src/lapack/svddc.rs b/src/lapack/svddc.rs index 22f770e3..57b727d1 100644 --- a/src/lapack/svddc.rs +++ b/src/lapack/svddc.rs @@ -15,7 +15,11 @@ pub trait SVDDC_: Scalar { macro_rules! impl_svdd { ($scalar:ty, $gesdd:path) => { impl SVDDC_ for $scalar { - unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self]) -> Result> { + unsafe fn svddc( + l: MatrixLayout, + jobz: UVTFlag, + mut a: &mut [Self], + ) -> Result> { let (m, n) = l.size(); let k = m.min(n); let lda = l.lda(); @@ -47,7 +51,11 @@ macro_rules! impl_svdd { SVDOutput { s: s, u: if jobz == UVTFlag::None { None } else { Some(u) }, - vt: if jobz == UVTFlag::None { None } else { Some(vt) }, + vt: if jobz == UVTFlag::None { + None + } else { + Some(vt) + }, }, ) } diff --git a/src/lapack/triangular.rs b/src/lapack/triangular.rs index 0d3ab874..c3ee706d 100644 --- a/src/lapack/triangular.rs +++ b/src/lapack/triangular.rs @@ -31,7 +31,12 @@ pub trait Triangular_: Sized { macro_rules! impl_triangular { ($scalar:ty, $trtri:path, $trtrs:path) => { impl Triangular_ for $scalar { - unsafe fn inv_triangular(l: MatrixLayout, uplo: UPLO, diag: Diag, a: &mut [Self]) -> Result<()> { + unsafe fn inv_triangular( + l: MatrixLayout, + uplo: UPLO, + diag: Diag, + a: &mut [Self], + ) -> Result<()> { let (n, _) = l.size(); let lda = l.lda(); let info = $trtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda); diff --git a/src/layout.rs b/src/layout.rs index 51099e6f..d75eebc5 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -129,7 +129,9 @@ where } fn as_allocated(&self) -> Result<&[A]> { - Ok(self.as_slice_memory_order().ok_or_else(|| LinalgError::MemoryNotCont)?) + Ok(self + .as_slice_memory_order() + .ok_or_else(|| LinalgError::MemoryNotCont)?) } } diff --git a/src/lobpcg/eig.rs b/src/lobpcg/eig.rs index 6f4d1ac4..f5f3aa3f 100644 --- a/src/lobpcg/eig.rs +++ b/src/lobpcg/eig.rs @@ -1,5 +1,5 @@ use super::lobpcg::{lobpcg, LobpcgResult, Order}; -use crate::{Lapack, Scalar, generate}; +use crate::{generate, Lapack, Scalar}; ///! Implements truncated eigenvalue decomposition /// use ndarray::prelude::*; @@ -87,7 +87,9 @@ impl Truncate } } -impl IntoIterator for TruncatedEig { +impl IntoIterator + for TruncatedEig +{ type Item = (Array1, Array2); type IntoIter = TruncatedEigIterator; @@ -110,7 +112,9 @@ pub struct TruncatedEigIterator { eig: TruncatedEig, } -impl Iterator for TruncatedEigIterator { +impl Iterator + for TruncatedEigIterator +{ type Item = (Array1, Array2); fn next(&mut self) -> Option { @@ -163,13 +167,20 @@ mod tests { #[test] fn test_truncated_eig() { let diag = arr1(&[ - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., + 20., ]); let a = Array2::from_diag(&diag); - let teig = TruncatedEig::new(a, Order::Largest).precision(1e-5).maxiter(500); + let teig = TruncatedEig::new(a, Order::Largest) + .precision(1e-5) + .maxiter(500); - let res = teig.into_iter().take(3).flat_map(|x| x.0.to_vec()).collect::>(); + let res = teig + .into_iter() + .take(3) + .flat_map(|x| x.0.to_vec()) + .collect::>(); let ground_truth = vec![20., 19., 18.]; assert!( diff --git a/src/lobpcg/lobpcg.rs b/src/lobpcg/lobpcg.rs index b9f3e52c..cf59b8de 100644 --- a/src/lobpcg/lobpcg.rs +++ b/src/lobpcg/lobpcg.rs @@ -6,7 +6,7 @@ use crate::error::{LinalgError, Result}; use crate::{cholesky::*, close_l2, eigh::*, norm::*, triangular::*}; use crate::{Lapack, Scalar}; use ndarray::prelude::*; -use ndarray::{OwnedRepr, ScalarOperand, Data}; +use ndarray::{Data, OwnedRepr, ScalarOperand}; use num_traits::{Float, NumCast}; /// Find largest or smallest eigenvalues @@ -46,7 +46,8 @@ fn sorted_eig, A: Scalar + Lapack>( Ok(match order { Order::Largest => ( - vals.slice_move(s![n-size..; -1]).mapv(|x| Scalar::from_real(x)), + vals.slice_move(s![n-size..; -1]) + .mapv(|x| Scalar::from_real(x)), vecs.slice_move(s![.., n-size..; -1]), ), Order::Smallest => ( @@ -220,7 +221,11 @@ pub fn lobpcg< let r = &ax - &lambda_x; // calculate L2 norm of error for every eigenvalue - let residual_norms = r.gencolumns().into_iter().map(|x| x.norm()).collect::>(); + let residual_norms = r + .gencolumns() + .into_iter() + .map(|x| x.norm()) + .collect::>(); residual_norms_history.push(residual_norms.clone()); // compare best result and update if we improved @@ -279,7 +284,9 @@ pub fn lobpcg< }; // if we are once below the max_rnorm, enable explicit gram flag - let max_norm = residual_norms.into_iter().fold(A::Real::neg_infinity(), A::Real::max); + let max_norm = residual_norms + .into_iter() + .fold(A::Real::neg_infinity(), A::Real::max); explicit_gram_flag = max_norm <= max_rnorm_float || explicit_gram_flag; // perform the Rayleigh Ritz procedure @@ -293,7 +300,12 @@ pub fn lobpcg< rar = (&rar + &rar.t()) / two; let xax = x.t().dot(&ax); - ((&xax + &xax.t()) / two, x.t().dot(&x), r.t().dot(&r), x.t().dot(&r)) + ( + (&xax + &xax.t()) / two, + x.t().dot(&x), + r.t().dot(&r), + x.t().dot(&r), + ) } else { ( lambda_diag, @@ -324,7 +336,8 @@ pub fn lobpcg< // // first try to compute the eigenvalue decomposition of the span{R, X, P}, // if this fails (or the algorithm was restarted), then just use span{R, X} - let result = p_ap.as_ref() + let result = p_ap + .as_ref() .ok_or(LinalgError::Lapack { return_code: 1 }) .and_then(|(active_p, active_ap)| { let xap = x.t().dot(active_ap); @@ -352,29 +365,36 @@ pub fn lobpcg< stack![Axis(1), xp.t(), rp.t(), pp] ]), size_x, - &order + &order, ) }) .or_else(|_| { p_ap = None; sorted_eig( - stack![Axis(0), stack![Axis(1), xax, xar], stack![Axis(1), xar.t(), rar]], - Some(stack![Axis(0), stack![Axis(1), xx, xr], stack![Axis(1), xr.t(), rr]]), + stack![ + Axis(0), + stack![Axis(1), xax, xar], + stack![Axis(1), xar.t(), rar] + ], + Some(stack![ + Axis(0), + stack![Axis(1), xx, xr], + stack![Axis(1), xr.t(), rr] + ]), size_x, - &order + &order, ) }); - // update eigenvalues and eigenvectors (lambda is also used in the next iteration) let eig_vecs; match result { Ok((x, y)) => { lambda = x; eig_vecs = y; - }, - Err(x) => break Err(x) + } + Err(x) => break Err(x), } // approximate eigenvector X and conjugate vectors P with solution of eigenproblem @@ -432,8 +452,8 @@ mod tests { use super::LobpcgResult; use super::Order; use crate::close_l2; - use crate::qr::*; use crate::generate; + use crate::qr::*; use ndarray::prelude::*; /// Test the `sorted_eigen` function @@ -457,7 +477,11 @@ mod tests { fn test_masking() { let matrix: Array2 = generate::random((10, 5)) * 10.0; let masked_matrix = ndarray_mask(matrix.view(), &[true, true, false, true, false]); - close_l2(&masked_matrix.slice(s![.., 2]), &matrix.slice(s![.., 3]), 1e-12); + close_l2( + &masked_matrix.slice(s![.., 2]), + &matrix.slice(s![.., 3]), + 1e-12, + ); } /// Test orthonormalization of a random matrix @@ -500,7 +524,11 @@ mod tests { // check correct order of eigenvalues if ground_truth_eigvals.len() == num { - close_l2(&Array1::from(ground_truth_eigvals.to_vec()), &vals, num as f64 * 5e-4) + close_l2( + &Array1::from(ground_truth_eigvals.to_vec()), + &vals, + num as f64 * 5e-4, + ) } } LobpcgResult::NoResult(err) => panic!("Did not converge: {:?}", err), @@ -511,7 +539,8 @@ mod tests { #[test] fn test_eigsolver_diag() { let diag = arr1(&[ - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., + 20., ]); let a = Array2::from_diag(&diag); @@ -547,7 +576,15 @@ mod tests { ]) .reversed_axes(); - let result = lobpcg(|y| a.dot(&y), x, |_| {}, Some(y), 1e-10, 50, Order::Smallest); + let result = lobpcg( + |y| a.dot(&y), + x, + |_| {}, + Some(y), + 1e-10, + 50, + Order::Smallest, + ); match result { LobpcgResult::Ok(vals, vecs, r_norms) | LobpcgResult::Err(vals, vecs, r_norms, _) => { // check convergence diff --git a/src/lobpcg/svd.rs b/src/lobpcg/svd.rs index bc85f06f..1324034b 100644 --- a/src/lobpcg/svd.rs +++ b/src/lobpcg/svd.rs @@ -3,7 +3,7 @@ ///! This module computes the k largest/smallest singular values/vectors for a dense matrix. use super::lobpcg::{lobpcg, LobpcgResult, Order}; use crate::error::Result; -use crate::{Lapack, Scalar, generate}; +use crate::{generate, Lapack, Scalar}; use ndarray::prelude::*; use ndarray::ScalarOperand; use num_traits::{Float, NumCast}; @@ -158,12 +158,14 @@ impl Truncate // convert into TruncatedSvdResult match res { - LobpcgResult::Ok(vals, vecs, _) | LobpcgResult::Err(vals, vecs, _, _) => Ok(TruncatedSvdResult { - problem: self.problem.clone(), - eigvals: vals, - eigvecs: vecs, - ngm: n > m, - }), + LobpcgResult::Ok(vals, vecs, _) | LobpcgResult::Err(vals, vecs, _, _) => { + Ok(TruncatedSvdResult { + problem: self.problem.clone(), + eigvals: vals, + eigvecs: vecs, + ngm: n > m, + }) + } LobpcgResult::NoResult(err) => Err(err), } } diff --git a/src/norm.rs b/src/norm.rs index 55259409..94e50382 100644 --- a/src/norm.rs +++ b/src/norm.rs @@ -53,7 +53,10 @@ pub enum NormalizeAxis { } /// normalize in L2 norm -pub fn normalize(mut m: ArrayBase, axis: NormalizeAxis) -> (ArrayBase, Vec) +pub fn normalize( + mut m: ArrayBase, + axis: NormalizeAxis, +) -> (ArrayBase, Vec) where A: Scalar + Lapack, S: DataMut, diff --git a/src/solve.rs b/src/solve.rs index 7346ad1b..e595fb6e 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -84,7 +84,10 @@ pub trait Solve { } /// Solves a system of linear equations `A * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. - fn solve_into>(&self, mut b: ArrayBase) -> Result> { + fn solve_into>( + &self, + mut b: ArrayBase, + ) -> Result> { self.solve_inplace(&mut b)?; Ok(b) } @@ -104,7 +107,10 @@ pub trait Solve { } /// Solves a system of linear equations `A^T * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. - fn solve_t_into>(&self, mut b: ArrayBase) -> Result> { + fn solve_t_into>( + &self, + mut b: ArrayBase, + ) -> Result> { self.solve_t_inplace(&mut b)?; Ok(b) } @@ -124,7 +130,10 @@ pub trait Solve { } /// Solves a system of linear equations `A^H * x = b` where `A` is `self`, `b` /// is the argument, and `x` is the successful result. - fn solve_h_into>(&self, mut b: ArrayBase) -> Result> { + fn solve_h_into>( + &self, + mut b: ArrayBase, + ) -> Result> { self.solve_h_inplace(&mut b)?; Ok(b) } @@ -151,7 +160,10 @@ where A: Scalar + Lapack, S: Data + RawDataClone, { - fn solve_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + fn solve_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> where Sb: DataMut, { @@ -166,7 +178,10 @@ where }; Ok(rhs) } - fn solve_t_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + fn solve_t_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> where Sb: DataMut, { @@ -181,7 +196,10 @@ where }; Ok(rhs) } - fn solve_h_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + fn solve_h_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> where Sb: DataMut, { @@ -203,21 +221,30 @@ where A: Scalar + Lapack, S: Data, { - fn solve_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + fn solve_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> where Sb: DataMut, { let f = self.factorize()?; f.solve_inplace(rhs) } - fn solve_t_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + fn solve_t_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> where Sb: DataMut, { let f = self.factorize()?; f.solve_t_inplace(rhs) } - fn solve_h_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + fn solve_h_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> where Sb: DataMut, { @@ -247,7 +274,10 @@ where { fn factorize_into(mut self) -> Result> { let ipiv = unsafe { A::lu(self.layout()?, self.as_allocated_mut()?)? }; - Ok(LUFactorized { a: self, ipiv: ipiv }) + Ok(LUFactorized { + a: self, + ipiv: ipiv, + }) } } @@ -285,7 +315,13 @@ where type Output = ArrayBase; fn inv_into(mut self) -> Result> { - unsafe { A::inv(self.a.square_layout()?, self.a.as_allocated_mut()?, &self.ipiv)? }; + unsafe { + A::inv( + self.a.square_layout()?, + self.a.as_allocated_mut()?, + &self.ipiv, + )? + }; Ok(self.a) } } @@ -399,10 +435,16 @@ where } else { -A::one() }; - let (upper_sign, ln_det) = u_diag_iter.fold((A::one(), A::Real::zero()), |(upper_sign, ln_det), &elem| { - let abs_elem: A::Real = elem.abs(); - (upper_sign * elem / A::from_real(abs_elem), ln_det + abs_elem.ln()) - }); + let (upper_sign, ln_det) = u_diag_iter.fold( + (A::one(), A::Real::zero()), + |(upper_sign, ln_det), &elem| { + let abs_elem: A::Real = elem.abs(); + ( + upper_sign * elem / A::from_real(abs_elem), + ln_det + abs_elem.ln(), + ) + }, + ); (pivot_sign * upper_sign, ln_det) } @@ -498,7 +540,13 @@ where S: Data + RawDataClone, { fn rcond(&self) -> Result { - unsafe { A::rcond(self.a.layout()?, self.a.as_allocated()?, self.a.opnorm_one()?) } + unsafe { + A::rcond( + self.a.layout()?, + self.a.as_allocated()?, + self.a.opnorm_one()?, + ) + } } } diff --git a/src/solveh.rs b/src/solveh.rs index 582eeb26..0e12720c 100644 --- a/src/solveh.rs +++ b/src/solveh.rs @@ -77,7 +77,10 @@ pub trait SolveH { /// Solves a system of linear equations `A * x = b` with Hermitian (or real /// symmetric) matrix `A`, where `A` is `self`, `b` is the argument, and /// `x` is the successful result. - fn solveh_into>(&self, mut b: ArrayBase) -> Result> { + fn solveh_into>( + &self, + mut b: ArrayBase, + ) -> Result> { self.solveh_inplace(&mut b)?; Ok(b) } @@ -103,7 +106,10 @@ where A: Scalar + Lapack, S: Data, { - fn solveh_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + fn solveh_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> where Sb: DataMut, { @@ -125,7 +131,10 @@ where A: Scalar + Lapack, S: Data, { - fn solveh_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + fn solveh_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> where Sb: DataMut, { @@ -157,7 +166,10 @@ where { fn factorizeh_into(mut self) -> Result> { let ipiv = unsafe { A::bk(self.square_layout()?, UPLO::Upper, self.as_allocated_mut()?)? }; - Ok(BKFactorized { a: self, ipiv: ipiv }) + Ok(BKFactorized { + a: self, + ipiv: ipiv, + }) } } diff --git a/src/svd.rs b/src/svd.rs index 7ad2dc8f..9bb90977 100644 --- a/src/svd.rs +++ b/src/svd.rs @@ -14,7 +14,11 @@ pub trait SVD { type U; type VT; type Sigma; - fn svd(&self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)>; + fn svd( + &self, + calc_u: bool, + calc_vt: bool, + ) -> Result<(Option, Self::Sigma, Option)>; } /// singular-value decomposition @@ -22,7 +26,11 @@ pub trait SVDInto { type U; type VT; type Sigma; - fn svd_into(self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)>; + fn svd_into( + self, + calc_u: bool, + calc_vt: bool, + ) -> Result<(Option, Self::Sigma, Option)>; } /// singular-value decomposition for mutable reference of matrix @@ -30,7 +38,11 @@ pub trait SVDInplace { type U; type VT; type Sigma; - fn svd_inplace(&mut self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)>; + fn svd_inplace( + &mut self, + calc_u: bool, + calc_vt: bool, + ) -> Result<(Option, Self::Sigma, Option)>; } impl SVDInto for ArrayBase @@ -42,7 +54,11 @@ where type VT = Array2; type Sigma = Array1; - fn svd_into(mut self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)> { + fn svd_into( + mut self, + calc_u: bool, + calc_vt: bool, + ) -> Result<(Option, Self::Sigma, Option)> { self.svd_inplace(calc_u, calc_vt) } } @@ -56,7 +72,11 @@ where type VT = Array2; type Sigma = Array1; - fn svd(&self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)> { + fn svd( + &self, + calc_u: bool, + calc_vt: bool, + ) -> Result<(Option, Self::Sigma, Option)> { let a = self.to_owned(); a.svd_into(calc_u, calc_vt) } @@ -71,7 +91,11 @@ where type VT = Array2; type Sigma = Array1; - fn svd_inplace(&mut self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)> { + fn svd_inplace( + &mut self, + calc_u: bool, + calc_vt: bool, + ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? }; let (n, m) = l.size(); diff --git a/src/svddc.rs b/src/svddc.rs index 97d63d29..f2c102fd 100644 --- a/src/svddc.rs +++ b/src/svddc.rs @@ -34,7 +34,10 @@ pub trait SVDDCInto { type U; type VT; type Sigma; - fn svddc_into(self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)>; + fn svddc_into( + self, + uvt_flag: UVTFlag, + ) -> Result<(Option, Self::Sigma, Option)>; } /// Singular-value decomposition of matrix reference by divide-and-conquer @@ -42,7 +45,10 @@ pub trait SVDDCInplace { type U; type VT; type Sigma; - fn svddc_inplace(&mut self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)>; + fn svddc_inplace( + &mut self, + uvt_flag: UVTFlag, + ) -> Result<(Option, Self::Sigma, Option)>; } impl SVDDC for ArrayBase @@ -68,7 +74,10 @@ where type VT = Array2; type Sigma = Array1; - fn svddc_into(mut self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)> { + fn svddc_into( + mut self, + uvt_flag: UVTFlag, + ) -> Result<(Option, Self::Sigma, Option)> { self.svddc_inplace(uvt_flag) } } @@ -82,7 +91,10 @@ where type VT = Array2; type Sigma = Array1; - fn svddc_inplace(&mut self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)> { + fn svddc_inplace( + &mut self, + uvt_flag: UVTFlag, + ) -> Result<(Option, Self::Sigma, Option)> { let l = self.layout()?; let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? }; let (m, n) = l.size(); diff --git a/src/triangular.rs b/src/triangular.rs index 4b4dc7bc..c54beafd 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -27,7 +27,12 @@ where S: DataMut, D: Dimension, { - fn solve_triangular_into(&self, uplo: UPLO, diag: Diag, b: ArrayBase) -> Result>; + fn solve_triangular_into( + &self, + uplo: UPLO, + diag: Diag, + b: ArrayBase, + ) -> Result>; } /// solve a triangular system with upper triangular matrix @@ -50,7 +55,12 @@ where Si: Data, So: DataMut + DataOwned, { - fn solve_triangular_into(&self, uplo: UPLO, diag: Diag, mut b: ArrayBase) -> Result> { + fn solve_triangular_into( + &self, + uplo: UPLO, + diag: Diag, + mut b: ArrayBase, + ) -> Result> { self.solve_triangular_inplace(uplo, diag, &mut b)?; Ok(b) } @@ -86,7 +96,12 @@ where Si: Data, So: DataMut + DataOwned, { - fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: &ArrayBase) -> Result> { + fn solve_triangular( + &self, + uplo: UPLO, + diag: Diag, + b: &ArrayBase, + ) -> Result> { let b = replicate(b); self.solve_triangular_into(uplo, diag, b) } @@ -98,7 +113,12 @@ where Si: Data, So: DataMut + DataOwned, { - fn solve_triangular_into(&self, uplo: UPLO, diag: Diag, b: ArrayBase) -> Result> { + fn solve_triangular_into( + &self, + uplo: UPLO, + diag: Diag, + b: ArrayBase, + ) -> Result> { let b = into_col(b); let b = self.solve_triangular_into(uplo, diag, b)?; Ok(flatten(b)) @@ -111,7 +131,12 @@ where Si: Data, So: DataMut + DataOwned, { - fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: &ArrayBase) -> Result> { + fn solve_triangular( + &self, + uplo: UPLO, + diag: Diag, + b: &ArrayBase, + ) -> Result> { let b = b.to_owned(); self.solve_triangular_into(uplo, diag, b) } diff --git a/tests/cholesky.rs b/tests/cholesky.rs index 9a7c56c1..accdbdf8 100644 --- a/tests/cholesky.rs +++ b/tests/cholesky.rs @@ -16,7 +16,11 @@ fn cholesky() { ); let lower = a_orig.cholesky(UPLO::Lower).unwrap(); - assert_close_l2!(&lower.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); + assert_close_l2!( + &lower.dot(&lower.t().mapv(|elem| elem.conj())), + &a_orig, + $rtol + ); let a: Array2<$elem> = replicate(&a_orig); let upper = a.cholesky_into(UPLO::Upper).unwrap(); @@ -28,7 +32,11 @@ fn cholesky() { let a: Array2<$elem> = replicate(&a_orig); let lower = a.cholesky_into(UPLO::Lower).unwrap(); - assert_close_l2!(&lower.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); + assert_close_l2!( + &lower.dot(&lower.t().mapv(|elem| elem.conj())), + &a_orig, + $rtol + ); let mut a: Array2<$elem> = replicate(&a_orig); { @@ -39,12 +47,20 @@ fn cholesky() { $rtol ); } - assert_close_l2!(&a.t().mapv(|elem| elem.conj()).dot(&upper.view()), &a_orig, $rtol); + assert_close_l2!( + &a.t().mapv(|elem| elem.conj()).dot(&upper.view()), + &a_orig, + $rtol + ); let mut a: Array2<$elem> = replicate(&a_orig); { let lower = a.cholesky_inplace(UPLO::Lower).unwrap(); - assert_close_l2!(&lower.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); + assert_close_l2!( + &lower.dot(&lower.t().mapv(|elem| elem.conj())), + &a_orig, + $rtol + ); } assert_close_l2!(&a.dot(&lower.t().mapv(|elem| elem.conj())), &a_orig, $rtol); }; @@ -120,7 +136,11 @@ fn cholesky_det() { assert_aclose!(a.factorizec(UPLO::Upper).unwrap().detc(), det, $atol); assert_aclose!(a.factorizec(UPLO::Upper).unwrap().ln_detc(), ln_det, $atol); assert_aclose!(a.factorizec(UPLO::Lower).unwrap().detc_into(), det, $atol); - assert_aclose!(a.factorizec(UPLO::Lower).unwrap().ln_detc_into(), ln_det, $atol); + assert_aclose!( + a.factorizec(UPLO::Lower).unwrap().ln_detc_into(), + ln_det, + $atol + ); assert_aclose!(a.detc().unwrap(), det, $atol); assert_aclose!(a.ln_detc().unwrap(), ln_det, $atol); assert_aclose!(a.clone().detc_into().unwrap(), det, $atol); @@ -145,15 +165,29 @@ fn cholesky_solve() { assert_close_l2!(&a.solvec(&b).unwrap(), &x, $rtol); assert_close_l2!(&a.solvec_into(b.clone()).unwrap(), &x, $rtol); assert_close_l2!(&a.solvec_inplace(&mut b.clone()).unwrap(), &x, $rtol); - assert_close_l2!(&a.factorizec(UPLO::Upper).unwrap().solvec(&b).unwrap(), &x, $rtol); - assert_close_l2!(&a.factorizec(UPLO::Lower).unwrap().solvec(&b).unwrap(), &x, $rtol); assert_close_l2!( - &a.factorizec(UPLO::Upper).unwrap().solvec_into(b.clone()).unwrap(), + &a.factorizec(UPLO::Upper).unwrap().solvec(&b).unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Lower).unwrap().solvec(&b).unwrap(), &x, $rtol ); assert_close_l2!( - &a.factorizec(UPLO::Lower).unwrap().solvec_into(b.clone()).unwrap(), + &a.factorizec(UPLO::Upper) + .unwrap() + .solvec_into(b.clone()) + .unwrap(), + &x, + $rtol + ); + assert_close_l2!( + &a.factorizec(UPLO::Lower) + .unwrap() + .solvec_into(b.clone()) + .unwrap(), &x, $rtol ); diff --git a/tests/det.rs b/tests/det.rs index 610b0b59..ec9c0d4a 100644 --- a/tests/det.rs +++ b/tests/det.rs @@ -12,7 +12,8 @@ where select_rows.remove(row); let mut select_cols = (0..a.ncols()).collect::>(); select_cols.remove(col); - a.select(Axis(0), &select_rows).select(Axis(1), &select_cols) + a.select(Axis(0), &select_rows) + .select(Axis(1), &select_cols) } /// Computes the determinant of matrix `a`. @@ -47,7 +48,10 @@ fn det_empty() { assert_eq!(a.factorize().unwrap().det().unwrap(), det); assert_eq!(a.factorize().unwrap().sln_det().unwrap(), (sign, ln_det)); assert_eq!(a.factorize().unwrap().det_into().unwrap(), det); - assert_eq!(a.factorize().unwrap().sln_det_into().unwrap(), (sign, ln_det)); + assert_eq!( + a.factorize().unwrap().sln_det_into().unwrap(), + (sign, ln_det) + ); assert_eq!(a.det().unwrap(), det); assert_eq!(a.sln_det().unwrap(), (sign, ln_det)); assert_eq!(a.clone().det_into().unwrap(), det); diff --git a/tests/deth.rs b/tests/deth.rs index 30155864..abd54105 100644 --- a/tests/deth.rs +++ b/tests/deth.rs @@ -8,7 +8,10 @@ fn deth_empty() { ($elem:ty) => { let a: Array2<$elem> = Array2::zeros((0, 0)); assert_eq!(a.factorizeh().unwrap().deth(), One::one()); - assert_eq!(a.factorizeh().unwrap().sln_deth(), (One::one(), Zero::zero())); + assert_eq!( + a.factorizeh().unwrap().sln_deth(), + (One::one(), Zero::zero()) + ); assert_eq!(a.factorizeh().unwrap().deth_into(), One::one()); assert_eq!( a.factorizeh().unwrap().sln_deth_into(), @@ -34,7 +37,10 @@ fn deth_zero() { assert_eq!(a.deth().unwrap(), Zero::zero()); assert_eq!(a.sln_deth().unwrap(), (Zero::zero(), Float::neg_infinity())); assert_eq!(a.clone().deth_into().unwrap(), Zero::zero()); - assert_eq!(a.sln_deth_into().unwrap(), (Zero::zero(), Float::neg_infinity())); + assert_eq!( + a.sln_deth_into().unwrap(), + (Zero::zero(), Float::neg_infinity()) + ); }; } deth_zero!(f64); @@ -71,11 +77,18 @@ fn deth() { // Compute determinant from eigenvalues. let (sign, ln_det) = a.eigvalsh(UPLO::Upper).unwrap().iter().fold( - (<$elem as Scalar>::Real::one(), <$elem as Scalar>::Real::zero()), + ( + <$elem as Scalar>::Real::one(), + <$elem as Scalar>::Real::zero(), + ), |(sign, ln_det), eigval| (sign * eigval.signum(), ln_det + eigval.abs().ln()), ); let det = sign * ln_det.exp(); - assert_aclose!(det, a.eigvalsh(UPLO::Upper).unwrap().iter().product(), $atol); + assert_aclose!( + det, + a.eigvalsh(UPLO::Upper).unwrap().iter().product(), + $atol + ); assert_aclose!(a.factorizeh().unwrap().deth(), det, $atol); { diff --git a/tests/eig.rs b/tests/eig.rs index 802356ee..ac520152 100644 --- a/tests/eig.rs +++ b/tests/eig.rs @@ -4,15 +4,25 @@ use ndarray_linalg::*; #[test] fn dgeev() { // https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/dgeev_ex.f.htm - let a: Array2 = arr2(&[[-1.01, 0.86, -4.60, 3.31, -4.81], - [ 3.98, 0.53, -7.04, 5.29, 3.55], - [ 3.30, 8.26, -3.89, 8.20, -1.51], - [ 4.43, 4.96, -7.66, -7.33, 6.18], - [ 7.31, -6.43, -6.16, 2.47, 5.58]]); + let a: Array2 = arr2(&[ + [-1.01, 0.86, -4.60, 3.31, -4.81], + [3.98, 0.53, -7.04, 5.29, 3.55], + [3.30, 8.26, -3.89, 8.20, -1.51], + [4.43, 4.96, -7.66, -7.33, 6.18], + [7.31, -6.43, -6.16, 2.47, 5.58], + ]); let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap(); - assert_close_l2!(&e, - &arr1(&[c64::new( 2.86, 10.76), c64::new( 2.86,-10.76), c64::new( -0.69, 4.70), c64::new( -0.69, -4.70), c64::new(-10.46, 0.00)]), - 1.0e-3); + assert_close_l2!( + &e, + &arr1(&[ + c64::new(2.86, 10.76), + c64::new(2.86, -10.76), + c64::new(-0.69, 4.70), + c64::new(-0.69, -4.70), + c64::new(-10.46, 0.00) + ]), + 1.0e-3 + ); /* let answer = &arr2(&[[c64::new( 0.11, 0.17), c64::new( 0.11, -0.17), c64::new( 0.73, 0.00), c64::new( 0.73, 0.00), c64::new( 0.46, 0.00)], @@ -33,15 +43,25 @@ fn dgeev() { #[test] fn fgeev() { // https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/dgeev_ex.f.htm - let a: Array2 = arr2(&[[-1.01, 0.86, -4.60, 3.31, -4.81], - [ 3.98, 0.53, -7.04, 5.29, 3.55], - [ 3.30, 8.26, -3.89, 8.20, -1.51], - [ 4.43, 4.96, -7.66, -7.33, 6.18], - [ 7.31, -6.43, -6.16, 2.47, 5.58]]); + let a: Array2 = arr2(&[ + [-1.01, 0.86, -4.60, 3.31, -4.81], + [3.98, 0.53, -7.04, 5.29, 3.55], + [3.30, 8.26, -3.89, 8.20, -1.51], + [4.43, 4.96, -7.66, -7.33, 6.18], + [7.31, -6.43, -6.16, 2.47, 5.58], + ]); let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap(); - assert_close_l2!(&e, - &arr1(&[c32::new( 2.86, 10.76), c32::new( 2.86,-10.76), c32::new( -0.69, 4.70), c32::new( -0.69, -4.70), c32::new(-10.46, 0.00)]), - 1.0e-3); + assert_close_l2!( + &e, + &arr1(&[ + c32::new(2.86, 10.76), + c32::new(2.86, -10.76), + c32::new(-0.69, 4.70), + c32::new(-0.69, -4.70), + c32::new(-10.46, 0.00) + ]), + 1.0e-3 + ); /* let answer = &arr2(&[[c32::new( 0.11, 0.17), c32::new( 0.11, -0.17), c32::new( 0.73, 0.00), c32::new( 0.73, 0.00), c32::new( 0.46, 0.00)], @@ -62,14 +82,43 @@ fn fgeev() { #[test] fn zgeev() { // https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/zgeev_ex.f.htm - let a: Array2 = arr2(&[[c64::new( -3.84, 2.25), c64::new( -8.94, -4.75), c64::new( 8.95, -6.53), c64::new( -9.87, 4.82)], - [c64::new( -0.66, 0.83), c64::new( -4.40, -3.82), c64::new( -3.50, -4.26), c64::new( -3.15, 7.36)], - [c64::new( -3.99, -4.73), c64::new( -5.88, -6.60), c64::new( -3.36, -0.40), c64::new( -0.75, 5.23)], - [c64::new( 7.74, 4.18), c64::new( 3.66, -7.53), c64::new( 2.58, 3.60), c64::new( 4.59, 5.41)],]); + let a: Array2 = arr2(&[ + [ + c64::new(-3.84, 2.25), + c64::new(-8.94, -4.75), + c64::new(8.95, -6.53), + c64::new(-9.87, 4.82), + ], + [ + c64::new(-0.66, 0.83), + c64::new(-4.40, -3.82), + c64::new(-3.50, -4.26), + c64::new(-3.15, 7.36), + ], + [ + c64::new(-3.99, -4.73), + c64::new(-5.88, -6.60), + c64::new(-3.36, -0.40), + c64::new(-0.75, 5.23), + ], + [ + c64::new(7.74, 4.18), + c64::new(3.66, -7.53), + c64::new(2.58, 3.60), + c64::new(4.59, 5.41), + ], + ]); let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap(); - assert_close_l2!(&e, - &arr1(&[c64::new( -9.43,-12.98), c64::new( -3.44, 12.69), c64::new( 0.11, -3.40), c64::new( 5.76, 7.13)]), - 1.0e-3); + assert_close_l2!( + &e, + &arr1(&[ + c64::new(-9.43, -12.98), + c64::new(-3.44, 12.69), + c64::new(0.11, -3.40), + c64::new(5.76, 7.13) + ]), + 1.0e-3 + ); /* let answer = &arr2(&[[c64::new( 0.43, 0.33), c64::new( 0.83, 0.00), c64::new( 0.60, 0.00), c64::new( -0.31, 0.03)], @@ -88,14 +137,43 @@ fn zgeev() { #[test] fn cgeev() { // https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/zgeev_ex.f.htm - let a: Array2 = arr2(&[[c32::new( -3.84, 2.25), c32::new( -8.94, -4.75), c32::new( 8.95, -6.53), c32::new( -9.87, 4.82)], - [c32::new( -0.66, 0.83), c32::new( -4.40, -3.82), c32::new( -3.50, -4.26), c32::new( -3.15, 7.36)], - [c32::new( -3.99, -4.73), c32::new( -5.88, -6.60), c32::new( -3.36, -0.40), c32::new( -0.75, 5.23)], - [c32::new( 7.74, 4.18), c32::new( 3.66, -7.53), c32::new( 2.58, 3.60), c32::new( 4.59, 5.41)],]); + let a: Array2 = arr2(&[ + [ + c32::new(-3.84, 2.25), + c32::new(-8.94, -4.75), + c32::new(8.95, -6.53), + c32::new(-9.87, 4.82), + ], + [ + c32::new(-0.66, 0.83), + c32::new(-4.40, -3.82), + c32::new(-3.50, -4.26), + c32::new(-3.15, 7.36), + ], + [ + c32::new(-3.99, -4.73), + c32::new(-5.88, -6.60), + c32::new(-3.36, -0.40), + c32::new(-0.75, 5.23), + ], + [ + c32::new(7.74, 4.18), + c32::new(3.66, -7.53), + c32::new(2.58, 3.60), + c32::new(4.59, 5.41), + ], + ]); let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap(); - assert_close_l2!(&e, - &arr1(&[c32::new( -9.43,-12.98), c32::new( -3.44, 12.69), c32::new( 0.11, -3.40), c32::new( 5.76, 7.13)]), - 1.0e-3); + assert_close_l2!( + &e, + &arr1(&[ + c32::new(-9.43, -12.98), + c32::new(-3.44, 12.69), + c32::new(0.11, -3.40), + c32::new(5.76, 7.13) + ]), + 1.0e-3 + ); /* let answer = &arr2(&[[c32::new( 0.43, 0.33), c32::new( 0.83, 0.00), c32::new( 0.60, 0.00), c32::new( -0.31, 0.03)], @@ -109,4 +187,4 @@ fn cgeev() { let ev = v.mapv(|f| e[i] * f); assert_close_l2!(&av, &ev, 1.0e-5); } -} \ No newline at end of file +} diff --git a/tests/opnorm.rs b/tests/opnorm.rs index 8e52bbd4..fbda597f 100644 --- a/tests/opnorm.rs +++ b/tests/opnorm.rs @@ -14,7 +14,10 @@ fn test(a: Array2, one: f64, inf: f64, fro: f64) { fn gen(i: usize, j: usize, rev: bool) -> Array2 { let n = (i * j + 1) as f64; if rev { - Array::range(1., n, 1.).into_shape((j, i)).unwrap().reversed_axes() + Array::range(1., n, 1.) + .into_shape((j, i)) + .unwrap() + .reversed_axes() } else { Array::range(1., n, 1.).into_shape((i, j)).unwrap() } diff --git a/tests/solve.rs b/tests/solve.rs index 13b91d99..f26ebccc 100644 --- a/tests/solve.rs +++ b/tests/solve.rs @@ -41,7 +41,9 @@ fn rcond() { fn rcond_hilbert() { macro_rules! rcond_hilbert { ($elem:ty, $rows:expr, $atol:expr) => { - let a = Array2::<$elem>::from_shape_fn(($rows, $rows), |(i, j)| 1. / (i as $elem + j as $elem - 1.)); + let a = Array2::<$elem>::from_shape_fn(($rows, $rows), |(i, j)| { + 1. / (i as $elem + j as $elem - 1.) + }); assert_aclose!(a.rcond().unwrap(), 0., $atol); assert_aclose!(a.rcond_into().unwrap(), 0., $atol); };