From 1df442ffe361d988a425d50af3419b656f2b871a Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 7 Jun 2024 07:12:36 +0100 Subject: [PATCH] feat(rust): Expose a few more expression nodes in the expression IR (#16781) --- py-polars/src/lazyframe/visitor/expr_nodes.rs | 49 ++++++++++++++----- py-polars/src/lazyframe/visitor/nodes.rs | 12 ++--- py-polars/src/lib.rs | 1 + 3 files changed, 44 insertions(+), 18 deletions(-) diff --git a/py-polars/src/lazyframe/visitor/expr_nodes.rs b/py-polars/src/lazyframe/visitor/expr_nodes.rs index fc807d64402c..91f826d1d0a9 100644 --- a/py-polars/src/lazyframe/visitor/expr_nodes.rs +++ b/py-polars/src/lazyframe/visitor/expr_nodes.rs @@ -14,6 +14,7 @@ use polars_time::prelude::RollingGroupOptions; use pyo3::exceptions::PyNotImplementedError; use pyo3::prelude::*; +use crate::series::PySeries; use crate::Wrap; #[pyclass] @@ -342,6 +343,16 @@ pub struct Function { options: PyObject, } +#[pyclass] +pub struct Slice { + #[pyo3(get)] + input: usize, + #[pyo3(get)] + offset: usize, + #[pyo3(get)] + length: usize, +} + #[pyclass] pub struct Len {} @@ -545,9 +556,18 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { value: Wrap(lit.to_any_value().unwrap()).to_object(py), dtype, }, - Duration(_, _) => return Err(PyNotImplementedError::new_err("duration literal")), - Time(_) => return Err(PyNotImplementedError::new_err("time literal")), - Series(_) => return Err(PyNotImplementedError::new_err("series literal")), + Duration(v, _) => Literal { + value: v.to_object(py), + dtype, + }, + Time(ns) => Literal { + value: ns.to_object(py), + dtype, + }, + Series(s) => Literal { + value: PySeries::new((**s).clone()).into_py(py), + dtype, + }, } } .into_py(py), @@ -995,9 +1015,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { return Err(PyNotImplementedError::new_err("search sorted")) }, FunctionExpr::Range(_) => return Err(PyNotImplementedError::new_err("range")), - FunctionExpr::DateOffset => { - return Err(PyNotImplementedError::new_err("date offset")) - }, + FunctionExpr::DateOffset => ("offset_by",).to_object(py), FunctionExpr::Trigonometry(trigfun) => match trigfun { TrigonometricFunction::Cos => ("cos",), TrigonometricFunction::Cot => ("cot",), @@ -1107,13 +1125,11 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { parallel: _, name: _, } => return Err(PyNotImplementedError::new_err("value counts")), - FunctionExpr::UniqueCounts => { - return Err(PyNotImplementedError::new_err("unique counts")) - }, + FunctionExpr::UniqueCounts => ("unique_counts",).to_object(py), FunctionExpr::ApproxNUnique => { return Err(PyNotImplementedError::new_err("approx nunique")) }, - FunctionExpr::Coalesce => return Err(PyNotImplementedError::new_err("coalesce")), + FunctionExpr::Coalesce => ("coalesce",).to_object(py), FunctionExpr::ShrinkType => { return Err(PyNotImplementedError::new_err("shrink type")) }, @@ -1134,7 +1150,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::Log { base: _ } => return Err(PyNotImplementedError::new_err("log")), FunctionExpr::Log1p => return Err(PyNotImplementedError::new_err("log1p")), FunctionExpr::Exp => return Err(PyNotImplementedError::new_err("exp")), - FunctionExpr::Unique(_) => return Err(PyNotImplementedError::new_err("unique")), + FunctionExpr::Unique(maintain_order) => ("unique", maintain_order).to_object(py), FunctionExpr::Round { decimals } => ("round", decimals).to_object(py), FunctionExpr::RoundSF { digits } => ("round_sig_figs", digits).to_object(py), FunctionExpr::Floor => ("floor",).to_object(py), @@ -1262,7 +1278,16 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { .into_py(py) }, AExpr::Wildcard => return Err(PyNotImplementedError::new_err("wildcard")), - AExpr::Slice { .. } => return Err(PyNotImplementedError::new_err("slice")), + AExpr::Slice { + input, + offset, + length, + } => Slice { + input: input.0, + offset: offset.0, + length: length.0, + } + .into_py(py), AExpr::Nth(_) => return Err(PyNotImplementedError::new_err("nth")), AExpr::Len => Len {}.into_py(py), }; diff --git a/py-polars/src/lazyframe/visitor/nodes.rs b/py-polars/src/lazyframe/visitor/nodes.rs index c6310960c654..07e4c71947a8 100644 --- a/py-polars/src/lazyframe/visitor/nodes.rs +++ b/py-polars/src/lazyframe/visitor/nodes.rs @@ -128,7 +128,7 @@ pub struct Select { #[pyo3(get)] expr: Vec, #[pyo3(get)] - options: (), //ProjectionOptions, + should_broadcast: bool, } #[pyclass] @@ -195,7 +195,7 @@ pub struct HStack { #[pyo3(get)] exprs: Vec, #[pyo3(get)] - options: (), // ProjectionOptions, + should_broadcast: bool, } #[pyclass] @@ -338,11 +338,11 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { input, expr, schema: _, - options: _, + options, } => Select { expr: expr.iter().map(|e| e.into()).collect(), input: input.0, - options: (), + should_broadcast: options.should_broadcast, } .into_py(py), IR::Sort { @@ -428,11 +428,11 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { input, exprs, schema: _, - options: _, + options, } => HStack { input: input.0, exprs: exprs.iter().map(|e| e.into()).collect(), - options: (), + should_broadcast: options.should_broadcast, } .into_py(py), IR::Reduce { diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 5d8de7cd5df0..569c09f4e689 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -103,6 +103,7 @@ fn _expr_nodes(_py: Python, m: &Bound) -> PyResult<()> { m.add_class::().unwrap(); m.add_class::().unwrap(); m.add_class::().unwrap(); + m.add_class::().unwrap(); m.add_class::().unwrap(); m.add_class::().unwrap(); m.add_class::().unwrap();