From 5562fd239dc7f92331330715c7304638c05c5eb3 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 4 Feb 2024 09:09:40 +0000 Subject: [PATCH] Add fast path for `TensorBase::map` for contiguous tensors This fast path already existed for `TensorBase::apply`. In a simple transformer model with ~25M params, this reduced time over 9 `Pow` operations used to square inputs with 128*512 elements from ~1.2ms to ~0.2ms. --- rten-tensor/src/tensor.rs | 25 ++++++++++++++++++++++--- src/ops/binary_elementwise.rs | 4 ++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/rten-tensor/src/tensor.rs b/rten-tensor/src/tensor.rs index 1577b833..a5a39a4c 100644 --- a/rten-tensor/src/tensor.rs +++ b/rten-tensor/src/tensor.rs @@ -385,8 +385,9 @@ impl + AsMut<[T]>, L: MutLayout> TensorBase { /// Replace each element in this tensor with the result of applying `f` to /// the element. pub fn apply T>(&mut self, f: F) { - if self.is_contiguous() { - self.data.as_mut().iter_mut().for_each(|x| *x = f(x)); + if let Some(data) = self.data_mut() { + // Fast path for contiguous tensors. + data.iter_mut().for_each(|x| *x = f(x)); } else { self.iter_mut().for_each(|x| *x = f(x)); } @@ -1086,7 +1087,12 @@ impl, L: MutLayout + Clone> AsView for TensorBase { where F: Fn(&Self::Elem) -> U, { - let data: Vec<_> = self.iter().map(f).collect(); + let data: Vec = if let Some(data) = self.data() { + // Fast path for contiguous tensors. + data.iter().map(f).collect() + } else { + self.iter().map(f).collect() + }; TensorBase::from_data(self.shape(), data) } @@ -1545,9 +1551,16 @@ mod tests { #[test] fn test_apply() { let data = vec![1., 2., 3., 4.]; + + // Contiguous tensor. let mut tensor = NdTensor::from_data([2, 2], data); tensor.apply(|x| *x * 2.); assert_eq!(tensor.to_vec(), &[2., 4., 6., 8.]); + + // Non-contiguous tensor + tensor.transpose(); + tensor.apply(|x| *x / 2.); + assert_eq!(tensor.to_vec(), &[1., 3., 2., 4.]); } #[test] @@ -2198,8 +2211,14 @@ mod tests { fn test_map() { let data = vec![1., 2., 3., 4.]; let tensor = NdTensor::from_data([2, 2], data); + + // Contiguous tensor let doubled = tensor.map(|x| x * 2.); assert_eq!(doubled.to_vec(), &[2., 4., 6., 8.]); + + // Non-contiguous tensor + let halved = doubled.transposed().map(|x| x / 2.); + assert_eq!(halved.to_vec(), &[1., 3., 2., 4.]); } #[test] diff --git a/src/ops/binary_elementwise.rs b/src/ops/binary_elementwise.rs index 387ae9d3..5c44ab14 100644 --- a/src/ops/binary_elementwise.rs +++ b/src/ops/binary_elementwise.rs @@ -657,8 +657,8 @@ fn powf(x: f32, y: f32) -> f32 { /// Raise elements of `a` to powers of corresponding elements in `b`. pub fn pow(a: TensorView, b: TensorView) -> Result { - if let Some(exp) = b.item() { - Ok(a.map(|x| powf(*x, *exp))) + if let Some(&exp) = b.item() { + Ok(a.map(|x| powf(*x, exp))) } else { binary_op(a, b, |x, y| x.powf(y)) }