From 999d92c14532ee189e086217a48492d25e250c2f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 13 Aug 2024 11:52:48 -0400 Subject: [PATCH] Add window function as template for others and function builder --- python/datafusion/expr.py | 20 ++++- python/datafusion/functions.py | 25 ++++-- python/datafusion/tests/test_dataframe.py | 91 ++++++++++++---------- src/expr.rs | 95 ++++++++++++++++++++++- src/functions.rs | 8 ++ 5 files changed, 191 insertions(+), 48 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 84bb7022..f37d4576 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -176,7 +176,10 @@ def display_name(self) -> str: This name will not include any CAST expressions. """ import warnings - warnings.warn("deprecated since 40.0.0: use schema_name instead", DeprecationWarning) + + warnings.warn( + "deprecated since 40.0.0: use schema_name instead", DeprecationWarning + ) return self.schema_name() def schema_name(self) -> str: @@ -364,6 +367,10 @@ def is_null(self) -> Expr: """Returns ``True`` if this expression is null.""" return Expr(self.expr.is_null()) + def is_not_null(self) -> Expr: + """Returns ``True`` if this expression is not null.""" + return Expr(self.expr.is_not_null()) + def cast(self, to: pa.DataType[Any]) -> Expr: """Cast to a new data type.""" return Expr(self.expr.cast(to)) @@ -414,6 +421,17 @@ def column_name(self, plan: LogicalPlan) -> str: """Compute the output column name based on the provided logical plan.""" return self.expr.column_name(plan) + def order_by(self, *exprs: Expr) -> ExprFuncBuilder: + return ExprFuncBuilder(self.expr.order_by(list(e.expr for e in exprs))) + + +class ExprFuncBuilder: + def __init__(self, builder: expr_internal.ExprFuncBuilder): + self.builder = builder + + def build(self) -> Expr: + return Expr(self.builder.build()) + class WindowFrame: """Defines a window frame for performing window operations.""" diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 2d3d87ee..af2e29df 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -27,6 +27,11 @@ from datafusion.expr import CaseBuilder, Expr, WindowFrame from datafusion.context import SessionContext +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import pyarrow as pa + __all__ = [ "abs", "acos", @@ -246,6 +251,7 @@ "var_pop", "var_samp", "window", + "lead", ] @@ -1011,12 +1017,12 @@ def struct(*args: Expr) -> Expr: return Expr(f.struct(*args)) -def named_struct(name_pairs: list[(str, Expr)]) -> Expr: +def named_struct(name_pairs: list[tuple[str, Expr]]) -> Expr: """Returns a struct with the given names and arguments pairs.""" - name_pairs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs] + name_pair_exprs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs] # flatten - name_pairs = [x.expr for xs in name_pairs for x in xs] + name_pairs = [x.expr for xs in name_pair_exprs for x in xs] return Expr(f.named_struct(*name_pairs)) @@ -1479,12 +1485,17 @@ def approx_percentile_cont( """Returns the value that is approximately at a given percentile of ``expr``.""" if num_centroids is None: return Expr( - f.approx_percentile_cont(expression.expr, percentile.expr, distinct=distinct, num_centroids=None) + f.approx_percentile_cont( + expression.expr, percentile.expr, distinct=distinct, num_centroids=None + ) ) return Expr( f.approx_percentile_cont( - expression.expr, percentile.expr, distinct=distinct, num_centroids=num_centroids.expr + expression.expr, + percentile.expr, + distinct=distinct, + num_centroids=num_centroids.expr, ) ) @@ -1732,3 +1743,7 @@ def bool_and(arg: Expr, distinct: bool = False) -> Expr: def bool_or(arg: Expr, distinct: bool = False) -> Expr: """Computes the boolean OR of the arguement.""" return Expr(f.bool_or(arg.expr, distinct=distinct)) + + +def lead(arg: Expr, shift_offset: int = 1, default_value: pa.Scalar | None = None): + return Expr(f.lead(arg.expr, shift_offset, default_value)) diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 6444d932..8b8dc54c 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -279,57 +279,68 @@ def test_distinct(): data_test_window_functions = [ - ("row", f.window("row_number", [], order_by=[f.order_by(column("c"))]), [2, 1, 3]), - ("rank", f.window("rank", [], order_by=[f.order_by(column("c"))]), [2, 1, 2]), - ("dense_rank", f.window("dense_rank", [], order_by=[f.order_by(column("c"))]), [2, 1, 2] ), - ("percent_rank", f.window("percent_rank", [], order_by=[f.order_by(column("c"))]), [0.5, 0, 0.5]), - ("cume_dist", f.window("cume_dist", [], order_by=[f.order_by(column("b"))]), [0.3333333333333333, 0.6666666666666666, 1.0]), - ("ntile", f.window("ntile", [literal(2)], order_by=[f.order_by(column("c"))]), [1, 1, 2]), - ("next", f.window("lead", [column("b")], order_by=[f.order_by(column("b"))]), [5, 6, None]), - ("previous", f.window("lag", [column("b")], order_by=[f.order_by(column("b"))]), [None, 4, 5]), - pytest.param( - "first_value", - f.window( + ("row", f.window("row_number", [], order_by=[f.order_by(column("c"))]), [2, 1, 3]), + ("rank", f.window("rank", [], order_by=[f.order_by(column("c"))]), [2, 1, 2]), + ( + "dense_rank", + f.window("dense_rank", [], order_by=[f.order_by(column("c"))]), + [2, 1, 2], + ), + ( + "percent_rank", + f.window("percent_rank", [], order_by=[f.order_by(column("c"))]), + [0.5, 0, 0.5], + ), + ( + "cume_dist", + f.window("cume_dist", [], order_by=[f.order_by(column("b"))]), + [0.3333333333333333, 0.6666666666666666, 1.0], + ), + ( + "ntile", + f.window("ntile", [literal(2)], order_by=[f.order_by(column("c"))]), + [1, 1, 2], + ), + ( + "next", + f.window("lead", [column("b")], order_by=[f.order_by(column("b"))]), + [5, 6, None], + ), + ("lead", f.lead(column("b")).order_by(column("b").sort()).build(), [5, 6, None]), + ( + "previous", + f.window("lag", [column("b")], order_by=[f.order_by(column("b"))]), + [None, 4, 5], + ), + pytest.param( "first_value", - [column("a")], - order_by=[f.order_by(column("b"))] + f.window("first_value", [column("a")], order_by=[f.order_by(column("b"))]), + [1, 1, 1], + ), + pytest.param( + "last_value", + f.window("last_value", [column("b")], order_by=[f.order_by(column("b"))]), + [4, 5, 6], ), - [1, 1, 1], - ), - pytest.param( - "last_value", - f.window("last_value", [column("b")], order_by=[f.order_by(column("b"))]), - [4, 5, 6], - ), - pytest.param( - "2nd_value", - f.window( - "nth_value", - [column("b"), literal(2)], - order_by=[f.order_by(column("b"))], + pytest.param( + "2nd_value", + f.window( + "nth_value", + [column("b"), literal(2)], + order_by=[f.order_by(column("b"))], + ), + [None, 5, 5], ), - [None, 5, 5], - ), ] @pytest.mark.parametrize("name,expr,result", data_test_window_functions) def test_window_functions(df, name, expr, result): - df = df.select( - column("a"), - column("b"), - column("c"), - f.alias(expr, name) - ) + df = df.select(column("a"), column("b"), column("c"), f.alias(expr, name)) table = pa.Table.from_batches(df.collect()) - expected = { - "a": [1, 2, 3], - "b": [4, 5, 6], - "c": [8, 5, 8], - name: result - } + expected = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [8, 5, 8], name: result} assert table.sort_by("a").to_pydict() == expected diff --git a/src/expr.rs b/src/expr.rs index 487db482..88ba85b0 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -16,10 +16,11 @@ // under the License. use datafusion_expr::utils::exprlist_to_fields; -use datafusion_expr::LogicalPlan; +use datafusion_expr::{ExprFuncBuilder, ExprFunctionExt, LogicalPlan}; use pyo3::{basic::CompareOp, prelude::*}; use std::convert::{From, Into}; use std::sync::Arc; +use window::PyWindowFrame; use arrow::pyarrow::ToPyArrow; use datafusion::arrow::datatypes::{DataType, Field}; @@ -32,7 +33,7 @@ use datafusion_expr::{ lit, Between, BinaryExpr, Case, Cast, Expr, Like, Operator, TryCast, }; -use crate::common::data_type::{DataTypeMap, RexType}; +use crate::common::data_type::{DataTypeMap, NullTreatment, RexType}; use crate::errors::{py_runtime_err, py_type_err, py_unsupported_variant_err, DataFusionError}; use crate::expr::aggregate_expr::PyAggregateFunction; use crate::expr::binary_expr::PyBinaryExpr; @@ -281,6 +282,10 @@ impl PyExpr { self.expr.clone().is_null().into() } + pub fn is_not_null(&self) -> PyExpr { + self.expr.clone().is_not_null().into() + } + pub fn cast(&self, to: PyArrowType) -> PyExpr { // self.expr.cast_to() requires DFSchema to validate that the cast // is supported, omit that for now @@ -510,6 +515,92 @@ impl PyExpr { pub fn column_name(&self, plan: PyLogicalPlan) -> PyResult { self._column_name(&plan.plan()).map_err(py_runtime_err) } + + // Expression Function Builder functions + + pub fn order_by(&self, order_by: Vec) -> PyExprFuncBuilder { + let order_by = order_by.iter().map(|e| e.expr.clone()).collect(); + self.expr.clone().order_by(order_by).into() + } + + pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder { + self.expr.clone().filter(filter.expr.clone()).into() + } + + pub fn distinct(&self) -> PyExprFuncBuilder { + self.expr.clone().distinct().into() + } + + pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder { + self.expr + .clone() + .null_treatment(Some(null_treatment.into())) + .into() + } + + pub fn partition_by(&self, partition_by: Vec) -> PyExprFuncBuilder { + let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect(); + self.expr.clone().partition_by(partition_by).into() + } + + pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder { + self.expr.clone().window_frame(window_frame.into()).into() + } +} + +#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)] +#[derive(Debug, Clone)] +pub struct PyExprFuncBuilder { + pub builder: ExprFuncBuilder, +} + +impl From for PyExprFuncBuilder { + fn from(builder: ExprFuncBuilder) -> Self { + Self { builder } + } +} + +#[pymethods] +impl PyExprFuncBuilder { + pub fn order_by(&self, order_by: Vec) -> PyExprFuncBuilder { + let order_by = order_by.iter().map(|e| e.expr.clone()).collect(); + self.builder.clone().order_by(order_by).into() + } + + pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder { + self.builder.clone().filter(filter.expr.clone()).into() + } + + pub fn distinct(&self) -> PyExprFuncBuilder { + self.builder.clone().distinct().into() + } + + pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder { + self.builder + .clone() + .null_treatment(Some(null_treatment.into())) + .into() + } + + pub fn partition_by(&self, partition_by: Vec) -> PyExprFuncBuilder { + let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect(); + self.builder.clone().partition_by(partition_by).into() + } + + pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder { + self.builder + .clone() + .window_frame(window_frame.into()) + .into() + } + + pub fn build(&self) -> PyResult { + self.builder + .clone() + .build() + .map(|expr| expr.into()) + .map_err(|err| err.into()) + } } impl PyExpr { diff --git a/src/functions.rs b/src/functions.rs index c53d4ad9..0b56b782 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -16,6 +16,7 @@ // under the License. use datafusion::functions_aggregate::all_default_aggregate_functions; +use datafusion_expr::window_function; use datafusion_expr::ExprFunctionExt; use pyo3::{prelude::*, wrap_pyfunction}; @@ -882,6 +883,11 @@ aggregate_function!(array_agg, functions_aggregate::array_agg::array_agg_udaf); aggregate_function!(max, functions_aggregate::min_max::max_udaf); aggregate_function!(min, functions_aggregate::min_max::min_udaf); +#[pyfunction] +pub fn lead(arg: PyExpr, shift_offset: i64, default_value: Option) -> PyExpr { + window_function::lead(arg.expr, Some(shift_offset), default_value).into() +} + pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(abs))?; m.add_wrapped(wrap_pyfunction!(acos))?; @@ -1066,5 +1072,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(array_slice))?; m.add_wrapped(wrap_pyfunction!(flatten))?; + m.add_wrapped(wrap_pyfunction!(lead))?; + Ok(()) }