Skip to content

Commit

Permalink
Functional tagged unions (pydantic#171)
Browse files Browse the repository at this point in the history
* Using a function as descriminator for tagged unions

* rename tag_key -> discriminator

* deal correctly with incorrect return value

* tweak error messages

* tweak Discriminator logic

* improve coverage
  • Loading branch information
samuelcolvin authored Jul 17, 2022
1 parent 17db7e3 commit 515ff10
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 122 deletions.
4 changes: 2 additions & 2 deletions pydantic_core/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from datetime import date, datetime, time, timedelta
from typing import Any, Callable, Dict, List, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union

if sys.version_info < (3, 11):
from typing_extensions import NotRequired, Required
Expand Down Expand Up @@ -194,7 +194,7 @@ class UnionSchema(TypedDict, total=False):
class TaggedUnionSchema(TypedDict):
type: Literal['tagged-union']
choices: Dict[str, Schema]
tag_key: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
discriminator: Union[str, List[Union[str, int]], List[List[Union[str, int]]], Callable[[Any], Optional[str]]]
strict: NotRequired[bool]
ref: NotRequired[str]

Expand Down
26 changes: 22 additions & 4 deletions src/errors/kinds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,16 @@ pub enum ErrorKind {
CallableType,
// ---------------------
// union errors
#[strum(message = "Input key \"{key}\" must match one of the allowed tags {tags}")]
UnionTagNotFound { key: String, tags: String },
#[strum(
message = "Input tag '{tag}' found using {discriminator} does not match any of the expected tags: {expected_tags}"
)]
UnionTagInvalid {
discriminator: String,
tag: String,
expected_tags: String,
},
#[strum(message = "Unable to extract tag using discriminator {discriminator}")]
UnionTagNotFound { discriminator: String },
}

macro_rules! render {
Expand Down Expand Up @@ -287,7 +295,12 @@ impl ErrorKind {
Self::DateTimeObjectInvalid { error } => render!(template, error),
Self::TimeDeltaParsing { error } => render!(template, error),
Self::IsInstanceOf { class } => render!(template, class),
Self::UnionTagNotFound { key, tags } => render!(template, key, tags),
Self::UnionTagInvalid {
discriminator,
tag,
expected_tags,
} => render!(template, discriminator, tag, expected_tags),
Self::UnionTagNotFound { discriminator } => render!(template, discriminator),
_ => template.to_string(),
}
}
Expand Down Expand Up @@ -335,7 +348,12 @@ impl ErrorKind {
Self::DateTimeObjectInvalid { error } => py_dict!(py, error),
Self::TimeDeltaParsing { error } => py_dict!(py, error),
Self::IsInstanceOf { class } => py_dict!(py, class),
Self::UnionTagNotFound { key, tags } => py_dict!(py, key, tags),
Self::UnionTagInvalid {
discriminator,
tag,
expected_tags,
} => py_dict!(py, discriminator, tag, expected_tags),
Self::UnionTagNotFound { discriminator } => py_dict!(py, discriminator),
_ => Ok(None),
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/errors/line_error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::PyDowncastError;

use crate::input::{Input, JsonInput};

Expand All @@ -20,6 +22,12 @@ impl<'a> From<PyErr> for ValError<'a> {
}
}

impl<'a> From<PyDowncastError<'_>> for ValError<'a> {
fn from(py_downcast: PyDowncastError) -> Self {
Self::InternalErr(PyTypeError::new_err(py_downcast.to_string()))
}
}

