Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cargo-fmt --check on CI #194

Merged
merged 5 commits into from
May 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,12 @@ jobs:
with:
command: test
args: --features=openblas --no-default-features

check-format:
runs-on: ubuntu-18.04
steps:
- uses: actions/checkout@v1
- uses: actions-rs/cargo@v1
with:
command: fmt
args: -- --check
2 changes: 1 addition & 1 deletion examples/eig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ fn main() {
let a_c: Array2<c64> = a.map(|f| c64::new(*f, 0.0));
let av = a_c.dot(&vecs);
println!("AV = \n{:?}", av);
}
}
4 changes: 3 additions & 1 deletion examples/truncated_svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ fn main() {
let a = arr2(&[[3., 2., 2.], [2., 3., -2.]]);

// calculate the truncated singular value decomposition for 2 singular values
let result = TruncatedSvd::new(a, TruncatedOrder::Largest).decompose(2).unwrap();
let result = TruncatedSvd::new(a, TruncatedOrder::Largest)
.decompose(2)
.unwrap();

// acquire singular values, left-singular vectors and right-singular vectors
let (u, sigma, v_t) = result.values_vectors();
Expand Down
6 changes: 0 additions & 6 deletions rustfmt.toml

This file was deleted.

15 changes: 12 additions & 3 deletions src/cholesky.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@ where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solvec_inplace<'a, Sb>(&self, b: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
fn solvec_inplace<'a, Sb>(
&self,
b: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
Expand Down Expand Up @@ -327,7 +330,10 @@ pub trait SolveC<A: Scalar> {
/// Solves a system of linear equations `A * x = b` with Hermitian (or real
/// symmetric) positive definite matrix `A`, where `A` is `self`, `b` is
/// the argument, and `x` is the successful result.
fn solvec_into<S: DataMut<Elem = A>>(&self, mut b: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
fn solvec_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
self.solvec_inplace(&mut b)?;
Ok(b)
}
Expand All @@ -346,7 +352,10 @@ where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solvec_inplace<'a, Sb>(&self, b: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
fn solvec_inplace<'a, Sb>(
&self,
b: &'a mut ArrayBase<Sb, Ix1>,
) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
Expand Down
6 changes: 5 additions & 1 deletion src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ where
} else {
ArrayBase::from_shape_vec(a.dim().f(), a.into_raw_vec()).unwrap()
};
assert_eq!(new.strides(), strides.as_slice(), "Custom stride is not supported");
assert_eq!(
new.strides(),
strides.as_slice(),
"Custom stride is not supported"
);
new
}

Expand Down
11 changes: 8 additions & 3 deletions src/eig.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! Eigenvalue decomposition for non-symmetric square matrices

use ndarray::*;
use crate::error::*;
use crate::layout::*;
use crate::types::*;
use ndarray::*;

/// Eigenvalue decomposition of general matrix reference
pub trait Eig {
Expand All @@ -27,7 +27,12 @@ where
let layout = a.square_layout()?;
let (s, t) = unsafe { A::eig(true, layout, a.as_allocated_mut()?)? };
let (n, _) = layout.size();
Ok((ArrayBase::from(s), ArrayBase::from(t).into_shape((n as usize, n as usize)).unwrap()))
Ok((
ArrayBase::from(s),
ArrayBase::from(t)
.into_shape((n as usize, n as usize))
.unwrap(),
))
}
}

Expand All @@ -49,4 +54,4 @@ where
let (s, _) = unsafe { A::eig(true, a.square_layout()?, a.as_allocated_mut()?)? };
Ok(ArrayBase::from(s))
}
}
}
12 changes: 9 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@ pub enum LinalgError {
impl fmt::Display for LinalgError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
LinalgError::NotSquare { rows, cols } => write!(f, "Not square: rows({}) != cols({})", rows, cols),
LinalgError::Lapack { return_code } => write!(f, "LAPACK: return_code = {}", return_code),
LinalgError::InvalidStride { s0, s1 } => write!(f, "invalid stride: s0={}, s1={}", s0, s1),
LinalgError::NotSquare { rows, cols } => {
write!(f, "Not square: rows({}) != cols({})", rows, cols)
}
LinalgError::Lapack { return_code } => {
write!(f, "LAPACK: return_code = {}", return_code)
}
LinalgError::InvalidStride { s0, s1 } => {
write!(f, "invalid stride: s0={}, s1={}", s0, s1)
}
LinalgError::MemoryNotCont => write!(f, "Memory is not contiguous"),
LinalgError::Shape(err) => write!(f, "Shape Error: {}", err),
}
Expand Down
4 changes: 3 additions & 1 deletion src/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ where
assert_eq!(self.len(), rhs.len());
Zip::from(self)
.and(rhs)
.fold_while(A::zero(), |acc, s, r| FoldWhile::Continue(acc + s.conj() * *r))
.fold_while(A::zero(), |acc, s, r| {
FoldWhile::Continue(acc + s.conj() * *r)
})
.into_inner()
}
}
12 changes: 10 additions & 2 deletions src/krylov/arnoldi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ where
}

