Skip to content

Commit

Permalink
perf: improved numeric fill_(forward/backward) (#16475)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored May 25, 2024
1 parent 6de7422 commit a1ff7ee
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 17 deletions.
30 changes: 18 additions & 12 deletions crates/polars-arrow/src/legacy/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,27 @@ impl<T: NativeType> FromTrustedLenIterator<T> for PrimitiveArray<T> {
}
}

impl<T: NativeType> FromIteratorReversed<T> for PrimitiveArray<T> {
impl<T> FromIteratorReversed<T> for Vec<T> {
fn from_trusted_len_iter_rev<I: TrustedLen<Item = T>>(iter: I) -> Self {
let size = iter.size_hint().1.unwrap();

let mut vals: Vec<T> = Vec::with_capacity(size);
unsafe {
// Set to end of buffer.
let mut ptr = vals.as_mut_ptr().add(size);

iter.for_each(|item| {
ptr = ptr.sub(1);
std::ptr::write(ptr, item);
});
vals.set_len(size)
let len = iter.size_hint().1.unwrap();
let mut out: Vec<T> = Vec::with_capacity(len);
let mut idx = len;
for x in iter {
debug_assert!(idx > 0);
idx -= 1;
out.as_mut_ptr().add(idx).write(x);
}
debug_assert!(idx == 0);
out.set_len(len);
out
}
}
}

impl<T: NativeType> FromIteratorReversed<T> for PrimitiveArray<T> {
fn from_trusted_len_iter_rev<I: TrustedLen<Item = T>>(iter: I) -> Self {
let vals: Vec<T> = iter.collect_reversed();
PrimitiveArray::new(ArrowDataType::from(T::PRIMITIVE), vals.into(), None)
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/trusted_len.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ unsafe impl<A: TrustedLen> TrustedLen for std::iter::StepBy<A> {}
unsafe impl<I, St, F, B> TrustedLen for Scan<I, St, F>
where
F: FnMut(&mut St, I::Item) -> Option<B>,
I: TrustedLen + Iterator<Item = B>,
I: TrustedLen,
{
}

Expand Down
70 changes: 68 additions & 2 deletions crates/polars-core/src/chunked_array/ops/fill_null.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use arrow::bitmap::MutableBitmap;
use arrow::legacy::kernels::set::set_at_nulls;
use arrow::legacy::trusted_len::FromIteratorReversed;
use arrow::legacy::utils::FromTrustedLenIterator;
use bytemuck::Zeroable;
use num_traits::{Bounded, NumCast, One, Zero};

use crate::prelude::*;
Expand Down Expand Up @@ -280,6 +282,70 @@ macro_rules! impl_fill_backward_limit {
}};
}

fn fill_forward_numeric<'a, T, I>(ca: &'a ChunkedArray<T>) -> ChunkedArray<T>
where
T: PolarsDataType,
&'a ChunkedArray<T>: IntoIterator<IntoIter = I>,
I: TrustedLen + Iterator<Item = Option<T::Physical<'a>>>,
T::ZeroablePhysical<'a>: LocalCopy,
{
// Compute values.
let values: Vec<T::ZeroablePhysical<'a>> = ca
.into_iter()
.scan(T::ZeroablePhysical::zeroed(), |prev, v| {
*prev = v.map(|v| v.into()).unwrap_or(prev.cheap_clone());
Some(prev.cheap_clone())
})
.collect_trusted();

// Compute bitmask.
let num_start_nulls = ca.first_non_null().unwrap_or(ca.len());
let mut bm = MutableBitmap::with_capacity(ca.len());
bm.extend_constant(num_start_nulls, false);
bm.extend_constant(ca.len() - num_start_nulls, true);
ChunkedArray::from_chunk_iter_like(
ca,
[
T::Array::from_zeroable_vec(values, ca.dtype().to_arrow(true))
.with_validity_typed(Some(bm.into())),
],
)
}

fn fill_backward_numeric<'a, T, I>(ca: &'a ChunkedArray<T>) -> ChunkedArray<T>
where
T: PolarsDataType,
&'a ChunkedArray<T>: IntoIterator<IntoIter = I>,
I: TrustedLen + Iterator<Item = Option<T::Physical<'a>>> + DoubleEndedIterator,
T::ZeroablePhysical<'a>: LocalCopy,
{
// Compute values.
let values: Vec<T::ZeroablePhysical<'a>> = ca
.into_iter()
.rev()
.scan(T::ZeroablePhysical::zeroed(), |prev, v| {
*prev = v.map(|v| v.into()).unwrap_or(prev.cheap_clone());
Some(prev.cheap_clone())
})
.collect_reversed();

// Compute bitmask.
let num_end_nulls = ca
.last_non_null()
.map(|i| ca.len() - 1 - i)
.unwrap_or(ca.len());
let mut bm = MutableBitmap::with_capacity(ca.len());
bm.extend_constant(ca.len() - num_end_nulls, true);
bm.extend_constant(num_end_nulls, false);
ChunkedArray::from_chunk_iter_like(
ca,
[
T::Array::from_zeroable_vec(values, ca.dtype().to_arrow(true))
.with_validity_typed(Some(bm.into())),
],
)
}

fn fill_null_numeric<T>(
ca: &ChunkedArray<T>,
strategy: FillNullStrategy,
Expand All @@ -293,9 +359,9 @@ where
return Ok(ca.clone());
}
let mut out = match strategy {
FillNullStrategy::Forward(None) => fill_forward(ca),
FillNullStrategy::Forward(None) => fill_forward_numeric(ca),
FillNullStrategy::Forward(Some(limit)) => fill_forward_limit(ca, limit),
FillNullStrategy::Backward(None) => fill_backward(ca),
FillNullStrategy::Backward(None) => fill_backward_numeric(ca),
FillNullStrategy::Backward(Some(limit)) => fill_backward_limit(ca, limit),
FillNullStrategy::Min => {
ca.fill_null_with_values(ChunkAgg::min(ca).ok_or_else(err_fill_null)?)?
Expand Down
18 changes: 16 additions & 2 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,10 +1866,24 @@ def test_fill_nan() -> None:
assert df.fill_nan(2.0).dtypes == [pl.Float64, pl.Datetime]


def test_forward_fill() -> None:
df = pl.DataFrame({"a": [1.0, None, 3.0]})
fill = df.select(pl.col("a").forward_fill())["a"]
assert_series_equal(fill, pl.Series("a", [1, 1, 3]).cast(pl.Float64))

df = pl.DataFrame({"a": [None, 1, None]})
fill = df.select(pl.col("a").forward_fill())["a"]
assert_series_equal(fill, pl.Series("a", [None, 1, 1]).cast(pl.Int64))


def test_backward_fill() -> None:
df = pl.DataFrame({"a": [1.0, None, 3.0]})
col_a_backward_fill = df.select([pl.col("a").backward_fill()])["a"]
assert_series_equal(col_a_backward_fill, pl.Series("a", [1, 3, 3]).cast(pl.Float64))
fill = df.select(pl.col("a").backward_fill())["a"]
assert_series_equal(fill, pl.Series("a", [1, 3, 3]).cast(pl.Float64))

df = pl.DataFrame({"a": [None, 1, None]})
fill = df.select(pl.col("a").backward_fill())["a"]
assert_series_equal(fill, pl.Series("a", [1, 1, None]).cast(pl.Int64))


def test_shrink_to_fit() -> None:
Expand Down

0 comments on commit a1ff7ee

Please sign in to comment.