Skip to content

Commit

Permalink
perf: use zeroed vec in ewm_mean_by for sorted fastpath (#16265)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored May 16, 2024
1 parent 84ac01b commit 092e3da
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 24 deletions.
18 changes: 18 additions & 0 deletions crates/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
mod any_value;
use arrow::compute::concatenate::concatenate_validities;
use arrow::compute::utils::combine_validities_and;
pub mod flatten;
pub(crate) mod series;
mod supertype;
Expand Down Expand Up @@ -834,6 +836,22 @@ where
}
}

pub fn binary_concatenate_validities<'a, T, B>(
left: &'a ChunkedArray<T>,
right: &'a ChunkedArray<B>,
) -> Option<Bitmap>
where
B: PolarsDataType,
T: PolarsDataType,
{
let (left, right) = align_chunks_binary(left, right);
let left_chunk_refs: Vec<_> = left.chunks().iter().map(|c| &**c).collect();
let left_validity = concatenate_validities(&left_chunk_refs);
let right_chunk_refs: Vec<_> = right.chunks().iter().map(|c| &**c).collect();
let right_validity = concatenate_validities(&right_chunk_refs);
combine_validities_and(left_validity.as_ref(), right_validity.as_ref())
}

pub trait IntoVec<T> {
fn into_vec(self) -> Vec<T>;
}
Expand Down
45 changes: 21 additions & 24 deletions crates/polars-ops/src/series/ops/ewm_by.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use arrow::compute::concatenate::concatenate_validities;
use arrow::compute::utils::combine_validities_and;
use bytemuck::allocation::zeroed_vec;
use num_traits::{Float, FromPrimitive, One, Zero};
use polars_core::prelude::*;
use polars_core::utils::align_chunks_binary;
use polars_core::utils::binary_concatenate_validities;

pub fn ewm_mean_by(
s: &Series,
Expand Down Expand Up @@ -108,12 +106,7 @@ where
});
let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(true));
if (times.null_count() > 0) || (values.null_count() > 0) {
let (times, values) = align_chunks_binary(times, values);
let times_chunk_refs: Vec<_> = times.chunks().iter().map(|c| &**c).collect();
let times_validity = concatenate_validities(&times_chunk_refs);
let values_chunk_refs: Vec<_> = values.chunks().iter().map(|c| &**c).collect();
let values_validity = concatenate_validities(&values_chunk_refs);
let validity = combine_validities_and(times_validity.as_ref(), values_validity.as_ref());
let validity = binary_concatenate_validities(times, values);
arr = arr.with_validity_typed(validity);
}
ChunkedArray::with_chunk(values.name(), arr)
Expand All @@ -129,7 +122,7 @@ where
T: PolarsFloatType,
T::Native: Float + Zero + One,
{
let mut out = Vec::with_capacity(times.len());
let mut out: Vec<_> = zeroed_vec(times.len());

let mut skip_rows: usize = 0;
let mut prev_time: i64 = 0;
Expand All @@ -138,30 +131,34 @@ where
if let (Some(time), Some(value)) = (time, value) {
prev_time = time;
prev_result = value;
out.push(Some(prev_result));
unsafe {
*out.get_unchecked_mut(idx) = prev_result;
}
skip_rows = idx + 1;
break;
} else {
out.push(None)
}
}
values
.iter()
.zip(times.iter())
.enumerate()
.skip(skip_rows)
.for_each(|(value, time)| {
let result_opt = match (time, value) {
(Some(time), Some(value)) => {
let result = update(value, prev_result, time, prev_time, half_life);
prev_time = time;
prev_result = result;
Some(result)
},
_ => None,
.for_each(|(idx, (value, time))| {
if let (Some(time), Some(value)) = (time, value) {
let result = update(value, prev_result, time, prev_time, half_life);
prev_time = time;
prev_result = result;
unsafe {
*out.get_unchecked_mut(idx) = result;
}
};
out.push(result_opt);
});
ChunkedArray::<T>::from_iter_options(values.name(), out.into_iter())
let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(true));
if (times.null_count() > 0) || (values.null_count() > 0) {
let validity = binary_concatenate_validities(times, values);
arr = arr.with_validity_typed(validity);
}
ChunkedArray::with_chunk(values.name(), arr)
}

fn adjust_half_life_to_time_unit(half_life: i64, time_unit: &TimeUnit) -> i64 {
Expand Down

0 comments on commit 092e3da

Please sign in to comment.