diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index bee8823d1f59..66b40d5b8eb3 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -799,6 +799,15 @@ pub trait AsArray: private::Sealed { self.as_list_opt().expect("list array") } + /// Downcast this to a [`FixedSizeBinaryArray`] returning `None` if not possible + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray>; + + /// Downcast this to a [`FixedSizeBinaryArray`] panicking if not possible + fn as_fixed_size_binary(&self) -> &FixedSizeBinaryArray { + self.as_fixed_size_binary_opt() + .expect("fixed size binary array") + } + /// Downcast this to a [`FixedSizeListArray`] returning `None` if not possible fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray>; @@ -848,6 +857,10 @@ impl AsArray for dyn Array + '_ { self.as_any().downcast_ref() } + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> { + self.as_any().downcast_ref() + } + fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> { self.as_any().downcast_ref() } @@ -885,6 +898,10 @@ impl AsArray for ArrayRef { self.as_ref().as_list_opt() } + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> { + self.as_ref().as_fixed_size_binary_opt() + } + fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> { self.as_ref().as_fixed_size_list_opt() } diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index c3e9e26ec05e..648a7d7afcca 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -23,10 +23,9 @@ use arrow_array::cast::*; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::BooleanBufferBuilder; -use arrow_buffer::{ArrowNativeType, MutableBuffer, NullBuffer}; -use arrow_data::ArrayData; +use arrow_buffer::{ArrowNativeType, NullBuffer}; use arrow_data::ArrayDataBuilder; -use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; +use arrow_schema::{ArrowError, DataType}; use arrow_select::take::take; use std::cmp::Ordering; use std::sync::Arc; @@ -181,13 +180,6 @@ where } } -fn cmp(l: T, r: T) -> Ordering -where - T: Ord, -{ - l.cmp(&r) -} - // partition indices into valid and null indices fn partition_validity(array: &dyn Array) -> (Vec, Vec) { match array.null_count() { @@ -204,210 +196,33 @@ fn partition_validity(array: &dyn Array) -> (Vec, Vec) { /// For floating point arrays any NaN values are considered to be greater than any other non-null value. /// `limit` is an option for [partial_sort]. pub fn sort_to_indices( - values: &dyn Array, + array: &dyn Array, options: Option, limit: Option, ) -> Result { let options = options.unwrap_or_default(); - let (v, n) = partition_validity(values); - - Ok(match values.data_type() { - DataType::Decimal128(_, _) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Decimal256(_, _) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Boolean => sort_boolean(values, v, n, &options, limit), - DataType::Int8 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Int16 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Int32 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Int64 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::UInt8 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::UInt16 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::UInt32 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::UInt64 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Float16 => sort_primitive::( - values, - v, - n, - |x, y| x.total_cmp(&y), - &options, - limit, - ), - DataType::Float32 => sort_primitive::( - values, - v, - n, - |x, y| x.total_cmp(&y), - &options, - limit, - ), - DataType::Float64 => sort_primitive::( - values, - v, - n, - |x, y| x.total_cmp(&y), - &options, - limit, - ), - DataType::Date32 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Date64 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Time32(TimeUnit::Second) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Time32(TimeUnit::Millisecond) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Time64(TimeUnit::Microsecond) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Time64(TimeUnit::Nanosecond) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Timestamp(TimeUnit::Second, _) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Interval(IntervalUnit::YearMonth) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Interval(IntervalUnit::DayTime) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Duration(TimeUnit::Second) => { - sort_primitive::(values, v, n, cmp, &options, limit) + let (v, n) = partition_validity(array); + + Ok(downcast_primitive_array! { + array => sort_primitive(array, v, n, options, limit), + DataType::Boolean => sort_boolean(array.as_boolean(), v, n, options, limit), + DataType::Utf8 => sort_bytes(array.as_string::(), v, n, options, limit), + DataType::LargeUtf8 => sort_bytes(array.as_string::(), v, n, options, limit), + DataType::Binary => sort_bytes(array.as_binary::(), v, n, options, limit), + DataType::LargeBinary => sort_bytes(array.as_binary::(), v, n, options, limit), + DataType::FixedSizeBinary(_) => sort_fixed_size_binary(array.as_fixed_size_binary(), v, n, options, limit), + DataType::List(_) => sort_list(array.as_list::(), v, n, options, limit)?, + DataType::LargeList(_) => sort_list(array.as_list::(), v, n, options, limit)?, + DataType::FixedSizeList(_, _) => sort_fixed_size_list(array.as_fixed_size_list(), v, n, options, limit)?, + DataType::Dictionary(_, _) => downcast_dictionary_array!{ + array => sort_dictionary(array, v, n, options, limit)?, + _ => unreachable!() } - DataType::Duration(TimeUnit::Millisecond) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Duration(TimeUnit::Microsecond) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Duration(TimeUnit::Nanosecond) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Utf8 => sort_string::(values, v, n, &options, limit), - DataType::LargeUtf8 => sort_string::(values, v, n, &options, limit), - DataType::List(field) | DataType::FixedSizeList(field, _) => { - match field.data_type() { - DataType::Int8 => sort_list::(values, v, n, &options, limit), - DataType::Int16 => sort_list::(values, v, n, &options, limit), - DataType::Int32 => sort_list::(values, v, n, &options, limit), - DataType::Int64 => sort_list::(values, v, n, &options, limit), - DataType::UInt8 => sort_list::(values, v, n, &options, limit), - DataType::UInt16 => sort_list::(values, v, n, &options, limit), - DataType::UInt32 => sort_list::(values, v, n, &options, limit), - DataType::UInt64 => sort_list::(values, v, n, &options, limit), - DataType::Float16 => sort_list::(values, v, n, &options, limit), - DataType::Float32 => sort_list::(values, v, n, &options, limit), - DataType::Float64 => sort_list::(values, v, n, &options, limit), - t => { - return Err(ArrowError::ComputeError(format!( - "Sort not supported for list type {t:?}" - ))); - } - } - } - DataType::LargeList(field) => match field.data_type() { - DataType::Int8 => sort_list::(values, v, n, &options, limit), - DataType::Int16 => sort_list::(values, v, n, &options, limit), - DataType::Int32 => sort_list::(values, v, n, &options, limit), - DataType::Int64 => sort_list::(values, v, n, &options, limit), - DataType::UInt8 => sort_list::(values, v, n, &options, limit), - DataType::UInt16 => sort_list::(values, v, n, &options, limit), - DataType::UInt32 => sort_list::(values, v, n, &options, limit), - DataType::UInt64 => sort_list::(values, v, n, &options, limit), - DataType::Float16 => sort_list::(values, v, n, &options, limit), - DataType::Float32 => sort_list::(values, v, n, &options, limit), - DataType::Float64 => sort_list::(values, v, n, &options, limit), - t => { - return Err(ArrowError::ComputeError(format!( - "Sort not supported for list type {t:?}" - ))); - } - }, - DataType::Dictionary(_, _) => { - let value_null_first = if options.descending { - // When sorting dictionary in descending order, we take inverse of of null ordering - // when sorting the values. Because if `nulls_first` is true, null must be in front - // of non-null value. As we take the sorted order of value array to sort dictionary - // keys, these null values will be treated as smallest ones and be sorted to the end - // of sorted result. So we set `nulls_first` to false when sorting dictionary value - // array to make them as largest ones, then null values will be put at the beginning - // of sorted dictionary result. - !options.nulls_first - } else { - options.nulls_first - }; - let value_options = Some(SortOptions { - descending: false, - nulls_first: value_null_first, - }); - downcast_dictionary_array! { - values => { - let dict_values = values.values(); - let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; - let rank = sorted_rank(&sorted_value_indices); - sort_dictionary(values, &rank, v, n, options, limit) - } - _ => unreachable!(), - } - } - DataType::Binary | DataType::FixedSizeBinary(_) => { - sort_binary::(values, v, n, &options, limit) - } - DataType::LargeBinary => sort_binary::(values, v, n, &options, limit), DataType::RunEndEncoded(run_ends_field, _) => match run_ends_field.data_type() { - DataType::Int16 => sort_run_to_indices::(values, &options, limit), - DataType::Int32 => sort_run_to_indices::(values, &options, limit), - DataType::Int64 => sort_run_to_indices::(values, &options, limit), + DataType::Int16 => sort_run_to_indices::(array, options, limit), + DataType::Int32 => sort_run_to_indices::(array, options, limit), + DataType::Int64 => sort_run_to_indices::(array, options, limit), dt => { return Err(ArrowError::ComputeError(format!( "Invalid run end data type: {dt}" @@ -422,147 +237,76 @@ pub fn sort_to_indices( }) } -/// Sort boolean values -/// -/// when a limit is present, the sort is pair-comparison based as k-select might be more efficient, -/// when the limit is absent, binary partition is used to speed up (which is linear). -/// -/// TODO maybe partition_validity call can be eliminated in this case -/// and [tri-color sort](https://en.wikipedia.org/wiki/Dutch_national_flag_problem) -/// can be used instead. fn sort_boolean( - values: &dyn Array, + values: &BooleanArray, value_indices: Vec, - mut null_indices: Vec, - options: &SortOptions, + null_indices: Vec, + options: SortOptions, limit: Option, ) -> UInt32Array { - let values = values - .as_any() - .downcast_ref::() - .expect("Unable to downcast to boolean array"); - let descending = options.descending; - - let valids_len = value_indices.len(); - let nulls_len = null_indices.len(); - - let mut len = values.len(); - let valids = if let Some(limit) = limit { - len = limit.min(len); - // create tuples that are used for sorting - let mut valids = value_indices - .into_iter() - .map(|index| (index, values.value(index as usize))) - .collect::>(); - - sort_valids(descending, &mut valids, len, cmp); - valids - } else { - // when limit is not present, we have a better way than sorting: we can just partition - // the vec into [false..., true...] or [true..., false...] when descending - // TODO when https://github.com/rust-lang/rust/issues/62543 is merged we can use partition_in_place - let (mut a, b): (Vec<_>, Vec<_>) = value_indices - .into_iter() - .map(|index| (index, values.value(index as usize))) - .partition(|(_, value)| *value == descending); - a.extend(b); - if descending { - null_indices.reverse(); - } - a - }; - - let nulls = null_indices; - - // collect results directly into a buffer instead of a vec to avoid another aligned allocation - let result_capacity = len * std::mem::size_of::(); - let mut result = MutableBuffer::new(result_capacity); - // sets len to capacity so we can access the whole buffer as a typed slice - result.resize(result_capacity, 0); - let result_slice: &mut [u32] = result.typed_data_mut(); - - if options.nulls_first { - let size = nulls_len.min(len); - result_slice[0..size].copy_from_slice(&nulls[0..size]); - if nulls_len < len { - insert_valid_values(result_slice, nulls_len, &valids[0..len - size]); - } - } else { - // nulls last - let size = valids.len().min(len); - insert_valid_values(result_slice, 0, &valids[0..size]); - if len > size { - result_slice[valids_len..].copy_from_slice(&nulls[0..(len - valids_len)]); - } - } - - let result_data = unsafe { - ArrayData::new_unchecked( - DataType::UInt32, - len, - Some(0), - None, - 0, - vec![result.into()], - vec![], - ) - }; + let mut valids = value_indices + .into_iter() + .map(|index| (index, values.value(index as usize))) + .collect::>(); + sort_impl(options, &mut valids, &null_indices, limit, |a, b| a.cmp(&b)).into() +} - UInt32Array::from(result_data) +fn sort_primitive( + values: &PrimitiveArray, + value_indices: Vec, + nulls: Vec, + options: SortOptions, + limit: Option, +) -> UInt32Array { + let mut valids = value_indices + .into_iter() + .map(|index| (index, values.value(index as usize))) + .collect::>(); + sort_impl(options, &mut valids, &nulls, limit, T::Native::compare).into() } -/// Sort primitive values -fn sort_primitive( - values: &dyn Array, +fn sort_bytes( + values: &GenericByteArray, value_indices: Vec, - null_indices: Vec, - cmp: F, - options: &SortOptions, + nulls: Vec, + options: SortOptions, limit: Option, -) -> UInt32Array -where - T: ArrowPrimitiveType, - T::Native: PartialOrd, - F: Fn(T::Native, T::Native) -> Ordering, -{ - // create tuples that are used for sorting - let valids = { - let values = values.as_primitive::(); - value_indices - .into_iter() - .map(|index| (index, values.value(index as usize))) - .collect::>() - }; - sort_primitive_inner(values.len(), null_indices, cmp, options, limit, valids) +) -> UInt32Array { + let mut valids = value_indices + .into_iter() + .map(|index| (index, values.value(index as usize).as_ref())) + .collect::>(); + + sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into() } -/// Given a list of indices that yield a sorted order, returns the ordered -/// rank of each index -/// -/// e.g. [2, 4, 3, 1, 0] -> [4, 3, 0, 2, 1] -fn sorted_rank(sorted_value_indices: &UInt32Array) -> Vec { - assert_eq!(sorted_value_indices.null_count(), 0); - let sorted_indices = sorted_value_indices.values(); - let mut out: Vec<_> = vec![0_u32; sorted_indices.len()]; - for (ix, val) in sorted_indices.iter().enumerate() { - out[*val as usize] = ix as u32; - } - out +fn sort_fixed_size_binary( + values: &FixedSizeBinaryArray, + value_indices: Vec, + nulls: Vec, + options: SortOptions, + limit: Option, +) -> UInt32Array { + let mut valids = value_indices + .iter() + .copied() + .map(|index| (index, values.value(index as usize))) + .collect::>(); + sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into() } -/// Sort dictionary given the sorted rank of each key fn sort_dictionary( dict: &DictionaryArray, - rank: &[u32], value_indices: Vec, null_indices: Vec, options: SortOptions, limit: Option, -) -> UInt32Array { +) -> Result { let keys: &PrimitiveArray = dict.keys(); + let rank = child_rank(dict.values().as_ref(), options)?; // create tuples that are used for sorting - let valids = value_indices + let mut valids = value_indices .into_iter() .map(|index| { let key: K::Native = keys.value(index as usize); @@ -570,83 +314,100 @@ fn sort_dictionary( }) .collect::>(); - sort_primitive_inner::<_, _>(keys.len(), null_indices, cmp, &options, limit, valids) + Ok(sort_impl(options, &mut valids, &null_indices, limit, |a, b| a.cmp(&b)).into()) } -// sort is instantiated a lot so we only compile this inner version for each native type -fn sort_primitive_inner( - value_len: usize, - nulls: Vec, - cmp: F, - options: &SortOptions, +fn sort_list( + array: &GenericListArray, + value_indices: Vec, + null_indices: Vec, + options: SortOptions, limit: Option, - mut valids: Vec<(u32, T)>, -) -> UInt32Array -where - T: ArrowNativeType, - T: PartialOrd, - F: Fn(T, T) -> Ordering, -{ - let valids_len = valids.len(); - let nulls_len = nulls.len(); - let mut len = value_len; +) -> Result { + let rank = child_rank(array.values().as_ref(), options)?; + let offsets = array.value_offsets(); + let mut valids = value_indices + .into_iter() + .map(|index| { + let end = offsets[index as usize + 1].as_usize(); + let start = offsets[index as usize].as_usize(); + (index, &rank[start..end]) + }) + .collect::>(); + Ok(sort_impl(options, &mut valids, &null_indices, limit, Ord::cmp).into()) +} - if let Some(limit) = limit { - len = limit.min(len); - } +fn sort_fixed_size_list( + array: &FixedSizeListArray, + value_indices: Vec, + null_indices: Vec, + options: SortOptions, + limit: Option, +) -> Result { + let rank = child_rank(array.values().as_ref(), options)?; + let size = array.value_length() as usize; + let mut valids = value_indices + .into_iter() + .map(|index| { + let start = index as usize * size; + (index, &rank[start..start + size]) + }) + .collect::>(); + Ok(sort_impl(options, &mut valids, &null_indices, limit, Ord::cmp).into()) +} - sort_valids(options.descending, &mut valids, len, cmp); +#[inline(never)] +fn sort_impl( + options: SortOptions, + valids: &mut [(u32, T)], + nulls: &[u32], + limit: Option, + mut cmp: impl FnMut(T, T) -> Ordering, +) -> Vec { + let v_limit = match (limit, options.nulls_first) { + (Some(l), true) => l.saturating_sub(nulls.len()).min(valids.len()), + _ => valids.len(), + }; - // collect results directly into a buffer instead of a vec to avoid another aligned allocation - let result_capacity = len * std::mem::size_of::(); - let mut result = MutableBuffer::new(result_capacity); - // sets len to capacity so we can access the whole buffer as a typed slice - result.resize(result_capacity, 0); - let result_slice: &mut [u32] = result.typed_data_mut(); + match options.descending { + false => sort_unstable_by(valids, v_limit, |a, b| cmp(a.1, b.1)), + true => sort_unstable_by(valids, v_limit, |a, b| cmp(a.1, b.1).reverse()), + } - if options.nulls_first { - let size = nulls_len.min(len); - result_slice[0..size].copy_from_slice(&nulls[0..size]); - if nulls_len < len { - insert_valid_values(result_slice, nulls_len, &valids[0..len - size]); + let len = valids.len() + nulls.len(); + let limit = limit.unwrap_or(len).min(len); + let mut out = Vec::with_capacity(len); + match options.nulls_first { + true => { + out.extend_from_slice(&nulls[..nulls.len().min(limit)]); + let remaining = limit - out.len(); + out.extend(valids.iter().map(|x| x.0).take(remaining)); } - } else { - // nulls last - let size = valids.len().min(len); - insert_valid_values(result_slice, 0, &valids[0..size]); - if len > size { - result_slice[valids_len..].copy_from_slice(&nulls[0..(len - valids_len)]); + false => { + out.extend(valids.iter().map(|x| x.0).take(limit)); + let remaining = limit - out.len(); + out.extend_from_slice(&nulls[..remaining]) } } - - let result_data = unsafe { - ArrayData::new_unchecked( - DataType::UInt32, - len, - Some(0), - None, - 0, - vec![result.into()], - vec![], - ) - }; - - UInt32Array::from(result_data) + out } -// insert valid and nan values in the correct order depending on the descending flag -fn insert_valid_values(result_slice: &mut [u32], offset: usize, valids: &[(u32, T)]) { - let valids_len = valids.len(); - // helper to append the index part of the valid tuples - let append_valids = move |dst_slice: &mut [u32]| { - debug_assert_eq!(dst_slice.len(), valids_len); - dst_slice - .iter_mut() - .zip(valids.iter()) - .for_each(|(dst, src)| *dst = src.0) - }; +/// Computes the rank for a set of child values +fn child_rank(values: &dyn Array, options: SortOptions) -> Result, ArrowError> { + // If parent sort order is descending we need to invert the value of nulls_first so that + // when the parent is sorted based on the produced ranks, nulls are still ordered correctly + let value_options = Some(SortOptions { + descending: false, + nulls_first: options.nulls_first != options.descending, + }); - append_valids(&mut result_slice[offset..offset + valids.len()]); + let sorted_value_indices = sort_to_indices(values, value_options, None)?; + let sorted_indices = sorted_value_indices.values(); + let mut out: Vec<_> = vec![0_u32; sorted_indices.len()]; + for (ix, val) in sorted_indices.iter().enumerate() { + out[*val as usize] = ix as u32; + } + Ok(out) } // Sort run array and return sorted run array. @@ -737,7 +498,7 @@ fn sort_run_downcasted( // encoded back to run array. fn sort_run_to_indices( values: &dyn Array, - options: &SortOptions, + options: SortOptions, limit: Option, ) -> UInt32Array { let run_array = values.as_any().downcast_ref::>().unwrap(); @@ -752,7 +513,7 @@ fn sort_run_to_indices( let consume_runs = |run_length, logical_start| { result.extend(logical_start as u32..(logical_start + run_length) as u32); }; - sort_run_inner(run_array, Some(*options), output_len, consume_runs); + sort_run_inner(run_array, Some(options), output_len, consume_runs); UInt32Array::from(result) } @@ -834,200 +595,6 @@ where (values_indices, run_values) } -/// Sort strings -fn sort_string( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> UInt32Array { - let values = values - .as_any() - .downcast_ref::>() - .unwrap(); - - sort_string_helper( - values, - value_indices, - null_indices, - options, - limit, - |array, idx| array.value(idx as usize), - ) -} - -/// shared implementation between dictionary encoded and plain string arrays -#[inline] -fn sort_string_helper<'a, A: Array, F>( - values: &'a A, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, - value_fn: F, -) -> UInt32Array -where - F: Fn(&'a A, u32) -> &str, -{ - let mut valids = value_indices - .into_iter() - .map(|index| (index, value_fn(values, index))) - .collect::>(); - let mut nulls = null_indices; - let descending = options.descending; - let mut len = values.len(); - - if let Some(limit) = limit { - len = limit.min(len); - } - - sort_valids(descending, &mut valids, len, cmp); - // collect the order of valid tuplies - let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); - - if options.nulls_first { - nulls.append(&mut valid_indices); - nulls.truncate(len); - UInt32Array::from(nulls) - } else { - // no need to sort nulls as they are in the correct order already - valid_indices.append(&mut nulls); - valid_indices.truncate(len); - UInt32Array::from(valid_indices) - } -} - -fn sort_list( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> UInt32Array -where - S: OffsetSizeTrait, -{ - sort_list_inner::(values, value_indices, null_indices, options, limit) -} - -fn sort_list_inner( - values: &dyn Array, - value_indices: Vec, - mut null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> UInt32Array -where - S: OffsetSizeTrait, -{ - let mut valids: Vec<(u32, ArrayRef)> = values - .as_any() - .downcast_ref::() - .map_or_else( - || { - let values = as_generic_list_array::(values); - value_indices - .iter() - .copied() - .map(|index| (index, values.value(index as usize))) - .collect() - }, - |values| { - value_indices - .iter() - .copied() - .map(|index| (index, values.value(index as usize))) - .collect() - }, - ); - - let mut len = values.len(); - let descending = options.descending; - - if let Some(limit) = limit { - len = limit.min(len); - } - sort_valids_array(descending, &mut valids, &mut null_indices, len); - - let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); - if options.nulls_first { - null_indices.append(&mut valid_indices); - null_indices.truncate(len); - UInt32Array::from(null_indices) - } else { - valid_indices.append(&mut null_indices); - valid_indices.truncate(len); - UInt32Array::from(valid_indices) - } -} - -fn sort_binary( - values: &dyn Array, - value_indices: Vec, - mut null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> UInt32Array -where - S: OffsetSizeTrait, -{ - let mut valids: Vec<(u32, &[u8])> = values - .as_any() - .downcast_ref::() - .map_or_else( - || { - let values = as_generic_binary_array::(values); - value_indices - .iter() - .copied() - .map(|index| (index, values.value(index as usize))) - .collect() - }, - |values| { - value_indices - .iter() - .copied() - .map(|index| (index, values.value(index as usize))) - .collect() - }, - ); - - let mut len = values.len(); - let descending = options.descending; - - if let Some(limit) = limit { - len = limit.min(len); - } - - sort_valids(descending, &mut valids, len, cmp); - - let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); - if options.nulls_first { - null_indices.append(&mut valid_indices); - null_indices.truncate(len); - UInt32Array::from(null_indices) - } else { - valid_indices.append(&mut null_indices); - valid_indices.truncate(len); - UInt32Array::from(valid_indices) - } -} - -/// Compare two `Array`s based on the ordering defined in [build_compare] -fn cmp_array(a: &dyn Array, b: &dyn Array) -> Ordering { - let cmp_op = build_compare(a, b).unwrap(); - let length = a.len().max(b.len()); - - for i in 0..length { - let result = cmp_op(i, i); - if result != Ordering::Equal { - return result; - } - } - Ordering::Equal -} - /// One column to be used in lexicographical sort #[derive(Clone, Debug)] pub struct SortColumn { @@ -1146,8 +713,10 @@ pub fn partial_sort(v: &mut [T], limit: usize, mut is_less: F) where F: FnMut(&T, &T) -> Ordering, { - let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less); - before.sort_unstable_by(is_less); + if let Some(n) = limit.checked_sub(1) { + let (before, _mid, _after) = v.select_nth_unstable_by(n, &mut is_less); + before.sort_unstable_by(is_less); + } } type LexicographicalCompareItem<'a> = ( @@ -1228,42 +797,6 @@ impl LexicographicalComparator<'_> { } } -fn sort_valids( - descending: bool, - valids: &mut [(u32, T)], - len: usize, - mut cmp: impl FnMut(T, T) -> Ordering, -) where - T: ?Sized + Copy, -{ - let valids_len = valids.len(); - if !descending { - sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1)); - } else { - sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1).reverse()); - } -} - -fn sort_valids_array( - descending: bool, - valids: &mut [(u32, ArrayRef)], - nulls: &mut [T], - len: usize, -) { - let valids_len = valids.len(); - if !descending { - sort_unstable_by(valids, len.min(valids_len), |a, b| { - cmp_array(a.1.as_ref(), b.1.as_ref()) - }); - } else { - sort_unstable_by(valids, len.min(valids_len), |a, b| { - cmp_array(a.1.as_ref(), b.1.as_ref()).reverse() - }); - // reverse to keep a stable ordering - nulls.reverse(); - } -} - #[cfg(test)] mod tests { use super::*; @@ -1980,7 +1513,7 @@ mod tests { nulls_first: false, }), None, - vec![2, 3, 1, 4, 5, 0], + vec![2, 3, 1, 4, 0, 5], ); // boolean, descending, nulls first @@ -1991,7 +1524,7 @@ mod tests { nulls_first: true, }), None, - vec![5, 0, 2, 3, 1, 4], + vec![0, 5, 2, 3, 1, 4], ); // boolean, descending, nulls first, limit diff --git a/arrow/benches/sort_kernel.rs b/arrow/benches/sort_kernel.rs index 3a3ce4462dff..63e10e0528ba 100644 --- a/arrow/benches/sort_kernel.rs +++ b/arrow/benches/sort_kernel.rs @@ -67,23 +67,37 @@ fn bench_sort_to_indices(array: &dyn Array, limit: Option) { fn add_benchmark(c: &mut Criterion) { let arr = create_primitive_array::(2usize.pow(10), 0.0); - c.bench_function("sort i64 2^10", |b| b.iter(|| bench_sort(&arr))); - - let arr = create_primitive_array::(2usize.pow(12), 0.5); - c.bench_function("sort i64 2^12", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort i32 2^10", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort i32 to indices 2^10", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) + }); let arr = create_primitive_array::(2usize.pow(12), 0.0); - c.bench_function("sort i64 nulls 2^10", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort i32 2^12", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort i32 to indices 2^12", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) + }); + + let arr = create_primitive_array::(2usize.pow(10), 0.5); + c.bench_function("sort i32 nulls 2^10", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort i32 nulls to indices 2^10", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) + }); let arr = create_primitive_array::(2usize.pow(12), 0.5); - c.bench_function("sort i64 nulls 2^12", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort i32 nulls 2^12", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort i32 nulls to indices 2^12", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) + }); let arr = create_f32_array(2_usize.pow(12), false); + c.bench_function("sort f32 2^12", |b| b.iter(|| bench_sort(&arr))); c.bench_function("sort f32 to indices 2^12", |b| { b.iter(|| bench_sort_to_indices(&arr, None)) }); let arr = create_f32_array(2usize.pow(12), true); + c.bench_function("sort f32 nulls 2^12", |b| b.iter(|| bench_sort(&arr))); c.bench_function("sort f32 nulls to indices 2^12", |b| { b.iter(|| bench_sort_to_indices(&arr, None)) });