diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index b978a1ab..bd6aa354 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -87,7 +87,24 @@ impl Operator for Gemm { } } +/// Hints for how a batched MatMul should be performed. This exists to enable +/// comparisons in tests and benchmarks. +#[derive(Copy, Clone, Debug, PartialEq)] +enum MatmulStrategy { + /// Use the best strategy for the input shapes. + Auto, + + /// Perform separate GEMM calls for each pair of matrices to multiply in + /// the batch. + #[cfg(test)] + Batch, +} + pub fn matmul(a: TensorView, b: TensorView) -> Result { + matmul_impl(a, b, MatmulStrategy::Auto) +} + +fn matmul_impl(a: TensorView, b: TensorView, strategy: MatmulStrategy) -> Result { if a.ndim() < 2 || b.ndim() < 2 { return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions")); } @@ -106,12 +123,31 @@ pub fn matmul(a: TensorView, b: TensorView) -> Result { let a_prefix = &a.shape()[..a.ndim() - 2]; let b_prefix = &b.shape()[..b.ndim() - 2]; + + let num_a_matrices: usize = a_prefix.iter().product(); + let num_b_matrices: usize = b_prefix.iter().product(); + let out_prefix = broadcast_shapes(a_prefix, b_prefix) .ok_or(OpError::IncompatibleInputShapes("Cannot broadcast shapes"))?; - let out_shape = &[out_prefix.as_slice(), &[a_rows, b_cols]].concat(); - let mut output = Tensor::zeros(out_shape); + // A batched matrix multiplication with `[A, M, K] x [K, N]`, where `A` and + // can consist of multiple dimensions, can be converted to a non-batched + // matmul by reshaping the inputs as `[A * M, K]` * `[K, N]`, and then + // reshaping the `[A * M, N]` output to `[A, M, N]`. + // + // The upside is that one larger matmul is likely to be more efficient than + // `A` smaller matmuls. This is especially true if `M` is small (eg. 1). + if strategy == MatmulStrategy::Auto && a.ndim() > 2 && b.ndim() == 2 { + // nb. We assume `a` is likely already contiguous, so this will be cheap. + let a_contig = a.to_contiguous(); + let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice()); + let mut output = matmul(a_matrix, b.clone())?; + output.reshape(out_shape); + return Ok(output); + } + + let mut output = Tensor::zeros(out_shape); if output.is_empty() { return Ok(output); } @@ -128,9 +164,6 @@ pub fn matmul(a: TensorView, b: TensorView) -> Result { .unwrap() .chunks_mut(out_row_stride * a_rows); - let num_a_matrices: usize = a_prefix.iter().product(); - let num_b_matrices: usize = b_prefix.iter().product(); - let gemm = GemmExecutor::new(); // Prepack re-used inputs to amortize packing cost. @@ -196,10 +229,12 @@ mod tests { use rten_tensor::prelude::*; use rten_tensor::rng::XorShiftRng; use rten_tensor::test_util::expect_equal; - use rten_tensor::Tensor; + use rten_tensor::{Tensor, TensorView, TensorViewMut}; use crate::gemm::gemm; - use crate::ops::matmul::{gemm_op, matmul, OpError}; + use crate::test_util::run_bench; + + use super::{gemm_op, matmul, matmul_impl, MatmulStrategy, OpError}; fn gemm_tensors(c: &mut Tensor, a: &Tensor, b: &Tensor, alpha: f32, beta: f32) { c.make_contiguous(); @@ -214,6 +249,34 @@ mod tests { ) } + /// Multiply matrices in `a` by corresponding matrices in `b` and write to + /// `c`. The shapes of `a` and `b` are broadcast so that their first N-2 + /// dims match `c`. + fn reference_matmul(mut c: TensorViewMut, a: TensorView, b: TensorView) { + let a_batch_dims = a.ndim() - 2; + let b_batch_dims = b.ndim() - 2; + let out_prefix = &c.shape()[..c.ndim() - 2]; + + let a_bcast = [out_prefix, &a.shape()[a_batch_dims..]].concat(); + let b_bcast = [out_prefix, &b.shape()[b_batch_dims..]].concat(); + + a.broadcast(a_bcast.as_slice()) + .inner_iter::<2>() + .zip(b.broadcast(b_bcast.as_slice()).inner_iter::<2>()) + .zip(c.inner_iter_mut::<2>()) + .for_each(|((a, b), mut c)| { + let c_row_stride = c.stride(0); + gemm( + c.data_mut().unwrap(), + c_row_stride, + a, + b, + 1., /* alpha */ + 0., /* beta */ + ) + }); + } + #[test] fn test_gemm_op() -> Result<(), Box> { let mut rng = XorShiftRng::new(1234); @@ -286,15 +349,104 @@ mod tests { #[test] fn test_matmul() -> Result<(), Box> { - let mut rng = XorShiftRng::new(1234); - let a = Tensor::rand(&[3, 10], &mut rng); - let b = Tensor::rand(&[10, 8], &mut rng); + struct Case<'a> { + a_shape: &'a [usize], + b_shape: &'a [usize], + out_shape: &'a [usize], + } - let mut expected = Tensor::zeros(&[3, 8]); - gemm_tensors(&mut expected, &a, &b, 1., 1.); + let cases = [ + // Simple matmul + Case { + a_shape: &[3, 10], + b_shape: &[10, 8], + out_shape: &[3, 8], + }, + // LHS input is a batch + Case { + a_shape: &[2, 3, 10], + b_shape: &[10, 8], + out_shape: &[2, 3, 8], + }, + // RHS input is a batch + Case { + a_shape: &[3, 10], + b_shape: &[2, 10, 8], + out_shape: &[2, 3, 8], + }, + // Both inputs are batches + Case { + a_shape: &[2, 3, 10], + b_shape: &[2, 10, 8], + out_shape: &[2, 3, 8], + }, + ]; - let result = matmul(a.view(), b.view()).unwrap(); - expect_equal(&result, &expected)?; + for Case { + a_shape, + b_shape, + out_shape, + } in cases + { + let mut rng = XorShiftRng::new(1234); + let a = Tensor::rand(a_shape, &mut rng); + let b = Tensor::rand(b_shape, &mut rng); + let mut expected = Tensor::zeros(out_shape); + + reference_matmul(expected.view_mut(), a.view(), b.view()); + let result = matmul(a.view(), b.view()).unwrap(); + expect_equal(&result, &expected)?; + } + + Ok(()) + } + + #[test] + fn test_matmul_invalid() -> Result<(), Box> { + struct Case<'a> { + a_shape: &'a [usize], + b_shape: &'a [usize], + error: OpError, + } + + let cases = [ + Case { + a_shape: &[3], + b_shape: &[10, 8], + error: OpError::InvalidValue("Inputs must have >= 2 dimensions"), + }, + Case { + a_shape: &[3, 10], + b_shape: &[10], + error: OpError::InvalidValue("Inputs must have >= 2 dimensions"), + }, + Case { + a_shape: &[3, 10], + b_shape: &[11, 8], + error: OpError::IncompatibleInputShapes( + "Columns of first matrix does not match rows of second matrix", + ), + }, + Case { + a_shape: &[2, 3, 10], + b_shape: &[3, 10, 8], + error: OpError::IncompatibleInputShapes("Cannot broadcast shapes"), + }, + ]; + + for Case { + a_shape, + b_shape, + error, + } in cases + { + let mut rng = XorShiftRng::new(1234); + let a = Tensor::rand(a_shape, &mut rng); + let b = Tensor::rand(b_shape, &mut rng); + + let result = matmul(a.view(), b.view()); + assert_eq!(result, Err(error)); + } Ok(()) } @@ -327,42 +479,54 @@ mod tests { } #[test] - fn test_matmul_broadcast() -> Result<(), Box> { - let mut rng = XorShiftRng::new(1234); - let mut a = Tensor::rand(&[3, 10], &mut rng); - let mut b = Tensor::rand(&[10, 8], &mut rng); - - let mut expected = Tensor::zeros(&[3, 8]); - gemm_tensors(&mut expected, &a, &b, 1., 1.); - expected.reshape(&[1, 1, 3, 8]); - - // LHS input has excess 1 dims - a.reshape(&[1, 1, 3, 10]); - let result = matmul(a.view(), b.view()).unwrap(); - expect_equal(&result, &expected)?; + #[ignore] + fn bench_matmul() { + struct Case { + a_batch: usize, + a_rows: usize, + a_cols: usize, + b_cols: usize, + } - // RHS input has excess 1 dims - a.reshape(&[3, 10]); - b.reshape(&[1, 1, 10, 8]); - let result = matmul(a.view(), b.view()).unwrap(); - expect_equal(&result, &expected)?; + let mut cases = Vec::new(); + let a_cols = 512; + let b_cols = 1536; + + for a_batch in [1, 10, 128, 256, 512, 1024] { + for a_rows in [1, 16, 32, 64] { + cases.push(Case { + a_batch, + a_rows, + a_cols, + b_cols, + }); + } + } - // RHS input requires broadcasting - let broadcast_a_shape = &[1, 4, 3, 10][..]; - let broadcast_expected_shape = &[1, 4, 3, 8][..]; - let broadcast_a = a.broadcast(broadcast_a_shape); - let broadcast_expected = expected.broadcast(broadcast_expected_shape); - let result = matmul(broadcast_a, b.view()).unwrap(); - expect_equal(&result.view(), &broadcast_expected)?; - - // LHS input requires broadcasting - let broadcast_b_shape = &[1, 3, 10, 8][..]; - let broadcast_expected_shape = &[1, 3, 3, 8][..]; - let broadcast_b = b.broadcast(broadcast_b_shape); - let expected = expected.broadcast(broadcast_expected_shape); - let result = matmul(a.view(), broadcast_b).unwrap(); - expect_equal(&result.view(), &expected)?; + for Case { + a_batch, + a_rows, + a_cols, + b_cols, + } in cases + { + let mut rng = XorShiftRng::new(1234); + let a = Tensor::rand(&[a_batch, a_rows, a_cols], &mut rng); + let b = Tensor::rand(&[a_cols, b_cols], &mut rng); + + let run_trial = |strategy| { + let trials = 10; + let desc = format!( + "matmul [{a_batch},{a_rows},{a_cols}] x [{a_cols},{b_cols}], strategy={strategy:?}", + ); + run_bench(trials, &desc, || { + matmul_impl(a.view(), b.view(), strategy).unwrap(); + }); + }; - Ok(()) + run_trial(MatmulStrategy::Batch); + run_trial(MatmulStrategy::Auto); + println!(); + } } } diff --git a/src/test_util.rs b/src/test_util.rs index 3b85058c..8fd56f31 100644 --- a/src/test_util.rs +++ b/src/test_util.rs @@ -7,7 +7,7 @@ pub fn run_bench(trials: usize, description: &str, mut f: F) { return; } - let mut times = Vec::new(); + let mut times = Vec::with_capacity(trials); for _ in 0..trials { let mut t = Timer::new(); t.start();