Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extract vs cast_as #2

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict, PyList};

macro_rules! dict_get {
macro_rules! get_as {
($dict:ident, $key:expr, $type:ty) => {
match $dict.get_item($key) {
Some(t) => Some(<$type>::extract(t)?),
Expand Down Expand Up @@ -45,7 +45,7 @@ struct BooleanSchema {
impl SchemaType for BooleanSchema {
fn build(dict: &PyDict) -> PyResult<Self> {
Ok(BooleanSchema {
const_: dict_get!(dict, "const", bool),
const_: get_as!(dict, "const", bool),
})
}

Expand All @@ -70,13 +70,13 @@ struct IntegerSchema {
impl SchemaType for IntegerSchema {
fn build(dict: &PyDict) -> PyResult<Self> {
Ok(IntegerSchema {
enum_: dict_get!(dict, "enum", Vec<i64>),
const_: dict_get!(dict, "const", i64),
multiple_of: dict_get!(dict, "multiple_of", i64),
maximum: dict_get!(dict, "maximum", i64),
exclusive_maximum: dict_get!(dict, "exclusive_maximum", i64),
minimum: dict_get!(dict, "minimum", i64),
exclusive_minimum: dict_get!(dict, "exclusive_minimum", i64),
enum_: get_as!(dict, "enum", Vec<i64>),
const_: get_as!(dict, "const", i64),
multiple_of: get_as!(dict, "multiple_of", i64),
maximum: get_as!(dict, "maximum", i64),
exclusive_maximum: get_as!(dict, "exclusive_maximum", i64),
minimum: get_as!(dict, "minimum", i64),
exclusive_minimum: get_as!(dict, "exclusive_minimum", i64),
})
}

Expand All @@ -101,13 +101,13 @@ struct NumberSchema {
impl SchemaType for NumberSchema {
fn build(dict: &PyDict) -> PyResult<Self> {
Ok(NumberSchema {
enum_: dict_get!(dict, "enum", Vec<f64>),
const_: dict_get!(dict, "const", f64),
multiple_of: dict_get!(dict, "multiple_of", f64),
minimum: dict_get!(dict, "minimum", f64),
exclusive_minimum: dict_get!(dict, "exclusive_minimum", f64),
maximum: dict_get!(dict, "maximum", f64),
exclusive_maximum: dict_get!(dict, "exclusive_maximum", f64),
enum_: get_as!(dict, "enum", Vec<f64>),
const_: get_as!(dict, "const", f64),
multiple_of: get_as!(dict, "multiple_of", f64),
minimum: get_as!(dict, "minimum", f64),
exclusive_minimum: get_as!(dict, "exclusive_minimum", f64),
maximum: get_as!(dict, "maximum", f64),
exclusive_maximum: get_as!(dict, "exclusive_maximum", f64),
})
}

Expand All @@ -130,11 +130,11 @@ struct StringSchema {
impl SchemaType for StringSchema {
fn build(dict: &PyDict) -> PyResult<Self> {
Ok(StringSchema {
enum_: dict_get!(dict, "enum", Vec<String>),
const_: dict_get!(dict, "const", String),
pattern: dict_get!(dict, "pattern", String),
min_length: dict_get!(dict, "min_length", usize),
max_length: dict_get!(dict, "max_length", usize),
enum_: get_as!(dict, "enum", Vec<String>),
const_: get_as!(dict, "const", String),
pattern: get_as!(dict, "pattern", String),
min_length: get_as!(dict, "min_length", usize),
max_length: get_as!(dict, "max_length", usize),
})
}

Expand Down Expand Up @@ -180,24 +180,24 @@ struct ArraySchema {
impl SchemaType for ArraySchema {
fn build(dict: &PyDict) -> PyResult<Self> {
Ok(ArraySchema {
enum_: dict_get!(dict, "enum", Vec<Schema>),
enum_: get_as!(dict, "enum", Vec<Schema>),
items: match dict.get_item("items") {
Some(t) => Some(Box::new(Schema::extract(t)?)),
Some(t) => Some(Box::new(t.extract()?)),
None => None,
},
prefix_items: dict_get!(dict, "prefix_items", Vec<Schema>),
prefix_items: get_as!(dict, "prefix_items", Vec<Schema>),
contains: match dict.get_item("contains") {
Some(t) => Some(Box::new(Schema::extract(t)?)),
Some(t) => Some(Box::new(t.extract()?)),
None => None,
},
unique_items: match dict.get_item("unique_items") {
Some(t) => bool::extract(t)?,
None => false,
},
min_items: dict_get!(dict, "min_items", usize),
max_items: dict_get!(dict, "max_items", usize),
min_contains: dict_get!(dict, "min_contains", usize),
max_contains: dict_get!(dict, "max_contains", usize),
min_items: get_as!(dict, "min_items", usize),
max_items: get_as!(dict, "max_items", usize),
min_contains: get_as!(dict, "min_contains", usize),
max_contains: get_as!(dict, "max_contains", usize),
})
}

Expand Down Expand Up @@ -264,7 +264,7 @@ impl SchemaType for ObjectSchema {
properties.push(SchemaProperty {
key: key.to_string(),
required: required.contains(&key.to_string()),
schema: Schema::extract(value)?,
schema: value.extract()?,
validator,
});
}
Expand All @@ -273,11 +273,11 @@ impl SchemaType for ObjectSchema {
None => Vec::new(),
},
additional_properties: match dict.get_item("additional_properties") {
Some(t) => Some(Box::new(Schema::extract(t)?)),
Some(t) => Some(Box::new(t.extract()?)),
None => None,
},
min_properties: dict_get!(dict, "min_properties", usize),
max_properties: dict_get!(dict, "max_properties", usize),
min_properties: get_as!(dict, "min_properties", usize),
max_properties: get_as!(dict, "max_properties", usize),
})
}

Expand Down Expand Up @@ -320,7 +320,7 @@ enum Schema {

impl SchemaType for Schema {
fn build(dict: &PyDict) -> PyResult<Self> {
let type_ = match dict_get!(dict, "type", String) {
let type_ = match get_as!(dict, "type", String) {
Some(type_) => type_,
None => {
return Err(PyKeyError::new_err("'type' is required"));
Expand Down Expand Up @@ -367,8 +367,8 @@ pub struct SchemaValidator {
#[pymethods]
impl SchemaValidator {
#[new]
fn py_new(py: Python, schema: PyObject) -> PyResult<Self> {
let schema: Schema = schema.extract(py)?;
fn py_new(schema_dict: &PyDict) -> PyResult<Self> {
let schema = Schema::build(schema_dict)?;
Ok(Self { schema })
}

Expand Down
10 changes: 5 additions & 5 deletions src/validators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use pyo3::types::{PyBytes, PyDict, PyInt, PyList, PyString};

#[pyfunction]
pub fn validate_str(v: &PyAny) -> PyResult<String> {
if let Ok(str) = v.cast_as::<PyString>() {
if let Ok(str) = v.extract::<&PyString>() {
str.extract()
} else if let Ok(bytes) = v.cast_as::<PyBytes>() {
} else if let Ok(bytes) = v.extract::<&PyBytes>() {
Ok(std::str::from_utf8(bytes.as_bytes())?.to_string())
} else if let Ok(int) = v.cast_as::<PyInt>() {
} else if let Ok(int) = v.extract::<&PyInt>() {
Ok(i64::extract(int)?.to_string())
} else if let Ok(float) = f64::extract(v) {
// don't cast_as here so Decimals are covered - internally f64:extract uses PyFloat_AsDouble
Expand Down Expand Up @@ -101,9 +101,9 @@ pub fn validate_str_recursive<'py>(
to_lower: bool,
to_upper: bool,
) -> PyResult<&'py PyAny> {
if let Ok(list) = value.cast_as::<PyList>() {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can reproduce this locally and it turns out that exactly these two calls to extract make the difference. (The other ones above do not.)

(It would be nice to loose the second here to keep this PR focused on the extract-vs-cast-as issue.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks so much for looking into it.

Is there anything I can do to help here?

On balance, to avoid profiling every change, I guess I'd prefer to use cast_as where possible - I feels more semantically correct and AFAIK it's never slower (???). So I'll just close this PR.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have reproduced the problem in an isolated PyO3 benchmark now. So I think you can close this and also revoke my access to the repository. Everything else will need to be handled in the PyO3 repository.

if let Ok(list) = value.extract::<&PyList>() {
validate_str_list(py, list, min_length, max_length, strip_whitespace, to_lower, to_upper)
} else if let Ok(dict) = value.cast_as::<PyDict>() {
} else if let Ok(dict) = value.extract::<&PyDict>() {
validate_str_dict(py, dict, min_length, max_length, strip_whitespace, to_lower, to_upper)
} else {
validate_str_full(py, value, min_length, max_length, strip_whitespace, to_lower, to_upper)
Expand Down