Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lax_str and lax_int support for enum values not inherited from str/int #1015

Merged
merged 9 commits into from
Oct 26, 2023
Merged
21 changes: 20 additions & 1 deletion src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use super::datetime::{
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
EitherTime,
};
use super::shared::{decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
use super::shared::{
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float,
str_as_int,
};
use super::{
py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments,
GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs,
Expand Down Expand Up @@ -256,6 +259,8 @@ impl<'a> Input<'a> for PyAny {
|| self.is_instance(decimal_type.as_ref(py)).unwrap_or_default()
} {
Ok(self.str()?.into())
} else if let Some(enum_val) = maybe_as_enum(self) {
Ok(enum_val.str()?.into())
} else {
Err(ValError::new(ErrorTypeDefaults::StringType, self))
}
Expand Down Expand Up @@ -340,6 +345,8 @@ impl<'a> Input<'a> for PyAny {
decimal_as_int(self.py(), self, decimal)
} else if let Ok(float) = self.extract::<f64>() {
float_as_int(self, float)
} else if let Some(enum_val) = maybe_as_enum(self) {
Ok(EitherInt::Py(enum_val))
Comment on lines +348 to +349
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidhewitt I'd like to get your thoughts but I fear this would slow down all int parsing error cases right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it. Any suggestions on a cheaper bypass check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adriangb I think I fixed the performance issue by using interned strings, so we avoid unnecessary allocations. Would you mind taking another look?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as performance goes I don't think maybe_as_enum is very expensive; apart from a one-time import of the enum.EnumMeta object the .get_type() and .is() methods are just pointer operations so should be extremely cheap.

} else {
Err(ValError::new(ErrorTypeDefaults::IntType, self))
}
Expand Down Expand Up @@ -759,6 +766,18 @@ fn maybe_as_string(v: &PyAny, unicode_error: ErrorType) -> ValResult<Option<Cow<
}
}

/// Utility for extracting an enum value, if possible.
fn maybe_as_enum(v: &PyAny) -> Option<&PyAny> {
let py = v.py();
let enum_meta_object = get_enum_meta_object(py);
let meta_type = v.get_type().get_type();
if meta_type.is(&enum_meta_object) {
v.getattr(intern!(py, "value")).ok()
} else {
None
}
}

#[cfg(PyPy)]
static DICT_KEYS_TYPE: pyo3::once_cell::GILOnceCell<Py<PyType>> = pyo3::once_cell::GILOnceCell::new();

Expand Down
16 changes: 15 additions & 1 deletion src/input/shared.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
use num_bigint::BigInt;
use pyo3::{intern, PyAny, Python};
use pyo3::sync::GILOnceCell;
use pyo3::{intern, Py, PyAny, Python, ToPyObject};

use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};

use super::parse_json::{JsonArray, JsonInput};
use super::{EitherFloat, EitherInt, Input};

static ENUM_META_OBJECT: GILOnceCell<Py<PyAny>> = GILOnceCell::new();

pub fn get_enum_meta_object(py: Python) -> Py<PyAny> {
ENUM_META_OBJECT
.get_or_init(py, || {
py.import(intern!(py, "enum"))
.and_then(|enum_module| enum_module.getattr(intern!(py, "EnumMeta")))
.unwrap()
.to_object(py)
})
.clone()
}

pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> {
ValError::new(
ErrorType::JsonInvalid {
Expand Down
3 changes: 2 additions & 1 deletion src/serializers/ob_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ impl ObTypeLookup {
fn is_enum(&self, op_value: Option<&PyAny>, py_type: &PyType) -> bool {
// only test on the type itself, not base types
if op_value.is_some() {
let enum_meta_type = self.enum_object.as_ref(py_type.py()).get_type();
let meta_type = py_type.get_type();
meta_type.is(&self.enum_object)
meta_type.is(enum_meta_type)
} else {
false
}
Expand Down
13 changes: 13 additions & 0 deletions tests/validators/test_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,16 @@ def test_float_subclass() -> None:
v_lax = v.validate_python(FloatSubclass(1))
assert v_lax == 1
assert type(v_lax) == int


def test_int_subclass_plain_enum() -> None:
v = SchemaValidator({'type': 'int'})

from enum import Enum

class PlainEnum(Enum):
ONE = 1

v_lax = v.validate_python(PlainEnum.ONE)
assert v_lax == 1
assert type(v_lax) == int
15 changes: 15 additions & 0 deletions tests/validators/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,21 @@ def test_lax_subclass(FruitEnum, kwargs):
assert repr(p) == "'pear'"


@pytest.mark.parametrize('kwargs', [{}, {'to_lower': True}], ids=repr)
def test_lax_subclass_plain_enum(kwargs):
v = SchemaValidator(core_schema.str_schema(**kwargs))

from enum import Enum

class PlainEnum(Enum):
ONE = 'one'

p = v.validate_python(PlainEnum.ONE)
assert p == 'one'
assert type(p) is str
assert repr(p) == "'one'"


def test_subclass_preserved() -> None:
class StrSubclass(str):
pass
Expand Down