diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index fb5a8db550e3..8f683cabe6d6 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -49,8 +49,6 @@ pub enum AggregateFunction { ArrayAgg, /// N'th value in a group according to some ordering NthValue, - /// Variance (Sample) - Variance, /// Variance (Population) VariancePop, /// Standard Deviation (Sample) @@ -111,7 +109,6 @@ impl AggregateFunction { ApproxDistinct => "APPROX_DISTINCT", ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", - Variance => "VAR", VariancePop => "VAR_POP", Stddev => "STDDEV", StddevPop => "STDDEV_POP", @@ -169,9 +166,7 @@ impl FromStr for AggregateFunction { "stddev" => AggregateFunction::Stddev, "stddev_pop" => AggregateFunction::StddevPop, "stddev_samp" => AggregateFunction::Stddev, - "var" => AggregateFunction::Variance, "var_pop" => AggregateFunction::VariancePop, - "var_samp" => AggregateFunction::Variance, "regr_slope" => AggregateFunction::RegrSlope, "regr_intercept" => AggregateFunction::RegrIntercept, "regr_count" => AggregateFunction::RegrCount, @@ -235,7 +230,6 @@ impl AggregateFunction { AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { Ok(DataType::Boolean) } - AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]), AggregateFunction::VariancePop => { variance_return_type(&coerced_data_types[0]) } @@ -315,7 +309,6 @@ impl AggregateFunction { } AggregateFunction::Avg | AggregateFunction::Sum - | AggregateFunction::Variance | AggregateFunction::VariancePop | AggregateFunction::Stddev | AggregateFunction::StddevPop diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 6bd204c53c61..b7004e200d70 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -173,7 +173,7 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } - AggregateFunction::Variance | AggregateFunction::VariancePop => { + AggregateFunction::VariancePop => { if !is_variance_support_arg_type(&input_types[0]) { return plan_err!( "The function {:?} does not support inputs of type {:?}.", diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index cb8ef65420c2..ff02d25ad00b 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -59,6 +59,7 @@ pub mod covariance; pub mod first_last; pub mod median; pub mod sum; +pub mod variance; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; @@ -74,6 +75,7 @@ pub mod expr_fn { pub use super::first_last::last_value; pub use super::median::median; pub use super::sum::sum; + pub use super::variance::var_sample; } /// Returns all default aggregate functions @@ -85,6 +87,7 @@ pub fn all_default_aggregate_functions() -> Vec> { sum::sum_udaf(), covariance::covar_pop_udaf(), median::median_udaf(), + variance::var_samp_udaf(), ] } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs new file mode 100644 index 000000000000..b5d467d0e780 --- /dev/null +++ b/datafusion/functions-aggregate/src/variance.rs @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`VarianceSample`]: covariance sample aggregations. + +use std::fmt::Debug; + +use arrow::{ + array::{ArrayRef, Float64Array, UInt64Array}, + compute::kernels::cast, + datatypes::{DataType, Field}, +}; + +use datafusion_common::{ + downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::{ + function::{AccumulatorArgs, StateFieldsArgs}, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Signature, Volatility, +}; +use datafusion_physical_expr_common::aggregate::stats::StatsType; + +make_udaf_expr_and_func!( + VarianceSample, + var_sample, + expression, + "Computes the sample variance.", + var_samp_udaf +); + +pub struct VarianceSample { + signature: Signature, + aliases: Vec, +} + +impl Debug for VarianceSample { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("VarianceSample") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for VarianceSample { + fn default() -> Self { + Self::new() + } +} + +impl VarianceSample { + pub fn new() -> Self { + Self { + aliases: vec![String::from("var_sample"), String::from("var_samp")], + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for VarianceSample { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "var" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Variance requires numeric input types"); + } + + Ok(DataType::Float64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean"), DataType::Float64, true), + Field::new(format_state_name(name, "m2"), DataType::Float64, true), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!("VAR(DISTINCT) aggregations are not available"); + } + + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// An accumulator to compute variance +/// The algrithm used is an online implementation and numerically stable. It is based on this paper: +/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". +/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. +/// +/// The algorithm has been analyzed here: +/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". +/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. + +#[derive(Debug)] +pub struct VarianceAccumulator { + m2: f64, + mean: f64, + count: u64, + stats_type: StatsType, +} + +impl VarianceAccumulator { + /// Creates a new `VarianceAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + m2: 0_f64, + mean: 0_f64, + count: 0_u64, + stats_type: s_type, + }) + } + + pub fn get_count(&self) -> u64 { + self.count + } + + pub fn get_mean(&self) -> f64 { + self.mean + } + + pub fn get_m2(&self) -> f64 { + self.m2 + } +} + +impl Accumulator for VarianceAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean), + ScalarValue::from(self.m2), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast(&values[0], &DataType::Float64)?; + let arr = downcast_value!(values, Float64Array).iter().flatten(); + + for value in arr { + let new_count = self.count + 1; + let delta1 = value - self.mean; + let new_mean = delta1 / new_count as f64 + self.mean; + let delta2 = value - new_mean; + let new_m2 = self.m2 + delta1 * delta2; + + self.count += 1; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast(&values[0], &DataType::Float64)?; + let arr = downcast_value!(values, Float64Array).iter().flatten(); + + for value in arr { + let new_count = self.count - 1; + let delta1 = self.mean - value; + let new_mean = delta1 / new_count as f64 + self.mean; + let delta2 = new_mean - value; + let new_m2 = self.m2 - delta1 * delta2; + + self.count -= 1; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], UInt64Array); + let means = downcast_value!(states[1], Float64Array); + let m2s = downcast_value!(states[2], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0_u64 { + continue; + } + let new_count = self.count + c; + let new_mean = self.mean * self.count as f64 / new_count as f64 + + means.value(i) * c as f64 / new_count as f64; + let delta = self.mean - means.value(i); + let new_m2 = self.m2 + + m2s.value(i) + + delta * delta * self.count as f64 * c as f64 / new_count as f64; + + self.count = new_count; + self.mean = new_mean; + self.m2 = new_m2; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0 { + self.count - 1 + } else { + self.count + } + } + }; + + Ok(ScalarValue::Float64(match self.count { + 0 => None, + 1 => { + if let StatsType::Population = self.stats_type { + Some(0.0) + } else { + None + } + } + _ => Some(self.m2 / count as f64), + })) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn supports_retract_batch(&self) -> bool { + true + } +} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 813a394d6943..07409dd1f4dc 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -160,14 +160,6 @@ pub fn create_aggregate_expr( (AggregateFunction::Avg, true) => { return not_impl_err!("AVG(DISTINCT) aggregations are not available"); } - (AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Variance, true) => { - return not_impl_err!("VAR(DISTINCT) aggregations are not available"); - } (AggregateFunction::VariancePop, false) => Arc::new( expressions::VariancePop::new(input_phy_exprs[0].clone(), name, data_type), ), @@ -367,12 +359,13 @@ pub fn create_aggregate_expr( #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; + use expressions::{StddevPop, VariancePop}; use super::*; use crate::expressions::{ try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, DistinctCount, - Max, Min, Stddev, Variance, + Max, Min, Stddev, }; use datafusion_common::{plan_err, DataFusionError, ScalarValue}; @@ -719,44 +712,6 @@ mod tests { Ok(()) } - #[test] - fn test_variance_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Variance]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Variance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - #[test] fn test_var_pop_expr() -> Result<()> { let funcs = vec![AggregateFunction::VariancePop]; @@ -782,8 +737,8 @@ mod tests { &input_schema, "c1", )?; - if fun == AggregateFunction::Variance { - assert!(result_agg_phy_exprs.as_any().is::()); + if fun == AggregateFunction::VariancePop { + assert!(result_agg_phy_exprs.as_any().is::()); assert_eq!("c1", result_agg_phy_exprs.name()); assert_eq!( Field::new("c1", DataType::Float64, true), @@ -820,7 +775,7 @@ mod tests { &input_schema, "c1", )?; - if fun == AggregateFunction::Variance { + if fun == AggregateFunction::Stddev { assert!(result_agg_phy_exprs.as_any().is::()); assert_eq!("c1", result_agg_phy_exprs.name()); assert_eq!( @@ -858,8 +813,8 @@ mod tests { &input_schema, "c1", )?; - if fun == AggregateFunction::Variance { - assert!(result_agg_phy_exprs.as_any().is::()); + if fun == AggregateFunction::StddevPop { + assert!(result_agg_phy_exprs.as_any().is::()); assert_eq!("c1", result_agg_phy_exprs.name()); assert_eq!( Field::new("c1", DataType::Float64, true), @@ -987,32 +942,6 @@ mod tests { assert!(observed.is_err()); } - #[test] - fn test_variance_return_type() -> Result<()> { - let observed = AggregateFunction::Variance.return_type(&[DataType::Float32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::UInt32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::Int64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_variance_no_utf8() { - let observed = AggregateFunction::Variance.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - } - #[test] fn test_stddev_return_type() -> Result<()> { let observed = AggregateFunction::Stddev.return_type(&[DataType::Float32])?; diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index 7ae917409a21..3db3c0e3ae5e 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -35,13 +35,6 @@ use datafusion_common::downcast_value; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; -/// VAR and VAR_SAMP aggregate expression -#[derive(Debug)] -pub struct Variance { - name: String, - expr: Arc, -} - /// VAR_POP aggregate expression #[derive(Debug)] pub struct VariancePop { @@ -49,74 +42,6 @@ pub struct VariancePop { expr: Arc, } -impl Variance { - /// Create a new VARIANCE aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of variance just support FLOAT64 data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr, - } - } -} - -impl AggregateExpr for Variance { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean"), - DataType::Float64, - true, - ), - Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Variance { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.name == x.name && self.expr.eq(&x.expr)) - .unwrap_or(false) - } -} - impl VariancePop { /// Create a new VAR_POP aggregate function pub fn new( diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 1e9644f75afe..324699af5b5c 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -60,7 +60,7 @@ pub use crate::aggregate::stddev::{Stddev, StddevPop}; pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; -pub use crate::aggregate::variance::{Variance, VariancePop}; +pub use crate::aggregate::variance::VariancePop; pub use crate::window::cume_dist::{cume_dist, CumeDist}; pub use crate::window::lead_lag::{lag, lead, WindowShift}; pub use crate::window::nth_value::NthValue; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index fa95194696dd..f8d229f48dc4 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -479,7 +479,7 @@ enum AggregateFunction { COUNT = 4; APPROX_DISTINCT = 5; ARRAY_AGG = 6; - VARIANCE = 7; + // VARIANCE = 7; VARIANCE_POP = 8; // COVARIANCE = 9; // COVARIANCE_POP = 10; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b0e77eb69eff..6de030679c80 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -539,7 +539,6 @@ impl serde::Serialize for AggregateFunction { Self::Count => "COUNT", Self::ApproxDistinct => "APPROX_DISTINCT", Self::ArrayAgg => "ARRAY_AGG", - Self::Variance => "VARIANCE", Self::VariancePop => "VARIANCE_POP", Self::Stddev => "STDDEV", Self::StddevPop => "STDDEV_POP", @@ -582,7 +581,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "COUNT", "APPROX_DISTINCT", "ARRAY_AGG", - "VARIANCE", "VARIANCE_POP", "STDDEV", "STDDEV_POP", @@ -654,7 +652,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "COUNT" => Ok(AggregateFunction::Count), "APPROX_DISTINCT" => Ok(AggregateFunction::ApproxDistinct), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), - "VARIANCE" => Ok(AggregateFunction::Variance), "VARIANCE_POP" => Ok(AggregateFunction::VariancePop), "STDDEV" => Ok(AggregateFunction::Stddev), "STDDEV_POP" => Ok(AggregateFunction::StddevPop), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 6d8a0c305761..e397f3545986 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1923,7 +1923,7 @@ pub enum AggregateFunction { Count = 4, ApproxDistinct = 5, ArrayAgg = 6, - Variance = 7, + /// VARIANCE = 7; VariancePop = 8, /// COVARIANCE = 9; /// COVARIANCE_POP = 10; @@ -1966,7 +1966,6 @@ impl AggregateFunction { AggregateFunction::Count => "COUNT", AggregateFunction::ApproxDistinct => "APPROX_DISTINCT", AggregateFunction::ArrayAgg => "ARRAY_AGG", - AggregateFunction::Variance => "VARIANCE", AggregateFunction::VariancePop => "VARIANCE_POP", AggregateFunction::Stddev => "STDDEV", AggregateFunction::StddevPop => "STDDEV_POP", @@ -2005,7 +2004,6 @@ impl AggregateFunction { "COUNT" => Some(Self::Count), "APPROX_DISTINCT" => Some(Self::ApproxDistinct), "ARRAY_AGG" => Some(Self::ArrayAgg), - "VARIANCE" => Some(Self::Variance), "VARIANCE_POP" => Some(Self::VariancePop), "STDDEV" => Some(Self::Stddev), "STDDEV_POP" => Some(Self::StddevPop), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index e2a2f875ea0c..f8a78bdbdced 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -149,7 +149,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Count => Self::Count, protobuf::AggregateFunction::ApproxDistinct => Self::ApproxDistinct, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, - protobuf::AggregateFunction::Variance => Self::Variance, protobuf::AggregateFunction::VariancePop => Self::VariancePop, protobuf::AggregateFunction::Stddev => Self::Stddev, protobuf::AggregateFunction::StddevPop => Self::StddevPop, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d2783305f638..15d0d6dd491d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -120,7 +120,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Count => Self::Count, AggregateFunction::ApproxDistinct => Self::ApproxDistinct, AggregateFunction::ArrayAgg => Self::ArrayAgg, - AggregateFunction::Variance => Self::Variance, AggregateFunction::VariancePop => Self::VariancePop, AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, @@ -418,7 +417,6 @@ pub fn serialize_expr( AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, AggregateFunction::VariancePop => { protobuf::AggregateFunction::VariancePop } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 071463614165..834f59abb10d 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -29,7 +29,7 @@ use datafusion::physical_plan::expressions::{ DistinctCount, DistinctSum, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, - StringAgg, Sum, TryCastExpr, Variance, VariancePop, WindowShift, + StringAgg, Sum, TryCastExpr, VariancePop, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -281,8 +281,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Max } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Avg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Variance } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::VariancePop } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 14d72274806d..deae97fecc96 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -31,8 +31,9 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; -use datafusion::functions_aggregate::expr_fn::{covar_pop, covar_samp, first_value}; -use datafusion::functions_aggregate::median::median; +use datafusion::functions_aggregate::expr_fn::{ + covar_pop, covar_samp, first_value, median, var_sample, +}; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; @@ -651,6 +652,7 @@ async fn roundtrip_expr_api() -> Result<()> { covar_pop(lit(1.5), lit(2.2)), sum(lit(1)), median(lit(2)), + var_sample(lit(2.2)), ]; // ensure expressions created with the expr api can be round tripped diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index c652c8041ff1..3b1f0dfd6d89 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -40,7 +40,7 @@ bigdecimal = { workspace = true } bytes = { workspace = true, optional = true } chrono = { workspace = true, optional = true } clap = { version = "4.4.8", features = ["derive", "env"] } -datafusion = { workspace = true, default-features = true } +datafusion = { workspace = true, default-features = true, features = ["avro"] } datafusion-common = { workspace = true, default-features = true } datafusion-common-runtime = { workspace = true, default-features = true } futures = { workspace = true } @@ -60,7 +60,13 @@ tokio-postgres = { version = "0.7.7", optional = true } [features] avro = ["datafusion/avro"] -postgres = ["bytes", "chrono", "tokio-postgres", "postgres-types", "postgres-protocol"] +postgres = [ + "bytes", + "chrono", + "tokio-postgres", + "postgres-types", + "postgres-protocol", +] [dev-dependencies] env_logger = { workspace = true } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 98e64b025b22..56ec0342577f 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2338,6 +2338,18 @@ select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; statement ok drop table t; +# variance_f64_1 +statement ok +create table t (c double) as values (1), (2), (3), (4), (5); + +query RT +select var(c), arrow_typeof(var(c)) from t; +---- +2.5 Float64 + +statement ok +drop table t; + # aggregate stddev f64_1 statement ok create table t (c1 double) as values (1), (2); @@ -2494,6 +2506,18 @@ select var(c1), arrow_typeof(var(c1)) from t; statement ok drop table t; +# variance_f64_2 +statement ok +create table t (c double) as values (1.1), (2), (3); + +query RT +select var(c), arrow_typeof(var(c)) from t; +---- +0.903333333333 Float64 + +statement ok +drop table t; + # aggregate variance f64_4 statement ok create table t (c1 double) as values (1.1), (2), (3); @@ -2506,6 +2530,30 @@ select var(c1), arrow_typeof(var(c1)) from t; statement ok drop table t; +# variance_1_input +statement ok +create table t (a double not null) as values (1); + +query RT +select var(a), arrow_typeof(var(a)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# variance_i32_all_nulls +statement ok +create table t (a int) as values (null), (null); + +query RT +select var(a), arrow_typeof(var(a)) from t; +---- +NULL Float64 + +statement ok +drop table t; + # aggregate variance i32 statement ok create table t (c1 int) as values (1), (2), (3), (4), (5); diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index ce738c7a6f3e..1fd8b0a346da 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -518,4 +518,3 @@ set datafusion.optimizer.prefer_hash_join = true; statement ok set datafusion.execution.batch_size = 8192; -