Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support complex numbers #1331

Merged
merged 20 commits into from
Aug 15, 2024
2 changes: 1 addition & 1 deletion generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: # noqa: C901
if isinstance(obj, str):
return {'type': obj}
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal):
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal, complex):
return {'type': obj.__name__.lower()}
elif is_typeddict(obj):
return type_dict_schema(obj, definitions)
Expand Down
42 changes: 42 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,44 @@ def decimal_schema(
)


class ComplexSchema(TypedDict, total=False):
type: Required[Literal['complex']]
ref: str
metadata: Any
serialization: SerSchema


def complex_schema(
*,
ref: str | None = None,
metadata: Any = None,
serialization: SerSchema | None = None,
) -> ComplexSchema:
"""
Returns a schema that matches a complex value, e.g.:

```py
from pydantic_core import SchemaValidator, core_schema

schema = core_schema.complex_schema()
v = SchemaValidator(schema)
assert v.validate_python('1+2j') == complex(1, 2)
assert v.validate_python(complex(1, 2)) == complex(1, 2)
```

Args:
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(
type='complex',
ref=ref,
metadata=metadata,
serialization=serialization,
)


class StringSchema(TypedDict, total=False):
type: Required[Literal['str']]
pattern: Union[str, Pattern[str]]
Expand Down Expand Up @@ -3777,6 +3815,7 @@ def definition_reference_schema(
DefinitionsSchema,
DefinitionReferenceSchema,
UuidSchema,
ComplexSchema,
]
elif False:
CoreSchema: TypeAlias = Mapping[str, Any]
Expand Down Expand Up @@ -3832,6 +3871,7 @@ def definition_reference_schema(
'definitions',
'definition-ref',
'uuid',
'complex',
]

CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']
Expand Down Expand Up @@ -3936,6 +3976,8 @@ def definition_reference_schema(
'decimal_max_digits',
'decimal_max_places',
'decimal_whole_digits',
'complex_parsing',
'complex_type',
]


Expand Down
5 changes: 5 additions & 0 deletions src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,9 @@ error_types! {
DecimalWholeDigits {
whole_digits: {ctx_type: u64, ctx_fn: field_from_context},
},
// Complex errors
ComplexParsing{},
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
ComplexType {},
}

macro_rules! render {
Expand Down Expand Up @@ -564,6 +567,8 @@ impl ErrorType {
Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total",
Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}",
Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point",
Self::ComplexParsing {..} => "Input should be a valid complex string following the rule at https://docs.python.org/3/library/functions.html#complex",
Self::ComplexType { .. } => "Input should be a valid complex number",
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::py_err;

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
use super::return_enums::{EitherBytes, EitherComplex, EitherInt, EitherString};
use super::{EitherFloat, GenericIterator, ValidationMatch};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -172,6 +172,8 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValMatch<EitherTimedelta<'py>>;

fn validate_complex(&self, py: Python<'py>) -> ValMatch<EitherComplex<'py>>;
}

/// The problem to solve here is that iterating collections often returns owned
Expand Down
18 changes: 18 additions & 0 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use speedate::MicrosecondsPrecisionOverflowBehavior;
use strum::EnumMessage;

use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::input::return_enums::EitherComplex;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;

use super::datetime::{
Expand Down Expand Up @@ -296,6 +298,18 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
_ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)),
}
}

fn validate_complex(&self, py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
match self {
JsonValue::Str(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(
&PyString::new_bound(py, s),
self,
)?))),
JsonValue::Float(f) => Ok(ValidationMatch::strict(EitherComplex::Complex([*f, 0.0]))),
JsonValue::Int(f) => Ok(ValidationMatch::strict(EitherComplex::Complex([(*f) as f64, 0.0]))),
_ => Err(ValError::new(ErrorTypeDefaults::ComplexParsing, self)),
}
}
}

/// Required for JSON Object keys so the string can behave like an Input
Expand Down Expand Up @@ -425,6 +439,10 @@ impl<'py> Input<'py> for str {
) -> ValResult<ValidationMatch<EitherTimedelta<'py>>> {
bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax)
}

fn validate_complex(&self, _py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
Err(ValError::new(ErrorTypeDefaults::ComplexType, self))
}
}

impl BorrowInput<'_> for &'_ String {
Expand Down
25 changes: 23 additions & 2 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ use pyo3::prelude::*;

use pyo3::types::PyType;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList,
PyMapping, PySet, PyString, PyTime, PyTuple,
PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator,
PyList, PyMapping, PySet, PyString, PyTime, PyTuple,
};

use pyo3::PyTypeCheck;
use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::tools::{extract_i64, safe_repr};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::validators::Exactness;
use crate::ArgsKwargs;
Expand All @@ -25,6 +26,7 @@ use super::datetime::{
EitherTime,
};
use super::input_abstract::ValMatch;
use super::return_enums::EitherComplex;
use super::return_enums::{iterate_attributes, iterate_mapping_items, ValidationMatch};
use super::shared::{
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int,
Expand Down Expand Up @@ -592,6 +594,25 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {

Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self))
}

fn validate_complex<'a>(&'a self, _py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
if let Ok(complex) = self.downcast::<PyComplex>() {
return Ok(ValidationMatch::exact(EitherComplex::Py(complex.to_owned())));
} else if let Ok(s) = self.downcast::<PyString>() {
return Ok(ValidationMatch::exact(EitherComplex::Py(string_to_complex(s, self)?)));
} else if self.is_exact_instance_of::<PyFloat>() {
return Ok(ValidationMatch::exact(EitherComplex::Complex([
self.extract::<f64>().unwrap(),
0.0,
])));
} else if self.is_exact_instance_of::<PyInt>() {
return Ok(ValidationMatch::exact(EitherComplex::Complex([
self.extract::<i64>().unwrap() as f64,
0.0,
])));
}
Err(ValError::new(ErrorTypeDefaults::ComplexType, self))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a think here about what behaviour we expect in strict mode. Probably we only want to accept complex objects in strict. Do we want to accept integers / floats in "strict" mode? In lax, we can accept everything for sure.

Also, similarly the ValidationMatch::exact calls should probably be downgraded to either strict or lax - these will set the level of priority in e.g. complex | int union. I think most would agree that complex | int would expect an integer value to not be cast to a complex. We might want some according tests for some of these union cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strict mode is strict in terms of typing, right? I was mostly thinking from the perspective of mathematics when I added float and int. I suppose then we should only allow complex objects in strict mode? Then ValidationMatch::exact is the same as strict. This sounds reasonable to me, but I'm not exactly sure about the how different exact and strict usually are.

One more question about strict mode: how does it apply to other input types? Is it like when the input is in JSON, we only accept strings in strict mode?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose then we should only allow complex objects in strict mode?

I think so, yes (note that in JSON it's less straightforward as you note). In JSON I think we only accept strings in strict mode, yes.

Then ValidationMatch::exact is the same as strict.

IIRC in other cases I made subclasses strict, if complex objects allow subclassing. I'd check what I did for ints or floats.

}

impl<'py> BorrowInput<'py> for Bound<'py, PyAny> {
Expand Down
9 changes: 9 additions & 0 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}
use crate::input::py_string_str;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::safe_repr;
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
};
use super::input_abstract::{Never, ValMatch};
use super::return_enums::EitherComplex;
use super::shared::{str_as_bool, str_as_float, str_as_int};
use super::{
Arguments, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, Input,
Expand Down Expand Up @@ -217,6 +219,13 @@ impl<'py> Input<'py> for StringMapping<'py> {
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)),
}
}

fn validate_complex(&self, _py: Python<'py>) -> ValResult<ValidationMatch<EitherComplex<'py>>> {
match self {
Self::String(s) => Ok(ValidationMatch::strict(EitherComplex::Py(string_to_complex(s, self)?))),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ComplexType, self)),
}
}
}

impl<'py> BorrowInput<'py> for StringMapping<'py> {
Expand Down
29 changes: 28 additions & 1 deletion src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use pyo3::intern;
use pyo3::prelude::*;
#[cfg(not(PyPy))]
use pyo3::types::PyFunction;
use pyo3::types::{PyBytes, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString};
use pyo3::types::{PyBytes, PyComplex, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString};

use serde::{ser::Error, Serialize, Serializer};

Expand Down Expand Up @@ -716,3 +716,30 @@ impl ToPyObject for Int {
}
}
}

#[derive(Clone)]
pub enum EitherComplex<'a> {
Complex([f64; 2]),
Py(Bound<'a, PyComplex>),
}

impl<'a> IntoPy<PyObject> for EitherComplex<'a> {
fn into_py(self, py: Python<'_>) -> PyObject {
match self {
Self::Complex(c) => PyComplex::from_doubles_bound(py, c[0], c[1]).into_py(py),
Self::Py(c) => c.into_py(py),
}
}
}

impl<'a> EitherComplex<'a> {
pub fn as_f64(&self, py: Python<'_>) -> [f64; 2] {
match self {
EitherComplex::Complex(f) => *f,
EitherComplex::Py(f) => [
f.getattr(intern!(py, "real")).unwrap().extract().unwrap(),
f.getattr(intern!(py, "imag")).unwrap().extract().unwrap(),
],
}
}
}
24 changes: 23 additions & 1 deletion src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::exceptions::PyTypeError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyComplex;
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple};

use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};
Expand Down Expand Up @@ -226,6 +227,13 @@ pub(crate) fn infer_to_python_known(
}
PyList::new_bound(py, items).into_py(py)
}
ObType::Complex => {
let dict = value.downcast::<PyDict>()?;
let new_dict = PyDict::new_bound(py);
let _ = new_dict.set_item("real", dict.get_item("real")?);
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
new_dict.into_py(py)
}
ObType::Path => value.str()?.into_py(py),
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py),
ObType::Unknown => {
Expand Down Expand Up @@ -274,6 +282,13 @@ pub(crate) fn infer_to_python_known(
);
iter.into_py(py)
}
ObType::Complex => {
let dict = value.downcast::<PyDict>()?;
let new_dict = PyDict::new_bound(py);
let _ = new_dict.set_item("real", dict.get_item("real")?);
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
new_dict.into_py(py)
}
ObType::Unknown => {
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
Expand Down Expand Up @@ -402,6 +417,13 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
ObType::None => serializer.serialize_none(),
ObType::Int | ObType::IntSubclass => serialize!(Int),
ObType::Bool => serialize!(bool),
ObType::Complex => {
let v = value.downcast::<PyComplex>().map_err(py_err_se_err)?;
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry(&"real", &v.real())?;
map.serialize_entry(&"imag", &v.imag())?;
map.end()
}
ObType::Float | ObType::FloatSubclass => {
let v = value.extract::<f64>().map_err(py_err_se_err)?;
type_serializers::float::serialize_f64(v, serializer, extra.config.inf_nan_mode)
Expand Down Expand Up @@ -647,7 +669,7 @@ pub(crate) fn infer_json_key_known<'a>(
}
Ok(Cow::Owned(key_build.finish()))
}
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => {
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator | ObType::Complex => {
py_err!(PyTypeError; "`{}` not valid as object key", ob_type)
}
ObType::Dataclass | ObType::PydanticSerializable => {
Expand Down
8 changes: 6 additions & 2 deletions src/serializers/ob_type.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList,
PyNone, PySet, PyString, PyTime, PyTuple, PyType,
PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt,
PyIterator, PyList, PyNone, PySet, PyString, PyTime, PyTuple, PyType,
};
use pyo3::{intern, PyTypeInfo};

Expand Down Expand Up @@ -48,6 +48,7 @@ pub struct ObTypeLookup {
pattern_object: PyObject,
// uuid type
uuid_object: PyObject,
complex: usize,
}

static TYPE_LOOKUP: GILOnceCell<ObTypeLookup> = GILOnceCell::new();
Expand Down Expand Up @@ -101,6 +102,7 @@ impl ObTypeLookup {
.to_object(py),
pattern_object: py.import_bound("re").unwrap().getattr("Pattern").unwrap().to_object(py),
uuid_object: py.import_bound("uuid").unwrap().getattr("UUID").unwrap().to_object(py),
complex: PyComplex::type_object_raw(py) as usize,
}
}

Expand Down Expand Up @@ -171,6 +173,7 @@ impl ObTypeLookup {
ObType::Pattern => self.path_object.as_ptr() as usize == ob_type,
ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type,
ObType::Unknown => false,
ObType::Complex => self.complex == ob_type,
};

if ans {
Expand Down Expand Up @@ -425,6 +428,7 @@ pub enum ObType {
Uuid,
// unknown type
Unknown,
Complex,
}

impl PartialEq for ObType {
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ combined_serializer! {
Enum: super::type_serializers::enum_::EnumSerializer;
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
Tuple: super::type_serializers::tuple::TupleSerializer;
Complex: super::type_serializers::complex::ComplexSerializer;
}
}

Expand Down Expand Up @@ -251,6 +252,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit),
}
}
}
Expand Down
Loading
Loading