From 8a20588435371a79fb8266bbd769e22780c264ee Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Mon, 2 Sep 2024 11:47:12 +0200 Subject: [PATCH] feat(python): Make expressions containing Python UDFs serializable (#18135) --- crates/polars-plan/src/client/check.rs | 76 +------------ crates/polars-plan/src/dsl/expr.rs | 3 - crates/polars-plan/src/dsl/expr_dyn_fn.rs | 100 ++++++++++++++++-- crates/polars-plan/src/dsl/python_udf.rs | 69 +++++++++--- .../unit/cloud/test_prepare_cloud_plan.py | 67 +++++++----- py-polars/tests/unit/test_serde.py | 10 ++ 6 files changed, 200 insertions(+), 125 deletions(-) diff --git a/crates/polars-plan/src/client/check.rs b/crates/polars-plan/src/client/check.rs index 08cbb6cb7319..a01addd9231d 100644 --- a/crates/polars-plan/src/client/check.rs +++ b/crates/polars-plan/src/client/check.rs @@ -1,30 +1,15 @@ use polars_core::error::{polars_err, PolarsResult}; use polars_io::path_utils::is_cloud_url; -use crate::dsl::Expr; use crate::plans::options::SinkType; -use crate::plans::{DslFunction, DslPlan, FileScan, FunctionIR}; +use crate::plans::{DslPlan, FileScan}; /// Assert that the given [`DslPlan`] is eligible to be executed on Polars Cloud. pub(super) fn assert_cloud_eligible(dsl: &DslPlan) -> PolarsResult<()> { - let mut expr_stack = vec![]; for plan_node in dsl.into_iter() { match plan_node { - DslPlan::MapFunction { function, .. } => match function { - DslFunction::FunctionIR(FunctionIR::Opaque { .. }) => { - return ineligible_error("contains opaque function") - }, - #[cfg(feature = "python")] - DslFunction::OpaquePython { .. } => { - return ineligible_error("contains Python function") - }, - _ => (), - }, #[cfg(feature = "python")] DslPlan::PythonScan { .. } => return ineligible_error("contains Python scan"), - DslPlan::GroupBy { apply: Some(_), .. } => { - return ineligible_error("contains Python function in group by operation") - }, DslPlan::Scan { paths, .. } if paths.lock().unwrap().0.iter().any(|p| !is_cloud_url(p)) => { @@ -39,23 +24,7 @@ pub(super) fn assert_cloud_eligible(dsl: &DslPlan) -> PolarsResult<()> { return ineligible_error("contains sink to non-cloud location"); } }, - plan => { - plan.get_expr(&mut expr_stack); - - for expr in expr_stack.drain(..) { - for expr_node in expr.into_iter() { - match expr_node { - Expr::AnonymousFunction { .. } => { - return ineligible_error("contains anonymous function") - }, - Expr::RenameAlias { .. } => { - return ineligible_error("contains custom name remapping") - }, - _ => (), - } - } - } - }, + _ => (), } } Ok(()) @@ -101,47 +70,6 @@ impl DslPlan { PythonScan { .. } => (), } } - - fn get_expr<'a>(&'a self, scratch: &mut Vec<&'a Expr>) { - use DslPlan::*; - match self { - Filter { predicate, .. } => scratch.push(predicate), - Scan { predicate, .. } => { - if let Some(expr) = predicate { - scratch.push(expr) - } - }, - DataFrameScan { filter, .. } => { - if let Some(expr) = filter { - scratch.push(expr) - } - }, - Select { expr, .. } => scratch.extend(expr), - HStack { exprs, .. } => scratch.extend(exprs), - Sort { by_column, .. } => scratch.extend(by_column), - GroupBy { keys, aggs, .. } => { - scratch.extend(keys); - scratch.extend(aggs); - }, - Join { - left_on, right_on, .. - } => { - scratch.extend(left_on); - scratch.extend(right_on); - }, - Cache { .. } - | Distinct { .. } - | Slice { .. } - | MapFunction { .. } - | Union { .. } - | HConcat { .. } - | ExtContext { .. } - | Sink { .. } - | IR { .. } => (), - #[cfg(feature = "python")] - PythonScan { .. } => (), - } - } } pub struct DslPlanIter<'a> { diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 4553f87360e0..0c40baf3dfee 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -143,8 +143,6 @@ pub enum Expr { Len, /// Take the nth column in the `DataFrame` Nth(i64), - // skipped fields must be last otherwise serde fails in pickle - #[cfg_attr(feature = "serde", serde(skip))] RenameAlias { function: SpecialEq>, expr: Arc, @@ -157,7 +155,6 @@ pub enum Expr { /// function to apply function: SpecialEq>, /// output dtype of the function - #[cfg_attr(feature = "serde", serde(skip))] output_type: GetOutput, options: FunctionOptions, }, diff --git a/crates/polars-plan/src/dsl/expr_dyn_fn.rs b/crates/polars-plan/src/dsl/expr_dyn_fn.rs index 911a0c4308b2..1a498b1fecfe 100644 --- a/crates/polars-plan/src/dsl/expr_dyn_fn.rs +++ b/crates/polars-plan/src/dsl/expr_dyn_fn.rs @@ -1,5 +1,6 @@ use std::fmt::Formatter; use std::ops::Deref; +use std::sync::Arc; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -17,7 +18,7 @@ pub trait SeriesUdf: Send + Sync { fn call_udf(&self, s: &mut [Series]) -> PolarsResult>; fn try_serialize(&self, _buf: &mut Vec) -> PolarsResult<()> { - polars_bail!(ComputeError: "serialize not supported for this 'opaque' function") + polars_bail!(ComputeError: "serialization not supported for this 'opaque' function") } // Needed for python functions. After they are deserialized we first check if they @@ -46,30 +47,29 @@ impl Serialize for SpecialEq> { #[cfg(feature = "serde")] impl<'a> Deserialize<'a> for SpecialEq> { - fn deserialize(_deserializer: D) -> std::result::Result + fn deserialize(deserializer: D) -> std::result::Result where D: Deserializer<'a>, { use serde::de::Error; #[cfg(feature = "python")] { - use crate::dsl::python_udf::MAGIC_BYTE_MARK; - let buf = Vec::::deserialize(_deserializer)?; + let buf = Vec::::deserialize(deserializer)?; - if buf.starts_with(MAGIC_BYTE_MARK) { + if buf.starts_with(python_udf::MAGIC_BYTE_MARK) { let udf = python_udf::PythonUdfExpression::try_deserialize(&buf) .map_err(|e| D::Error::custom(format!("{e}")))?; Ok(SpecialEq::new(udf)) } else { Err(D::Error::custom( - "deserialize not supported for this 'opaque' function", + "deserialization not supported for this 'opaque' function", )) } } #[cfg(not(feature = "python"))] { Err(D::Error::custom( - "deserialize not supported for this 'opaque' function", + "deserialization not supported for this 'opaque' function", )) } } @@ -125,9 +125,16 @@ impl Default for SpecialEq> { pub trait RenameAliasFn: Send + Sync { fn call(&self, name: &PlSmallStr) -> PolarsResult; + + fn try_serialize(&self, _buf: &mut Vec) -> PolarsResult<()> { + polars_bail!(ComputeError: "serialization not supported for this renaming function") + } } -impl PolarsResult + Send + Sync> RenameAliasFn for F { +impl RenameAliasFn for F +where + F: Fn(&PlSmallStr) -> PolarsResult + Send + Sync, +{ fn call(&self, name: &PlSmallStr) -> PolarsResult { self(name) } @@ -250,6 +257,10 @@ pub trait FunctionOutputField: Send + Sync { cntxt: Context, fields: &[Field], ) -> PolarsResult; + + fn try_serialize(&self, _buf: &mut Vec) -> PolarsResult<()> { + polars_bail!(ComputeError: "serialization not supported for this output field") + } } pub type GetOutput = SpecialEq>; @@ -344,3 +355,76 @@ where self(input_schema, cntxt, fields) } } + +#[cfg(feature = "serde")] +impl Serialize for GetOutput { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + use serde::ser::Error; + let mut buf = vec![]; + self.0 + .try_serialize(&mut buf) + .map_err(|e| S::Error::custom(format!("{e}")))?; + serializer.serialize_bytes(&buf) + } +} + +#[cfg(feature = "serde")] +impl<'a> Deserialize<'a> for GetOutput { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + use serde::de::Error; + #[cfg(feature = "python")] + { + let buf = Vec::::deserialize(deserializer)?; + + if buf.starts_with(python_udf::MAGIC_BYTE_MARK) { + let get_output = python_udf::PythonGetOutput::try_deserialize(&buf) + .map_err(|e| D::Error::custom(format!("{e}")))?; + Ok(SpecialEq::new(get_output)) + } else { + Err(D::Error::custom( + "deserialization not supported for this output field", + )) + } + } + #[cfg(not(feature = "python"))] + { + Err(D::Error::custom( + "deserialization not supported for this output field", + )) + } + } +} + +#[cfg(feature = "serde")] +impl Serialize for SpecialEq> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + use serde::ser::Error; + let mut buf = vec![]; + self.0 + .try_serialize(&mut buf) + .map_err(|e| S::Error::custom(format!("{e}")))?; + serializer.serialize_bytes(&buf) + } +} + +#[cfg(feature = "serde")] +impl<'a> Deserialize<'a> for SpecialEq> { + fn deserialize(_deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + use serde::de::Error; + Err(D::Error::custom( + "deserialization not supported for this renaming function", + )) + } +} diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index d31a659c7be7..b105f62df482 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -5,6 +5,7 @@ use polars_core::datatypes::{DataType, Field}; use polars_core::error::*; use polars_core::frame::DataFrame; use polars_core::prelude::Series; +use polars_core::schema::Schema; use pyo3::prelude::*; use pyo3::pybacked::PyBackedBytes; use pyo3::types::PyBytes; @@ -17,14 +18,14 @@ use super::expr_dyn_fn::*; use crate::constants::MAP_LIST_NAME; use crate::prelude::*; -// Will be overwritten on python polar start up. +// Will be overwritten on Python Polars start up. pub static mut CALL_SERIES_UDF_PYTHON: Option< fn(s: Series, lambda: &PyObject) -> PolarsResult, > = None; pub static mut CALL_DF_UDF_PYTHON: Option< fn(s: DataFrame, lambda: &PyObject) -> PolarsResult, > = None; -pub(super) const MAGIC_BYTE_MARK: &[u8] = "POLARS_PYTHON_UDF".as_bytes(); +pub(super) const MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes(); #[derive(Clone, Debug)] pub struct PythonFunction(pub PyObject); @@ -141,7 +142,7 @@ impl PythonUdfExpression { .unwrap(); let arg = (PyBytes::new_bound(py, remainder),); let python_function = pickle.call1(arg).map_err(from_pyerr)?; - Ok(Arc::new(PythonUdfExpression::new( + Ok(Arc::new(Self::new( python_function.into(), output_type, is_elementwise, @@ -229,6 +230,54 @@ impl SeriesUdf for PythonUdfExpression { } } +/// Serializable version of [`GetOutput`] for Python UDFs. +pub struct PythonGetOutput { + return_dtype: Option, +} + +impl PythonGetOutput { + pub fn new(return_dtype: Option) -> Self { + Self { return_dtype } + } + + #[cfg(feature = "serde")] + pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult> { + // Skip header. + debug_assert!(buf.starts_with(MAGIC_BYTE_MARK)); + let buf = &buf[MAGIC_BYTE_MARK.len()..]; + + let mut reader = Cursor::new(buf); + let return_dtype: Option = + ciborium::de::from_reader(&mut reader).map_err(map_err)?; + + Ok(Arc::new(Self::new(return_dtype)) as Arc) + } +} + +impl FunctionOutputField for PythonGetOutput { + fn get_field( + &self, + _input_schema: &Schema, + _cntxt: Context, + fields: &[Field], + ) -> PolarsResult { + // Take the name of first field, just like [`GetOutput::map_field`]. + let name = fields[0].name(); + let return_dtype = match self.return_dtype { + Some(ref dtype) => dtype.clone(), + None => DataType::Unknown(Default::default()), + }; + Ok(Field::new(name.clone(), return_dtype)) + } + + #[cfg(feature = "serde")] + fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { + buf.extend_from_slice(MAGIC_BYTE_MARK); + ciborium::ser::into_writer(&self.return_dtype, &mut *buf).unwrap(); + Ok(()) + } +} + impl Expr { pub fn map_python(self, func: PythonUdfExpression, agg_list: bool) -> Expr { let (collect_groups, name) = if agg_list { @@ -241,16 +290,10 @@ impl Expr { let returns_scalar = func.returns_scalar; let return_dtype = func.output_type.clone(); - let output_type = GetOutput::map_field(move |fld| { - Ok(match return_dtype { - Some(ref dt) => Field::new(fld.name().clone(), dt.clone()), - None => { - let mut fld = fld.clone(); - fld.coerce(DataType::Unknown(Default::default())); - fld - }, - }) - }); + + let output_field = PythonGetOutput::new(return_dtype); + let output_type = SpecialEq::new(Arc::new(output_field) as Arc); + let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT; if returns_scalar { flags |= FunctionFlags::RETURNS_SCALAR; diff --git a/py-polars/tests/unit/cloud/test_prepare_cloud_plan.py b/py-polars/tests/unit/cloud/test_prepare_cloud_plan.py index a49a0d65e701..d99bab04ef7d 100644 --- a/py-polars/tests/unit/cloud/test_prepare_cloud_plan.py +++ b/py-polars/tests/unit/cloud/test_prepare_cloud_plan.py @@ -6,7 +6,7 @@ import polars as pl from polars._utils.cloud import prepare_cloud_plan -from polars.exceptions import InvalidOperationError +from polars.exceptions import ComputeError, InvalidOperationError CLOUD_SOURCE = "s3://my-nonexistent-bucket/dataset" @@ -28,20 +28,6 @@ def test_prepare_cloud_plan(lf: pl.LazyFrame) -> None: assert isinstance(deserialized, pl.LazyFrame) -def test_prepare_cloud_plan_optimization_toggle() -> None: - lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) - - with pytest.raises(TypeError, match="unexpected keyword argument"): - prepare_cloud_plan(lf, nonexistent_optimization=False) - - result = prepare_cloud_plan(lf, projection_pushdown=False) - assert isinstance(result, bytes) - - # TODO: How to check that this optimization was toggled correctly? - deserialized = pl.LazyFrame.deserialize(BytesIO(result)) - assert isinstance(deserialized, pl.LazyFrame) - - @pytest.mark.parametrize( "lf", [ @@ -51,12 +37,6 @@ def test_prepare_cloud_plan_optimization_toggle() -> None: pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).select( pl.col("b").map_batches(lambda x: sum(x)) ), - pl.LazyFrame({"a": [{"x": 1, "y": 2}]}).select( - pl.col("a").name.map(lambda x: x.upper()) - ), - pl.LazyFrame({"a": [{"x": 1, "y": 2}]}).select( - pl.col("a").name.map_fields(lambda x: x.upper()) - ), pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).map_batches(lambda x: x), pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) .group_by("a") @@ -67,14 +47,31 @@ def test_prepare_cloud_plan_optimization_toggle() -> None: pl.scan_parquet(CLOUD_SOURCE).filter( pl.col("a") < pl.lit(1).map_elements(lambda x: x + 1) ), + pl.LazyFrame({"a": [[1, 2], [3, 4, 5]], "b": [3, 4]}).select( + pl.col("a").map_elements(lambda x: sum(x), return_dtype=pl.Int64) + ), ], ) -def test_prepare_cloud_plan_fail_on_udf(lf: pl.LazyFrame) -> None: - with pytest.raises( - InvalidOperationError, - match="logical plan ineligible for execution on Polars Cloud", - ): - prepare_cloud_plan(lf) +def test_prepare_cloud_plan_udf(lf: pl.LazyFrame) -> None: + result = prepare_cloud_plan(lf) + assert isinstance(result, bytes) + + deserialized = pl.LazyFrame.deserialize(BytesIO(result)) + assert isinstance(deserialized, pl.LazyFrame) + + +def test_prepare_cloud_plan_optimization_toggle() -> None: + lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) + + with pytest.raises(TypeError, match="unexpected keyword argument"): + prepare_cloud_plan(lf, nonexistent_optimization=False) + + result = prepare_cloud_plan(lf, projection_pushdown=False) + assert isinstance(result, bytes) + + # TODO: How to check that this optimization was toggled correctly? + deserialized = pl.LazyFrame.deserialize(BytesIO(result)) + assert isinstance(deserialized, pl.LazyFrame) @pytest.mark.parametrize( @@ -107,3 +104,19 @@ def test_prepare_cloud_plan_fail_on_python_scan(tmp_path: Path) -> None: match="logical plan ineligible for execution on Polars Cloud", ): prepare_cloud_plan(lf) + + +@pytest.mark.parametrize( + "lf", + [ + pl.LazyFrame({"a": [{"x": 1, "y": 2}]}).select( + pl.col("a").name.map(lambda x: x.upper()) + ), + pl.LazyFrame({"a": [{"x": 1, "y": 2}]}).select( + pl.col("a").name.map_fields(lambda x: x.upper()) + ), + ], +) +def test_prepare_cloud_plan_fail_on_serialization(lf: pl.LazyFrame) -> None: + with pytest.raises(ComputeError, match="serialization not supported"): + prepare_cloud_plan(lf) diff --git a/py-polars/tests/unit/test_serde.py b/py-polars/tests/unit/test_serde.py index 869e75e3bf64..c1a79025690e 100644 --- a/py-polars/tests/unit/test_serde.py +++ b/py-polars/tests/unit/test_serde.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import pickle from datetime import datetime, timedelta @@ -207,3 +208,12 @@ def test_serde_data_type_instantiated_with_attributes() -> None: deserialized = pickle.loads(serialized) assert deserialized == dtype assert isinstance(deserialized, pl.DataType) + + +def test_serde_udf() -> None: + lf = pl.LazyFrame({"a": [[1, 2], [3, 4, 5]], "b": [3, 4]}).select( + pl.col("a").map_elements(lambda x: sum(x), return_dtype=pl.Int32) + ) + result = pl.LazyFrame.deserialize(io.BytesIO(lf.serialize())) + + assert_frame_equal(lf, result)