Skip to content

Commit

Permalink
feat(python): Make expressions containing Python UDFs serializable (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Sep 2, 2024
1 parent 7168479 commit 8a20588
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 125 deletions.
76 changes: 2 additions & 74 deletions crates/polars-plan/src/client/check.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,15 @@
use polars_core::error::{polars_err, PolarsResult};
use polars_io::path_utils::is_cloud_url;

use crate::dsl::Expr;
use crate::plans::options::SinkType;
use crate::plans::{DslFunction, DslPlan, FileScan, FunctionIR};
use crate::plans::{DslPlan, FileScan};

/// Assert that the given [`DslPlan`] is eligible to be executed on Polars Cloud.
pub(super) fn assert_cloud_eligible(dsl: &DslPlan) -> PolarsResult<()> {
let mut expr_stack = vec![];
for plan_node in dsl.into_iter() {
match plan_node {
DslPlan::MapFunction { function, .. } => match function {
DslFunction::FunctionIR(FunctionIR::Opaque { .. }) => {
return ineligible_error("contains opaque function")
},
#[cfg(feature = "python")]
DslFunction::OpaquePython { .. } => {
return ineligible_error("contains Python function")
},
_ => (),
},
#[cfg(feature = "python")]
DslPlan::PythonScan { .. } => return ineligible_error("contains Python scan"),
DslPlan::GroupBy { apply: Some(_), .. } => {
return ineligible_error("contains Python function in group by operation")
},
DslPlan::Scan { paths, .. }
if paths.lock().unwrap().0.iter().any(|p| !is_cloud_url(p)) =>
{
Expand All @@ -39,23 +24,7 @@ pub(super) fn assert_cloud_eligible(dsl: &DslPlan) -> PolarsResult<()> {
return ineligible_error("contains sink to non-cloud location");
}
},
plan => {
plan.get_expr(&mut expr_stack);

for expr in expr_stack.drain(..) {
for expr_node in expr.into_iter() {
match expr_node {
Expr::AnonymousFunction { .. } => {
return ineligible_error("contains anonymous function")
},
Expr::RenameAlias { .. } => {
return ineligible_error("contains custom name remapping")
},
_ => (),
}
}
}
},
_ => (),
}
}
Ok(())
Expand Down Expand Up @@ -101,47 +70,6 @@ impl DslPlan {
PythonScan { .. } => (),
}
}

fn get_expr<'a>(&'a self, scratch: &mut Vec<&'a Expr>) {
use DslPlan::*;
match self {
Filter { predicate, .. } => scratch.push(predicate),
Scan { predicate, .. } => {
if let Some(expr) = predicate {
scratch.push(expr)
}
},
DataFrameScan { filter, .. } => {
if let Some(expr) = filter {
scratch.push(expr)
}
},
Select { expr, .. } => scratch.extend(expr),
HStack { exprs, .. } => scratch.extend(exprs),
Sort { by_column, .. } => scratch.extend(by_column),
GroupBy { keys, aggs, .. } => {
scratch.extend(keys);
scratch.extend(aggs);
},
Join {
left_on, right_on, ..
} => {
scratch.extend(left_on);
scratch.extend(right_on);
},
Cache { .. }
| Distinct { .. }
| Slice { .. }
| MapFunction { .. }
| Union { .. }
| HConcat { .. }
| ExtContext { .. }
| Sink { .. }
| IR { .. } => (),
#[cfg(feature = "python")]
PythonScan { .. } => (),
}
}
}

