Skip to content

Commit

Permalink
Fix pydantic 7715 (#1002)
Browse files Browse the repository at this point in the history
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
  • Loading branch information
sydney-runkle and dmontagu authored Oct 2, 2023
1 parent 4622ed7 commit d93482e
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 31 deletions.
38 changes: 25 additions & 13 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,19 +232,31 @@ impl Validator for DataclassArgsValidator {
}
// found neither, check if there is a default value, otherwise error
(None, None) => {
if let Some(value) =
field
.validator
.default_value(py, Some(field.name.as_str()), state)?
{
set_item!(field, value);
} else {
errors.push(field.lookup_key.error(
ErrorTypeDefaults::Missing,
input,
self.loc_by_alias,
&field.name,
));
match field.validator.default_value(py, Some(field.name.as_str()), state) {
Ok(Some(value)) => {
// Default value exists, and passed validation if required
set_item!(field, value);
},
Ok(None) => {
// This means there was no default value
errors.push(field.lookup_key.error(
ErrorTypeDefaults::Missing,
input,
self.loc_by_alias,
&field.name
));
},
Err(ValError::Omit) => continue,
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
// Note: this will always use the field name even if there is an alias
// However, we don't mind so much because this error can only happen if the
// default value fails validation, which is arguably a developer error.
// We could try to "fix" this in the future if desired.
errors.push(err);
}
}
Err(err) => return Err(err),
}
}
}
Expand Down
36 changes: 27 additions & 9 deletions src/validators/model_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,15 +211,33 @@ impl Validator for ModelFieldsValidator {
Err(err) => return ControlFlow::Break(err.into_owned(py)),
}
continue;
} else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? {
control_flow!(model_dict.set_item(&field.name_py, value))?;
} else {
errors.push(field.lookup_key.error(
ErrorTypeDefaults::Missing,
input,
self.loc_by_alias,
&field.name
));
}

match field.validator.default_value(py, Some(field.name.as_str()), state) {
Ok(Some(value)) => {
// Default value exists, and passed validation if required
control_flow!(model_dict.set_item(&field.name_py, value))?;
},
Ok(None) => {
// This means there was no default value
errors.push(field.lookup_key.error(
ErrorTypeDefaults::Missing,
input,
self.loc_by_alias,
&field.name
));
},
Err(ValError::Omit) => continue,
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
// Note: this will always use the field name even if there is an alias
// However, we don't mind so much because this error can only happen if the
// default value fails validation, which is arguably a developer error.
// We could try to "fix" this in the future if desired.
errors.push(err);
}
}
Err(err) => return ControlFlow::Break(err),
}
}
ControlFlow::Continue(())
Expand Down
38 changes: 29 additions & 9 deletions src/validators/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,35 @@ impl Validator for TypedDictValidator {
Err(err) => return ControlFlow::Break(err.into_owned(py)),
}
continue;
} else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? {
control_flow!(output_dict.set_item(&field.name_py, value))?;
} else if field.required {
errors.push(field.lookup_key.error(
ErrorTypeDefaults::Missing,
input,
self.loc_by_alias,
&field.name
));
}

match field.validator.default_value(py, Some(field.name.as_str()), state) {
Ok(Some(value)) => {
// Default value exists, and passed validation if required
control_flow!(output_dict.set_item(&field.name_py, value))?;
},
Ok(None) => {
// This means there was no default value
if (field.required) {
errors.push(field.lookup_key.error(
ErrorTypeDefaults::Missing,
input,
self.loc_by_alias,
&field.name
));
}
},
Err(ValError::Omit) => continue,
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
// Note: this will always use the field name even if there is an alias
// However, we don't mind so much because this error can only happen if the
// default value fails validation, which is arguably a developer error.
// We could try to "fix" this in the future if desired.
errors.push(err);
}
}
Err(err) => return ControlFlow::Break(err),
}
}
ControlFlow::Continue(())
Expand Down
150 changes: 150 additions & 0 deletions tests/validators/test_with_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,153 @@ def _validator(cls, v, info):
gc.collect()

assert ref() is None


validate_default_raises_examples = [
(
{},
[
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {}},
],
),
(
{'z': 'some str'},
[
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
],
),
(
{'x': None},
[
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'x': None}},
],
),
(
{'x': None, 'z': 'some str'},
[
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
],
),
(
{'y': None},
[
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'y': None}},
],
),
(
{'y': None, 'z': 'some str'},
[
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
],
),
(
{'x': None, 'y': None},
[
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'x': None, 'y': None}},
],
),
(
{'x': None, 'y': None, 'z': 'some str'},
[
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
],
),
(
{'x': 1, 'y': None, 'z': 'some str'},
[
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': 1},
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None},
],
),
(
{'x': None, 'y': 1, 'z': 'some str'},
[
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None},
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': 1},
],
),
(
{'x': 1, 'y': 1, 'z': 'some str'},
[
{'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': 1},
{'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': 1},
],
),
]


@pytest.mark.parametrize(
'core_schema_constructor,field_constructor',
[
(core_schema.model_fields_schema, core_schema.model_field),
(core_schema.typed_dict_schema, core_schema.typed_dict_field),
],
)
@pytest.mark.parametrize('input_value,expected', validate_default_raises_examples)
def test_validate_default_raises(
core_schema_constructor: Union[core_schema.ModelFieldsSchema, core_schema.TypedDictSchema],
field_constructor: Union[core_schema.model_field, core_schema.typed_dict_field],
input_value: dict,
expected: Any,
) -> None:
def _raise(ex: Exception) -> None:
raise ex()

inner_schema = core_schema.no_info_after_validator_function(
lambda x: _raise(AssertionError), core_schema.nullable_schema(core_schema.int_schema())
)

v = SchemaValidator(
core_schema_constructor(
{
'x': field_constructor(
core_schema.with_default_schema(inner_schema, default=None, validate_default=True)
),
'y': field_constructor(
core_schema.with_default_schema(inner_schema, default=None, validate_default=True)
),
'z': field_constructor(core_schema.str_schema()),
}
)
)

with pytest.raises(ValidationError) as exc_info:
v.validate_python(input_value)
assert exc_info.value.errors(include_url=False, include_context=False) == expected


@pytest.mark.parametrize('input_value,expected', validate_default_raises_examples)
def test_validate_default_raises_dataclass(input_value: dict, expected: Any) -> None:
def _raise(ex: Exception) -> None:
raise ex()

inner_schema = core_schema.no_info_after_validator_function(
lambda x: _raise(AssertionError), core_schema.nullable_schema(core_schema.int_schema())
)

x = core_schema.dataclass_field(
name='x', schema=core_schema.with_default_schema(inner_schema, default=None, validate_default=True)
)
y = core_schema.dataclass_field(
name='y', schema=core_schema.with_default_schema(inner_schema, default=None, validate_default=True)
)
z = core_schema.dataclass_field(name='z', schema=core_schema.str_schema())

v = SchemaValidator(core_schema.dataclass_args_schema('XYZ', [x, y, z]))

with pytest.raises(ValidationError) as exc_info:
v.validate_python(input_value)

assert exc_info.value.errors(include_url=False, include_context=False) == expected

0 comments on commit d93482e

Please sign in to comment.