diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 7b91a261c7e1..bb9b01b13aa7 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -68,6 +68,8 @@ where LT: ArrowNumericType, RT: ArrowNumericType, F: Fn(LT::Native, RT::Native) -> LT::Native, + LT::Native: ArrowNativeTypeOp, + RT::Native: ArrowNativeTypeOp, { if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -78,6 +80,24 @@ where Ok(binary(left, right, op)) } +/// This is similar to `math_op` as it performs given operation between two input primitive arrays. +/// But the given operation can return `Err` if overflow is detected. For the case, this function +/// returns an `Err`. +fn math_checked_op( + left: &PrimitiveArray, + right: &PrimitiveArray, + op: F, +) -> Result> +where + LT: ArrowNumericType, + RT: ArrowNumericType, + F: Fn(LT::Native, RT::Native) -> Result, + LT::Native: ArrowNativeTypeOp, + RT::Native: ArrowNativeTypeOp, +{ + try_binary(left, right, op) +} + /// Helper function for operations where a valid `0` on the right array should /// result in an [ArrowError::DivideByZero], namely the division and modulo operations /// @@ -522,57 +542,64 @@ macro_rules! typed_dict_math_op { }}; } -/// Helper function to perform math lambda function on values from two dictionary arrays, this -/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) -macro_rules! math_dict_op { - ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ - if $left.len() != $right.len() { - return Err(ArrowError::ComputeError(format!( - "Cannot perform operation on arrays of different length ({}, {})", - $left.len(), - $right.len() - ))); - } +/// Perform given operation on two `DictionaryArray`s. +/// Returns an error if the two arrays have different value type +fn math_op_dict( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result> +where + K: ArrowNumericType, + T: ArrowNumericType, + F: Fn(T::Native, T::Native) -> T::Native, + T::Native: ArrowNativeTypeOp, +{ + if left.len() != right.len() { + return Err(ArrowError::ComputeError(format!( + "Cannot perform operation on arrays of different length ({}, {})", + left.len(), + right.len() + ))); + } - // Safety justification: Since the inputs are valid Arrow arrays, all values are - // valid indexes into the dictionary (which is verified during construction) - - let left_iter = unsafe { - $left - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($left.keys_iter()) - }; - - let right_iter = unsafe { - $right - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($right.keys_iter()) - }; - - let result = left_iter - .zip(right_iter) - .map(|(left_value, right_value)| { - if let (Some(left), Some(right)) = (left_value, right_value) { - Some($op(left, right)) - } else { - None - } - }) - .collect(); + // Safety justification: Since the inputs are valid Arrow arrays, all values are + // valid indexes into the dictionary (which is verified during construction) - Ok(result) - }}; + let left_iter = unsafe { + left.values() + .as_any() + .downcast_ref::>() + .unwrap() + .take_iter_unchecked(left.keys_iter()) + }; + + let right_iter = unsafe { + right + .values() + .as_any() + .downcast_ref::>() + .unwrap() + .take_iter_unchecked(right.keys_iter()) + }; + + let result = left_iter + .zip(right_iter) + .map(|(left_value, right_value)| { + if let (Some(left), Some(right)) = (left_value, right_value) { + Some(op(left, right)) + } else { + None + } + }) + .collect(); + + Ok(result) } /// Perform given operation on two `DictionaryArray`s. /// Returns an error if the two arrays have different value type -fn math_op_dict( +fn math_checked_op_dict( left: &DictionaryArray, right: &DictionaryArray, op: F, @@ -580,9 +607,21 @@ fn math_op_dict( where K: ArrowNumericType, T: ArrowNumericType, - F: Fn(T::Native, T::Native) -> T::Native, + F: Fn(T::Native, T::Native) -> Result, + T::Native: ArrowNativeTypeOp, { - math_dict_op!(left, right, op, PrimitiveArray) + // left and right's value types are supposed to be same as guaranteed by the caller macro now. + if left.value_type() != T::DATA_TYPE { + return Err(ArrowError::NotYetImplemented(format!( + "Cannot perform provided operation on dictionary array of value type {}", + left.value_type() + ))); + } + + let left = left.downcast_dict::>().unwrap(); + let right = right.downcast_dict::>().unwrap(); + + try_binary(left, right, op) } /// Helper function for operations where a valid `0` on the right array should @@ -678,10 +717,13 @@ where /// Perform `left + right` operation on two arrays. If either left or right value is null /// then the result is also null. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `add_dyn_checked` instead. pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_math_op!(left, right, |a, b| a + b, math_op_dict) + typed_dict_math_op!(left, right, |a, b| a.add_wrapping(b), math_op_dict) } DataType::Date32 => { let l = as_primitive_array::(left); @@ -734,7 +776,84 @@ pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { _ => { downcast_primitive_array!( (left, right) => { - math_op(left, right, |a, b| a + b).map(|a| Arc::new(a) as ArrayRef) + math_op(left, right, |a, b| a.add_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), right.data_type() + ))) + ) + } + } +} + +/// Perform `left + right` operation on two arrays. If either left or right value is null +/// then the result is also null. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `add_dyn` instead. +pub fn add_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!( + left, + right, + |a, b| a.add_checked(b), + math_checked_op_dict + ) + } + DataType::Date32 => { + let l = as_primitive_array::(left); + match right.data_type() { + DataType::Interval(IntervalUnit::YearMonth) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date32Type::add_year_months)?; + Ok(Arc::new(res)) + } + DataType::Interval(IntervalUnit::DayTime) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date32Type::add_day_time)?; + Ok(Arc::new(res)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date32Type::add_month_day_nano)?; + Ok(Arc::new(res)) + } + _ => Err(ArrowError::CastError(format!( + "Cannot perform arithmetic operation between array of type {} and array of type {}", + left.data_type(), right.data_type() + ))), + } + } + DataType::Date64 => { + let l = as_primitive_array::(left); + match right.data_type() { + DataType::Interval(IntervalUnit::YearMonth) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date64Type::add_year_months)?; + Ok(Arc::new(res)) + } + DataType::Interval(IntervalUnit::DayTime) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date64Type::add_day_time)?; + Ok(Arc::new(res)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date64Type::add_month_day_nano)?; + Ok(Arc::new(res)) + } + _ => Err(ArrowError::CastError(format!( + "Cannot perform arithmetic operation between array of type {} and array of type {}", + left.data_type(), right.data_type() + ))), + } + } + _ => { + downcast_primitive_array!( + (left, right) => { + math_checked_op(left, right, |a, b| a.add_checked(b)).map(|a| Arc::new(a) as ArrayRef) } _ => Err(ArrowError::CastError(format!( "Unsupported data type {}, {}", @@ -845,15 +964,47 @@ where /// Perform `left - right` operation on two arrays. If either left or right value is null /// then the result is also null. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `subtract_dyn_checked` instead. pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_math_op!(left, right, |a, b| a - b, math_op_dict) + typed_dict_math_op!(left, right, |a, b| a.sub_wrapping(b), math_op_dict) } _ => { downcast_primitive_array!( (left, right) => { - math_op(left, right, |a, b| a - b).map(|a| Arc::new(a) as ArrayRef) + math_op(left, right, |a, b| a.sub_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), right.data_type() + ))) + ) + } + } +} + +/// Perform `left - right` operation on two arrays. If either left or right value is null +/// then the result is also null. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `subtract_dyn` instead. +pub fn subtract_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!( + left, + right, + |a, b| a.sub_checked(b), + math_checked_op_dict + ) + } + _ => { + downcast_primitive_array!( + (left, right) => { + math_checked_op(left, right, |a, b| a.sub_checked(b)).map(|a| Arc::new(a) as ArrayRef) } _ => Err(ArrowError::CastError(format!( "Unsupported data type {}, {}", @@ -983,15 +1134,47 @@ where /// Perform `left * right` operation on two arrays. If either left or right value is null /// then the result is also null. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `multiply_dyn_checked` instead. pub fn multiply_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_math_op!(left, right, |a, b| a * b, math_op_dict) + typed_dict_math_op!(left, right, |a, b| a.mul_wrapping(b), math_op_dict) } _ => { downcast_primitive_array!( (left, right) => { - math_op(left, right, |a, b| a * b).map(|a| Arc::new(a) as ArrayRef) + math_op(left, right, |a, b| a.mul_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), right.data_type() + ))) + ) + } + } +} + +/// Perform `left * right` operation on two arrays. If either left or right value is null +/// then the result is also null. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `multiply_dyn` instead. +pub fn multiply_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!( + left, + right, + |a, b| a.mul_checked(b), + math_checked_op_dict + ) + } + _ => { + downcast_primitive_array!( + (left, right) => { + math_checked_op(left, right, |a, b| a.mul_checked(b)).map(|a| Arc::new(a) as ArrayRef) } _ => Err(ArrowError::CastError(format!( "Unsupported data type {}, {}", @@ -1140,7 +1323,52 @@ where /// Perform `left / right` operation on two arrays. If either left or right value is null /// then the result is also null. If any right hand value is zero then the result of this /// operation will be `Err(ArrowError::DivideByZero)`. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `divide_dyn_checked` instead. pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!( + left, + right, + |a, b| { + if b.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(a.div_wrapping(b)) + } + }, + math_divide_checked_op_dict + ) + } + _ => { + downcast_primitive_array!( + (left, right) => { + math_checked_divide_op(left, right, |a, b| { + if b.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(a.div_wrapping(b)) + } + }).map(|a| Arc::new(a) as ArrayRef) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), right.data_type() + ))) + ) + } + } +} + +/// Perform `left / right` operation on two arrays. If either left or right value is null +/// then the result is also null. If any right hand value is zero then the result of this +/// operation will be `Err(ArrowError::DivideByZero)`. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `divide_dyn` instead. +pub fn divide_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { typed_dict_math_op!( @@ -2363,4 +2591,140 @@ mod tests { let expected = Int32Array::from(vec![None]); assert_eq!(expected, overflow.unwrap()); } + + #[test] + fn test_primitive_add_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + let b = Int32Array::from(vec![1, 1]); + + let wrapped = add_dyn(&a, &b).unwrap(); + let expected = + Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = add_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_dictionary_add_dyn_wrapping_overflow() { + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(2, 2); + builder.append(i32::MAX).unwrap(); + builder.append(i32::MIN).unwrap(); + let a = builder.finish(); + + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(2, 2); + builder.append(1).unwrap(); + builder.append(1).unwrap(); + let b = builder.finish(); + + let wrapped = add_dyn(&a, &b).unwrap(); + let expected = + Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = add_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_subtract_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![-2]); + let b = Int32Array::from(vec![i32::MAX]); + + let wrapped = subtract_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = subtract_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_dictionary_subtract_dyn_wrapping_overflow() { + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(-2).unwrap(); + let a = builder.finish(); + + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(i32::MAX).unwrap(); + let b = builder.finish(); + + let wrapped = subtract_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = subtract_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_mul_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![10]); + let b = Int32Array::from(vec![i32::MAX]); + + let wrapped = multiply_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = multiply_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_dictionary_mul_dyn_wrapping_overflow() { + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(10).unwrap(); + let a = builder.finish(); + + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(i32::MAX).unwrap(); + let b = builder.finish(); + + let wrapped = multiply_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = multiply_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_div_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MIN]); + let b = Int32Array::from(vec![-1]); + + let wrapped = divide_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = divide_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_dictionary_div_dyn_wrapping_overflow() { + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(i32::MIN).unwrap(); + let a = builder.finish(); + + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(-1).unwrap(); + let b = builder.finish(); + + let wrapped = divide_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = divide_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } } diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 216e3bfcac30..7174de0ce44e 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -18,7 +18,8 @@ //! Defines kernels suitable to perform operations to primitive arrays. use crate::array::{ - Array, ArrayData, ArrayIter, ArrayRef, BufferBuilder, DictionaryArray, PrimitiveArray, + Array, ArrayAccessor, ArrayData, ArrayIter, ArrayRef, BufferBuilder, DictionaryArray, + PrimitiveArray, }; use crate::buffer::Buffer; use crate::compute::util::combine_option_bitmap; @@ -26,6 +27,7 @@ use crate::datatypes::{ArrowNumericType, ArrowPrimitiveType}; use crate::downcast_dictionary_array; use crate::error::{ArrowError, Result}; use crate::util::bit_iterator::try_for_each_valid_idx; +use arrow_buffer::MutableBuffer; use std::sync::Arc; #[inline] @@ -276,16 +278,14 @@ where /// /// Return an error if the arrays have different lengths or /// the operation is under erroneous -pub fn try_binary( - a: &PrimitiveArray, - b: &PrimitiveArray, +pub fn try_binary( + a: A, + b: B, op: F, ) -> Result> where - A: ArrowPrimitiveType, - B: ArrowPrimitiveType, O: ArrowPrimitiveType, - F: Fn(A::Native, B::Native) -> Result, + F: Fn(A::Item, B::Item) -> Result, { if a.len() != b.len() { return Err(ArrowError::ComputeError( @@ -298,36 +298,52 @@ where let len = a.len(); if a.null_count() == 0 && b.null_count() == 0 { - let values = a.values().iter().zip(b.values()).map(|(l, r)| op(*l, *r)); - let buffer = unsafe { Buffer::try_from_trusted_len_iter(values) }?; - // JUSTIFICATION - // Benefit - // ~75% speedup - // Soundness - // `values` is an iterator with a known size from a PrimitiveArray - return Ok(unsafe { build_primitive_array(len, buffer, 0, None) }); + try_binary_no_nulls(len, a, b, op) + } else { + let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); + + let null_count = null_buffer + .as_ref() + .map(|x| len - x.count_set_bits()) + .unwrap_or_default(); + + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { + unsafe { + *slice.get_unchecked_mut(idx) = + op(a.value_unchecked(idx), b.value_unchecked(idx))? + }; + Ok::<_, ArrowError>(()) + })?; + + Ok(unsafe { + build_primitive_array(len, buffer.finish(), null_count, null_buffer) + }) } +} - let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); - - let null_count = null_buffer - .as_ref() - .map(|x| len - x.count_set_bits()) - .unwrap_or_default(); - - let mut buffer = BufferBuilder::::new(len); - buffer.append_n_zeroed(len); - let slice = buffer.as_slice_mut(); - - try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { +/// This intentional inline(never) attribute helps LLVM optimize the loop. +#[inline(never)] +fn try_binary_no_nulls( + len: usize, + a: A, + b: B, + op: F, +) -> Result> +where + O: ArrowPrimitiveType, + F: Fn(A::Item, B::Item) -> Result, +{ + let mut buffer = MutableBuffer::new(len * O::get_byte_width()); + for idx in 0..len { unsafe { - *slice.get_unchecked_mut(idx) = - op(a.value_unchecked(idx), b.value_unchecked(idx))? + buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?); }; - Ok::<_, ArrowError>(()) - })?; - - Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) }) + } + Ok(unsafe { build_primitive_array(len, buffer.into(), 0, None) }) } /// Applies the provided binary operation across `a` and `b`, collecting the optional results