pub struct DslPlanIter<'a> {
Expand Down
3 changes: 0 additions & 3 deletions crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ pub enum Expr {
Len,
/// Take the nth column in the `DataFrame`
Nth(i64),
// skipped fields must be last otherwise serde fails in pickle
#[cfg_attr(feature = "serde", serde(skip))]
RenameAlias {
function: SpecialEq<Arc<dyn RenameAliasFn>>,
expr: Arc<Expr>,
Expand All @@ -157,7 +155,6 @@ pub enum Expr {
/// function to apply
function: SpecialEq<Arc<dyn SeriesUdf>>,
/// output dtype of the function
#[cfg_attr(feature = "serde", serde(skip))]
output_type: GetOutput,
options: FunctionOptions,
},
Expand Down
100 changes: 92 additions & 8 deletions crates/polars-plan/src/dsl/expr_dyn_fn.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt::Formatter;
use std::ops::Deref;
use std::sync::Arc;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand All @@ -17,7 +18,7 @@ pub trait SeriesUdf: Send + Sync {
fn call_udf(&self, s: &mut [Series]) -> PolarsResult<Option<Series>>;

fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
polars_bail!(ComputeError: "serialize not supported for this 'opaque' function")
polars_bail!(ComputeError: "serialization not supported for this 'opaque' function")
}

// Needed for python functions. After they are deserialized we first check if they
Expand Down Expand Up @@ -46,30 +47,29 @@ impl Serialize for SpecialEq<Arc<dyn SeriesUdf>> {

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn SeriesUdf>> {
fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
{
use serde::de::Error;
#[cfg(feature = "python")]
{
use crate::dsl::python_udf::MAGIC_BYTE_MARK;
let buf = Vec::<u8>::deserialize(_deserializer)?;
let buf = Vec::<u8>::deserialize(deserializer)?;

if buf.starts_with(MAGIC_BYTE_MARK) {
if buf.starts_with(python_udf::MAGIC_BYTE_MARK) {
let udf = python_udf::PythonUdfExpression::try_deserialize(&buf)
.map_err(|e| D::Error::custom(format!("{e}")))?;
Ok(SpecialEq::new(udf))
} else {
Err(D::Error::custom(
"deserialize not supported for this 'opaque' function",
"deserialization not supported for this 'opaque' function",
))
}
}
#[cfg(not(feature = "python"))]
{
Err(D::Error::custom(
"deserialize not supported for this 'opaque' function",
"deserialization not supported for this 'opaque' function",
))
}
}
Expand Down Expand Up @@ -125,9 +125,16 @@ impl Default for SpecialEq<Arc<dyn BinaryUdfOutputField>> {

pub trait RenameAliasFn: Send + Sync {
fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr>;

fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
polars_bail!(ComputeError: "serialization not supported for this renaming function")
}
}

impl<F: Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + Send + Sync> RenameAliasFn for F {
impl<F> RenameAliasFn for F
where
F: Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + Send + Sync,
{
fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr> {
self(name)
}
Expand Down Expand Up @@ -250,6 +257,10 @@ pub trait FunctionOutputField: Send + Sync {
cntxt: Context,
fields: &[Field],
) -> PolarsResult<Field>;

fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
polars_bail!(ComputeError: "serialization not supported for this output field")
}
}

pub type GetOutput = SpecialEq<Arc<dyn FunctionOutputField>>;
Expand Down Expand Up @@ -344,3 +355,76 @@ where
self(input_schema, cntxt, fields)
}
}

#[cfg(feature = "serde")]
impl Serialize for GetOutput {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::Error;
let mut buf = vec![];
self.0
.try_serialize(&mut buf)
.map_err(|e| S::Error::custom(format!("{e}")))?;
serializer.serialize_bytes(&buf)
}
}

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for GetOutput {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
{
use serde::de::Error;
#[cfg(feature = "python")]
{
let buf = Vec::<u8>::deserialize(deserializer)?;

if buf.starts_with(python_udf::MAGIC_BYTE_MARK) {
let get_output = python_udf::PythonGetOutput::try_deserialize(&buf)
.map_err(|e| D::Error::custom(format!("{e}")))?;
Ok(SpecialEq::new(get_output))
} else {
Err(D::Error::custom(
"deserialization not supported for this output field",
))
}
}
#[cfg(not(feature = "python"))]
{
Err(D::Error::custom(
"deserialization not supported for this output field",
))
}
}
}

#[cfg(feature = "serde")]
impl Serialize for SpecialEq<Arc<dyn RenameAliasFn>> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::Error;
let mut buf = vec![];
self.0
.try_serialize(&mut buf)
.map_err(|e| S::Error::custom(format!("{e}")))?;
serializer.serialize_bytes(&buf)
}
}

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn RenameAliasFn>> {
fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
{
use serde::de::Error;
Err(D::Error::custom(
"deserialization not supported for this renaming function",
))
}
}
69 changes: 56 additions & 13 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use polars_core::datatypes::{DataType, Field};
use polars_core::error::*;
use polars_core::frame::DataFrame;
use polars_core::prelude::Series;
use polars_core::schema::Schema;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedBytes;
use pyo3::types::PyBytes;
Expand All @@ -17,14 +18,14 @@ use super::expr_dyn_fn::*;
use crate::constants::MAP_LIST_NAME;
use crate::prelude::*;

