Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/aggregates as windows #871

Merged
merged 4 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
Union = expr_internal.Union
Unnest = expr_internal.Unnest
UnnestExpr = expr_internal.UnnestExpr
Window = expr_internal.Window
WindowExpr = expr_internal.WindowExpr

__all__ = [
"Expr",
Expand Down Expand Up @@ -154,6 +154,7 @@
"Partitioning",
"Repartition",
"Window",
"WindowExpr",
"WindowFrame",
"WindowFrameBound",
]
Expand Down Expand Up @@ -542,6 +543,36 @@ def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder:
"""
return ExprFuncBuilder(self.expr.window_frame(window_frame.window_frame))

def over(self, window: Window) -> Expr:
"""Turn an aggregate function into a window function.

This function turns any aggregate function into a window function. With the
exception of ``partition_by``, how each of the parameters is used is determined
by the underlying aggregate function.

Args:
window: Window definition
"""
partition_by_raw = expr_list_to_raw_expr_list(window._partition_by)
order_by_raw = sort_list_to_raw_sort_list(window._order_by)
window_frame_raw = (
window._window_frame.window_frame
if window._window_frame is not None
else None
)
null_treatment_raw = (
window._null_treatment.value if window._null_treatment is not None else None
)

return Expr(
self.expr.over(
partition_by=partition_by_raw,
order_by=order_by_raw,
window_frame=window_frame_raw,
null_treatment=null_treatment_raw,
)
)


class ExprFuncBuilder:
def __init__(self, builder: expr_internal.ExprFuncBuilder):
Expand Down Expand Up @@ -584,6 +615,30 @@ def build(self) -> Expr:
return Expr(self.builder.build())


class Window:
"""Define reusable window parameters."""

def __init__(
self,
partition_by: Optional[list[Expr]] = None,
window_frame: Optional[WindowFrame] = None,
order_by: Optional[list[SortExpr | Expr]] = None,
null_treatment: Optional[NullTreatment] = None,
) -> None:
"""Construct a window definition.

