Skip to content

Commit

Permalink
Add SchemaSerializer.__reduce__ method to enable pickle serializa…
Browse files Browse the repository at this point in the history
…tion (#1006)

Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
edoakes authored Oct 9, 2023
1 parent 8e66bd9 commit b51105a
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 28 deletions.
26 changes: 23 additions & 3 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@ mod ob_type;
mod shared;
mod type_serializers;

#[pyclass(module = "pydantic_core._pydantic_core")]
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaSerializer {
serializer: CombinedSerializer,
definitions: Definitions<CombinedSerializer>,
expected_json_size: AtomicUsize,
config: SerializationConfig,
// References to the Python schema and config objects are saved to enable
// reconstructing the object for pickle support (see `__reduce__`).
py_schema: Py<PyDict>,
py_config: Option<Py<PyDict>>,
}

impl SchemaSerializer {
Expand Down Expand Up @@ -71,15 +75,19 @@ impl SchemaSerializer {
#[pymethods]
impl SchemaSerializer {
#[new]
pub fn py_new(schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
pub fn py_new(py: Python, schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
let mut definitions_builder = DefinitionsBuilder::new();

let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
Ok(Self {
serializer,
definitions: definitions_builder.finish()?,
expected_json_size: AtomicUsize::new(1024),
config: SerializationConfig::from_config(config)?,
py_schema: schema.into_py(py),
py_config: match config {
Some(c) if !c.is_empty() => Some(c.into_py(py)),
_ => None,
},
})
}

Expand Down Expand Up @@ -174,6 +182,14 @@ impl SchemaSerializer {
Ok(py_bytes.into())
}

pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<(PyObject, (PyObject, PyObject))> {
// Enables support for `pickle` serialization.
let py = slf.py();
let cls = slf.get_type().into();
let init_args = (slf.get().py_schema.to_object(py), slf.get().py_config.to_object(py));
Ok((cls, init_args))
}

pub fn __repr__(&self) -> String {
format!(
"SchemaSerializer(serializer={:#?}, definitions={:#?})",
Expand All @@ -182,6 +198,10 @@ impl SchemaSerializer {
}

fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
visit.call(&self.py_schema)?;
if let Some(ref py_config) = self.py_config {
visit.call(py_config)?;
}
self.serializer.py_gc_traverse(&visit)?;
self.definitions.py_gc_traverse(&visit)?;
Ok(())
Expand Down
32 changes: 23 additions & 9 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,15 @@ impl PySome {
}
}

#[pyclass(module = "pydantic_core._pydantic_core")]
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaValidator {
validator: CombinedValidator,
definitions: Definitions<CombinedValidator>,
schema: PyObject,
// References to the Python schema and config objects are saved to enable
// reconstructing the object for cloudpickle support (see `__reduce__`).
py_schema: Py<PyAny>,
py_config: Option<Py<PyDict>>,
#[pyo3(get)]
title: PyObject,
hide_input_in_errors: bool,
Expand All @@ -121,6 +124,11 @@ impl SchemaValidator {
for val in definitions.values() {
val.get().unwrap().complete()?;
}
let py_schema = schema.into_py(py);
let py_config = match config {
Some(c) if !c.is_empty() => Some(c.into_py(py)),
_ => None,
};
let config_title = match config {
Some(c) => c.get_item("title"),
None => None,
Expand All @@ -134,18 +142,20 @@ impl SchemaValidator {
Ok(Self {
validator,
definitions,
schema: schema.into_py(py),
py_schema,
py_config,
title,
hide_input_in_errors,
validation_error_cause,
})
}

pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<PyObject> {
pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<(PyObject, (PyObject, PyObject))> {
// Enables support for `pickle` serialization.
let py = slf.py();
let args = (slf.try_borrow()?.schema.to_object(py),);
let cls = slf.getattr("__class__")?;
Ok((cls, args).into_py(py))
let cls = slf.get_type().into();
let init_args = (slf.get().py_schema.to_object(py), slf.get().py_config.to_object(py));
Ok((cls, init_args))
}

#[pyo3(signature = (input, *, strict=None, from_attributes=None, context=None, self_instance=None))]
Expand Down Expand Up @@ -307,7 +317,10 @@ impl SchemaValidator {

fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.validator.py_gc_traverse(&visit)?;
visit.call(&self.schema)?;
visit.call(&self.py_schema)?;
if let Some(ref py_config) = self.py_config {
visit.call(py_config)?;
}
Ok(())
}
}
Expand Down Expand Up @@ -396,7 +409,8 @@ impl<'py> SelfValidator<'py> {
Ok(SchemaValidator {
validator,
definitions,
schema: py.None(),
py_schema: py.None(),
py_config: None,
title: "Self Schema".into_py(py),
hide_input_in_errors: false,
validation_error_cause: false,
Expand Down
50 changes: 50 additions & 0 deletions tests/serializers/test_pickling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import json
import pickle
from datetime import timedelta

import pytest

from pydantic_core import core_schema
from pydantic_core._pydantic_core import SchemaSerializer


def repr_function(value, _info):
return repr(value)


def test_basic_schema_serializer():
s = SchemaSerializer(core_schema.dict_schema())
s = pickle.loads(pickle.dumps(s))
assert s.to_python({'a': 1, b'b': 2, 33: 3}) == {'a': 1, b'b': 2, 33: 3}
assert s.to_python({'a': 1, b'b': 2, 33: 3, True: 4}, mode='json') == {'a': 1, 'b': 2, '33': 3, 'true': 4}
assert s.to_json({'a': 1, b'b': 2, 33: 3, True: 4}) == b'{"a":1,"b":2,"33":3,"true":4}'

assert s.to_python({(1, 2): 3}) == {(1, 2): 3}
assert s.to_python({(1, 2): 3}, mode='json') == {'1,2': 3}
assert s.to_json({(1, 2): 3}) == b'{"1,2":3}'


@pytest.mark.parametrize(
'value,expected_python,expected_json',
[(None, 'None', b'"None"'), (1, '1', b'"1"'), ([1, 2, 3], '[1, 2, 3]', b'"[1, 2, 3]"')],
)
def test_schema_serializer_capturing_function(value, expected_python, expected_json):
# Test a SchemaSerializer that captures a function.
s = SchemaSerializer(
core_schema.any_schema(
serialization=core_schema.plain_serializer_function_ser_schema(repr_function, info_arg=True)
)
)
s = pickle.loads(pickle.dumps(s))
assert s.to_python(value) == expected_python
assert s.to_json(value) == expected_json
assert s.to_python(value, mode='json') == json.loads(expected_json)


def test_schema_serializer_containing_config():
s = SchemaSerializer(core_schema.timedelta_schema(), config={'ser_json_timedelta': 'float'})
s = pickle.loads(pickle.dumps(s))

assert s.to_python(timedelta(seconds=4, microseconds=500_000)) == timedelta(seconds=4, microseconds=500_000)
assert s.to_python(timedelta(seconds=4, microseconds=500_000), mode='json') == 4.5
assert s.to_json(timedelta(seconds=4, microseconds=500_000)) == b'4.5'
4 changes: 2 additions & 2 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ mod tests {
]
}"#;
let schema: &PyDict = py.eval(code, None, None).unwrap().extract().unwrap();
SchemaSerializer::py_new(schema, None).unwrap();
SchemaSerializer::py_new(py, schema, None).unwrap();
});
}

Expand Down Expand Up @@ -77,7 +77,7 @@ a = A()
py.run(code, None, Some(locals)).unwrap();
let a: &PyAny = locals.get_item("a").unwrap().extract().unwrap();
let schema: &PyDict = locals.get_item("schema").unwrap().extract().unwrap();
let serialized: Vec<u8> = SchemaSerializer::py_new(schema, None)
let serialized: Vec<u8> = SchemaSerializer::py_new(py, schema, None)
.unwrap()
.to_json(py, a, None, None, None, true, false, false, false, false, true, None)
.unwrap()
Expand Down
9 changes: 7 additions & 2 deletions tests/test_garbage_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class BaseModel:
__schema__: SchemaSerializer

def __init_subclass__(cls) -> None:
cls.__schema__ = SchemaSerializer(core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER))
cls.__schema__ = SchemaSerializer(
core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER), config={'ser_json_timedelta': 'float'}
)

cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary()

Expand Down Expand Up @@ -56,7 +58,10 @@ class BaseModel:
__validator__: SchemaValidator

def __init_subclass__(cls) -> None:
cls.__validator__ = SchemaValidator(core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER))
cls.__validator__ = SchemaValidator(
core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER),
config=core_schema.CoreConfig(extra_fields_behavior='allow'),
)

cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary()

Expand Down
12 changes: 0 additions & 12 deletions tests/validators/test_datetime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import json
import pickle
import platform
import re
from datetime import date, datetime, time, timedelta, timezone, tzinfo
Expand Down Expand Up @@ -480,17 +479,6 @@ def test_tz_constraint_wrong():
validate_core_schema(core_schema.datetime_schema(tz_constraint='wrong'))


def test_tz_pickle() -> None:
"""
https://github.com/pydantic/pydantic-core/issues/589
"""
v = SchemaValidator(core_schema.datetime_schema())
original = datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15)))
validated = v.validate_python('2022-06-08T12:13:14-12:15')
assert validated == original
assert pickle.loads(pickle.dumps(validated)) == validated == original


def test_tz_hash() -> None:
v = SchemaValidator(core_schema.datetime_schema())
lookup: Dict[datetime, str] = {}
Expand Down
53 changes: 53 additions & 0 deletions tests/validators/test_pickling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pickle
import re
from datetime import datetime, timedelta, timezone

import pytest

from pydantic_core import core_schema, validate_core_schema
from pydantic_core._pydantic_core import SchemaValidator, ValidationError


def test_basic_schema_validator():
v = SchemaValidator(
validate_core_schema(
{'type': 'dict', 'strict': True, 'keys_schema': {'type': 'int'}, 'values_schema': {'type': 'int'}}
)
)
v = pickle.loads(pickle.dumps(v))
assert v.validate_python({'1': 2, '3': 4}) == {1: 2, 3: 4}
assert v.validate_python({}) == {}
with pytest.raises(ValidationError, match=re.escape('[type=dict_type, input_value=[], input_type=list]')):
v.validate_python([])


def test_schema_validator_containing_config():
"""
Verify that the config object is not lost during (de)serialization.
"""
v = SchemaValidator(
core_schema.model_fields_schema({'f': core_schema.model_field(core_schema.str_schema())}),
config=core_schema.CoreConfig(extra_fields_behavior='allow'),
)
v = pickle.loads(pickle.dumps(v))

m, model_extra, fields_set = v.validate_python({'f': 'x', 'extra_field': '123'})
assert m == {'f': 'x'}
# If the config was lost during (de)serialization, the below checks would fail as
# the default behavior is to ignore extra fields.
assert model_extra == {'extra_field': '123'}
assert fields_set == {'f', 'extra_field'}

v.validate_assignment(m, 'f', 'y')
assert m == {'f': 'y'}


def test_schema_validator_tz_pickle() -> None:
"""
https://github.com/pydantic/pydantic-core/issues/589
"""
v = SchemaValidator(core_schema.datetime_schema())
original = datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15)))
validated = v.validate_python('2022-06-08T12:13:14-12:15')
assert validated == original
assert pickle.loads(pickle.dumps(validated)) == validated == original

0 comments on commit b51105a

Please sign in to comment.