Skip to content

Commit

Permalink
Add to turn any aggregate function into a window function
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer committed Sep 12, 2024
1 parent 8aebaea commit e4df79a
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 34 deletions.
37 changes: 37 additions & 0 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,43 @@ def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder:
"""
return ExprFuncBuilder(self.expr.window_frame(window_frame.window_frame))

def over(
self,
partition_by: Optional[list[Expr]] = None,
window_frame: Optional[WindowFrame] = None,
order_by: Optional[list[SortExpr | Expr]] = None,
null_treatment: Optional[NullTreatment] = None,
) -> Expr:
"""Turn an aggregate function into a window function.
This function turns any aggregate function into a window function. With the
exception of ``partition_by``, how each of the parameters is used is determined
by the underlying aggregate function.
Args:
partition_by: Expressions to partition the window frame on
window_frame: Specify the window frame parameters
order_by: Set ordering within the window frame
null_treatment: Set how to handle null values
"""
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)
window_frame_raw = (
window_frame.window_frame if window_frame is not None else None
)
null_treatment_raw = (
null_treatment.value if null_treatment is not None else None
)

return Expr(
self.expr.over(
partition_by=partition_by_raw,
order_by=order_by_raw,
window_frame=window_frame_raw,
null_treatment=null_treatment_raw,
)
)


class ExprFuncBuilder:
def __init__(self, builder: expr_internal.ExprFuncBuilder):
Expand Down
34 changes: 13 additions & 21 deletions python/datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,38 +386,30 @@ def test_distinct():
),
[-1, -1, None, 7, -1, -1, None],
),
# TODO update all aggregate functions as windows once upstream merges https://github.com/apache/datafusion-python/issues/833
pytest.param(
(
"first_value",
f.window(
"first_value",
[column("a")],
order_by=[f.order_by(column("b"))],
partition_by=[column("c")],
f.first_value(column("a")).over(
partition_by=[column("c")], order_by=[column("b")]
),
[1, 1, 1, 1, 5, 5, 5],
),
pytest.param(
(
"last_value",
f.window("last_value", [column("a")])
.window_frame(WindowFrame("rows", 0, None))
.order_by(column("b"))
.partition_by(column("c"))
.build(),
f.last_value(column("a")).over(
partition_by=[column("c")],
order_by=[column("b")],
window_frame=WindowFrame("rows", None, None),
),
[3, 3, 3, 3, 6, 6, 6],
),
pytest.param(
(
"3rd_value",
f.window(
"nth_value",
[column("b"), literal(3)],
order_by=[f.order_by(column("a"))],
),
f.nth_value(column("b"), 3).over(order_by=[column("a")]),
[None, None, 7, 7, 7, 7, 7],
),
pytest.param(
(
"avg",
f.round(f.window("avg", [column("b")], order_by=[column("a")]), literal(3)),
f.round(f.avg(column("b")).over(order_by=[column("a")]), literal(3)),
[7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0],
),
]
Expand Down
44 changes: 43 additions & 1 deletion src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
// under the License.

use datafusion::logical_expr::utils::exprlist_to_fields;
use datafusion::logical_expr::{ExprFuncBuilder, ExprFunctionExt, LogicalPlan};
use datafusion::logical_expr::{
ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
};
use pyo3::{basic::CompareOp, prelude::*};
use std::convert::{From, Into};
use std::sync::Arc;
Expand All @@ -39,6 +41,7 @@ use crate::expr::aggregate_expr::PyAggregateFunction;
use crate::expr::binary_expr::PyBinaryExpr;
use crate::expr::column::PyColumn;
use crate::expr::literal::PyLiteral;
use crate::functions::add_builder_fns_to_window;
use crate::sql::logical::PyLogicalPlan;

use self::alias::PyAlias;
Expand Down Expand Up @@ -547,6 +550,45 @@ impl PyExpr {
pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
self.expr.clone().window_frame(window_frame.into()).into()
}

#[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, null_treatment=None))]
pub fn over(
&self,
partition_by: Option<Vec<PyExpr>>,
window_frame: Option<PyWindowFrame>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
match &self.expr {
Expr::AggregateFunction(agg_fn) => {
let window_fn = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
agg_fn.args.clone(),
));

add_builder_fns_to_window(
window_fn,
partition_by,
window_frame,
order_by,
null_treatment,
)
}
Expr::WindowFunction(_) => add_builder_fns_to_window(
self.expr.clone(),
partition_by,
window_frame,
order_by,
null_treatment,
),
_ => Err(
DataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan(
"Using `over` requires an aggregate function.".to_string(),
))
.into(),
),
}
}
}

#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]
Expand Down
31 changes: 19 additions & 12 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use std::ptr::null;

Check warning on line 18 in src/functions.rs

View workflow job for this annotation

GitHub Actions / test-matrix (3.10, stable)

unused import: `std::ptr::null`

use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::logical_expr::window_function;
use datafusion::logical_expr::ExprFunctionExt;
Expand Down Expand Up @@ -711,14 +713,15 @@ pub fn string_agg(
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
}

fn add_builder_fns_to_window(
pub(crate) fn add_builder_fns_to_window(
window_fn: Expr,
partition_by: Option<Vec<PyExpr>>,
window_frame: Option<PyWindowFrame>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
// Since ExprFuncBuilder::new() is private, set an empty partition and then
// override later if appropriate.
let mut builder = window_fn.partition_by(vec![]);
let null_treatment = null_treatment.map(|n| n.into());
let mut builder = window_fn.null_treatment(null_treatment);

if let Some(partition_cols) = partition_by {
builder = builder.partition_by(
Expand All @@ -734,6 +737,10 @@ fn add_builder_fns_to_window(
builder = builder.order_by(order_by_cols);
}

if let Some(window_frame) = window_frame {
builder = builder.window_frame(window_frame.into());
}

builder.build().map(|e| e.into()).map_err(|err| err.into())
}

Expand All @@ -748,7 +755,7 @@ pub fn lead(
) -> PyResult<PyExpr> {
let window_fn = window_function::lead(arg.expr, Some(shift_offset), default_value);

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -762,7 +769,7 @@ pub fn lag(
) -> PyResult<PyExpr> {
let window_fn = window_function::lag(arg.expr, Some(shift_offset), default_value);

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -773,7 +780,7 @@ pub fn row_number(
) -> PyResult<PyExpr> {
let window_fn = datafusion::functions_window::expr_fn::row_number();

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -784,7 +791,7 @@ pub fn rank(
) -> PyResult<PyExpr> {
let window_fn = window_function::rank();

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -795,7 +802,7 @@ pub fn dense_rank(
) -> PyResult<PyExpr> {
let window_fn = window_function::dense_rank();

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -806,7 +813,7 @@ pub fn percent_rank(
) -> PyResult<PyExpr> {
let window_fn = window_function::percent_rank();

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -817,7 +824,7 @@ pub fn cume_dist(
) -> PyResult<PyExpr> {
let window_fn = window_function::cume_dist();

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

#[pyfunction]
Expand All @@ -829,7 +836,7 @@ pub fn ntile(
) -> PyResult<PyExpr> {
let window_fn = window_function::ntile(arg.into());

add_builder_fns_to_window(window_fn, partition_by, order_by)
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
}

pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand Down

0 comments on commit e4df79a

Please sign in to comment.