/// Utility to execute Arnoldi iteration with Householder reflection
pub fn arnoldi_householder<A, S>(a: impl LinearOperator<Elem = A>, v: ArrayBase<S, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
pub fn arnoldi_householder<A, S>(
a: impl LinearOperator<Elem = A>,
v: ArrayBase<S, Ix1>,
tol: A::Real,
) -> (Q<A>, H<A>)
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
Expand All @@ -107,7 +111,11 @@ where
}

/// Utility to execute Arnoldi iteration with modified Gram-Schmit orthogonalizer
pub fn arnoldi_mgs<A, S>(a: impl LinearOperator<Elem = A>, v: ArrayBase<S, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
pub fn arnoldi_mgs<A, S>(
a: impl LinearOperator<Elem = A>,
v: ArrayBase<S, Ix1>,
tol: A::Real,
) -> (Q<A>, H<A>)
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
Expand Down
6 changes: 5 additions & 1 deletion src/krylov/householder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ impl<A: Scalar + Lapack> Householder<A> {
S: DataMut<Elem = A>,
{
assert!(k < self.v.len());
assert_eq!(a.len(), self.dim, "Input array size mismaches to the dimension");
assert_eq!(
a.len(),
self.dim,
"Input array size mismaches to the dimension"
);
reflect(&self.v[k].slice(s![k..]), &mut a.slice_mut(s![k..]));
}

Expand Down
10 changes: 8 additions & 2 deletions src/lapack/cholesky.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ pub trait Cholesky_: Sized {
/// **Warning: Only the portion of `a` corresponding to `UPLO` is written.**
unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
/// Wrapper of `*potrs`
unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self])
-> Result<()>;
}

