From 0c05baba2b78f2baca05fc55b20cd2a95000ea48 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 14 Sep 2022 09:39:24 -0700 Subject: [PATCH 01/15] Init --- arrow/src/compute/kernels/arithmetic.rs | 156 +++++++++++++++++++++++- arrow/src/datatypes/native.rs | 1 + 2 files changed, 156 insertions(+), 1 deletion(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index a344407e426d..2a94bf0c72ca 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -66,6 +66,8 @@ where LT: ArrowNumericType, RT: ArrowNumericType, F: Fn(LT::Native, RT::Native) -> LT::Native, + LT::Native: ArrowNativeTypeOp, + RT::Native: ArrowNativeTypeOp, { if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -88,6 +90,8 @@ where LT: ArrowNumericType, RT: ArrowNumericType, F: Fn(LT::Native, RT::Native) -> Option, + LT::Native: ArrowNativeTypeOp, + RT::Native: ArrowNativeTypeOp, { if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -620,6 +624,59 @@ macro_rules! math_dict_op { }}; } +/// Helper function to perform math lambda function on values from two dictionary arrays, this +/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) +macro_rules! math_dict_checked_op { + ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ + if $left.len() != $right.len() { + return Err(ArrowError::ComputeError(format!( + "Cannot perform operation on arrays of different length ({}, {})", + $left.len(), + $right.len() + ))); + } + + // Safety justification: Since the inputs are valid Arrow arrays, all values are + // valid indexes into the dictionary (which is verified during construction) + + let left_iter = unsafe { + $left + .values() + .as_any() + .downcast_ref::<$value_ty>() + .unwrap() + .take_iter_unchecked($left.keys_iter()) + }; + + let right_iter = unsafe { + $right + .values() + .as_any() + .downcast_ref::<$value_ty>() + .unwrap() + .take_iter_unchecked($right.keys_iter()) + }; + + let result = left_iter + .zip(right_iter) + .map(|(left_value, right_value)| { + if let (Some(left), Some(right)) = (left_value, right_value) { + Some($op(left, right).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?}, {:?}", + left, right + )) + })) + } else { + None + } + }) + .collect(); + + Ok(result) + }}; +} + /// Perform given operation on two `DictionaryArray`s. /// Returns an error if the two arrays have different value type fn math_op_dict( @@ -631,10 +688,27 @@ where K: ArrowNumericType, T: ArrowNumericType, F: Fn(T::Native, T::Native) -> T::Native, + T::Native: ArrowNativeTypeOp, { math_dict_op!(left, right, op, PrimitiveArray) } +/// Perform given operation on two `DictionaryArray`s. +/// Returns an error if the two arrays have different value type +fn math_checked_op_dict( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result> +where + K: ArrowNumericType, + T: ArrowNumericType, + F: Fn(T::Native, T::Native) -> Option, + T::Native: ArrowNativeTypeOp, +{ + math_dict_checked_op!(left, right, op, PrimitiveArray) +} + /// Helper function for operations where a valid `0` on the right array should /// result in an [ArrowError::DivideByZero], namely the division and modulo operations /// @@ -728,6 +802,9 @@ where /// Perform `left + right` operation on two arrays. If either left or right value is null /// then the result is also null. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `add_checked_dyn` instead. pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { @@ -784,7 +861,84 @@ pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { _ => { downcast_primitive_array!( (left, right) => { - math_op(left, right, |a, b| a + b).map(|a| Arc::new(a) as ArrayRef) + math_op(left, right, |a, b| a.add_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), right.data_type() + ))) + ) + } + } +} + +/// Perform `left + right` operation on two arrays. If either left or right value is null +/// then the result is also null. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `add_dyn` instead. +pub fn add_checked_dyn(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!( + left, + right, + |a, b| a.add_checked(b), + math_checked_op_dict + ) + } + DataType::Date32 => { + let l = as_primitive_array::(left); + match right.data_type() { + DataType::Interval(IntervalUnit::YearMonth) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date32Type::add_year_months)?; + Ok(Arc::new(res)) + } + DataType::Interval(IntervalUnit::DayTime) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date32Type::add_day_time)?; + Ok(Arc::new(res)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date32Type::add_month_day_nano)?; + Ok(Arc::new(res)) + } + _ => Err(ArrowError::CastError(format!( + "Cannot perform arithmetic operation between array of type {} and array of type {}", + left.data_type(), right.data_type() + ))), + } + } + DataType::Date64 => { + let l = as_primitive_array::(left); + match right.data_type() { + DataType::Interval(IntervalUnit::YearMonth) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date64Type::add_year_months)?; + Ok(Arc::new(res)) + } + DataType::Interval(IntervalUnit::DayTime) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date64Type::add_day_time)?; + Ok(Arc::new(res)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let r = as_primitive_array::(right); + let res = math_op(l, r, Date64Type::add_month_day_nano)?; + Ok(Arc::new(res)) + } + _ => Err(ArrowError::CastError(format!( + "Cannot perform arithmetic operation between array of type {} and array of type {}", + left.data_type(), right.data_type() + ))), + } + } + _ => { + downcast_primitive_array!( + (left, right) => { + math_checked_op(left, right, |a, b| a.add_checked(b)).map(|a| Arc::new(a) as ArrayRef) } _ => Err(ArrowError::CastError(format!( "Unsupported data type {}, {}", diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index 444f2b27dce6..563e6be904da 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -211,6 +211,7 @@ native_type_op!(i8); native_type_op!(i16); native_type_op!(i32); native_type_op!(i64); +native_type_op!(i128); native_type_op!(u8); native_type_op!(u16); native_type_op!(u32); From ac376a1b05c678f5aba02d229864d3ac5204dc97 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 15 Sep 2022 00:06:02 -0700 Subject: [PATCH 02/15] More --- arrow/src/compute/kernels/arithmetic.rs | 164 ++++++++++++++++++------ 1 file changed, 125 insertions(+), 39 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 2a94bf0c72ca..ead830ee0b32 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -639,41 +639,22 @@ macro_rules! math_dict_checked_op { // Safety justification: Since the inputs are valid Arrow arrays, all values are // valid indexes into the dictionary (which is verified during construction) - let left_iter = unsafe { - $left - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($left.keys_iter()) - }; + let left = $left.values().as_any().downcast_ref::<$value_ty>().unwrap(); - let right_iter = unsafe { - $right - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($right.keys_iter()) - }; + let right = $right + .values() + .as_any() + .downcast_ref::<$value_ty>() + .unwrap(); - let result = left_iter - .zip(right_iter) - .map(|(left_value, right_value)| { - if let (Some(left), Some(right)) = (left_value, right_value) { - Some($op(left, right).ok_or_else(|| { - ArrowError::ComputeError(format!( - "Overflow happened on: {:?}, {:?}", - left, right - )) - })) - } else { - None - } + try_binary(left, right, |a, b| { + $op(a, b).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?}, {:?}", + a, b + )) }) - .collect(); - - Ok(result) + }) }}; } @@ -804,11 +785,11 @@ where /// then the result is also null. /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `add_checked_dyn` instead. +/// For an overflow-checking variant, use `add_dyn_checked` instead. pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_math_op!(left, right, |a, b| a + b, math_op_dict) + typed_dict_math_op!(left, right, |a, b| a.add_wrapping(b), math_op_dict) } DataType::Date32 => { let l = as_primitive_array::(left); @@ -877,7 +858,7 @@ pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, /// use `add_dyn` instead. -pub fn add_checked_dyn(left: &dyn Array, right: &dyn Array) -> Result { +pub fn add_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { typed_dict_math_op!( @@ -1030,15 +1011,47 @@ where /// Perform `left - right` operation on two arrays. If either left or right value is null /// then the result is also null. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `subtract_dyn_checked` instead. pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_math_op!(left, right, |a, b| a - b, math_op_dict) + typed_dict_math_op!(left, right, |a, b| a.sub_wrapping(b), math_op_dict) + } + _ => { + downcast_primitive_array!( + (left, right) => { + math_op(left, right, |a, b| a.sub_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), right.data_type() + ))) + ) + } + } +} + +/// Perform `left - right` operation on two arrays. If either left or right value is null +/// then the result is also null. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `subtract_dyn` instead. +pub fn subtract_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!( + left, + right, + |a, b| a.sub_checked(b), + math_checked_op_dict + ) } _ => { downcast_primitive_array!( (left, right) => { - math_op(left, right, |a, b| a - b).map(|a| Arc::new(a) as ArrayRef) + math_checked_op(left, right, |a, b| a.sub_checked(b)).map(|a| Arc::new(a) as ArrayRef) } _ => Err(ArrowError::CastError(format!( "Unsupported data type {}, {}", @@ -1158,15 +1171,47 @@ where /// Perform `left * right` operation on two arrays. If either left or right value is null /// then the result is also null. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `multiply_dyn_checked` instead. pub fn multiply_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_math_op!(left, right, |a, b| a * b, math_op_dict) + typed_dict_math_op!(left, right, |a, b| a.mul_wrapping(b), math_op_dict) + } + _ => { + downcast_primitive_array!( + (left, right) => { + math_op(left, right, |a, b| a.mul_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), right.data_type() + ))) + ) + } + } +} + +/// Perform `left * right` operation on two arrays. If either left or right value is null +/// then the result is also null. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `multiply_dyn` instead. +pub fn multiply_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!( + left, + right, + |a, b| a.mul_checked(b), + math_checked_op_dict + ) } _ => { downcast_primitive_array!( (left, right) => { - math_op(left, right, |a, b| a * b).map(|a| Arc::new(a) as ArrayRef) + math_checked_op(left, right, |a, b| a.mul_checked(b)).map(|a| Arc::new(a) as ArrayRef) } _ => Err(ArrowError::CastError(format!( "Unsupported data type {}, {}", @@ -1302,6 +1347,33 @@ where /// then the result is also null. If any right hand value is zero then the result of this /// operation will be `Err(ArrowError::DivideByZero)`. pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!( + left, + right, + |a, b| a.div_wrapping(b), + math_divide_checked_op_dict + ) + } + _ => { + downcast_primitive_array!( + (left, right) => { + math_checked_divide_op(left, right, |a, b| Some(a.div_wrapping(b))).map(|a| Arc::new(a) as ArrayRef) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), right.data_type() + ))) + ) + } + } +} + +/// Perform `left / right` operation on two arrays. If either left or right value is null +/// then the result is also null. If any right hand value is zero then the result of this +/// operation will be `Err(ArrowError::DivideByZero)`. +pub fn divide_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { typed_dict_math_op!(left, right, |a, b| a / b, math_divide_checked_op_dict) @@ -2394,4 +2466,18 @@ mod tests { let expected = Int32Array::from(vec![None]); assert_eq!(expected, overflow.unwrap()); } + + #[test] + fn test_primitive_add_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + let b = Int32Array::from(vec![1, 1]); + + let wrapped = add_dyn(&a, &b).unwrap(); + let expected = + Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = add_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } } From 7f570d18e0421b9b53836ca970d93503dc4066c4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 15 Sep 2022 00:22:46 -0700 Subject: [PATCH 03/15] More --- arrow/src/compute/kernels/arithmetic.rs | 30 +++++++++++++++++++------ 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index ead830ee0b32..242560dd7230 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -163,7 +163,7 @@ fn math_checked_divide_op_on_iters( where T: ArrowNumericType, T::Native: One + Zero, - F: Fn(T::Native, T::Native) -> T::Native, + F: Fn(T::Native, T::Native) -> Result, { let buffer = if null_bit_buffer.is_some() { let values = left.zip(right).map(|(left, right)| { @@ -171,7 +171,7 @@ where if r.is_zero() { Err(ArrowError::DivideByZero) } else { - Ok(op(l, r)) + op(l, r) } } else { Ok(T::default_value()) @@ -186,7 +186,7 @@ where if right.is_zero() { Err(ArrowError::DivideByZero) } else { - Ok(op(left, right)) + op(left, right) } }, ); @@ -707,7 +707,7 @@ where K: ArrowNumericType, T: ArrowNumericType, T::Native: One + Zero, - F: Fn(T::Native, T::Native) -> T::Native, + F: Fn(T::Native, T::Native) -> Result, { if left.len() != right.len() { return Err(ArrowError::ComputeError(format!( @@ -1346,13 +1346,16 @@ where /// Perform `left / right` operation on two arrays. If either left or right value is null /// then the result is also null. If any right hand value is zero then the result of this /// operation will be `Err(ArrowError::DivideByZero)`. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `divide_dyn_checked` instead. pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { typed_dict_math_op!( left, right, - |a, b| a.div_wrapping(b), + |a, b| Ok(a.div_wrapping(b)), math_divide_checked_op_dict ) } @@ -1373,15 +1376,28 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// Perform `left / right` operation on two arrays. If either left or right value is null /// then the result is also null. If any right hand value is zero then the result of this /// operation will be `Err(ArrowError::DivideByZero)`. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `divide_dyn` instead. pub fn divide_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_math_op!(left, right, |a, b| a / b, math_divide_checked_op_dict) + typed_dict_math_op!( + left, + right, + |a, b| a.div_checked(b).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?}, {:?}", + a, b + )) + }), + math_divide_checked_op_dict + ) } _ => { downcast_primitive_array!( (left, right) => { - math_checked_divide_op(left, right, |a, b| Some(a / b)).map(|a| Arc::new(a) as ArrayRef) + math_checked_divide_op(left, right, |a, b| a.div_checked(b)).map(|a| Arc::new(a) as ArrayRef) } _ => Err(ArrowError::CastError(format!( "Unsupported data type {}, {}", From 97c7094b43b209585c6c60dcb2e45bdd9fb03985 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 15 Sep 2022 16:20:30 -0700 Subject: [PATCH 04/15] Add tests --- arrow/src/compute/kernels/arithmetic.rs | 128 +++++++++++++++++++++++- arrow/src/compute/kernels/arity.rs | 13 ++- 2 files changed, 133 insertions(+), 8 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 242560dd7230..540d6b92598c 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -639,13 +639,17 @@ macro_rules! math_dict_checked_op { // Safety justification: Since the inputs are valid Arrow arrays, all values are // valid indexes into the dictionary (which is verified during construction) - let left = $left.values().as_any().downcast_ref::<$value_ty>().unwrap(); + // let left = $left.values().as_any().downcast_ref::<$value_ty>().unwrap(); + let left = $left.downcast_dict::<$value_ty>().unwrap(); + let right = $right.downcast_dict::<$value_ty>().unwrap(); + /* let right = $right .values() .as_any() .downcast_ref::<$value_ty>() .unwrap(); + */ try_binary(left, right, |a, b| { $op(a, b).ok_or_else(|| { @@ -2496,4 +2500,126 @@ mod tests { let overflow = add_dyn_checked(&a, &b); overflow.expect_err("overflow should be detected"); } + + #[test] + fn test_dictionary_add_dyn_wrapping_overflow() { + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(2, 2); + builder.append(i32::MAX).unwrap(); + builder.append(i32::MIN).unwrap(); + let a = builder.finish(); + + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(2, 2); + builder.append(1).unwrap(); + builder.append(1).unwrap(); + let b = builder.finish(); + + let wrapped = add_dyn(&a, &b).unwrap(); + let expected = + Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = add_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_subtract_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![-2]); + let b = Int32Array::from(vec![i32::MAX]); + + let wrapped = subtract_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = subtract_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_dictionary_subtract_dyn_wrapping_overflow() { + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(-2).unwrap(); + let a = builder.finish(); + + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(i32::MAX).unwrap(); + let b = builder.finish(); + + let wrapped = subtract_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = subtract_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_mul_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![10]); + let b = Int32Array::from(vec![i32::MAX]); + + let wrapped = multiply_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = multiply_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_dictionary_mul_dyn_wrapping_overflow() { + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(10).unwrap(); + let a = builder.finish(); + + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(i32::MAX).unwrap(); + let b = builder.finish(); + + let wrapped = multiply_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = multiply_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_primitive_div_dyn_wrapping_overflow() { + let a = Int32Array::from(vec![i32::MIN]); + let b = Int32Array::from(vec![-1]); + + let wrapped = divide_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = divide_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } + + #[test] + fn test_dictionary_div_dyn_wrapping_overflow() { + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(i32::MIN).unwrap(); + let a = builder.finish(); + + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(-1).unwrap(); + let b = builder.finish(); + + let wrapped = divide_dyn(&a, &b).unwrap(); + let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef; + assert_eq!(&expected, &wrapped); + + let overflow = divide_dyn_checked(&a, &b); + overflow.expect_err("overflow should be detected"); + } } diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index fffa81af8190..e6d8a00f6c71 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -18,7 +18,8 @@ //! Defines kernels suitable to perform operations to primitive arrays. use crate::array::{ - Array, ArrayData, ArrayIter, ArrayRef, BufferBuilder, DictionaryArray, PrimitiveArray, + Array, ArrayAccessor, ArrayData, ArrayIter, ArrayRef, BufferBuilder, DictionaryArray, + PrimitiveArray, }; use crate::buffer::Buffer; use crate::compute::util::combine_option_bitmap; @@ -218,16 +219,14 @@ where /// # Panic /// /// Panics if the arrays have different lengths -pub fn try_binary( - a: &PrimitiveArray, - b: &PrimitiveArray, +pub fn try_binary( + a: A, + b: B, op: F, ) -> Result> where - A: ArrowPrimitiveType, - B: ArrowPrimitiveType, O: ArrowPrimitiveType, - F: Fn(A::Native, B::Native) -> Result, + F: Fn(A::Item, B::Item) -> Result, { assert_eq!(a.len(), b.len()); let len = a.len(); From 9d526c90ba92cfefc18adaa8fe9b0349cf25235a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 16 Sep 2022 16:56:26 -0700 Subject: [PATCH 05/15] Fix clippy --- arrow/src/compute/kernels/arithmetic.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index e577f263ade5..fd27a038825c 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -101,7 +101,7 @@ where )); } - try_binary(left, right, |a, b| op(a, b)) + try_binary(left, right, op) } /// Helper function for operations where a valid `0` on the right array should From 04066c0dd70a236094dc97c448a0bcf0b982c765 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 16 Sep 2022 17:18:29 -0700 Subject: [PATCH 06/15] Remove macro --- arrow/src/compute/kernels/arithmetic.rs | 124 ++++++++++-------------- 1 file changed, 52 insertions(+), 72 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index fd27a038825c..1bfa59288a8b 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -548,76 +548,6 @@ macro_rules! typed_dict_math_op { }}; } -/// Helper function to perform math lambda function on values from two dictionary arrays, this -/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) -macro_rules! math_dict_op { - ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ - if $left.len() != $right.len() { - return Err(ArrowError::ComputeError(format!( - "Cannot perform operation on arrays of different length ({}, {})", - $left.len(), - $right.len() - ))); - } - - // Safety justification: Since the inputs are valid Arrow arrays, all values are - // valid indexes into the dictionary (which is verified during construction) - - let left_iter = unsafe { - $left - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($left.keys_iter()) - }; - - let right_iter = unsafe { - $right - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($right.keys_iter()) - }; - - let result = left_iter - .zip(right_iter) - .map(|(left_value, right_value)| { - if let (Some(left), Some(right)) = (left_value, right_value) { - Some($op(left, right)) - } else { - None - } - }) - .collect(); - - Ok(result) - }}; -} - -/// Helper function to perform math lambda function on values from two dictionary arrays, this -/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) -macro_rules! math_dict_checked_op { - ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ - if $left.len() != $right.len() { - return Err(ArrowError::ComputeError(format!( - "Cannot perform operation on arrays of different length ({}, {})", - $left.len(), - $right.len() - ))); - } - - // Safety justification: Since the inputs are valid Arrow arrays, all values are - // valid indexes into the dictionary (which is verified during construction) - - let left = $left.downcast_dict::<$value_ty>().unwrap(); - let right = $right.downcast_dict::<$value_ty>().unwrap(); - - try_binary(left, right, |a, b| $op(a, b)) - }}; -} - /// Perform given operation on two `DictionaryArray`s. /// Returns an error if the two arrays have different value type fn math_op_dict( @@ -631,7 +561,46 @@ where F: Fn(T::Native, T::Native) -> T::Native, T::Native: ArrowNativeTypeOp, { - math_dict_op!(left, right, op, PrimitiveArray) + if left.len() != right.len() { + return Err(ArrowError::ComputeError(format!( + "Cannot perform operation on arrays of different length ({}, {})", + left.len(), + right.len() + ))); + } + + // Safety justification: Since the inputs are valid Arrow arrays, all values are + // valid indexes into the dictionary (which is verified during construction) + + let left_iter = unsafe { + left.values() + .as_any() + .downcast_ref::>() + .unwrap() + .take_iter_unchecked(left.keys_iter()) + }; + + let right_iter = unsafe { + right + .values() + .as_any() + .downcast_ref::>() + .unwrap() + .take_iter_unchecked(right.keys_iter()) + }; + + let result = left_iter + .zip(right_iter) + .map(|(left_value, right_value)| { + if let (Some(left), Some(right)) = (left_value, right_value) { + Some(op(left, right)) + } else { + None + } + }) + .collect(); + + Ok(result) } /// Perform given operation on two `DictionaryArray`s. @@ -647,7 +616,18 @@ where F: Fn(T::Native, T::Native) -> Result, T::Native: ArrowNativeTypeOp, { - math_dict_checked_op!(left, right, op, PrimitiveArray) + if left.len() != right.len() { + return Err(ArrowError::ComputeError(format!( + "Cannot perform operation on arrays of different length ({}, {})", + left.len(), + right.len() + ))); + } + + let left = left.downcast_dict::>().unwrap(); + let right = right.downcast_dict::>().unwrap(); + + try_binary(left, right, |a, b| op(a, b)) } /// Helper function for operations where a valid `0` on the right array should From c69c48bb87d206f6490b47f0f09236e7a9c2c76a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 16 Sep 2022 17:41:00 -0700 Subject: [PATCH 07/15] Update doc --- arrow/src/compute/kernels/arithmetic.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 1bfa59288a8b..26876159b6aa 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -81,7 +81,7 @@ where } /// This is similar to `math_op` as it performs given operation between two input primitive arrays. -/// But the given operation can return `None` if overflow is detected. For the case, this function +/// But the given operation can return `Err` if overflow is detected. For the case, this function /// returns an `Err`. fn math_checked_op( left: &PrimitiveArray, From dc6077f3cc7a03637b19d6b8f5b9cfbea9e08b65 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 16 Sep 2022 17:48:59 -0700 Subject: [PATCH 08/15] Fix clippy --- arrow/src/compute/kernels/arithmetic.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 26876159b6aa..4c82d73401d0 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -627,7 +627,7 @@ where let left = left.downcast_dict::>().unwrap(); let right = right.downcast_dict::>().unwrap(); - try_binary(left, right, |a, b| op(a, b)) + try_binary(left, right, op) } /// Helper function for operations where a valid `0` on the right array should From 83dcff1dfef9a9f14c6b7161b7f1f87cfc5962ef Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 16 Sep 2022 18:49:11 -0700 Subject: [PATCH 09/15] Remove length check --- arrow/src/compute/kernels/arithmetic.rs | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 4c82d73401d0..03d919e33897 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -95,12 +95,6 @@ where LT::Native: ArrowNativeTypeOp, RT::Native: ArrowNativeTypeOp, { - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform math operation on arrays of different length".to_string(), - )); - } - try_binary(left, right, op) } @@ -616,14 +610,6 @@ where F: Fn(T::Native, T::Native) -> Result, T::Native: ArrowNativeTypeOp, { - if left.len() != right.len() { - return Err(ArrowError::ComputeError(format!( - "Cannot perform operation on arrays of different length ({}, {})", - left.len(), - right.len() - ))); - } - let left = left.downcast_dict::>().unwrap(); let right = right.downcast_dict::>().unwrap(); From 4394ff1e04fbe6ad3c1813e4bac293bb15bd6af8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 17 Sep 2022 00:49:23 -0700 Subject: [PATCH 10/15] Tweak try_binary to coordinate latest optimization --- arrow/src/array/array.rs | 13 +++++++++++++ arrow/src/array/array_primitive.rs | 12 +++++++++++- arrow/src/compute/kernels/arity.rs | 17 +++++++++++++++-- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs index 38ba2025a2e3..109b65838a64 100644 --- a/arrow/src/array/array.rs +++ b/arrow/src/array/array.rs @@ -356,6 +356,19 @@ pub trait ArrayAccessor: Array { /// # Safety /// Caller is responsible for ensuring that the index is within the bounds of the array unsafe fn value_unchecked(&self, index: usize) -> Self::Item; + + /// Returns a values accessor [`ArrayValuesAccessor`] for this [`ArrayAccessor`] if + /// it supports. Returns [`None`] if it doesn't support accessing values directly. + fn get_values_accessor(&self) -> Option<&dyn ArrayValuesAccessor> { + None + } +} + +/// A trait for accessing the values of an [`ArrayAccessor`] as a slice at once. Not all +/// [`ArrayAccessor`] implementations support this trait. Currently only [`PrimitiveArray`] +/// supports it. +pub trait ArrayValuesAccessor: ArrayAccessor { + fn values(&self) -> &[Self::Item]; } /// Constructs an array using the input `data`. diff --git a/arrow/src/array/array_primitive.rs b/arrow/src/array/array_primitive.rs index 57168b7b9e60..303ebb7fc037 100644 --- a/arrow/src/array/array_primitive.rs +++ b/arrow/src/array/array_primitive.rs @@ -33,7 +33,7 @@ use crate::{ util::trusted_len_unzip, }; -use crate::array::array::ArrayAccessor; +use crate::array::array::{ArrayAccessor, ArrayValuesAccessor}; use half::f16; /// Array whose elements are of primitive types. @@ -204,6 +204,16 @@ impl<'a, T: ArrowPrimitiveType> ArrayAccessor for &'a PrimitiveArray { unsafe fn value_unchecked(&self, index: usize) -> Self::Item { PrimitiveArray::value_unchecked(self, index) } + + fn get_values_accessor(&self) -> Option<&dyn ArrayValuesAccessor> { + Some(self) + } +} + +impl<'a, T: ArrowPrimitiveType> ArrayValuesAccessor for &'a PrimitiveArray { + fn values(&self) -> &[Self::Item] { + PrimitiveArray::values(self) + } } pub(crate) fn as_datetime(v: i64) -> Option { diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index dfeb171d72cd..77432f3b8829 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -285,6 +285,8 @@ pub fn try_binary( where O: ArrowPrimitiveType, F: Fn(A::Item, B::Item) -> Result, + A::Item: Copy, + B::Item: Copy, { if a.len() != b.len() { return Err(ArrowError::ComputeError( @@ -296,8 +298,19 @@ where } let len = a.len(); - if a.null_count() == 0 && b.null_count() == 0 { - let values = a.values().iter().zip(b.values()).map(|(l, r)| op(*l, *r)); + let a_values_accessor = a.get_values_accessor(); + let b_values_accrssor = b.get_values_accessor(); + if a.null_count() == 0 + && b.null_count() == 0 + && a_values_accessor.is_some() + && b_values_accrssor.is_some() + { + let values = a_values_accessor + .unwrap() + .values() + .iter() + .zip(b_values_accrssor.unwrap().values()) + .map(|(l, r)| op(*l, *r)); let buffer = unsafe { Buffer::try_from_trusted_len_iter(values) }?; // JUSTIFICATION // Benefit From 0f8a5bbf51d373ab97ae5002fbebafaa3718bed0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 17 Sep 2022 01:01:33 -0700 Subject: [PATCH 11/15] Fix clippy --- arrow/src/compute/kernels/arity.rs | 80 +++++++++++++++--------------- 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 77432f3b8829..05dae37973b2 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -300,46 +300,48 @@ where let a_values_accessor = a.get_values_accessor(); let b_values_accrssor = b.get_values_accessor(); - if a.null_count() == 0 - && b.null_count() == 0 - && a_values_accessor.is_some() - && b_values_accrssor.is_some() - { - let values = a_values_accessor - .unwrap() - .values() - .iter() - .zip(b_values_accrssor.unwrap().values()) - .map(|(l, r)| op(*l, *r)); - let buffer = unsafe { Buffer::try_from_trusted_len_iter(values) }?; - // JUSTIFICATION - // Benefit - // ~75% speedup - // Soundness - // `values` is an iterator with a known size from a PrimitiveArray - return Ok(unsafe { build_primitive_array(len, buffer, 0, None) }); + match (a_values_accessor, b_values_accrssor) { + (Some(a_values), Some(b_values)) + if a.null_count() == 0 && b.null_count() == 0 => + { + let values = a_values + .values() + .iter() + .zip(b_values.values()) + .map(|(l, r)| op(*l, *r)); + let buffer = unsafe { Buffer::try_from_trusted_len_iter(values) }?; + // JUSTIFICATION + // Benefit + // ~75% speedup + // Soundness + // `values` is an iterator with a known size from a PrimitiveArray + Ok(unsafe { build_primitive_array(len, buffer, 0, None) }) + } + _ => { + let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); + + let null_count = null_buffer + .as_ref() + .map(|x| len - x.count_set_bits()) + .unwrap_or_default(); + + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { + unsafe { + *slice.get_unchecked_mut(idx) = + op(a.value_unchecked(idx), b.value_unchecked(idx))? + }; + Ok::<_, ArrowError>(()) + })?; + + Ok(unsafe { + build_primitive_array(len, buffer.finish(), null_count, null_buffer) + }) + } } - - let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); - - let null_count = null_buffer - .as_ref() - .map(|x| len - x.count_set_bits()) - .unwrap_or_default(); - - let mut buffer = BufferBuilder::::new(len); - buffer.append_n_zeroed(len); - let slice = buffer.as_slice_mut(); - - try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { - unsafe { - *slice.get_unchecked_mut(idx) = - op(a.value_unchecked(idx), b.value_unchecked(idx))? - }; - Ok::<_, ArrowError>(()) - })?; - - Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) }) } /// Applies the provided binary operation across `a` and `b`, collecting the optional results From d81924e16f18a1bc76b0c54822d2af8d6506cda6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 17 Sep 2022 10:05:38 -0700 Subject: [PATCH 12/15] Use for loop --- arrow/src/array/array.rs | 13 ----- arrow/src/array/array_primitive.rs | 12 +---- arrow/src/compute/kernels/arity.rs | 78 +++++++++++++----------------- 3 files changed, 35 insertions(+), 68 deletions(-) diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs index 109b65838a64..38ba2025a2e3 100644 --- a/arrow/src/array/array.rs +++ b/arrow/src/array/array.rs @@ -356,19 +356,6 @@ pub trait ArrayAccessor: Array { /// # Safety /// Caller is responsible for ensuring that the index is within the bounds of the array unsafe fn value_unchecked(&self, index: usize) -> Self::Item; - - /// Returns a values accessor [`ArrayValuesAccessor`] for this [`ArrayAccessor`] if - /// it supports. Returns [`None`] if it doesn't support accessing values directly. - fn get_values_accessor(&self) -> Option<&dyn ArrayValuesAccessor> { - None - } -} - -/// A trait for accessing the values of an [`ArrayAccessor`] as a slice at once. Not all -/// [`ArrayAccessor`] implementations support this trait. Currently only [`PrimitiveArray`] -/// supports it. -pub trait ArrayValuesAccessor: ArrayAccessor { - fn values(&self) -> &[Self::Item]; } /// Constructs an array using the input `data`. diff --git a/arrow/src/array/array_primitive.rs b/arrow/src/array/array_primitive.rs index 303ebb7fc037..57168b7b9e60 100644 --- a/arrow/src/array/array_primitive.rs +++ b/arrow/src/array/array_primitive.rs @@ -33,7 +33,7 @@ use crate::{ util::trusted_len_unzip, }; -use crate::array::array::{ArrayAccessor, ArrayValuesAccessor}; +use crate::array::array::ArrayAccessor; use half::f16; /// Array whose elements are of primitive types. @@ -204,16 +204,6 @@ impl<'a, T: ArrowPrimitiveType> ArrayAccessor for &'a PrimitiveArray { unsafe fn value_unchecked(&self, index: usize) -> Self::Item { PrimitiveArray::value_unchecked(self, index) } - - fn get_values_accessor(&self) -> Option<&dyn ArrayValuesAccessor> { - Some(self) - } -} - -impl<'a, T: ArrowPrimitiveType> ArrayValuesAccessor for &'a PrimitiveArray { - fn values(&self) -> &[Self::Item] { - PrimitiveArray::values(self) - } } pub(crate) fn as_datetime(v: i64) -> Option { diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 05dae37973b2..ec44215aa6e0 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -285,8 +285,6 @@ pub fn try_binary( where O: ArrowPrimitiveType, F: Fn(A::Item, B::Item) -> Result, - A::Item: Copy, - B::Item: Copy, { if a.len() != b.len() { return Err(ArrowError::ComputeError( @@ -298,49 +296,41 @@ where } let len = a.len(); - let a_values_accessor = a.get_values_accessor(); - let b_values_accrssor = b.get_values_accessor(); - match (a_values_accessor, b_values_accrssor) { - (Some(a_values), Some(b_values)) - if a.null_count() == 0 && b.null_count() == 0 => - { - let values = a_values - .values() - .iter() - .zip(b_values.values()) - .map(|(l, r)| op(*l, *r)); - let buffer = unsafe { Buffer::try_from_trusted_len_iter(values) }?; - // JUSTIFICATION - // Benefit - // ~75% speedup - // Soundness - // `values` is an iterator with a known size from a PrimitiveArray - Ok(unsafe { build_primitive_array(len, buffer, 0, None) }) - } - _ => { - let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); - - let null_count = null_buffer - .as_ref() - .map(|x| len - x.count_set_bits()) - .unwrap_or_default(); - - let mut buffer = BufferBuilder::::new(len); - buffer.append_n_zeroed(len); - let slice = buffer.as_slice_mut(); - - try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { - unsafe { - *slice.get_unchecked_mut(idx) = - op(a.value_unchecked(idx), b.value_unchecked(idx))? - }; - Ok::<_, ArrowError>(()) - })?; - - Ok(unsafe { - build_primitive_array(len, buffer.finish(), null_count, null_buffer) - }) + if a.null_count() == 0 && b.null_count() == 0 { + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + for idx in 0..len { + unsafe { + *slice.get_unchecked_mut(idx) = + op(a.value_unchecked(idx), b.value_unchecked(idx))? + }; } + Ok(unsafe { build_primitive_array(len, buffer.finish(), 0, None) }) + } else { + let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); + + let null_count = null_buffer + .as_ref() + .map(|x| len - x.count_set_bits()) + .unwrap_or_default(); + + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { + unsafe { + *slice.get_unchecked_mut(idx) = + op(a.value_unchecked(idx), b.value_unchecked(idx))? + }; + Ok::<_, ArrowError>(()) + })?; + + Ok(unsafe { + build_primitive_array(len, buffer.finish(), null_count, null_buffer) + }) } } From 4eb0fe42d0dced646dec007a8525d2bacf11aa87 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 18 Sep 2022 22:59:37 -0700 Subject: [PATCH 13/15] Split non-null variant into never inline function --- arrow/src/compute/kernels/arity.rs | 34 ++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index ec44215aa6e0..6f1e5a4c35c5 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -27,6 +27,7 @@ use crate::datatypes::{ArrowNumericType, ArrowPrimitiveType}; use crate::downcast_dictionary_array; use crate::error::{ArrowError, Result}; use crate::util::bit_iterator::try_for_each_valid_idx; +use arrow_buffer::MutableBuffer; use std::sync::Arc; #[inline] @@ -297,17 +298,7 @@ where let len = a.len(); if a.null_count() == 0 && b.null_count() == 0 { - let mut buffer = BufferBuilder::::new(len); - buffer.append_n_zeroed(len); - let slice = buffer.as_slice_mut(); - - for idx in 0..len { - unsafe { - *slice.get_unchecked_mut(idx) = - op(a.value_unchecked(idx), b.value_unchecked(idx))? - }; - } - Ok(unsafe { build_primitive_array(len, buffer.finish(), 0, None) }) + try_binary_no_nulls(len, a, b, op) } else { let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); @@ -334,6 +325,27 @@ where } } +/// This intentional inline(never) attribute helps LLVM optimize the loop. +#[inline(never)] +fn try_binary_no_nulls( + len: usize, + a: A, + b: B, + op: F, +) -> Result> +where + O: ArrowPrimitiveType, + F: Fn(A::Item, B::Item) -> Result, +{ + let mut buffer = MutableBuffer::new(len); + for idx in 0..len { + unsafe { + buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?); + }; + } + Ok(unsafe { build_primitive_array(len, buffer.into(), 0, None) }) +} + /// Applies the provided binary operation across `a` and `b`, collecting the optional results /// into a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the corresponding /// index in the result will also be null. The binary operation could return `None` which From 38de6653d3b3eff2aeeb998de79168984ace7693 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 18 Sep 2022 23:13:04 -0700 Subject: [PATCH 14/15] Add value type check --- arrow/src/compute/kernels/arithmetic.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 03d919e33897..bb9b01b13aa7 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -610,6 +610,14 @@ where F: Fn(T::Native, T::Native) -> Result, T::Native: ArrowNativeTypeOp, { + // left and right's value types are supposed to be same as guaranteed by the caller macro now. + if left.value_type() != T::DATA_TYPE { + return Err(ArrowError::NotYetImplemented(format!( + "Cannot perform provided operation on dictionary array of value type {}", + left.value_type() + ))); + } + let left = left.downcast_dict::>().unwrap(); let right = right.downcast_dict::>().unwrap(); From 4d0682683a879173fe31fac255d0e40d9ec64b27 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 19 Sep 2022 09:29:30 -0700 Subject: [PATCH 15/15] Multiply by get_byte_width of output type. --- arrow/src/compute/kernels/arity.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 6f1e5a4c35c5..7174de0ce44e 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -337,7 +337,7 @@ where O: ArrowPrimitiveType, F: Fn(A::Item, B::Item) -> Result, { - let mut buffer = MutableBuffer::new(len); + let mut buffer = MutableBuffer::new(len * O::get_byte_width()); for idx in 0..len { unsafe { buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?);