// Will be overwritten on python polar start up.
// Will be overwritten on Python Polars start up.
pub static mut CALL_SERIES_UDF_PYTHON: Option<
fn(s: Series, lambda: &PyObject) -> PolarsResult<Series>,
> = None;
pub static mut CALL_DF_UDF_PYTHON: Option<
fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
> = None;
pub(super) const MAGIC_BYTE_MARK: &[u8] = "POLARS_PYTHON_UDF".as_bytes();
pub(super) const MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes();

#[derive(Clone, Debug)]
pub struct PythonFunction(pub PyObject);
Expand Down Expand Up @@ -141,7 +142,7 @@ impl PythonUdfExpression {
.unwrap();
let arg = (PyBytes::new_bound(py, remainder),);
let python_function = pickle.call1(arg).map_err(from_pyerr)?;
Ok(Arc::new(PythonUdfExpression::new(
Ok(Arc::new(Self::new(
python_function.into(),
output_type,
is_elementwise,
Expand Down Expand Up @@ -229,6 +230,54 @@ impl SeriesUdf for PythonUdfExpression {
}
}

/// Serializable version of [`GetOutput`] for Python UDFs.
pub struct PythonGetOutput {
return_dtype: Option<DataType>,
}

impl PythonGetOutput {
pub fn new(return_dtype: Option<DataType>) -> Self {
Self { return_dtype }
}

#[cfg(feature = "serde")]
pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn FunctionOutputField>> {
// Skip header.
debug_assert!(buf.starts_with(MAGIC_BYTE_MARK));
let buf = &buf[MAGIC_BYTE_MARK.len()..];

let mut reader = Cursor::new(buf);
let return_dtype: Option<DataType> =
ciborium::de::from_reader(&mut reader).map_err(map_err)?;

Ok(Arc::new(Self::new(return_dtype)) as Arc<dyn FunctionOutputField>)
}
}

impl FunctionOutputField for PythonGetOutput {
fn get_field(
&self,
_input_schema: &Schema,
_cntxt: Context,
fields: &[Field],
) -> PolarsResult<Field> {
// Take the name of first field, just like [`GetOutput::map_field`].
let name = fields[0].name();
let return_dtype = match self.return_dtype {
Some(ref dtype) => dtype.clone(),
None => DataType::Unknown(Default::default()),
};
Ok(Field::new(name.clone(), return_dtype))
}

#[cfg(feature = "serde")]
fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
buf.extend_from_slice(MAGIC_BYTE_MARK);
ciborium::ser::into_writer(&self.return_dtype, &mut *buf).unwrap();
Ok(())
}
}

impl Expr {
pub fn map_python(self, func: PythonUdfExpression, agg_list: bool) -> Expr {
let (collect_groups, name) = if agg_list {
Expand All @@ -241,16 +290,10 @@ impl Expr {

let returns_scalar = func.returns_scalar;
let return_dtype = func.output_type.clone();
let output_type = GetOutput::map_field(move |fld| {
Ok(match return_dtype {
Some(ref dt) => Field::new(fld.name().clone(), dt.clone()),
None => {
let mut fld = fld.clone();
fld.coerce(DataType::Unknown(Default::default()));
fld
},
})
});

let output_field = PythonGetOutput::new(return_dtype);
let output_type = SpecialEq::new(Arc::new(output_field) as Arc<dyn FunctionOutputField>);

let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
if returns_scalar {
flags |= FunctionFlags::RETURNS_SCALAR;
Expand Down
Loading

0 comments on commit 8a20588

Please sign in to comment.