impl<'a> From<Vec<ValLineError<'a>>> for ValError<'a> {
fn from(line_errors: Vec<ValLineError<'a>>) -> Self {
Self::LineErrors(line_errors)
Expand Down
68 changes: 32 additions & 36 deletions src/lookup_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ pub enum LookupKey {
impl fmt::Display for LookupKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Simple(key, _) => write!(f, "{}", key),
Self::Choice(key1, key2, _, _) => write!(f, "{} | {}", key1, key2),
Self::Simple(key, _) => write!(f, "'{}'", key),
Self::Choice(key1, key2, _, _) => write!(f, "'{}' | '{}'", key1, key2),
Self::PathChoices(paths) => write!(
f,
"{}",
Expand All @@ -43,42 +43,38 @@ macro_rules! py_string {
}

impl LookupKey {
pub fn from_py(py: Python, field: &PyDict, alt_alias: Option<&str>, name: &str) -> PyResult<Option<Self>> {
if let Some(value) = field.get_item(name) {
if let Ok(alias_py) = value.cast_as::<PyString>() {
let alias: String = alias_py.extract()?;
let alias_py: Py<PyString> = alias_py.into_py(py);
match alt_alias {
Some(alt_alias) => Ok(Some(LookupKey::Choice(
alias,
alt_alias.to_string(),
alias_py,
py_string!(py, alt_alias),
))),
None => Ok(Some(LookupKey::Simple(alias, alias_py))),
}
pub fn from_py(py: Python, value: &PyAny, alt_alias: Option<&str>) -> PyResult<Self> {
if let Ok(alias_py) = value.cast_as::<PyString>() {
let alias: String = alias_py.extract()?;
let alias_py: Py<PyString> = alias_py.into_py(py);
match alt_alias {
Some(alt_alias) => Ok(LookupKey::Choice(
alias,
alt_alias.to_string(),
alias_py,
py_string!(py, alt_alias),
)),
None => Ok(LookupKey::Simple(alias, alias_py)),
}
} else {
let list: &PyList = value.cast_as()?;
let first = match list.get_item(0) {
Ok(v) => v,
Err(_) => return py_error!("Lookup paths must have at least one element"),
};
let mut locs: Vec<Path> = if first.cast_as::<PyString>().is_ok() {
// list of strings rather than list of lists
vec![Self::path_choice(py, list)?]
} else {
let list: &PyList = value.cast_as()?;
let first = match list.get_item(0) {
Ok(v) => v,
Err(_) => return py_error!("\"{}\" must have at least one element", name),
};
let mut locs: Vec<Path> = if first.cast_as::<PyString>().is_ok() {
// list of strings rather than list of lists
vec![Self::path_choice(py, list)?]
} else {
list.iter()
.map(|obj| Self::path_choice(py, obj))
.collect::<PyResult<_>>()?
};
list.iter()
.map(|obj| Self::path_choice(py, obj))
.collect::<PyResult<_>>()?
};

if let Some(alt_alias) = alt_alias {
locs.push(vec![PathItem::S(alt_alias.to_string(), py_string!(py, alt_alias))])
}
Ok(Some(LookupKey::PathChoices(locs)))
if let Some(alt_alias) = alt_alias {
locs.push(vec![PathItem::S(alt_alias.to_string(), py_string!(py, alt_alias))])
}
} else {
Ok(None)
Ok(LookupKey::PathChoices(locs))
}
}

Expand Down Expand Up @@ -219,7 +215,7 @@ pub enum PathItem {
impl fmt::Display for PathItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::S(key, _) => write!(f, "{}", key),
Self::S(key, _) => write!(f, "'{}'", key),
Self::I(key) => write!(f, "{}", key),
}
}
Expand Down
8 changes: 5 additions & 3 deletions src/validators/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@ impl BuildValidator for TypedDictValidator {
(default, default_factory) => (default, default_factory),
};

let alt_alias = if populate_by_name { Some(field_name) } else { None };
let lookup_key = match LookupKey::from_py(py, field_info, alt_alias, "alias")? {
Some(key) => key,
let lookup_key = match field_info.get_item("alias") {
Some(alias) => {
let alt_alias = if populate_by_name { Some(field_name) } else { None };
LookupKey::from_py(py, alias, alt_alias)?
}
None => LookupKey::from_string(py, field_name),
};
fields.push(TypedDictField {
Expand Down
Loading

0 comments on commit 515ff10

Please sign in to comment.