From 77f429fe00e2be1cd7c80009054d74e53230dabd Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Sun, 1 Aug 2021 11:51:02 +0800 Subject: [PATCH 01/13] Make merge_sort_slices MergeSortSlices public --- src/compute/merge_sort/mod.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/compute/merge_sort/mod.rs b/src/compute/merge_sort/mod.rs index 3dad0afdcb2..ab8d91f789d 100644 --- a/src/compute/merge_sort/mod.rs +++ b/src/compute/merge_sort/mod.rs @@ -234,7 +234,7 @@ fn recursive_merge_sort(slices: &[&[MergeSlice]], comparator: &Comparator) -> Ve // An iterator adapter that merge-sorts two iterators of `MergeSlice` into a single `MergeSlice` // such that the resulting `MergeSlice`s are ordered according to `comparator`. -struct MergeSortSlices<'a, L, R> +pub struct MergeSortSlices<'a, L, R> where L: Iterator, R: Iterator, @@ -426,7 +426,11 @@ where /// Given two iterators of slices representing two sets of sorted [`Array`]s, and a `comparator` bound to those [`Array`]s, /// returns a new iterator of slices denoting how to `take` slices from each of the arrays such that the resulting /// array is sorted according to `comparator` -fn merge_sort_slices<'a, L: Iterator, R: Iterator>( +pub fn merge_sort_slices< + 'a, + L: Iterator, + R: Iterator, +>( lhs: L, rhs: R, comparator: &'a Comparator, @@ -439,7 +443,7 @@ type Comparator<'a> = Box Ordering + 'a>; type IsValid<'a> = Box bool + 'a>; /// returns a comparison function between any two arrays of each pair of arrays, according to `SortOptions`. -fn build_comparator<'a>( +pub fn build_comparator<'a>( pairs: &'a [(&'a [&'a dyn Array], &SortOptions)], ) -> Result> { // prepare the comparison function of _values_ between all pairs of arrays From ea2f82d35fc4973117873d814e83831def1ccaff Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Sun, 1 Aug 2021 13:15:46 +0800 Subject: [PATCH 02/13] Make MergeSlice pub --- src/compute/merge_sort/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compute/merge_sort/mod.rs b/src/compute/merge_sort/mod.rs index ab8d91f789d..e3c2d4feddd 100644 --- a/src/compute/merge_sort/mod.rs +++ b/src/compute/merge_sort/mod.rs @@ -77,7 +77,7 @@ use crate::error::Result; /// This representation is useful when building arrays in memory as it allows to memcopy slices of arrays. /// This is particularly useful in merge-sort because sorted arrays (passed to the merge-sort) are more likely /// to have contiguous blocks of sorted elements (than by random). -type MergeSlice = (usize, usize, usize); +pub type MergeSlice = (usize, usize, usize); /// Takes N arrays together through `slices` under the assumption that the slices have /// a total coverage of the arrays. From 5677dd3dfb2dd44ee1a7e3502e85f79a4f617e64 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Mon, 2 Aug 2021 22:42:18 +0800 Subject: [PATCH 03/13] improve partial sort --- src/compute/sort/common.rs | 71 +++++++++++++-------------- src/compute/sort/primitive/indices.rs | 18 +++---- src/compute/sort/utf8.rs | 23 ++++++--- 3 files changed, 58 insertions(+), 54 deletions(-) diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs index c36763831bb..a72fdde2413 100644 --- a/src/compute/sort/common.rs +++ b/src/compute/sort/common.rs @@ -11,40 +11,39 @@ use super::SortOptions; /// * `get` is only called for `0 <= i < limit` /// * `cmp` is only called from the co-domain of `get`. #[inline] -fn k_element_sort_inner( +fn k_element_sort_inner( indices: &mut [I], - get: G, + values: &[T], descending: bool, limit: usize, mut cmp: F, ) where - G: Fn(usize) -> T, F: FnMut(&T, &T) -> std::cmp::Ordering, { if descending { - let compare = |lhs: &I, rhs: &I| { - let lhs = get(lhs.to_usize()); - let rhs = get(rhs.to_usize()); - cmp(&lhs, &rhs).reverse() + let compare = |lhs: &I, rhs: &I| unsafe { + let lhs = values.get_unchecked(lhs.to_usize()); + let rhs = values.get_unchecked(rhs.to_usize()); + cmp(lhs, rhs).reverse() }; let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &I, rhs: &I| { - let lhs = get(lhs.to_usize()); - let rhs = get(rhs.to_usize()); - cmp(&lhs, &rhs).reverse() + let compare = |lhs: &I, rhs: &I| unsafe { + let lhs = values.get_unchecked(lhs.to_usize()); + let rhs = values.get_unchecked(rhs.to_usize()); + cmp(lhs, rhs).reverse() }; before.sort_unstable_by(compare); } else { - let compare = |lhs: &I, rhs: &I| { - let lhs = get(lhs.to_usize()); - let rhs = get(rhs.to_usize()); - cmp(&lhs, &rhs) + let compare = |lhs: &I, rhs: &I| unsafe { + let lhs = values.get_unchecked(lhs.to_usize()); + let rhs = values.get_unchecked(rhs.to_usize()); + cmp(lhs, rhs) }; let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &I, rhs: &I| { - let lhs = get(lhs.to_usize()); - let rhs = get(rhs.to_usize()); - cmp(&lhs, &rhs) + let compare = |lhs: &I, rhs: &I| unsafe { + let lhs = values.get_unchecked(lhs.to_usize()); + let rhs = values.get_unchecked(rhs.to_usize()); + cmp(lhs, rhs) }; before.sort_unstable_by(compare); } @@ -55,32 +54,31 @@ fn k_element_sort_inner( /// * `get` is only called for `0 <= i < limit` /// * `cmp` is only called from the co-domain of `get`. #[inline] -fn sort_unstable_by( +fn sort_unstable_by( indices: &mut [I], - get: G, + values: &[T], mut cmp: F, descending: bool, limit: usize, ) where I: Index, - G: Fn(usize) -> T, F: FnMut(&T, &T) -> std::cmp::Ordering, { if limit != indices.len() { - return k_element_sort_inner(indices, get, descending, limit, cmp); + return k_element_sort_inner(indices, values, descending, limit, cmp); } if descending { - indices.sort_unstable_by(|lhs, rhs| { - let lhs = get(lhs.to_usize()); - let rhs = get(rhs.to_usize()); - cmp(&lhs, &rhs).reverse() + indices.sort_unstable_by(|lhs, rhs| unsafe { + let lhs = values.get_unchecked(lhs.to_usize()); + let rhs = values.get_unchecked(rhs.to_usize()); + cmp(lhs, rhs).reverse() }) } else { - indices.sort_unstable_by(|lhs, rhs| { - let lhs = get(lhs.to_usize()); - let rhs = get(rhs.to_usize()); - cmp(&lhs, &rhs) + indices.sort_unstable_by(|lhs, rhs| unsafe { + let lhs = values.get_unchecked(lhs.to_usize()); + let rhs = values.get_unchecked(rhs.to_usize()); + cmp(lhs, rhs) }) } } @@ -90,9 +88,9 @@ fn sort_unstable_by( /// * `get` is only called for `0 <= i < length` /// * `cmp` is only called from the co-domain of `get`. #[inline] -pub(super) fn indices_sorted_unstable_by( +pub(super) fn indices_sorted_unstable_by( validity: &Option, - get: G, + values: &[T], cmp: F, length: usize, options: &SortOptions, @@ -100,7 +98,6 @@ pub(super) fn indices_sorted_unstable_by( ) -> PrimitiveArray where I: Index, - G: Fn(usize) -> T, F: Fn(&T, &T) -> std::cmp::Ordering, { let descending = options.descending; @@ -136,7 +133,7 @@ where // limit is by construction < indices.len() let limit = limit - validity.null_count(); let indices = &mut indices.as_mut_slice()[validity.null_count()..]; - sort_unstable_by(indices, get, cmp, options.descending, limit) + sort_unstable_by(indices, values, cmp, options.descending, limit) } } else { let last_valid_index = length - validity.null_count(); @@ -157,7 +154,7 @@ where // limit is by construction <= values.len() let limit = limit.min(last_valid_index); let indices = &mut indices.as_mut_slice()[..last_valid_index]; - sort_unstable_by(indices, get, cmp, options.descending, limit); + sort_unstable_by(indices, values, cmp, options.descending, limit); } indices.truncate(limit); @@ -174,7 +171,7 @@ where // Soundness: // indices are by construction `< values.len()` // limit is by construction `< values.len()` - sort_unstable_by(&mut indices, get, cmp, descending, limit); + sort_unstable_by(&mut indices, values, cmp, descending, limit); indices.truncate(limit); indices.shrink_to_fit(); diff --git a/src/compute/sort/primitive/indices.rs b/src/compute/sort/primitive/indices.rs index e7c79947b4c..022c7c842ee 100644 --- a/src/compute/sort/primitive/indices.rs +++ b/src/compute/sort/primitive/indices.rs @@ -18,16 +18,14 @@ where T: NativeType, F: Fn(&T, &T) -> std::cmp::Ordering, { - unsafe { - common::indices_sorted_unstable_by( - array.validity(), - |x: usize| *array.values().as_slice().get_unchecked(x), - cmp, - array.len(), - options, - limit, - ) - } + common::indices_sorted_unstable_by( + array.validity(), + array.values().as_slice(), + cmp, + array.len(), + options, + limit, + ) } #[cfg(test)] diff --git a/src/compute/sort/utf8.rs b/src/compute/sort/utf8.rs index 20f2d65fea7..ad5972b0391 100644 --- a/src/compute/sort/utf8.rs +++ b/src/compute/sort/utf8.rs @@ -9,9 +9,13 @@ pub(super) fn indices_sorted_unstable_by( options: &SortOptions, limit: Option, ) -> PrimitiveArray { - let get = |idx| unsafe { array.value_unchecked(idx as usize) }; + let values = unsafe { + (0..array.len()) + .map(|idx| array.value_unchecked(idx as usize)) + .collect::>() + }; let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); - common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) + common::indices_sorted_unstable_by(array.validity(), &values, cmp, array.len(), options, limit) } pub(super) fn indices_sorted_unstable_by_dictionary( @@ -27,11 +31,16 @@ pub(super) fn indices_sorted_unstable_by_dictionary>() .unwrap(); - let get = |idx| unsafe { - let index = keys.value_unchecked(idx as usize); - // Note: there is no check that the keys are within bounds of the dictionary. - dict.value(index.to_usize().unwrap()) + let values = unsafe { + (0..array.len()) + .map(|idx| { + let index = keys.value_unchecked(idx as usize); + // Note: there is no check that the keys are within bounds of the dictionary. + dict.value(index.to_usize().unwrap()) + }) + .collect::>() }; + let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); - common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) + common::indices_sorted_unstable_by(array.validity(), &values, cmp, array.len(), options, limit) } From b701b94474c8eb56b2a76277cb36e15a83b453c8 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Tue, 3 Aug 2021 07:14:11 +0800 Subject: [PATCH 04/13] improve from_usize --- benches/growable.rs | 8 ++--- benches/take_kernels.rs | 2 +- src/compute/sort/common.rs | 67 +++++++++++++++++++------------------- 3 files changed, 38 insertions(+), 39 deletions(-) diff --git a/benches/growable.rs b/benches/growable.rs index bedb89bd5a4..e0a4d2426ae 100644 --- a/benches/growable.rs +++ b/benches/growable.rs @@ -14,7 +14,7 @@ fn add_benchmark(c: &mut Criterion) { let i32_array = create_primitive_array::(1026 * 10, DataType::Int32, 0.0); c.bench_function("growable::primitive::non_null::non_null", |b| { b.iter(|| { - let mut a = GrowablePrimitive::new(&[&i32_array], false, 1026 * 10); + let mut a = GrowablePrimitive::new(vec![&i32_array], false, 1026 * 10); values .clone() .into_iter() @@ -25,7 +25,7 @@ fn add_benchmark(c: &mut Criterion) { let i32_array = create_primitive_array::(1026 * 10, DataType::Int32, 0.0); c.bench_function("growable::primitive::non_null::null", |b| { b.iter(|| { - let mut a = GrowablePrimitive::new(&[&i32_array], true, 1026 * 10); + let mut a = GrowablePrimitive::new(vec![&i32_array], true, 1026 * 10); values.clone().into_iter().for_each(|start| { if start % 2 == 0 { a.extend_validity(10); @@ -41,7 +41,7 @@ fn add_benchmark(c: &mut Criterion) { let values = values.collect::>(); c.bench_function("growable::primitive::null::non_null", |b| { b.iter(|| { - let mut a = GrowablePrimitive::new(&[&i32_array], false, 1026 * 10); + let mut a = GrowablePrimitive::new(vec![&i32_array], false, 1026 * 10); values .clone() .into_iter() @@ -50,7 +50,7 @@ fn add_benchmark(c: &mut Criterion) { }); c.bench_function("growable::primitive::null::null", |b| { b.iter(|| { - let mut a = GrowablePrimitive::new(&[&i32_array], true, 1026 * 10); + let mut a = GrowablePrimitive::new(vec![&i32_array], true, 1026 * 10); values.clone().into_iter().for_each(|start| { if start % 2 == 0 { a.extend_validity(10); diff --git a/benches/take_kernels.rs b/benches/take_kernels.rs index 2ab746b443c..9303a719f34 100644 --- a/benches/take_kernels.rs +++ b/benches/take_kernels.rs @@ -35,7 +35,7 @@ fn create_random_index(size: usize, null_density: f32) -> PrimitiveArray { (0..size) .map(|_| { if rng.gen::() > null_density { - let value = rng.gen_range::(0i32, size as i32); + let value = rng.gen_range::(0i32..size as i32); Some(value) } else { None diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs index a72fdde2413..25615c6150a 100644 --- a/src/compute/sort/common.rs +++ b/src/compute/sort/common.rs @@ -11,8 +11,8 @@ use super::SortOptions; /// * `get` is only called for `0 <= i < limit` /// * `cmp` is only called from the co-domain of `get`. #[inline] -fn k_element_sort_inner( - indices: &mut [I], +fn k_element_sort_inner( + indices: &mut [usize], values: &[T], descending: bool, limit: usize, @@ -21,28 +21,28 @@ fn k_element_sort_inner( F: FnMut(&T, &T) -> std::cmp::Ordering, { if descending { - let compare = |lhs: &I, rhs: &I| unsafe { - let lhs = values.get_unchecked(lhs.to_usize()); - let rhs = values.get_unchecked(rhs.to_usize()); + let compare = |lhs: &usize, rhs: &usize| unsafe { + let lhs = values.get_unchecked(*lhs); + let rhs = values.get_unchecked(*rhs); cmp(lhs, rhs).reverse() }; let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &I, rhs: &I| unsafe { - let lhs = values.get_unchecked(lhs.to_usize()); - let rhs = values.get_unchecked(rhs.to_usize()); + let compare = |lhs: &usize, rhs: &usize| unsafe { + let lhs = values.get_unchecked(*lhs); + let rhs = values.get_unchecked(*rhs); cmp(lhs, rhs).reverse() }; before.sort_unstable_by(compare); } else { - let compare = |lhs: &I, rhs: &I| unsafe { - let lhs = values.get_unchecked(lhs.to_usize()); - let rhs = values.get_unchecked(rhs.to_usize()); + let compare = |lhs: &usize, rhs: &usize| unsafe { + let lhs = values.get_unchecked(*lhs); + let rhs = values.get_unchecked(*rhs); cmp(lhs, rhs) }; let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &I, rhs: &I| unsafe { - let lhs = values.get_unchecked(lhs.to_usize()); - let rhs = values.get_unchecked(rhs.to_usize()); + let compare = |lhs: &usize, rhs: &usize| unsafe { + let lhs = values.get_unchecked(*lhs); + let rhs = values.get_unchecked(*rhs); cmp(lhs, rhs) }; before.sort_unstable_by(compare); @@ -54,14 +54,13 @@ fn k_element_sort_inner( /// * `get` is only called for `0 <= i < limit` /// * `cmp` is only called from the co-domain of `get`. #[inline] -fn sort_unstable_by( - indices: &mut [I], +fn sort_unstable_by( + indices: &mut [usize], values: &[T], mut cmp: F, descending: bool, limit: usize, ) where - I: Index, F: FnMut(&T, &T) -> std::cmp::Ordering, { if limit != indices.len() { @@ -70,14 +69,14 @@ fn sort_unstable_by( if descending { indices.sort_unstable_by(|lhs, rhs| unsafe { - let lhs = values.get_unchecked(lhs.to_usize()); - let rhs = values.get_unchecked(rhs.to_usize()); + let lhs = values.get_unchecked(*lhs); + let rhs = values.get_unchecked(*rhs); cmp(lhs, rhs).reverse() }) } else { indices.sort_unstable_by(|lhs, rhs| unsafe { - let lhs = values.get_unchecked(lhs.to_usize()); - let rhs = values.get_unchecked(rhs.to_usize()); + let lhs = values.get_unchecked(*lhs); + let rhs = values.get_unchecked(*rhs); cmp(lhs, rhs) }) } @@ -107,7 +106,7 @@ where let limit = limit.min(length); let indices = if let Some(validity) = validity { - let mut indices = MutableBuffer::::from_len_zeroed(length); + let mut indices = vec![length; 0usize]; if options.nulls_first { let mut nulls = 0; @@ -117,10 +116,10 @@ where .zip(0..length) .for_each(|(is_valid, index)| { if is_valid { - indices[validity.null_count() + valids] = I::from_usize(index).unwrap(); + indices[validity.null_count() + valids] = index; valids += 1; } else { - indices[nulls] = I::from_usize(index).unwrap(); + indices[nulls] = index; nulls += 1; } }); @@ -141,10 +140,10 @@ where let mut valids = 0; validity.iter().zip(0..length).for_each(|(x, index)| { if x { - indices[valids] = I::from_usize(index).unwrap(); + indices[valids] = index; valids += 1; } else { - indices[last_valid_index + nulls] = I::from_usize(index).unwrap(); + indices[last_valid_index + nulls] = index; nulls += 1; } }); @@ -162,12 +161,7 @@ where indices } else { - let mut indices = unsafe { - MutableBuffer::from_trusted_len_iter_unchecked( - (0..length).map(|x| I::from_usize(x).unwrap()), - ) - }; - + let mut indices = vec![length; 0usize]; // Soundness: // indices are by construction `< values.len()` // limit is by construction `< values.len()` @@ -175,8 +169,13 @@ where indices.truncate(limit); indices.shrink_to_fit(); - indices }; - PrimitiveArray::::from_data(I::DATA_TYPE, indices.into(), None) + let mut buffer_indices = MutableBuffer::::with_capacity(indices.len()); + unsafe { + buffer_indices.extend_from_trusted_len_iter_unchecked( + indices.iter().map(|c| I::from_usize(*c).unwrap()), + ); + } + PrimitiveArray::::from_data(I::DATA_TYPE, buffer_indices.into(), None) } From d351b1ef3b7718be6538446c5dbc7d69447e45c3 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Tue, 3 Aug 2021 07:22:31 +0800 Subject: [PATCH 05/13] improve from_usize --- src/compute/sort/common.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs index 25615c6150a..bf1b3e65127 100644 --- a/src/compute/sort/common.rs +++ b/src/compute/sort/common.rs @@ -161,12 +161,11 @@ where indices } else { - let mut indices = vec![length; 0usize]; + let mut indices: Vec = (0..length as usize).collect(); // Soundness: // indices are by construction `< values.len()` // limit is by construction `< values.len()` sort_unstable_by(&mut indices, values, cmp, descending, limit); - indices.truncate(limit); indices.shrink_to_fit(); indices From cefb4c94d7011564fd32b106e564d509223424dd Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Tue, 3 Aug 2021 07:44:27 +0800 Subject: [PATCH 06/13] improve reverse --- src/compute/sort/common.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs index bf1b3e65127..76cfb1dbd6d 100644 --- a/src/compute/sort/common.rs +++ b/src/compute/sort/common.rs @@ -24,13 +24,13 @@ fn k_element_sort_inner( let compare = |lhs: &usize, rhs: &usize| unsafe { let lhs = values.get_unchecked(*lhs); let rhs = values.get_unchecked(*rhs); - cmp(lhs, rhs).reverse() + cmp(rhs, lhs) }; let (before, _, _) = indices.select_nth_unstable_by(limit, compare); let compare = |lhs: &usize, rhs: &usize| unsafe { let lhs = values.get_unchecked(*lhs); let rhs = values.get_unchecked(*rhs); - cmp(lhs, rhs).reverse() + cmp(rhs, lhs) }; before.sort_unstable_by(compare); } else { From 1cf0850f83c76be519a39ca2765cb0f9c2d1168f Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Tue, 3 Aug 2021 09:39:00 +0800 Subject: [PATCH 07/13] add std::iter::Step trait to Index --- benches/sort_kernel.rs | 15 +++++- src/array/specification.rs | 13 ++++- src/compute/sort/boolean.rs | 2 +- src/compute/sort/common.rs | 77 +++++++++++++----------------- src/compute/sort/mod.rs | 2 +- src/compute/sort/primitive/sort.rs | 4 +- src/lib.rs | 2 + 7 files changed, 66 insertions(+), 49 deletions(-) diff --git a/benches/sort_kernel.rs b/benches/sort_kernel.rs index 30d89c441cb..e78e4462b04 100644 --- a/benches/sort_kernel.rs +++ b/benches/sort_kernel.rs @@ -19,7 +19,7 @@ extern crate criterion; use criterion::Criterion; -use arrow2::compute::sort::{lexsort, sort, SortColumn, SortOptions}; +use arrow2::compute::sort::{lexsort, sort, sort_to_indices, SortColumn, SortOptions}; use arrow2::util::bench_util::*; use arrow2::{array::*, datatypes::*}; @@ -42,6 +42,15 @@ fn bench_sort(arr_a: &dyn Array) { sort(criterion::black_box(arr_a), &SortOptions::default(), None).unwrap(); } +fn bench_sort_limit(arr_a: &dyn Array) { + let _: PrimitiveArray = sort_to_indices( + criterion::black_box(arr_a), + &SortOptions::default(), + Some(100), + ) + .unwrap(); +} + fn add_benchmark(c: &mut Criterion) { (10..=20).step_by(2).for_each(|log2_size| { let size = 2usize.pow(log2_size); @@ -51,6 +60,10 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| bench_sort(&arr_a)) }); + c.bench_function(&format!("sort-limit 2^{} f32", log2_size), |b| { + b.iter(|| bench_sort_limit(&arr_a)) + }); + let arr_b = create_primitive_array_with_seed::(size, DataType::Float32, 0.0, 43); c.bench_function(&format!("lexsort 2^{} f32", log2_size), |b| { b.iter(|| bench_lexsort(&arr_a, &arr_b)) diff --git a/src/array/specification.rs b/src/array/specification.rs index 70576cb8621..f9a82526b3b 100644 --- a/src/array/specification.rs +++ b/src/array/specification.rs @@ -8,9 +8,12 @@ use crate::{ }; /// Trait describing any type that can be used to index a slot of an array. -pub trait Index: NativeType + NaturalDataType { +pub trait Index: NativeType + NaturalDataType + std::iter::Step { fn to_usize(&self) -> usize; fn from_usize(index: usize) -> Option; + fn is_usize() -> bool { + false + } } /// Trait describing types that can be used as offsets as per Arrow specification. @@ -95,6 +98,10 @@ impl Index for u32 { fn from_usize(value: usize) -> Option { Self::try_from(value).ok() } + + fn is_usize() -> bool { + std::mem::size_of::() == std::mem::size_of::() + } } impl Index for u64 { @@ -107,6 +114,10 @@ impl Index for u64 { fn from_usize(value: usize) -> Option { Self::try_from(value).ok() } + + fn is_usize() -> bool { + std::mem::size_of::() == std::mem::size_of::() + } } #[inline] diff --git a/src/compute/sort/boolean.rs b/src/compute/sort/boolean.rs index 2e7082a5901..bb7faec6efa 100644 --- a/src/compute/sort/boolean.rs +++ b/src/compute/sort/boolean.rs @@ -26,7 +26,7 @@ pub fn sort_boolean( if !descending { valids.sort_by(|a, b| a.1.cmp(&b.1)); } else { - valids.sort_by(|a, b| a.1.cmp(&b.1).reverse()); + valids.sort_by(|a, b| b.1.cmp(&a.1)); // reverse to keep a stable ordering nulls.reverse(); } diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs index 76cfb1dbd6d..18ac986516b 100644 --- a/src/compute/sort/common.rs +++ b/src/compute/sort/common.rs @@ -11,8 +11,8 @@ use super::SortOptions; /// * `get` is only called for `0 <= i < limit` /// * `cmp` is only called from the co-domain of `get`. #[inline] -fn k_element_sort_inner( - indices: &mut [usize], +fn k_element_sort_inner( + indices: &mut [I], values: &[T], descending: bool, limit: usize, @@ -21,31 +21,21 @@ fn k_element_sort_inner( F: FnMut(&T, &T) -> std::cmp::Ordering, { if descending { - let compare = |lhs: &usize, rhs: &usize| unsafe { - let lhs = values.get_unchecked(*lhs); - let rhs = values.get_unchecked(*rhs); + let mut compare = |lhs: &I, rhs: &I| unsafe { + let lhs = values.get_unchecked(lhs.to_usize()); + let rhs = values.get_unchecked(rhs.to_usize()); cmp(rhs, lhs) }; - let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &usize, rhs: &usize| unsafe { - let lhs = values.get_unchecked(*lhs); - let rhs = values.get_unchecked(*rhs); - cmp(rhs, lhs) - }; - before.sort_unstable_by(compare); + let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare); + before.sort_unstable_by(&mut compare); } else { - let compare = |lhs: &usize, rhs: &usize| unsafe { - let lhs = values.get_unchecked(*lhs); - let rhs = values.get_unchecked(*rhs); + let mut compare = |lhs: &I, rhs: &I| unsafe { + let lhs = values.get_unchecked(lhs.to_usize()); + let rhs = values.get_unchecked(rhs.to_usize()); cmp(lhs, rhs) }; - let (before, _, _) = indices.select_nth_unstable_by(limit, compare); - let compare = |lhs: &usize, rhs: &usize| unsafe { - let lhs = values.get_unchecked(*lhs); - let rhs = values.get_unchecked(*rhs); - cmp(lhs, rhs) - }; - before.sort_unstable_by(compare); + let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare); + before.sort_unstable_by(&mut compare); } } @@ -54,8 +44,8 @@ fn k_element_sort_inner( /// * `get` is only called for `0 <= i < limit` /// * `cmp` is only called from the co-domain of `get`. #[inline] -fn sort_unstable_by( - indices: &mut [usize], +fn sort_unstable_by( + indices: &mut [I], values: &[T], mut cmp: F, descending: bool, @@ -69,14 +59,14 @@ fn sort_unstable_by( if descending { indices.sort_unstable_by(|lhs, rhs| unsafe { - let lhs = values.get_unchecked(*lhs); - let rhs = values.get_unchecked(*rhs); - cmp(lhs, rhs).reverse() + let lhs = values.get_unchecked(lhs.to_usize()); + let rhs = values.get_unchecked(rhs.to_usize()); + cmp(rhs, lhs) }) } else { indices.sort_unstable_by(|lhs, rhs| unsafe { - let lhs = values.get_unchecked(*lhs); - let rhs = values.get_unchecked(*rhs); + let lhs = values.get_unchecked(lhs.to_usize()); + let rhs = values.get_unchecked(rhs.to_usize()); cmp(lhs, rhs) }) } @@ -106,8 +96,10 @@ where let limit = limit.min(length); let indices = if let Some(validity) = validity { - let mut indices = vec![length; 0usize]; - + let mut indices = MutableBuffer::::with_capacity(length); + unsafe { + indices.set_len(length); + } if options.nulls_first { let mut nulls = 0; let mut valids = 0; @@ -116,10 +108,10 @@ where .zip(0..length) .for_each(|(is_valid, index)| { if is_valid { - indices[validity.null_count() + valids] = index; + indices[validity.null_count() + valids] = I::from_usize(index).unwrap(); valids += 1; } else { - indices[nulls] = index; + indices[nulls] = I::from_usize(index).unwrap(); nulls += 1; } }); @@ -140,10 +132,10 @@ where let mut valids = 0; validity.iter().zip(0..length).for_each(|(x, index)| { if x { - indices[valids] = index; + indices[valids] = I::from_usize(index).unwrap(); valids += 1; } else { - indices[last_valid_index + nulls] = index; + indices[last_valid_index + nulls] = I::from_usize(index).unwrap(); nulls += 1; } }); @@ -161,7 +153,11 @@ where indices } else { - let mut indices: Vec = (0..length as usize).collect(); + let mut indices = unsafe { + MutableBuffer::from_trusted_len_iter_unchecked( + I::from_usize(0).unwrap()..I::from_usize(length).unwrap(), + ) + }; // Soundness: // indices are by construction `< values.len()` // limit is by construction `< values.len()` @@ -170,11 +166,6 @@ where indices.shrink_to_fit(); indices }; - let mut buffer_indices = MutableBuffer::::with_capacity(indices.len()); - unsafe { - buffer_indices.extend_from_trusted_len_iter_unchecked( - indices.iter().map(|c| I::from_usize(*c).unwrap()), - ); - } - PrimitiveArray::::from_data(I::DATA_TYPE, buffer_indices.into(), None) + + PrimitiveArray::::from_data(I::DATA_TYPE, indices.into(), None) } diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 32cf63166c6..781263f45dc 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -387,7 +387,7 @@ where if !options.descending { valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref())) } else { - valids.sort_by(|a, b| cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()) + valids.sort_by(|a, b| cmp_array(b.1.as_ref(), a.1.as_ref())) } let values = valids.iter().map(|tuple| tuple.0); diff --git a/src/compute/sort/primitive/sort.rs b/src/compute/sort/primitive/sort.rs index 197ea299416..6fcedd80343 100644 --- a/src/compute/sort/primitive/sort.rs +++ b/src/compute/sort/primitive/sort.rs @@ -34,7 +34,7 @@ where F: FnMut(&T, &T) -> std::cmp::Ordering, { if descending { - let (before, _, _) = values.select_nth_unstable_by(limit, |x, y| cmp(x, y).reverse()); + let (before, _, _) = values.select_nth_unstable_by(limit, |x, y| cmp(y, x)); before.sort_unstable_by(|x, y| cmp(x, y)); } else { let (before, _, _) = values.select_nth_unstable_by(limit, |x, y| cmp(x, y)); @@ -52,7 +52,7 @@ where } if descending { - values.sort_unstable_by(|x, y| cmp(x, y).reverse()); + values.sort_unstable_by(|x, y| cmp(y, x)); } else { values.sort_unstable_by(cmp); }; diff --git a/src/lib.rs b/src/lib.rs index 060b6278883..1131cd15166 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(step_trait)] + pub mod alloc; pub mod array; pub mod bitmap; From 0ec21227c0a44dd14fbe38bd7b38871ae6a83cd7 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Tue, 3 Aug 2021 10:28:02 +0800 Subject: [PATCH 08/13] add to_vec to MergeSortSlices --- src/buffer/mutable.rs | 13 +++++++++++++ src/compute/merge_sort/mod.rs | 23 +++++++++++++++++++++++ src/compute/sort/common.rs | 9 +++------ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/src/buffer/mutable.rs b/src/buffer/mutable.rs index 7cf5cdae394..42ac115d0c0 100644 --- a/src/buffer/mutable.rs +++ b/src/buffer/mutable.rs @@ -93,6 +93,19 @@ impl MutableBuffer { } } + /// Allocates a new [MutableBuffer] with `len` and capacity to be at least `len` where + /// all bytes are not initialized + #[inline] + pub fn from_len(len: usize) -> Self { + let new_capacity = capacity_multiple_of_64::(len); + let ptr = alloc::allocate_aligned(new_capacity); + Self { + ptr, + len, + capacity: new_capacity, + } + } + /// Ensures that this buffer has at least `self.len + additional` bytes. This re-allocates iff /// `self.len + additional > capacity`. /// # Example diff --git a/src/compute/merge_sort/mod.rs b/src/compute/merge_sort/mod.rs index e3c2d4feddd..3b97b2babc0 100644 --- a/src/compute/merge_sort/mod.rs +++ b/src/compute/merge_sort/mod.rs @@ -291,6 +291,29 @@ where None => self.right = None, } } + + /// Collect the MergeSortSlices to be a vec for reusing + #[warn(dead_code)] + pub fn to_vec(self, limit: Option) -> Vec { + match limit { + Some(limit) => { + let mut v = Vec::with_capacity(limit); + let mut current_len = 0; + for (index, start, len) in self { + v.push((index, start, len)); + + if len + current_len >= limit { + break; + } else { + current_len += len; + } + } + + v + } + None => self.into_iter().collect(), + } + } } impl<'a, L, R> Iterator for MergeSortSlices<'a, L, R> diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs index 18ac986516b..f1c79971764 100644 --- a/src/compute/sort/common.rs +++ b/src/compute/sort/common.rs @@ -96,10 +96,7 @@ where let limit = limit.min(length); let indices = if let Some(validity) = validity { - let mut indices = MutableBuffer::::with_capacity(length); - unsafe { - indices.set_len(length); - } + let mut indices = MutableBuffer::::from_len(length); if options.nulls_first { let mut nulls = 0; let mut valids = 0; @@ -122,12 +119,12 @@ where // Soundness: // all indices in `indices` are by construction `< array.len() == values.len()` // limit is by construction < indices.len() - let limit = limit - validity.null_count(); + let limit = limit.saturating_sub(validity.null_count()); let indices = &mut indices.as_mut_slice()[validity.null_count()..]; sort_unstable_by(indices, values, cmp, options.descending, limit) } } else { - let last_valid_index = length - validity.null_count(); + let last_valid_index = length.saturating_sub(validity.null_count()); let mut nulls = 0; let mut valids = 0; validity.iter().zip(0..length).for_each(|(x, index)| { From f5ea531f202972f84d9eb75719651c5fdbb59df0 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Tue, 3 Aug 2021 15:16:05 +0800 Subject: [PATCH 09/13] Apply suggestions --- Cargo.toml | 1 + benches/sort_kernel.rs | 8 +++- src/array/specification.rs | 58 +++++++++++++++++++++++- src/buffer/mutable.rs | 12 ++--- src/compute/sort/common.rs | 64 +++++++++++++-------------- src/compute/sort/primitive/indices.rs | 19 ++++---- src/compute/sort/utf8.rs | 22 +++------ src/lib.rs | 2 - 8 files changed, 118 insertions(+), 68 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 55eeab15e41..0fd835c5bf1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,7 @@ criterion = "0.3" flate2 = "1" doc-comment = "0.3" crossbeam-channel = "0.5.1" +pprof = { version = "0.5.0", features = ["flamegraph", "criterion"] } [features] default = ["io_csv", "io_json", "io_ipc", "io_ipc_compression", "io_json_integration", "io_print", "io_parquet", "regex", "merge_sort", "ahash", "benchmarks", "compute"] diff --git a/benches/sort_kernel.rs b/benches/sort_kernel.rs index e78e4462b04..415e6c2a54b 100644 --- a/benches/sort_kernel.rs +++ b/benches/sort_kernel.rs @@ -23,6 +23,8 @@ use arrow2::compute::sort::{lexsort, sort, sort_to_indices, SortColumn, SortOpti use arrow2::util::bench_util::*; use arrow2::{array::*, datatypes::*}; +use pprof::criterion::{Output, PProfProfiler}; + fn bench_lexsort(arr_a: &dyn Array, array_b: &dyn Array) { let columns = vec![ SortColumn { @@ -87,5 +89,9 @@ fn add_benchmark(c: &mut Criterion) { }); } -criterion_group!(benches, add_benchmark); +criterion_group! { + name = benches; + config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); + targets = add_benchmark +} criterion_main!(benches); diff --git a/src/array/specification.rs b/src/array/specification.rs index f9a82526b3b..0684d9a6f54 100644 --- a/src/array/specification.rs +++ b/src/array/specification.rs @@ -3,17 +3,19 @@ use std::convert::TryFrom; use num::Num; use crate::{ - buffer::Buffer, + buffer::{Buffer, MutableBuffer}, types::{NativeType, NaturalDataType}, }; /// Trait describing any type that can be used to index a slot of an array. -pub trait Index: NativeType + NaturalDataType + std::iter::Step { +pub trait Index: NativeType + NaturalDataType { fn to_usize(&self) -> usize; fn from_usize(index: usize) -> Option; fn is_usize() -> bool { false } + + fn buffer_from_range(start: usize, end: usize) -> Option>; } /// Trait describing types that can be used as offsets as per Arrow specification. @@ -74,6 +76,19 @@ impl Index for i32 { fn from_usize(value: usize) -> Option { Self::try_from(value).ok() } + + fn buffer_from_range(start: usize, end: usize) -> Option> { + let start = Self::from_usize(start); + let end = Self::from_usize(end); + match (start, end) { + (Some(start), Some(end)) => unsafe { + Some(MutableBuffer::::from_trusted_len_iter_unchecked( + start..end, + )) + }, + _ => None, + } + } } impl Index for i64 { @@ -86,6 +101,19 @@ impl Index for i64 { fn from_usize(value: usize) -> Option { Self::try_from(value).ok() } + + fn buffer_from_range(start: usize, end: usize) -> Option> { + let start = Self::from_usize(start); + let end = Self::from_usize(end); + match (start, end) { + (Some(start), Some(end)) => unsafe { + Some(MutableBuffer::::from_trusted_len_iter_unchecked( + start..end, + )) + }, + _ => None, + } + } } impl Index for u32 { @@ -102,6 +130,19 @@ impl Index for u32 { fn is_usize() -> bool { std::mem::size_of::() == std::mem::size_of::() } + + fn buffer_from_range(start: usize, end: usize) -> Option> { + let start = Self::from_usize(start); + let end = Self::from_usize(end); + match (start, end) { + (Some(start), Some(end)) => unsafe { + Some(MutableBuffer::::from_trusted_len_iter_unchecked( + start..end, + )) + }, + _ => None, + } + } } impl Index for u64 { @@ -118,6 +159,19 @@ impl Index for u64 { fn is_usize() -> bool { std::mem::size_of::() == std::mem::size_of::() } + + fn buffer_from_range(start: usize, end: usize) -> Option> { + let start = Self::from_usize(start); + let end = Self::from_usize(end); + match (start, end) { + (Some(start), Some(end)) => unsafe { + Some(MutableBuffer::::from_trusted_len_iter_unchecked( + start..end, + )) + }, + _ => None, + } + } } #[inline] diff --git a/src/buffer/mutable.rs b/src/buffer/mutable.rs index 42ac115d0c0..e811901f21b 100644 --- a/src/buffer/mutable.rs +++ b/src/buffer/mutable.rs @@ -96,14 +96,10 @@ impl MutableBuffer { /// Allocates a new [MutableBuffer] with `len` and capacity to be at least `len` where /// all bytes are not initialized #[inline] - pub fn from_len(len: usize) -> Self { - let new_capacity = capacity_multiple_of_64::(len); - let ptr = alloc::allocate_aligned(new_capacity); - Self { - ptr, - len, - capacity: new_capacity, - } + pub unsafe fn from_len(len: usize) -> Self { + let mut buffer = MutableBuffer::with_capacity(len); + buffer.set_len(len); + buffer } /// Ensures that this buffer has at least `self.len + additional` bytes. This re-allocates iff diff --git a/src/compute/sort/common.rs b/src/compute/sort/common.rs index f1c79971764..7e909e01cd6 100644 --- a/src/compute/sort/common.rs +++ b/src/compute/sort/common.rs @@ -11,28 +11,29 @@ use super::SortOptions; /// * `get` is only called for `0 <= i < limit` /// * `cmp` is only called from the co-domain of `get`. #[inline] -fn k_element_sort_inner( +fn k_element_sort_inner( indices: &mut [I], - values: &[T], + get: G, descending: bool, limit: usize, mut cmp: F, ) where + G: Fn(usize) -> T, F: FnMut(&T, &T) -> std::cmp::Ordering, { if descending { - let mut compare = |lhs: &I, rhs: &I| unsafe { - let lhs = values.get_unchecked(lhs.to_usize()); - let rhs = values.get_unchecked(rhs.to_usize()); - cmp(rhs, lhs) + let mut compare = |lhs: &I, rhs: &I| { + let lhs = get(lhs.to_usize()); + let rhs = get(rhs.to_usize()); + cmp(&rhs, &lhs) }; let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare); before.sort_unstable_by(&mut compare); } else { - let mut compare = |lhs: &I, rhs: &I| unsafe { - let lhs = values.get_unchecked(lhs.to_usize()); - let rhs = values.get_unchecked(rhs.to_usize()); - cmp(lhs, rhs) + let mut compare = |lhs: &I, rhs: &I| { + let lhs = get(lhs.to_usize()); + let rhs = get(rhs.to_usize()); + cmp(&lhs, &rhs) }; let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare); before.sort_unstable_by(&mut compare); @@ -44,30 +45,32 @@ fn k_element_sort_inner( /// * `get` is only called for `0 <= i < limit` /// * `cmp` is only called from the co-domain of `get`. #[inline] -fn sort_unstable_by( +fn sort_unstable_by( indices: &mut [I], - values: &[T], + get: G, mut cmp: F, descending: bool, limit: usize, ) where + I: Index, + G: Fn(usize) -> T, F: FnMut(&T, &T) -> std::cmp::Ordering, { if limit != indices.len() { - return k_element_sort_inner(indices, values, descending, limit, cmp); + return k_element_sort_inner(indices, get, descending, limit, cmp); } if descending { - indices.sort_unstable_by(|lhs, rhs| unsafe { - let lhs = values.get_unchecked(lhs.to_usize()); - let rhs = values.get_unchecked(rhs.to_usize()); - cmp(rhs, lhs) + indices.sort_unstable_by(|lhs, rhs| { + let lhs = get(lhs.to_usize()); + let rhs = get(rhs.to_usize()); + cmp(&rhs, &lhs) }) } else { - indices.sort_unstable_by(|lhs, rhs| unsafe { - let lhs = values.get_unchecked(lhs.to_usize()); - let rhs = values.get_unchecked(rhs.to_usize()); - cmp(lhs, rhs) + indices.sort_unstable_by(|lhs, rhs| { + let lhs = get(lhs.to_usize()); + let rhs = get(rhs.to_usize()); + cmp(&lhs, &rhs) }) } } @@ -77,9 +80,9 @@ fn sort_unstable_by( /// * `get` is only called for `0 <= i < length` /// * `cmp` is only called from the co-domain of `get`. #[inline] -pub(super) fn indices_sorted_unstable_by( +pub(super) fn indices_sorted_unstable_by( validity: &Option, - values: &[T], + get: G, cmp: F, length: usize, options: &SortOptions, @@ -87,6 +90,7 @@ pub(super) fn indices_sorted_unstable_by( ) -> PrimitiveArray where I: Index, + G: Fn(usize) -> T, F: Fn(&T, &T) -> std::cmp::Ordering, { let descending = options.descending; @@ -96,7 +100,7 @@ where let limit = limit.min(length); let indices = if let Some(validity) = validity { - let mut indices = MutableBuffer::::from_len(length); + let mut indices = unsafe { MutableBuffer::::from_len(length) }; if options.nulls_first { let mut nulls = 0; let mut valids = 0; @@ -121,7 +125,7 @@ where // limit is by construction < indices.len() let limit = limit.saturating_sub(validity.null_count()); let indices = &mut indices.as_mut_slice()[validity.null_count()..]; - sort_unstable_by(indices, values, cmp, options.descending, limit) + sort_unstable_by(indices, get, cmp, options.descending, limit) } } else { let last_valid_index = length.saturating_sub(validity.null_count()); @@ -142,7 +146,7 @@ where // limit is by construction <= values.len() let limit = limit.min(last_valid_index); let indices = &mut indices.as_mut_slice()[..last_valid_index]; - sort_unstable_by(indices, values, cmp, options.descending, limit); + sort_unstable_by(indices, get, cmp, options.descending, limit); } indices.truncate(limit); @@ -150,15 +154,11 @@ where indices } else { - let mut indices = unsafe { - MutableBuffer::from_trusted_len_iter_unchecked( - I::from_usize(0).unwrap()..I::from_usize(length).unwrap(), - ) - }; + let mut indices = Index::buffer_from_range(0, length).unwrap(); // Soundness: // indices are by construction `< values.len()` // limit is by construction `< values.len()` - sort_unstable_by(&mut indices, values, cmp, descending, limit); + sort_unstable_by(&mut indices, get, cmp, descending, limit); indices.truncate(limit); indices.shrink_to_fit(); indices diff --git a/src/compute/sort/primitive/indices.rs b/src/compute/sort/primitive/indices.rs index 022c7c842ee..d3d7f76bd36 100644 --- a/src/compute/sort/primitive/indices.rs +++ b/src/compute/sort/primitive/indices.rs @@ -18,14 +18,17 @@ where T: NativeType, F: Fn(&T, &T) -> std::cmp::Ordering, { - common::indices_sorted_unstable_by( - array.validity(), - array.values().as_slice(), - cmp, - array.len(), - options, - limit, - ) + let values = array.values().as_slice(); + unsafe { + common::indices_sorted_unstable_by( + array.validity(), + |x: usize| *values.get_unchecked(x), + cmp, + array.len(), + options, + limit, + ) + } } #[cfg(test)] diff --git a/src/compute/sort/utf8.rs b/src/compute/sort/utf8.rs index ad5972b0391..17a75aef4e9 100644 --- a/src/compute/sort/utf8.rs +++ b/src/compute/sort/utf8.rs @@ -9,13 +9,9 @@ pub(super) fn indices_sorted_unstable_by( options: &SortOptions, limit: Option, ) -> PrimitiveArray { - let values = unsafe { - (0..array.len()) - .map(|idx| array.value_unchecked(idx as usize)) - .collect::>() - }; + let get = |idx| unsafe { array.value_unchecked(idx as usize) }; let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); - common::indices_sorted_unstable_by(array.validity(), &values, cmp, array.len(), options, limit) + common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) } pub(super) fn indices_sorted_unstable_by_dictionary( @@ -31,16 +27,12 @@ pub(super) fn indices_sorted_unstable_by_dictionary>() .unwrap(); - let values = unsafe { - (0..array.len()) - .map(|idx| { - let index = keys.value_unchecked(idx as usize); - // Note: there is no check that the keys are within bounds of the dictionary. - dict.value(index.to_usize().unwrap()) - }) - .collect::>() + let get = |idx| unsafe { + let index = keys.value_unchecked(idx as usize); + // Note: there is no check that the keys are within bounds of the dictionary. + dict.value(index.to_usize().unwrap()) }; let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); - common::indices_sorted_unstable_by(array.validity(), &values, cmp, array.len(), options, limit) + common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) } diff --git a/src/lib.rs b/src/lib.rs index 1131cd15166..060b6278883 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(step_trait)] - pub mod alloc; pub mod array; pub mod bitmap; From 8190d1142cc2d0c56491998fbebed99633677bca Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Tue, 3 Aug 2021 15:23:32 +0800 Subject: [PATCH 10/13] Remove pprof, because it's not working on windows --- Cargo.toml | 1 - benches/sort_kernel.rs | 8 +------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0fd835c5bf1..55eeab15e41 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,7 +66,6 @@ criterion = "0.3" flate2 = "1" doc-comment = "0.3" crossbeam-channel = "0.5.1" -pprof = { version = "0.5.0", features = ["flamegraph", "criterion"] } [features] default = ["io_csv", "io_json", "io_ipc", "io_ipc_compression", "io_json_integration", "io_print", "io_parquet", "regex", "merge_sort", "ahash", "benchmarks", "compute"] diff --git a/benches/sort_kernel.rs b/benches/sort_kernel.rs index 415e6c2a54b..e78e4462b04 100644 --- a/benches/sort_kernel.rs +++ b/benches/sort_kernel.rs @@ -23,8 +23,6 @@ use arrow2::compute::sort::{lexsort, sort, sort_to_indices, SortColumn, SortOpti use arrow2::util::bench_util::*; use arrow2::{array::*, datatypes::*}; -use pprof::criterion::{Output, PProfProfiler}; - fn bench_lexsort(arr_a: &dyn Array, array_b: &dyn Array) { let columns = vec![ SortColumn { @@ -89,9 +87,5 @@ fn add_benchmark(c: &mut Criterion) { }); } -criterion_group! { - name = benches; - config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); - targets = add_benchmark -} +criterion_group!(benches, add_benchmark); criterion_main!(benches); From 35f594b3da4e3126a19e5df7b7faaa5fa6fb2913 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Wed, 4 Aug 2021 12:43:51 +0800 Subject: [PATCH 11/13] [arrow2] fix && add tests for to_vec --- src/compute/merge_sort/mod.rs | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/compute/merge_sort/mod.rs b/src/compute/merge_sort/mod.rs index 3b97b2babc0..3f4736a24aa 100644 --- a/src/compute/merge_sort/mod.rs +++ b/src/compute/merge_sort/mod.rs @@ -300,13 +300,13 @@ where let mut v = Vec::with_capacity(limit); let mut current_len = 0; for (index, start, len) in self { - v.push((index, start, len)); - if len + current_len >= limit { + v.push((index, start, limit - current_len)); break; } else { - current_len += len; + v.push((index, start, len)); } + current_len += len; } v @@ -585,6 +585,22 @@ mod tests { Ok(()) } + #[test] + fn test_merge_slices_to_vec() -> Result<()> { + let a0: &dyn Array = &Int32Array::from_slice(&[0, 2, 4, 6, 8]); + let a1: &dyn Array = &Int32Array::from_slice(&[1, 3, 5, 7, 9]); + + let options = SortOptions::default(); + let arrays = vec![a0, a1]; + let pairs = vec![(arrays.as_ref(), &options)]; + let comparator = build_comparator(&pairs)?; + + let slices = merge_sort_slices(once(&(0, 0, 5)), once(&(1, 0, 5)), &comparator); + let vec = slices.to_vec(Some(5)); + assert_eq!(vec, [(0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1), (0, 2, 1)]); + Ok(()) + } + #[test] fn test_merge_4_i32() -> Result<()> { let a0: &dyn Array = &Int32Array::from_slice(&[0, 1]); From 1c94af48235d8f7c46b2d41db2c8466bccb3daa3 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Wed, 4 Aug 2021 14:53:38 +0800 Subject: [PATCH 12/13] [sort] add shrink_to_fit after all truncates --- src/compute/sort/lex_sort.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compute/sort/lex_sort.rs b/src/compute/sort/lex_sort.rs index 0a596bfd068..3a15a10b675 100644 --- a/src/compute/sort/lex_sort.rs +++ b/src/compute/sort/lex_sort.rs @@ -177,6 +177,7 @@ pub fn lexsort_to_indices( let (before, _, _) = values.select_nth_unstable_by(limit, lex_comparator); before.sort_unstable_by(lex_comparator); values.truncate(limit); + values.shrink_to_fit(); } else { values.sort_unstable_by(lex_comparator); } From 22d58cf543d53eac8fdc7c38d44129996c1aca85 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Wed, 4 Aug 2021 14:53:44 +0800 Subject: [PATCH 13/13] [sort] add shrink_to_fit after all truncates --- src/compute/sort/boolean.rs | 1 + src/compute/sort/primitive/sort.rs | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/compute/sort/boolean.rs b/src/compute/sort/boolean.rs index bb7faec6efa..6afdd551ec1 100644 --- a/src/compute/sort/boolean.rs +++ b/src/compute/sort/boolean.rs @@ -45,6 +45,7 @@ pub fn sort_boolean( // un-efficient; there are much more performant ways of sorting nulls above, anyways. if let Some(limit) = limit { values.truncate(limit); + values.shrink_to_fit(); } PrimitiveArray::::from_data(I::DATA_TYPE, values.into(), None) diff --git a/src/compute/sort/primitive/sort.rs b/src/compute/sort/primitive/sort.rs index 6fcedd80343..de6aa58867b 100644 --- a/src/compute/sort/primitive/sort.rs +++ b/src/compute/sort/primitive/sort.rs @@ -125,6 +125,7 @@ where }; // values are sorted, we can now truncate the remaining. buffer.truncate(limit); + buffer.shrink_to_fit(); (buffer.into(), new_validity.into()) } @@ -154,6 +155,7 @@ where sort_values(&mut buffer.as_mut_slice(), cmp, options.descending, limit); buffer.truncate(limit); + buffer.shrink_to_fit(); (buffer.into(), None) };