diff --git a/src/input/input_python.rs b/src/input/input_python.rs index fe688a487..33d7ca296 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -21,7 +21,10 @@ use super::datetime::{ float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; -use super::shared::{decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; +use super::shared::{ + decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float, + str_as_int, +}; use super::{ py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs, @@ -256,6 +259,8 @@ impl<'a> Input<'a> for PyAny { || self.is_instance(decimal_type.as_ref(py)).unwrap_or_default() } { Ok(self.str()?.into()) + } else if let Some(enum_val) = maybe_as_enum(self) { + Ok(enum_val.str()?.into()) } else { Err(ValError::new(ErrorTypeDefaults::StringType, self)) } @@ -340,6 +345,8 @@ impl<'a> Input<'a> for PyAny { decimal_as_int(self.py(), self, decimal) } else if let Ok(float) = self.extract::() { float_as_int(self, float) + } else if let Some(enum_val) = maybe_as_enum(self) { + Ok(EitherInt::Py(enum_val)) } else { Err(ValError::new(ErrorTypeDefaults::IntType, self)) } @@ -759,6 +766,18 @@ fn maybe_as_string(v: &PyAny, unicode_error: ErrorType) -> ValResult Option<&PyAny> { + let py = v.py(); + let enum_meta_object = get_enum_meta_object(py); + let meta_type = v.get_type().get_type(); + if meta_type.is(&enum_meta_object) { + v.getattr(intern!(py, "value")).ok() + } else { + None + } +} + #[cfg(PyPy)] static DICT_KEYS_TYPE: pyo3::once_cell::GILOnceCell> = pyo3::once_cell::GILOnceCell::new(); diff --git a/src/input/shared.rs b/src/input/shared.rs index 1a8e2b61c..105da4bcc 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -1,11 +1,25 @@ use num_bigint::BigInt; -use pyo3::{intern, PyAny, Python}; +use pyo3::sync::GILOnceCell; +use pyo3::{intern, Py, PyAny, Python, ToPyObject}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use super::parse_json::{JsonArray, JsonInput}; use super::{EitherFloat, EitherInt, Input}; +static ENUM_META_OBJECT: GILOnceCell> = GILOnceCell::new(); + +pub fn get_enum_meta_object(py: Python) -> Py { + ENUM_META_OBJECT + .get_or_init(py, || { + py.import(intern!(py, "enum")) + .and_then(|enum_module| enum_module.getattr(intern!(py, "EnumMeta"))) + .unwrap() + .to_object(py) + }) + .clone() +} + pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> { ValError::new( ErrorType::JsonInvalid { diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index fc491f618..109aed3bb 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -259,8 +259,9 @@ impl ObTypeLookup { fn is_enum(&self, op_value: Option<&PyAny>, py_type: &PyType) -> bool { // only test on the type itself, not base types if op_value.is_some() { + let enum_meta_type = self.enum_object.as_ref(py_type.py()).get_type(); let meta_type = py_type.get_type(); - meta_type.is(&self.enum_object) + meta_type.is(enum_meta_type) } else { false } diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index 8d5850dc8..dedc2bd93 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -459,3 +459,16 @@ def test_float_subclass() -> None: v_lax = v.validate_python(FloatSubclass(1)) assert v_lax == 1 assert type(v_lax) == int + + +def test_int_subclass_plain_enum() -> None: + v = SchemaValidator({'type': 'int'}) + + from enum import Enum + + class PlainEnum(Enum): + ONE = 1 + + v_lax = v.validate_python(PlainEnum.ONE) + assert v_lax == 1 + assert type(v_lax) == int diff --git a/tests/validators/test_string.py b/tests/validators/test_string.py index cab6e5127..bc2102de2 100644 --- a/tests/validators/test_string.py +++ b/tests/validators/test_string.py @@ -249,6 +249,21 @@ def test_lax_subclass(FruitEnum, kwargs): assert repr(p) == "'pear'" +@pytest.mark.parametrize('kwargs', [{}, {'to_lower': True}], ids=repr) +def test_lax_subclass_plain_enum(kwargs): + v = SchemaValidator(core_schema.str_schema(**kwargs)) + + from enum import Enum + + class PlainEnum(Enum): + ONE = 'one' + + p = v.validate_python(PlainEnum.ONE) + assert p == 'one' + assert type(p) is str + assert repr(p) == "'one'" + + def test_subclass_preserved() -> None: class StrSubclass(str): pass