Skip to content

Commit

Permalink
Fix unsafe signature
Browse files Browse the repository at this point in the history
  • Loading branch information
termoshtt committed Jul 5, 2020
1 parent 261e79a commit f7d93f4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 86 deletions.
94 changes: 51 additions & 43 deletions lax/src/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,65 +5,71 @@ use crate::{error::*, layout::MatrixLayout};
use cauchy::*;
use num_traits::Zero;

/// Wraps `*getrf`, `*getri`, and `*getrs`
pub trait Solve_: Scalar + Sized {
/// Computes the LU factorization of a general `m x n` matrix `a` using
/// partial pivoting with row interchanges.
///
/// If the result matches `Err(LinalgError::Lapack(LapackError {
/// return_code )) if return_code > 0`, then `U[(return_code-1,
/// return_code-1)]` is exactly zero. The factorization has been completed,
/// but the factor `U` is exactly singular, and division by zero will occur
/// if it is used to solve a system of equations.
unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
unsafe fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>;
/// $ PA = LU $
///
/// Error
/// ------
/// - `LapackComputationalFailure { return_code }` when the matrix is singular
/// - `U[(return_code-1, return_code-1)]` is exactly zero.
/// - Division by zero will occur if it is used to solve a system of equations.
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;

fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>;

/// Estimates the the reciprocal of the condition number of the matrix in 1-norm.
///
/// `anorm` should be the 1-norm of the matrix `a`.
unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real>;
unsafe fn solve(
l: MatrixLayout,
t: Transpose,
a: &[Self],
p: &Pivot,
b: &mut [Self],
) -> Result<()>;
fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real>;

fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
}

macro_rules! impl_solve {
($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => {
impl Solve_ for $scalar {
unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
let (row, col) = l.size();
let k = ::std::cmp::min(row, col);
let mut ipiv = vec![0; k as usize];
$getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv).as_lapack_result()?;
unsafe {
$getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv)
.as_lapack_result()?;
}
Ok(ipiv)
}

unsafe fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
let (n, _) = l.size();
$getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?;
unsafe {
$getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?;
}
Ok(())
}

unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> {
fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> {
let (n, _) = l.size();
let mut rcond = Self::Real::zero();
$gecon(
l.lapacke_layout(),
NormType::One as u8,
n,
a,
l.lda(),
anorm,
&mut rcond,
)
unsafe {
$gecon(
l.lapacke_layout(),
NormType::One as u8,
n,
a,
l.lda(),
anorm,
&mut rcond,
)
}
.as_lapack_result()?;

Ok(rcond)
}

unsafe fn solve(
fn solve(
l: MatrixLayout,
t: Transpose,
a: &[Self],
Expand All @@ -73,18 +79,20 @@ macro_rules! impl_solve {
let (n, _) = l.size();
let nrhs = 1;
let ldb = 1;
$getrs(
l.lapacke_layout(),
t as u8,
n,
nrhs,
a,
l.lda(),
ipiv,
b,
ldb,
)
.as_lapack_result()?;
unsafe {
$getrs(
l.lapacke_layout(),
t as u8,
n,
nrhs,
a,
l.lda(),
ipiv,
b,
ldb,
)
.as_lapack_result()?;
}
Ok(())
}
}
Expand Down
76 changes: 33 additions & 43 deletions ndarray-linalg/src/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,13 @@ where
where
Sb: DataMut<Elem = A>,
{
unsafe {
A::solve(
self.a.square_layout()?,
Transpose::No,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?
};
A::solve(
self.a.square_layout()?,
Transpose::No,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?;
Ok(rhs)
}
fn solve_t_inplace<'a, Sb>(
Expand All @@ -185,15 +183,13 @@ where
where
Sb: DataMut<Elem = A>,
{
unsafe {
A::solve(
self.a.square_layout()?,
Transpose::Transpose,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?
};
A::solve(
self.a.square_layout()?,
Transpose::Transpose,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?;
Ok(rhs)
}
fn solve_h_inplace<'a, Sb>(
Expand All @@ -203,15 +199,13 @@ where
where
Sb: DataMut<Elem = A>,
{
unsafe {
A::solve(
self.a.square_layout()?,
Transpose::Hermite,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?
};
A::solve(
self.a.square_layout()?,
Transpose::Hermite,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?;
Ok(rhs)
}
}
Expand Down Expand Up @@ -273,7 +267,7 @@ where
S: DataMut<Elem = A> + RawDataClone,
{
fn factorize_into(mut self) -> Result<LUFactorized<S>> {
let ipiv = unsafe { A::lu(self.layout()?, self.as_allocated_mut()?)? };
let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?;
Ok(LUFactorized { a: self, ipiv })
}
}
Expand All @@ -285,7 +279,7 @@ where
{
fn factorize(&self) -> Result<LUFactorized<OwnedRepr<A>>> {
let mut a: Array2<A> = replicate(self);
let ipiv = unsafe { A::lu(a.layout()?, a.as_allocated_mut()?)? };
let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?;
Ok(LUFactorized { a, ipiv })
}
}
Expand All @@ -312,13 +306,11 @@ where
type Output = ArrayBase<S, Ix2>;

fn inv_into(mut self) -> Result<ArrayBase<S, Ix2>> {
unsafe {
A::inv(
self.a.square_layout()?,
self.a.as_allocated_mut()?,
&self.ipiv,
)?
};
A::inv(
self.a.square_layout()?,
self.a.as_allocated_mut()?,
&self.ipiv,
)?;
Ok(self.a)
}
}
Expand Down Expand Up @@ -539,13 +531,11 @@ where
S: Data<Elem = A> + RawDataClone,
{
fn rcond(&self) -> Result<A::Real> {
unsafe {
Ok(A::rcond(
self.a.layout()?,
self.a.as_allocated()?,
self.a.opnorm_one()?,
)?)
}
Ok(A::rcond(
self.a.layout()?,
self.a.as_allocated()?,
self.a.opnorm_one()?,
)?)
}
}

Expand Down

0 comments on commit f7d93f4

Please sign in to comment.