diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index 8f21cd7c498d..3825e5ec66f4 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -20,36 +20,117 @@ use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::ArrowNativeType; -use arrow_schema::ArrowError; +use arrow_buffer::{ArrowNativeType, NullBuffer}; +use arrow_schema::{ArrowError, SortOptions}; use std::cmp::Ordering; /// Compare the values at two arbitrary indices in two arrays. pub type DynComparator = Box Ordering + Send + Sync>; -fn compare_primitive(left: &dyn Array, right: &dyn Array) -> DynComparator +/// If parent sort order is descending we need to invert the value of nulls_first so that +/// when the parent is sorted based on the produced ranks, nulls are still ordered correctly +fn child_opts(opts: SortOptions) -> SortOptions { + SortOptions { + descending: false, + nulls_first: opts.nulls_first != opts.descending, + } +} + +fn compare(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator where - T::Native: ArrowNativeTypeOp, + A: Array + Clone, + F: Fn(usize, usize) -> Ordering + Send + Sync + 'static, { - let left = left.as_primitive::().clone(); - let right = right.as_primitive::().clone(); - Box::new(move |i, j| left.value(i).compare(right.value(j))) + let l = l.logical_nulls().filter(|x| x.null_count() > 0); + let r = r.logical_nulls().filter(|x| x.null_count() > 0); + match (opts.nulls_first, opts.descending) { + (true, true) => compare_impl::(l, r, cmp), + (true, false) => compare_impl::(l, r, cmp), + (false, true) => compare_impl::(l, r, cmp), + (false, false) => compare_impl::(l, r, cmp), + } } -fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left: BooleanArray = left.as_boolean().clone(); - let right: BooleanArray = right.as_boolean().clone(); +fn compare_impl( + l: Option, + r: Option, + cmp: F, +) -> DynComparator +where + F: Fn(usize, usize) -> Ordering + Send + Sync + 'static, +{ + let cmp = move |i, j| match DESCENDING { + true => cmp(i, j).reverse(), + false => cmp(i, j), + }; + + let (left_null, right_null) = match NULLS_FIRST { + true => (Ordering::Less, Ordering::Greater), + false => (Ordering::Greater, Ordering::Less), + }; + + match (l, r) { + (None, None) => Box::new(cmp), + (Some(l), None) => Box::new(move |i, j| match l.is_null(i) { + true => left_null, + false => cmp(i, j), + }), + (None, Some(r)) => Box::new(move |i, j| match r.is_null(j) { + true => right_null, + false => cmp(i, j), + }), + (Some(l), Some(r)) => Box::new(move |i, j| match (l.is_null(i), r.is_null(j)) { + (true, true) => Ordering::Equal, + (true, false) => left_null, + (false, true) => right_null, + (false, false) => cmp(i, j), + }), + } +} + +fn compare_primitive( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> DynComparator +where + T::Native: ArrowNativeTypeOp, +{ + let left = left.as_primitive::(); + let right = right.as_primitive::(); + let l_values = left.values().clone(); + let r_values = right.values().clone(); - Box::new(move |i, j| left.value(i).cmp(&right.value(j))) + compare(&left, &right, opts, move |i, j| { + l_values[i].compare(r_values[j]) + }) } -fn compare_bytes(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left = left.as_bytes::().clone(); - let right = right.as_bytes::().clone(); +fn compare_boolean(left: &dyn Array, right: &dyn Array, opts: SortOptions) -> DynComparator { + let left = left.as_boolean(); + let right = right.as_boolean(); + + let l_values = left.values().clone(); + let r_values = right.values().clone(); - Box::new(move |i, j| { - let l: &[u8] = left.value(i).as_ref(); - let r: &[u8] = right.value(j).as_ref(); + compare(left, right, opts, move |i, j| { + l_values.value(i).cmp(&r_values.value(j)) + }) +} + +fn compare_bytes( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> DynComparator { + let left = left.as_bytes::(); + let right = right.as_bytes::(); + + let l = left.clone(); + let r = right.clone(); + compare(left, right, opts, move |i, j| { + let l: &[u8] = l.value(i).as_ref(); + let r: &[u8] = r.value(j).as_ref(); l.cmp(r) }) } @@ -57,67 +138,234 @@ fn compare_bytes(left: &dyn Array, right: &dyn Array) -> DynCo fn compare_dict( left: &dyn Array, right: &dyn Array, + opts: SortOptions, ) -> Result { let left = left.as_dictionary::(); let right = right.as_dictionary::(); - let cmp = build_compare(left.values().as_ref(), right.values().as_ref())?; - let left_keys = left.keys().clone(); - let right_keys = right.keys().clone(); + let c_opts = child_opts(opts); + let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?; + let left_keys = left.keys().values().clone(); + let right_keys = right.keys().values().clone(); - // TODO: Handle value nulls (#2687) - Ok(Box::new(move |i, j| { - let l = left_keys.value(i).as_usize(); - let r = right_keys.value(j).as_usize(); + let f = compare(left, right, opts, move |i, j| { + let l = left_keys[i].as_usize(); + let r = right_keys[j].as_usize(); cmp(l, r) - })) + }); + Ok(f) } -/// returns a comparison function that compares two values at two different positions +fn compare_list( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> Result { + let left = left.as_list::(); + let right = right.as_list::(); + + let c_opts = child_opts(opts); + let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?; + + let l_o = left.offsets().clone(); + let r_o = right.offsets().clone(); + let f = compare(left, right, opts, move |i, j| { + let l_end = l_o[i + 1].as_usize(); + let l_start = l_o[i].as_usize(); + + let r_end = r_o[j + 1].as_usize(); + let r_start = r_o[j].as_usize(); + + for (i, j) in (l_start..l_end).zip(r_start..r_end) { + match cmp(i, j) { + Ordering::Equal => continue, + r => return r, + } + } + (l_end - l_start).cmp(&(r_end - r_start)) + }); + Ok(f) +} + +fn compare_fixed_list( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> Result { + let left = left.as_fixed_size_list(); + let right = right.as_fixed_size_list(); + + let c_opts = child_opts(opts); + let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?; + + let l_size = left.value_length().to_usize().unwrap(); + let r_size = right.value_length().to_usize().unwrap(); + let size_cmp = l_size.cmp(&r_size); + + let f = compare(left, right, opts, move |i, j| { + let l_start = i * l_size; + let l_end = l_start + l_size; + let r_start = j * r_size; + let r_end = r_start + r_size; + for (i, j) in (l_start..l_end).zip(r_start..r_end) { + match cmp(i, j) { + Ordering::Equal => continue, + r => return r, + } + } + size_cmp + }); + Ok(f) +} + +fn compare_struct( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> Result { + let left = left.as_struct(); + let right = right.as_struct(); + + if left.columns().len() != right.columns().len() { + return Err(ArrowError::InvalidArgumentError( + "Cannot compare StructArray with different number of columns".to_string(), + )); + } + + let c_opts = child_opts(opts); + let columns = left.columns().iter().zip(right.columns()); + let comparators = columns + .map(|(l, r)| make_comparator(l, r, c_opts)) + .collect::, _>>()?; + + let f = compare(left, right, opts, move |i, j| { + for cmp in &comparators { + match cmp(i, j) { + Ordering::Equal => continue, + r => return r, + } + } + Ordering::Equal + }); + Ok(f) +} + +#[deprecated(note = "Use make_comparator")] +#[doc(hidden)] +pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { + make_comparator(left, right, SortOptions::default()) +} + +/// Returns a comparison function that compares two values at two different positions /// between the two arrays. -/// The arrays' types must be equal. -/// # Example -/// ``` -/// use arrow_array::Int32Array; -/// use arrow_ord::ord::build_compare; /// +/// For comparing arrays element-wise, see also the vectorised kernels in [`crate::cmp`]. +/// +/// If `nulls_first` is true `NULL` values will be considered less than any non-null value, +/// otherwise they will be considered greater. +/// +/// # Basic Usage +/// +/// ``` +/// # use std::cmp::Ordering; +/// # use arrow_array::Int32Array; +/// # use arrow_ord::ord::make_comparator; +/// # use arrow_schema::SortOptions; +/// # /// let array1 = Int32Array::from(vec![1, 2]); /// let array2 = Int32Array::from(vec![3, 4]); /// -/// let cmp = build_compare(&array1, &array2).unwrap(); -/// +/// let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); /// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2) -/// assert_eq!(std::cmp::Ordering::Less, cmp(0, 1)); +/// assert_eq!(cmp(0, 1), Ordering::Less); +/// +/// let array1 = Int32Array::from(vec![Some(1), None]); +/// let array2 = Int32Array::from(vec![None, Some(2)]); +/// let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); +/// +/// assert_eq!(cmp(0, 1), Ordering::Less); // Some(1) vs Some(2) +/// assert_eq!(cmp(1, 1), Ordering::Less); // None vs Some(2) +/// assert_eq!(cmp(1, 0), Ordering::Equal); // None vs None +/// assert_eq!(cmp(0, 0), Ordering::Greater); // Some(1) vs None /// ``` -// This is a factory of comparisons. -// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime. -pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { +/// +/// # Postgres-compatible Nested Comparison +/// +/// Whilst SQL prescribes ternary logic for nulls, that is comparing a value against a NULL yields +/// a NULL, many systems, including postgres, instead apply a total ordering to comparison of +/// nested nulls. That is nulls within nested types are either greater than any value (postgres), +/// or less than any value (Spark). +/// +/// In particular +/// +/// ```ignore +/// { a: 1, b: null } == { a: 1, b: null } => true +/// { a: 1, b: null } == { a: 1, b: 1 } => false +/// { a: 1, b: null } == null => null +/// null == null => null +/// ``` +/// +/// This could be implemented as below +/// +/// ``` +/// # use arrow_array::{Array, BooleanArray}; +/// # use arrow_buffer::NullBuffer; +/// # use arrow_ord::cmp; +/// # use arrow_ord::ord::make_comparator; +/// # use arrow_schema::{ArrowError, SortOptions}; +/// fn eq(a: &dyn Array, b: &dyn Array) -> Result { +/// if !a.data_type().is_nested() { +/// return cmp::eq(&a, &b); // Use faster vectorised kernel +/// } +/// +/// let cmp = make_comparator(a, b, SortOptions::default())?; +/// let len = a.len().min(b.len()); +/// let values = (0..len).map(|i| cmp(i, i).is_eq()).collect(); +/// let nulls = NullBuffer::union(a.nulls(), b.nulls()); +/// Ok(BooleanArray::new(values, nulls)) +/// } +/// ```` +pub fn make_comparator( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> Result { use arrow_schema::DataType::*; + macro_rules! primitive_helper { - ($t:ty, $left:expr, $right:expr) => { - Ok(compare_primitive::<$t>($left, $right)) + ($t:ty, $left:expr, $right:expr, $nulls_first:expr) => { + Ok(compare_primitive::<$t>($left, $right, $nulls_first)) }; } downcast_primitive! { - left.data_type(), right.data_type() => (primitive_helper, left, right), - (Boolean, Boolean) => Ok(compare_boolean(left, right)), - (Utf8, Utf8) => Ok(compare_bytes::(left, right)), - (LargeUtf8, LargeUtf8) => Ok(compare_bytes::(left, right)), - (Binary, Binary) => Ok(compare_bytes::(left, right)), - (LargeBinary, LargeBinary) => Ok(compare_bytes::(left, right)), + left.data_type(), right.data_type() => (primitive_helper, left, right, opts), + (Boolean, Boolean) => Ok(compare_boolean(left, right, opts)), + (Utf8, Utf8) => Ok(compare_bytes::(left, right, opts)), + (LargeUtf8, LargeUtf8) => Ok(compare_bytes::(left, right, opts)), + (Binary, Binary) => Ok(compare_bytes::(left, right, opts)), + (LargeBinary, LargeBinary) => Ok(compare_bytes::(left, right, opts)), (FixedSizeBinary(_), FixedSizeBinary(_)) => { - let left = left.as_fixed_size_binary().clone(); - let right = right.as_fixed_size_binary().clone(); - Ok(Box::new(move |i, j| left.value(i).cmp(right.value(j)))) + let left = left.as_fixed_size_binary(); + let right = right.as_fixed_size_binary(); + + let l = left.clone(); + let r = right.clone(); + Ok(compare(left, right, opts, move |i, j| { + l.value(i).cmp(r.value(j)) + })) }, + (List(_), List(_)) => compare_list::(left, right, opts), + (LargeList(_), LargeList(_)) => compare_list::(left, right, opts), + (FixedSizeList(_, _), FixedSizeList(_, _)) => compare_fixed_list(left, right, opts), + (Struct(_), Struct(_)) => compare_struct(left, right, opts), (Dictionary(l_key, _), Dictionary(r_key, _)) => { macro_rules! dict_helper { - ($t:ty, $left:expr, $right:expr) => { - compare_dict::<$t>($left, $right) + ($t:ty, $left:expr, $right:expr, $opts: expr) => { + compare_dict::<$t>($left, $right, $opts) }; } downcast_integer! { - l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right), + l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right, opts), _ => unreachable!() } }, @@ -131,7 +379,9 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result>(); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); assert_eq!(Ordering::Equal, cmp(3, 4)); @@ -330,7 +580,7 @@ pub mod tests { let d2 = vec!["e", "f", "g", "a"]; let a2 = d2.into_iter().collect::>(); - let cmp = build_compare(&a1, &a2).unwrap(); + let cmp = make_comparator(&a1, &a2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Equal, cmp(0, 3)); @@ -347,7 +597,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -366,7 +616,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -385,7 +635,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -408,7 +658,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); // v1 vs v3 assert_eq!(Ordering::Equal, cmp(0, 3)); // v1 vs v1 @@ -427,7 +677,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -446,7 +696,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -475,7 +725,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -487,7 +737,7 @@ pub mod tests { fn test_bytes_impl() { let offsets = OffsetBuffer::from_lengths([3, 3, 1]); let a = GenericByteArray::::new(offsets, b"abcdefa".into(), None); - let cmp = build_compare(&a, &a).unwrap(); + let cmp = make_comparator(&a, &a, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); assert_eq!(Ordering::Greater, cmp(0, 2)); @@ -501,4 +751,157 @@ pub mod tests { test_bytes_impl::(); test_bytes_impl::(); } + + #[test] + fn test_lists() { + let mut a = ListBuilder::new(ListBuilder::new(Int32Builder::new())); + a.extend([ + Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]), + Some(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(1)]), + ]), + Some(vec![]), + ]); + let a = a.finish(); + let mut b = ListBuilder::new(ListBuilder::new(Int32Builder::new())); + b.extend([ + Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]), + Some(vec![ + Some(vec![Some(1), Some(2), None]), + Some(vec![Some(1)]), + ]), + Some(vec![ + Some(vec![Some(1), Some(2), Some(3), Some(4)]), + Some(vec![Some(1)]), + ]), + None, + ]); + let b = b.finish(); + + let opts = SortOptions { + descending: false, + nulls_first: true, + }; + let cmp = make_comparator(&a, &b, opts).unwrap(); + assert_eq!(cmp(0, 0), Ordering::Equal); + assert_eq!(cmp(0, 1), Ordering::Less); + assert_eq!(cmp(0, 2), Ordering::Less); + assert_eq!(cmp(1, 2), Ordering::Less); + assert_eq!(cmp(1, 3), Ordering::Greater); + assert_eq!(cmp(2, 0), Ordering::Less); + + let opts = SortOptions { + descending: true, + nulls_first: true, + }; + let cmp = make_comparator(&a, &b, opts).unwrap(); + assert_eq!(cmp(0, 0), Ordering::Equal); + assert_eq!(cmp(0, 1), Ordering::Less); + assert_eq!(cmp(0, 2), Ordering::Less); + assert_eq!(cmp(1, 2), Ordering::Greater); + assert_eq!(cmp(1, 3), Ordering::Greater); + assert_eq!(cmp(2, 0), Ordering::Greater); + + let opts = SortOptions { + descending: true, + nulls_first: false, + }; + let cmp = make_comparator(&a, &b, opts).unwrap(); + assert_eq!(cmp(0, 0), Ordering::Equal); + assert_eq!(cmp(0, 1), Ordering::Greater); + assert_eq!(cmp(0, 2), Ordering::Greater); + assert_eq!(cmp(1, 2), Ordering::Greater); + assert_eq!(cmp(1, 3), Ordering::Less); + assert_eq!(cmp(2, 0), Ordering::Greater); + + let opts = SortOptions { + descending: false, + nulls_first: false, + }; + let cmp = make_comparator(&a, &b, opts).unwrap(); + assert_eq!(cmp(0, 0), Ordering::Equal); + assert_eq!(cmp(0, 1), Ordering::Greater); + assert_eq!(cmp(0, 2), Ordering::Greater); + assert_eq!(cmp(1, 2), Ordering::Less); + assert_eq!(cmp(1, 3), Ordering::Less); + assert_eq!(cmp(2, 0), Ordering::Less); + } + + #[test] + fn test_struct() { + let fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new_list("b", Field::new("item", DataType::Int32, true), true), + ]); + + let a = Int32Array::from(vec![Some(1), Some(2), None, None]); + let mut b = ListBuilder::new(Int32Builder::new()); + b.extend([Some(vec![Some(1), Some(2)]), Some(vec![None]), None, None]); + let b = b.finish(); + + let nulls = Some(NullBuffer::from_iter([true, true, true, false])); + let values = vec![Arc::new(a) as _, Arc::new(b) as _]; + let s1 = StructArray::new(fields.clone(), values, nulls); + + let a = Int32Array::from(vec![None, Some(2), None]); + let mut b = ListBuilder::new(Int32Builder::new()); + b.extend([None, None, Some(vec![])]); + let b = b.finish(); + + let values = vec![Arc::new(a) as _, Arc::new(b) as _]; + let s2 = StructArray::new(fields.clone(), values, None); + + let opts = SortOptions { + descending: false, + nulls_first: true, + }; + let cmp = make_comparator(&s1, &s2, opts).unwrap(); + assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None) + assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, None) + assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None) + assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, []) + assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, []) + assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None) + assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None) + + let opts = SortOptions { + descending: true, + nulls_first: true, + }; + let cmp = make_comparator(&s1, &s2, opts).unwrap(); + assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None) + assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, None) + assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None) + assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, []) + assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, []) + assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None) + assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None) + + let opts = SortOptions { + descending: true, + nulls_first: false, + }; + let cmp = make_comparator(&s1, &s2, opts).unwrap(); + assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None) + assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None) + assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None) + assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, []) + assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, []) + assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None) + assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None) + + let opts = SortOptions { + descending: false, + nulls_first: false, + }; + let cmp = make_comparator(&s1, &s2, opts).unwrap(); + assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None) + assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None) + assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None) + assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, []) + assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, []) + assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None) + assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None) + } } diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index fe3a1f86ac00..8ae87787d283 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -17,13 +17,13 @@ //! Defines sort kernel for `ArrayRef` -use crate::ord::{build_compare, DynComparator}; +use crate::ord::{make_comparator, DynComparator}; use arrow_array::builder::BufferBuilder; use arrow_array::cast::*; use arrow_array::types::*; use arrow_array::*; +use arrow_buffer::ArrowNativeType; use arrow_buffer::BooleanBufferBuilder; -use arrow_buffer::{ArrowNativeType, NullBuffer}; use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType}; use arrow_select::take::take; @@ -704,60 +704,21 @@ where } } -type LexicographicalCompareItem = ( - Option, // nulls - DynComparator, // comparator - SortOptions, // sort_option -); - /// A lexicographical comparator that wraps given array data (columns) and can lexicographically compare data /// at given two indices. The lifetime is the same at the data wrapped. pub struct LexicographicalComparator { - compare_items: Vec, + compare_items: Vec, } impl LexicographicalComparator { /// lexicographically compare values at the wrapped columns with given indices. pub fn compare(&self, a_idx: usize, b_idx: usize) -> Ordering { - for (nulls, comparator, sort_option) in &self.compare_items { - let (lhs_valid, rhs_valid) = match nulls { - Some(n) => (n.is_valid(a_idx), n.is_valid(b_idx)), - None => (true, true), - }; - - match (lhs_valid, rhs_valid) { - (true, true) => { - match (comparator)(a_idx, b_idx) { - // equal, move on to next column - Ordering::Equal => continue, - order => { - if sort_option.descending { - return order.reverse(); - } else { - return order; - } - } - } - } - (false, true) => { - return if sort_option.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (true, false) => { - return if sort_option.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - // equal, move on to next column - (false, false) => continue, + for comparator in &self.compare_items { + match comparator(a_idx, b_idx) { + Ordering::Equal => continue, + r => return r, } } - Ordering::Equal } @@ -766,61 +727,16 @@ impl LexicographicalComparator { pub fn try_new(columns: &[SortColumn]) -> Result { let compare_items = columns .iter() - .map(Self::build_compare_item) + .map(|c| { + make_comparator( + c.values.as_ref(), + c.values.as_ref(), + c.options.unwrap_or_default(), + ) + }) .collect::, ArrowError>>()?; Ok(LexicographicalComparator { compare_items }) } - - fn build_compare_item(column: &SortColumn) -> Result { - let values = column.values.as_ref(); - let options = column.options.unwrap_or_default(); - let comparator = match values.data_type() { - DataType::List(_) => Self::build_list_compare(values.as_list::(), options)?, - DataType::LargeList(_) => Self::build_list_compare(values.as_list::(), options)?, - DataType::FixedSizeList(_, _) => { - Self::build_fixed_size_list_compare(values.as_fixed_size_list(), options)? - } - _ => build_compare(values, values)?, - }; - Ok((values.logical_nulls(), comparator, options)) - } - - fn build_list_compare( - array: &GenericListArray, - options: SortOptions, - ) -> Result { - let rank = child_rank(array.values().as_ref(), options)?; - let offsets = array.offsets().clone(); - let cmp = Box::new(move |i: usize, j: usize| { - macro_rules! nth_value { - ($INDEX:expr) => {{ - let end = offsets[$INDEX + 1].as_usize(); - let start = offsets[$INDEX].as_usize(); - &rank[start..end] - }}; - } - Ord::cmp(nth_value!(i), nth_value!(j)) - }); - Ok(cmp) - } - - fn build_fixed_size_list_compare( - array: &FixedSizeListArray, - options: SortOptions, - ) -> Result { - let rank = child_rank(array.values().as_ref(), options)?; - let size = array.value_length() as usize; - let cmp = Box::new(move |i: usize, j: usize| { - macro_rules! nth_value { - ($INDEX:expr) => {{ - let start = $INDEX * size; - &rank[start..start + size] - }}; - } - Ord::cmp(nth_value!(i), nth_value!(j)) - }); - Ok(cmp) - } } #[cfg(test)] @@ -829,7 +745,7 @@ mod tests { use arrow_array::builder::{ FixedSizeListBuilder, Int64Builder, ListBuilder, PrimitiveRunBuilder, }; - use arrow_buffer::i256; + use arrow_buffer::{i256, NullBuffer}; use half::f16; use rand::rngs::StdRng; use rand::{Rng, RngCore, SeedableRng}; diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 4dc7349ca2d5..8e1285493b0b 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -1302,9 +1302,9 @@ mod tests { use arrow_array::builder::*; use arrow_array::types::*; use arrow_array::*; - use arrow_buffer::i256; - use arrow_buffer::Buffer; - use arrow_cast::display::array_value_to_string; + use arrow_buffer::{i256, NullBuffer}; + use arrow_buffer::{Buffer, OffsetBuffer}; + use arrow_cast::display::{ArrayFormatter, FormatOptions}; use arrow_ord::sort::{LexicographicalComparator, SortColumn}; use super::*; @@ -2099,9 +2099,35 @@ mod tests { builder.finish() } + fn generate_struct(len: usize, valid_percent: f64) -> StructArray { + let mut rng = thread_rng(); + let nulls = NullBuffer::from_iter((0..len).map(|_| rng.gen_bool(valid_percent))); + let a = generate_primitive_array::(len, valid_percent); + let b = generate_strings::(len, valid_percent); + let fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]); + let values = vec![Arc::new(a) as _, Arc::new(b) as _]; + StructArray::new(fields, values, Some(nulls)) + } + + fn generate_list(len: usize, valid_percent: f64, values: F) -> ListArray + where + F: FnOnce(usize) -> ArrayRef, + { + let mut rng = thread_rng(); + let offsets = OffsetBuffer::::from_lengths((0..len).map(|_| rng.gen_range(0..10))); + let values_len = offsets.last().unwrap().to_usize().unwrap(); + let values = values(values_len); + let nulls = NullBuffer::from_iter((0..len).map(|_| rng.gen_bool(valid_percent))); + let field = Arc::new(Field::new("item", values.data_type().clone(), true)); + ListArray::new(field, offsets, values, Some(nulls)) + } + fn generate_column(len: usize) -> ArrayRef { let mut rng = thread_rng(); - match rng.gen_range(0..10) { + match rng.gen_range(0..14) { 0 => Arc::new(generate_primitive_array::(len, 0.8)), 1 => Arc::new(generate_primitive_array::(len, 0.8)), 2 => Arc::new(generate_primitive_array::(len, 0.8)), @@ -2125,6 +2151,16 @@ mod tests { 0.8, )), 9 => Arc::new(generate_fixed_size_binary(len, 0.8)), + 10 => Arc::new(generate_struct(len, 0.8)), + 11 => Arc::new(generate_list(len, 0.8, |values_len| { + Arc::new(generate_primitive_array::(values_len, 0.8)) + })), + 12 => Arc::new(generate_list(len, 0.8, |values_len| { + Arc::new(generate_strings::(values_len, 0.8)) + })), + 13 => Arc::new(generate_list(len, 0.8, |values_len| { + Arc::new(generate_struct(values_len, 0.8)) + })), _ => unreachable!(), } } @@ -2132,7 +2168,14 @@ mod tests { fn print_row(cols: &[SortColumn], row: usize) -> String { let t: Vec<_> = cols .iter() - .map(|x| array_value_to_string(&x.values, row).unwrap()) + .map(|x| match x.values.is_valid(row) { + true => { + let opts = FormatOptions::default().with_null("NULL"); + let formatter = ArrayFormatter::try_new(x.values.as_ref(), &opts).unwrap(); + formatter.value(row).to_string() + } + false => "NULL".to_string(), + }) .collect(); t.join(",") } diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index b563c320bb6d..242c9148cac4 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -36,4 +36,5 @@ pub use arrow_array::ffi::export_array_into_raw; // --------------------- Array's values comparison --------------------- -pub use arrow_ord::ord::{build_compare, DynComparator}; +#[allow(deprecated)] +pub use arrow_ord::ord::{build_compare, make_comparator, DynComparator};