diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index 3e7e5d30b9ed..b94d724a5185 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -88,6 +88,10 @@ where ) } + fn _sum_as_f64(&self) -> f64 { + self.downcast_iter().map(float_sum::sum_arr_as_f64).sum() + } + fn min(&self) -> Option { if self.null_count() == self.len() { return None; @@ -216,13 +220,11 @@ where } fn mean(&self) -> Option { - if self.null_count() == self.len() { + let count = self.len() - self.null_count(); + if count == 0 { return None; } - - let len = (self.len() - self.null_count()) as f64; - let sum: f64 = self.downcast_iter().map(float_sum::sum_arr_as_f64).sum(); - Some(sum / len) + Some(self._sum_as_f64() / count as f64) } } diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 615d2af96b70..b252d23814eb 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -242,6 +242,8 @@ pub trait ChunkAgg { None } + fn _sum_as_f64(&self) -> f64; + fn min(&self) -> Option { None } diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index 49da460464b8..aae8a5837af8 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -166,6 +166,10 @@ impl SeriesTrait for SeriesWrap { ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } + fn _sum_as_f64(&self) -> f64 { + self.0.sum().unwrap() as f64 + } + fn mean(&self) -> Option { self.0.mean() } diff --git a/crates/polars-core/src/series/implementations/date.rs b/crates/polars-core/src/series/implementations/date.rs index 3726043c4f13..834449e73992 100644 --- a/crates/polars-core/src/series/implementations/date.rs +++ b/crates/polars-core/src/series/implementations/date.rs @@ -170,6 +170,10 @@ impl SeriesTrait for SeriesWrap { (a.into_date().into_series(), b.into_date().into_series()) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index 59c733c8d1e9..a6a5f111d541 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -176,6 +176,10 @@ impl SeriesTrait for SeriesWrap { ) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 98f579a95e8f..30125ccc15b6 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -382,6 +382,10 @@ impl SeriesTrait for SeriesWrap { })) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() / self.scale_factor() as f64 + } + fn mean(&self) -> Option { self.0.mean().map(|v| v / self.scale_factor() as f64) } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 35751b722485..73d2e4f730fb 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -288,6 +288,10 @@ impl SeriesTrait for SeriesWrap { (a, b) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index c3f545573d48..cc52d73cdc60 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -215,6 +215,10 @@ macro_rules! impl_dyn_series { ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 6b6dd08f36cc..3e4e41395b0b 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -318,6 +318,10 @@ macro_rules! impl_dyn_series { ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } diff --git a/crates/polars-core/src/series/implementations/time.rs b/crates/polars-core/src/series/implementations/time.rs index 137de2f31961..3808f7d977af 100644 --- a/crates/polars-core/src/series/implementations/time.rs +++ b/crates/polars-core/src/series/implementations/time.rs @@ -145,6 +145,10 @@ impl SeriesTrait for SeriesWrap { (a.into_series(), b.into_series()) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 933a600d4cb2..b5b60c5eff33 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -292,6 +292,11 @@ pub trait SeriesTrait: } } + /// Returns the sum of the array as an f64. + fn _sum_as_f64(&self) -> f64 { + invalid_operation_panic!(_sum_as_f64, self) + } + /// Returns the mean value in the array /// Returns an option because the array is nullable. fn mean(&self) -> Option { diff --git a/crates/polars-expr/src/reduce/convert.rs b/crates/polars-expr/src/reduce/convert.rs index 8859b37b25c5..279d77d6eb67 100644 --- a/crates/polars-expr/src/reduce/convert.rs +++ b/crates/polars-expr/src/reduce/convert.rs @@ -16,21 +16,25 @@ pub fn into_reduction( expr_arena: &mut Arena, schema: &Schema, ) -> PolarsResult<(Box, Node)> { - let e = expr_arena.get(node); - let field = e.to_field(schema, Context::Default, expr_arena)?; + let get_dt = |node| { + expr_arena + .get(node) + .to_dtype(schema, Context::Default, expr_arena) + }; let out = match expr_arena.get(node) { AExpr::Agg(agg) => match agg { - IRAggExpr::Sum(node) => ( - Box::new(SumReduce::new(field.dtype.clone())) as Box, - *node, + IRAggExpr::Sum(input) => ( + Box::new(SumReduce::new(get_dt(*input)?)) as Box, + *input, ), IRAggExpr::Min { propagate_nans, input, } => { - if *propagate_nans && field.dtype.is_float() { + let dt = get_dt(*input)?; + if *propagate_nans && dt.is_float() { feature_gated!("propagate_nans", { - let out: Box = match field.dtype { + let out: Box = match dt { DataType::Float32 => Box::new(NanMinReduce::::new()), DataType::Float64 => Box::new(NanMinReduce::::new()), _ => unreachable!(), @@ -39,7 +43,7 @@ pub fn into_reduction( }) } else { ( - Box::new(MinReduce::new(field.dtype.clone())) as Box, + Box::new(MinReduce::new(dt.clone())) as Box, *input, ) } @@ -48,9 +52,10 @@ pub fn into_reduction( propagate_nans, input, } => { - if *propagate_nans && field.dtype.is_float() { + let dt = get_dt(*input)?; + if *propagate_nans && dt.is_float() { feature_gated!("propagate_nans", { - let out: Box = match field.dtype { + let out: Box = match dt { DataType::Float32 => Box::new(NanMaxReduce::::new()), DataType::Float64 => Box::new(NanMaxReduce::::new()), _ => unreachable!(), @@ -58,11 +63,11 @@ pub fn into_reduction( (out, *input) }) } else { - (Box::new(MaxReduce::new(field.dtype.clone())) as _, *input) + (Box::new(MaxReduce::new(dt.clone())) as _, *input) } }, IRAggExpr::Mean(input) => { - let out: Box = Box::new(MeanReduce::new(field.dtype.clone())); + let out: Box = Box::new(MeanReduce::new(get_dt(*input)?)); (out, *input) }, _ => unreachable!(), diff --git a/crates/polars-expr/src/reduce/mean.rs b/crates/polars-expr/src/reduce/mean.rs index dc79c2658714..e8b19b342de6 100644 --- a/crates/polars-expr/src/reduce/mean.rs +++ b/crates/polars-expr/src/reduce/mean.rs @@ -29,10 +29,9 @@ pub struct MeanReduceState { impl ReductionState for MeanReduceState { fn update(&mut self, batch: &Series) -> PolarsResult<()> { - // TODO: don't go through mean but add sum_as_f64 to series trait. let count = batch.len() as u64 - batch.null_count() as u64; self.count += count; - self.sum += batch.mean().unwrap_or(0.0) * count as f64; + self.sum += batch._sum_as_f64(); Ok(()) }