Skip to content

Commit

Permalink
refactor(rust): Fix mean reduction in new-streaming (#18572)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Sep 5, 2024
1 parent c85b338 commit 76b8f46
Show file tree
Hide file tree
Showing 13 changed files with 64 additions and 19 deletions.
12 changes: 7 additions & 5 deletions crates/polars-core/src/chunked_array/ops/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ where
)
}

fn _sum_as_f64(&self) -> f64 {
self.downcast_iter().map(float_sum::sum_arr_as_f64).sum()
}

fn min(&self) -> Option<T::Native> {
if self.null_count() == self.len() {
return None;
Expand Down Expand Up @@ -216,13 +220,11 @@ where
}

fn mean(&self) -> Option<f64> {
if self.null_count() == self.len() {
let count = self.len() - self.null_count();
if count == 0 {
return None;
}

let len = (self.len() - self.null_count()) as f64;
let sum: f64 = self.downcast_iter().map(float_sum::sum_arr_as_f64).sum();
Some(sum / len)
Some(self._sum_as_f64() / count as f64)
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ pub trait ChunkAgg<T> {
None
}

fn _sum_as_f64(&self) -> f64;

fn min(&self) -> Option<T> {
None
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/series/implementations/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ impl SeriesTrait for SeriesWrap<BooleanChunked> {
ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series())
}

fn _sum_as_f64(&self) -> f64 {
self.0.sum().unwrap() as f64
}

fn mean(&self) -> Option<f64> {
self.0.mean()
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/series/implementations/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ impl SeriesTrait for SeriesWrap<DateChunked> {
(a.into_date().into_series(), b.into_date().into_series())
}

fn _sum_as_f64(&self) -> f64 {
self.0._sum_as_f64()
}

fn mean(&self) -> Option<f64> {
self.0.mean()
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/series/implementations/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ impl SeriesTrait for SeriesWrap<DatetimeChunked> {
)
}

fn _sum_as_f64(&self) -> f64 {
self.0._sum_as_f64()
}

fn mean(&self) -> Option<f64> {
self.0.mean()
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/series/implementations/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ impl SeriesTrait for SeriesWrap<DecimalChunked> {
}))
}

fn _sum_as_f64(&self) -> f64 {
self.0._sum_as_f64() / self.scale_factor() as f64
}

fn mean(&self) -> Option<f64> {
self.0.mean().map(|v| v / self.scale_factor() as f64)
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/series/implementations/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ impl SeriesTrait for SeriesWrap<DurationChunked> {
(a, b)
}

fn _sum_as_f64(&self) -> f64 {
self.0._sum_as_f64()
}

fn mean(&self) -> Option<f64> {
self.0.mean()
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/series/implementations/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ macro_rules! impl_dyn_series {
ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series())
}

fn _sum_as_f64(&self) -> f64 {
self.0._sum_as_f64()
}

fn mean(&self) -> Option<f64> {
self.0.mean()
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ macro_rules! impl_dyn_series {
ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series())
}

fn _sum_as_f64(&self) -> f64 {
self.0._sum_as_f64()
}

fn mean(&self) -> Option<f64> {
self.0.mean()
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-core/src/series/implementations/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ impl SeriesTrait for SeriesWrap<TimeChunked> {
(a.into_series(), b.into_series())
}

fn _sum_as_f64(&self) -> f64 {
self.0._sum_as_f64()
}

fn mean(&self) -> Option<f64> {
self.0.mean()
}
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ pub trait SeriesTrait:
}
}

/// Returns the sum of the array as an f64.
fn _sum_as_f64(&self) -> f64 {
invalid_operation_panic!(_sum_as_f64, self)
}

/// Returns the mean value in the array
/// Returns an option because the array is nullable.
fn mean(&self) -> Option<f64> {
Expand Down
29 changes: 17 additions & 12 deletions crates/polars-expr/src/reduce/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,25 @@ pub fn into_reduction(
expr_arena: &mut Arena<AExpr>,
schema: &Schema,
) -> PolarsResult<(Box<dyn Reduction>, Node)> {
let e = expr_arena.get(node);
let field = e.to_field(schema, Context::Default, expr_arena)?;
let get_dt = |node| {
expr_arena
.get(node)
.to_dtype(schema, Context::Default, expr_arena)
};
let out = match expr_arena.get(node) {
AExpr::Agg(agg) => match agg {
IRAggExpr::Sum(node) => (
Box::new(SumReduce::new(field.dtype.clone())) as Box<dyn Reduction>,
*node,
IRAggExpr::Sum(input) => (
Box::new(SumReduce::new(get_dt(*input)?)) as Box<dyn Reduction>,
*input,
),
IRAggExpr::Min {
propagate_nans,
input,
} => {
if *propagate_nans && field.dtype.is_float() {
let dt = get_dt(*input)?;
if *propagate_nans && dt.is_float() {
feature_gated!("propagate_nans", {
let out: Box<dyn Reduction> = match field.dtype {
let out: Box<dyn Reduction> = match dt {
DataType::Float32 => Box::new(NanMinReduce::<Float32Type>::new()),
DataType::Float64 => Box::new(NanMinReduce::<Float64Type>::new()),
_ => unreachable!(),
Expand All @@ -39,7 +43,7 @@ pub fn into_reduction(
})
} else {
(
Box::new(MinReduce::new(field.dtype.clone())) as Box<dyn Reduction>,
Box::new(MinReduce::new(dt.clone())) as Box<dyn Reduction>,
*input,
)
}
Expand All @@ -48,21 +52,22 @@ pub fn into_reduction(
propagate_nans,
input,
} => {
if *propagate_nans && field.dtype.is_float() {
let dt = get_dt(*input)?;
if *propagate_nans && dt.is_float() {
feature_gated!("propagate_nans", {
let out: Box<dyn Reduction> = match field.dtype {
let out: Box<dyn Reduction> = match dt {
DataType::Float32 => Box::new(NanMaxReduce::<Float32Type>::new()),
DataType::Float64 => Box::new(NanMaxReduce::<Float64Type>::new()),
_ => unreachable!(),
};
(out, *input)
})
} else {
(Box::new(MaxReduce::new(field.dtype.clone())) as _, *input)
(Box::new(MaxReduce::new(dt.clone())) as _, *input)
}
},
IRAggExpr::Mean(input) => {
let out: Box<dyn Reduction> = Box::new(MeanReduce::new(field.dtype.clone()));
let out: Box<dyn Reduction> = Box::new(MeanReduce::new(get_dt(*input)?));
(out, *input)
},
_ => unreachable!(),
Expand Down
3 changes: 1 addition & 2 deletions crates/polars-expr/src/reduce/mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ pub struct MeanReduceState {

impl ReductionState for MeanReduceState {
fn update(&mut self, batch: &Series) -> PolarsResult<()> {
// TODO: don't go through mean but add sum_as_f64 to series trait.
let count = batch.len() as u64 - batch.null_count() as u64;
self.count += count;
self.sum += batch.mean().unwrap_or(0.0) * count as f64;
self.sum += batch._sum_as_f64();
Ok(())
}

Expand Down

0 comments on commit 76b8f46

Please sign in to comment.