From dfd4442da8332116c7e1f9fcd9de7c6da856442b Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 3 Apr 2024 09:26:31 +0800 Subject: [PATCH] Make FirstValue an UDAF, Change `AggregateUDFImpl::accumulator` signature, support ORDER BY for UDAFs (#9874) * first draft Signed-off-by: jayzhan211 * clippy fix Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * use one vector for ordering req Signed-off-by: jayzhan211 * add sort exprs to accumulator Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix doc test Signed-off-by: jayzhan211 * change to ref Signed-off-by: jayzhan211 * fix typo Signed-off-by: jayzhan211 * fix doc Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * move schema and logical ordering exprs Signed-off-by: jayzhan211 * remove redudant info Signed-off-by: jayzhan211 * rename Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * add ignore nulls Signed-off-by: jayzhan211 * fix conflict Signed-off-by: jayzhan211 * backup Signed-off-by: jayzhan211 * complete return_type Signed-off-by: jayzhan211 * complete replace Signed-off-by: jayzhan211 * split to first value udf Signed-off-by: jayzhan211 * replace accumulator Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * small fix Signed-off-by: jayzhan211 * remove ordering types Signed-off-by: jayzhan211 * make state fields more flexible Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * replace done Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * rm comments Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * rm test1 Signed-off-by: jayzhan211 * fix state fields Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * args struct for accumulator Signed-off-by: jayzhan211 * simplify Signed-off-by: jayzhan211 * add sig Signed-off-by: jayzhan211 * add comments Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix docs Signed-off-by: jayzhan211 * use exprs utils Signed-off-by: jayzhan211 * rm state type Signed-off-by: jayzhan211 * add comment Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion-examples/examples/advanced_udaf.rs | 19 ++- datafusion/core/src/execution/context/mod.rs | 20 +++ datafusion/core/src/physical_planner.rs | 50 ++++--- .../user_defined/user_defined_aggregates.rs | 20 +-- datafusion/expr/src/expr.rs | 3 +- datafusion/expr/src/expr_fn.rs | 128 ++++++++++++++++-- datafusion/expr/src/function.rs | 41 +++++- datafusion/expr/src/tree_node/expr.rs | 1 + datafusion/expr/src/udaf.rs | 101 ++++++++------ .../optimizer/src/analyzer/type_coercion.rs | 15 +- .../optimizer/src/common_subexpr_eliminate.rs | 4 +- .../physical-expr/src/aggregate/build_in.rs | 1 + .../physical-expr/src/aggregate/first_last.rs | 58 +++++++- .../physical-expr/src/aggregate/utils.rs | 2 +- datafusion/physical-expr/src/lib.rs | 2 + .../physical-plan/src/aggregates/mod.rs | 2 + datafusion/physical-plan/src/udaf.rs | 75 ++++++---- datafusion/physical-plan/src/windows/mod.rs | 15 +- .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/physical_plan/mod.rs | 6 +- .../tests/cases/roundtrip_logical_plan.rs | 1 + .../tests/cases/roundtrip_physical_plan.rs | 6 +- datafusion/sql/src/expr/function.rs | 11 +- .../substrait/src/logical_plan/consumer.rs | 2 +- 24 files changed, 450 insertions(+), 134 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 10164a850bfb..342a23b6e73d 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow_schema::{Field, Schema}; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; use datafusion_physical_expr::NullState; use std::{any::Any, sync::Arc}; @@ -30,7 +31,8 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{cast::as_float64_array, ScalarValue}; use datafusion_expr::{ - Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, + function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl, + GroupsAccumulator, Signature, }; /// This example shows how to use the full AggregateUDFImpl API to implement a user @@ -85,13 +87,21 @@ impl AggregateUDFImpl for GeoMeanUdaf { /// is supported, DataFusion will use this row oriented /// accumulator when the aggregate function is used as a window function /// or when there are only aggregates (no GROUP BY columns) in the plan. - fn accumulator(&self, _arg: &DataType) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(GeometricMean::new())) } /// This is the description of the state. accumulator's state() must match the types here. - fn state_type(&self, _return_type: &DataType) -> Result> { - Ok(vec![DataType::Float64, DataType::UInt32]) + fn state_fields( + &self, + _name: &str, + value_type: DataType, + _ordering_fields: Vec, + ) -> Result> { + Ok(vec![ + Field::new("prod", value_type, true), + Field::new("n", DataType::UInt32, true), + ]) } /// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator` @@ -191,7 +201,6 @@ impl Accumulator for GeometricMean { // create local session context with an in-memory table fn create_context() -> Result { - use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::datasource::MemTable; // define a schema. let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index f8bf0d2ee1e5..4eaaf94ecf5d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -69,11 +69,14 @@ use datafusion_common::{ OwnedTableReference, SchemaReference, }; use datafusion_execution::registry::SerializerRegistry; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::{create_first_value, Signature, Volatility}; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, var_provider::is_system_variables, Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; +use datafusion_physical_expr::create_first_value_accumulator; use datafusion_sql::{ parser::{CopyToSource, CopyToStatement, DFParser}, planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel}, @@ -82,6 +85,7 @@ use datafusion_sql::{ use async_trait::async_trait; use chrono::{DateTime, Utc}; +use log::debug; use parking_lot::RwLock; use sqlparser::dialect::dialect_from_str; use url::Url; @@ -1451,6 +1455,22 @@ impl SessionState { datafusion_functions_array::register_all(&mut new_self) .expect("can not register array expressions"); + let first_value = create_first_value( + "FIRST_VALUE", + Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), + Arc::new(create_first_value_accumulator), + ); + + match new_self.register_udaf(Arc::new(first_value)) { + Ok(Some(existing_udaf)) => { + debug!("Overwrite existing UDAF: {}", existing_udaf.name()); + } + Ok(None) => {} + Err(err) => { + panic!("Failed to register UDAF: {}", err); + } + } + new_self } /// Returns new [`SessionState`] using the provided diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 4733c1433ad0..275d639a7a9f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -247,24 +247,20 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { distinct, args, filter, - order_by, + order_by: _, null_treatment: _, }) => match func_def { AggregateFunctionDefinition::BuiltIn(..) => { create_function_physical_name(func_def.name(), *distinct, args) } AggregateFunctionDefinition::UDF(fun) => { - // TODO: Add support for filter and order by in AggregateUDF + // TODO: Add support for filter by in AggregateUDF if filter.is_some() { return exec_err!( "aggregate expression with filter is not supported" ); } - if order_by.is_some() { - return exec_err!( - "aggregate expression with order_by is not supported" - ); - } + let names = args .iter() .map(|e| create_physical_name(e, false)) @@ -1667,20 +1663,22 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( )?), None => None, }; - let order_by = match order_by { - Some(e) => Some(create_physical_sort_exprs( - e, - logical_input_schema, - execution_props, - )?), - None => None, - }; + let ignore_nulls = null_treatment .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { - let ordering_reqs = order_by.clone().unwrap_or(vec![]); + let physical_sort_exprs = match order_by { + Some(exprs) => Some(create_physical_sort_exprs( + exprs, + logical_input_schema, + execution_props, + )?), + None => None, + }; + let ordering_reqs: Vec = + physical_sort_exprs.clone().unwrap_or(vec![]); let agg_expr = aggregates::create_aggregate_expr( fun, *distinct, @@ -1690,16 +1688,30 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( name, ignore_nulls, )?; - (agg_expr, filter, order_by) + (agg_expr, filter, physical_sort_exprs) } AggregateFunctionDefinition::UDF(fun) => { + let sort_exprs = order_by.clone().unwrap_or(vec![]); + let physical_sort_exprs = match order_by { + Some(exprs) => Some(create_physical_sort_exprs( + exprs, + logical_input_schema, + execution_props, + )?), + None => None, + }; + let ordering_reqs: Vec = + physical_sort_exprs.clone().unwrap_or(vec![]); let agg_expr = udaf::create_aggregate_expr( fun, &args, + &sort_exprs, + &ordering_reqs, physical_input_schema, name, - ); - (agg_expr?, filter, order_by) + ignore_nulls, + )?; + (agg_expr, filter, physical_sort_exprs) } AggregateFunctionDefinition::Name(_) => { return internal_err!( diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index a58a8cf51681..6085fca8761f 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -45,7 +45,8 @@ use datafusion::{ }; use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF, + create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, + SimpleAggregateUDF, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -491,7 +492,7 @@ impl TimeSum { // Returns the same type as its input let return_type = timestamp_type.clone(); - let state_type = vec![timestamp_type.clone()]; + let state_fields = vec![Field::new("sum", timestamp_type, true)]; let volatility = Volatility::Immutable; @@ -505,7 +506,7 @@ impl TimeSum { return_type, volatility, accumulator, - state_type, + state_fields, )); // register the selector as "time_sum" @@ -591,6 +592,11 @@ impl FirstSelector { fn register(ctx: &mut SessionContext) { let return_type = Self::output_datatype(); let state_type = Self::state_datatypes(); + let state_fields = state_type + .into_iter() + .enumerate() + .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .collect::>(); // Possible input signatures let signatures = vec![TypeSignature::Exact(Self::input_datatypes())]; @@ -607,7 +613,7 @@ impl FirstSelector { Signature::one_of(signatures, volatility), return_type, accumulator, - state_type, + state_fields, )); // register the selector as "first" @@ -717,15 +723,11 @@ impl AggregateUDFImpl for TestGroupsAccumulator { Ok(DataType::UInt64) } - fn accumulator(&self, _arg: &DataType) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); } - fn state_type(&self, _return_type: &DataType) -> Result> { - Ok(vec![DataType::UInt64]) - } - fn groups_accumulator_supported(&self) -> bool { true } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7ede4cd8ffc9..427c3fde7c0d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -577,6 +577,7 @@ impl AggregateFunction { distinct: bool, filter: Option>, order_by: Option>, + null_treatment: Option, ) -> Self { Self { func_def: AggregateFunctionDefinition::UDF(udf), @@ -584,7 +585,7 @@ impl AggregateFunction { distinct, filter, order_by, - null_treatment: None, + null_treatment, } } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index db9eb84c2180..5294ca754532 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -21,15 +21,17 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Placeholder, ScalarFunction, TryCast, }; -use crate::function::PartitionEvaluatorFactory; +use crate::function::{ + AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, +}; +use crate::udaf::format_state_name; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, - logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, - BuiltinScalarFunction, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, - ScalarUDF, Signature, Volatility, + logical_plan::Subquery, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, + Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{Column, Result}; use std::any::Any; use std::fmt::Debug; @@ -695,16 +697,32 @@ pub fn create_udaf( ) -> AggregateUDF { let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); + let state_fields = state_type + .into_iter() + .enumerate() + .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .collect::>(); AggregateUDF::from(SimpleAggregateUDF::new( name, input_type, return_type, volatility, accumulator, - state_type, + state_fields, )) } +/// Creates a new UDAF with a specific signature, state type and return type. +/// The signature and state type must match the `Accumulator's implementation`. +/// TOOD: We plan to move aggregate function to its own crate. This function will be deprecated then. +pub fn create_first_value( + name: &str, + signature: Signature, + accumulator: AccumulatorFactoryFunction, +) -> AggregateUDF { + AggregateUDF::from(FirstValue::new(name, signature, accumulator)) +} + /// Implements [`AggregateUDFImpl`] for functions that have a single signature and /// return type. pub struct SimpleAggregateUDF { @@ -712,7 +730,7 @@ pub struct SimpleAggregateUDF { signature: Signature, return_type: DataType, accumulator: AccumulatorFactoryFunction, - state_type: Vec, + state_fields: Vec, } impl Debug for SimpleAggregateUDF { @@ -734,7 +752,7 @@ impl SimpleAggregateUDF { return_type: DataType, volatility: Volatility, accumulator: AccumulatorFactoryFunction, - state_type: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -743,7 +761,7 @@ impl SimpleAggregateUDF { signature, return_type, accumulator, - state_type, + state_fields, } } @@ -752,7 +770,7 @@ impl SimpleAggregateUDF { signature: Signature, return_type: DataType, accumulator: AccumulatorFactoryFunction, - state_type: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); Self { @@ -760,7 +778,7 @@ impl SimpleAggregateUDF { signature, return_type, accumulator, - state_type, + state_fields, } } } @@ -782,12 +800,92 @@ impl AggregateUDFImpl for SimpleAggregateUDF { Ok(self.return_type.clone()) } - fn accumulator(&self, arg: &DataType) -> Result> { - (self.accumulator)(arg) + fn accumulator( + &self, + acc_args: AccumulatorArgs, + ) -> Result> { + (self.accumulator)(acc_args) + } + + fn state_fields( + &self, + _name: &str, + _value_type: DataType, + _ordering_fields: Vec, + ) -> Result> { + Ok(self.state_fields.clone()) + } +} + +pub struct FirstValue { + name: String, + signature: Signature, + accumulator: AccumulatorFactoryFunction, +} + +impl Debug for FirstValue { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("FirstValue") + .field("name", &self.name) + .field("signature", &self.signature) + .field("accumulator", &"") + .finish() + } +} + +impl FirstValue { + pub fn new( + name: impl Into, + signature: Signature, + accumulator: AccumulatorFactoryFunction, + ) -> Self { + let name = name.into(); + Self { + name, + signature, + accumulator, + } + } +} + +impl AggregateUDFImpl for FirstValue { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn accumulator( + &self, + acc_args: AccumulatorArgs, + ) -> Result> { + (self.accumulator)(acc_args) } - fn state_type(&self, _return_type: &DataType) -> Result> { - Ok(self.state_type.clone()) + fn state_fields( + &self, + name: &str, + value_type: DataType, + ordering_fields: Vec, + ) -> Result> { + let mut fields = vec![Field::new( + format_state_name(name, "first_value"), + value_type, + true, + )]; + fields.extend(ordering_fields); + fields.push(Field::new("is_set", DataType::Boolean, true)); + Ok(fields) } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index adf4dd3fef20..7598c805adf6 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,8 +17,9 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{Accumulator, ColumnarValue, PartitionEvaluator}; -use arrow::datatypes::DataType; +use crate::ColumnarValue; +use crate::{Accumulator, Expr, PartitionEvaluator}; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::Result; use std::sync::Arc; @@ -37,10 +38,40 @@ pub type ScalarFunctionImplementation = pub type ReturnTypeFunction = Arc Result> + Send + Sync>; -/// Factory that returns an accumulator for the given aggregate, given -/// its return datatype. +/// Arguments passed to create an accumulator +pub struct AccumulatorArgs<'a> { + // default arguments + /// the return type of the function + pub data_type: &'a DataType, + /// the schema of the input arguments + pub schema: &'a Schema, + /// whether to ignore nulls + pub ignore_nulls: bool, + + // ordering arguments + /// the expressions of `order by`, if no ordering is required, this will be an empty slice + pub sort_exprs: &'a [Expr], +} + +impl<'a> AccumulatorArgs<'a> { + pub fn new( + data_type: &'a DataType, + schema: &'a Schema, + ignore_nulls: bool, + sort_exprs: &'a [Expr], + ) -> Self { + Self { + data_type, + schema, + ignore_nulls, + sort_exprs, + } + } +} + +/// Factory that returns an accumulator for the given aggregate function. pub type AccumulatorFactoryFunction = - Arc Result> + Send + Sync>; + Arc Result> + Send + Sync>; /// Factory that creates a PartitionEvaluator for the given window /// function diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 1c672851e9b5..0909d8f662f6 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -379,6 +379,7 @@ impl TreeNode for Expr { false, new_filter, new_order_by, + null_treatment, ))) } AggregateFunctionDefinition::Name(_) => { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index c46dd9cd3a6f..ba80f39dde43 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,16 +17,16 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions +use crate::function::AccumulatorArgs; use crate::groups_accumulator::GroupsAccumulator; use crate::{Accumulator, Expr}; -use crate::{ - AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, -}; -use arrow::datatypes::DataType; +use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{not_impl_err, Result}; use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; +use std::vec; /// Logical representation of a user-defined [aggregate function] (UDAF). /// @@ -90,14 +90,12 @@ impl AggregateUDF { signature: &Signature, return_type: &ReturnTypeFunction, accumulator: &AccumulatorFactoryFunction, - state_type: &StateTypeFunction, ) -> Self { Self::new_from_impl(AggregateUDFLegacyWrapper { name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), accumulator: accumulator.clone(), - state_type: state_type.clone(), }) } @@ -131,12 +129,14 @@ impl AggregateUDF { /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { + // TODO: Support dictinct, filter, order by and null_treatment Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf( Arc::new(self.clone()), args, false, None, None, + None, )) } @@ -166,16 +166,21 @@ impl AggregateUDF { self.inner.return_type(args) } - /// Return an accumulator the given aggregate, given - /// its return datatype. - pub fn accumulator(&self, return_type: &DataType) -> Result> { - self.inner.accumulator(return_type) + /// Return an accumulator the given aggregate, given its return datatype + pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + self.inner.accumulator(acc_args) } - /// Return the type of the intermediate state used by this aggregator, given - /// its return datatype. Supports multi-phase aggregations - pub fn state_type(&self, return_type: &DataType) -> Result> { - self.inner.state_type(return_type) + /// Return the fields of the intermediate state used by this aggregator, given + /// its state name, value type and ordering fields. See [`AggregateUDFImpl::state_fields`] + /// for more details. Supports multi-phase aggregations + pub fn state_fields( + &self, + name: &str, + value_type: DataType, + ordering_fields: Vec, + ) -> Result> { + self.inner.state_fields(name, value_type, ordering_fields) } /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. @@ -213,8 +218,10 @@ where /// # use std::any::Any; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; -/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; +/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs}; +/// # use arrow::datatypes::Schema; +/// # use arrow::datatypes::Field; /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { /// signature: Signature @@ -240,9 +247,12 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _arg: &DataType) -> Result> { unimplemented!() } -/// fn state_type(&self, _return_type: &DataType) -> Result> { -/// Ok(vec![DataType::Float64, DataType::UInt32]) +/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } +/// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec) -> Result> { +/// Ok(vec![ +/// Field::new("value", value_type, true), +/// Field::new("ordering", DataType::UInt32, true) +/// ]) /// } /// } /// @@ -269,15 +279,35 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Return a new [`Accumulator`] that aggregates values for a specific /// group during query execution. - fn accumulator(&self, arg: &DataType) -> Result>; + /// + /// `acc_args`: the arguments to the accumulator. See [`AccumulatorArgs`] for more details. + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result>; + + /// Return the fields of the intermediate state. + /// + /// name: the name of the state + /// + /// value_type: the type of the value, it should be the result of the `return_type` + /// + /// ordering_fields: the fields used for ordering, empty if no ordering expression is provided + fn state_fields( + &self, + name: &str, + value_type: DataType, + ordering_fields: Vec, + ) -> Result> { + let value_fields = vec![Field::new( + format_state_name(name, "value"), + value_type, + true, + )]; - /// Return the type used to serialize the [`Accumulator`]'s intermediate state. - /// See [`Accumulator::state()`] for more details - fn state_type(&self, return_type: &DataType) -> Result>; + Ok(value_fields.into_iter().chain(ordering_fields).collect()) + } /// If the aggregate expression has a specialized /// [`GroupsAccumulator`] implementation. If this returns true, - /// `[Self::create_groups_accumulator`] will be called. + /// `[Self::create_groups_accumulator]` will be called. fn groups_accumulator_supported(&self) -> bool { false } @@ -337,12 +367,8 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.return_type(arg_types) } - fn accumulator(&self, arg: &DataType) -> Result> { - self.inner.accumulator(arg) - } - - fn state_type(&self, return_type: &DataType) -> Result> { - self.inner.state_type(return_type) + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + self.inner.accumulator(acc_args) } fn aliases(&self) -> &[String] { @@ -361,8 +387,6 @@ pub struct AggregateUDFLegacyWrapper { return_type: ReturnTypeFunction, /// actual implementation accumulator: AccumulatorFactoryFunction, - /// the accumulator's state's description as a function of the return type - state_type: StateTypeFunction, } impl Debug for AggregateUDFLegacyWrapper { @@ -394,12 +418,13 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { Ok(res.as_ref().clone()) } - fn accumulator(&self, arg: &DataType) -> Result> { - (self.accumulator)(arg) + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + (self.accumulator)(acc_args) } +} - fn state_type(&self, return_type: &DataType) -> Result> { - let res = (self.state_type)(return_type)?; - Ok(res.as_ref().clone()) - } +/// returns the name of the state +/// TODO: Remove duplicated function in physical-expr +pub(crate) fn format_state_name(name: &str, state_name: &str) -> String { + format!("{name}[{state_name}]") } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index b7b7c4f20e4a..fbbd9a945673 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -366,7 +366,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter { )?; Ok(Transformed::yes(Expr::AggregateFunction( expr::AggregateFunction::new_udf( - fun, new_expr, false, filter, order_by, + fun, + new_expr, + false, + filter, + order_by, + null_treatment, ), ))) } @@ -896,6 +901,7 @@ mod test { false, None, None, + None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation"; @@ -906,7 +912,6 @@ mod test { fn agg_udaf_invalid_input() -> Result<()> { let empty = empty(); let return_type = DataType::Float64; - let state_type = vec![DataType::UInt64, DataType::Float64]; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::::default())); let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( @@ -914,7 +919,10 @@ mod test { Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), return_type, accumulator, - state_type, + vec![ + Field::new("count", DataType::UInt64, true), + Field::new("avg", DataType::Float64, true), + ], )); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), @@ -922,6 +930,7 @@ mod test { false, None, None, + None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "") diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 77613aa66293..c3c0569df707 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -806,7 +806,6 @@ mod test { let return_type = DataType::UInt32; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); - let state_type = vec![DataType::UInt32]; let udf_agg = |inner: Expr| { Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature( @@ -814,12 +813,13 @@ mod test { Signature::exact(vec![DataType::UInt32], Volatility::Stable), return_type.clone(), accumulator.clone(), - state_type.clone(), + vec![Field::new("value", DataType::UInt32, true)], ))), vec![inner], false, None, None, + None, )) }; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index cee679863870..c549e6219375 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -367,6 +367,7 @@ pub fn create_aggregate_expr( input_phy_types[0].clone(), ordering_req.to_vec(), ordering_types, + vec![], ) .with_ignore_nulls(ignore_nulls), ), diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 6d6e32a14987..26bd219f65f0 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use crate::aggregate::utils::{down_cast_any_ref, get_sort_options, ordering_fields}; -use crate::expressions::format_state_name; +use crate::expressions::{self, format_state_name}; use crate::{ reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, }; @@ -29,11 +29,13 @@ use crate::{ use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; use arrow::compute::{self, lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field}; +use arrow_schema::SortOptions; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::Accumulator; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{Accumulator, Expr}; /// FIRST_VALUE aggregate expression #[derive(Debug, Clone)] @@ -45,6 +47,7 @@ pub struct FirstValue { ordering_req: LexOrdering, requirement_satisfied: bool, ignore_nulls: bool, + state_fields: Vec, } impl FirstValue { @@ -55,6 +58,7 @@ impl FirstValue { input_data_type: DataType, ordering_req: LexOrdering, order_by_data_types: Vec, + state_fields: Vec, ) -> Self { let requirement_satisfied = ordering_req.is_empty(); Self { @@ -65,6 +69,7 @@ impl FirstValue { ordering_req, requirement_satisfied, ignore_nulls: false, + state_fields, } } @@ -149,6 +154,10 @@ impl AggregateExpr for FirstValue { } fn state_fields(&self) -> Result> { + if !self.state_fields.is_empty() { + return Ok(self.state_fields.clone()); + } + let mut fields = vec![Field::new( format_state_name(&self.name, "first_value"), self.input_data_type.clone(), @@ -384,6 +393,50 @@ impl Accumulator for FirstValueAccumulator { } } +pub fn create_first_value_accumulator( + acc_args: AccumulatorArgs, +) -> Result> { + let mut all_sort_orders = vec![]; + + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in acc_args.sort_exprs { + if let Expr::Sort(sort) = expr { + if let Expr::Column(col) = sort.expr.as_ref() { + let name = &col.name; + let e = expressions::col(name, acc_args.schema)?; + sort_exprs.push(PhysicalSortExpr { + expr: e, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); + } + } + } + if !sort_exprs.is_empty() { + all_sort_orders.extend(sort_exprs); + } + + let ordering_req = all_sort_orders; + + let ordering_dtypes = ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; + + let requirement_satisfied = ordering_req.is_empty(); + + FirstValueAccumulator::try_new( + acc_args.data_type, + &ordering_dtypes, + ordering_req, + acc_args.ignore_nulls, + ) + .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) +} + /// LAST_VALUE aggregate expression #[derive(Debug, Clone)] pub struct LastValue { @@ -471,6 +524,7 @@ impl LastValue { input_data_type, reverse_order_bys(&ordering_req), order_by_data_types, + vec![], ) } } diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index 60d59c16be5f..613f6118e907 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -188,7 +188,7 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { } /// Construct corresponding fields for lexicographical ordering requirement expression -pub(crate) fn ordering_fields( +pub fn ordering_fields( ordering_req: &[PhysicalSortExpr], // Data type of each expression in the ordering requirement data_types: &[DataType], diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 7819d5116160..655771270a6b 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -58,3 +58,5 @@ pub use sort_expr::{ PhysicalSortRequirement, }; pub use utils::{reverse_order_bys, split_conjunction}; + +pub use aggregate::first_last::create_first_value_accumulator; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index e263876b07d5..f8ad03bf6d97 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -2026,6 +2026,7 @@ mod tests { DataType::Float64, ordering_req.clone(), vec![DataType::Float64], + vec![], ))] } else { vec![Arc::new(LastValue::new( @@ -2209,6 +2210,7 @@ mod tests { DataType::Float64, sort_expr_reverse.clone(), vec![DataType::Float64], + vec![], )), Arc::new(LastValue::new( col_b.clone(), diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index fd9279dfd552..74a5603c0c81 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -17,22 +17,20 @@ //! This module contains functions and structs supporting user-defined aggregate functions. -use datafusion_expr::GroupsAccumulator; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{Expr, GroupsAccumulator}; use fmt::Debug; use std::any::Any; use std::fmt; -use arrow::{ - datatypes::Field, - datatypes::{DataType, Schema}, -}; +use arrow::datatypes::{DataType, Field, Schema}; -use super::{expressions::format_state_name, Accumulator, AggregateExpr}; +use super::{Accumulator, AggregateExpr}; use datafusion_common::{not_impl_err, Result}; pub use datafusion_expr::AggregateUDF; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; -use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; +use datafusion_physical_expr::aggregate::utils::{down_cast_any_ref, ordering_fields}; use std::sync::Arc; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. @@ -40,19 +38,34 @@ use std::sync::Arc; pub fn create_aggregate_expr( fun: &AggregateUDF, input_phy_exprs: &[Arc], - input_schema: &Schema, + sort_exprs: &[Expr], + ordering_req: &[PhysicalSortExpr], + schema: &Schema, name: impl Into, + ignore_nulls: bool, ) -> Result> { let input_exprs_types = input_phy_exprs .iter() - .map(|arg| arg.data_type(input_schema)) + .map(|arg| arg.data_type(schema)) .collect::>>()?; + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(schema)) + .collect::>>()?; + + let ordering_fields = ordering_fields(ordering_req, &ordering_types); + Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), args: input_phy_exprs.to_vec(), data_type: fun.return_type(&input_exprs_types)?, name: name.into(), + schema: schema.clone(), + sort_exprs: sort_exprs.to_vec(), + ordering_req: ordering_req.to_vec(), + ignore_nulls, + ordering_fields, })) } @@ -64,6 +77,13 @@ pub struct AggregateFunctionExpr { /// Output / return type of this aggregate data_type: DataType, name: String, + schema: Schema, + // The logical order by expressions + sort_exprs: Vec, + // The physical order by expressions + ordering_req: LexOrdering, + ignore_nulls: bool, + ordering_fields: Vec, } impl AggregateFunctionExpr { @@ -84,21 +104,11 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - let fields = self - .fun - .state_type(&self.data_type)? - .iter() - .enumerate() - .map(|(i, data_type)| { - Field::new( - format_state_name(&self.name, &format!("{i}")), - data_type.clone(), - true, - ) - }) - .collect::>(); - - Ok(fields) + self.fun.state_fields( + self.name(), + self.data_type.clone(), + self.ordering_fields.clone(), + ) } fn field(&self) -> Result { @@ -106,11 +116,18 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - self.fun.accumulator(&self.data_type) + let acc_args = AccumulatorArgs::new( + &self.data_type, + &self.schema, + self.ignore_nulls, + &self.sort_exprs, + ); + + self.fun.accumulator(acc_args) } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = self.fun.accumulator(&self.data_type)?; + let accumulator = self.create_accumulator()?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to @@ -175,6 +192,10 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_groups_accumulator(&self) -> Result> { self.fun.create_groups_accumulator() } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + } } impl PartialEq for AggregateFunctionExpr { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 21f42f41fb5c..c5c845614c7b 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -92,8 +92,19 @@ pub fn create_window_expr( )) } WindowFunctionDefinition::AggregateUDF(fun) => { - let aggregate = - udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; + // TODO: Ordering not supported for Window UDFs yet + let sort_exprs = &[]; + let ordering_req = &[]; + + let aggregate = udaf::create_aggregate_expr( + fun.as_ref(), + args, + sort_exprs, + ordering_req, + input_schema, + name, + ignore_nulls, + )?; window_expr_from_aggregate_expr( partition_by, order_by, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3694418412b1..6a536b2fa375 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1389,6 +1389,7 @@ pub fn parse_expr( false, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), parse_vec_expr(&pb.order_by, registry, codec)?, + None, ))) } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 00dacffe06c2..4d5d6cadad17 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -517,7 +517,11 @@ impl AsExecutionPlan for PhysicalPlanNode { } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &physical_schema, name) + // TODO: `order by` is not supported for UDAF yet + let sort_exprs = &[]; + let ordering_req = &[]; + let ignore_nulls = false; + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 4cd133dc21d4..f136e314559b 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1772,6 +1772,7 @@ fn roundtrip_aggregate_udf() { false, Some(Box::new(lit(true))), None, + None, )); let ctx = SessionContext::new(); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 0238291c77e1..5dacf692e904 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -412,14 +412,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let return_type = DataType::Int64; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example))); - let state_type = vec![DataType::Int64]; let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "example", Signature::exact(vec![DataType::Int64], Volatility::Immutable), return_type, accumulator, - state_type, + vec![Field::new("value", DataType::Int64, true)], )); let ctx = SessionContext::new(); @@ -431,8 +430,11 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let aggregates: Vec> = vec![udaf::create_aggregate_expr( &udaf, &[col("b", &schema)?], + &[], + &[], &schema, "example_agg", + false, )?]; roundtrip_test_with_context( diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 582404b29749..e97eb1a32b12 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -221,9 +221,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { + let order_by = + self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; + let order_by = (!order_by.is_empty()).then_some(order_by); let args = self.function_args_to_expr(args, schema, planner_context)?; + // TODO: Support filter and distinct for UDAFs return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fm, args, false, None, None, + fm, + args, + false, + None, + order_by, + null_treatment, ))); } diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index e68f3f992817..73782ab27f71 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -754,7 +754,7 @@ pub async fn from_substrait_agg_func( // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { Ok(Arc::new(Expr::AggregateFunction( - expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by), + expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), ))) } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) {