From 76b6655f900397a09092bbf60ef502a9156edfd4 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 18 Jul 2024 16:27:54 -0400 Subject: [PATCH 01/21] Moving over AggregateExt to ExprFunctionExt and adding in function settings for window functions --- datafusion-examples/examples/expr_api.rs | 4 +- datafusion/core/tests/expr_api/mod.rs | 2 +- datafusion/expr/src/expr.rs | 8 +- datafusion/expr/src/expr_fn.rs | 247 +++++++++++++++++- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udaf.rs | 177 +------------ .../functions-aggregate/src/first_last.rs | 2 +- .../optimizer/src/optimize_projections/mod.rs | 2 +- .../src/replace_distinct_aggregate.rs | 2 +- .../src/single_distinct_to_groupby.rs | 2 +- .../tests/cases/roundtrip_logical_plan.rs | 2 +- datafusion/sql/src/unparser/expr.rs | 2 +- docs/source/user-guide/expressions.md | 2 +- 13 files changed, 260 insertions(+), 194 deletions(-) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index a48171c625a8..eb08fe6593ca 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -33,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator}; +use datafusion_expr::{ExprFunctionExt, ColumnarValue, ExprSchemable, Operator}; /// This example demonstrates the DataFusion [`Expr`] API. /// @@ -95,7 +95,7 @@ fn expr_fn_demo() -> Result<()> { let agg = first_value.call(vec![col("price")]); assert_eq!(agg.to_string(), "first_value(price)"); - // You can use the AggregateExt trait to create more complex aggregates + // You can use the ExprFunctionExt trait to create more complex aggregates // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts ) let agg = first_value .call(vec![col("price")]) diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index f36f2d539845..d76b3c9dc1ec 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -21,7 +21,7 @@ use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Field}; use datafusion::prelude::*; use datafusion_common::{assert_contains, DFSchema, ScalarValue}; -use datafusion_expr::AggregateExt; +use datafusion_expr::ExprFunctionExt; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::first_last::first_value_udaf; use datafusion_functions_aggregate::sum::sum_udaf; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index e3620501d9a8..33d229b61174 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -289,9 +289,9 @@ pub enum Expr { /// Calls an aggregate function with arguments, and optional /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. /// - /// See also [`AggregateExt`] to set these fields. + /// See also [`ExprFunctionExt`] to set these fields. /// - /// [`AggregateExt`]: crate::udaf::AggregateExt + /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), @@ -641,9 +641,9 @@ impl AggregateFunctionDefinition { /// Aggregate function /// -/// See also [`AggregateExt`] to set these fields on `Expr` +/// See also [`ExprFunctionExt`] to set these fields on `Expr` /// -/// [`AggregateExt`]: crate::udaf::AggregateExt +/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 9187e8352205..289f62f873bf 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,7 +19,7 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - Placeholder, TryCast, Unnest, + Placeholder, TryCast, Unnest, WindowFunction, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, @@ -30,12 +30,13 @@ use crate::{ AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; -use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; +use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl}; use arrow::compute::kernels::cast_utils::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{Column, Result, ScalarValue}; +use datafusion_common::{plan_err, Column, Result, ScalarValue}; +use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::Debug; use std::ops::Not; @@ -664,6 +665,246 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) } + +/// Extensions for configuring [`Expr::AggregateFunction`] +/// +/// Adds methods to [`Expr`] that make it easy to set optional aggregate options +/// such as `ORDER BY`, `FILTER` and `DISTINCT` +/// +/// # Example +/// ```no_run +/// # use datafusion_common::Result; +/// # use datafusion_expr::{AggregateUDF, col, Expr, lit}; +/// # use sqlparser::ast::NullTreatment; +/// # fn count(arg: Expr) -> Expr { todo!{} } +/// # fn first_value(arg: Expr) -> Expr { todo!{} } +/// # fn main() -> Result<()> { +/// use datafusion_expr::ExprFunctionExt; +/// +/// // Create COUNT(x FILTER y > 5) +/// let agg = count(col("x")) +/// .filter(col("y").gt(lit(5))) +/// .build()?; +/// // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS) +/// let sort_expr = col("y").sort(true, true); +/// let agg = first_value(col("x")) +/// .order_by(vec![sort_expr]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; +/// # Ok(()) +/// # } +/// ``` +pub trait ExprFunctionExt { + /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] + fn order_by(self, order_by: Vec) -> ExprFuncBuilder; + /// Add `FILTER ` + fn filter(self, filter: Expr) -> ExprFuncBuilder; + /// Add `DISTINCT` + fn distinct(self) -> ExprFuncBuilder; + /// Add `RESPECT NULLS` or `IGNORE NULLS` + fn null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder; + // Add `PARTITION BY` + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; + // Add appropriate window frame conditions + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder; +} + +#[derive(Debug, Clone)] +pub enum ExprFuncKind { + Aggregate(AggregateFunction), + Window(WindowFunction), +} + +/// Implementation of [`ExprFunctionExt`]. +/// +/// See [`ExprFunctionExt`] for usage and examples +#[derive(Debug, Clone)] +pub struct ExprFuncBuilder { + fun: Option, + order_by: Option>, + filter: Option, + distinct: bool, + null_treatment: Option, + partition_by: Option>, + window_frame: Option, +} + +impl ExprFuncBuilder { + /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`] + + fn new(fun: Option) -> Self { + Self { + fun, + order_by: None, + filter: None, + distinct: false, + null_treatment: None, + partition_by: None, + window_frame: None, + } + } + + /// Updates and returns the in progress [`Expr::AggregateFunction`] + /// + /// # Errors: + /// + /// Returns an error of this builder [`ExprFunctionExt`] was used with an + /// `Expr` variant other than [`Expr::AggregateFunction`] + pub fn build(self) -> Result { + let Self { + fun, + order_by, + filter, + distinct, + null_treatment, + partition_by, + window_frame, + } = self; + + let Some(fun) = fun else { + return plan_err!( + "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction" + ); + }; + + if let Some(order_by) = &order_by { + for expr in order_by.iter() { + if !matches!(expr, Expr::Sort(_)) { + return plan_err!( + "ORDER BY expressions must be Expr::Sort, found {expr:?}" + ); + } + } + } + + let fun_expr = match fun { + ExprFuncKind::Aggregate(mut udaf) => { + udaf.order_by = order_by; + udaf.filter = filter.map(Box::new); + udaf.distinct = distinct; + udaf.null_treatment = null_treatment; + Expr::AggregateFunction(udaf) + } + ExprFuncKind::Window(mut udwf) => { + let has_order_by = order_by.as_ref().map(|o| o.len() > 0); + udwf.order_by = order_by.unwrap_or_default(); + udwf.partition_by = partition_by.unwrap_or_default(); + udwf.window_frame = window_frame.unwrap_or(WindowFrame::new(has_order_by)); + udwf.null_treatment = null_treatment; + Expr::WindowFunction(udwf) + } + }; + + Ok(fun_expr) + } + + /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] + pub fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { + self.order_by = Some(order_by); + self + } + + /// Add `FILTER ` + pub fn filter(mut self, filter: Expr) -> ExprFuncBuilder { + self.filter = Some(filter); + self + } + + /// Add `DISTINCT` + pub fn distinct(mut self) -> ExprFuncBuilder { + self.distinct = true; + self + } + + /// Add `RESPECT NULLS` or `IGNORE NULLS` + pub fn null_treatment(mut self, null_treatment: NullTreatment) -> ExprFuncBuilder { + self.null_treatment = Some(null_treatment); + self + } + + pub fn partition_by(mut self, partition_by: Vec) -> ExprFuncBuilder { + self.partition_by = Some(partition_by); + self + } + + pub fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { + self.window_frame = Some(window_frame); + self + } +} + +impl ExprFunctionExt for Expr { + fn order_by(self, order_by: Vec) -> ExprFuncBuilder { + let mut builder = match self { + Expr::AggregateFunction(udaf) => ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))), + Expr::WindowFunction(udwf) => ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))), + _ => ExprFuncBuilder::new(None), + }; + if builder.fun.is_some() { + builder.order_by = Some(order_by); + } + builder + } + fn filter(self, filter: Expr) -> ExprFuncBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + builder.filter = Some(filter); + builder + } + _ => ExprFuncBuilder::new(None), + } + } + fn distinct(self) -> ExprFuncBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + builder.distinct = true; + builder + } + _ => ExprFuncBuilder::new(None), + } + } + fn null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder { + let mut builder = match self { + Expr::AggregateFunction(udaf) => ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))), + Expr::WindowFunction(udwf) => ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))), + _ => ExprFuncBuilder::new(None), + }; + if builder.fun.is_some() { + builder.null_treatment = Some(null_treatment); + } + builder + } + + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { + match self { + Expr::WindowFunction(udwf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + builder.partition_by = Some(partition_by); + builder + } + _ => ExprFuncBuilder::new(None), + } + } + + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { + match self { + Expr::WindowFunction(udwf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + builder.window_frame = Some(window_frame); + builder + } + _ => ExprFuncBuilder::new(None), + } + } +} + + #[cfg(test)] mod test { use super::*; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index e1943c890e7c..354e795fe64d 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -86,7 +86,7 @@ pub use signature::{ }; pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF}; +pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 1657e034fbe2..29267f30100a 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -24,9 +24,8 @@ use std::sync::Arc; use std::vec; use arrow::datatypes::{DataType, Field}; -use sqlparser::ast::NullTreatment; -use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; use crate::expr::AggregateFunction; use crate::function::{ @@ -655,177 +654,3 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { (self.accumulator)(acc_args) } } - -/// Extensions for configuring [`Expr::AggregateFunction`] -/// -/// Adds methods to [`Expr`] that make it easy to set optional aggregate options -/// such as `ORDER BY`, `FILTER` and `DISTINCT` -/// -/// # Example -/// ```no_run -/// # use datafusion_common::Result; -/// # use datafusion_expr::{AggregateUDF, col, Expr, lit}; -/// # use sqlparser::ast::NullTreatment; -/// # fn count(arg: Expr) -> Expr { todo!{} } -/// # fn first_value(arg: Expr) -> Expr { todo!{} } -/// # fn main() -> Result<()> { -/// use datafusion_expr::AggregateExt; -/// -/// // Create COUNT(x FILTER y > 5) -/// let agg = count(col("x")) -/// .filter(col("y").gt(lit(5))) -/// .build()?; -/// // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS) -/// let sort_expr = col("y").sort(true, true); -/// let agg = first_value(col("x")) -/// .order_by(vec![sort_expr]) -/// .null_treatment(NullTreatment::IgnoreNulls) -/// .build()?; -/// # Ok(()) -/// # } -/// ``` -pub trait AggregateExt { - /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(self, order_by: Vec) -> AggregateBuilder; - /// Add `FILTER ` - fn filter(self, filter: Expr) -> AggregateBuilder; - /// Add `DISTINCT` - fn distinct(self) -> AggregateBuilder; - /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder; -} - -/// Implementation of [`AggregateExt`]. -/// -/// See [`AggregateExt`] for usage and examples -#[derive(Debug, Clone)] -pub struct AggregateBuilder { - udaf: Option, - order_by: Option>, - filter: Option, - distinct: bool, - null_treatment: Option, -} - -impl AggregateBuilder { - /// Create a new `AggregateBuilder`, see [`AggregateExt`] - - fn new(udaf: Option) -> Self { - Self { - udaf, - order_by: None, - filter: None, - distinct: false, - null_treatment: None, - } - } - - /// Updates and returns the in progress [`Expr::AggregateFunction`] - /// - /// # Errors: - /// - /// Returns an error of this builder [`AggregateExt`] was used with an - /// `Expr` variant other than [`Expr::AggregateFunction`] - pub fn build(self) -> Result { - let Self { - udaf, - order_by, - filter, - distinct, - null_treatment, - } = self; - - let Some(mut udaf) = udaf else { - return plan_err!( - "AggregateExt can only be used with Expr::AggregateFunction" - ); - }; - - if let Some(order_by) = &order_by { - for expr in order_by.iter() { - if !matches!(expr, Expr::Sort(_)) { - return plan_err!( - "ORDER BY expressions must be Expr::Sort, found {expr:?}" - ); - } - } - } - - udaf.order_by = order_by; - udaf.filter = filter.map(Box::new); - udaf.distinct = distinct; - udaf.null_treatment = null_treatment; - Ok(Expr::AggregateFunction(udaf)) - } - - /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - pub fn order_by(mut self, order_by: Vec) -> AggregateBuilder { - self.order_by = Some(order_by); - self - } - - /// Add `FILTER ` - pub fn filter(mut self, filter: Expr) -> AggregateBuilder { - self.filter = Some(filter); - self - } - - /// Add `DISTINCT` - pub fn distinct(mut self) -> AggregateBuilder { - self.distinct = true; - self - } - - /// Add `RESPECT NULLS` or `IGNORE NULLS` - pub fn null_treatment(mut self, null_treatment: NullTreatment) -> AggregateBuilder { - self.null_treatment = Some(null_treatment); - self - } -} - -impl AggregateExt for Expr { - fn order_by(self, order_by: Vec) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.order_by = Some(order_by); - builder - } - _ => AggregateBuilder::new(None), - } - } - fn filter(self, filter: Expr) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.filter = Some(filter); - builder - } - _ => AggregateBuilder::new(None), - } - } - fn distinct(self) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.distinct = true; - builder - } - _ => AggregateBuilder::new(None), - } - } - fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.null_treatment = Some(null_treatment); - builder - } - _ => AggregateBuilder::new(None), - } - } -} diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 0e619bacef82..862bd8c1378a 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -31,7 +31,7 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, + Accumulator, ExprFunctionExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 58c1ae297b02..9f04a01a3377 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -806,7 +806,7 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; - use datafusion_expr::AggregateExt; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ binary_expr, build_join_schema, builder::table_scan_with_filters, diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index fcd33be618f7..430517121f2a 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -23,7 +23,7 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{col, AggregateExt, LogicalPlanBuilder}; +use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index f2b4abdd6cbd..d776e6598cbe 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -354,7 +354,7 @@ mod tests { use super::*; use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; - use datafusion_expr::AggregateExt; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, }; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3476d5d042cc..b9aa4773b812 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -60,7 +60,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateFunction, AggregateUDF, ColumnarValue, + Accumulator, ExprFunctionExt, AggregateFunction, AggregateUDF, ColumnarValue, ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 2f7854c1a183..fc41f3302adf 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1342,7 +1342,7 @@ mod tests { table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; - use datafusion_expr::{interval_month_day_nano_lit, AggregateExt}; + use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 6e693a0e7087..60036e440ffb 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -308,7 +308,7 @@ select log(-1), log(0), sqrt(-1); ## Aggregate Function Builder -You can also use the `AggregateExt` trait to more easily build Aggregate arguments `Expr`. +You can also use the `ExprFunctionExt` trait to more easily build Aggregate arguments `Expr`. See `datafusion-examples/examples/expr_api.rs` for example usage. From 4b124c6e9f141e5b958da0bf9af8ca7720885a71 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 19 Jul 2024 08:30:26 -0400 Subject: [PATCH 02/21] Switch WindowFrame to only need the window function definition and arguments. Other parameters will be set via the ExprFuncBuilder --- datafusion/core/src/dataframe/mod.rs | 10 +- datafusion/core/tests/dataframe/mod.rs | 13 +-- datafusion/expr/src/expr.rs | 44 ++++++-- datafusion/expr/src/expr_fn.rs | 4 +- datafusion/expr/src/lib.rs | 1 + datafusion/expr/src/tree_node.rs | 20 ++-- datafusion/expr/src/utils.rs | 80 +++----------- datafusion/expr/src/window_function.rs | 104 ++++++++++++++++++ .../src/analyzer/count_wildcard_rule.rs | 11 +- .../optimizer/src/analyzer/type_coercion.rs | 20 ++-- .../optimizer/src/optimize_projections/mod.rs | 16 +-- .../simplify_expressions/expr_simplifier.rs | 14 +-- .../proto/src/logical_plan/from_proto.rs | 30 +---- .../tests/cases/roundtrip_logical_plan.rs | 49 ++------- datafusion/sql/src/expr/function.rs | 32 +++--- 15 files changed, 219 insertions(+), 229 deletions(-) create mode 100644 datafusion/expr/src/window_function.rs diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index fb28b5c1ab47..2c5ba4b38ee0 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1696,8 +1696,7 @@ mod tests { use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, - Volatility, WindowFrame, WindowFunctionDefinition, + cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, ScalarFunctionImplementation, Volatility, WindowFunctionDefinition }; use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_physical_expr::expressions::Column; @@ -1866,12 +1865,7 @@ mod tests { WindowFunctionDefinition::BuiltInWindowFunction( BuiltInWindowFunction::FirstValue, ), - vec![col("aggregate_test_100.c1")], - vec![col("aggregate_test_100.c2")], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("aggregate_test_100.c1")])).partition_by(vec![col("aggregate_test_100.c2")]).build().unwrap(); let t2 = t.select(vec![col("c1"), first_row])?; let plan = t2.plan.clone(); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index d68b80691917..11787b609872 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -54,9 +54,7 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, - scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunctionDefinition, + cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition }; use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum}; @@ -182,16 +180,13 @@ async fn test_count_wildcard_on_window() -> Result<()> { .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![wildcard()], - vec![], - vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], + vec![wildcard()])).order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]).window_frame( WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - ), - None, - ))])? + )).build().unwrap() + ])? .explain(false, false)? .collect() .await?; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 33d229b61174..8ab7c89e2cc8 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,8 +28,7 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; use crate::{ - aggregate_function, built_in_window_function, udaf, ExprSchemable, Operator, - Signature, + aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF }; use crate::{window_frame, Volatility}; @@ -769,6 +768,30 @@ impl fmt::Display for WindowFunctionDefinition { } } +impl From for WindowFunctionDefinition { + fn from(value: aggregate_function::AggregateFunction) -> Self { + Self::AggregateFunction(value) + } +} + +impl From for WindowFunctionDefinition { + fn from(value: BuiltInWindowFunction) -> Self { + Self::BuiltInWindowFunction(value) + } +} + +impl From> for WindowFunctionDefinition { + fn from(value: Arc) -> Self { + Self::AggregateUDF(value) + } +} + +impl From> for WindowFunctionDefinition { + fn from(value: Arc) -> Self { + Self::WindowUDF(value) + } +} + /// Window function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { @@ -789,20 +812,17 @@ pub struct WindowFunction { impl WindowFunction { /// Create a new Window expression pub fn new( - fun: WindowFunctionDefinition, + fun: impl Into, args: Vec, - partition_by: Vec, - order_by: Vec, - window_frame: window_frame::WindowFrame, - null_treatment: Option, + ) -> Self { Self { - fun, + fun: fun.into(), args, - partition_by, - order_by, - window_frame, - null_treatment, + partition_by: Vec::default(), + order_by: Vec::default(), + window_frame: WindowFrame::new(None), + null_treatment: None, } } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 289f62f873bf..9e5a402a0a9f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -746,12 +746,12 @@ impl ExprFuncBuilder { } } - /// Updates and returns the in progress [`Expr::AggregateFunction`] + /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] /// /// # Errors: /// /// Returns an error of this builder [`ExprFunctionExt`] was used with an - /// `Expr` variant other than [`Expr::AggregateFunction`] + /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] pub fn build(self) -> Result { let Self { fun, diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 354e795fe64d..0a5cf4653a22 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -60,6 +60,7 @@ pub mod type_coercion; pub mod utils; pub mod var_provider; pub mod window_frame; +pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f1df8609f903..3d7a72180ca6 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -22,7 +22,7 @@ use crate::expr::{ Cast, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; -use crate::Expr; +use crate::{Expr, ExprFunctionExt}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, @@ -294,14 +294,18 @@ impl TreeNode for Expr { transform_vec(order_by, &mut f) )? .update_data(|(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new( + let mut builder = Expr::WindowFunction(WindowFunction::new( fun, - new_args, - new_partition_by, - new_order_by, - window_frame, - null_treatment, - )) + new_args)).partition_by(new_partition_by).order_by(new_order_by).window_frame(window_frame); + if let Some(n) = null_treatment { + builder = builder.null_treatment(n) + } + builder.build().unwrap() + // new_partition_by, + // new_order_by, + // window_frame, + // null_treatment, + // )) }), Expr::AggregateFunction(AggregateFunction { args, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 889aa0952e51..908209614056 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1252,9 +1252,7 @@ impl AggregateOrderSensitivity { mod tests { use super::*; use crate::{ - col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::sum_udaf, AggregateFunction, Cast, WindowFrame, - WindowFunctionDefinition, + col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::sum_udaf, AggregateFunction, Cast, ExprFunctionExt, WindowFrame, WindowFunctionDefinition }; #[test] @@ -1269,36 +1267,16 @@ mod tests { fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("name")])); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("name")])); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), - vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("name")])); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("age")])); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key = vec![]; @@ -1316,36 +1294,16 @@ mod tests { Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")], - vec![], - vec![age_asc.clone(), name_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + vec![col("name")])).order_by(vec![age_asc.clone(), name_desc.clone()]).build().unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("name")])); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), - vec![col("name")], - vec![], - vec![age_asc.clone(), name_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + vec![col("name")])).order_by(vec![age_asc.clone(), name_desc.clone()]).build().unwrap(); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")], - vec![], - vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + vec![col("age")])).order_by(vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()]).build().unwrap(); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1372,27 +1330,19 @@ mod tests { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")], - vec![], - vec![ + vec![col("name")])).order_by(vec![ Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - ], - WindowFrame::new(Some(false)), - None, - )), + ]).window_frame(WindowFrame::new(Some(false))) + .build().unwrap(), Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")], - vec![], - vec![ + vec![col("age")])).order_by(vec![ Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ], - WindowFrame::new(Some(false)), - None, - )), + ]).window_frame(WindowFrame::new(Some(false))) + .build().unwrap(), ]; let expected = vec![ Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs new file mode 100644 index 000000000000..0fa1d4168655 --- /dev/null +++ b/datafusion/expr/src/window_function.rs @@ -0,0 +1,104 @@ +use datafusion_common::ScalarValue; + +use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; + + + +/// Create an expression to represent the `row_number` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn row_number() -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::RowNumber, vec![]) +} + +/// Create an expression to represent the `rank` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn rank() -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::Rank, vec![]) +} + +/// Create an expression to represent the `dense_rank` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn dense_rank() -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::DenseRank, vec![]) +} + +/// Create an expression to represent the `percent_rank` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn percent_rank() -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::PercentRank, vec![]) +} + +/// Create an expression to represent the `cume_dist` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn cume_dist() -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![]) +} + +/// Create an expression to represent the `ntile` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn ntile(arg: Expr) -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg]) +} + +/// Create an expression to represent the `lag` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn lag( + arg: Expr, + shift_offset: Option, + default_value: Option, +) -> WindowFunction { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + WindowFunction::new( + BuiltInWindowFunction::Lag, + vec![arg, shift_offset_lit, default_lit], + ) +} + +/// Create an expression to represent the `lead` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn lead( + arg: Expr, + shift_offset: Option, + default_value: Option, +) -> WindowFunction { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + WindowFunction::new( + BuiltInWindowFunction::Lead, + vec![arg, shift_offset_lit, default_lit], + ) +} + +/// Create an expression to represent the `first_value` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn first_value(arg: Expr) -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![arg]) +} + +/// Create an expression to represent the `last_value` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn last_value(arg: Expr) -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::LastValue, vec![arg]) +} + +/// Create an expression to represent the `nth_value` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn nth_value(arg: Expr, n: i64) -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::NthValue, vec![arg, n.lit()]) +} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index fa8aeb86ed31..2dc050eae6f6 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -101,6 +101,7 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, @@ -222,16 +223,12 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![wildcard()], - vec![], - vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], - WindowFrame::new_bounds( + vec![wildcard()])).order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]).window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - ), - None, - ))])? + )).build()? + ])? .project(vec![count(wildcard())])? .build()?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 50fb1b8193ce..5e28ade1ed43 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -46,9 +46,7 @@ use datafusion_expr::type_coercion::other::{ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - type_coercion, AggregateFunction, AggregateUDF, Expr, ExprSchemable, LogicalPlan, - Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, + is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, type_coercion, AggregateFunction, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits }; use crate::analyzer::AnalyzerRule; @@ -466,14 +464,16 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { _ => args, }; - Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( + Ok(Transformed::yes({ + let mut builder = Expr::WindowFunction(WindowFunction::new( fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - )))) + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); + if let Some(n) = null_treatment { + builder = builder.null_treatment(n); + } + builder.build()? + } + )) } Expr::Alias(_) | Expr::Column(_) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 9f04a01a3377..787146d90b00 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -815,7 +815,7 @@ mod tests { lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, max, min, not, try_cast, when, AggregateFunction, BinaryExpr, Expr, Extension, - Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFrame, + Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition, }; @@ -1918,21 +1918,11 @@ mod tests { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("test.a")], - vec![col("test.b")], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("test.a")])).partition_by(vec![col("test.b")]).build().unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("test.b")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("test.b")])); let col1 = col(max1.display_name()?); let col2 = col(max2.display_name()?); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 56556f387d1b..64751e72b9ee 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3858,12 +3858,7 @@ mod tests { let window_function_expr = Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( udwf, - vec![], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -3874,12 +3869,7 @@ mod tests { let window_function_expr = Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( udwf, - vec![], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b6b556a8ed6b..fd6d19d2fd08 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -25,6 +25,7 @@ use datafusion_common::{ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; +use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, @@ -300,7 +301,6 @@ pub fn parse_expr( ) })?; // TODO: support proto for null treatment - let null_treatment = None; regularize_window_order_by(&window_frame, &mut order_by)?; match window_function { @@ -314,12 +314,7 @@ pub fn parse_expr( registry, "expr", codec, - )?], - partition_by, - order_by, - window_frame, - None, - ))) + )?])).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) @@ -335,12 +330,7 @@ pub fn parse_expr( expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), - args, - partition_by, - order_by, - window_frame, - null_treatment, - ))) + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { @@ -354,12 +344,7 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), - args, - partition_by, - order_by, - window_frame, - None, - ))) + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { @@ -373,12 +358,7 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), - args, - partition_by, - order_by, - window_frame, - None, - ))) + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index b9aa4773b812..bb7f52e5769a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -2047,24 +2047,14 @@ fn roundtrip_window() { WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(Some(false)), - None, - )); + vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(WindowFrame::new(Some(false))).build().unwrap(); // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(Some(false)), - None, - )); + vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(WindowFrame::new(Some(false))).build().unwrap(); // 3. with window_frame with row numbers let range_number_frame = WindowFrame::new_bounds( @@ -2077,12 +2067,7 @@ fn roundtrip_window() { WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![], - vec![col("col1")], - vec![col("col2")], - range_number_frame, - None, - )); + vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(range_number_frame).build().unwrap(); // 4. test with AggregateFunction let row_number_frame = WindowFrame::new_bounds( @@ -2093,12 +2078,7 @@ fn roundtrip_window() { let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); // 5. test with AggregateUDF #[derive(Debug)] @@ -2142,12 +2122,7 @@ fn roundtrip_window() { let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), - vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); ctx.register_udaf(dummy_agg); // 6. test with WindowUDF @@ -2218,21 +2193,11 @@ fn roundtrip_window() { let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), - vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), - vec![col("col1")], - vec![], - vec![], - row_number_frame.clone(), - None, - )); + vec![col("col1")])).window_frame(row_number_frame.clone()).build().unwrap(); ctx.register_udwf(dummy_window_udf); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 4804752d8389..8f324f7ced7f 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -24,7 +24,7 @@ use datafusion_common::{ use datafusion_expr::planner::PlannerResult; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition, + expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition }; use datafusion_expr::{ expr::{ScalarFunction, Unnest}, @@ -326,23 +326,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = self.function_args_to_expr(args, schema, planner_context)?; - Expr::WindowFunction(expr::WindowFunction::new( + let mut builder = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(aggregate_fun), - args, - partition_by, - order_by, - window_frame, - null_treatment, - )) + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); + if let Some(n) = null_treatment { + builder = builder.null_treatment(n); + }; + builder.build().unwrap() } - _ => Expr::WindowFunction(expr::WindowFunction::new( - fun, - self.function_args_to_expr(args, schema, planner_context)?, - partition_by, - order_by, - window_frame, - null_treatment, - )), + _ => { + let mut builder = Expr::WindowFunction(expr::WindowFunction::new( + fun, + self.function_args_to_expr(args, schema, planner_context)?)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); + if let Some(n) = null_treatment { + builder = builder.null_treatment(n); + } + builder.build().unwrap() + }, }; return Ok(expr); } From 31a82ddc9cd43f64441c2766d3e1fc0404a6d301 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 19 Jul 2024 08:45:23 -0400 Subject: [PATCH 03/21] Changing null_treatment to take an option, but this is mostly for code cleanliness and not strictly required --- datafusion/expr/src/expr_fn.rs | 10 +++++----- datafusion/expr/src/tree_node.rs | 16 +++------------- .../optimizer/src/analyzer/type_coercion.rs | 10 ++-------- datafusion/sql/src/expr/function.rs | 16 ++++------------ 4 files changed, 14 insertions(+), 38 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 9e5a402a0a9f..fcb9b81662c3 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -704,7 +704,7 @@ pub trait ExprFunctionExt { /// Add `DISTINCT` fn distinct(self) -> ExprFuncBuilder; /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder; + fn null_treatment(self, null_treatment: impl Into>) -> ExprFuncBuilder; // Add `PARTITION BY` fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; // Add appropriate window frame conditions @@ -821,8 +821,8 @@ impl ExprFuncBuilder { } /// Add `RESPECT NULLS` or `IGNORE NULLS` - pub fn null_treatment(mut self, null_treatment: NullTreatment) -> ExprFuncBuilder { - self.null_treatment = Some(null_treatment); + pub fn null_treatment(mut self, null_treatment: impl Into>) -> ExprFuncBuilder { + self.null_treatment = null_treatment.into(); self } @@ -869,14 +869,14 @@ impl ExprFunctionExt for Expr { _ => ExprFuncBuilder::new(None), } } - fn null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder { + fn null_treatment(self, null_treatment: impl Into>) -> ExprFuncBuilder { let mut builder = match self { Expr::AggregateFunction(udaf) => ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))), Expr::WindowFunction(udwf) => ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))), _ => ExprFuncBuilder::new(None), }; if builder.fun.is_some() { - builder.null_treatment = Some(null_treatment); + builder.null_treatment = null_treatment.into(); } builder } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 3d7a72180ca6..f262613b2295 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -293,20 +293,10 @@ impl TreeNode for Expr { order_by, transform_vec(order_by, &mut f) )? - .update_data(|(new_args, new_partition_by, new_order_by)| { - let mut builder = Expr::WindowFunction(WindowFunction::new( + .update_data(|(new_args, new_partition_by, new_order_by)| Expr::WindowFunction(WindowFunction::new( fun, - new_args)).partition_by(new_partition_by).order_by(new_order_by).window_frame(window_frame); - if let Some(n) = null_treatment { - builder = builder.null_treatment(n) - } - builder.build().unwrap() - // new_partition_by, - // new_order_by, - // window_frame, - // null_treatment, - // )) - }), + new_args)).partition_by(new_partition_by).order_by(new_order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() + ), Expr::AggregateFunction(AggregateFunction { args, func_def, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 5e28ade1ed43..38c9637148f3 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -464,15 +464,9 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { _ => args, }; - Ok(Transformed::yes({ - let mut builder = Expr::WindowFunction(WindowFunction::new( + Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( fun, - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); - if let Some(n) = null_treatment { - builder = builder.null_treatment(n); - } - builder.build()? - } + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build()? )) } Expr::Alias(_) diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 8f324f7ced7f..efa914bfae7f 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -326,22 +326,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = self.function_args_to_expr(args, schema, planner_context)?; - let mut builder = Expr::WindowFunction(expr::WindowFunction::new( + Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(aggregate_fun), - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); - if let Some(n) = null_treatment { - builder = builder.null_treatment(n); - }; - builder.build().unwrap() + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() } _ => { - let mut builder = Expr::WindowFunction(expr::WindowFunction::new( + Expr::WindowFunction(expr::WindowFunction::new( fun, - self.function_args_to_expr(args, schema, planner_context)?)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); - if let Some(n) = null_treatment { - builder = builder.null_treatment(n); - } - builder.build().unwrap() + self.function_args_to_expr(args, schema, planner_context)?)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() }, }; return Ok(expr); From d290d2e675920c5eec42e0ace57e387d9ab66e50 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 19 Jul 2024 08:49:46 -0400 Subject: [PATCH 04/21] Moving functions in ExprFuncBuilder over to be explicitly implementing ExprFunctionExt trait so we can guarantee a consistent user experience no matter which they call on the Expr and which on the builder --- datafusion/expr/src/expr_fn.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index fcb9b81662c3..492c1476ec77 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -799,39 +799,42 @@ impl ExprFuncBuilder { Ok(fun_expr) } +} + +impl ExprFunctionExt for ExprFuncBuilder { /// Add `ORDER BY ` /// /// Note: `order_by` must be [`Expr::Sort`] - pub fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { + fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { self.order_by = Some(order_by); self } /// Add `FILTER ` - pub fn filter(mut self, filter: Expr) -> ExprFuncBuilder { + fn filter(mut self, filter: Expr) -> ExprFuncBuilder { self.filter = Some(filter); self } /// Add `DISTINCT` - pub fn distinct(mut self) -> ExprFuncBuilder { + fn distinct(mut self) -> ExprFuncBuilder { self.distinct = true; self } /// Add `RESPECT NULLS` or `IGNORE NULLS` - pub fn null_treatment(mut self, null_treatment: impl Into>) -> ExprFuncBuilder { + fn null_treatment(mut self, null_treatment: impl Into>) -> ExprFuncBuilder { self.null_treatment = null_treatment.into(); self } - pub fn partition_by(mut self, partition_by: Vec) -> ExprFuncBuilder { + fn partition_by(mut self, partition_by: Vec) -> ExprFuncBuilder { self.partition_by = Some(partition_by); self } - pub fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { + fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { self.window_frame = Some(window_frame); self } From d267d0ea32b3628d4aa62a64d20f2df3981fd99b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 19 Jul 2024 08:52:23 -0400 Subject: [PATCH 05/21] Apply cargo fmt --- datafusion-examples/examples/expr_api.rs | 2 +- datafusion/core/src/dataframe/mod.rs | 6 +- datafusion/core/tests/dataframe/mod.rs | 17 +++-- datafusion/expr/src/expr.rs | 9 +-- datafusion/expr/src/expr_fn.rs | 53 +++++++++----- datafusion/expr/src/tree_node.rs | 13 ++-- datafusion/expr/src/utils.rs | 71 +++++++++++++------ datafusion/expr/src/window_function.rs | 2 - .../functions-aggregate/src/first_last.rs | 4 +- .../src/analyzer/count_wildcard_rule.rs | 15 ++-- .../optimizer/src/analyzer/type_coercion.rs | 15 ++-- .../optimizer/src/optimize_projections/mod.rs | 9 ++- .../simplify_expressions/expr_simplifier.rs | 14 ++-- .../proto/src/logical_plan/from_proto.rs | 32 +++++++-- .../tests/cases/roundtrip_logical_plan.rs | 56 ++++++++++++--- datafusion/sql/src/expr/function.rs | 27 +++++-- 16 files changed, 246 insertions(+), 99 deletions(-) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index eb08fe6593ca..0eb823302acf 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -33,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{ExprFunctionExt, ColumnarValue, ExprSchemable, Operator}; +use datafusion_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator}; /// This example demonstrates the DataFusion [`Expr`] API. /// diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2c5ba4b38ee0..49b638946e78 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1865,7 +1865,11 @@ mod tests { WindowFunctionDefinition::BuiltInWindowFunction( BuiltInWindowFunction::FirstValue, ), - vec![col("aggregate_test_100.c1")])).partition_by(vec![col("aggregate_test_100.c2")]).build().unwrap(); + vec![col("aggregate_test_100.c1")], + )) + .partition_by(vec![col("aggregate_test_100.c2")]) + .build() + .unwrap(); let t2 = t.select(vec![col("c1"), first_row])?; let plan = t2.plan.clone(); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 11787b609872..3ff653230965 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -180,13 +180,16 @@ async fn test_count_wildcard_on_window() -> Result<()> { .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![wildcard()])).order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]).window_frame( - WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - )).build().unwrap() - ])? + vec![wildcard()], + )) + .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build() + .unwrap()])? .explain(false, false)? .collect() .await?; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 8ab7c89e2cc8..694f849f12aa 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,7 +28,8 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; use crate::{ - aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF + aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction, + ExprSchemable, Operator, Signature, WindowFrame, WindowUDF, }; use crate::{window_frame, Volatility}; @@ -811,11 +812,7 @@ pub struct WindowFunction { impl WindowFunction { /// Create a new Window expression - pub fn new( - fun: impl Into, - args: Vec, - - ) -> Self { + pub fn new(fun: impl Into, args: Vec) -> Self { Self { fun: fun.into(), args, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 492c1476ec77..ff0a7d589c4c 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -30,7 +30,9 @@ use crate::{ AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; -use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl}; +use crate::{ + AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, +}; use arrow::compute::kernels::cast_utils::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; @@ -665,7 +667,6 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) } - /// Extensions for configuring [`Expr::AggregateFunction`] /// /// Adds methods to [`Expr`] that make it easy to set optional aggregate options @@ -704,7 +705,10 @@ pub trait ExprFunctionExt { /// Add `DISTINCT` fn distinct(self) -> ExprFuncBuilder; /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment(self, null_treatment: impl Into>) -> ExprFuncBuilder; + fn null_treatment( + self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder; // Add `PARTITION BY` fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; // Add appropriate window frame conditions @@ -791,7 +795,8 @@ impl ExprFuncBuilder { let has_order_by = order_by.as_ref().map(|o| o.len() > 0); udwf.order_by = order_by.unwrap_or_default(); udwf.partition_by = partition_by.unwrap_or_default(); - udwf.window_frame = window_frame.unwrap_or(WindowFrame::new(has_order_by)); + udwf.window_frame = + window_frame.unwrap_or(WindowFrame::new(has_order_by)); udwf.null_treatment = null_treatment; Expr::WindowFunction(udwf) } @@ -802,7 +807,6 @@ impl ExprFuncBuilder { } impl ExprFunctionExt for ExprFuncBuilder { - /// Add `ORDER BY ` /// /// Note: `order_by` must be [`Expr::Sort`] @@ -824,7 +828,10 @@ impl ExprFunctionExt for ExprFuncBuilder { } /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment(mut self, null_treatment: impl Into>) -> ExprFuncBuilder { + fn null_treatment( + mut self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder { self.null_treatment = null_treatment.into(); self } @@ -833,7 +840,7 @@ impl ExprFunctionExt for ExprFuncBuilder { self.partition_by = Some(partition_by); self } - + fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { self.window_frame = Some(window_frame); self @@ -843,8 +850,12 @@ impl ExprFunctionExt for ExprFuncBuilder { impl ExprFunctionExt for Expr { fn order_by(self, order_by: Vec) -> ExprFuncBuilder { let mut builder = match self { - Expr::AggregateFunction(udaf) => ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))), - Expr::WindowFunction(udwf) => ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))), + Expr::AggregateFunction(udaf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + } + Expr::WindowFunction(udwf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + } _ => ExprFuncBuilder::new(None), }; if builder.fun.is_some() { @@ -855,7 +866,8 @@ impl ExprFunctionExt for Expr { fn filter(self, filter: Expr) -> ExprFuncBuilder { match self { Expr::AggregateFunction(udaf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + let mut builder = + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); builder.filter = Some(filter); builder } @@ -865,17 +877,25 @@ impl ExprFunctionExt for Expr { fn distinct(self) -> ExprFuncBuilder { match self { Expr::AggregateFunction(udaf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + let mut builder = + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); builder.distinct = true; builder } _ => ExprFuncBuilder::new(None), } } - fn null_treatment(self, null_treatment: impl Into>) -> ExprFuncBuilder { + fn null_treatment( + self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder { let mut builder = match self { - Expr::AggregateFunction(udaf) => ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))), - Expr::WindowFunction(udwf) => ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))), + Expr::AggregateFunction(udaf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + } + Expr::WindowFunction(udwf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + } _ => ExprFuncBuilder::new(None), }; if builder.fun.is_some() { @@ -883,7 +903,7 @@ impl ExprFunctionExt for Expr { } builder } - + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { @@ -894,7 +914,7 @@ impl ExprFunctionExt for Expr { _ => ExprFuncBuilder::new(None), } } - + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { @@ -907,7 +927,6 @@ impl ExprFunctionExt for Expr { } } - #[cfg(test)] mod test { use super::*; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f262613b2295..a97b9f010f79 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -293,10 +293,15 @@ impl TreeNode for Expr { order_by, transform_vec(order_by, &mut f) )? - .update_data(|(new_args, new_partition_by, new_order_by)| Expr::WindowFunction(WindowFunction::new( - fun, - new_args)).partition_by(new_partition_by).order_by(new_order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() - ), + .update_data(|(new_args, new_partition_by, new_order_by)| { + Expr::WindowFunction(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() + }), Expr::AggregateFunction(AggregateFunction { args, func_def, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 908209614056..2ef1597abfd1 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1252,7 +1252,9 @@ impl AggregateOrderSensitivity { mod tests { use super::*; use crate::{ - col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::sum_udaf, AggregateFunction, Cast, ExprFunctionExt, WindowFrame, WindowFunctionDefinition + col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, + test::function_stub::sum_udaf, AggregateFunction, Cast, ExprFunctionExt, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -1267,16 +1269,20 @@ mod tests { fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")])); + vec![col("name")], + )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")])); + vec![col("name")], + )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), - vec![col("name")])); + vec![col("name")], + )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")])); + vec![col("age")], + )); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key = vec![]; @@ -1294,16 +1300,33 @@ mod tests { Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")])).order_by(vec![age_asc.clone(), name_desc.clone()]).build().unwrap(); + vec![col("name")], + )) + .order_by(vec![age_asc.clone(), name_desc.clone()]) + .build() + .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")])); + vec![col("name")], + )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), - vec![col("name")])).order_by(vec![age_asc.clone(), name_desc.clone()]).build().unwrap(); + vec![col("name")], + )) + .order_by(vec![age_asc.clone(), name_desc.clone()]) + .build() + .unwrap(); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")])).order_by(vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()]).build().unwrap(); + vec![col("age")], + )) + .order_by(vec![ + name_desc.clone(), + age_asc.clone(), + created_at_desc.clone(), + ]) + .build() + .unwrap(); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1330,19 +1353,27 @@ mod tests { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")])).order_by(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - ]).window_frame(WindowFrame::new(Some(false))) - .build().unwrap(), + vec![col("name")], + )) + .order_by(vec![ + Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), + Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), + ]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(), Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")])).order_by(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ]).window_frame(WindowFrame::new(Some(false))) - .build().unwrap(), + vec![col("age")], + )) + .order_by(vec![ + Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), + Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), + Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), + ]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(), ]; let expected = vec![ Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 0fa1d4168655..f61c9110ffc9 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -2,8 +2,6 @@ use datafusion_common::ScalarValue; use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; - - /// Create an expression to represent the `row_number` window function /// /// Note: call [`WindowFunction::build]` to create an [`Expr`] diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 862bd8c1378a..39f1944452af 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -31,8 +31,8 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, ExprFunctionExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, - TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, ExprFunctionExt, + Signature, TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{ diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 2dc050eae6f6..338268e299da 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -223,12 +223,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![wildcard()])).order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]).window_frame(WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - )).build()? - ])? + vec![wildcard()], + )) + .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build()?])? .project(vec![count(wildcard())])? .build()?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 38c9637148f3..75dbb4d1adcd 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -46,7 +46,10 @@ use datafusion_expr::type_coercion::other::{ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, type_coercion, AggregateFunction, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits + is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, + type_coercion, AggregateFunction, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, + LogicalPlan, Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; @@ -464,9 +467,13 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { _ => args, }; - Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( - fun, - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build()? + Ok(Transformed::yes( + Expr::WindowFunction(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build()?, )) } Expr::Alias(_) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 787146d90b00..16abf93f3807 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -1918,11 +1918,16 @@ mod tests { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("test.a")])).partition_by(vec![col("test.b")]).build().unwrap(); + vec![col("test.a")], + )) + .partition_by(vec![col("test.b")]) + .build() + .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("test.b")])); + vec![col("test.b")], + )); let col1 = col(max1.display_name()?); let col2 = col(max2.display_name()?); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 64751e72b9ee..38dfbb3ed551 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3855,10 +3855,9 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( - udwf, - vec![])); + let window_function_expr = Expr::WindowFunction( + datafusion_expr::expr::WindowFunction::new(udwf, vec![]), + ); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -3866,10 +3865,9 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( - udwf, - vec![])); + let window_function_expr = Expr::WindowFunction( + datafusion_expr::expr::WindowFunction::new(udwf, vec![]), + ); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index fd6d19d2fd08..f3f0e603e060 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -314,7 +314,13 @@ pub fn parse_expr( registry, "expr", codec, - )?])).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) + )?], + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) @@ -330,7 +336,13 @@ pub fn parse_expr( expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) + args, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { @@ -344,7 +356,13 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) + args, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { @@ -358,7 +376,13 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) + args, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index bb7f52e5769a..019428666e65 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -60,7 +60,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, ExprFunctionExt, AggregateFunction, AggregateUDF, ColumnarValue, + Accumulator, AggregateFunction, AggregateUDF, ColumnarValue, ExprFunctionExt, ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, @@ -2047,14 +2047,26 @@ fn roundtrip_window() { WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(WindowFrame::new(Some(false))).build().unwrap(); + vec![], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(); // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(WindowFrame::new(Some(false))).build().unwrap(); + vec![], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(); // 3. with window_frame with row numbers let range_number_frame = WindowFrame::new_bounds( @@ -2067,7 +2079,13 @@ fn roundtrip_window() { WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(range_number_frame).build().unwrap(); + vec![], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(range_number_frame) + .build() + .unwrap(); // 4. test with AggregateFunction let row_number_frame = WindowFrame::new_bounds( @@ -2078,7 +2096,13 @@ fn roundtrip_window() { let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); + vec![col("col1")], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); // 5. test with AggregateUDF #[derive(Debug)] @@ -2122,7 +2146,13 @@ fn roundtrip_window() { let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), - vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); + vec![col("col1")], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); ctx.register_udaf(dummy_agg); // 6. test with WindowUDF @@ -2193,11 +2223,21 @@ fn roundtrip_window() { let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), - vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); + vec![col("col1")], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), - vec![col("col1")])).window_frame(row_number_frame.clone()).build().unwrap(); + vec![col("col1")], + )) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); ctx.register_udwf(dummy_window_udf); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index efa914bfae7f..1777c8d81153 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -24,7 +24,8 @@ use datafusion_common::{ use datafusion_expr::planner::PlannerResult; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition + expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, + WindowFunctionDefinition, }; use datafusion_expr::{ expr::{ScalarFunction, Unnest}, @@ -328,13 +329,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(aggregate_fun), - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() + args, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() } - _ => { - Expr::WindowFunction(expr::WindowFunction::new( - fun, - self.function_args_to_expr(args, schema, planner_context)?)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() - }, + _ => Expr::WindowFunction(expr::WindowFunction::new( + fun, + self.function_args_to_expr(args, schema, planner_context)?, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap(), }; return Ok(expr); } From 2e758add162243eaeb6975f8f20aa9faddbc0755 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 23 Jul 2024 08:31:03 -0400 Subject: [PATCH 06/21] Add deprecated trait AggregateExt so that users get a warning but still builds --- datafusion/expr/src/lib.rs | 5 +++++ datafusion/expr/src/udaf.rs | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 0a5cf4653a22..bb148145c749 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -17,6 +17,9 @@ // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] +// TODO When the deprecated trait AggregateExt is removed, remove this unstable feature. +#![feature(trait_alias)] + //! [DataFusion](https://github.com/apache/datafusion) //! is an extensible query execution framework that uses //! [Apache Arrow](https://arrow.apache.org) as its in-memory format. @@ -88,6 +91,8 @@ pub use signature::{ pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; +#[allow(deprecated)] +pub use udaf::AggregateExt; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 29267f30100a..5fc893243ad5 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -34,7 +34,7 @@ use crate::function::{ use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; -use crate::{Accumulator, Expr}; +use crate::{Accumulator, Expr, ExprFunctionExt}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; /// Logical representation of a user-defined [aggregate function] (UDAF). @@ -654,3 +654,7 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { (self.accumulator)(acc_args) } } + +#[deprecated(since = "40.0.0", note="Use ExprFunctionExt instead.")] +// pub trait AggregateExt : ExprFunctionExt {} +pub trait AggregateExt = ExprFunctionExt; \ No newline at end of file From fbde31fcc48ed8ecd0d05723bf0473c0ac559e82 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 23 Jul 2024 09:12:00 -0400 Subject: [PATCH 07/21] Window helper functions should return Expr --- datafusion/expr/src/window_function.rs | 48 +++++++++++++------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index f61c9110ffc9..81d2bccbecea 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -5,43 +5,43 @@ use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; /// Create an expression to represent the `row_number` window function /// /// Note: call [`WindowFunction::build]` to create an [`Expr`] -pub fn row_number() -> WindowFunction { - WindowFunction::new(BuiltInWindowFunction::RowNumber, vec![]) +pub fn row_number() -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::RowNumber, vec![])) } /// Create an expression to represent the `rank` window function /// /// Note: call [`WindowFunction::build]` to create an [`Expr`] -pub fn rank() -> WindowFunction { - WindowFunction::new(BuiltInWindowFunction::Rank, vec![]) +pub fn rank() -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Rank, vec![])) } /// Create an expression to represent the `dense_rank` window function /// /// Note: call [`WindowFunction::build]` to create an [`Expr`] -pub fn dense_rank() -> WindowFunction { - WindowFunction::new(BuiltInWindowFunction::DenseRank, vec![]) +pub fn dense_rank() -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::DenseRank, vec![])) } /// Create an expression to represent the `percent_rank` window function /// /// Note: call [`WindowFunction::build]` to create an [`Expr`] -pub fn percent_rank() -> WindowFunction { - WindowFunction::new(BuiltInWindowFunction::PercentRank, vec![]) +pub fn percent_rank() -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::PercentRank, vec![])) } /// Create an expression to represent the `cume_dist` window function /// /// Note: call [`WindowFunction::build]` to create an [`Expr`] -pub fn cume_dist() -> WindowFunction { - WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![]) +pub fn cume_dist() -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![])) } /// Create an expression to represent the `ntile` window function /// /// Note: call [`WindowFunction::build]` to create an [`Expr`] -pub fn ntile(arg: Expr) -> WindowFunction { - WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg]) +pub fn ntile(arg: Expr) -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg])) } /// Create an expression to represent the `lag` window function @@ -51,15 +51,15 @@ pub fn lag( arg: Expr, shift_offset: Option, default_value: Option, -) -> WindowFunction { +) -> Expr { let shift_offset_lit = shift_offset .map(|v| v.lit()) .unwrap_or(ScalarValue::Null.lit()); let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); - WindowFunction::new( + Expr::WindowFunction(WindowFunction::new( BuiltInWindowFunction::Lag, vec![arg, shift_offset_lit, default_lit], - ) + )) } /// Create an expression to represent the `lead` window function @@ -69,34 +69,34 @@ pub fn lead( arg: Expr, shift_offset: Option, default_value: Option, -) -> WindowFunction { +) -> Expr { let shift_offset_lit = shift_offset .map(|v| v.lit()) .unwrap_or(ScalarValue::Null.lit()); let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); - WindowFunction::new( + Expr::WindowFunction(WindowFunction::new( BuiltInWindowFunction::Lead, vec![arg, shift_offset_lit, default_lit], - ) + )) } /// Create an expression to represent the `first_value` window function /// /// Note: call [`WindowFunction::build]` to create an [`Expr`] -pub fn first_value(arg: Expr) -> WindowFunction { - WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![arg]) +pub fn first_value(arg: Expr) -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![arg])) } /// Create an expression to represent the `last_value` window function /// /// Note: call [`WindowFunction::build]` to create an [`Expr`] -pub fn last_value(arg: Expr) -> WindowFunction { - WindowFunction::new(BuiltInWindowFunction::LastValue, vec![arg]) +pub fn last_value(arg: Expr) -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::LastValue, vec![arg])) } /// Create an expression to represent the `nth_value` window function /// /// Note: call [`WindowFunction::build]` to create an [`Expr`] -pub fn nth_value(arg: Expr, n: i64) -> WindowFunction { - WindowFunction::new(BuiltInWindowFunction::NthValue, vec![arg, n.lit()]) +pub fn nth_value(arg: Expr, n: i64) -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::NthValue, vec![arg, n.lit()])) } From 99f1c79f77fc032c3a7de16042bbe6af29f19313 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 23 Jul 2024 09:12:18 -0400 Subject: [PATCH 08/21] Update documentation to show window function example --- datafusion/expr/src/expr_fn.rs | 45 ++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index ff0a7d589c4c..c4227cdabb75 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -667,32 +667,39 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) } -/// Extensions for configuring [`Expr::AggregateFunction`] +/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] /// -/// Adds methods to [`Expr`] that make it easy to set optional aggregate options +/// Adds methods to [`Expr`] that make it easy to set optional options /// such as `ORDER BY`, `FILTER` and `DISTINCT` /// /// # Example /// ```no_run -/// # use datafusion_common::Result; -/// # use datafusion_expr::{AggregateUDF, col, Expr, lit}; -/// # use sqlparser::ast::NullTreatment; -/// # fn count(arg: Expr) -> Expr { todo!{} } -/// # fn first_value(arg: Expr) -> Expr { todo!{} } +/// # use datafusion::{ +/// # common::Result, +/// # functions_aggregate::{count::count, expr_fn::first_value}, +/// # logical_expr::window_function::percent_rank, +/// # prelude::{col, lit, ExprFunctionExt}, +/// # sql::sqlparser::ast::NullTreatment, +/// # }; +/// /// # fn main() -> Result<()> { -/// use datafusion_expr::ExprFunctionExt; +/// // Create an aggregate count, filtering on column y > 5 +/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?; +/// +/// // Find the first value in an aggregate sorted by column y +/// let sort_expr = col("y").sort(true, true); +/// let agg = first_value(col("x"), None) +/// .order_by(vec![sort_expr]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; /// -/// // Create COUNT(x FILTER y > 5) -/// let agg = count(col("x")) -/// .filter(col("y").gt(lit(5))) -/// .build()?; -/// // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS) -/// let sort_expr = col("y").sort(true, true); -/// let agg = first_value(col("x")) -/// .order_by(vec![sort_expr]) -/// .null_treatment(NullTreatment::IgnoreNulls) -/// .build()?; -/// # Ok(()) +/// // Create a window expression for percent rank partitioned on column a +/// let window = percent_rank() +/// .partition_by(vec![col("a")]) +/// .order_by(vec![col("b")]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; +/// # Ok(()) /// # } /// ``` pub trait ExprFunctionExt { From fd9ebdf4033ccf7825c2a40ebd2a5d1c9634c0bd Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 23 Jul 2024 09:14:23 -0400 Subject: [PATCH 09/21] Add license info --- datafusion/expr/src/window_function.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 81d2bccbecea..08cf50c2f6cd 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use datafusion_common::ScalarValue; use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; From 4344a9f0149f1b74869bbbbf2166ee2d636ab600 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 24 Jul 2024 07:26:00 -0400 Subject: [PATCH 10/21] Update comments that are no longer applicable --- datafusion/expr/src/window_function.rs | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 08cf50c2f6cd..a5bc65c7294f 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -20,50 +20,36 @@ use datafusion_common::ScalarValue; use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; /// Create an expression to represent the `row_number` window function -/// -/// Note: call [`WindowFunction::build]` to create an [`Expr`] pub fn row_number() -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::RowNumber, vec![])) } /// Create an expression to represent the `rank` window function -/// -/// Note: call [`WindowFunction::build]` to create an [`Expr`] pub fn rank() -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Rank, vec![])) } /// Create an expression to represent the `dense_rank` window function -/// -/// Note: call [`WindowFunction::build]` to create an [`Expr`] pub fn dense_rank() -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::DenseRank, vec![])) } /// Create an expression to represent the `percent_rank` window function -/// -/// Note: call [`WindowFunction::build]` to create an [`Expr`] pub fn percent_rank() -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::PercentRank, vec![])) } /// Create an expression to represent the `cume_dist` window function -/// -/// Note: call [`WindowFunction::build]` to create an [`Expr`] pub fn cume_dist() -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![])) } /// Create an expression to represent the `ntile` window function -/// -/// Note: call [`WindowFunction::build]` to create an [`Expr`] pub fn ntile(arg: Expr) -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg])) } /// Create an expression to represent the `lag` window function -/// -/// Note: call [`WindowFunction::build]` to create an [`Expr`] pub fn lag( arg: Expr, shift_offset: Option, @@ -80,8 +66,6 @@ pub fn lag( } /// Create an expression to represent the `lead` window function -/// -/// Note: call [`WindowFunction::build]` to create an [`Expr`] pub fn lead( arg: Expr, shift_offset: Option, @@ -98,22 +82,16 @@ pub fn lead( } /// Create an expression to represent the `first_value` window function -/// -/// Note: call [`WindowFunction::build]` to create an [`Expr`] pub fn first_value(arg: Expr) -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![arg])) } /// Create an expression to represent the `last_value` window function -/// -/// Note: call [`WindowFunction::build]` to create an [`Expr`] pub fn last_value(arg: Expr) -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::LastValue, vec![arg])) } /// Create an expression to represent the `nth_value` window function -/// -/// Note: call [`WindowFunction::build]` to create an [`Expr`] pub fn nth_value(arg: Expr, n: i64) -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::NthValue, vec![arg, n.lit()])) } From a154ddc6e94e3518ed7fbd4784f196be3bedbdc4 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 24 Jul 2024 07:26:22 -0400 Subject: [PATCH 11/21] Remove first_value and last_value since these are already implemented in the aggregate functions --- datafusion/expr/src/window_function.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index a5bc65c7294f..e3d27d97d88f 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -81,16 +81,6 @@ pub fn lead( )) } -/// Create an expression to represent the `first_value` window function -pub fn first_value(arg: Expr) -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![arg])) -} - -/// Create an expression to represent the `last_value` window function -pub fn last_value(arg: Expr) -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::LastValue, vec![arg])) -} - /// Create an expression to represent the `nth_value` window function pub fn nth_value(arg: Expr, n: i64) -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::NthValue, vec![arg, n.lit()])) From 64cbc3696eab1ba05e4b458351ca770b7dda1caf Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 24 Jul 2024 07:56:05 -0400 Subject: [PATCH 12/21] Update to use WindowFunction::new to set additional parameters for order_by using ExprFunctionExt --- datafusion-examples/examples/advanced_udwf.rs | 11 ++++----- datafusion-examples/examples/simple_udwf.rs | 11 ++++----- datafusion/expr/src/expr.rs | 24 ++++++++++++++++++- datafusion/expr/src/udwf.rs | 23 ++++++------------ 4 files changed, 40 insertions(+), 29 deletions(-) diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 11fb6f6ccc48..5a723f952a51 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -216,12 +216,11 @@ async fn main() -> Result<()> { df.show().await?; // Now, run the function using the DataFrame API: - let window_expr = smooth_it.call( - vec![col("speed")], // smooth_it(speed) - vec![col("car")], // PARTITION BY car - vec![col("time").sort(true, true)], // ORDER BY time ASC - WindowFrame::new(None), - ); + let window_expr = smooth_it.call(vec![col("speed")]) // smooth_it(speed) + .partition_by(vec![col("car")]) // PARTITION BY car + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + .window_frame(WindowFrame::new(None)) + .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; // print the results diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index 563f02cee6a6..3f25419438c9 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -118,12 +118,11 @@ async fn main() -> Result<()> { df.show().await?; // Now, run the function using the DataFrame API: - let window_expr = smooth_it.call( - vec![col("speed")], // smooth_it(speed) - vec![col("car")], // PARTITION BY car - vec![col("time").sort(true, true)], // ORDER BY time ASC - WindowFrame::new(None), - ); + let window_expr = smooth_it.call(vec![col("speed")]) // smooth_it(speed) + .partition_by(vec![col("car")]) // PARTITION BY car + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + .window_frame(WindowFrame::new(None)) + .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; // print the results diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 694f849f12aa..349c88e751a5 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -794,6 +794,27 @@ impl From> for WindowFunctionDefinition { } /// Window function +/// +/// Holds the actual actual function to call +/// [`window_function::WindowFunction`] as well as its arguments +/// (`args`) and the contents of the `OVER` clause: +/// +/// 1. `PARTITION BY` +/// 2. `ORDER BY` +/// 3. Window frame (e.g. `ROWS 1 PRECEDING AND 1 FOLLOWING`) +/// +/// See [`Self::build`] to create an [`Expr`] +/// +/// # Example +/// ```/// # use datafusion_expr::expr::WindowFunction; +/// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c) +/// let expr: Expr = Expr::WindowFunction( +/// WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![col("a")]) +/// ) +/// .with_partition_by(vec![col("b")]) +/// .with_order_by(vec![col("b")]) +/// .build()?; +/// ``` #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { /// Name of the function @@ -811,7 +832,8 @@ pub struct WindowFunction { } impl WindowFunction { - /// Create a new Window expression + /// Create a new Window expression with the specified argument an + /// empty `OVER` clause pub fn new(fun: impl Into, args: Vec) -> Self { Self { fun: fun.into(), diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 1a6b21e3dd29..e527447a9e1f 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -28,9 +28,10 @@ use arrow::datatypes::DataType; use datafusion_common::Result; +use crate::expr::WindowFunction; use crate::{ function::WindowFunctionSimplification, Expr, PartitionEvaluator, - PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, + PartitionEvaluatorFactory, ReturnTypeFunction, Signature, }; /// Logical representation of a user-defined window function (UDWF) @@ -123,28 +124,18 @@ impl WindowUDF { Self::new_from_impl(AliasedWindowUDFImpl::new(Arc::clone(&self.inner), aliases)) } - /// creates a [`Expr`] that calls the window function given - /// the `partition_by`, `order_by`, and `window_frame` definition + /// creates a [`Expr`] that calls the window function with default + /// values for order_by, partition_by, window_frame. See [`ExprFunctionExt`] + /// for details on setting these values. /// /// This utility allows using the UDWF without requiring access to /// the registry, such as with the DataFrame API. pub fn call( &self, - args: Vec, - partition_by: Vec, - order_by: Vec, - window_frame: WindowFrame, - ) -> Expr { + args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(crate::expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: None, - }) + Expr::WindowFunction(WindowFunction::new(fun, args)) } /// Returns this function's name From 532f2626ee3c250f808729b312986937c58ef6e9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 24 Jul 2024 07:57:04 -0400 Subject: [PATCH 13/21] Apply cargo fmt --- datafusion-examples/examples/advanced_udwf.rs | 7 ++++--- datafusion-examples/examples/simple_udwf.rs | 7 ++++--- datafusion/expr/src/lib.rs | 3 +-- datafusion/expr/src/udaf.rs | 4 ++-- datafusion/expr/src/udwf.rs | 4 +--- datafusion/expr/src/window_function.rs | 20 +++++++++++++++---- 6 files changed, 28 insertions(+), 17 deletions(-) diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 5a723f952a51..ec0318a561b9 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -216,9 +216,10 @@ async fn main() -> Result<()> { df.show().await?; // Now, run the function using the DataFrame API: - let window_expr = smooth_it.call(vec![col("speed")]) // smooth_it(speed) - .partition_by(vec![col("car")]) // PARTITION BY car - .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + let window_expr = smooth_it + .call(vec![col("speed")]) // smooth_it(speed) + .partition_by(vec![col("car")]) // PARTITION BY car + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC .window_frame(WindowFrame::new(None)) .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index 3f25419438c9..22dfbbbf0c3a 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -118,9 +118,10 @@ async fn main() -> Result<()> { df.show().await?; // Now, run the function using the DataFrame API: - let window_expr = smooth_it.call(vec![col("speed")]) // smooth_it(speed) - .partition_by(vec![col("car")]) // PARTITION BY car - .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + let window_expr = smooth_it + .call(vec![col("speed")]) // smooth_it(speed) + .partition_by(vec![col("car")]) // PARTITION BY car + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC .window_frame(WindowFrame::new(None)) .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index bb148145c749..612d7538fcc1 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -16,7 +16,6 @@ // under the License. // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] - // TODO When the deprecated trait AggregateExt is removed, remove this unstable feature. #![feature(trait_alias)] @@ -90,9 +89,9 @@ pub use signature::{ }; pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; #[allow(deprecated)] pub use udaf::AggregateExt; +pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 5fc893243ad5..d2e628cd24d5 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -655,6 +655,6 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { } } -#[deprecated(since = "40.0.0", note="Use ExprFunctionExt instead.")] +#[deprecated(since = "40.0.0", note = "Use ExprFunctionExt instead.")] // pub trait AggregateExt : ExprFunctionExt {} -pub trait AggregateExt = ExprFunctionExt; \ No newline at end of file +pub trait AggregateExt = ExprFunctionExt; diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index e527447a9e1f..2ce6647b866c 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -130,9 +130,7 @@ impl WindowUDF { /// /// This utility allows using the UDWF without requiring access to /// the registry, such as with the DataFrame API. - pub fn call( - &self, - args: Vec) -> Expr { + pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); Expr::WindowFunction(WindowFunction::new(fun, args)) diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index e3d27d97d88f..5e81464d39c2 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -21,7 +21,10 @@ use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; /// Create an expression to represent the `row_number` window function pub fn row_number() -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::RowNumber, vec![])) + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::RowNumber, + vec![], + )) } /// Create an expression to represent the `rank` window function @@ -31,12 +34,18 @@ pub fn rank() -> Expr { /// Create an expression to represent the `dense_rank` window function pub fn dense_rank() -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::DenseRank, vec![])) + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::DenseRank, + vec![], + )) } /// Create an expression to represent the `percent_rank` window function pub fn percent_rank() -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::PercentRank, vec![])) + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::PercentRank, + vec![], + )) } /// Create an expression to represent the `cume_dist` window function @@ -83,5 +92,8 @@ pub fn lead( /// Create an expression to represent the `nth_value` window function pub fn nth_value(arg: Expr, n: i64) -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::NthValue, vec![arg, n.lit()])) + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::NthValue, + vec![arg, n.lit()], + )) } From acfcece2d8cd31bbfd3db6e50d772d2ea068ef27 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jul 2024 14:45:10 -0400 Subject: [PATCH 14/21] Fix up clippy --- datafusion/expr/src/expr_fn.rs | 2 +- datafusion/expr/src/lib.rs | 4 ---- datafusion/expr/src/udaf.rs | 8 ++------ 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c4227cdabb75..f2cd19ef1215 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -799,7 +799,7 @@ impl ExprFuncBuilder { Expr::AggregateFunction(udaf) } ExprFuncKind::Window(mut udwf) => { - let has_order_by = order_by.as_ref().map(|o| o.len() > 0); + let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); udwf.order_by = order_by.unwrap_or_default(); udwf.partition_by = partition_by.unwrap_or_default(); udwf.window_frame = diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 612d7538fcc1..0a5cf4653a22 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -16,8 +16,6 @@ // under the License. // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// TODO When the deprecated trait AggregateExt is removed, remove this unstable feature. -#![feature(trait_alias)] //! [DataFusion](https://github.com/apache/datafusion) //! is an extensible query execution framework that uses @@ -89,8 +87,6 @@ pub use signature::{ }; pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -#[allow(deprecated)] -pub use udaf::AggregateExt; pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 9c963d4a680a..e6604903757d 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -34,7 +34,7 @@ use crate::function::{ use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; -use crate::{Accumulator, Expr, ExprFunctionExt}; +use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; /// Logical representation of a user-defined [aggregate function] (UDAF). @@ -653,8 +653,4 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { (self.accumulator)(acc_args) } -} - -#[deprecated(since = "40.0.0", note = "Use ExprFunctionExt instead.")] -// pub trait AggregateExt : ExprFunctionExt {} -pub trait AggregateExt = ExprFunctionExt; +} \ No newline at end of file From 039f427f069e66fcf5691bb52df5069b2aeff3d1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jul 2024 14:54:03 -0400 Subject: [PATCH 15/21] fix doc example --- datafusion/expr/src/expr.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 858e65374e1b..8ca15da0c384 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -806,14 +806,17 @@ impl From> for WindowFunctionDefinition { /// See [`Self::build`] to create an [`Expr`] /// /// # Example -/// ```/// # use datafusion_expr::expr::WindowFunction; +/// ``` +/// # use datafusion_expr::{Expr, BuiltInWindowFunction, col, ExprFunctionExt}; +/// # use datafusion_expr::expr::WindowFunction; /// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c) -/// let expr: Expr = Expr::WindowFunction( +/// let expr = Expr::WindowFunction( /// WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![col("a")]) /// ) -/// .with_partition_by(vec![col("b")]) -/// .with_order_by(vec![col("b")]) -/// .build()?; +/// .partition_by(vec![col("b")]) +/// .order_by(vec![col("b").sort(true, true)]) +/// .build() +/// .unwrap(); /// ``` #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { From 75e364a22c9c4a754f490fc47f609bd5d806c3b5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jul 2024 14:54:24 -0400 Subject: [PATCH 16/21] fmt --- datafusion/core/src/dataframe/mod.rs | 3 ++- datafusion/core/tests/dataframe/mod.rs | 4 +++- datafusion/expr/src/udaf.rs | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 49b638946e78..ea437cc99a33 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1696,7 +1696,8 @@ mod tests { use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, ScalarFunctionImplementation, Volatility, WindowFunctionDefinition + cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, + ScalarFunctionImplementation, Volatility, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_physical_expr::expressions::Column; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index a844693dcb05..d83a47ceb069 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -54,7 +54,9 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition + cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, + scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e6604903757d..8867a478f790 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -653,4 +653,4 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { (self.accumulator)(acc_args) } -} \ No newline at end of file +} From 1a801f64550047d75cd1dd2f06664c78ffc7b79b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jul 2024 14:58:43 -0400 Subject: [PATCH 17/21] doc tweaks --- datafusion/expr/src/expr.rs | 4 ++++ datafusion/expr/src/expr_fn.rs | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 8ca15da0c384..2a7089d8e530 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -60,6 +60,10 @@ use sqlparser::ast::NullTreatment; /// use the fluent APIs in [`crate::expr_fn`] such as [`col`] and [`lit`], or /// methods such as [`Expr::alias`], [`Expr::cast_to`], and [`Expr::Like`]). /// +/// See also [`ExprFunctionExt`] for creating aggregate and window functions. +/// +/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt +/// /// # Schema Access /// /// See [`ExprSchemable::get_type`] to access the [`DataType`] and nullability diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index f2cd19ef1215..7bf08ac31f67 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -744,7 +744,6 @@ pub struct ExprFuncBuilder { impl ExprFuncBuilder { /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`] - fn new(fun: Option) -> Self { Self { fun, @@ -761,7 +760,7 @@ impl ExprFuncBuilder { /// /// # Errors: /// - /// Returns an error of this builder [`ExprFunctionExt`] was used with an + /// Returns an error if this builder [`ExprFunctionExt`] was used with an /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] pub fn build(self) -> Result { let Self { From d6898725463dc5006a2503e0460ce31d0f82efaa Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jul 2024 15:09:55 -0400 Subject: [PATCH 18/21] more doc tweaks --- datafusion/expr/src/expr.rs | 2 ++ datafusion/expr/src/expr_fn.rs | 52 ++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 2a7089d8e530..75ee201b5b46 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -287,6 +287,8 @@ pub enum Expr { /// This expression is guaranteed to have a fixed type. TryCast(TryCast), /// A sort expression, that can be used to sort values. + /// + /// See [Expr::sort] for more details Sort(Sort), /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 7bf08ac31f67..1f51cded2239 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -674,31 +674,35 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// /// # Example /// ```no_run -/// # use datafusion::{ -/// # common::Result, -/// # functions_aggregate::{count::count, expr_fn::first_value}, -/// # logical_expr::window_function::percent_rank, -/// # prelude::{col, lit, ExprFunctionExt}, -/// # sql::sqlparser::ast::NullTreatment, -/// # }; -/// +/// # use datafusion_common::Result; +/// # use datafusion_expr::test::function_stub::count; +/// # use sqlparser::ast::NullTreatment; +/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col}; +/// # use datafusion_expr::window_function::percent_rank; +/// # // first_value is an aggregate function in another crate +/// # fn first_value(_arg: Expr) -> Expr { +/// unimplemented!() } /// # fn main() -> Result<()> { -/// // Create an aggregate count, filtering on column y > 5 -/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?; +/// // Create an aggregate count, filtering on column y > 5 +/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?; /// -/// // Find the first value in an aggregate sorted by column y -/// let sort_expr = col("y").sort(true, true); -/// let agg = first_value(col("x"), None) -/// .order_by(vec![sort_expr]) -/// .null_treatment(NullTreatment::IgnoreNulls) -/// .build()?; +/// // Find the first value in an aggregate sorted by column y +/// // equivalent to: +/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)` +/// let sort_expr = col("y").sort(true, true); +/// let agg = first_value(col("x")) +/// .order_by(vec![sort_expr]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; /// -/// // Create a window expression for percent rank partitioned on column a -/// let window = percent_rank() -/// .partition_by(vec![col("a")]) -/// .order_by(vec![col("b")]) -/// .null_treatment(NullTreatment::IgnoreNulls) -/// .build()?; +/// // Create a window expression for percent rank partitioned on column a +/// // equivalent to: +/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)` +/// let window = percent_rank() +/// .partition_by(vec![col("a")]) +/// .order_by(vec![col("b").sort(true, true)]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; /// # Ok(()) /// # } /// ``` @@ -716,9 +720,9 @@ pub trait ExprFunctionExt { self, null_treatment: impl Into>, ) -> ExprFuncBuilder; - // Add `PARTITION BY` + /// Add `PARTITION BY` fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; - // Add appropriate window frame conditions + /// Add appropriate window frame conditions fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder; } From 4884c8aa13c4f48a3cbcc7b0d3c24dfd9f5e36dd Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jul 2024 15:14:35 -0400 Subject: [PATCH 19/21] fix up links --- datafusion/expr/src/expr.rs | 7 ++----- datafusion/expr/src/udwf.rs | 11 +++++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 75ee201b5b46..68d5504eea48 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -801,16 +801,13 @@ impl From> for WindowFunctionDefinition { /// Window function /// -/// Holds the actual actual function to call -/// [`window_function::WindowFunction`] as well as its arguments -/// (`args`) and the contents of the `OVER` clause: +/// Holds the actual actual function to call [`WindowFunction`] as well as its +/// arguments (`args`) and the contents of the `OVER` clause: /// /// 1. `PARTITION BY` /// 2. `ORDER BY` /// 3. Window frame (e.g. `ROWS 1 PRECEDING AND 1 FOLLOWING`) /// -/// See [`Self::build`] to create an [`Expr`] -/// /// # Example /// ``` /// # use datafusion_expr::{Expr, BuiltInWindowFunction, col, ExprFunctionExt}; diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 2ce6647b866c..af8c89870770 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -125,11 +125,14 @@ impl WindowUDF { } /// creates a [`Expr`] that calls the window function with default - /// values for order_by, partition_by, window_frame. See [`ExprFunctionExt`] - /// for details on setting these values. + /// values for `order_by`, `partition_by`, `window_frame`. /// - /// This utility allows using the UDWF without requiring access to - /// the registry, such as with the DataFrame API. + /// See [`ExprFunctionExt`] for details on setting these values. + /// + /// This utility allows using a user defined window function without + /// requiring access to the registry, such as with the DataFrame API. + /// + /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); From 9bfd1dd0cb0a25ffb8159c541bdcc79923676f0c Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jul 2024 15:42:00 -0400 Subject: [PATCH 20/21] fix integration test --- .../proto/tests/cases/roundtrip_logical_plan.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 12477013893d..7a4de4f61a38 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -2075,7 +2075,7 @@ fn roundtrip_window() { vec![], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) + .order_by(vec![col("col2").sort(true, false)]) .window_frame(WindowFrame::new(Some(false))) .build() .unwrap(); @@ -2088,7 +2088,7 @@ fn roundtrip_window() { vec![], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) + .order_by(vec![col("col2").sort(false, true)]) .window_frame(WindowFrame::new(Some(false))) .build() .unwrap(); @@ -2107,7 +2107,7 @@ fn roundtrip_window() { vec![], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) + .order_by(vec![col("col2").sort(false, false)]) .window_frame(range_number_frame) .build() .unwrap(); @@ -2124,7 +2124,7 @@ fn roundtrip_window() { vec![col("col1")], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) + .order_by(vec![col("col2").sort(true, true)]) .window_frame(row_number_frame.clone()) .build() .unwrap(); @@ -2174,7 +2174,7 @@ fn roundtrip_window() { vec![col("col1")], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) + .order_by(vec![col("col2").sort(true, true)]) .window_frame(row_number_frame.clone()) .build() .unwrap(); @@ -2251,7 +2251,7 @@ fn roundtrip_window() { vec![col("col1")], )) .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) + .order_by(vec![col("col2").sort(true, true)]) .window_frame(row_number_frame.clone()) .build() .unwrap(); From 77726ebd0be9c64f2e1d5e188d46d435f7c9c163 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jul 2024 15:46:24 -0400 Subject: [PATCH 21/21] fix anothr doc example --- datafusion/expr/src/udwf.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index af8c89870770..5abce013dfb6 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -202,7 +202,7 @@ where /// # use std::any::Any; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame}; +/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; /// #[derive(Debug, Clone)] /// struct SmoothIt { @@ -236,12 +236,13 @@ where /// let smooth_it = WindowUDF::from(SmoothIt::new()); /// /// // Call the function `add_one(col)` -/// let expr = smooth_it.call( -/// vec![col("speed")], // smooth_it(speed) -/// vec![col("car")], // PARTITION BY car -/// vec![col("time").sort(true, true)], // ORDER BY time ASC -/// WindowFrame::new(None), -/// ); +/// // smooth_it(speed) OVER (PARTITION BY car ORDER BY time ASC) +/// let expr = smooth_it.call(vec![col("speed")]) +/// .partition_by(vec![col("car")]) +/// .order_by(vec![col("time").sort(true, true)]) +/// .window_frame(WindowFrame::new(None)) +/// .build() +/// .unwrap(); /// ``` pub trait WindowUDFImpl: Debug + Send + Sync { /// Returns this object as an [`Any`] trait object