Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alternative strategy for batched matrix multiplication #51

Merged
merged 3 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 212 additions & 48 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,24 @@ impl Operator for Gemm {
}
}

/// Hints for how a batched MatMul should be performed. This exists to enable
/// comparisons in tests and benchmarks.
#[derive(Copy, Clone, Debug, PartialEq)]
enum MatmulStrategy {
/// Use the best strategy for the input shapes.
Auto,

/// Perform separate GEMM calls for each pair of matrices to multiply in
/// the batch.
#[cfg(test)]
Batch,
}

pub fn matmul(a: TensorView, b: TensorView) -> Result<Tensor, OpError> {
matmul_impl(a, b, MatmulStrategy::Auto)
}

fn matmul_impl(a: TensorView, b: TensorView, strategy: MatmulStrategy) -> Result<Tensor, OpError> {
if a.ndim() < 2 || b.ndim() < 2 {
return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions"));
}
Expand All @@ -106,12 +123,31 @@ pub fn matmul(a: TensorView, b: TensorView) -> Result<Tensor, OpError> {

let a_prefix = &a.shape()[..a.ndim() - 2];
let b_prefix = &b.shape()[..b.ndim() - 2];

let num_a_matrices: usize = a_prefix.iter().product();
let num_b_matrices: usize = b_prefix.iter().product();

let out_prefix = broadcast_shapes(a_prefix, b_prefix)
.ok_or(OpError::IncompatibleInputShapes("Cannot broadcast shapes"))?;

let out_shape = &[out_prefix.as_slice(), &[a_rows, b_cols]].concat();
let mut output = Tensor::zeros(out_shape);

// A batched matrix multiplication with `[A, M, K] x [K, N]`, where `A` and
// can consist of multiple dimensions, can be converted to a non-batched
// matmul by reshaping the inputs as `[A * M, K]` * `[K, N]`, and then
// reshaping the `[A * M, N]` output to `[A, M, N]`.
//
// The upside is that one larger matmul is likely to be more efficient than
// `A` smaller matmuls. This is especially true if `M` is small (eg. 1).
if strategy == MatmulStrategy::Auto && a.ndim() > 2 && b.ndim() == 2 {
// nb. We assume `a` is likely already contiguous, so this will be cheap.
let a_contig = a.to_contiguous();
let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice());
let mut output = matmul(a_matrix, b.clone())?;
output.reshape(out_shape);
return Ok(output);
}

let mut output = Tensor::zeros(out_shape);
if output.is_empty() {
return Ok(output);
}
Expand All @@ -128,9 +164,6 @@ pub fn matmul(a: TensorView, b: TensorView) -> Result<Tensor, OpError> {
.unwrap()
.chunks_mut(out_row_stride * a_rows);

let num_a_matrices: usize = a_prefix.iter().product();
let num_b_matrices: usize = b_prefix.iter().product();

let gemm = GemmExecutor::new();

// Prepack re-used inputs to amortize packing cost.
Expand Down Expand Up @@ -196,10 +229,12 @@ mod tests {
use rten_tensor::prelude::*;
use rten_tensor::rng::XorShiftRng;
use rten_tensor::test_util::expect_equal;
use rten_tensor::Tensor;
use rten_tensor::{Tensor, TensorView, TensorViewMut};

use crate::gemm::gemm;
use crate::ops::matmul::{gemm_op, matmul, OpError};
use crate::test_util::run_bench;

use super::{gemm_op, matmul, matmul_impl, MatmulStrategy, OpError};

fn gemm_tensors(c: &mut Tensor, a: &Tensor, b: &Tensor, alpha: f32, beta: f32) {
c.make_contiguous();
Expand All @@ -214,6 +249,34 @@ mod tests {
)
}

/// Multiply matrices in `a` by corresponding matrices in `b` and write to
/// `c`. The shapes of `a` and `b` are broadcast so that their first N-2
/// dims match `c`.
fn reference_matmul(mut c: TensorViewMut, a: TensorView, b: TensorView) {
let a_batch_dims = a.ndim() - 2;
let b_batch_dims = b.ndim() - 2;
let out_prefix = &c.shape()[..c.ndim() - 2];

let a_bcast = [out_prefix, &a.shape()[a_batch_dims..]].concat();
let b_bcast = [out_prefix, &b.shape()[b_batch_dims..]].concat();

a.broadcast(a_bcast.as_slice())
.inner_iter::<2>()
.zip(b.broadcast(b_bcast.as_slice()).inner_iter::<2>())
.zip(c.inner_iter_mut::<2>())
.for_each(|((a, b), mut c)| {
let c_row_stride = c.stride(0);
gemm(
c.data_mut().unwrap(),
c_row_stride,
a,
b,
1., /* alpha */
0., /* beta */
)
});
}

#[test]
fn test_gemm_op() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);
Expand Down Expand Up @@ -286,15 +349,104 @@ mod tests {

#[test]
fn test_matmul() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[3, 10], &mut rng);
let b = Tensor::rand(&[10, 8], &mut rng);
struct Case<'a> {
a_shape: &'a [usize],
b_shape: &'a [usize],
out_shape: &'a [usize],
}

let mut expected = Tensor::zeros(&[3, 8]);
gemm_tensors(&mut expected, &a, &b, 1., 1.);
let cases = [
// Simple matmul
Case {
a_shape: &[3, 10],
b_shape: &[10, 8],
out_shape: &[3, 8],
},
// LHS input is a batch
Case {
a_shape: &[2, 3, 10],
b_shape: &[10, 8],
out_shape: &[2, 3, 8],
},
// RHS input is a batch
Case {
a_shape: &[3, 10],
b_shape: &[2, 10, 8],
out_shape: &[2, 3, 8],
},
// Both inputs are batches
Case {
a_shape: &[2, 3, 10],
b_shape: &[2, 10, 8],
out_shape: &[2, 3, 8],
},
];

let result = matmul(a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
for Case {
a_shape,
b_shape,
out_shape,
} in cases
{
let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(a_shape, &mut rng);
let b = Tensor::rand(b_shape, &mut rng);
let mut expected = Tensor::zeros(out_shape);

reference_matmul(expected.view_mut(), a.view(), b.view());
let result = matmul(a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
}

Ok(())
}

#[test]
fn test_matmul_invalid() -> Result<(), Box<dyn Error>> {
struct Case<'a> {
a_shape: &'a [usize],
b_shape: &'a [usize],
error: OpError,
}

let cases = [
Case {
a_shape: &[3],
b_shape: &[10, 8],
error: OpError::InvalidValue("Inputs must have >= 2 dimensions"),
},
Case {
a_shape: &[3, 10],
b_shape: &[10],
error: OpError::InvalidValue("Inputs must have >= 2 dimensions"),
},
Case {
a_shape: &[3, 10],
b_shape: &[11, 8],
error: OpError::IncompatibleInputShapes(
"Columns of first matrix does not match rows of second matrix",
),
},
Case {
a_shape: &[2, 3, 10],
b_shape: &[3, 10, 8],
error: OpError::IncompatibleInputShapes("Cannot broadcast shapes"),
},
];

for Case {
a_shape,
b_shape,
error,
} in cases
{
let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(a_shape, &mut rng);
let b = Tensor::rand(b_shape, &mut rng);

let result = matmul(a.view(), b.view());
assert_eq!(result, Err(error));
}

Ok(())
}
Expand Down Expand Up @@ -327,42 +479,54 @@ mod tests {
}

#[test]
fn test_matmul_broadcast() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);
let mut a = Tensor::rand(&[3, 10], &mut rng);
let mut b = Tensor::rand(&[10, 8], &mut rng);

