Skip to content

Commit 57e6991

Browse files
authored
Validate bytes based on ser_json_bytes (#1308)
1 parent 40b8a94 commit 57e6991

19 files changed

+223
-22
lines changed

Cargo.lock

+7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ num-bigint = "0.4.6"
4747
python3-dll-a = "0.2.10"
4848
uuid = "1.9.1"
4949
jiter = { version = "0.5", features = ["python"] }
50+
hex = "0.4.3"
5051

5152
[lib]
5253
name = "_pydantic_core"

python/pydantic_core/_pydantic_core.pyi

+4-4
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def to_json(
352352
exclude_none: bool = False,
353353
round_trip: bool = False,
354354
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
355-
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
355+
bytes_mode: Literal['utf8', 'base64', 'hex'] = 'utf8',
356356
inf_nan_mode: Literal['null', 'constants', 'strings'] = 'constants',
357357
serialize_unknown: bool = False,
358358
fallback: Callable[[Any], Any] | None = None,
@@ -373,7 +373,7 @@ def to_json(
373373
exclude_none: Whether to exclude fields that have a value of `None`.
374374
round_trip: Whether to enable serialization and validation round-trip support.
375375
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
376-
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
376+
bytes_mode: How to serialize `bytes` objects, either `'utf8'`, `'base64'`, or `'hex'`.
377377
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'`, `'constants'`, or `'strings'`.
378378
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
379379
`"<Unserializable {value_type} object>"` will be used.
@@ -427,7 +427,7 @@ def to_jsonable_python(
427427
exclude_none: bool = False,
428428
round_trip: bool = False,
429429
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
430-
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
430+
bytes_mode: Literal['utf8', 'base64', 'hex'] = 'utf8',
431431
inf_nan_mode: Literal['null', 'constants', 'strings'] = 'constants',
432432
serialize_unknown: bool = False,
433433
fallback: Callable[[Any], Any] | None = None,
@@ -448,7 +448,7 @@ def to_jsonable_python(
448448
exclude_none: Whether to exclude fields that have a value of `None`.
449449
round_trip: Whether to enable serialization and validation round-trip support.
450450
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
451-
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
451+
bytes_mode: How to serialize `bytes` objects, either `'utf8'`, `'base64'`, or `'hex'`.
452452
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'`, `'constants'`, or `'strings'`.
453453
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
454454
`"<Unserializable {value_type} object>"` will be used.

python/pydantic_core/core_schema.py

+3
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class CoreConfig(TypedDict, total=False):
7070
ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'.
7171
ser_json_inf_nan: The serialization option for infinity and NaN values
7272
in float fields. Default is 'null'.
73+
val_json_bytes: The validation option for `bytes` values, complementing ser_json_bytes. Default is 'utf8'.
7374
hide_input_in_errors: Whether to hide input data from `ValidationError` representation.
7475
validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError.
7576
Requires exceptiongroup backport pre Python 3.11.
@@ -107,6 +108,7 @@ class CoreConfig(TypedDict, total=False):
107108
ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601'
108109
ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
109110
ser_json_inf_nan: Literal['null', 'constants', 'strings'] # default: 'null'
111+
val_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
110112
# used to hide input data from ValidationError repr
111113
hide_input_in_errors: bool
112114
validation_error_cause: bool # default: False
@@ -3904,6 +3906,7 @@ def definition_reference_schema(
39043906
'bytes_type',
39053907
'bytes_too_short',
39063908
'bytes_too_long',
3909+
'bytes_invalid_encoding',
39073910
'value_error',
39083911
'assertion_error',
39093912
'literal_error',

src/errors/types.rs

+10
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,10 @@ error_types! {
290290
BytesTooLong {
291291
max_length: {ctx_type: usize, ctx_fn: field_from_context},
292292
},
293+
BytesInvalidEncoding {
294+
encoding: {ctx_type: String, ctx_fn: field_from_context},
295+
encoding_error: {ctx_type: String, ctx_fn: field_from_context},
296+
},
293297
// ---------------------
294298
// python errors from functions
295299
ValueError {
@@ -515,6 +519,7 @@ impl ErrorType {
515519
Self::BytesType {..} => "Input should be a valid bytes",
516520
Self::BytesTooShort {..} => "Data should have at least {min_length} byte{expected_plural}",
517521
Self::BytesTooLong {..} => "Data should have at most {max_length} byte{expected_plural}",
522+
Self::BytesInvalidEncoding { .. } => "Data should be valid {encoding}: {encoding_error}",
518523
Self::ValueError {..} => "Value error, {error}",
519524
Self::AssertionError {..} => "Assertion failed, {error}",
520525
Self::CustomError {..} => "", // custom errors are handled separately
@@ -664,6 +669,11 @@ impl ErrorType {
664669
let expected_plural = plural_s(*max_length);
665670
to_string_render!(tmpl, max_length, expected_plural)
666671
}
672+
Self::BytesInvalidEncoding {
673+
encoding,
674+
encoding_error,
675+
..
676+
} => render!(tmpl, encoding, encoding_error),
667677
Self::ValueError { error, .. } => {
668678
let error = &error
669679
.as_ref()

src/input/input_abstract.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use pyo3::{intern, prelude::*};
77
use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
88
use crate::lookup_key::{LookupKey, LookupPath};
99
use crate::tools::py_err;
10+
use crate::validators::ValBytesMode;
1011

1112
use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
1213
use super::return_enums::{EitherBytes, EitherInt, EitherString};
@@ -71,7 +72,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
7172

7273
fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValMatch<EitherString<'_>>;
7374

74-
fn validate_bytes<'a>(&'a self, strict: bool) -> ValMatch<EitherBytes<'a, 'py>>;
75+
fn validate_bytes<'a>(&'a self, strict: bool, mode: ValBytesMode) -> ValMatch<EitherBytes<'a, 'py>>;
7576

7677
fn validate_bool(&self, strict: bool) -> ValMatch<bool>;
7778

src/input/input_json.rs

+19-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use strum::EnumMessage;
1010
use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
1111
use crate::lookup_key::{LookupKey, LookupPath};
1212
use crate::validators::decimal::create_decimal;
13+
use crate::validators::ValBytesMode;
1314

1415
use super::datetime::{
1516
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
@@ -106,9 +107,16 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
106107
}
107108
}
108109

109-
fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
110+
fn validate_bytes<'a>(
111+
&'a self,
112+
_strict: bool,
113+
mode: ValBytesMode,
114+
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
110115
match self {
111-
JsonValue::Str(s) => Ok(ValidationMatch::strict(s.as_bytes().into())),
116+
JsonValue::Str(s) => match mode.deserialize_string(s) {
117+
Ok(b) => Ok(ValidationMatch::strict(b)),
118+
Err(e) => Err(ValError::new(e, self)),
119+
},
112120
_ => Err(ValError::new(ErrorTypeDefaults::BytesType, self)),
113121
}
114122
}
@@ -342,8 +350,15 @@ impl<'py> Input<'py> for str {
342350
Ok(ValidationMatch::strict(self.into()))
343351
}
344352

345-
fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
346-
Ok(ValidationMatch::strict(self.as_bytes().into()))
353+
fn validate_bytes<'a>(
354+
&'a self,
355+
_strict: bool,
356+
mode: ValBytesMode,
357+
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
358+
match mode.deserialize_string(self) {
359+
Ok(b) => Ok(ValidationMatch::strict(b)),
360+
Err(e) => Err(ValError::new(e, self)),
361+
}
347362
}
348363

349364
fn validate_bool(&self, _strict: bool) -> ValResult<ValidationMatch<bool>> {

src/input/input_python.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError,
1717
use crate::tools::{extract_i64, safe_repr};
1818
use crate::validators::decimal::{create_decimal, get_decimal_type};
1919
use crate::validators::Exactness;
20+
use crate::validators::ValBytesMode;
2021
use crate::ArgsKwargs;
2122

2223
use super::datetime::{
@@ -174,7 +175,11 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
174175
Err(ValError::new(ErrorTypeDefaults::StringType, self))
175176
}
176177

177-
fn validate_bytes<'a>(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
178+
fn validate_bytes<'a>(
179+
&'a self,
180+
strict: bool,
181+
mode: ValBytesMode,
182+
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
178183
if let Ok(py_bytes) = self.downcast_exact::<PyBytes>() {
179184
return Ok(ValidationMatch::exact(py_bytes.into()));
180185
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
@@ -185,7 +190,10 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
185190
if !strict {
186191
return if let Ok(py_str) = self.downcast::<PyString>() {
187192
let str = py_string_str(py_str)?;
188-
Ok(str.as_bytes().into())
193+
match mode.deserialize_string(str) {
194+
Ok(b) => Ok(b),
195+
Err(e) => Err(ValError::new(e, self)),
196+
}
189197
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
190198
Ok(py_byte_array.to_vec().into())
191199
} else {

src/input/input_string.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::input::py_string_str;
88
use crate::lookup_key::{LookupKey, LookupPath};
99
use crate::tools::safe_repr;
1010
use crate::validators::decimal::create_decimal;
11+
use crate::validators::ValBytesMode;
1112

1213
use super::datetime::{
1314
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
@@ -105,9 +106,16 @@ impl<'py> Input<'py> for StringMapping<'py> {
105106
}
106107
}
107108

108-
fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
109+
fn validate_bytes<'a>(
110+
&'a self,
111+
_strict: bool,
112+
mode: ValBytesMode,
113+
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
109114
match self {
110-
Self::String(s) => py_string_str(s).map(|b| ValidationMatch::strict(b.as_bytes().into())),
115+
Self::String(s) => py_string_str(s).and_then(|b| match mode.deserialize_string(b) {
116+
Ok(b) => Ok(ValidationMatch::strict(b)),
117+
Err(e) => Err(ValError::new(e, self)),
118+
}),
111119
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)),
112120
}
113121
}

src/lib.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use std::sync::OnceLock;
77
use jiter::{map_json_error, PartialMode, PythonParse, StringCacheMode};
88
use pyo3::exceptions::PyTypeError;
99
use pyo3::{prelude::*, sync::GILOnceCell};
10+
use serializers::BytesMode;
11+
use validators::ValBytesMode;
1012

1113
// parse this first to get access to the contained macro
1214
#[macro_use]
@@ -55,7 +57,7 @@ pub fn from_json<'py>(
5557
allow_partial: bool,
5658
) -> PyResult<Bound<'py, PyAny>> {
5759
let v_match = data
58-
.validate_bytes(false)
60+
.validate_bytes(false, ValBytesMode { ser: BytesMode::Utf8 })
5961
.map_err(|_| PyTypeError::new_err("Expected bytes, bytearray or str"))?;
6062
let json_either_bytes = v_match.into_inner();
6163
let json_bytes = json_either_bytes.as_slice();

src/serializers/config.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pub trait FromConfig {
5252
macro_rules! serialization_mode {
5353
($name:ident, $config_key:expr, $($variant:ident => $value:expr),* $(,)?) => {
5454
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
55-
pub(crate) enum $name {
55+
pub enum $name {
5656
#[default]
5757
$($variant,)*
5858
}
@@ -183,9 +183,7 @@ impl BytesMode {
183183
Err(e) => Err(Error::custom(e.to_string())),
184184
},
185185
Self::Base64 => serializer.serialize_str(&base64::engine::general_purpose::URL_SAFE.encode(bytes)),
186-
Self::Hex => {
187-
serializer.serialize_str(&bytes.iter().fold(String::new(), |acc, b| acc + &format!("{b:02x}")))
188-
}
186+
Self::Hex => serializer.serialize_str(hex::encode(bytes).as_str()),
189187
}
190188
}
191189
}

src/serializers/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use pyo3::{PyTraverseError, PyVisit};
88
use crate::definitions::{Definitions, DefinitionsBuilder};
99
use crate::py_gc::PyGcTraverse;
1010

11+
pub(crate) use config::BytesMode;
1112
use config::SerializationConfig;
1213
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
1314
use extra::{CollectWarnings, SerRecursionState, WarningsMode};

src/validators/bytes.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ use crate::input::Input;
88

99
use crate::tools::SchemaDict;
1010

11+
use super::config::ValBytesMode;
1112
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
1213

1314
#[derive(Debug, Clone)]
1415
pub struct BytesValidator {
1516
strict: bool,
17+
bytes_mode: ValBytesMode,
1618
}
1719

1820
impl BuildValidator for BytesValidator {
@@ -31,6 +33,7 @@ impl BuildValidator for BytesValidator {
3133
} else {
3234
Ok(Self {
3335
strict: is_strict(schema, config)?,
36+
bytes_mode: ValBytesMode::from_config(config)?,
3437
}
3538
.into())
3639
}
@@ -47,7 +50,7 @@ impl Validator for BytesValidator {
4750
state: &mut ValidationState<'_, 'py>,
4851
) -> ValResult<PyObject> {
4952
input
50-
.validate_bytes(state.strict_or(self.strict))
53+
.validate_bytes(state.strict_or(self.strict), self.bytes_mode)
5154
.map(|m| m.unpack(state).into_py(py))
5255
}
5356

@@ -59,6 +62,7 @@ impl Validator for BytesValidator {
5962
#[derive(Debug, Clone)]
6063
pub struct BytesConstrainedValidator {
6164
strict: bool,
65+
bytes_mode: ValBytesMode,
6266
max_length: Option<usize>,
6367
min_length: Option<usize>,
6468
}
@@ -72,7 +76,9 @@ impl Validator for BytesConstrainedValidator {
7276
input: &(impl Input<'py> + ?Sized),
7377
state: &mut ValidationState<'_, 'py>,
7478
) -> ValResult<PyObject> {
75-
let either_bytes = input.validate_bytes(state.strict_or(self.strict))?.unpack(state);
79+
let either_bytes = input
80+
.validate_bytes(state.strict_or(self.strict), self.bytes_mode)?
81+
.unpack(state);
7682
let len = either_bytes.len()?;
7783

7884
if let Some(min_length) = self.min_length {
@@ -110,6 +116,7 @@ impl BytesConstrainedValidator {
110116
let py = schema.py();
111117
Ok(Self {
112118
strict: is_strict(schema, config)?,
119+
bytes_mode: ValBytesMode::from_config(config)?,
113120
min_length: schema.get_as(intern!(py, "min_length"))?,
114121
max_length: schema.get_as(intern!(py, "max_length"))?,
115122
}

src/validators/config.rs

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
use std::borrow::Cow;
2+
use std::str::FromStr;
3+
4+
use base64::Engine;
5+
use pyo3::types::{PyDict, PyString};
6+
use pyo3::{intern, prelude::*};
7+
8+
use crate::errors::ErrorType;
9+
use crate::input::EitherBytes;
10+
use crate::serializers::BytesMode;
11+
use crate::tools::SchemaDict;
12+
13+
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
14+
pub struct ValBytesMode {
15+
pub ser: BytesMode,
16+
}
17+
18+
impl ValBytesMode {
19+
pub fn from_config(config: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
20+
let Some(config_dict) = config else {
21+
return Ok(Self::default());
22+
};
23+
let raw_mode = config_dict.get_as::<Bound<'_, PyString>>(intern!(config_dict.py(), "val_json_bytes"))?;
24+
let ser_mode = raw_mode.map_or_else(|| Ok(BytesMode::default()), |raw| BytesMode::from_str(&raw.to_cow()?))?;
25+
Ok(Self { ser: ser_mode })
26+
}
27+
28+
pub fn deserialize_string<'py>(self, s: &str) -> Result<EitherBytes<'_, 'py>, ErrorType> {
29+
match self.ser {
30+
BytesMode::Utf8 => Ok(EitherBytes::Cow(Cow::Borrowed(s.as_bytes()))),
31+
BytesMode::Base64 => match base64::engine::general_purpose::URL_SAFE.decode(s) {
32+
Ok(bytes) => Ok(EitherBytes::from(bytes)),
33+
Err(err) => Err(ErrorType::BytesInvalidEncoding {
34+
encoding: "base64".to_string(),
35+
encoding_error: err.to_string(),
36+
context: None,
37+
}),
38+
},
39+
BytesMode::Hex => match hex::decode(s) {
40+
Ok(vec) => Ok(EitherBytes::from(vec)),
41+
Err(err) => Err(ErrorType::BytesInvalidEncoding {
42+
encoding: "hex".to_string(),
43+
encoding_error: err.to_string(),
44+
context: None,
45+
}),
46+
},
47+
}
48+
}
49+
}

0 commit comments

Comments
 (0)