macro_rules! impl_cholesky {
Expand All @@ -36,7 +37,12 @@ macro_rules! impl_cholesky {
into_result(info, ())
}

unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()> {
unsafe fn solve_cholesky(
l: MatrixLayout,
uplo: UPLO,
a: &[Self],
b: &mut [Self],
) -> Result<()> {
let (n, _) = l.size();
let nrhs = 1;
let ldb = 1;
Expand Down
106 changes: 76 additions & 30 deletions src/lapack/eig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,39 @@ use super::into_result;

/// Wraps `*geev` for real/complex
pub trait Eig_: Scalar {
unsafe fn eig(calc_v: bool, l: MatrixLayout, a: &mut [Self]) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)>;
unsafe fn eig(
calc_v: bool,
l: MatrixLayout,
a: &mut [Self],
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)>;
}

macro_rules! impl_eig_complex {
($scalar:ty, $ev:path) => {
impl Eig_ for $scalar {
unsafe fn eig(calc_v: bool, l: MatrixLayout, mut a: &mut [Self]) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
unsafe fn eig(
calc_v: bool,
l: MatrixLayout,
mut a: &mut [Self],
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
let (n, _) = l.size();
let jobvr = if calc_v { b'V' } else { b'N' };
let mut w = vec![Self::Complex::zero(); n as usize];
let mut vl = Vec::new();
let mut vr = vec![Self::Complex::zero(); (n * n) as usize];
let info = $ev(l.lapacke_layout(), b'N', jobvr, n, &mut a, n, &mut w, &mut vl, n, &mut vr, n);
let info = $ev(
l.lapacke_layout(),
b'N',
jobvr,
n,
&mut a,
n,
&mut w,
&mut vl,
n,
&mut vr,
n,
);
into_result(info, (w, vr))
}
}
Expand All @@ -33,49 +53,75 @@ macro_rules! impl_eig_complex {
macro_rules! impl_eig_real {
($scalar:ty, $ev:path) => {
impl Eig_ for $scalar {
unsafe fn eig(calc_v: bool, l: MatrixLayout, mut a: &mut [Self]) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
unsafe fn eig(
calc_v: bool,
l: MatrixLayout,
mut a: &mut [Self],
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
let (n, _) = l.size();
let jobvr = if calc_v { b'V' } else { b'N' };
let mut wr = vec![Self::Real::zero(); n as usize];
let mut wi = vec![Self::Real::zero(); n as usize];
let mut vl = Vec::new();
let mut vr = vec![Self::Real::zero(); (n * n) as usize];
let info = $ev(l.lapacke_layout(), b'N', jobvr, n, &mut a, n, &mut wr, &mut wi, &mut vl, n, &mut vr, n);
let w: Vec<Self::Complex> = wr.iter().zip(wi.iter()).map(|(&r, &i)| Self::Complex::new(r, i)).collect();
let info = $ev(
l.lapacke_layout(),
b'N',
jobvr,
n,
&mut a,
n,
&mut wr,
&mut wi,
&mut vl,
n,
&mut vr,
n,
);
let w: Vec<Self::Complex> = wr
.iter()
.zip(wi.iter())
.map(|(&r, &i)| Self::Complex::new(r, i))
.collect();
// If the j-th eigenvalue is real, then
// eigenvector = [ vr[j], vr[j+n], vr[j+2*n], ... ].
//
// If the j-th and (j+1)-st eigenvalues form a complex conjugate pair,
// If the j-th and (j+1)-st eigenvalues form a complex conjugate pair,
// eigenvector(j) = [ vr[j] + i*vr[j+1], vr[j+n] + i*vr[j+n+1], vr[j+2*n] + i*vr[j+2*n+1], ... ] and
// eigenvector(j+1) = [ vr[j] - i*vr[j+1], vr[j+n] - i*vr[j+n+1], vr[j+2*n] - i*vr[j+2*n+1], ... ].
//
//
// Therefore, if eigenvector(j) is written as [ v_{j0}, v_{j1}, v_{j2}, ... ],
// you have to make
// you have to make
// v = vec![ v_{00}, v_{10}, v_{20}, ..., v_{jk}, v_{(j+1)k}, v_{(j+2)k}, ... ] (v.len() = n*n)
// based on wi and vr.
// After that, v is converted to Array2 (see ../eig.rs).
let n = n as usize;
let mut flg = false;
let conj: Vec<i8> = wi.iter().map(|&i| {
if flg {
flg = false;
-1
} else if i != 0.0 {
flg = true;
1
} else {
0
}
}).collect();
let v: Vec<Self::Complex> = (0..n*n).map(|i| {
let j = i % n;
match conj[j] {
1 => Self::Complex::new(vr[i], vr[i+1]),
-1 => Self::Complex::new(vr[i-1], -vr[i]),
_ => Self::Complex::new(vr[i], 0.0),
}
}).collect();

let conj: Vec<i8> = wi
.iter()
.map(|&i| {
if flg {
flg = false;
-1
} else if i != 0.0 {
flg = true;
1
} else {
0
}
})
.collect();
let v: Vec<Self::Complex> = (0..n * n)
.map(|i| {
let j = i % n;
match conj[j] {
1 => Self::Complex::new(vr[i], vr[i + 1]),
-1 => Self::Complex::new(vr[i - 1], -vr[i]),
_ => Self::Complex::new(vr[i], 0.0),
}
})
.collect();

into_result(info, (w, v))
}
}
Expand All @@ -85,4 +131,4 @@ macro_rules! impl_eig_real {
impl_eig_real!(f64, lapacke::dgeev);
impl_eig_real!(f32, lapacke::sgeev);
impl_eig_complex!(c64, lapacke::zgeev);
impl_eig_complex!(c32, lapacke::cgeev);
impl_eig_complex!(c32, lapacke::cgeev);
14 changes: 12 additions & 2 deletions src/lapack/eigh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ use super::{into_result, UPLO};

/// Wraps `*syev` for real and `*heev` for complex
pub trait Eigh_: Scalar {
unsafe fn eigh(calc_eigenvec: bool, l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Vec<Self::Real>>;
unsafe fn eigh(
calc_eigenvec: bool,
l: MatrixLayout,
uplo: UPLO,
a: &mut [Self],
) -> Result<Vec<Self::Real>>;
unsafe fn eigh_generalized(
calc_eigenvec: bool,
l: MatrixLayout,
Expand All @@ -24,7 +29,12 @@ pub trait Eigh_: Scalar {
macro_rules! impl_eigh {
($scalar:ty, $ev:path, $evg:path) => {
impl Eigh_ for $scalar {
unsafe fn eigh(calc_v: bool, l: MatrixLayout, uplo: UPLO, mut a: &mut [Self]) -> Result<Vec<Self::Real>> {
unsafe fn eigh(
calc_v: bool,
l: MatrixLayout,
uplo: UPLO,
mut a: &mut [Self],
) -> Result<Vec<Self::Real>> {
let (n, _) = l.size();
let jobz = if calc_v { b'V' } else { b'N' };
let mut w = vec![Self::Real::zero(); n as usize];
Expand Down
Loading