Skip to content

Commit

Permalink
Merge pull request #1421 from rust-ndarray/blas-simplify
Browse files Browse the repository at this point in the history
Refactor and simplify BLAS gemm call further
  • Loading branch information
bluss authored Aug 14, 2024
2 parents 1df6c32 + 876ad01 commit 33e2a58
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 156 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ keywords = ["array", "data-structure", "multidimensional", "matrix", "blas"]
categories = ["data-structures", "science"]

exclude = ["docgen/images/*"]
resolver = "2"

[lib]
name = "ndarray"
Expand Down
4 changes: 2 additions & 2 deletions crates/blas-mock-tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ doc = false
doctest = false

[dependencies]
ndarray = { workspace = true, features = ["approx", "blas"] }
ndarray-gen = { workspace = true }
cblas-sys = { workspace = true }

[dev-dependencies]
ndarray = { workspace = true, features = ["approx", "blas"] }
ndarray-gen = { workspace = true }
itertools = { workspace = true }
16 changes: 10 additions & 6 deletions crates/blas-tests/tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use ndarray::linalg::general_mat_vec_mul;
use ndarray::Order;
use ndarray::{Data, Ix, LinalgScalar};
use ndarray_gen::array_builder::ArrayBuilder;
use ndarray_gen::array_builder::ElementGenerator;

use approx::assert_relative_eq;
use defmac::defmac;
Expand Down Expand Up @@ -230,7 +231,6 @@ fn gen_mat_mul()
let sizes = vec![
(4, 4, 4),
(8, 8, 8),
(10, 10, 10),
(8, 8, 1),
(1, 10, 10),
(10, 1, 10),
Expand All @@ -241,19 +241,23 @@ fn gen_mat_mul()
(4, 17, 3),
(17, 3, 22),
(19, 18, 2),
(16, 17, 15),
(15, 16, 17),
(67, 63, 62),
(67, 50, 62),
];
let strides = &[1, 2, -1, -2];
let cf_order = [Order::C, Order::F];
let generator = [ElementGenerator::Sequential, ElementGenerator::Checkerboard];

// test different strides and memory orders
for (&s1, &s2) in iproduct!(strides, strides) {
for (&s1, &s2, &gen) in iproduct!(strides, strides, &generator) {
for &(m, k, n) in &sizes {
for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);
let a = ArrayBuilder::new((m, k)).memory_order(ord1).build() * 0.5;
println!("Case s1={}, s2={}, gen={:?}, orders={:?}, {:?}, {:?}", s1, s2, gen, ord1, ord2, ord3);
let a = ArrayBuilder::new((m, k))
.memory_order(ord1)
.generator(gen)
.build()
* 0.5;
let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build();

Expand Down
17 changes: 8 additions & 9 deletions crates/ndarray-gen/src/array_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct ArrayBuilder<D: Dimension>
pub enum ElementGenerator
{
Sequential,
Checkerboard,
Zero,
}

Expand Down Expand Up @@ -64,16 +65,14 @@ where D: Dimension
pub fn build<T>(self) -> Array<T, D>
where T: Num + Clone
{
let mut current = T::zero();
let zero = T::zero();
let size = self.dim.size();
let use_zeros = self.generator == ElementGenerator::Zero;
Array::from_iter((0..size).map(|_| {
let ret = current.clone();
if !use_zeros {
current = ret.clone() + T::one();
}
ret
}))
(match self.generator {
ElementGenerator::Sequential =>
Array::from_iter(core::iter::successors(Some(zero), |elt| Some(elt.clone() + T::one())).take(size)),
ElementGenerator::Checkerboard => Array::from_iter([T::one(), zero].iter().cycle().take(size).cloned()),
ElementGenerator::Zero => Array::zeros(size),
})
.into_shape_with_order((self.dim, self.memory_order))
.unwrap()
}
Expand Down
1 change: 1 addition & 0 deletions scripts/cross-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ QC_FEAT=--features=ndarray-rand/quickcheck

cross build -v --features="$FEATURES" $QC_FEAT --target=$TARGET
cross test -v --no-fail-fast --features="$FEATURES" $QC_FEAT --target=$TARGET
cross test -v -p blas-mock-tests
2 changes: 1 addition & 1 deletion scripts/makechangelog.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Will produce some duplicates for PRs integrated using rebase,
# but those will not occur with current merge queue.

git log --first-parent --pretty="format:%H" "$@" | while read commit_sha
git log --first-parent --pretty="tformat:%H" "$@" | while IFS= read -r commit_sha
do
gh api "/repos/:owner/:repo/commits/${commit_sha}/pulls" \
-q ".[] | \"- \(.title) by [@\(.user.login)](\(.user.html_url)) [#\(.number)](\(.html_url))\""
Expand Down
Loading

0 comments on commit 33e2a58

Please sign in to comment.