From a00cfbfdbf3143c8c56d9ea043be3fb69da008ee Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 18 Sep 2024 16:56:43 -0400 Subject: [PATCH] feat: aggregates as windows (#871) * Add to turn any aggregate function into a window function * Rename Window to WindowExpr so we can define Window to mean a window definition to be reused * Add unit test to cover default frames * Improve error report --- python/datafusion/expr.py | 57 ++++++++++++++++- python/datafusion/tests/test_dataframe.py | 75 ++++++++++++++++------- src/expr.rs | 46 +++++++++++++- src/expr/window.rs | 20 +++--- src/functions.rs | 29 +++++---- src/sql/logical.rs | 4 +- 6 files changed, 183 insertions(+), 48 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index fd5e6f04..152aa38d 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -92,7 +92,7 @@ Union = expr_internal.Union Unnest = expr_internal.Unnest UnnestExpr = expr_internal.UnnestExpr -Window = expr_internal.Window +WindowExpr = expr_internal.WindowExpr __all__ = [ "Expr", @@ -154,6 +154,7 @@ "Partitioning", "Repartition", "Window", + "WindowExpr", "WindowFrame", "WindowFrameBound", ] @@ -542,6 +543,36 @@ def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder: """ return ExprFuncBuilder(self.expr.window_frame(window_frame.window_frame)) + def over(self, window: Window) -> Expr: + """Turn an aggregate function into a window function. + + This function turns any aggregate function into a window function. With the + exception of ``partition_by``, how each of the parameters is used is determined + by the underlying aggregate function. + + Args: + window: Window definition + """ + partition_by_raw = expr_list_to_raw_expr_list(window._partition_by) + order_by_raw = sort_list_to_raw_sort_list(window._order_by) + window_frame_raw = ( + window._window_frame.window_frame + if window._window_frame is not None + else None + ) + null_treatment_raw = ( + window._null_treatment.value if window._null_treatment is not None else None + ) + + return Expr( + self.expr.over( + partition_by=partition_by_raw, + order_by=order_by_raw, + window_frame=window_frame_raw, + null_treatment=null_treatment_raw, + ) + ) + class ExprFuncBuilder: def __init__(self, builder: expr_internal.ExprFuncBuilder): @@ -584,6 +615,30 @@ def build(self) -> Expr: return Expr(self.builder.build()) +class Window: + """Define reusable window parameters.""" + + def __init__( + self, + partition_by: Optional[list[Expr]] = None, + window_frame: Optional[WindowFrame] = None, + order_by: Optional[list[SortExpr | Expr]] = None, + null_treatment: Optional[NullTreatment] = None, + ) -> None: + """Construct a window definition. + + Args: + partition_by: Partitions for window operation + window_frame: Define the start and end bounds of the window frame + order_by: Set ordering + null_treatment: Indicate how nulls are to be treated + """ + self._partition_by = partition_by + self._window_frame = window_frame + self._order_by = order_by + self._null_treatment = null_treatment + + class WindowFrame: """Defines a window frame for performing window operations.""" diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 90954d09..ad7f728b 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -31,6 +31,7 @@ literal, udf, ) +from datafusion.expr import Window @pytest.fixture @@ -386,38 +387,32 @@ def test_distinct(): ), [-1, -1, None, 7, -1, -1, None], ), - # TODO update all aggregate functions as windows once upstream merges https://github.com/apache/datafusion-python/issues/833 - pytest.param( + ( "first_value", - f.window( - "first_value", - [column("a")], - order_by=[f.order_by(column("b"))], - partition_by=[column("c")], + f.first_value(column("a")).over( + Window(partition_by=[column("c")], order_by=[column("b")]) ), [1, 1, 1, 1, 5, 5, 5], ), - pytest.param( + ( "last_value", - f.window("last_value", [column("a")]) - .window_frame(WindowFrame("rows", 0, None)) - .order_by(column("b")) - .partition_by(column("c")) - .build(), + f.last_value(column("a")).over( + Window( + partition_by=[column("c")], + order_by=[column("b")], + window_frame=WindowFrame("rows", None, None), + ) + ), [3, 3, 3, 3, 6, 6, 6], ), - pytest.param( + ( "3rd_value", - f.window( - "nth_value", - [column("b"), literal(3)], - order_by=[f.order_by(column("a"))], - ), + f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])), [None, None, 7, 7, 7, 7, 7], ), - pytest.param( + ( "avg", - f.round(f.window("avg", [column("b")], order_by=[column("a")]), literal(3)), + f.round(f.avg(column("b")).over(Window(order_by=[column("a")])), literal(3)), [7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0], ), ] @@ -473,6 +468,44 @@ def test_invalid_window_frame(units, start_bound, end_bound): WindowFrame(units, start_bound, end_bound) +def test_window_frame_defaults_match_postgres(partitioned_df): + # ref: https://github.com/apache/datafusion-python/issues/688 + + window_frame = WindowFrame("rows", None, None) + + col_a = column("a") + + # Using `f.window` with or without an unbounded window_frame produces the same + # results. These tests are included as a regression check but can be removed when + # f.window() is deprecated in favor of using the .over() approach. + no_frame = f.window("avg", [col_a]).alias("no_frame") + with_frame = f.window("avg", [col_a], window_frame=window_frame).alias("with_frame") + df_1 = partitioned_df.select(col_a, no_frame, with_frame) + + expected = { + "a": [0, 1, 2, 3, 4, 5, 6], + "no_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], + "with_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], + } + + assert df_1.sort(col_a).to_pydict() == expected + + # When order is not set, the default frame should be unounded preceeding to + # unbounded following. When order is set, the default frame is unbounded preceeding + # to current row. + no_order = f.avg(col_a).over(Window()).alias("over_no_order") + with_order = f.avg(col_a).over(Window(order_by=[col_a])).alias("over_with_order") + df_2 = partitioned_df.select(col_a, no_order, with_order) + + expected = { + "a": [0, 1, 2, 3, 4, 5, 6], + "over_no_order": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], + "over_with_order": [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0], + } + + assert df_2.sort(col_a).to_pydict() == expected + + def test_get_dataframe(tmp_path): ctx = SessionContext() diff --git a/src/expr.rs b/src/expr.rs index 304d147c..49fa4b84 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -16,7 +16,9 @@ // under the License. use datafusion::logical_expr::utils::exprlist_to_fields; -use datafusion::logical_expr::{ExprFuncBuilder, ExprFunctionExt, LogicalPlan}; +use datafusion::logical_expr::{ + ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition, +}; use pyo3::{basic::CompareOp, prelude::*}; use std::convert::{From, Into}; use std::sync::Arc; @@ -39,6 +41,7 @@ use crate::expr::aggregate_expr::PyAggregateFunction; use crate::expr::binary_expr::PyBinaryExpr; use crate::expr::column::PyColumn; use crate::expr::literal::PyLiteral; +use crate::functions::add_builder_fns_to_window; use crate::sql::logical::PyLogicalPlan; use self::alias::PyAlias; @@ -558,6 +561,45 @@ impl PyExpr { pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder { self.expr.clone().window_frame(window_frame.into()).into() } + + #[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, null_treatment=None))] + pub fn over( + &self, + partition_by: Option>, + window_frame: Option, + order_by: Option>, + null_treatment: Option, + ) -> PyResult { + match &self.expr { + Expr::AggregateFunction(agg_fn) => { + let window_fn = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()), + agg_fn.args.clone(), + )); + + add_builder_fns_to_window( + window_fn, + partition_by, + window_frame, + order_by, + null_treatment, + ) + } + Expr::WindowFunction(_) => add_builder_fns_to_window( + self.expr.clone(), + partition_by, + window_frame, + order_by, + null_treatment, + ), + _ => Err( + DataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan( + format!("Using {} with `over` is not allowed. Must use an aggregate or window function.", self.expr.variant_name()), + )) + .into(), + ), + } + } } #[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)] @@ -749,7 +791,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; Ok(()) diff --git a/src/expr/window.rs b/src/expr/window.rs index 950db12a..6486dbb3 100644 --- a/src/expr/window.rs +++ b/src/expr/window.rs @@ -32,9 +32,9 @@ use super::py_expr_list; use crate::errors::py_datafusion_err; -#[pyclass(name = "Window", module = "datafusion.expr", subclass)] +#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)] #[derive(Clone)] -pub struct PyWindow { +pub struct PyWindowExpr { window: Window, } @@ -62,15 +62,15 @@ pub struct PyWindowFrameBound { frame_bound: WindowFrameBound, } -impl From for Window { - fn from(window: PyWindow) -> Window { +impl From for Window { + fn from(window: PyWindowExpr) -> Window { window.window } } -impl From for PyWindow { - fn from(window: Window) -> PyWindow { - PyWindow { window } +impl From for PyWindowExpr { + fn from(window: Window) -> PyWindowExpr { + PyWindowExpr { window } } } @@ -80,7 +80,7 @@ impl From for PyWindowFrameBound { } } -impl Display for PyWindow { +impl Display for PyWindowExpr { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!( f, @@ -103,7 +103,7 @@ impl Display for PyWindowFrame { } #[pymethods] -impl PyWindow { +impl PyWindowExpr { /// Returns the schema of the Window pub fn schema(&self) -> PyResult { Ok(self.window.schema.as_ref().clone().into()) @@ -283,7 +283,7 @@ impl PyWindowFrameBound { } } -impl LogicalNode for PyWindow { +impl LogicalNode for PyWindowExpr { fn inputs(&self) -> Vec { vec![self.window.input.as_ref().clone().into()] } diff --git a/src/functions.rs b/src/functions.rs index 32f6519f..6f8dd7ad 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -711,14 +711,15 @@ pub fn string_agg( add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) } -fn add_builder_fns_to_window( +pub(crate) fn add_builder_fns_to_window( window_fn: Expr, partition_by: Option>, + window_frame: Option, order_by: Option>, + null_treatment: Option, ) -> PyResult { - // Since ExprFuncBuilder::new() is private, set an empty partition and then - // override later if appropriate. - let mut builder = window_fn.partition_by(vec![]); + let null_treatment = null_treatment.map(|n| n.into()); + let mut builder = window_fn.null_treatment(null_treatment); if let Some(partition_cols) = partition_by { builder = builder.partition_by( @@ -734,6 +735,10 @@ fn add_builder_fns_to_window( builder = builder.order_by(order_by_cols); } + if let Some(window_frame) = window_frame { + builder = builder.window_frame(window_frame.into()); + } + builder.build().map(|e| e.into()).map_err(|err| err.into()) } @@ -748,7 +753,7 @@ pub fn lead( ) -> PyResult { let window_fn = window_function::lead(arg.expr, Some(shift_offset), default_value); - add_builder_fns_to_window(window_fn, partition_by, order_by) + add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) } #[pyfunction] @@ -762,7 +767,7 @@ pub fn lag( ) -> PyResult { let window_fn = window_function::lag(arg.expr, Some(shift_offset), default_value); - add_builder_fns_to_window(window_fn, partition_by, order_by) + add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) } #[pyfunction] @@ -773,7 +778,7 @@ pub fn row_number( ) -> PyResult { let window_fn = datafusion::functions_window::expr_fn::row_number(); - add_builder_fns_to_window(window_fn, partition_by, order_by) + add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) } #[pyfunction] @@ -784,7 +789,7 @@ pub fn rank( ) -> PyResult { let window_fn = window_function::rank(); - add_builder_fns_to_window(window_fn, partition_by, order_by) + add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) } #[pyfunction] @@ -795,7 +800,7 @@ pub fn dense_rank( ) -> PyResult { let window_fn = window_function::dense_rank(); - add_builder_fns_to_window(window_fn, partition_by, order_by) + add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) } #[pyfunction] @@ -806,7 +811,7 @@ pub fn percent_rank( ) -> PyResult { let window_fn = window_function::percent_rank(); - add_builder_fns_to_window(window_fn, partition_by, order_by) + add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) } #[pyfunction] @@ -817,7 +822,7 @@ pub fn cume_dist( ) -> PyResult { let window_fn = window_function::cume_dist(); - add_builder_fns_to_window(window_fn, partition_by, order_by) + add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) } #[pyfunction] @@ -829,7 +834,7 @@ pub fn ntile( ) -> PyResult { let window_fn = window_function::ntile(arg.into()); - add_builder_fns_to_window(window_fn, partition_by, order_by) + add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) } pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { diff --git a/src/sql/logical.rs b/src/sql/logical.rs index 89655ab7..d00f0af3 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -34,7 +34,7 @@ use crate::expr::subquery::PySubquery; use crate::expr::subquery_alias::PySubqueryAlias; use crate::expr::table_scan::PyTableScan; use crate::expr::unnest::PyUnnest; -use crate::expr::window::PyWindow; +use crate::expr::window::PyWindowExpr; use datafusion::logical_expr::LogicalPlan; use pyo3::prelude::*; @@ -80,7 +80,7 @@ impl PyLogicalPlan { LogicalPlan::Subquery(plan) => PySubquery::from(plan.clone()).to_variant(py), LogicalPlan::SubqueryAlias(plan) => PySubqueryAlias::from(plan.clone()).to_variant(py), LogicalPlan::Unnest(plan) => PyUnnest::from(plan.clone()).to_variant(py), - LogicalPlan::Window(plan) => PyWindow::from(plan.clone()).to_variant(py), + LogicalPlan::Window(plan) => PyWindowExpr::from(plan.clone()).to_variant(py), LogicalPlan::Repartition(_) | LogicalPlan::Union(_) | LogicalPlan::Statement(_)