Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Sort improve #246

Merged
merged 13 commits into from
Aug 4, 2021
8 changes: 4 additions & 4 deletions benches/growable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn add_benchmark(c: &mut Criterion) {
let i32_array = create_primitive_array::<i32>(1026 * 10, DataType::Int32, 0.0);
c.bench_function("growable::primitive::non_null::non_null", |b| {
b.iter(|| {
let mut a = GrowablePrimitive::new(&[&i32_array], false, 1026 * 10);
let mut a = GrowablePrimitive::new(vec![&i32_array], false, 1026 * 10);
values
.clone()
.into_iter()
Expand All @@ -25,7 +25,7 @@ fn add_benchmark(c: &mut Criterion) {
let i32_array = create_primitive_array::<i32>(1026 * 10, DataType::Int32, 0.0);
c.bench_function("growable::primitive::non_null::null", |b| {
b.iter(|| {
let mut a = GrowablePrimitive::new(&[&i32_array], true, 1026 * 10);
let mut a = GrowablePrimitive::new(vec![&i32_array], true, 1026 * 10);
values.clone().into_iter().for_each(|start| {
if start % 2 == 0 {
a.extend_validity(10);
Expand All @@ -41,7 +41,7 @@ fn add_benchmark(c: &mut Criterion) {
let values = values.collect::<Vec<_>>();
c.bench_function("growable::primitive::null::non_null", |b| {
b.iter(|| {
let mut a = GrowablePrimitive::new(&[&i32_array], false, 1026 * 10);
let mut a = GrowablePrimitive::new(vec![&i32_array], false, 1026 * 10);
values
.clone()
.into_iter()
Expand All @@ -50,7 +50,7 @@ fn add_benchmark(c: &mut Criterion) {
});
c.bench_function("growable::primitive::null::null", |b| {
b.iter(|| {
let mut a = GrowablePrimitive::new(&[&i32_array], true, 1026 * 10);
let mut a = GrowablePrimitive::new(vec![&i32_array], true, 1026 * 10);
values.clone().into_iter().for_each(|start| {
if start % 2 == 0 {
a.extend_validity(10);
Expand Down
15 changes: 14 additions & 1 deletion benches/sort_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
extern crate criterion;
use criterion::Criterion;

use arrow2::compute::sort::{lexsort, sort, SortColumn, SortOptions};
use arrow2::compute::sort::{lexsort, sort, sort_to_indices, SortColumn, SortOptions};
use arrow2::util::bench_util::*;
use arrow2::{array::*, datatypes::*};

Expand All @@ -42,6 +42,15 @@ fn bench_sort(arr_a: &dyn Array) {
sort(criterion::black_box(arr_a), &SortOptions::default(), None).unwrap();
}

fn bench_sort_limit(arr_a: &dyn Array) {
let _: PrimitiveArray<u32> = sort_to_indices(
criterion::black_box(arr_a),
&SortOptions::default(),
Some(100),
)
.unwrap();
}

fn add_benchmark(c: &mut Criterion) {
(10..=20).step_by(2).for_each(|log2_size| {
let size = 2usize.pow(log2_size);
Expand All @@ -51,6 +60,10 @@ fn add_benchmark(c: &mut Criterion) {
b.iter(|| bench_sort(&arr_a))
});

c.bench_function(&format!("sort-limit 2^{} f32", log2_size), |b| {
b.iter(|| bench_sort_limit(&arr_a))
});

let arr_b = create_primitive_array_with_seed::<f32>(size, DataType::Float32, 0.0, 43);
c.bench_function(&format!("lexsort 2^{} f32", log2_size), |b| {
b.iter(|| bench_lexsort(&arr_a, &arr_b))
Expand Down
2 changes: 1 addition & 1 deletion benches/take_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn create_random_index(size: usize, null_density: f32) -> PrimitiveArray<i32> {
(0..size)
.map(|_| {
if rng.gen::<f32>() > null_density {
let value = rng.gen_range::<i32, _, _>(0i32, size as i32);
let value = rng.gen_range::<i32, _>(0i32..size as i32);
Some(value)
} else {
None
Expand Down
13 changes: 12 additions & 1 deletion src/array/specification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ use crate::{
};

/// Trait describing any type that can be used to index a slot of an array.
pub trait Index: NativeType + NaturalDataType {
pub trait Index: NativeType + NaturalDataType + std::iter::Step {
fn to_usize(&self) -> usize;
fn from_usize(index: usize) -> Option<Self>;
fn is_usize() -> bool {
false
}
}

/// Trait describing types that can be used as offsets as per Arrow specification.
Expand Down Expand Up @@ -95,6 +98,10 @@ impl Index for u32 {
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}

fn is_usize() -> bool {
std::mem::size_of::<Self>() == std::mem::size_of::<usize>()
}
}

impl Index for u64 {
Expand All @@ -107,6 +114,10 @@ impl Index for u64 {
fn from_usize(value: usize) -> Option<Self> {
Self::try_from(value).ok()
}

fn is_usize() -> bool {
std::mem::size_of::<Self>() == std::mem::size_of::<usize>()
}
}

#[inline]
Expand Down
13 changes: 13 additions & 0 deletions src/buffer/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ impl<T: NativeType> MutableBuffer<T> {
}
}

/// Allocates a new [MutableBuffer] with `len` and capacity to be at least `len` where
/// all bytes are not initialized
#[inline]
pub fn from_len(len: usize) -> Self {
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
let new_capacity = capacity_multiple_of_64::<T>(len);
let ptr = alloc::allocate_aligned(new_capacity);
Self {
ptr,
len,
capacity: new_capacity,
}
}

/// Ensures that this buffer has at least `self.len + additional` bytes. This re-allocates iff
/// `self.len + additional > capacity`.
/// # Example
Expand Down
35 changes: 31 additions & 4 deletions src/compute/merge_sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ use crate::error::Result;
/// This representation is useful when building arrays in memory as it allows to memcopy slices of arrays.
/// This is particularly useful in merge-sort because sorted arrays (passed to the merge-sort) are more likely
/// to have contiguous blocks of sorted elements (than by random).
type MergeSlice = (usize, usize, usize);
pub type MergeSlice = (usize, usize, usize);

/// Takes N arrays together through `slices` under the assumption that the slices have
/// a total coverage of the arrays.
Expand Down Expand Up @@ -234,7 +234,7 @@ fn recursive_merge_sort(slices: &[&[MergeSlice]], comparator: &Comparator) -> Ve

// An iterator adapter that merge-sorts two iterators of `MergeSlice` into a single `MergeSlice`
// such that the resulting `MergeSlice`s are ordered according to `comparator`.
struct MergeSortSlices<'a, L, R>
pub struct MergeSortSlices<'a, L, R>
where
L: Iterator<Item = &'a MergeSlice>,
R: Iterator<Item = &'a MergeSlice>,
Expand Down Expand Up @@ -291,6 +291,29 @@ where
None => self.right = None,
}
}

/// Collect the MergeSortSlices to be a vec for reusing
#[warn(dead_code)]
pub fn to_vec(self, limit: Option<usize>) -> Vec<MergeSlice> {
match limit {
Some(limit) => {
let mut v = Vec::with_capacity(limit);
let mut current_len = 0;
for (index, start, len) in self {
v.push((index, start, len));

if len + current_len >= limit {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this misses the last entries: if the last slice is past limit, imo it should be modified to be equal to limit - current_len, so that we do not lose the last items within the slice

break;
} else {
current_len += len;
}
}

v
}
None => self.into_iter().collect(),
}
}
}

impl<'a, L, R> Iterator for MergeSortSlices<'a, L, R>
Expand Down Expand Up @@ -426,7 +449,11 @@ where
/// Given two iterators of slices representing two sets of sorted [`Array`]s, and a `comparator` bound to those [`Array`]s,
/// returns a new iterator of slices denoting how to `take` slices from each of the arrays such that the resulting
/// array is sorted according to `comparator`
fn merge_sort_slices<'a, L: Iterator<Item = &'a MergeSlice>, R: Iterator<Item = &'a MergeSlice>>(
pub fn merge_sort_slices<
'a,
L: Iterator<Item = &'a MergeSlice>,
R: Iterator<Item = &'a MergeSlice>,
>(
lhs: L,
rhs: R,
comparator: &'a Comparator,
Expand All @@ -439,7 +466,7 @@ type Comparator<'a> = Box<dyn Fn(usize, usize, usize, usize) -> Ordering + 'a>;
type IsValid<'a> = Box<dyn Fn(usize) -> bool + 'a>;

/// returns a comparison function between any two arrays of each pair of arrays, according to `SortOptions`.
fn build_comparator<'a>(
pub fn build_comparator<'a>(
pairs: &'a [(&'a [&'a dyn Array], &SortOptions)],
) -> Result<Comparator<'a>> {
// prepare the comparison function of _values_ between all pairs of arrays
Expand Down
2 changes: 1 addition & 1 deletion src/compute/sort/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub fn sort_boolean<I: Index>(
if !descending {
valids.sort_by(|a, b| a.1.cmp(&b.1));
} else {
valids.sort_by(|a, b| a.1.cmp(&b.1).reverse());
valids.sort_by(|a, b| b.1.cmp(&a.1));
// reverse to keep a stable ordering
nulls.reverse();
}
Expand Down
87 changes: 35 additions & 52 deletions src/compute/sort/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,31 @@ use super::SortOptions;
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
fn k_element_sort_inner<I: Index, T, G, F>(
fn k_element_sort_inner<I: Index, T, F>(
indices: &mut [I],
get: G,
values: &[T],
descending: bool,
limit: usize,
mut cmp: F,
) where
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if descending {
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
let mut compare = |lhs: &I, rhs: &I| unsafe {
let lhs = values.get_unchecked(lhs.to_usize());
let rhs = values.get_unchecked(rhs.to_usize());
cmp(rhs, lhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
};
before.sort_unstable_by(compare);
let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare);
before.sort_unstable_by(&mut compare);
} else {
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
};
let (before, _, _) = indices.select_nth_unstable_by(limit, compare);
let compare = |lhs: &I, rhs: &I| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
let mut compare = |lhs: &I, rhs: &I| unsafe {
let lhs = values.get_unchecked(lhs.to_usize());
let rhs = values.get_unchecked(rhs.to_usize());
cmp(lhs, rhs)
};
before.sort_unstable_by(compare);
let (before, _, _) = indices.select_nth_unstable_by(limit, &mut compare);
before.sort_unstable_by(&mut compare);
}
}

Expand All @@ -55,32 +44,30 @@ fn k_element_sort_inner<I: Index, T, G, F>(
/// * `get` is only called for `0 <= i < limit`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
fn sort_unstable_by<I, T, G, F>(
fn sort_unstable_by<I: Index, T, F>(
indices: &mut [I],
get: G,
values: &[T],
mut cmp: F,
descending: bool,
limit: usize,
) where
I: Index,
G: Fn(usize) -> T,
F: FnMut(&T, &T) -> std::cmp::Ordering,
{
if limit != indices.len() {
return k_element_sort_inner(indices, get, descending, limit, cmp);
return k_element_sort_inner(indices, values, descending, limit, cmp);
}

if descending {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs).reverse()
indices.sort_unstable_by(|lhs, rhs| unsafe {
let lhs = values.get_unchecked(lhs.to_usize());
let rhs = values.get_unchecked(rhs.to_usize());
cmp(rhs, lhs)
})
} else {
indices.sort_unstable_by(|lhs, rhs| {
let lhs = get(lhs.to_usize());
let rhs = get(rhs.to_usize());
cmp(&lhs, &rhs)
indices.sort_unstable_by(|lhs, rhs| unsafe {
let lhs = values.get_unchecked(lhs.to_usize());
let rhs = values.get_unchecked(rhs.to_usize());
cmp(lhs, rhs)
})
}
}
Expand All @@ -90,17 +77,16 @@ fn sort_unstable_by<I, T, G, F>(
/// * `get` is only called for `0 <= i < length`
/// * `cmp` is only called from the co-domain of `get`.
#[inline]
pub(super) fn indices_sorted_unstable_by<I, T, G, F>(
pub(super) fn indices_sorted_unstable_by<I, T, F>(
validity: &Option<Bitmap>,
get: G,
values: &[T],
cmp: F,
length: usize,
options: &SortOptions,
limit: Option<usize>,
) -> PrimitiveArray<I>
where
I: Index,
G: Fn(usize) -> T,
F: Fn(&T, &T) -> std::cmp::Ordering,
{
let descending = options.descending;
Expand All @@ -110,8 +96,7 @@ where
let limit = limit.min(length);

let indices = if let Some(validity) = validity {
let mut indices = MutableBuffer::<I>::from_len_zeroed(length);

let mut indices = MutableBuffer::<I>::from_len(length);
if options.nulls_first {
let mut nulls = 0;
let mut valids = 0;
Expand All @@ -134,12 +119,12 @@ where
// Soundness:
// all indices in `indices` are by construction `< array.len() == values.len()`
// limit is by construction < indices.len()
let limit = limit - validity.null_count();
let limit = limit.saturating_sub(validity.null_count());
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
let indices = &mut indices.as_mut_slice()[validity.null_count()..];
sort_unstable_by(indices, get, cmp, options.descending, limit)
sort_unstable_by(indices, values, cmp, options.descending, limit)
}
} else {
let last_valid_index = length - validity.null_count();
let last_valid_index = length.saturating_sub(validity.null_count());
let mut nulls = 0;
let mut valids = 0;
validity.iter().zip(0..length).for_each(|(x, index)| {
Expand All @@ -157,7 +142,7 @@ where
// limit is by construction <= values.len()
let limit = limit.min(last_valid_index);
let indices = &mut indices.as_mut_slice()[..last_valid_index];
sort_unstable_by(indices, get, cmp, options.descending, limit);
sort_unstable_by(indices, values, cmp, options.descending, limit);
}

indices.truncate(limit);
Expand All @@ -167,19 +152,17 @@ where
} else {
let mut indices = unsafe {
MutableBuffer::from_trusted_len_iter_unchecked(
(0..length).map(|x| I::from_usize(x).unwrap()),
I::from_usize(0).unwrap()..I::from_usize(length).unwrap(),
)
};

// Soundness:
// indices are by construction `< values.len()`
// limit is by construction `< values.len()`
sort_unstable_by(&mut indices, get, cmp, descending, limit);

sort_unstable_by(&mut indices, values, cmp, descending, limit);
indices.truncate(limit);
indices.shrink_to_fit();

indices
};

PrimitiveArray::<I>::from_data(I::DATA_TYPE, indices.into(), None)
}
Loading