diff --git a/arrow/benches/arithmetic_kernels.rs b/arrow/benches/arithmetic_kernels.rs index 10af0b5432ef..2aa2e7191a68 100644 --- a/arrow/benches/arithmetic_kernels.rs +++ b/arrow/benches/arithmetic_kernels.rs @@ -20,107 +20,62 @@ extern crate criterion; use criterion::Criterion; use rand::Rng; -use std::sync::Arc; - extern crate arrow; +use arrow::datatypes::Float32Type; use arrow::util::bench_util::*; -use arrow::{array::*, datatypes::Float32Type}; use arrow::{compute::kernels::arithmetic::*, util::test_util::seedable_rng}; -fn create_array(size: usize, with_nulls: bool) -> ArrayRef { - let null_density = if with_nulls { 0.5 } else { 0.0 }; - let array = create_primitive_array::(size, null_density); - Arc::new(array) -} - -fn bench_add(arr_a: &ArrayRef, arr_b: &ArrayRef) { - let arr_a = arr_a.as_any().downcast_ref::().unwrap(); - let arr_b = arr_b.as_any().downcast_ref::().unwrap(); - criterion::black_box(add(arr_a, arr_b).unwrap()); -} - -fn bench_subtract(arr_a: &ArrayRef, arr_b: &ArrayRef) { - let arr_a = arr_a.as_any().downcast_ref::().unwrap(); - let arr_b = arr_b.as_any().downcast_ref::().unwrap(); - criterion::black_box(subtract(arr_a, arr_b).unwrap()); -} - -fn bench_multiply(arr_a: &ArrayRef, arr_b: &ArrayRef) { - let arr_a = arr_a.as_any().downcast_ref::().unwrap(); - let arr_b = arr_b.as_any().downcast_ref::().unwrap(); - criterion::black_box(multiply(arr_a, arr_b).unwrap()); -} - -fn bench_divide(arr_a: &ArrayRef, arr_b: &ArrayRef) { - let arr_a = arr_a.as_any().downcast_ref::().unwrap(); - let arr_b = arr_b.as_any().downcast_ref::().unwrap(); - criterion::black_box(divide_checked(arr_a, arr_b).unwrap()); -} - -fn bench_divide_unchecked(arr_a: &ArrayRef, arr_b: &ArrayRef) { - let arr_a = arr_a.as_any().downcast_ref::().unwrap(); - let arr_b = arr_b.as_any().downcast_ref::().unwrap(); - criterion::black_box(divide(arr_a, arr_b).unwrap()); -} - -fn bench_divide_scalar(array: &ArrayRef, divisor: f32) { - let array = array.as_any().downcast_ref::().unwrap(); - criterion::black_box(divide_scalar(array, divisor).unwrap()); -} - -fn bench_modulo(arr_a: &ArrayRef, arr_b: &ArrayRef) { - let arr_a = arr_a.as_any().downcast_ref::().unwrap(); - let arr_b = arr_b.as_any().downcast_ref::().unwrap(); - criterion::black_box(modulus(arr_a, arr_b).unwrap()); -} - -fn bench_modulo_scalar(array: &ArrayRef, divisor: f32) { - let array = array.as_any().downcast_ref::().unwrap(); - criterion::black_box(modulus_scalar(array, divisor).unwrap()); -} - fn add_benchmark(c: &mut Criterion) { const BATCH_SIZE: usize = 64 * 1024; - let arr_a = create_array(BATCH_SIZE, false); - let arr_b = create_array(BATCH_SIZE, false); - let scalar = seedable_rng().gen(); - - c.bench_function("add", |b| b.iter(|| bench_add(&arr_a, &arr_b))); - c.bench_function("subtract", |b| b.iter(|| bench_subtract(&arr_a, &arr_b))); - c.bench_function("multiply", |b| b.iter(|| bench_multiply(&arr_a, &arr_b))); - c.bench_function("divide", |b| b.iter(|| bench_divide(&arr_a, &arr_b))); - c.bench_function("divide_unchecked", |b| { - b.iter(|| bench_divide_unchecked(&arr_a, &arr_b)) - }); - c.bench_function("divide_scalar", |b| { - b.iter(|| bench_divide_scalar(&arr_a, scalar)) - }); - c.bench_function("modulo", |b| b.iter(|| bench_modulo(&arr_a, &arr_b))); - c.bench_function("modulo_scalar", |b| { - b.iter(|| bench_modulo_scalar(&arr_a, scalar)) - }); - - let arr_a_nulls = create_array(BATCH_SIZE, true); - let arr_b_nulls = create_array(BATCH_SIZE, true); - c.bench_function("add_nulls", |b| { - b.iter(|| bench_add(&arr_a_nulls, &arr_b_nulls)) - }); - c.bench_function("divide_nulls", |b| { - b.iter(|| bench_divide(&arr_a_nulls, &arr_b_nulls)) - }); - c.bench_function("divide_nulls_unchecked", |b| { - b.iter(|| bench_divide_unchecked(&arr_a_nulls, &arr_b_nulls)) - }); - c.bench_function("divide_scalar_nulls", |b| { - b.iter(|| bench_divide_scalar(&arr_a_nulls, scalar)) - }); - c.bench_function("modulo_nulls", |b| { - b.iter(|| bench_modulo(&arr_a_nulls, &arr_b_nulls)) - }); - c.bench_function("modulo_scalar_nulls", |b| { - b.iter(|| bench_modulo_scalar(&arr_a_nulls, scalar)) - }); + for null_density in [0., 0.1, 0.5, 0.9, 1.0] { + let arr_a = create_primitive_array::(BATCH_SIZE, null_density); + let arr_b = create_primitive_array::(BATCH_SIZE, null_density); + let scalar = seedable_rng().gen(); + + c.bench_function(&format!("add({})", null_density), |b| { + b.iter(|| criterion::black_box(add(&arr_a, &arr_b).unwrap())) + }); + c.bench_function(&format!("add_checked({})", null_density), |b| { + b.iter(|| criterion::black_box(add_checked(&arr_a, &arr_b).unwrap())) + }); + c.bench_function(&format!("add_scalar({})", null_density), |b| { + b.iter(|| criterion::black_box(add_scalar(&arr_a, scalar).unwrap())) + }); + c.bench_function(&format!("subtract({})", null_density), |b| { + b.iter(|| criterion::black_box(subtract(&arr_a, &arr_b).unwrap())) + }); + c.bench_function(&format!("subtract_checked({})", null_density), |b| { + b.iter(|| criterion::black_box(subtract_checked(&arr_a, &arr_b).unwrap())) + }); + c.bench_function(&format!("subtract_scalar({})", null_density), |b| { + b.iter(|| criterion::black_box(subtract_scalar(&arr_a, scalar).unwrap())) + }); + c.bench_function(&format!("multiply({})", null_density), |b| { + b.iter(|| criterion::black_box(multiply(&arr_a, &arr_b).unwrap())) + }); + c.bench_function(&format!("multiply_checked({})", null_density), |b| { + b.iter(|| criterion::black_box(multiply_checked(&arr_a, &arr_b).unwrap())) + }); + c.bench_function(&format!("multiply_scalar({})", null_density), |b| { + b.iter(|| criterion::black_box(multiply_scalar(&arr_a, scalar).unwrap())) + }); + c.bench_function(&format!("divide({})", null_density), |b| { + b.iter(|| criterion::black_box(divide(&arr_a, &arr_b).unwrap())) + }); + c.bench_function(&format!("divide_checked({})", null_density), |b| { + b.iter(|| criterion::black_box(divide_checked(&arr_a, &arr_b).unwrap())) + }); + c.bench_function(&format!("divide_scalar({})", null_density), |b| { + b.iter(|| criterion::black_box(divide_scalar(&arr_a, scalar).unwrap())) + }); + c.bench_function(&format!("modulo({})", null_density), |b| { + b.iter(|| criterion::black_box(modulus(&arr_a, &arr_b).unwrap())) + }); + c.bench_function(&format!("modulo_scalar({})", null_density), |b| { + b.iter(|| criterion::black_box(modulus_scalar(&arr_a, scalar).unwrap())) + }); + } } criterion_group!(benches, add_benchmark); diff --git a/arrow/src/array/iterator.rs b/arrow/src/array/iterator.rs index 4269e99625b7..e64712fa883a 100644 --- a/arrow/src/array/iterator.rs +++ b/arrow/src/array/iterator.rs @@ -24,8 +24,51 @@ use super::{ PrimitiveArray, }; -/// an iterator that returns Some(T) or None, that can be used on any [`ArrayAccessor`] -// Note: This implementation is based on std's [Vec]s' [IntoIter]. +/// An iterator that returns Some(T) or None, that can be used on any [`ArrayAccessor`] +/// +/// # Performance +/// +/// [`ArrayIter`] provides an idiomatic way to iterate over an array, however, this +/// comes at the cost of performance. In particular the interleaved handling of +/// the null mask is often sub-optimal. +/// +/// If performing an infallible operation, it is typically faster to perform the operation +/// on every index of the array, and handle the null mask separately. For [`PrimitiveArray`] +/// this functionality is provided by [`compute::unary`] +/// +/// ``` +/// # use arrow::array::PrimitiveArray; +/// # use arrow::compute::unary; +/// # use arrow::datatypes::Int32Type; +/// +/// fn add(a: &PrimitiveArray, b: i32) -> PrimitiveArray { +/// unary(a, |a| a + b) +/// } +/// ``` +/// +/// If performing a fallible operation, it isn't possible to perform the operation independently +/// of the null mask, as this might result in a spurious failure on a null index. However, +/// there are more efficient ways to iterate over just the non-null indices, this functionality +/// is provided by [`compute::try_unary`] +/// +/// ``` +/// # use arrow::array::PrimitiveArray; +/// # use arrow::compute::try_unary; +/// # use arrow::datatypes::Int32Type; +/// # use arrow::error::{ArrowError, Result}; +/// +/// fn checked_add(a: &PrimitiveArray, b: i32) -> Result> { +/// try_unary(a, |a| { +/// a.checked_add(b).ok_or_else(|| { +/// ArrowError::CastError(format!("overflow adding {} to {}", a, b)) +/// }) +/// }) +/// } +/// ``` +/// +/// [`PrimitiveArray`]: [crate::array::PrimitiveArray] +/// [`compute::unary`]: [crate::compute::unary] +/// [`compute::try_unary`]: [crate::compute::try_unary] #[derive(Debug)] pub struct ArrayIter { array: T, diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 9bf4b00c3132..17850f2a8cff 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -31,12 +31,11 @@ use crate::buffer::Buffer; #[cfg(feature = "simd")] use crate::buffer::MutableBuffer; use crate::compute::kernels::arity::unary; -use crate::compute::unary_dyn; use crate::compute::util::combine_option_bitmap; +use crate::compute::{binary, try_binary, unary_dyn}; use crate::datatypes::{ - native_op::ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, DataType, - Date32Type, Date64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - IntervalYearMonthType, + native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type, + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, }; use crate::datatypes::{ Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, @@ -74,33 +73,7 @@ where )); } - let null_bit_buffer = - combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; - - let values = left - .values() - .iter() - .zip(right.values().iter()) - .map(|(l, r)| op(*l, *r)); - // JUSTIFICATION - // Benefit - // ~60% speedup - // Soundness - // `values` is an iterator with a known size from a PrimitiveArray - let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; - - let data = unsafe { - ArrayData::new_unchecked( - LT::DATA_TYPE, - left.len(), - None, - null_bit_buffer, - 0, - vec![buffer], - vec![], - ) - }; - Ok(PrimitiveArray::::from(data)) + Ok(binary(left, right, op)) } /// This is similar to `math_op` as it performs given operation between two input primitive arrays. @@ -122,85 +95,11 @@ where )); } - let left_iter = ArrayIter::new(left); - let right_iter = ArrayIter::new(right); - - let values: Result::Native>>> = left_iter - .into_iter() - .zip(right_iter.into_iter()) - .map(|(l, r)| { - if let (Some(l), Some(r)) = (l, r) { - let result = op(l, r); - if let Some(r) = result { - Ok(Some(r)) - } else { - // Overflow - Err(ArrowError::ComputeError(format!( - "Overflow happened on: {:?}, {:?}", - l, r - ))) - } - } else { - Ok(None) - } - }) - .collect(); - - let values = values?; - - Ok(PrimitiveArray::::from_iter(values)) -} - -/// This is similar to `math_checked_op` but just for divide op. -fn math_checked_divide( - left: &PrimitiveArray, - right: &PrimitiveArray, - op: F, -) -> Result> -where - LT: ArrowNumericType, - RT: ArrowNumericType, - RT::Native: One + Zero, - F: Fn(LT::Native, RT::Native) -> Option, -{ - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform math operation on arrays of different length".to_string(), - )); - } - - let left_iter = ArrayIter::new(left); - let right_iter = ArrayIter::new(right); - - let values: Result::Native>>> = left_iter - .into_iter() - .zip(right_iter.into_iter()) - .map(|(l, r)| { - if let (Some(l), Some(r)) = (l, r) { - let result = op(l, r); - if let Some(r) = result { - Ok(Some(r)) - } else if r.is_zero() { - Err(ArrowError::ComputeError(format!( - "DivideByZero on: {:?}, {:?}", - l, r - ))) - } else { - // Overflow - Err(ArrowError::ComputeError(format!( - "Overflow happened on: {:?}, {:?}", - l, r - ))) - } - } else { - Ok(None) - } + try_binary(left, right, |a, b| { + op(a, b).ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow happened on: {:?}, {:?}", a, b)) }) - .collect(); - - let values = values?; - - Ok(PrimitiveArray::::from_iter(values)) + }) } /// Helper function for operations where a valid `0` on the right array should @@ -211,15 +110,16 @@ where /// This function errors if: /// * the arrays have different lengths /// * there is an element where both left and right values are valid and the right value is `0` -fn math_checked_divide_op( - left: &PrimitiveArray, - right: &PrimitiveArray, +fn math_checked_divide_op( + left: &PrimitiveArray, + right: &PrimitiveArray, op: F, -) -> Result> +) -> Result> where - T: ArrowNumericType, - T::Native: One + Zero, - F: Fn(T::Native, T::Native) -> T::Native, + LT: ArrowNumericType, + RT: ArrowNumericType, + RT::Native: One + Zero, + F: Fn(LT::Native, RT::Native) -> Option, { if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -227,16 +127,18 @@ where )); } - let null_bit_buffer = - combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; - - math_checked_divide_op_on_iters( - left.into_iter(), - right.into_iter(), - op, - left.len(), - null_bit_buffer, - ) + try_binary(left, right, |l, r| { + if r.is_zero() { + Err(ArrowError::DivideByZero) + } else { + op(l, r).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?}, {:?}", + l, r + )) + }) + } + }) } /// Helper function for operations where a valid `0` on the right array should @@ -900,7 +802,7 @@ pub fn add_scalar( scalar: T::Native, ) -> Result> where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: Add, { Ok(unary(array, |value| value + scalar)) @@ -911,7 +813,7 @@ where /// the scalar, or a `DictionaryArray` of the value type same as the scalar. pub fn add_scalar_dyn(array: &dyn Array, scalar: T::Native) -> Result where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: Add, { unary_dyn::<_, T>(array, |value| value + scalar) @@ -927,7 +829,7 @@ pub fn subtract( right: &PrimitiveArray, ) -> Result> where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { math_op(left, right, |a, b| a.sub_wrapping(b)) @@ -943,7 +845,7 @@ pub fn subtract_checked( right: &PrimitiveArray, ) -> Result> where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { math_checked_op(left, right, |a, b| a.sub_checked(b)) @@ -1033,7 +935,7 @@ pub fn multiply( right: &PrimitiveArray, ) -> Result> where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { math_op(left, right, |a, b| a.mul_wrapping(b)) @@ -1049,7 +951,7 @@ pub fn multiply_checked( right: &PrimitiveArray, ) -> Result> where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { math_checked_op(left, right, |a, b| a.mul_checked(b)) @@ -1100,7 +1002,7 @@ where /// the scalar, or a `DictionaryArray` of the value type same as the scalar. pub fn multiply_scalar_dyn(array: &dyn Array, scalar: T::Native) -> Result where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: Add + Sub + Mul @@ -1120,7 +1022,7 @@ pub fn modulus( right: &PrimitiveArray, ) -> Result> where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: Rem + Zero + One, { #[cfg(feature = "simd")] @@ -1128,7 +1030,7 @@ where a % b }); #[cfg(not(feature = "simd"))] - return math_checked_divide_op(left, right, |a, b| a % b); + return math_checked_divide_op(left, right, |a, b| Some(a % b)); } /// Perform `left / right` operation on two arrays. If either left or right value is null @@ -1148,7 +1050,7 @@ where #[cfg(feature = "simd")] return simd_checked_divide_op(&left, &right, simd_checked_divide::, |a, b| a / b); #[cfg(not(feature = "simd"))] - return math_checked_divide(left, right, |a, b| a.div_checked(b)); + return math_checked_divide_op(left, right, |a, b| a.div_checked(b)); } /// Perform `left / right` operation on two arrays. If either left or right value is null @@ -1162,7 +1064,7 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result { _ => { downcast_primitive_array!( (left, right) => { - math_checked_divide_op(left, right, |a, b| a / b).map(|a| Arc::new(a) as ArrayRef) + math_checked_divide_op(left, right, |a, b| Some(a / b)).map(|a| Arc::new(a) as ArrayRef) } _ => Err(ArrowError::CastError(format!( "Unsupported data type {}, {}", @@ -1199,7 +1101,7 @@ pub fn modulus_scalar( modulo: T::Native, ) -> Result> where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: Rem + Zero, { if modulo.is_zero() { @@ -1217,7 +1119,7 @@ pub fn divide_scalar( divisor: T::Native, ) -> Result> where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: Div + Zero, { if divisor.is_zero() { @@ -1232,7 +1134,7 @@ where /// same as the scalar, or a `DictionaryArray` of the value type same as the scalar. pub fn divide_scalar_dyn(array: &dyn Array, divisor: T::Native) -> Result where - T: datatypes::ArrowNumericType, + T: ArrowNumericType, T::Native: Div + Zero, { if divisor.is_zero() { diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 1251baf52fd8..ee3ff5e23a83 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -17,37 +17,41 @@ //! Defines kernels suitable to perform operations to primitive arrays. -use crate::array::{Array, ArrayData, ArrayRef, DictionaryArray, PrimitiveArray}; +use crate::array::{ + Array, ArrayData, ArrayRef, BufferBuilder, DictionaryArray, PrimitiveArray, +}; use crate::buffer::Buffer; +use crate::compute::util::combine_option_bitmap; use crate::datatypes::{ArrowNumericType, ArrowPrimitiveType}; use crate::downcast_dictionary_array; use crate::error::{ArrowError, Result}; +use crate::util::bit_iterator::try_for_each_valid_idx; use std::sync::Arc; #[inline] -fn into_primitive_array_data( - array: &PrimitiveArray, +unsafe fn build_primitive_array( + len: usize, buffer: Buffer, -) -> ArrayData { - let data = array.data(); - unsafe { - ArrayData::new_unchecked( - O::DATA_TYPE, - array.len(), - Some(data.null_count()), - data.null_buffer() - .map(|b| b.bit_slice(array.offset(), array.len())), - 0, - vec![buffer], - vec![], - ) - } + null_count: usize, + null_buffer: Option, +) -> PrimitiveArray { + PrimitiveArray::from(ArrayData::new_unchecked( + O::DATA_TYPE, + len, + Some(null_count), + null_buffer, + 0, + vec![buffer], + vec![], + )) } /// Applies an unary and infallible function to a primitive array. /// This is the fastest way to perform an operation on a primitive array when -/// the benefits of a vectorized operation outweights the cost of branching nulls and non-nulls. +/// the benefits of a vectorized operation outweigh the cost of branching nulls and non-nulls. +/// /// # Implementation +/// /// This will apply the function for all values, including those on null slots. /// This implies that the operation must be infallible for any value of the corresponding type /// or this function may panic. @@ -68,6 +72,14 @@ where O: ArrowPrimitiveType, F: Fn(I::Native) -> O::Native, { + let data = array.data(); + let len = data.len(); + let null_count = data.null_count(); + + let null_buffer = data + .null_buffer() + .map(|b| b.bit_slice(data.offset(), data.len())); + let values = array.values().iter().map(|v| op(*v)); // JUSTIFICATION // Benefit @@ -75,9 +87,40 @@ where // Soundness // `values` is an iterator with a known size because arrays are sized. let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; + unsafe { build_primitive_array(len, buffer, null_count, null_buffer) } +} + +/// Applies a unary and fallible function to all valid values in a primitive array +/// +/// This is unlike [`unary`] which will apply an infallible function to all rows regardless +/// of validity, in many cases this will be significantly faster and should be preferred +/// if `op` is infallible. +/// +/// Note: LLVM is currently unable to effectively vectorize fallible operations +pub fn try_unary(array: &PrimitiveArray, op: F) -> Result> +where + I: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(I::Native) -> Result, +{ + let len = array.len(); + let null_count = array.null_count(); + + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(array.len()); + let slice = buffer.as_slice_mut(); + + let null_buffer = array + .data_ref() + .null_buffer() + .map(|b| b.bit_slice(array.offset(), array.len())); - let data = into_primitive_array_data::<_, O>(array, buffer); - PrimitiveArray::::from(data) + try_for_each_valid_idx(array.len(), 0, null_count, null_buffer.as_deref(), |idx| { + unsafe { *slice.get_unchecked_mut(idx) = op(array.value_unchecked(idx))? }; + Ok::<_, ArrowError>(()) + })?; + + Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) }) } /// A helper function that applies an unary function to a dictionary array with primitive value type. @@ -119,6 +162,101 @@ where } } +/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, collecting +/// the results in a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the +/// corresponding index in the result will also be null +/// +/// Like [`unary`] the provided function is evaluated for every index, ignoring validity. This +/// is beneficial when the cost of the operation is low compared to the cost of branching, and +/// especially when the operation can be vectorised, however, requires `op` to be infallible +/// for all possible values of its inputs +/// +/// # Panic +/// +/// Panics if the arrays have different lengths +pub fn binary( + a: &PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> PrimitiveArray +where + A: ArrowPrimitiveType, + B: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(A::Native, B::Native) -> O::Native, +{ + assert_eq!(a.len(), b.len()); + let len = a.len(); + + if a.is_empty() { + return PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)); + } + + let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); + let null_count = null_buffer + .as_ref() + .map(|x| len - x.count_set_bits()) + .unwrap_or_default(); + + let values = a.values().iter().zip(b.values()).map(|(l, r)| op(*l, *r)); + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size from a PrimitiveArray + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; + + unsafe { build_primitive_array(len, buffer, null_count, null_buffer) } +} + +/// Applies the provided fallible binary operation across `a` and `b`, returning any error, +/// and collecting the results into a [`PrimitiveArray`]. If any index is null in either `a` +/// or `b`, the corresponding index in the result will also be null +/// +/// Like [`try_unary`] the function is only evaluated for non-null indices +/// +/// # Panic +/// +/// Panics if the arrays have different lengths +pub fn try_binary( + a: &PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> Result> +where + A: ArrowPrimitiveType, + B: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(A::Native, B::Native) -> Result, +{ + assert_eq!(a.len(), b.len()); + let len = a.len(); + + if a.is_empty() { + return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE))); + } + + let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); + let null_count = null_buffer + .as_ref() + .map(|x| len - x.count_set_bits()) + .unwrap_or_default(); + + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { + unsafe { + *slice.get_unchecked_mut(idx) = + op(a.value_unchecked(idx), b.value_unchecked(idx))? + }; + Ok::<_, ArrowError>(()) + })?; + + Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/arrow/src/util/bit_iterator.rs b/arrow/src/util/bit_iterator.rs index bba9dac60a4b..ceefaa860cb1 100644 --- a/arrow/src/util/bit_iterator.rs +++ b/arrow/src/util/bit_iterator.rs @@ -16,6 +16,7 @@ // under the License. use crate::util::bit_chunk_iterator::{UnalignedBitChunk, UnalignedBitChunkIterator}; +use std::result::Result; /// Iterator of contiguous ranges of set bits within a provided packed bitmask /// @@ -157,4 +158,45 @@ impl<'a> Iterator for BitIndexIterator<'a> { } } +/// Calls the provided closure for each index in the provided null mask that is set, +/// using an adaptive strategy based on the null count +/// +/// Ideally this would be encapsulated in an [`Iterator`] that would determine the optimal +/// strategy up front, and then yield indexes based on this. +/// +/// Unfortunately, external iteration based on the resulting [`Iterator`] would match the strategy +/// variant on each call to [`Iterator::next`], and LLVM generally cannot eliminate this. +/// +/// One solution to this might be internal iteration, e.g. [`Iterator::try_fold`], however, +/// it is currently [not possible] to override this for custom iterators in stable Rust. +/// +/// As such this is the next best option +/// +/// [not possible]: https://github.com/rust-lang/rust/issues/69595 +#[inline] +pub fn try_for_each_valid_idx Result<(), E>>( + len: usize, + offset: usize, + null_count: usize, + nulls: Option<&[u8]>, + f: F, +) -> Result<(), E> { + let valid_count = len - null_count; + + if valid_count == len { + (0..len).try_for_each(f) + } else if null_count != len { + let selectivity = valid_count as f64 / len as f64; + if selectivity > 0.8 { + BitSliceIterator::new(nulls.unwrap(), offset, len) + .flat_map(|(start, end)| start..end) + .try_for_each(f) + } else { + BitIndexIterator::new(nulls.unwrap(), offset, len).try_for_each(f) + } + } else { + Ok(()) + } +} + // Note: tests located in filter module