From 441188e0583588a01febadc6f056ce38e2f826da Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Wed, 12 Jun 2024 22:06:50 +0100 Subject: [PATCH 01/18] implement complex --- generate_self_schema.py | 2 +- python/pydantic_core/core_schema.py | 39 ++++++++++ src/errors/types.rs | 3 + src/input/input_abstract.rs | 4 +- src/input/input_json.rs | 32 +++++++- src/input/input_python.rs | 36 ++++++++- src/input/input_string.rs | 10 ++- src/input/return_enums.rs | 29 ++++++- src/input/shared.rs | 7 ++ src/serializers/infer.rs | 24 +++++- src/serializers/ob_type.rs | 8 +- src/serializers/shared.rs | 2 + src/serializers/type_serializers/complex.rs | 83 +++++++++++++++++++++ src/serializers/type_serializers/mod.rs | 1 + src/validators/complex.rs | 39 ++++++++++ src/validators/mod.rs | 3 + tests/serializers/test_complex.py | 35 +++++++++ tests/validators/test_complex.py | 79 ++++++++++++++++++++ 18 files changed, 426 insertions(+), 10 deletions(-) create mode 100644 src/serializers/type_serializers/complex.rs create mode 100644 src/validators/complex.rs create mode 100644 tests/serializers/test_complex.py create mode 100644 tests/validators/test_complex.py diff --git a/generate_self_schema.py b/generate_self_schema.py index acecf19d8..415cd1f97 100644 --- a/generate_self_schema.py +++ b/generate_self_schema.py @@ -49,7 +49,7 @@ def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: if isinstance(obj, str): return {'type': obj} - elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal): + elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal, complex): return {'type': obj.__name__.lower()} elif is_typeddict(obj): return type_dict_schema(obj, definitions) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 2cb875b23..e44a62d5f 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -742,6 +742,43 @@ def decimal_schema( ) +class ComplexSchema(TypedDict, total=False): + type: Required[Literal['complex']] + ref: str + metadata: Any + serialization: SerSchema + + +def complex_schema( + *, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> ComplexSchema: + """ + Returns a schema that matches a complex value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.complex_schema() + v = SchemaValidator(schema) + assert v.validate_python({'real': 1, 'imag': 2}) == complex(1, 2) + ``` + + Args: + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='complex', + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + class StringSchema(TypedDict, total=False): type: Required[Literal['str']] pattern: str @@ -3777,6 +3814,7 @@ def definition_reference_schema( DefinitionsSchema, DefinitionReferenceSchema, UuidSchema, + ComplexSchema, ] elif False: CoreSchema: TypeAlias = Mapping[str, Any] @@ -3832,6 +3870,7 @@ def definition_reference_schema( 'definitions', 'definition-ref', 'uuid', + 'complex', ] CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field'] diff --git a/src/errors/types.rs b/src/errors/types.rs index 57025f9b5..b97f0924b 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -422,6 +422,8 @@ error_types! { DecimalWholeDigits { whole_digits: {ctx_type: u64, ctx_fn: field_from_context}, }, + // Complex errors + ComplexType {}, } macro_rules! render { @@ -564,6 +566,7 @@ impl ErrorType { Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total", Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}", Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point", + Self::ComplexType { .. } => "Input should be a valid dictionary with exactly two keys, 'real' and 'imag', with float values", } } diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 5e915c548..4c6853dde 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -9,7 +9,7 @@ use crate::lookup_key::{LookupKey, LookupPath}; use crate::tools::py_err; use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta}; -use super::return_enums::{EitherBytes, EitherInt, EitherString}; +use super::return_enums::{EitherBytes, EitherComplex, EitherInt, EitherString}; use super::{EitherFloat, GenericIterator, ValidationMatch}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -172,6 +172,8 @@ pub trait Input<'py>: fmt::Debug + ToPyObject { strict: bool, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, ) -> ValMatch>; + + fn validate_complex(&self) -> ValMatch>; } /// The problem to solve here is that iterating collections often returns owned diff --git a/src/input/input_json.rs b/src/input/input_json.rs index f2bf74998..cf0ca732d 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::collections::HashSet; use jiter::{JsonArray, JsonObject, JsonValue, LazyIndexMap}; use pyo3::prelude::*; @@ -16,7 +17,7 @@ use super::datetime::{ float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; use super::input_abstract::{ConsumeIterator, Never, ValMatch}; -use super::return_enums::ValidationMatch; +use super::return_enums::{EitherComplex, ValidationMatch}; use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_float, str_as_int}; use super::{ Arguments, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, Input, @@ -296,6 +297,31 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> { _ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } + + fn validate_complex(&self) -> ValResult>> { + let default = JsonValue::Float(0.0); + match self { + JsonValue::Object(object) => { + let mut allowed_keys = HashSet::from(["real".to_owned(), "imag".to_owned()]); + for key in object.keys() { + let k = &key.to_string(); + if !allowed_keys.remove(k) { + return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); + } + } + let real = object.get("real").unwrap_or(&default).validate_float(true); + let imag = object.get("imag").unwrap_or(&default).validate_float(true); + if real.is_err() || imag.is_err() { + return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); + } + Ok(ValidationMatch::strict(EitherComplex::Complex([ + real.unwrap().into_inner().as_f64(), + imag.unwrap().into_inner().as_f64(), + ]))) + } + _ => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)), + } + } } /// Required for JSON Object keys so the string can behave like an Input @@ -425,6 +451,10 @@ impl<'py> Input<'py> for str { ) -> ValResult>> { bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax) } + + fn validate_complex(&self) -> ValResult>> { + Err(ValError::new(ErrorTypeDefaults::ComplexType, self)) + } } impl BorrowInput<'_> for &'_ String { diff --git a/src/input/input_python.rs b/src/input/input_python.rs index b2284efb5..336508747 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -6,8 +6,8 @@ use pyo3::prelude::*; use pyo3::types::PyType; use pyo3::types::{ - PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList, - PyMapping, PySet, PyString, PyTime, PyTuple, + PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, + PyList, PyMapping, PySet, PyString, PyTime, PyTuple, }; use pyo3::PyTypeCheck; @@ -25,6 +25,7 @@ use super::datetime::{ EitherTime, }; use super::input_abstract::ValMatch; +use super::return_enums::EitherComplex; use super::return_enums::{iterate_attributes, iterate_mapping_items, ValidationMatch}; use super::shared::{ decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int, @@ -592,6 +593,37 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)) } + + fn validate_complex<'a>(&'a self) -> ValResult>> { + if let Ok(complex) = self.downcast::() { + return Ok(ValidationMatch::exact(EitherComplex::Py(complex.to_owned()))); + } else if let Ok(complex) = self.downcast::() { + let re = complex.get_item("real"); + let im = complex.get_item("imag"); + if complex.len() > 2 || re.is_err() && im.is_err() { + return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); + } + let mut res = [0.0, 0.0]; + if let Some(v) = re.unwrap_or(None) { + if v.is_instance_of::() || v.is_instance_of::() { + let u = v.extract::(); + res[0] = u.unwrap_or(0.0); + } else { + return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); + } + } + if let Some(v) = im.unwrap_or(None) { + if v.is_instance_of::() || v.is_instance_of::() { + let u = v.extract::(); + res[1] = u.unwrap_or(0.0); + } else { + return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); + } + } + return Ok(ValidationMatch::exact(EitherComplex::Complex(res))); + } + Err(ValError::new(ErrorTypeDefaults::ComplexType, self)) + } } impl<'py> BorrowInput<'py> for Bound<'py, PyAny> { diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 3c61cdebc..b031a6da5 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -13,7 +13,8 @@ use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime, }; use super::input_abstract::{Never, ValMatch}; -use super::shared::{str_as_bool, str_as_float, str_as_int}; +use super::return_enums::EitherComplex; +use super::shared::{str_as_bool, str_as_complex, str_as_float, str_as_int}; use super::{ Arguments, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, Input, KeywordArgs, ValidatedDict, ValidationMatch, @@ -217,6 +218,13 @@ impl<'py> Input<'py> for StringMapping<'py> { Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } + + fn validate_complex(&self) -> ValResult>> { + match self { + Self::String(s) => str_as_complex(self, py_string_str(s)?).map(ValidationMatch::strict), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), + } + } } impl<'py> BorrowInput<'py> for StringMapping<'py> { diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 22faaba71..d246e0293 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -12,7 +12,7 @@ use pyo3::intern; use pyo3::prelude::*; #[cfg(not(PyPy))] use pyo3::types::PyFunction; -use pyo3::types::{PyBytes, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString}; +use pyo3::types::{PyBytes, PyComplex, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString}; use serde::{ser::Error, Serialize, Serializer}; @@ -716,3 +716,30 @@ impl ToPyObject for Int { } } } + +#[derive(Clone)] +pub enum EitherComplex<'a> { + Complex([f64; 2]), + Py(Bound<'a, PyComplex>), +} + +impl<'a> IntoPy for EitherComplex<'a> { + fn into_py(self, py: Python<'_>) -> PyObject { + match self { + Self::Complex(c) => PyComplex::from_doubles_bound(py, c[0], c[1]).into_py(py), + Self::Py(c) => c.into_py(py), + } + } +} + +impl<'a> EitherComplex<'a> { + pub fn as_f64(&self, py: Python<'_>) -> [f64; 2] { + match self { + EitherComplex::Complex(f) => *f, + EitherComplex::Py(f) => [ + f.getattr(intern!(py, "real")).unwrap().extract().unwrap(), + f.getattr(intern!(py, "imag")).unwrap().extract().unwrap(), + ], + } + } +} diff --git a/src/input/shared.rs b/src/input/shared.rs index 95b9912a5..4753c9fe1 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -8,6 +8,7 @@ use jiter::{JsonErrorType, NumberInt}; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; +use super::return_enums::EitherComplex; use super::{EitherFloat, EitherInt, Input}; static ENUM_META_OBJECT: GILOnceCell> = GILOnceCell::new(); @@ -204,3 +205,9 @@ pub fn decimal_as_int<'py>( } Ok(EitherInt::Py(numerator)) } + +/// parse a complex as a complex +pub fn str_as_complex<'py>(input: &(impl Input<'py> + ?Sized), _str: &str) -> ValResult> { + // TODO + Err(ValError::new(ErrorTypeDefaults::ComplexType, input)) +} diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 2384e8691..f39ad3ee8 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -4,6 +4,7 @@ use pyo3::exceptions::PyTypeError; use pyo3::intern; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; +use pyo3::types::PyComplex; use pyo3::types::{PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple}; use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer}; @@ -225,6 +226,13 @@ pub(crate) fn infer_to_python_known( } PyList::new_bound(py, items).into_py(py) } + ObType::Complex => { + let dict = value.downcast::()?; + let new_dict = PyDict::new_bound(py); + let _ = new_dict.set_item("real", dict.get_item("real")?); + let _ = new_dict.set_item("imag", dict.get_item("imag")?); + new_dict.into_py(py) + } ObType::Path => value.str()?.into_py(py), ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py), ObType::Unknown => { @@ -273,6 +281,13 @@ pub(crate) fn infer_to_python_known( ); iter.into_py(py) } + ObType::Complex => { + let dict = value.downcast::()?; + let new_dict = PyDict::new_bound(py); + let _ = new_dict.set_item("real", dict.get_item("real")?); + let _ = new_dict.set_item("imag", dict.get_item("imag")?); + new_dict.into_py(py) + } ObType::Unknown => { if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; @@ -401,6 +416,13 @@ pub(crate) fn infer_serialize_known( ObType::None => serializer.serialize_none(), ObType::Int | ObType::IntSubclass => serialize!(Int), ObType::Bool => serialize!(bool), + ObType::Complex => { + let v = value.downcast::().map_err(py_err_se_err)?; + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry(&"real", &v.real())?; + map.serialize_entry(&"imag", &v.imag())?; + map.end() + } ObType::Float | ObType::FloatSubclass => { let v = value.extract::().map_err(py_err_se_err)?; if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null { @@ -650,7 +672,7 @@ pub(crate) fn infer_json_key_known<'a>( } Ok(Cow::Owned(key_build.finish())) } - ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => { + ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator | ObType::Complex => { py_err!(PyTypeError; "`{}` not valid as object key", ob_type) } ObType::Dataclass | ObType::PydanticSerializable => { diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index 1eaa090a8..f2dbf6a45 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -1,8 +1,8 @@ use pyo3::prelude::*; use pyo3::sync::GILOnceCell; use pyo3::types::{ - PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList, - PyNone, PySet, PyString, PyTime, PyTuple, PyType, + PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt, + PyIterator, PyList, PyNone, PySet, PyString, PyTime, PyTuple, PyType, }; use pyo3::{intern, PyTypeInfo}; @@ -48,6 +48,7 @@ pub struct ObTypeLookup { pattern_object: PyObject, // uuid type uuid_object: PyObject, + complex: usize, } static TYPE_LOOKUP: GILOnceCell = GILOnceCell::new(); @@ -101,6 +102,7 @@ impl ObTypeLookup { .to_object(py), pattern_object: py.import_bound("re").unwrap().getattr("Pattern").unwrap().to_object(py), uuid_object: py.import_bound("uuid").unwrap().getattr("UUID").unwrap().to_object(py), + complex: PyComplex::type_object_raw(py) as usize, } } @@ -171,6 +173,7 @@ impl ObTypeLookup { ObType::Pattern => self.path_object.as_ptr() as usize == ob_type, ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type, ObType::Unknown => false, + ObType::Complex => self.complex == ob_type, }; if ans { @@ -425,6 +428,7 @@ pub enum ObType { Uuid, // unknown type Unknown, + Complex, } impl PartialEq for ObType { diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index b9e0a727d..4c7fb57b6 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -142,6 +142,7 @@ combined_serializer! { Enum: super::type_serializers::enum_::EnumSerializer; Recursive: super::type_serializers::definitions::DefinitionRefSerializer; Tuple: super::type_serializers::tuple::TupleSerializer; + Complex: super::type_serializers::complex::ComplexSerializer; } } @@ -251,6 +252,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit), } } } diff --git a/src/serializers/type_serializers/complex.rs b/src/serializers/type_serializers/complex.rs new file mode 100644 index 000000000..f7237c9dc --- /dev/null +++ b/src/serializers/type_serializers/complex.rs @@ -0,0 +1,83 @@ +use std::borrow::Cow; + +use pyo3::prelude::*; +use pyo3::types::{PyComplex, PyDict}; + +use serde::ser::SerializeMap; + +use crate::definitions::DefinitionsBuilder; + +use super::{infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerMode, TypeSerializer}; + +#[derive(Debug, Clone)] +pub struct ComplexSerializer {} + +impl BuildSerializer for ComplexSerializer { + const EXPECTED_TYPE: &'static str = "complex"; + fn build( + _schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + Ok(Self {}.into()) + } +} + +impl_py_gc_traverse!(ComplexSerializer {}); + +impl TypeSerializer for ComplexSerializer { + fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + let py = value.py(); + match value.downcast::() { + Ok(py_complex) => match extra.mode { + SerMode::Json => { + let new_dict = PyDict::new_bound(py); + let _ = new_dict.set_item("real", py_complex.real()); + let _ = new_dict.set_item("imag", py_complex.imag()); + Ok(new_dict.into_py(py)) + } + _ => Ok(value.into_py(py)), + }, + Err(_) => { + extra.warnings.on_fallback_py(self.get_name(), value, extra)?; + infer_to_python(value, include, exclude, extra) + } + } + } + + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + self._invalid_as_json_key(key, extra, "complex") + } + + fn serde_serialize( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> Result { + match value.downcast::() { + Ok(py_complex) => { + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry(&"real", &py_complex.real())?; + map.serialize_entry(&"imag", &py_complex.imag())?; + map.end() + } + Err(_) => { + extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; + infer_serialize(value, serializer, include, exclude, extra) + } + } + } + + fn get_name(&self) -> &str { + "complex" + } +} diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index da36f0bc1..dabd006a3 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -1,5 +1,6 @@ pub mod any; pub mod bytes; +pub mod complex; pub mod dataclass; pub mod datetime_etc; pub mod decimal; diff --git a/src/validators/complex.rs b/src/validators/complex.rs new file mode 100644 index 000000000..b9b050d74 --- /dev/null +++ b/src/validators/complex.rs @@ -0,0 +1,39 @@ +use pyo3::prelude::*; +use pyo3::types::PyDict; + +use crate::errors::ValResult; +use crate::input::Input; + +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; + +#[derive(Debug)] +pub struct ComplexValidator {} + +impl BuildValidator for ComplexValidator { + const EXPECTED_TYPE: &'static str = "complex"; + fn build( + _schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + Ok(Self {}.into()) + } +} + +impl_py_gc_traverse!(ComplexValidator {}); + +impl Validator for ComplexValidator { + fn validate<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + state: &mut ValidationState<'_, 'py>, + ) -> ValResult { + let res = input.validate_complex()?.unpack(state); + Ok(res.into_py(py)) + } + + fn get_name(&self) -> &str { + "complex" + } +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index ede6489ab..dd55cef8a 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -24,6 +24,7 @@ mod bytes; mod call; mod callable; mod chain; +mod complex; mod custom_error; mod dataclass; mod date; @@ -579,6 +580,7 @@ pub fn build_validator( // recursive (self-referencing) models definitions::DefinitionRefValidator, definitions::DefinitionsValidatorBuilder, + complex::ComplexValidator, ) } @@ -732,6 +734,7 @@ pub enum CombinedValidator { DefinitionRef(definitions::DefinitionRefValidator), // input dependent JsonOrPython(json_or_python::JsonOrPython), + Complex(complex::ComplexValidator), } /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, diff --git a/tests/serializers/test_complex.py b/tests/serializers/test_complex.py new file mode 100644 index 000000000..8154921c5 --- /dev/null +++ b/tests/serializers/test_complex.py @@ -0,0 +1,35 @@ +import json +import math + +import pytest + +from pydantic_core import SchemaSerializer, core_schema + + +@pytest.mark.parametrize( + 'value,substr,expected', + [ + (complex(1, 2), '"real":1.0', {'real': 1.0, 'imag': 2.0}), + (complex(-float('inf'), 2), '"real":-Infinity', {'real': -float('inf'), 'imag': 2.0}), + (complex(float('inf'), 2), '"real":Infinity', {'real': float('inf'), 'imag': 2.0}), + (complex(float('nan'), 2), '"real":NaN', {'real': float('nan'), 'imag': 2.0}), + ], +) +def test_complex_json(value, substr, expected): + v = SchemaSerializer(core_schema.complex_schema()) + c = v.to_python(value) + c_json = v.to_python(value, mode='json') + json_str = v.to_json(value).decode() + c_reloaded = json.loads(json_str) + + assert substr in json_str + assert c.imag == expected['imag'] + + if math.isnan(expected['real']): + assert math.isnan(c.real) + assert math.isnan(c_json['real']) + assert math.isnan(c_reloaded['real']) + else: + assert c.real == expected['real'] + assert c_json['real'] == expected['real'] + assert c_reloaded['real'] == expected['real'] diff --git a/tests/validators/test_complex.py b/tests/validators/test_complex.py new file mode 100644 index 000000000..0e00fb37e --- /dev/null +++ b/tests/validators/test_complex.py @@ -0,0 +1,79 @@ +import math +import re + +import pytest + +from pydantic_core import SchemaValidator, ValidationError + +from ..conftest import Err, PyAndJson + +EXPECTED_TYPE_ERROR_MESSAGE = ( + "Input should be a valid dictionary with exactly two keys, 'real' and 'imag', with float values" +) + + +def test_dict(py_and_json: PyAndJson): + v = py_and_json({'type': 'complex'}) + assert v.validate_test({'real': 2, 'imag': 4}) == complex(2, 4) + with pytest.raises(ValidationError, match=re.escape('[type=complex_type, input_value=[], input_type=list]')): + v.validate_test([]) + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + (complex(2, 4), complex(2, 4)), + ({'real': 2, 'imag': 4}, complex(2, 4)), + ({'real': 2}, complex(2, 0)), + ({'imag': 2}, complex(0, 2)), + ({}, complex(0, 0)), + ({'real': 'test', 'imag': 1}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ('foobar', Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ([], Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ([('x', 'y')], Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ((), Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ((('x', 'y'),), Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ( + (type('Foobar', (), {'x': 1})()), + Err(EXPECTED_TYPE_ERROR_MESSAGE), + ), + ], + ids=repr, +) +def test_complex_cases(input_value, expected): + v = SchemaValidator({'type': 'complex'}) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_python(input_value) + else: + assert v.validate_python(input_value) == expected + + +def test_nan_inf_complex(): + v = SchemaValidator({'type': 'complex'}) + c = v.validate_python({'real': float('nan'), 'imag': float('inf')}) + # c != complex(float('nan'), float('inf')) as nan != nan, + # so we need to examine the values individually + assert math.isnan(c.real) + assert math.isinf(c.imag) + + +def test_json_complex(): + v = SchemaValidator({'type': 'complex'}) + assert v.validate_json('{"real": 2, "imag": 4}') == complex(2, 4) + with pytest.raises(ValidationError) as exc_info: + v.validate_json('1') + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'complex_type', + 'loc': (), + 'msg': EXPECTED_TYPE_ERROR_MESSAGE, + 'input': 1, + } + ] + + +def test_string_complex(): + v = SchemaValidator({'type': 'complex'}) + with pytest.raises(ValidationError, match=re.escape(EXPECTED_TYPE_ERROR_MESSAGE)): + v.validate_strings("{'real': float('nan'), 'imag': 0}") From ca12c2b73193b31e4f0f3af580133ba8e7127da7 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Wed, 12 Jun 2024 22:48:51 +0100 Subject: [PATCH 02/18] fix tests --- python/pydantic_core/core_schema.py | 1 + src/input/input_json.rs | 14 ++++++++------ tests/test_errors.py | 5 +++++ tests/test_schema_functions.py | 1 + 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 0f32f57ef..b641e32c7 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3975,6 +3975,7 @@ def definition_reference_schema( 'decimal_max_digits', 'decimal_max_places', 'decimal_whole_digits', + 'complex_type', ] diff --git a/src/input/input_json.rs b/src/input/input_json.rs index cf0ca732d..a8dbb1b9c 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -311,13 +311,15 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> { } let real = object.get("real").unwrap_or(&default).validate_float(true); let imag = object.get("imag").unwrap_or(&default).validate_float(true); - if real.is_err() || imag.is_err() { - return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); + if let Ok(re) = real { + if let Ok(im) = imag { + return Ok(ValidationMatch::strict(EitherComplex::Complex([ + re.into_inner().as_f64(), + im.into_inner().as_f64(), + ]))); + } } - Ok(ValidationMatch::strict(EitherComplex::Complex([ - real.unwrap().into_inner().as_f64(), - imag.unwrap().into_inner().as_f64(), - ]))) + Err(ValError::new(ErrorTypeDefaults::ComplexType, self)) } _ => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)), } diff --git a/tests/test_errors.py b/tests/test_errors.py index daabf04c6..fcf4ba3e9 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -385,6 +385,11 @@ def f(input_value, info): 'Decimal input should have no more than 1 digit before the decimal point', {'whole_digits': 1}, ), + ( + 'complex_type', + "Input should be a valid dictionary with exactly two keys, 'real' and 'imag', with float values", + None, + ), ] diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index 0971f0b9a..6d96678c8 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -290,6 +290,7 @@ def args(*args, **kwargs): (core_schema.uuid_schema, args(), {'type': 'uuid'}), (core_schema.decimal_schema, args(), {'type': 'decimal'}), (core_schema.decimal_schema, args(multiple_of=5, gt=1.2), {'type': 'decimal', 'multiple_of': 5, 'gt': 1.2}), + (core_schema.complex_schema, args(), {'type': 'complex'}), ] From 1e7dbd971204997c4e51a4896bcbb4bb72f4987f Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Thu, 13 Jun 2024 08:34:56 +0100 Subject: [PATCH 03/18] handle real numbers --- src/input/input_json.rs | 2 ++ src/input/input_python.rs | 18 ++++++++++++++---- src/input/input_string.rs | 2 +- tests/validators/test_complex.py | 11 ++++++++--- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/input/input_json.rs b/src/input/input_json.rs index a8dbb1b9c..0e08b2760 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -321,6 +321,8 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> { } Err(ValError::new(ErrorTypeDefaults::ComplexType, self)) } + JsonValue::Float(f) => Ok(ValidationMatch::strict(EitherComplex::Complex([*f, 0.0]))), + JsonValue::Int(f) => Ok(ValidationMatch::strict(EitherComplex::Complex([(*f) as f64, 0.0]))), _ => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)), } } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 4523af06c..9b0dd883f 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -605,22 +605,32 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { } let mut res = [0.0, 0.0]; if let Some(v) = re.unwrap_or(None) { - if v.is_instance_of::() || v.is_instance_of::() { + if v.is_exact_instance_of::() || v.is_exact_instance_of::() { let u = v.extract::(); - res[0] = u.unwrap_or(0.0); + res[0] = u.unwrap(); } else { return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); } } if let Some(v) = im.unwrap_or(None) { - if v.is_instance_of::() || v.is_instance_of::() { + if v.is_exact_instance_of::() || v.is_exact_instance_of::() { let u = v.extract::(); - res[1] = u.unwrap_or(0.0); + res[1] = u.unwrap(); } else { return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); } } return Ok(ValidationMatch::exact(EitherComplex::Complex(res))); + } else if self.is_exact_instance_of::() { + return Ok(ValidationMatch::exact(EitherComplex::Complex([ + self.extract::().unwrap(), + 0.0, + ]))); + } else if self.is_exact_instance_of::() { + return Ok(ValidationMatch::exact(EitherComplex::Complex([ + self.extract::().unwrap() as f64, + 0.0, + ]))); } Err(ValError::new(ErrorTypeDefaults::ComplexType, self)) } diff --git a/src/input/input_string.rs b/src/input/input_string.rs index b031a6da5..652459d4c 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -222,7 +222,7 @@ impl<'py> Input<'py> for StringMapping<'py> { fn validate_complex(&self) -> ValResult>> { match self { Self::String(s) => str_as_complex(self, py_string_str(s)?).map(ValidationMatch::strict), - Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)), } } } diff --git a/tests/validators/test_complex.py b/tests/validators/test_complex.py index 0e00fb37e..350b0e27f 100644 --- a/tests/validators/test_complex.py +++ b/tests/validators/test_complex.py @@ -27,7 +27,10 @@ def test_dict(py_and_json: PyAndJson): ({'real': 2}, complex(2, 0)), ({'imag': 2}, complex(0, 2)), ({}, complex(0, 0)), + (3, complex(3, 0)), + (2.0, complex(2, 0)), ({'real': 'test', 'imag': 1}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ({'real': True, 'imag': 1}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), ('foobar', Err(EXPECTED_TYPE_ERROR_MESSAGE)), ([], Err(EXPECTED_TYPE_ERROR_MESSAGE)), ([('x', 'y')], Err(EXPECTED_TYPE_ERROR_MESSAGE)), @@ -61,14 +64,16 @@ def test_nan_inf_complex(): def test_json_complex(): v = SchemaValidator({'type': 'complex'}) assert v.validate_json('{"real": 2, "imag": 4}') == complex(2, 4) + assert v.validate_json('1') == complex(1, 0) + assert v.validate_json('1.0') == complex(1, 0) with pytest.raises(ValidationError) as exc_info: - v.validate_json('1') + v.validate_json('"1"') assert exc_info.value.errors(include_url=False) == [ { 'type': 'complex_type', 'loc': (), 'msg': EXPECTED_TYPE_ERROR_MESSAGE, - 'input': 1, + 'input': '1', } ] @@ -76,4 +81,4 @@ def test_json_complex(): def test_string_complex(): v = SchemaValidator({'type': 'complex'}) with pytest.raises(ValidationError, match=re.escape(EXPECTED_TYPE_ERROR_MESSAGE)): - v.validate_strings("{'real': float('nan'), 'imag': 0}") + v.validate_strings("{'real': 1, 'imag': 0}") From 419ec8de807596c17f2b97bdab7bf5a24571a670 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Sat, 22 Jun 2024 22:16:07 +0100 Subject: [PATCH 04/18] handle complex strings --- src/errors/types.rs | 4 +- src/input/input_abstract.rs | 2 +- src/input/input_json.rs | 36 ++++--------- src/input/input_python.rs | 29 ++-------- src/input/input_string.rs | 7 +-- src/input/shared.rs | 7 --- src/serializers/type_serializers/complex.rs | 32 +++++++---- src/validators/complex.rs | 36 +++++++++++-- src/validators/mod.rs | 2 +- tests/serializers/test_complex.py | 36 +++++++------ tests/validators/test_complex.py | 59 +++++++++++++-------- 11 files changed, 134 insertions(+), 116 deletions(-) diff --git a/src/errors/types.rs b/src/errors/types.rs index b97f0924b..99f6622b5 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -423,6 +423,7 @@ error_types! { whole_digits: {ctx_type: u64, ctx_fn: field_from_context}, }, // Complex errors + ComplexParsing{}, ComplexType {}, } @@ -566,7 +567,8 @@ impl ErrorType { Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total", Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}", Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point", - Self::ComplexType { .. } => "Input should be a valid dictionary with exactly two keys, 'real' and 'imag', with float values", + Self::ComplexParsing {..} => "Input should be a valid complex string following the rule at https://docs.python.org/3/library/functions.html#complex", + Self::ComplexType { .. } => "Input should be a valid complex number", } } diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 4c6853dde..8e7ee9f7c 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -173,7 +173,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject { microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, ) -> ValMatch>; - fn validate_complex(&self) -> ValMatch>; + fn validate_complex(&self, py: Python<'py>) -> ValMatch>; } /// The problem to solve here is that iterating collections often returns owned diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 0e08b2760..d68b9c2ca 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -1,5 +1,4 @@ use std::borrow::Cow; -use std::collections::HashSet; use jiter::{JsonArray, JsonObject, JsonValue, LazyIndexMap}; use pyo3::prelude::*; @@ -9,7 +8,9 @@ use speedate::MicrosecondsPrecisionOverflowBehavior; use strum::EnumMessage; use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::input::return_enums::EitherComplex; use crate::lookup_key::{LookupKey, LookupPath}; +use crate::validators::complex::string_to_complex; use crate::validators::decimal::create_decimal; use super::datetime::{ @@ -17,7 +18,7 @@ use super::datetime::{ float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; use super::input_abstract::{ConsumeIterator, Never, ValMatch}; -use super::return_enums::{EitherComplex, ValidationMatch}; +use super::return_enums::ValidationMatch; use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_float, str_as_int}; use super::{ Arguments, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, Input, @@ -298,32 +299,15 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> { } } - fn validate_complex(&self) -> ValResult>> { - let default = JsonValue::Float(0.0); + fn validate_complex(&self, py: Python<'py>) -> ValResult>> { match self { - JsonValue::Object(object) => { - let mut allowed_keys = HashSet::from(["real".to_owned(), "imag".to_owned()]); - for key in object.keys() { - let k = &key.to_string(); - if !allowed_keys.remove(k) { - return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); - } - } - let real = object.get("real").unwrap_or(&default).validate_float(true); - let imag = object.get("imag").unwrap_or(&default).validate_float(true); - if let Ok(re) = real { - if let Ok(im) = imag { - return Ok(ValidationMatch::strict(EitherComplex::Complex([ - re.into_inner().as_f64(), - im.into_inner().as_f64(), - ]))); - } - } - Err(ValError::new(ErrorTypeDefaults::ComplexType, self)) - } + JsonValue::Str(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex( + &PyString::new_bound(py, s), + self, + )?))), JsonValue::Float(f) => Ok(ValidationMatch::strict(EitherComplex::Complex([*f, 0.0]))), JsonValue::Int(f) => Ok(ValidationMatch::strict(EitherComplex::Complex([(*f) as f64, 0.0]))), - _ => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)), + _ => Err(ValError::new(ErrorTypeDefaults::ComplexParsing, self)), } } } @@ -456,7 +440,7 @@ impl<'py> Input<'py> for str { bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax) } - fn validate_complex(&self) -> ValResult>> { + fn validate_complex(&self, _py: Python<'py>) -> ValResult>> { Err(ValError::new(ErrorTypeDefaults::ComplexType, self)) } } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 9b0dd883f..d0a24427b 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -15,6 +15,7 @@ use speedate::MicrosecondsPrecisionOverflowBehavior; use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::tools::{extract_i64, safe_repr}; +use crate::validators::complex::string_to_complex; use crate::validators::decimal::{create_decimal, get_decimal_type}; use crate::validators::Exactness; use crate::ArgsKwargs; @@ -594,33 +595,11 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)) } - fn validate_complex<'a>(&'a self) -> ValResult>> { + fn validate_complex<'a>(&'a self, _py: Python<'py>) -> ValResult>> { if let Ok(complex) = self.downcast::() { return Ok(ValidationMatch::exact(EitherComplex::Py(complex.to_owned()))); - } else if let Ok(complex) = self.downcast::() { - let re = complex.get_item("real"); - let im = complex.get_item("imag"); - if complex.len() > 2 || re.is_err() && im.is_err() { - return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); - } - let mut res = [0.0, 0.0]; - if let Some(v) = re.unwrap_or(None) { - if v.is_exact_instance_of::() || v.is_exact_instance_of::() { - let u = v.extract::(); - res[0] = u.unwrap(); - } else { - return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); - } - } - if let Some(v) = im.unwrap_or(None) { - if v.is_exact_instance_of::() || v.is_exact_instance_of::() { - let u = v.extract::(); - res[1] = u.unwrap(); - } else { - return Err(ValError::new(ErrorTypeDefaults::ComplexType, self)); - } - } - return Ok(ValidationMatch::exact(EitherComplex::Complex(res))); + } else if let Ok(s) = self.downcast::() { + return Ok(ValidationMatch::exact(EitherComplex::Py(string_to_complex(s, self)?))); } else if self.is_exact_instance_of::() { return Ok(ValidationMatch::exact(EitherComplex::Complex([ self.extract::().unwrap(), diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 652459d4c..d89ddaca8 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -7,6 +7,7 @@ use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult} use crate::input::py_string_str; use crate::lookup_key::{LookupKey, LookupPath}; use crate::tools::safe_repr; +use crate::validators::complex::string_to_complex; use crate::validators::decimal::create_decimal; use super::datetime::{ @@ -14,7 +15,7 @@ use super::datetime::{ }; use super::input_abstract::{Never, ValMatch}; use super::return_enums::EitherComplex; -use super::shared::{str_as_bool, str_as_complex, str_as_float, str_as_int}; +use super::shared::{str_as_bool, str_as_float, str_as_int}; use super::{ Arguments, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, Input, KeywordArgs, ValidatedDict, ValidationMatch, @@ -219,9 +220,9 @@ impl<'py> Input<'py> for StringMapping<'py> { } } - fn validate_complex(&self) -> ValResult>> { + fn validate_complex(&self, _py: Python<'py>) -> ValResult>> { match self { - Self::String(s) => str_as_complex(self, py_string_str(s)?).map(ValidationMatch::strict), + Self::String(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(s, self)?))), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)), } } diff --git a/src/input/shared.rs b/src/input/shared.rs index 4753c9fe1..95b9912a5 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -8,7 +8,6 @@ use jiter::{JsonErrorType, NumberInt}; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; -use super::return_enums::EitherComplex; use super::{EitherFloat, EitherInt, Input}; static ENUM_META_OBJECT: GILOnceCell> = GILOnceCell::new(); @@ -205,9 +204,3 @@ pub fn decimal_as_int<'py>( } Ok(EitherInt::Py(numerator)) } - -/// parse a complex as a complex -pub fn str_as_complex<'py>(input: &(impl Input<'py> + ?Sized), _str: &str) -> ValResult> { - // TODO - Err(ValError::new(ErrorTypeDefaults::ComplexType, input)) -} diff --git a/src/serializers/type_serializers/complex.rs b/src/serializers/type_serializers/complex.rs index f7237c9dc..5a525476e 100644 --- a/src/serializers/type_serializers/complex.rs +++ b/src/serializers/type_serializers/complex.rs @@ -3,8 +3,6 @@ use std::borrow::Cow; use pyo3::prelude::*; use pyo3::types::{PyComplex, PyDict}; -use serde::ser::SerializeMap; - use crate::definitions::DefinitionsBuilder; use super::{infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerMode, TypeSerializer}; @@ -37,10 +35,17 @@ impl TypeSerializer for ComplexSerializer { match value.downcast::() { Ok(py_complex) => match extra.mode { SerMode::Json => { - let new_dict = PyDict::new_bound(py); - let _ = new_dict.set_item("real", py_complex.real()); - let _ = new_dict.set_item("imag", py_complex.imag()); - Ok(new_dict.into_py(py)) + let re = py_complex.real(); + let im = py_complex.imag(); + let mut s = format!("{im}j"); + if re != 0.0 { + let mut sign = ""; + if im >= 0.0 { + sign = "+"; + } + s = format!("{re}{sign}{s}"); + } + Ok(s.into_py(py)) } _ => Ok(value.into_py(py)), }, @@ -65,10 +70,17 @@ impl TypeSerializer for ComplexSerializer { ) -> Result { match value.downcast::() { Ok(py_complex) => { - let mut map = serializer.serialize_map(Some(2))?; - map.serialize_entry(&"real", &py_complex.real())?; - map.serialize_entry(&"imag", &py_complex.imag())?; - map.end() + let re = py_complex.real(); + let im = py_complex.imag(); + let mut s = format!("{im}j"); + if re != 0.0 { + let mut sign = ""; + if im >= 0.0 { + sign = "+"; + } + s = format!("{re}{sign}{s}"); + } + Ok(serializer.collect_str::(&s)?) } Err(_) => { extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; diff --git a/src/validators/complex.rs b/src/validators/complex.rs index b9b050d74..dbd238741 100644 --- a/src/validators/complex.rs +++ b/src/validators/complex.rs @@ -1,11 +1,21 @@ +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::PyDict; +use pyo3::sync::GILOnceCell; +use pyo3::types::{PyComplex, PyDict, PyString, PyType}; -use crate::errors::ValResult; +use crate::errors::{ErrorTypeDefaults, ToErrorValue, ValError, ValResult}; use crate::input::Input; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; +static COMPLEX_TYPE: GILOnceCell> = GILOnceCell::new(); + +pub fn get_complex_type(py: Python) -> &Bound<'_, PyType> { + COMPLEX_TYPE + .get_or_init(py, || py.get_type_bound::().into()) + .bind(py) +} + #[derive(Debug)] pub struct ComplexValidator {} @@ -29,7 +39,7 @@ impl Validator for ComplexValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let res = input.validate_complex()?.unpack(state); + let res = input.validate_complex(py)?.unpack(state); Ok(res.into_py(py)) } @@ -37,3 +47,23 @@ impl Validator for ComplexValidator { "complex" } } + +pub(crate) fn string_to_complex<'py>( + arg: &Bound<'py, PyString>, + input: impl ToErrorValue, +) -> ValResult> { + let py = arg.py(); + Ok(get_complex_type(py) + .call1((arg,)) + .map_err(|err| { + // Since arg is a string, the only possible error here is ValueError + // triggered by invalid complex strings and thus only this case is handled. + if err.is_instance_of::(py) { + ValError::new(ErrorTypeDefaults::ComplexParsing, input) + } else { + ValError::InternalErr(err) + } + })? + .downcast::()? + .to_owned()) +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index dd55cef8a..b6aa6f06f 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -24,7 +24,7 @@ mod bytes; mod call; mod callable; mod chain; -mod complex; +pub(crate) mod complex; mod custom_error; mod dataclass; mod date; diff --git a/tests/serializers/test_complex.py b/tests/serializers/test_complex.py index 8154921c5..feab96ac6 100644 --- a/tests/serializers/test_complex.py +++ b/tests/serializers/test_complex.py @@ -1,4 +1,3 @@ -import json import math import pytest @@ -7,29 +6,34 @@ @pytest.mark.parametrize( - 'value,substr,expected', + 'value,expected', [ - (complex(1, 2), '"real":1.0', {'real': 1.0, 'imag': 2.0}), - (complex(-float('inf'), 2), '"real":-Infinity', {'real': -float('inf'), 'imag': 2.0}), - (complex(float('inf'), 2), '"real":Infinity', {'real': float('inf'), 'imag': 2.0}), - (complex(float('nan'), 2), '"real":NaN', {'real': float('nan'), 'imag': 2.0}), + (complex(-1.23e-4, 567.89), '-0.000123+567.89j'), + (complex(0, -1.23), '-1.23j'), + (complex(1.5, 0), '1.5+0j'), + (complex(1, 2), '1+2j'), + (complex(0, 1), 'j'), + (complex(0, 1e-500), '0j'), + (complex(-float('inf'), 2), '-inf+2j'), + (complex(float('inf'), 2), 'inf+2j'), + (complex(float('nan'), 2), 'NaN+2j'), ], ) -def test_complex_json(value, substr, expected): +def test_complex_json(value, expected): v = SchemaSerializer(core_schema.complex_schema()) c = v.to_python(value) c_json = v.to_python(value, mode='json') json_str = v.to_json(value).decode() - c_reloaded = json.loads(json_str) - assert substr in json_str - assert c.imag == expected['imag'] + assert c_json == expected + assert json_str == f'"{expected}"' - if math.isnan(expected['real']): + if math.isnan(value.imag): + assert math.isnan(c.imag) + else: + assert c.imag == value.imag + + if math.isnan(value.real): assert math.isnan(c.real) - assert math.isnan(c_json['real']) - assert math.isnan(c_reloaded['real']) else: - assert c.real == expected['real'] - assert c_json['real'] == expected['real'] - assert c_reloaded['real'] == expected['real'] + assert c.imag == value.imag diff --git a/tests/validators/test_complex.py b/tests/validators/test_complex.py index 350b0e27f..5cb518960 100644 --- a/tests/validators/test_complex.py +++ b/tests/validators/test_complex.py @@ -5,33 +5,30 @@ from pydantic_core import SchemaValidator, ValidationError -from ..conftest import Err, PyAndJson +from ..conftest import Err -EXPECTED_TYPE_ERROR_MESSAGE = ( - "Input should be a valid dictionary with exactly two keys, 'real' and 'imag', with float values" -) - - -def test_dict(py_and_json: PyAndJson): - v = py_and_json({'type': 'complex'}) - assert v.validate_test({'real': 2, 'imag': 4}) == complex(2, 4) - with pytest.raises(ValidationError, match=re.escape('[type=complex_type, input_value=[], input_type=list]')): - v.validate_test([]) +EXPECTED_PARSE_ERROR_MESSAGE = 'Input should be a valid complex string following the rule at https://docs.python.org/3/library/functions.html#complex' +EXPECTED_TYPE_ERROR_MESSAGE = 'Input should be a valid complex number' @pytest.mark.parametrize( 'input_value,expected', [ (complex(2, 4), complex(2, 4)), - ({'real': 2, 'imag': 4}, complex(2, 4)), - ({'real': 2}, complex(2, 0)), - ({'imag': 2}, complex(0, 2)), - ({}, complex(0, 0)), + ('2', complex(2, 0)), + ('2j', complex(0, 2)), + ('+1.23e-4-5.67e+8J', complex(1.23e-4, -5.67e8)), + ('1.5-j', complex(1.5, -1)), + ('-j', complex(0, -1)), + ('j', complex(0, 1)), (3, complex(3, 0)), (2.0, complex(2, 0)), + ('1e-700j', complex(0, 0)), + ('', Err(EXPECTED_PARSE_ERROR_MESSAGE)), + ({'real': 2, 'imag': 4}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), ({'real': 'test', 'imag': 1}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), ({'real': True, 'imag': 1}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), - ('foobar', Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ('foobar', Err(EXPECTED_PARSE_ERROR_MESSAGE)), ([], Err(EXPECTED_TYPE_ERROR_MESSAGE)), ([('x', 'y')], Err(EXPECTED_TYPE_ERROR_MESSAGE)), ((), Err(EXPECTED_TYPE_ERROR_MESSAGE)), @@ -54,31 +51,47 @@ def test_complex_cases(input_value, expected): def test_nan_inf_complex(): v = SchemaValidator({'type': 'complex'}) - c = v.validate_python({'real': float('nan'), 'imag': float('inf')}) + c = v.validate_python('NaN+Infinityj') # c != complex(float('nan'), float('inf')) as nan != nan, # so we need to examine the values individually assert math.isnan(c.real) assert math.isinf(c.imag) +def test_overflow_complex(): + # Python simply converts too large float values to inf, so these strings + # are still valid, even if the numbers are out of range + v = SchemaValidator({'type': 'complex'}) + + c = v.validate_python('5e600j') + assert math.isinf(c.imag) + + c = v.validate_python('-5e600j') + assert math.isinf(c.imag) + + def test_json_complex(): v = SchemaValidator({'type': 'complex'}) - assert v.validate_json('{"real": 2, "imag": 4}') == complex(2, 4) + assert v.validate_json('"-1.23e+4+5.67e-8J"') == complex(-1.23e4, 5.67e-8) assert v.validate_json('1') == complex(1, 0) assert v.validate_json('1.0') == complex(1, 0) + # "1" is a valid complex string + assert v.validate_json('"1"') == complex(1, 0) + with pytest.raises(ValidationError) as exc_info: - v.validate_json('"1"') + v.validate_json('{"real": 2, "imag": 4}') assert exc_info.value.errors(include_url=False) == [ { - 'type': 'complex_type', + 'type': 'complex_parsing', 'loc': (), - 'msg': EXPECTED_TYPE_ERROR_MESSAGE, - 'input': '1', + 'msg': EXPECTED_PARSE_ERROR_MESSAGE, + 'input': {'real': 2, 'imag': 4}, } ] def test_string_complex(): v = SchemaValidator({'type': 'complex'}) - with pytest.raises(ValidationError, match=re.escape(EXPECTED_TYPE_ERROR_MESSAGE)): + assert v.validate_strings('+1.23e-4-5.67e+8J') == complex(1.23e-4, -5.67e8) + with pytest.raises(ValidationError, match=re.escape(EXPECTED_PARSE_ERROR_MESSAGE)): v.validate_strings("{'real': 1, 'imag': 0}") From 930519e0c1dd3a38ca8aa53d0a940073abfd3e97 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Sat, 22 Jun 2024 22:32:02 +0100 Subject: [PATCH 05/18] add test cases --- tests/validators/test_complex.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/validators/test_complex.py b/tests/validators/test_complex.py index 5cb518960..6211d8fe2 100644 --- a/tests/validators/test_complex.py +++ b/tests/validators/test_complex.py @@ -21,10 +21,12 @@ ('1.5-j', complex(1.5, -1)), ('-j', complex(0, -1)), ('j', complex(0, 1)), + ('\t( -1.23+4.5J )\n', complex(-1.23, 4.5)), (3, complex(3, 0)), (2.0, complex(2, 0)), ('1e-700j', complex(0, 0)), ('', Err(EXPECTED_PARSE_ERROR_MESSAGE)), + ('\t( -1.23+4.5J \n', Err(EXPECTED_PARSE_ERROR_MESSAGE)), ({'real': 2, 'imag': 4}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), ({'real': 'test', 'imag': 1}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), ({'real': True, 'imag': 1}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), From 8529a720e6fa3e22a8a5aafae100a618c96e5fa0 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Sat, 22 Jun 2024 22:45:41 +0100 Subject: [PATCH 06/18] fix tests --- python/pydantic_core/core_schema.py | 3 ++- tests/serializers/test_complex.py | 2 +- tests/test_errors.py | 7 ++++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index b641e32c7..86fe013d7 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -763,7 +763,8 @@ def complex_schema( schema = core_schema.complex_schema() v = SchemaValidator(schema) - assert v.validate_python({'real': 1, 'imag': 2}) == complex(1, 2) + assert v.validate_python('1+2j') == complex(1, 2) + assert v.validate_python(complex(1, 2)) == complex(1, 2) ``` Args: diff --git a/tests/serializers/test_complex.py b/tests/serializers/test_complex.py index feab96ac6..e7c98246a 100644 --- a/tests/serializers/test_complex.py +++ b/tests/serializers/test_complex.py @@ -12,7 +12,7 @@ (complex(0, -1.23), '-1.23j'), (complex(1.5, 0), '1.5+0j'), (complex(1, 2), '1+2j'), - (complex(0, 1), 'j'), + (complex(0, 1), '1j'), (complex(0, 1e-500), '0j'), (complex(-float('inf'), 2), '-inf+2j'), (complex(float('inf'), 2), 'inf+2j'), diff --git a/tests/test_errors.py b/tests/test_errors.py index fcf4ba3e9..748a5878e 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -387,7 +387,12 @@ def f(input_value, info): ), ( 'complex_type', - "Input should be a valid dictionary with exactly two keys, 'real' and 'imag', with float values", + "Input should be a valid complex number", + None, + ), + ( + 'complex_parse', + "Input should be a valid complex string following the rule at https://docs.python.org/3/library/functions.html#complex", None, ), ] From c68a9756724b96ecd18644c710130201181f46e8 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Sat, 22 Jun 2024 22:50:50 +0100 Subject: [PATCH 07/18] fix typo --- tests/test_errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_errors.py b/tests/test_errors.py index 748a5878e..0e173c9ea 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -391,7 +391,7 @@ def f(input_value, info): None, ), ( - 'complex_parse', + 'complex_parsing', "Input should be a valid complex string following the rule at https://docs.python.org/3/library/functions.html#complex", None, ), From c3bc5c72e2c04fd4c2ddd54211b347d887962d0e Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Sat, 22 Jun 2024 22:55:41 +0100 Subject: [PATCH 08/18] add error type --- python/pydantic_core/core_schema.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 86fe013d7..5e65cf9e6 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3977,6 +3977,7 @@ def definition_reference_schema( 'decimal_max_places', 'decimal_whole_digits', 'complex_type', + 'complex_parsing', ] From e589080fb1fcb395473010e405ee7c57d7a947c5 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Sat, 22 Jun 2024 23:04:15 +0100 Subject: [PATCH 09/18] fix tests --- python/pydantic_core/core_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 5e65cf9e6..796389718 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3976,8 +3976,8 @@ def definition_reference_schema( 'decimal_max_digits', 'decimal_max_places', 'decimal_whole_digits', - 'complex_type', 'complex_parsing', + 'complex_type', ] From 90b094640c847df61c0a86539028dd0ae6eb696c Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Sat, 22 Jun 2024 23:07:14 +0100 Subject: [PATCH 10/18] fix tests --- tests/test_errors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_errors.py b/tests/test_errors.py index 0e173c9ea..e024ca973 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -387,12 +387,12 @@ def f(input_value, info): ), ( 'complex_type', - "Input should be a valid complex number", + 'Input should be a valid complex number', None, ), ( 'complex_parsing', - "Input should be a valid complex string following the rule at https://docs.python.org/3/library/functions.html#complex", + 'Input should be a valid complex string following the rule at https://docs.python.org/3/library/functions.html#complex', None, ), ] From 1ea4bd94248c8e35eca8b0b8ae1943fa452ba202 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Tue, 25 Jun 2024 22:14:57 +0100 Subject: [PATCH 11/18] xfail test for pypy --- tests/validators/test_complex.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/validators/test_complex.py b/tests/validators/test_complex.py index 6211d8fe2..2eb778ebe 100644 --- a/tests/validators/test_complex.py +++ b/tests/validators/test_complex.py @@ -1,4 +1,5 @@ import math +import platform import re import pytest @@ -21,7 +22,6 @@ ('1.5-j', complex(1.5, -1)), ('-j', complex(0, -1)), ('j', complex(0, 1)), - ('\t( -1.23+4.5J )\n', complex(-1.23, 4.5)), (3, complex(3, 0)), (2.0, complex(2, 0)), ('1e-700j', complex(0, 0)), @@ -51,6 +51,15 @@ def test_complex_cases(input_value, expected): assert v.validate_python(input_value) == expected +@pytest.mark.xfail( + platform.python_implementation() == 'PyPy', + reason='PyPy cannot process this string due to a bug, even if this string is considered valid in python', +) +def test_valid_complex_string_with_space(): + v = SchemaValidator({'type': 'complex'}) + assert v.validate_python('\t( -1.23+4.5J )\n') == complex(-1.23, 4.5) + + def test_nan_inf_complex(): v = SchemaValidator({'type': 'complex'}) c = v.validate_python('NaN+Infinityj') From 7da4284f109f16983dd737f10bcb1cc277aeb668 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Tue, 25 Jun 2024 22:15:48 +0100 Subject: [PATCH 12/18] fix format --- src/errors/types.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/errors/types.rs b/src/errors/types.rs index 99f6622b5..5c2134340 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -423,7 +423,7 @@ error_types! { whole_digits: {ctx_type: u64, ctx_fn: field_from_context}, }, // Complex errors - ComplexParsing{}, + ComplexParsing {}, ComplexType {}, } @@ -568,7 +568,7 @@ impl ErrorType { Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}", Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point", Self::ComplexParsing {..} => "Input should be a valid complex string following the rule at https://docs.python.org/3/library/functions.html#complex", - Self::ComplexType { .. } => "Input should be a valid complex number", + Self::ComplexType {..} => "Input should be a valid complex number", } } From 7eb84c83958ff451926effebb79be4cd0748b5f2 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Mon, 5 Aug 2024 00:30:47 +0100 Subject: [PATCH 13/18] strict mode --- python/pydantic_core/core_schema.py | 4 ++ src/errors/types.rs | 8 ++-- src/input/input_abstract.rs | 2 +- src/input/input_json.rs | 27 ++++++++++--- src/input/input_python.rs | 27 ++++++++++--- src/input/input_string.rs | 2 +- src/validators/complex.rs | 18 ++++++--- tests/validators/test_complex.py | 60 +++++++++++++++++++++++++---- 8 files changed, 118 insertions(+), 30 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 796389718..8d882f5ee 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -744,6 +744,7 @@ def decimal_schema( class ComplexSchema(TypedDict, total=False): type: Required[Literal['complex']] + strict: bool ref: str metadata: Any serialization: SerSchema @@ -751,6 +752,7 @@ class ComplexSchema(TypedDict, total=False): def complex_schema( *, + strict: bool | None = None, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None, @@ -768,12 +770,14 @@ def complex_schema( ``` Args: + strict: Whether the value should be a complex object instance or a value that can be converted to a complex object ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ return _dict_not_none( type='complex', + strict=strict, ref=ref, metadata=metadata, serialization=serialization, diff --git a/src/errors/types.rs b/src/errors/types.rs index 5c2134340..848b5f6ab 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -423,8 +423,9 @@ error_types! { whole_digits: {ctx_type: u64, ctx_fn: field_from_context}, }, // Complex errors - ComplexParsing {}, ComplexType {}, + ComplexTypePyStrict {}, + ComplexStrParsing {}, } macro_rules! render { @@ -567,8 +568,9 @@ impl ErrorType { Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total", Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}", Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point", - Self::ComplexParsing {..} => "Input should be a valid complex string following the rule at https://docs.python.org/3/library/functions.html#complex", - Self::ComplexType {..} => "Input should be a valid complex number", + Self::ComplexType {..} => "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex", + Self::ComplexTypePyStrict {..} => "Input should be a valid Python complex object", + Self::ComplexStrParsing {..} => "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex", } } diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 8e7ee9f7c..f2538fad7 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -173,7 +173,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject { microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, ) -> ValMatch>; - fn validate_complex(&self, py: Python<'py>) -> ValMatch>; + fn validate_complex(&self, strict: bool, py: Python<'py>) -> ValMatch>; } /// The problem to solve here is that iterating collections often returns owned diff --git a/src/input/input_json.rs b/src/input/input_json.rs index d68b9c2ca..407212d94 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -299,15 +299,27 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> { } } - fn validate_complex(&self, py: Python<'py>) -> ValResult>> { + fn validate_complex(&self, strict: bool, py: Python<'py>) -> ValResult>> { match self { JsonValue::Str(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex( &PyString::new_bound(py, s), self, )?))), - JsonValue::Float(f) => Ok(ValidationMatch::strict(EitherComplex::Complex([*f, 0.0]))), - JsonValue::Int(f) => Ok(ValidationMatch::strict(EitherComplex::Complex([(*f) as f64, 0.0]))), - _ => Err(ValError::new(ErrorTypeDefaults::ComplexParsing, self)), + JsonValue::Float(f) => { + if !strict { + Ok(ValidationMatch::lax(EitherComplex::Complex([*f, 0.0]))) + } else { + Err(ValError::new(ErrorTypeDefaults::ComplexStrParsing, self)) + } + } + JsonValue::Int(f) => { + if !strict { + Ok(ValidationMatch::lax(EitherComplex::Complex([(*f) as f64, 0.0]))) + } else { + Err(ValError::new(ErrorTypeDefaults::ComplexStrParsing, self)) + } + } + _ => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)), } } } @@ -440,8 +452,11 @@ impl<'py> Input<'py> for str { bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax) } - fn validate_complex(&self, _py: Python<'py>) -> ValResult>> { - Err(ValError::new(ErrorTypeDefaults::ComplexType, self)) + fn validate_complex(&self, _strict: bool, py: Python<'py>) -> ValResult>> { + Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex( + self.to_object(py).downcast_bound::(py)?, + self, + )?))) } } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index d0a24427b..2c7321b8c 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -595,18 +595,33 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)) } - fn validate_complex<'a>(&'a self, _py: Python<'py>) -> ValResult>> { + fn validate_complex<'a>( + &'a self, + strict: bool, + _py: Python<'py>, + ) -> ValResult>> { if let Ok(complex) = self.downcast::() { - return Ok(ValidationMatch::exact(EitherComplex::Py(complex.to_owned()))); - } else if let Ok(s) = self.downcast::() { - return Ok(ValidationMatch::exact(EitherComplex::Py(string_to_complex(s, self)?))); + return Ok(ValidationMatch::strict(EitherComplex::Py(complex.to_owned()))); + } + if strict { + return Err(ValError::new(ErrorTypeDefaults::ComplexTypePyStrict, self)); + } + + if let Ok(s) = self.downcast::() { + // If input is not a valid complex string, instead of telling users to correct + // the string, it makes more sense to tell them to provide any acceptable value + // since they might have just given values of some incorrect types instead + // of actually trying some complex strings. + if let Ok(c) = string_to_complex(s, self) { + return Ok(ValidationMatch::lax(EitherComplex::Py(c))); + } } else if self.is_exact_instance_of::() { - return Ok(ValidationMatch::exact(EitherComplex::Complex([ + return Ok(ValidationMatch::lax(EitherComplex::Complex([ self.extract::().unwrap(), 0.0, ]))); } else if self.is_exact_instance_of::() { - return Ok(ValidationMatch::exact(EitherComplex::Complex([ + return Ok(ValidationMatch::lax(EitherComplex::Complex([ self.extract::().unwrap() as f64, 0.0, ]))); diff --git a/src/input/input_string.rs b/src/input/input_string.rs index d89ddaca8..87453f5dc 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -220,7 +220,7 @@ impl<'py> Input<'py> for StringMapping<'py> { } } - fn validate_complex(&self, _py: Python<'py>) -> ValResult>> { + fn validate_complex(&self, _strict: bool, _py: Python<'py>) -> ValResult>> { match self { Self::String(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(s, self)?))), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)), diff --git a/src/validators/complex.rs b/src/validators/complex.rs index dbd238741..d1d9f6c35 100644 --- a/src/validators/complex.rs +++ b/src/validators/complex.rs @@ -3,6 +3,7 @@ use pyo3::prelude::*; use pyo3::sync::GILOnceCell; use pyo3::types::{PyComplex, PyDict, PyString, PyType}; +use crate::build_tools::is_strict; use crate::errors::{ErrorTypeDefaults, ToErrorValue, ValError, ValResult}; use crate::input::Input; @@ -17,16 +18,21 @@ pub fn get_complex_type(py: Python) -> &Bound<'_, PyType> { } #[derive(Debug)] -pub struct ComplexValidator {} +pub struct ComplexValidator { + strict: bool, +} impl BuildValidator for ComplexValidator { const EXPECTED_TYPE: &'static str = "complex"; fn build( - _schema: &Bound<'_, PyDict>, - _config: Option<&Bound<'_, PyDict>>, + schema: &Bound<'_, PyDict>, + config: Option<&Bound<'_, PyDict>>, _definitions: &mut DefinitionsBuilder, ) -> PyResult { - Ok(Self {}.into()) + Ok(Self { + strict: is_strict(schema, config)?, + } + .into()) } } @@ -39,7 +45,7 @@ impl Validator for ComplexValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let res = input.validate_complex(py)?.unpack(state); + let res = input.validate_complex(self.strict, py)?.unpack(state); Ok(res.into_py(py)) } @@ -59,7 +65,7 @@ pub(crate) fn string_to_complex<'py>( // Since arg is a string, the only possible error here is ValueError // triggered by invalid complex strings and thus only this case is handled. if err.is_instance_of::(py) { - ValError::new(ErrorTypeDefaults::ComplexParsing, input) + ValError::new(ErrorTypeDefaults::ComplexStrParsing, input) } else { ValError::InternalErr(err) } diff --git a/tests/validators/test_complex.py b/tests/validators/test_complex.py index 2eb778ebe..d91abbb7b 100644 --- a/tests/validators/test_complex.py +++ b/tests/validators/test_complex.py @@ -8,8 +8,9 @@ from ..conftest import Err -EXPECTED_PARSE_ERROR_MESSAGE = 'Input should be a valid complex string following the rule at https://docs.python.org/3/library/functions.html#complex' -EXPECTED_TYPE_ERROR_MESSAGE = 'Input should be a valid complex number' +EXPECTED_PARSE_ERROR_MESSAGE = 'Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex' +EXPECTED_TYPE_ERROR_MESSAGE = 'Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex' +EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE = 'Input should be a valid Python complex object' @pytest.mark.parametrize( @@ -25,12 +26,12 @@ (3, complex(3, 0)), (2.0, complex(2, 0)), ('1e-700j', complex(0, 0)), - ('', Err(EXPECTED_PARSE_ERROR_MESSAGE)), - ('\t( -1.23+4.5J \n', Err(EXPECTED_PARSE_ERROR_MESSAGE)), + ('', Err(EXPECTED_TYPE_ERROR_MESSAGE)), + ('\t( -1.23+4.5J \n', Err(EXPECTED_TYPE_ERROR_MESSAGE)), ({'real': 2, 'imag': 4}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), ({'real': 'test', 'imag': 1}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), ({'real': True, 'imag': 1}, Err(EXPECTED_TYPE_ERROR_MESSAGE)), - ('foobar', Err(EXPECTED_PARSE_ERROR_MESSAGE)), + ('foobar', Err(EXPECTED_TYPE_ERROR_MESSAGE)), ([], Err(EXPECTED_TYPE_ERROR_MESSAGE)), ([('x', 'y')], Err(EXPECTED_TYPE_ERROR_MESSAGE)), ((), Err(EXPECTED_TYPE_ERROR_MESSAGE)), @@ -51,6 +52,37 @@ def test_complex_cases(input_value, expected): assert v.validate_python(input_value) == expected +@pytest.mark.parametrize( + 'input_value,expected', + [ + (complex(2, 4), complex(2, 4)), + ('2', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('2j', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('+1.23e-4-5.67e+8J', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('1.5-j', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('-j', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('j', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + (3, Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + (2.0, Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('1e-700j', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('\t( -1.23+4.5J \n', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ({'real': 2, 'imag': 4}, Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ({'real': 'test', 'imag': 1}, Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ({'real': True, 'imag': 1}, Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ('foobar', Err(EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE)), + ], + ids=repr, +) +def test_complex_strict(input_value, expected): + v = SchemaValidator({'type': 'complex', 'strict': True}) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_python(input_value) + else: + assert v.validate_python(input_value) == expected + + @pytest.mark.xfail( platform.python_implementation() == 'PyPy', reason='PyPy cannot process this string due to a bug, even if this string is considered valid in python', @@ -93,14 +125,28 @@ def test_json_complex(): v.validate_json('{"real": 2, "imag": 4}') assert exc_info.value.errors(include_url=False) == [ { - 'type': 'complex_parsing', + 'type': 'complex_type', 'loc': (), - 'msg': EXPECTED_PARSE_ERROR_MESSAGE, + 'msg': EXPECTED_TYPE_ERROR_MESSAGE, 'input': {'real': 2, 'imag': 4}, } ] +def test_json_complex_strict(): + v = SchemaValidator({'type': 'complex', 'strict': True}) + assert v.validate_json('"-1.23e+4+5.67e-8J"') == complex(-1.23e4, 5.67e-8) + # "1" is a valid complex string + assert v.validate_json('"1"') == complex(1, 0) + + with pytest.raises(ValidationError, match=re.escape(EXPECTED_PARSE_ERROR_MESSAGE)): + assert v.validate_json('1') == complex(1, 0) + with pytest.raises(ValidationError, match=re.escape(EXPECTED_PARSE_ERROR_MESSAGE)): + assert v.validate_json('1.0') == complex(1, 0) + with pytest.raises(ValidationError, match=re.escape(EXPECTED_TYPE_ERROR_MESSAGE)): + v.validate_json('{"real": 2, "imag": 4}') + + def test_string_complex(): v = SchemaValidator({'type': 'complex'}) assert v.validate_strings('+1.23e-4-5.67e+8J') == complex(1.23e-4, -5.67e8) From ae774c41d666db3b96f4cbf220319d188129674d Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Tue, 6 Aug 2024 21:11:59 +0100 Subject: [PATCH 14/18] update tests --- tests/validators/test_dict.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/validators/test_dict.py b/tests/validators/test_dict.py index 7b3ca19c8..86d73b42a 100644 --- a/tests/validators/test_dict.py +++ b/tests/validators/test_dict.py @@ -46,6 +46,21 @@ def test_dict_cases(input_value, expected): assert v.validate_python(input_value) == expected +def test_dict_complex_key(): + v = SchemaValidator( + {'type': 'dict', 'keys_schema': {'type': 'complex', 'strict': True}, 'values_schema': {'type': 'str'}} + ) + assert v.validate_python({complex(1, 2): '1'}) == {complex(1, 2): '1'} + with pytest.raises(ValidationError, match='Input should be a valid Python complex object'): + assert v.validate_python({'1+2j': b'1'}) == {complex(1, 2): '1'} + + v = SchemaValidator({'type': 'dict', 'keys_schema': {'type': 'complex'}, 'values_schema': {'type': 'str'}}) + with pytest.raises( + ValidationError, match='Input should be a valid python complex object, a number, or a valid complex string' + ): + v.validate_python({'1+2ja': b'1'}) + + def test_dict_value_error(py_and_json: PyAndJson): v = py_and_json({'type': 'dict', 'values_schema': {'type': 'int'}}) assert v.validate_test({'a': 2, 'b': '4'}) == {'a': 2, 'b': 4} From 9d373ef328da464da016ab7561aa224dd87317b8 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Tue, 6 Aug 2024 21:26:22 +0100 Subject: [PATCH 15/18] fix tests --- python/pydantic_core/core_schema.py | 3 ++- tests/test_errors.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index b49590dd2..868c3ee98 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -4000,8 +4000,9 @@ def definition_reference_schema( 'decimal_max_digits', 'decimal_max_places', 'decimal_whole_digits', - 'complex_parsing', 'complex_type', + 'complex_type_py_strict', + 'complex_str_parsing', ] diff --git a/tests/test_errors.py b/tests/test_errors.py index 5af8b49d1..d516e4a06 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -397,12 +397,17 @@ def f(input_value, info): ), ( 'complex_type', - 'Input should be a valid complex number', + 'Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex', None, ), ( - 'complex_parsing', - 'Input should be a valid complex string following the rule at https://docs.python.org/3/library/functions.html#complex', + 'complex_type_py_strict', + 'Input should be a valid Python complex object', + None, + ), + ( + 'complex_str_parsing', + 'Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex', None, ), ] From cc53b24d16e8956159f2dcbb7b4b1408a0e73653 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Tue, 6 Aug 2024 21:53:42 +0100 Subject: [PATCH 16/18] update tests --- tests/validators/test_complex.py | 4 ++-- tests/validators/test_dict.py | 40 ++++++++++++++++++++------------ 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/tests/validators/test_complex.py b/tests/validators/test_complex.py index d91abbb7b..468b80138 100644 --- a/tests/validators/test_complex.py +++ b/tests/validators/test_complex.py @@ -140,9 +140,9 @@ def test_json_complex_strict(): assert v.validate_json('"1"') == complex(1, 0) with pytest.raises(ValidationError, match=re.escape(EXPECTED_PARSE_ERROR_MESSAGE)): - assert v.validate_json('1') == complex(1, 0) + v.validate_json('1') with pytest.raises(ValidationError, match=re.escape(EXPECTED_PARSE_ERROR_MESSAGE)): - assert v.validate_json('1.0') == complex(1, 0) + v.validate_json('1.0') with pytest.raises(ValidationError, match=re.escape(EXPECTED_TYPE_ERROR_MESSAGE)): v.validate_json('{"real": 2, "imag": 4}') diff --git a/tests/validators/test_dict.py b/tests/validators/test_dict.py index 86d73b42a..530c8d90f 100644 --- a/tests/validators/test_dict.py +++ b/tests/validators/test_dict.py @@ -46,21 +46,6 @@ def test_dict_cases(input_value, expected): assert v.validate_python(input_value) == expected -def test_dict_complex_key(): - v = SchemaValidator( - {'type': 'dict', 'keys_schema': {'type': 'complex', 'strict': True}, 'values_schema': {'type': 'str'}} - ) - assert v.validate_python({complex(1, 2): '1'}) == {complex(1, 2): '1'} - with pytest.raises(ValidationError, match='Input should be a valid Python complex object'): - assert v.validate_python({'1+2j': b'1'}) == {complex(1, 2): '1'} - - v = SchemaValidator({'type': 'dict', 'keys_schema': {'type': 'complex'}, 'values_schema': {'type': 'str'}}) - with pytest.raises( - ValidationError, match='Input should be a valid python complex object, a number, or a valid complex string' - ): - v.validate_python({'1+2ja': b'1'}) - - def test_dict_value_error(py_and_json: PyAndJson): v = py_and_json({'type': 'dict', 'values_schema': {'type': 'int'}}) assert v.validate_test({'a': 2, 'b': '4'}) == {'a': 2, 'b': 4} @@ -250,3 +235,28 @@ def test_json_dict(): assert exc_info.value.errors(include_url=False) == [ {'type': 'dict_type', 'loc': (), 'msg': 'Input should be an object', 'input': 1} ] + + +def test_dict_complex_key(): + v = SchemaValidator( + {'type': 'dict', 'keys_schema': {'type': 'complex', 'strict': True}, 'values_schema': {'type': 'str'}} + ) + assert v.validate_python({complex(1, 2): '1'}) == {complex(1, 2): '1'} + with pytest.raises(ValidationError, match='Input should be a valid Python complex object'): + assert v.validate_python({'1+2j': b'1'}) == {complex(1, 2): '1'} + + v = SchemaValidator({'type': 'dict', 'keys_schema': {'type': 'complex'}, 'values_schema': {'type': 'str'}}) + with pytest.raises( + ValidationError, match='Input should be a valid python complex object, a number, or a valid complex string' + ): + v.validate_python({'1+2ja': b'1'}) + + +def test_json_dict_complex_key(): + v = SchemaValidator( + {'type': 'dict', 'keys_schema': {'type': 'complex'}, 'values_schema': {'type': 'int'}} + ) + assert v.validate_json('{"1+2j": 2, "-3": 4}') == {complex(1, 2): 2, complex(-3, 0): 4} + assert v.validate_json('{"1+2j": 2, "infj": 4}') == {complex(1, 2): 2, complex(0, float('inf')): 4} + with pytest.raises(ValidationError, match='Input should be a valid complex string'): + v.validate_json('{"1+2j": 2, "": 4}') == {complex(1, 2): 2, complex(0, float('inf')): 4} From 6c8c222912b2791917d31d904710de9a7a3f31e4 Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Tue, 6 Aug 2024 22:05:12 +0100 Subject: [PATCH 17/18] fix tests --- tests/validators/test_dict.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/validators/test_dict.py b/tests/validators/test_dict.py index 530c8d90f..d5839cf6a 100644 --- a/tests/validators/test_dict.py +++ b/tests/validators/test_dict.py @@ -253,9 +253,7 @@ def test_dict_complex_key(): def test_json_dict_complex_key(): - v = SchemaValidator( - {'type': 'dict', 'keys_schema': {'type': 'complex'}, 'values_schema': {'type': 'int'}} - ) + v = SchemaValidator({'type': 'dict', 'keys_schema': {'type': 'complex'}, 'values_schema': {'type': 'int'}}) assert v.validate_json('{"1+2j": 2, "-3": 4}') == {complex(1, 2): 2, complex(-3, 0): 4} assert v.validate_json('{"1+2j": 2, "infj": 4}') == {complex(1, 2): 2, complex(0, float('inf')): 4} with pytest.raises(ValidationError, match='Input should be a valid complex string'): From 89195897f67178a99ba330c7e28fa960c4149b5a Mon Sep 17 00:00:00 2001 From: Huan-Cheng Chang Date: Thu, 8 Aug 2024 20:54:10 +0100 Subject: [PATCH 18/18] use isinstance error --- python/pydantic_core/core_schema.py | 1 - src/errors/types.rs | 2 -- src/input/input_python.rs | 18 ++++++++++++------ tests/test_errors.py | 5 ----- tests/validators/test_complex.py | 2 +- tests/validators/test_dict.py | 2 +- 6 files changed, 14 insertions(+), 16 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 868c3ee98..886194cbc 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -4001,7 +4001,6 @@ def definition_reference_schema( 'decimal_max_places', 'decimal_whole_digits', 'complex_type', - 'complex_type_py_strict', 'complex_str_parsing', ] diff --git a/src/errors/types.rs b/src/errors/types.rs index b1ae7f7bc..ec129a63a 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -428,7 +428,6 @@ error_types! { }, // Complex errors ComplexType {}, - ComplexTypePyStrict {}, ComplexStrParsing {}, } @@ -574,7 +573,6 @@ impl ErrorType { Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}", Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point", Self::ComplexType {..} => "Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex", - Self::ComplexTypePyStrict {..} => "Input should be a valid Python complex object", Self::ComplexStrParsing {..} => "Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex", } } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 5cee4e1b1..46c32a9de 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -10,6 +10,7 @@ use pyo3::types::{ }; use pyo3::PyTypeCheck; +use pyo3::PyTypeInfo; use speedate::MicrosecondsPrecisionOverflowBehavior; use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; @@ -601,16 +602,21 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)) } - fn validate_complex<'a>( - &'a self, - strict: bool, - _py: Python<'py>, - ) -> ValResult>> { + fn validate_complex<'a>(&'a self, strict: bool, py: Python<'py>) -> ValResult>> { if let Ok(complex) = self.downcast::() { return Ok(ValidationMatch::strict(EitherComplex::Py(complex.to_owned()))); } if strict { - return Err(ValError::new(ErrorTypeDefaults::ComplexTypePyStrict, self)); + return Err(ValError::new( + ErrorType::IsInstanceOf { + class: PyComplex::type_object_bound(py) + .qualname() + .and_then(|name| name.extract()) + .unwrap_or_else(|_| "complex".to_owned()), + context: None, + }, + self, + )); } if let Ok(s) = self.downcast::() { diff --git a/tests/test_errors.py b/tests/test_errors.py index d516e4a06..b8265f04e 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -400,11 +400,6 @@ def f(input_value, info): 'Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex', None, ), - ( - 'complex_type_py_strict', - 'Input should be a valid Python complex object', - None, - ), ( 'complex_str_parsing', 'Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex', diff --git a/tests/validators/test_complex.py b/tests/validators/test_complex.py index 468b80138..83c5d416d 100644 --- a/tests/validators/test_complex.py +++ b/tests/validators/test_complex.py @@ -10,7 +10,7 @@ EXPECTED_PARSE_ERROR_MESSAGE = 'Input should be a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex' EXPECTED_TYPE_ERROR_MESSAGE = 'Input should be a valid python complex object, a number, or a valid complex string following the rules at https://docs.python.org/3/library/functions.html#complex' -EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE = 'Input should be a valid Python complex object' +EXPECTED_TYPE_ERROR_PY_STRICT_MESSAGE = 'Input should be an instance of complex' @pytest.mark.parametrize( diff --git a/tests/validators/test_dict.py b/tests/validators/test_dict.py index d5839cf6a..4057ce76e 100644 --- a/tests/validators/test_dict.py +++ b/tests/validators/test_dict.py @@ -242,7 +242,7 @@ def test_dict_complex_key(): {'type': 'dict', 'keys_schema': {'type': 'complex', 'strict': True}, 'values_schema': {'type': 'str'}} ) assert v.validate_python({complex(1, 2): '1'}) == {complex(1, 2): '1'} - with pytest.raises(ValidationError, match='Input should be a valid Python complex object'): + with pytest.raises(ValidationError, match='Input should be an instance of complex'): assert v.validate_python({'1+2j': b'1'}) == {complex(1, 2): '1'} v = SchemaValidator({'type': 'dict', 'keys_schema': {'type': 'complex'}, 'values_schema': {'type': 'str'}})