diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 67af6409..7c39cf88 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -5,65 +5,71 @@ use crate::{error::*, layout::MatrixLayout}; use cauchy::*; use num_traits::Zero; -/// Wraps `*getrf`, `*getri`, and `*getrs` pub trait Solve_: Scalar + Sized { /// Computes the LU factorization of a general `m x n` matrix `a` using /// partial pivoting with row interchanges. /// - /// If the result matches `Err(LinalgError::Lapack(LapackError { - /// return_code )) if return_code > 0`, then `U[(return_code-1, - /// return_code-1)]` is exactly zero. The factorization has been completed, - /// but the factor `U` is exactly singular, and division by zero will occur - /// if it is used to solve a system of equations. - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; - unsafe fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; + /// $ PA = LU $ + /// + /// Error + /// ------ + /// - `LapackComputationalFailure { return_code }` when the matrix is singular + /// - `U[(return_code-1, return_code-1)]` is exactly zero. + /// - Division by zero will occur if it is used to solve a system of equations. + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; + + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; + /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. /// /// `anorm` should be the 1-norm of the matrix `a`. - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; - unsafe fn solve( - l: MatrixLayout, - t: Transpose, - a: &[Self], - p: &Pivot, - b: &mut [Self], - ) -> Result<()>; + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; + + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } macro_rules! impl_solve { ($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => { impl Solve_ for $scalar { - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { let (row, col) = l.size(); let k = ::std::cmp::min(row, col); let mut ipiv = vec![0; k as usize]; - $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv).as_lapack_result()?; + unsafe { + $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv) + .as_lapack_result()?; + } Ok(ipiv) } - unsafe fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); - $getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?; + unsafe { + $getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?; + } Ok(()) } - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { let (n, _) = l.size(); let mut rcond = Self::Real::zero(); - $gecon( - l.lapacke_layout(), - NormType::One as u8, - n, - a, - l.lda(), - anorm, - &mut rcond, - ) + unsafe { + $gecon( + l.lapacke_layout(), + NormType::One as u8, + n, + a, + l.lda(), + anorm, + &mut rcond, + ) + } .as_lapack_result()?; + Ok(rcond) } - unsafe fn solve( + fn solve( l: MatrixLayout, t: Transpose, a: &[Self], @@ -73,18 +79,20 @@ macro_rules! impl_solve { let (n, _) = l.size(); let nrhs = 1; let ldb = 1; - $getrs( - l.lapacke_layout(), - t as u8, - n, - nrhs, - a, - l.lda(), - ipiv, - b, - ldb, - ) - .as_lapack_result()?; + unsafe { + $getrs( + l.lapacke_layout(), + t as u8, + n, + nrhs, + a, + l.lda(), + ipiv, + b, + ldb, + ) + .as_lapack_result()?; + } Ok(()) } } diff --git a/ndarray-linalg/src/solve.rs b/ndarray-linalg/src/solve.rs index 566511f3..fd4b3017 100644 --- a/ndarray-linalg/src/solve.rs +++ b/ndarray-linalg/src/solve.rs @@ -167,15 +167,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::No, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::No, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_t_inplace<'a, Sb>( @@ -185,15 +183,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::Transpose, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::Transpose, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_h_inplace<'a, Sb>( @@ -203,15 +199,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::Hermite, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::Hermite, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } } @@ -273,7 +267,7 @@ where S: DataMut + RawDataClone, { fn factorize_into(mut self) -> Result> { - let ipiv = unsafe { A::lu(self.layout()?, self.as_allocated_mut()?)? }; + let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?; Ok(LUFactorized { a: self, ipiv }) } } @@ -285,7 +279,7 @@ where { fn factorize(&self) -> Result>> { let mut a: Array2 = replicate(self); - let ipiv = unsafe { A::lu(a.layout()?, a.as_allocated_mut()?)? }; + let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?; Ok(LUFactorized { a, ipiv }) } } @@ -312,13 +306,11 @@ where type Output = ArrayBase; fn inv_into(mut self) -> Result> { - unsafe { - A::inv( - self.a.square_layout()?, - self.a.as_allocated_mut()?, - &self.ipiv, - )? - }; + A::inv( + self.a.square_layout()?, + self.a.as_allocated_mut()?, + &self.ipiv, + )?; Ok(self.a) } } @@ -539,13 +531,11 @@ where S: Data + RawDataClone, { fn rcond(&self) -> Result { - unsafe { - Ok(A::rcond( - self.a.layout()?, - self.a.as_allocated()?, - self.a.opnorm_one()?, - )?) - } + Ok(A::rcond( + self.a.layout()?, + self.a.as_allocated()?, + self.a.opnorm_one()?, + )?) } }