Skip to content

Commit

Permalink
perf: Dedup binviews up front (#20449)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 25, 2024
1 parent ffc5538 commit 96b7a9a
Showing 1 changed file with 43 additions and 52 deletions.
95 changes: 43 additions & 52 deletions crates/polars-ops/src/chunked_array/gather/chunked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use arrow::array::{Array, BinaryViewArrayGeneric, View, ViewType};
use arrow::bitmap::BitmapBuilder;
use arrow::buffer::Buffer;
use arrow::legacy::trusted_len::TrustedLenPush;
use hashbrown::hash_map::Entry;
use polars_core::prelude::gather::_update_gather_sorted_flag;
use polars_core::prelude::*;
use polars_core::series::IsSorted;
Expand Down Expand Up @@ -431,29 +430,6 @@ unsafe fn take_opt_unchecked_object<const B: u64>(s: &Series, by: &[ChunkId<B>])
builder.to_series()
}

unsafe fn update_view(
mut view: View,
orig_buffers: &[Buffer<u8>],
buffer_idxs: &mut PlHashMap<(*const u8, usize), u32>,
buffers: &mut Vec<Buffer<u8>>,
) -> View {
if view.length > 12 {
// Dedup on pointer + length.
let orig_buffer = orig_buffers.get_unchecked(view.buffer_idx as usize);
view.buffer_idx =
match buffer_idxs.entry((orig_buffer.as_slice().as_ptr(), orig_buffer.len())) {
Entry::Occupied(o) => *o.get(),
Entry::Vacant(v) => {
let buffer_idx = buffers.len() as u32;
buffers.push(orig_buffer.clone());
v.insert(buffer_idx);
buffer_idx
},
};
}
view
}

unsafe fn take_unchecked_binview<const B: u64, T, V>(
ca: &ChunkedArray<T>,
by: &[ChunkId<B>],
Expand Down Expand Up @@ -497,8 +473,7 @@ where

arc_data_buffers = arr.data_buffers().clone();
} else {
let mut buffer_idxs = PlHashMap::with_capacity(8);
let mut buffers = Vec::with_capacity(8);
let (buffers, buffer_offsets) = dedup_buffers(ca);

validity = if ca.has_nulls() {
let mut validity = BitmapBuilder::with_capacity(by.len());
Expand All @@ -511,12 +486,8 @@ where
validity.push_unchecked(false);
} else {
let view = *arr.views().get_unchecked(array_idx as usize);
views.push_unchecked(update_view(
view,
arr.data_buffers(),
&mut buffer_idxs,
&mut buffers,
));
let view = rewrite_view(view, chunk_idx, &buffer_offsets);
views.push_unchecked(view);
validity.push_unchecked(true);
}
}
Expand All @@ -527,12 +498,8 @@ where

let arr = ca.downcast_get_unchecked(chunk_idx as usize);
let view = *arr.views().get_unchecked(array_idx as usize);
views.push_unchecked(update_view(
view,
arr.data_buffers(),
&mut buffer_idxs,
&mut buffers,
));
let view = rewrite_view(view, chunk_idx, &buffer_offsets);
views.push_unchecked(view);
}
None
};
Expand All @@ -554,6 +521,39 @@ where
out
}

#[allow(clippy::unnecessary_cast)]
#[inline(always)]
unsafe fn rewrite_view(mut view: View, chunk_idx: IdxSize, buffer_offsets: &[u32]) -> View {
if view.length > 12 {
let base_offset = *buffer_offsets.get_unchecked(chunk_idx as usize);
view.buffer_idx += base_offset;
}
view
}

fn dedup_buffers<T, V>(ca: &ChunkedArray<T>) -> (Vec<Buffer<u8>>, Vec<u32>)
where
T: PolarsDataType<Array = BinaryViewArrayGeneric<V>>,
V: ViewType + ?Sized,
{
// Dedup buffers up front. Note: don't do this during view update, as this is much more
// costly.
let mut buffers = Vec::with_capacity(ca.chunks().len());
// Dont need to include the length, as we look at the arc pointers, which are immutable.
let mut buffers_dedup = PlHashSet::with_capacity(ca.chunks().len());
let mut buffer_offsets = Vec::with_capacity(ca.chunks().len() + 1);

for arr in ca.downcast_iter() {
let data_buffers = arr.data_buffers();
let arc_ptr = data_buffers.as_ptr();
buffer_offsets.push(buffers.len() as u32);
if buffers_dedup.insert(arc_ptr) {
buffers.extend(data_buffers.iter().cloned())
}
}
(buffers, buffer_offsets)
}

unsafe fn take_unchecked_binview_opt<const B: u64, T, V>(
ca: &ChunkedArray<T>,
by: &[ChunkId<B>],
Expand Down Expand Up @@ -599,8 +599,7 @@ where

arr.data_buffers().clone()
} else {
let mut buffer_idxs = PlHashMap::with_capacity(8);
let mut buffers = Vec::with_capacity(8);
let (buffers, buffer_offsets) = dedup_buffers(ca);

if ca.has_nulls() {
for id in by.iter() {
Expand All @@ -616,12 +615,8 @@ where
validity.push_unchecked(false);
} else {
let view = *arr.views().get_unchecked(array_idx as usize);
views.push_unchecked(update_view(
view,
arr.data_buffers(),
&mut buffer_idxs,
&mut buffers,
));
let view = rewrite_view(view, chunk_idx, &buffer_offsets);
views.push_unchecked(view);
validity.push_unchecked(true);
}
}
Expand All @@ -636,12 +631,8 @@ where
} else {
let arr = ca.downcast_get_unchecked(chunk_idx as usize);
let view = *arr.views().get_unchecked(array_idx as usize);
views.push_unchecked(update_view(
view,
arr.data_buffers(),
&mut buffer_idxs,
&mut buffers,
));
let view = rewrite_view(view, chunk_idx, &buffer_offsets);
views.push_unchecked(view);
validity.push_unchecked(true);
}
}
Expand Down

0 comments on commit 96b7a9a

Please sign in to comment.