Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up some string handling cases #1381

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::borrow::Cow;
use std::str::from_utf8;

use pyo3::intern;
Expand Down Expand Up @@ -144,12 +143,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)),
}
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
// Safety: the gil is held while from_utf8 is running so py_byte_array is not mutated,
// and we immediately copy the bytes into a new Python string
match from_utf8(unsafe { py_byte_array.as_bytes() }) {
// Why Python not Rust? to avoid an unnecessary allocation on the Rust side, the
// final output needs to be Python anyway.
Ok(s) => Ok(PyString::new_bound(self.py(), s).into()),
match bytearray_to_str(py_byte_array) {
Ok(py_str) => Ok(py_str.into()),
Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)),
}
} else if coerce_numbers_to_str && !self.is_exact_instance_of::<PyBool>() && {
Expand Down Expand Up @@ -204,8 +199,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
}

if !strict {
if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? {
return str_as_bool(self, &cow_str).map(ValidationMatch::lax);
if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? {
return str_as_bool(self, s).map(ValidationMatch::lax);
} else if let Some(int) = extract_i64(self) {
return int_as_bool(self, int).map(ValidationMatch::lax);
} else if let Ok(float) = self.extract::<f64>() {
Expand Down Expand Up @@ -241,8 +236,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {

'lax: {
if !strict {
return if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? {
str_as_int(self, &cow_str)
return if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? {
str_as_int(self, s)
} else if self.is_exact_instance_of::<PyFloat>() {
float_as_int(self, self.extract::<f64>()?)
} else if let Ok(decimal) = self.strict_decimal(self.py()) {
Expand Down Expand Up @@ -283,9 +278,9 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
}

if !strict {
if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? {
if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? {
// checking for bytes and string is fast, so do this before isinstance(float)
return str_as_float(self, &cow_str).map(ValidationMatch::lax);
return str_as_float(self, s).map(ValidationMatch::lax);
}
}

Expand Down Expand Up @@ -630,20 +625,31 @@ fn from_attributes_applicable(obj: &Bound<'_, PyAny>) -> bool {
}

/// Utility for extracting a string from a PyAny, if possible.
fn maybe_as_string<'a>(v: &'a Bound<'_, PyAny>, unicode_error: ErrorType) -> ValResult<Option<Cow<'a, str>>> {
fn maybe_as_string<'a>(v: &'a Bound<'_, PyAny>, unicode_error: ErrorType) -> ValResult<Option<&'a str>> {
if let Ok(py_string) = v.downcast::<PyString>() {
let str = py_string_str(py_string)?;
Ok(Some(Cow::Borrowed(str)))
py_string_str(py_string).map(Some)
} else if let Ok(bytes) = v.downcast::<PyBytes>() {
match from_utf8(bytes.as_bytes()) {
Ok(s) => Ok(Some(Cow::Owned(s.to_string()))),
Ok(s) => Ok(Some(s)),
Err(_) => Err(ValError::new(unicode_error, v)),
}
} else {
Ok(None)
}
}

/// Decode a Python bytearray to a Python string.
///
/// Using Python's built-in machinery for this should be efficient and avoids questions around
/// safety of concurrent mutation of the bytearray (by leaving that to the Python interpreter).
fn bytearray_to_str<'py>(bytearray: &Bound<'py, PyByteArray>) -> PyResult<Bound<'py, PyString>> {
let py = bytearray.py();
let py_string = bytearray
.call_method1(intern!(py, "decode"), (intern!(py, "utf-8"),))?
.downcast_into()?;
Ok(py_string)
}

/// Utility for extracting an enum value, if possible.
fn maybe_as_enum<'py>(v: &Bound<'py, PyAny>) -> Option<Bound<'py, PyAny>> {
let py = v.py();
Expand Down
2 changes: 1 addition & 1 deletion src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ macro_rules! serialization_mode {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<Bound<'_, PyString>>(intern!(config_dict.py(), $config_key))?;
raw_mode.map_or_else(|| Ok(Self::default()), |raw| Self::from_str(&raw.to_cow()?))
raw_mode.map_or_else(|| Ok(Self::default()), |raw| Self::from_str(raw.to_str()?))
}
}

Expand Down
14 changes: 7 additions & 7 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl GeneralFieldsSerializer {
for result in main_iter {
let (key, value) = result?;
let key_str = key_str(&key)?;
let op_field = self.fields.get(key_str.as_ref());
let op_field = self.fields.get(key_str);
if extra.exclude_none && value.is_none() {
if let Some(field) = op_field {
if field.required {
Expand All @@ -169,7 +169,7 @@ impl GeneralFieldsSerializer {
continue;
}
let field_extra = Extra {
field_name: Some(&key_str),
field_name: Some(key_str),
..extra
};
if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? {
Expand Down Expand Up @@ -236,13 +236,13 @@ impl GeneralFieldsSerializer {
}
let key_str = key_str(&key).map_err(py_err_se_err)?;
let field_extra = Extra {
field_name: Some(&key_str),
field_name: Some(key_str),
..extra
};

let filter = self.filter.key_filter(&key, include, exclude).map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = filter {
if let Some(field) = self.fields.get(key_str.as_ref()) {
if let Some(field) = self.fields.get(key_str) {
if let Some(ref serializer) = field.serializer {
if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
let s = PydanticSerializer::new(
Expand All @@ -252,7 +252,7 @@ impl GeneralFieldsSerializer {
next_exclude.as_ref(),
&field_extra,
);
let output_key = field.get_key_json(&key_str, &field_extra);
let output_key = field.get_key_json(key_str, &field_extra);
map.serialize_entry(&output_key, &s)?;
}
}
Expand Down Expand Up @@ -446,8 +446,8 @@ impl TypeSerializer for GeneralFieldsSerializer {
}
}

fn key_str<'a>(key: &'a Bound<'_, PyAny>) -> PyResult<Cow<'a, str>> {
key.downcast::<PyString>()?.to_cow()
fn key_str<'a>(key: &'a Bound<'_, PyAny>) -> PyResult<&'a str> {
key.downcast::<PyString>()?.to_str()
}

fn dict_items<'py>(
Expand Down
Loading