Skip to content

Commit

Permalink
Int extraction (#1155)
Browse files Browse the repository at this point in the history
samuelcolvin authored Jan 15, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 5d3aa43 commit d7cf72d
Showing 11 changed files with 64 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -36,3 +36,6 @@ node_modules/
/foobar.py
/python/pydantic_core/*.so
/src/self_schema.py

# samply
/profile.json
2 changes: 1 addition & 1 deletion src/errors/types.rs
Original file line number Diff line number Diff line change
@@ -786,7 +786,7 @@ impl From<Int> for Number {

impl FromPyObject<'_> for Number {
fn extract(obj: &PyAny) -> PyResult<Self> {
if let Ok(int) = extract_i64(obj) {
if let Some(int) = extract_i64(obj) {
Ok(Number::Int(int))
} else if let Ok(float) = obj.extract::<f64>() {
Ok(Number::Float(float))
2 changes: 1 addition & 1 deletion src/errors/value_exception.rs
Original file line number Diff line number Diff line change
@@ -122,7 +122,7 @@ impl PydanticCustomError {
let key: &PyString = key.downcast()?;
if let Ok(py_str) = value.downcast::<PyString>() {
message = message.replace(&format!("{{{}}}", key.to_str()?), py_str.to_str()?);
} else if let Ok(value_int) = extract_i64(value) {
} else if let Some(value_int) = extract_i64(value) {
message = message.replace(&format!("{{{}}}", key.to_str()?), &value_int.to_string());
} else {
// fallback for anything else just in case
10 changes: 5 additions & 5 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
@@ -96,7 +96,7 @@ impl AsLocItem for PyAny {
fn as_loc_item(&self) -> LocItem {
if let Ok(py_str) = self.downcast::<PyString>() {
py_str.to_string_lossy().as_ref().into()
} else if let Ok(key_int) = extract_i64(self) {
} else if let Some(key_int) = extract_i64(self) {
key_int.into()
} else {
safe_repr(self).to_string().into()
@@ -292,7 +292,7 @@ impl<'a> Input<'a> for PyAny {
if !strict {
if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? {
return str_as_bool(self, &cow_str).map(ValidationMatch::lax);
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
return int_as_bool(self, int).map(ValidationMatch::lax);
} else if let Ok(float) = self.extract::<f64>() {
if let Ok(int) = float_as_int(self, float) {
@@ -635,7 +635,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_time(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if PyBool::is_exact_type_of(self) {
Err(ValError::new(ErrorTypeDefaults::TimeType, self))
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
int_as_time(self, int, 0)
} else if let Ok(float) = self.extract::<f64>() {
float_as_time(self, float)
@@ -669,7 +669,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if PyBool::is_exact_type_of(self) {
Err(ValError::new(ErrorTypeDefaults::DatetimeType, self))
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
int_as_datetime(self, int, 0)
} else if let Ok(float) = self.extract::<f64>() {
float_as_datetime(self, float)
@@ -706,7 +706,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_timedelta(self, str.as_bytes(), microseconds_overflow_behavior)
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
bytes_as_timedelta(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
Ok(int_as_duration(self, int)?.into())
} else if let Ok(float) = self.extract::<f64>() {
Ok(float_as_duration(self, float)?.into())
6 changes: 3 additions & 3 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ use pyo3::PyTypeInfo;
use serde::{ser::Error, Serialize, Serializer};

use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult};
use crate::tools::py_err;
use crate::tools::{extract_i64, py_err};
use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator};

use super::input_string::StringMapping;
@@ -863,7 +863,7 @@ pub enum EitherInt<'a> {
impl<'a> EitherInt<'a> {
pub fn upcast(py_any: &'a PyAny) -> ValResult<Self> {
// Safety: we know that py_any is a python int
if let Ok(int_64) = py_any.extract::<i64>() {
if let Some(int_64) = extract_i64(py_any) {
Ok(Self::I64(int_64))
} else {
let big_int: BigInt = py_any.extract()?;
@@ -1021,7 +1021,7 @@ impl<'a> Rem for &'a Int {

impl<'a> FromPyObject<'a> for Int {
fn extract(obj: &'a PyAny) -> PyResult<Self> {
if let Ok(i) = obj.extract::<i64>() {
if let Some(i) = extract_i64(obj) {
Ok(Int::I64(i))
} else if let Ok(b) = obj.extract::<BigInt>() {
Ok(Int::Big(b))
2 changes: 1 addition & 1 deletion src/lookup_key.rs
Original file line number Diff line number Diff line change
@@ -429,7 +429,7 @@ impl PathItem {
} else {
Ok(Self::Pos(usize_key))
}
} else if let Ok(int_key) = extract_i64(obj) {
} else if let Some(int_key) = extract_i64(obj) {
if index == 0 {
py_err!(PyTypeError; "The first item in an alias path should be a string")
} else {
5 changes: 4 additions & 1 deletion src/serializers/infer.rs
Original file line number Diff line number Diff line change
@@ -123,7 +123,10 @@ pub(crate) fn infer_to_python_known(
// `bool` and `None` can't be subclasses, `ObType::Int`, `ObType::Float`, `ObType::Str` refer to exact types
ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py),
// have to do this to make sure subclasses of for example str are upcast to `str`
ObType::IntSubclass => extract_i64(value)?.into_py(py),
ObType::IntSubclass => match extract_i64(value) {
Some(v) => v.into_py(py),
None => return py_err!(PyTypeError; "expected int, got {}", safe_repr(value)),
},
ObType::Float | ObType::FloatSubclass => {
let v = value.extract::<f64>()?;
if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null {
4 changes: 2 additions & 2 deletions src/serializers/type_serializers/literal.rs
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ impl BuildSerializer for LiteralSerializer {
repr_args.push(item.repr()?.extract()?);
if let Ok(bool) = item.downcast::<PyBool>() {
expected_py.append(bool)?;
} else if let Ok(int) = extract_i64(item) {
} else if let Some(int) = extract_i64(item) {
expected_int.insert(int);
} else if let Ok(py_str) = item.downcast::<PyString>() {
expected_str.insert(py_str.to_str()?.to_string());
@@ -79,7 +79,7 @@ impl LiteralSerializer {
fn check<'a>(&self, value: &'a PyAny, extra: &Extra) -> PyResult<OutputValue<'a>> {
if extra.check.enabled() {
if !self.expected_int.is_empty() && !PyBool::is_type_of(value) {
if let Ok(int) = extract_i64(value) {
if let Some(int) = extract_i64(value) {
if self.expected_int.contains(&int) {
return Ok(OutputValue::OkInt(int));
}
28 changes: 21 additions & 7 deletions src/tools.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::borrow::Cow;

use pyo3::exceptions::{PyKeyError, PyTypeError};
use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyInt, PyString};
use pyo3::{intern, FromPyObject, PyTypeInfo};
use pyo3::types::{PyDict, PyString};
use pyo3::{ffi, intern, FromPyObject};

pub trait SchemaDict<'py> {
fn get_as<T>(&'py self, key: &PyString) -> PyResult<Option<T>>
@@ -99,10 +99,24 @@ pub fn safe_repr(v: &PyAny) -> Cow<str> {
}
}

pub fn extract_i64(v: &PyAny) -> PyResult<i64> {
if PyInt::is_type_of(v) {
v.extract()
/// Extract an i64 from a python object more quickly, see
/// https://github.com/PyO3/pyo3/pull/3742#discussion_r1451763928
#[cfg(not(any(target_pointer_width = "32", windows, PyPy)))]
pub fn extract_i64(obj: &PyAny) -> Option<i64> {
let val = unsafe { ffi::PyLong_AsLong(obj.as_ptr()) };
if val == -1 && PyErr::occurred(obj.py()) {
unsafe { ffi::PyErr_Clear() };
None
} else {
py_err!(PyTypeError; "expected int, got {}", safe_repr(v))
Some(val)
}
}

#[cfg(any(target_pointer_width = "32", windows, PyPy))]
pub fn extract_i64(v: &PyAny) -> Option<i64> {
if v.is_instance_of::<pyo3::types::PyInt>() {
v.extract().ok()
} else {
None
}
}
12 changes: 12 additions & 0 deletions tests/benchmarks/test_micro_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1232,6 +1232,18 @@ def test_strict_int(benchmark):
benchmark(v.validate_python, 42)


@pytest.mark.benchmark(group='strict_int')
def test_strict_int_fails(benchmark):
v = SchemaValidator(core_schema.int_schema(strict=True))

@benchmark
def t():
try:
v.validate_python(())
except ValidationError:
pass


@pytest.mark.benchmark(group='int_range')
def test_int_range(benchmark):
v = SchemaValidator(core_schema.int_schema(gt=0, lt=100))
15 changes: 11 additions & 4 deletions tests/validators/test_int.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,8 @@
('123456789123456.00001', Err('Input should be a valid integer, unable to parse string as an integer')),
(int(1e10), int(1e10)),
(i64_max, i64_max),
(i64_max + 1, i64_max + 1),
(i64_max * 2, i64_max * 2),
pytest.param(
12.5,
Err('Input should be a valid integer, got a number with a fractional part [type=int_from_float'),
@@ -106,10 +108,15 @@ def test_int(input_value, expected):
@pytest.mark.parametrize(
'input_value,expected',
[
(Decimal('1'), 1),
(Decimal('1.0'), 1),
(i64_max, i64_max),
(i64_max + 1, i64_max + 1),
pytest.param(Decimal('1'), 1),
pytest.param(Decimal('1.0'), 1),
pytest.param(i64_max, i64_max, id='i64_max'),
pytest.param(i64_max + 1, i64_max + 1, id='i64_max+1'),
pytest.param(
-1,
Err('Input should be greater than 0 [type=greater_than, input_value=-1, input_type=int]'),
id='-1',
),
(
-i64_max + 1,
Err('Input should be greater than 0 [type=greater_than, input_value=-9223372036854775806, input_type=int]'),

0 comments on commit d7cf72d

Please sign in to comment.