From 90dfb74d321d676f72acd2402f20869ba93a21f7 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Wed, 31 Jul 2024 12:56:17 +0100 Subject: [PATCH] clean up some string handling cases --- src/input/input_python.rs | 40 ++++++++++++++++++++++----------------- src/serializers/config.rs | 2 +- src/serializers/fields.rs | 14 +++++++------- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 3de712272..4245029ba 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::str::from_utf8; use pyo3::intern; @@ -144,12 +143,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), } } else if let Ok(py_byte_array) = self.downcast::() { - // Safety: the gil is held while from_utf8 is running so py_byte_array is not mutated, - // and we immediately copy the bytes into a new Python string - match from_utf8(unsafe { py_byte_array.as_bytes() }) { - // Why Python not Rust? to avoid an unnecessary allocation on the Rust side, the - // final output needs to be Python anyway. - Ok(s) => Ok(PyString::new_bound(self.py(), s).into()), + match bytearray_to_str(py_byte_array) { + Ok(py_str) => Ok(py_str.into()), Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), } } else if coerce_numbers_to_str && !self.is_exact_instance_of::() && { @@ -204,8 +199,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { } if !strict { - if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? { - return str_as_bool(self, &cow_str).map(ValidationMatch::lax); + if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? { + return str_as_bool(self, s).map(ValidationMatch::lax); } else if let Some(int) = extract_i64(self) { return int_as_bool(self, int).map(ValidationMatch::lax); } else if let Ok(float) = self.extract::() { @@ -241,8 +236,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { 'lax: { if !strict { - return if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? { - str_as_int(self, &cow_str) + return if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? { + str_as_int(self, s) } else if self.is_exact_instance_of::() { float_as_int(self, self.extract::()?) } else if let Ok(decimal) = self.strict_decimal(self.py()) { @@ -283,9 +278,9 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { } if !strict { - if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? { + if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? { // checking for bytes and string is fast, so do this before isinstance(float) - return str_as_float(self, &cow_str).map(ValidationMatch::lax); + return str_as_float(self, s).map(ValidationMatch::lax); } } @@ -630,13 +625,12 @@ fn from_attributes_applicable(obj: &Bound<'_, PyAny>) -> bool { } /// Utility for extracting a string from a PyAny, if possible. -fn maybe_as_string<'a>(v: &'a Bound<'_, PyAny>, unicode_error: ErrorType) -> ValResult>> { +fn maybe_as_string<'a>(v: &'a Bound<'_, PyAny>, unicode_error: ErrorType) -> ValResult> { if let Ok(py_string) = v.downcast::() { - let str = py_string_str(py_string)?; - Ok(Some(Cow::Borrowed(str))) + py_string_str(py_string).map(Some) } else if let Ok(bytes) = v.downcast::() { match from_utf8(bytes.as_bytes()) { - Ok(s) => Ok(Some(Cow::Owned(s.to_string()))), + Ok(s) => Ok(Some(s)), Err(_) => Err(ValError::new(unicode_error, v)), } } else { @@ -644,6 +638,18 @@ fn maybe_as_string<'a>(v: &'a Bound<'_, PyAny>, unicode_error: ErrorType) -> Val } } +/// Decode a Python bytearray to a Python string. +/// +/// Using Python's built-in machinery for this should be efficient and avoids questions around +/// safety of concurrent mutation of the bytearray (by leaving that to the Python interpreter). +fn bytearray_to_str<'py>(bytearray: &Bound<'py, PyByteArray>) -> PyResult> { + let py = bytearray.py(); + let py_string = bytearray + .call_method1(intern!(py, "decode"), (intern!(py, "utf-8"),))? + .downcast_into()?; + Ok(py_string) +} + /// Utility for extracting an enum value, if possible. fn maybe_as_enum<'py>(v: &Bound<'py, PyAny>) -> Option> { let py = v.py(); diff --git a/src/serializers/config.rs b/src/serializers/config.rs index 5421ed920..ad571a529 100644 --- a/src/serializers/config.rs +++ b/src/serializers/config.rs @@ -77,7 +77,7 @@ macro_rules! serialization_mode { return Ok(Self::default()); }; let raw_mode = config_dict.get_as::>(intern!(config_dict.py(), $config_key))?; - raw_mode.map_or_else(|| Ok(Self::default()), |raw| Self::from_str(&raw.to_cow()?)) + raw_mode.map_or_else(|| Ok(Self::default()), |raw| Self::from_str(raw.to_str()?)) } } diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index bf72ec218..f4f8910eb 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -159,7 +159,7 @@ impl GeneralFieldsSerializer { for result in main_iter { let (key, value) = result?; let key_str = key_str(&key)?; - let op_field = self.fields.get(key_str.as_ref()); + let op_field = self.fields.get(key_str); if extra.exclude_none && value.is_none() { if let Some(field) = op_field { if field.required { @@ -169,7 +169,7 @@ impl GeneralFieldsSerializer { continue; } let field_extra = Extra { - field_name: Some(&key_str), + field_name: Some(key_str), ..extra }; if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? { @@ -236,13 +236,13 @@ impl GeneralFieldsSerializer { } let key_str = key_str(&key).map_err(py_err_se_err)?; let field_extra = Extra { - field_name: Some(&key_str), + field_name: Some(key_str), ..extra }; let filter = self.filter.key_filter(&key, include, exclude).map_err(py_err_se_err)?; if let Some((next_include, next_exclude)) = filter { - if let Some(field) = self.fields.get(key_str.as_ref()) { + if let Some(field) = self.fields.get(key_str) { if let Some(ref serializer) = field.serializer { if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? { let s = PydanticSerializer::new( @@ -252,7 +252,7 @@ impl GeneralFieldsSerializer { next_exclude.as_ref(), &field_extra, ); - let output_key = field.get_key_json(&key_str, &field_extra); + let output_key = field.get_key_json(key_str, &field_extra); map.serialize_entry(&output_key, &s)?; } } @@ -446,8 +446,8 @@ impl TypeSerializer for GeneralFieldsSerializer { } } -fn key_str<'a>(key: &'a Bound<'_, PyAny>) -> PyResult> { - key.downcast::()?.to_cow() +fn key_str<'a>(key: &'a Bound<'_, PyAny>) -> PyResult<&'a str> { + key.downcast::()?.to_str() } fn dict_items<'py>(