let mut expected = Tensor::zeros(&[3, 8]);
gemm_tensors(&mut expected, &a, &b, 1., 1.);
expected.reshape(&[1, 1, 3, 8]);

// LHS input has excess 1 dims
a.reshape(&[1, 1, 3, 10]);
let result = matmul(a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
#[ignore]
fn bench_matmul() {
struct Case {
a_batch: usize,
a_rows: usize,
a_cols: usize,
b_cols: usize,
}

// RHS input has excess 1 dims
a.reshape(&[3, 10]);
b.reshape(&[1, 1, 10, 8]);
let result = matmul(a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let mut cases = Vec::new();
let a_cols = 512;
let b_cols = 1536;

for a_batch in [1, 10, 128, 256, 512, 1024] {
for a_rows in [1, 16, 32, 64] {
cases.push(Case {
a_batch,
a_rows,
a_cols,
b_cols,
});
}
}

// RHS input requires broadcasting
let broadcast_a_shape = &[1, 4, 3, 10][..];
let broadcast_expected_shape = &[1, 4, 3, 8][..];
let broadcast_a = a.broadcast(broadcast_a_shape);
let broadcast_expected = expected.broadcast(broadcast_expected_shape);
let result = matmul(broadcast_a, b.view()).unwrap();
expect_equal(&result.view(), &broadcast_expected)?;

// LHS input requires broadcasting
let broadcast_b_shape = &[1, 3, 10, 8][..];
let broadcast_expected_shape = &[1, 3, 3, 8][..];
let broadcast_b = b.broadcast(broadcast_b_shape);
let expected = expected.broadcast(broadcast_expected_shape);
let result = matmul(a.view(), broadcast_b).unwrap();
expect_equal(&result.view(), &expected)?;
for Case {
a_batch,
a_rows,
a_cols,
b_cols,
} in cases
{
let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[a_batch, a_rows, a_cols], &mut rng);
let b = Tensor::rand(&[a_cols, b_cols], &mut rng);

let run_trial = |strategy| {
let trials = 10;
let desc = format!(
"matmul [{a_batch},{a_rows},{a_cols}] x [{a_cols},{b_cols}], strategy={strategy:?}",
);
run_bench(trials, &desc, || {
matmul_impl(a.view(), b.view(), strategy).unwrap();
});
};

Ok(())
run_trial(MatmulStrategy::Batch);
run_trial(MatmulStrategy::Auto);
println!();
}
}
}
2 changes: 1 addition & 1 deletion src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub fn run_bench<F: FnMut()>(trials: usize, description: &str, mut f: F) {
return;
}

let mut times = Vec::new();
let mut times = Vec::with_capacity(trials);
for _ in 0..trials {
let mut t = Timer::new();
t.start();
Expand Down
Loading