Args:
partition_by: Partitions for window operation
window_frame: Define the start and end bounds of the window frame
order_by: Set ordering
null_treatment: Indicate how nulls are to be treated
"""
self._partition_by = partition_by
self._window_frame = window_frame
self._order_by = order_by
self._null_treatment = null_treatment


class WindowFrame:
"""Defines a window frame for performing window operations."""

Expand Down
75 changes: 54 additions & 21 deletions python/datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
literal,
udf,
)
from datafusion.expr import Window


@pytest.fixture
Expand Down Expand Up @@ -386,38 +387,32 @@ def test_distinct():
),
[-1, -1, None, 7, -1, -1, None],
),
# TODO update all aggregate functions as windows once upstream merges https://github.com/apache/datafusion-python/issues/833
pytest.param(
(
"first_value",
f.window(
"first_value",
[column("a")],
order_by=[f.order_by(column("b"))],
partition_by=[column("c")],
f.first_value(column("a")).over(
Window(partition_by=[column("c")], order_by=[column("b")])
),
[1, 1, 1, 1, 5, 5, 5],
),
pytest.param(
(
"last_value",
f.window("last_value", [column("a")])
.window_frame(WindowFrame("rows", 0, None))
.order_by(column("b"))
.partition_by(column("c"))
.build(),
f.last_value(column("a")).over(
Window(
partition_by=[column("c")],
order_by=[column("b")],
window_frame=WindowFrame("rows", None, None),
)
),
[3, 3, 3, 3, 6, 6, 6],
),
pytest.param(
(
"3rd_value",
f.window(
"nth_value",
[column("b"), literal(3)],
order_by=[f.order_by(column("a"))],
),
f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])),
[None, None, 7, 7, 7, 7, 7],
),
pytest.param(
(
"avg",
f.round(f.window("avg", [column("b")], order_by=[column("a")]), literal(3)),
f.round(f.avg(column("b")).over(Window(order_by=[column("a")])), literal(3)),
[7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0],
),
]
Expand Down Expand Up @@ -473,6 +468,44 @@ def test_invalid_window_frame(units, start_bound, end_bound):
WindowFrame(units, start_bound, end_bound)


def test_window_frame_defaults_match_postgres(partitioned_df):
# ref: https://github.com/apache/datafusion-python/issues/688

window_frame = WindowFrame("rows", None, None)

col_a = column("a")

# Using `f.window` with or without an unbounded window_frame produces the same
# results. These tests are included as a regression check but can be removed when
# f.window() is deprecated in favor of using the .over() approach.
no_frame = f.window("avg", [col_a]).alias("no_frame")
with_frame = f.window("avg", [col_a], window_frame=window_frame).alias("with_frame")
df_1 = partitioned_df.select(col_a, no_frame, with_frame)

expected = {
"a": [0, 1, 2, 3, 4, 5, 6],
"no_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
"with_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
}

assert df_1.sort(col_a).to_pydict() == expected

# When order is not set, the default frame should be unounded preceeding to
# unbounded following. When order is set, the default frame is unbounded preceeding
# to current row.
no_order = f.avg(col_a).over(Window()).alias("over_no_order")
with_order = f.avg(col_a).over(Window(order_by=[col_a])).alias("over_with_order")
df_2 = partitioned_df.select(col_a, no_order, with_order)

expected = {
"a": [0, 1, 2, 3, 4, 5, 6],
"over_no_order": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
"over_with_order": [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0],
}

assert df_2.sort(col_a).to_pydict() == expected


def test_get_dataframe(tmp_path):
ctx = SessionContext()

Expand Down
46 changes: 44 additions & 2 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
// under the License.

use datafusion::logical_expr::utils::exprlist_to_fields;
use datafusion::logical_expr::{ExprFuncBuilder, ExprFunctionExt, LogicalPlan};
use datafusion::logical_expr::{
ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
};
use pyo3::{basic::CompareOp, prelude::*};
use std::convert::{From, Into};
use std::sync::Arc;
Expand All @@ -39,6 +41,7 @@ use crate::expr::aggregate_expr::PyAggregateFunction;
use crate::expr::binary_expr::PyBinaryExpr;
use crate::expr::column::PyColumn;
use crate::expr::literal::PyLiteral;
use crate::functions::add_builder_fns_to_window;
use crate::sql::logical::PyLogicalPlan;

use self::alias::PyAlias;
Expand Down Expand Up @@ -558,6 +561,45 @@ impl PyExpr {
pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
self.expr.clone().window_frame(window_frame.into()).into()
}

#[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, null_treatment=None))]
pub fn over(
&self,
partition_by: Option<Vec<PyExpr>>,
window_frame: Option<PyWindowFrame>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
match &self.expr {
Expr::AggregateFunction(agg_fn) => {
let window_fn = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
agg_fn.args.clone(),
));

add_builder_fns_to_window(
window_fn,
partition_by,
window_frame,
order_by,
null_treatment,
)
}
Expr::WindowFunction(_) => add_builder_fns_to_window(
self.expr.clone(),
partition_by,
window_frame,
order_by,
null_treatment,
),
_ => Err(
DataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan(
format!("Using {} with `over` is not allowed. Must use an aggregate or window function.", self.expr.variant_name()),
))
.into(),
),
}
}
}

#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]
Expand Down Expand Up @@ -749,7 +791,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<drop_table::PyDropTable>()?;
m.add_class::<repartition::PyPartitioning>()?;
m.add_class::<repartition::PyRepartition>()?;
m.add_class::<window::PyWindow>()?;
m.add_class::<window::PyWindowExpr>()?;
m.add_class::<window::PyWindowFrame>()?;
m.add_class::<window::PyWindowFrameBound>()?;
Ok(())
Expand Down
20 changes: 10 additions & 10 deletions src/expr/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ use super::py_expr_list;

use crate::errors::py_datafusion_err;

#[pyclass(name = "Window", module = "datafusion.expr", subclass)]
#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyWindow {
pub struct PyWindowExpr {
window: Window,
}

Expand Down Expand Up @@ -62,15 +62,15 @@ pub struct PyWindowFrameBound {
frame_bound: WindowFrameBound,
}

impl From<PyWindow> for Window {
fn from(window: PyWindow) -> Window {
impl From<PyWindowExpr> for Window {
fn from(window: PyWindowExpr) -> Window {
window.window
}
}

impl From<Window> for PyWindow {
fn from(window: Window) -> PyWindow {
PyWindow { window }
impl From<Window> for PyWindowExpr {
fn from(window: Window) -> PyWindowExpr {
PyWindowExpr { window }
}
}

Expand All @@ -80,7 +80,7 @@ impl From<WindowFrameBound> for PyWindowFrameBound {
}
}

impl Display for PyWindow {
impl Display for PyWindowExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
Expand All @@ -103,7 +103,7 @@ impl Display for PyWindowFrame {
}

#[pymethods]
impl PyWindow {
impl PyWindowExpr {
/// Returns the schema of the Window
pub fn schema(&self) -> PyResult<PyDFSchema> {
Ok(self.window.schema.as_ref().clone().into())
Expand Down Expand Up @@ -283,7 +283,7 @@ impl PyWindowFrameBound {
}
}

impl LogicalNode for PyWindow {
impl LogicalNode for PyWindowExpr {
fn inputs(&self) -> Vec<PyLogicalPlan> {
vec![self.window.input.as_ref().clone().into()]
}
Expand Down
Loading
Loading