Skip to content

Commit

Permalink
Fix union validation logic when extra='allow' (#1334)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle committed Jun 20, 2024
1 parent fcc77f8 commit a65f327
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/validators/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ impl Validator for ModelValidator {
for field_name in validated_fields_set {
fields_set.add(field_name)?;
}
state.add_fields_set(fields_set.len());
}

force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?;
Expand Down Expand Up @@ -244,11 +243,9 @@ impl ModelValidator {
};
force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
force_setattr(py, self_instance, intern!(py, ROOT_FIELD), &output)?;
state.add_fields_set(fields_set.len());
} else {
let (model_dict, model_extra, fields_set): (Bound<PyAny>, Bound<PyAny>, Bound<PyAny>) =
output.extract(py)?;
state.add_fields_set(fields_set.len().unwrap_or(0));
set_model_attrs(self_instance, &model_dict, &model_extra, &fields_set)?;
}
self.call_post_init(py, self_instance.clone(), input, state.extra())
Expand Down Expand Up @@ -287,11 +284,10 @@ impl ModelValidator {
};
force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
force_setattr(py, &instance, intern!(py, ROOT_FIELD), output)?;
state.add_fields_set(fields_set.len());
} else {
let (model_dict, model_extra, val_fields_set) = output.extract(py)?;
let (model_dict, model_extra, val_fields_set): (Bound<PyAny>, Bound<PyAny>, Bound<PyAny>) =
output.extract(py)?;
let fields_set = existing_fields_set.unwrap_or(&val_fields_set);
state.add_fields_set(fields_set.len().unwrap_or(0));
set_model_attrs(&instance, &model_dict, &model_extra, fields_set)?;
}
self.call_post_init(py, instance, input, state.extra())
Expand Down
3 changes: 3 additions & 0 deletions src/validators/model_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ impl Validator for ModelFieldsValidator {
let mut model_extra_dict_op: Option<Bound<PyDict>> = None;
let mut errors: Vec<ValLineError> = Vec::with_capacity(self.fields.len());
let mut fields_set_vec: Vec<Py<PyString>> = Vec::with_capacity(self.fields.len());
let mut fields_set_count: usize = 0;

// we only care about which keys have been used if we're iterating over the object for extra after
// the first pass
Expand Down Expand Up @@ -184,6 +185,7 @@ impl Validator for ModelFieldsValidator {
Ok(value) => {
model_dict.set_item(&field.name_py, value)?;
fields_set_vec.push(field.name_py.clone_ref(py));
fields_set_count += 1;
}
Err(ValError::Omit) => continue,
Err(ValError::LineErrors(line_errors)) => {
Expand Down Expand Up @@ -327,6 +329,7 @@ impl Validator for ModelFieldsValidator {
Err(ValError::LineErrors(errors))
} else {
let fields_set = PySet::new_bound(py, &fields_set_vec)?;
state.add_fields_set(fields_set_count);

// if we have extra=allow, but we didn't create a dict because we were validating
// from attributes, set it now so __pydantic_extra__ is always a dict if extra=allow
Expand Down
4 changes: 4 additions & 0 deletions src/validators/validation_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ pub enum Exactness {
pub struct ValidationState<'a, 'py> {
pub recursion_guard: &'a mut RecursionState,
pub exactness: Option<Exactness>,
// This is used as a tie-breaking mechanism for union validation.
// Note: the count of the fields set is not always equivalent to the length of the
// `model_fields_set` attached to a model. `model_fields_set` includes extra fields
// when extra='allow', whereas this tally does not.
pub fields_set_count: Option<usize>,
// deliberately make Extra readonly
extra: Extra<'a, 'py>,
Expand Down
53 changes: 53 additions & 0 deletions tests/validators/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,3 +1280,56 @@ class ModelB:
)
assert isinstance(result, ModelB)
assert isinstance(result.b, SubModelW)


@pytest.mark.parametrize('extra_behavior', ['forbid', 'ignore', 'allow'])
def test_smart_union_extra_behavior(extra_behavior) -> None:
class Foo:
foo: str = 'foo'

class Bar:
bar: str = 'bar'

class Model:
x: Union[Foo, Bar]

validator = SchemaValidator(
core_schema.model_schema(
Model,
core_schema.model_fields_schema(
fields={
'x': core_schema.model_field(
core_schema.union_schema(
[
core_schema.model_schema(
Foo,
core_schema.model_fields_schema(
fields={
'foo': core_schema.model_field(
core_schema.with_default_schema(core_schema.str_schema(), default='foo')
)
}
),
extra_behavior=extra_behavior,
),
core_schema.model_schema(
Bar,
core_schema.model_fields_schema(
fields={
'bar': core_schema.model_field(
core_schema.with_default_schema(core_schema.str_schema(), default='bar')
)
}
),
extra_behavior=extra_behavior,
),
]
)
)
}
),
)
)

assert isinstance(validator.validate_python({'x': {'foo': 'foo'}}).x, Foo)
assert isinstance(validator.validate_python({'x': {'bar': 'bar'}}).x, Bar)

0 comments on commit a65f327

Please sign in to comment.