Skip to content

Commit

Permalink
Add window function as template for others and function builder
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer committed Aug 13, 2024
1 parent 9108f18 commit 999d92c
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 48 deletions.
20 changes: 19 additions & 1 deletion python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,10 @@ def display_name(self) -> str:
This name will not include any CAST expressions.
"""
import warnings
warnings.warn("deprecated since 40.0.0: use schema_name instead", DeprecationWarning)

warnings.warn(
"deprecated since 40.0.0: use schema_name instead", DeprecationWarning
)
return self.schema_name()

def schema_name(self) -> str:
Expand Down Expand Up @@ -364,6 +367,10 @@ def is_null(self) -> Expr:
"""Returns ``True`` if this expression is null."""
return Expr(self.expr.is_null())

def is_not_null(self) -> Expr:
"""Returns ``True`` if this expression is not null."""
return Expr(self.expr.is_not_null())

def cast(self, to: pa.DataType[Any]) -> Expr:
"""Cast to a new data type."""
return Expr(self.expr.cast(to))
Expand Down Expand Up @@ -414,6 +421,17 @@ def column_name(self, plan: LogicalPlan) -> str:
"""Compute the output column name based on the provided logical plan."""
return self.expr.column_name(plan)

def order_by(self, *exprs: Expr) -> ExprFuncBuilder:

Check failure on line 424 in python/datafusion/expr.py

View workflow job for this annotation

GitHub Actions / build

Ruff (D102)

python/datafusion/expr.py:424:9: D102 Missing docstring in public method
return ExprFuncBuilder(self.expr.order_by(list(e.expr for e in exprs)))


class ExprFuncBuilder:
def __init__(self, builder: expr_internal.ExprFuncBuilder):
self.builder = builder

def build(self) -> Expr:
return Expr(self.builder.build())


class WindowFrame:
"""Defines a window frame for performing window operations."""
Expand Down
25 changes: 20 additions & 5 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
from datafusion.expr import CaseBuilder, Expr, WindowFrame
from datafusion.context import SessionContext

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import pyarrow as pa

__all__ = [
"abs",
"acos",
Expand Down Expand Up @@ -246,6 +251,7 @@
"var_pop",
"var_samp",
"window",
"lead",
]


Expand Down Expand Up @@ -1011,12 +1017,12 @@ def struct(*args: Expr) -> Expr:
return Expr(f.struct(*args))


def named_struct(name_pairs: list[(str, Expr)]) -> Expr:
def named_struct(name_pairs: list[tuple[str, Expr]]) -> Expr:
"""Returns a struct with the given names and arguments pairs."""
name_pairs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs]
name_pair_exprs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs]

# flatten
name_pairs = [x.expr for xs in name_pairs for x in xs]
name_pairs = [x.expr for xs in name_pair_exprs for x in xs]
return Expr(f.named_struct(*name_pairs))


Expand Down Expand Up @@ -1479,12 +1485,17 @@ def approx_percentile_cont(
"""Returns the value that is approximately at a given percentile of ``expr``."""
if num_centroids is None:
return Expr(
f.approx_percentile_cont(expression.expr, percentile.expr, distinct=distinct, num_centroids=None)
f.approx_percentile_cont(
expression.expr, percentile.expr, distinct=distinct, num_centroids=None
)
)

return Expr(
f.approx_percentile_cont(
expression.expr, percentile.expr, distinct=distinct, num_centroids=num_centroids.expr
expression.expr,
percentile.expr,
distinct=distinct,
num_centroids=num_centroids.expr,
)
)

Expand Down Expand Up @@ -1732,3 +1743,7 @@ def bool_and(arg: Expr, distinct: bool = False) -> Expr:
def bool_or(arg: Expr, distinct: bool = False) -> Expr:
"""Computes the boolean OR of the arguement."""
return Expr(f.bool_or(arg.expr, distinct=distinct))


def lead(arg: Expr, shift_offset: int = 1, default_value: pa.Scalar | None = None):

Check failure on line 1748 in python/datafusion/functions.py

View workflow job for this annotation

GitHub Actions / build

Ruff (D103)

python/datafusion/functions.py:1748:5: D103 Missing docstring in public function
return Expr(f.lead(arg.expr, shift_offset, default_value))
91 changes: 51 additions & 40 deletions python/datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,57 +279,68 @@ def test_distinct():


data_test_window_functions = [
("row", f.window("row_number", [], order_by=[f.order_by(column("c"))]), [2, 1, 3]),
("rank", f.window("rank", [], order_by=[f.order_by(column("c"))]), [2, 1, 2]),
("dense_rank", f.window("dense_rank", [], order_by=[f.order_by(column("c"))]), [2, 1, 2] ),
("percent_rank", f.window("percent_rank", [], order_by=[f.order_by(column("c"))]), [0.5, 0, 0.5]),
("cume_dist", f.window("cume_dist", [], order_by=[f.order_by(column("b"))]), [0.3333333333333333, 0.6666666666666666, 1.0]),
("ntile", f.window("ntile", [literal(2)], order_by=[f.order_by(column("c"))]), [1, 1, 2]),
("next", f.window("lead", [column("b")], order_by=[f.order_by(column("b"))]), [5, 6, None]),
("previous", f.window("lag", [column("b")], order_by=[f.order_by(column("b"))]), [None, 4, 5]),
pytest.param(
"first_value",
f.window(
("row", f.window("row_number", [], order_by=[f.order_by(column("c"))]), [2, 1, 3]),
("rank", f.window("rank", [], order_by=[f.order_by(column("c"))]), [2, 1, 2]),
(
"dense_rank",
f.window("dense_rank", [], order_by=[f.order_by(column("c"))]),
[2, 1, 2],
),
(
"percent_rank",
f.window("percent_rank", [], order_by=[f.order_by(column("c"))]),
[0.5, 0, 0.5],
),
(
"cume_dist",
f.window("cume_dist", [], order_by=[f.order_by(column("b"))]),
[0.3333333333333333, 0.6666666666666666, 1.0],
),
(
"ntile",
f.window("ntile", [literal(2)], order_by=[f.order_by(column("c"))]),
[1, 1, 2],
),
(
"next",
f.window("lead", [column("b")], order_by=[f.order_by(column("b"))]),
[5, 6, None],
),
("lead", f.lead(column("b")).order_by(column("b").sort()).build(), [5, 6, None]),
(
"previous",
f.window("lag", [column("b")], order_by=[f.order_by(column("b"))]),
[None, 4, 5],
),
pytest.param(
"first_value",
[column("a")],
order_by=[f.order_by(column("b"))]
f.window("first_value", [column("a")], order_by=[f.order_by(column("b"))]),
[1, 1, 1],
),
pytest.param(
"last_value",
f.window("last_value", [column("b")], order_by=[f.order_by(column("b"))]),
[4, 5, 6],
),
[1, 1, 1],
),
pytest.param(
"last_value",
f.window("last_value", [column("b")], order_by=[f.order_by(column("b"))]),
[4, 5, 6],
),
pytest.param(
"2nd_value",
f.window(
"nth_value",
[column("b"), literal(2)],
order_by=[f.order_by(column("b"))],
pytest.param(
"2nd_value",
f.window(
"nth_value",
[column("b"), literal(2)],
order_by=[f.order_by(column("b"))],
),
[None, 5, 5],
),
[None, 5, 5],
),
]


@pytest.mark.parametrize("name,expr,result", data_test_window_functions)
def test_window_functions(df, name, expr, result):
df = df.select(
column("a"),
column("b"),
column("c"),
f.alias(expr, name)
)
df = df.select(column("a"), column("b"), column("c"), f.alias(expr, name))

table = pa.Table.from_batches(df.collect())

expected = {
"a": [1, 2, 3],
"b": [4, 5, 6],
"c": [8, 5, 8],
name: result
}
expected = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [8, 5, 8], name: result}

assert table.sort_by("a").to_pydict() == expected

Expand Down
95 changes: 93 additions & 2 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
// under the License.

use datafusion_expr::utils::exprlist_to_fields;
use datafusion_expr::LogicalPlan;
use datafusion_expr::{ExprFuncBuilder, ExprFunctionExt, LogicalPlan};
use pyo3::{basic::CompareOp, prelude::*};
use std::convert::{From, Into};
use std::sync::Arc;
use window::PyWindowFrame;

use arrow::pyarrow::ToPyArrow;
use datafusion::arrow::datatypes::{DataType, Field};
Expand All @@ -32,7 +33,7 @@ use datafusion_expr::{
lit, Between, BinaryExpr, Case, Cast, Expr, Like, Operator, TryCast,
};

use crate::common::data_type::{DataTypeMap, RexType};
use crate::common::data_type::{DataTypeMap, NullTreatment, RexType};
use crate::errors::{py_runtime_err, py_type_err, py_unsupported_variant_err, DataFusionError};
use crate::expr::aggregate_expr::PyAggregateFunction;
use crate::expr::binary_expr::PyBinaryExpr;
Expand Down Expand Up @@ -281,6 +282,10 @@ impl PyExpr {
self.expr.clone().is_null().into()
}

pub fn is_not_null(&self) -> PyExpr {
self.expr.clone().is_not_null().into()
}

pub fn cast(&self, to: PyArrowType<DataType>) -> PyExpr {
// self.expr.cast_to() requires DFSchema to validate that the cast
// is supported, omit that for now
Expand Down Expand Up @@ -510,6 +515,92 @@ impl PyExpr {
pub fn column_name(&self, plan: PyLogicalPlan) -> PyResult<String> {
self._column_name(&plan.plan()).map_err(py_runtime_err)
}

// Expression Function Builder functions

pub fn order_by(&self, order_by: Vec<PyExpr>) -> PyExprFuncBuilder {
let order_by = order_by.iter().map(|e| e.expr.clone()).collect();
self.expr.clone().order_by(order_by).into()
}

pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
self.expr.clone().filter(filter.expr.clone()).into()
}

pub fn distinct(&self) -> PyExprFuncBuilder {
self.expr.clone().distinct().into()
}

pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder {
self.expr
.clone()
.null_treatment(Some(null_treatment.into()))
.into()
}

pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder {
let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect();
self.expr.clone().partition_by(partition_by).into()
}

pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
self.expr.clone().window_frame(window_frame.into()).into()
}
}

#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]
#[derive(Debug, Clone)]
pub struct PyExprFuncBuilder {
pub builder: ExprFuncBuilder,
}

impl From<ExprFuncBuilder> for PyExprFuncBuilder {
fn from(builder: ExprFuncBuilder) -> Self {
Self { builder }
}
}

#[pymethods]
impl PyExprFuncBuilder {
pub fn order_by(&self, order_by: Vec<PyExpr>) -> PyExprFuncBuilder {
let order_by = order_by.iter().map(|e| e.expr.clone()).collect();
self.builder.clone().order_by(order_by).into()
}

pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
self.builder.clone().filter(filter.expr.clone()).into()
}

pub fn distinct(&self) -> PyExprFuncBuilder {
self.builder.clone().distinct().into()
}

pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder {
self.builder
.clone()
.null_treatment(Some(null_treatment.into()))
.into()
}

pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder {
let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect();
self.builder.clone().partition_by(partition_by).into()
}

pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
self.builder
.clone()
.window_frame(window_frame.into())
.into()
}

pub fn build(&self) -> PyResult<PyExpr> {
self.builder
.clone()
.build()
.map(|expr| expr.into())
.map_err(|err| err.into())
}
}

impl PyExpr {
Expand Down
8 changes: 8 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion_expr::window_function;
use datafusion_expr::ExprFunctionExt;
use pyo3::{prelude::*, wrap_pyfunction};

Expand Down Expand Up @@ -882,6 +883,11 @@ aggregate_function!(array_agg, functions_aggregate::array_agg::array_agg_udaf);
aggregate_function!(max, functions_aggregate::min_max::max_udaf);
aggregate_function!(min, functions_aggregate::min_max::min_udaf);

#[pyfunction]
pub fn lead(arg: PyExpr, shift_offset: i64, default_value: Option<ScalarValue>) -> PyExpr {
window_function::lead(arg.expr, Some(shift_offset), default_value).into()
}

pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(abs))?;
m.add_wrapped(wrap_pyfunction!(acos))?;
Expand Down Expand Up @@ -1066,5 +1072,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(array_slice))?;
m.add_wrapped(wrap_pyfunction!(flatten))?;

m.add_wrapped(wrap_pyfunction!(lead))?;

Ok(())
}

0 comments on commit 999d92c

Please sign in to comment.