Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support subclass inits for Url and MultiHostUrl #1508

Merged
merged 5 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pydantic_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ErrorTypeInfo(_TypedDict):
"""Example of context values."""


class MultiHostHost(_TypedDict):
class MultiHostHost(_TypedDict, total=False):
"""
A host part of a multi-host URL.
"""
Expand Down
4 changes: 2 additions & 2 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ class Url(SupportsAllComparisons):
scheme: str,
username: str | None = None,
password: str | None = None,
host: str,
host: str | None = None,
port: int | None = None,
path: str | None = None,
query: str | None = None,
Expand All @@ -596,7 +596,7 @@ class Url(SupportsAllComparisons):
scheme: The scheme part of the URL.
username: The username part of the URL, or omit for no username.
password: The password part of the URL, or omit for no password.
host: The host part of the URL.
host: The host part of the URL, or omit for no host.
port: The port part of the URL, or omit for no port.
path: The path part of the URL, or omit for no path.
query: The query part of the URL, or omit for no query.
Expand Down
8 changes: 8 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3655,6 +3655,7 @@ class MyModel:

class UrlSchema(TypedDict, total=False):
type: Required[Literal['url']]
cls: Type[Any]
max_length: int
allowed_schemes: List[str]
host_required: bool # default False
Expand All @@ -3669,6 +3670,7 @@ class UrlSchema(TypedDict, total=False):

def url_schema(
*,
cls: Type[Any] | None = None,
max_length: int | None = None,
allowed_schemes: list[str] | None = None,
host_required: bool | None = None,
Expand All @@ -3693,6 +3695,7 @@ def url_schema(
```

Args:
cls: The class to use for the URL build (a subclass of `pydantic_core.Url`)
max_length: The maximum length of the URL
allowed_schemes: The allowed URL schemes
host_required: Whether the URL must have a host
Expand All @@ -3706,6 +3709,7 @@ def url_schema(
"""
return _dict_not_none(
type='url',
cls=cls,
max_length=max_length,
allowed_schemes=allowed_schemes,
host_required=host_required,
Expand All @@ -3721,6 +3725,7 @@ def url_schema(

class MultiHostUrlSchema(TypedDict, total=False):
type: Required[Literal['multi-host-url']]
cls: Type[Any]
max_length: int
allowed_schemes: List[str]
host_required: bool # default False
Expand All @@ -3735,6 +3740,7 @@ class MultiHostUrlSchema(TypedDict, total=False):

def multi_host_url_schema(
*,
cls: Type[Any] | None = None,
max_length: int | None = None,
allowed_schemes: list[str] | None = None,
host_required: bool | None = None,
Expand All @@ -3759,6 +3765,7 @@ def multi_host_url_schema(
```

Args:
cls: The class to use for the URL build (a subclass of `pydantic_core.MultiHostUrl`)
max_length: The maximum length of the URL
allowed_schemes: The allowed URL schemes
host_required: Whether the URL must have a host
Expand All @@ -3772,6 +3779,7 @@ def multi_host_url_schema(
"""
return _dict_not_none(
type='multi-host-url',
cls=cls,
max_length=max_length,
allowed_schemes=allowed_schemes,
host_required=host_required,
Expand Down
16 changes: 9 additions & 7 deletions src/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ impl PyUrl {
}

#[classmethod]
#[pyo3(signature=(*, scheme, host, username=None, password=None, port=None, path=None, query=None, fragment=None))]
#[pyo3(signature=(*, scheme, host=None, username=None, password=None, port=None, path=None, query=None, fragment=None))]
#[allow(clippy::too_many_arguments)]
pub fn build<'py>(
cls: &Bound<'py, PyType>,
scheme: &str,
host: &str,
host: Option<&str>,
username: Option<&str>,
password: Option<&str>,
port: Option<u16>,
Expand All @@ -172,7 +172,7 @@ impl PyUrl {
let url_host = UrlHostParts {
username: username.map(Into::into),
password: password.map(Into::into),
host: Some(host.into()),
host: host.map(Into::into),
port,
};
let mut url = format!("{scheme}://{url_host}");
Expand Down Expand Up @@ -423,6 +423,7 @@ impl PyMultiHostUrl {
}
}

#[cfg_attr(debug_assertions, derive(Debug))]
pub struct UrlHostParts {
username: Option<String>,
password: Option<String>,
Expand All @@ -440,11 +441,12 @@ impl FromPyObject<'_> for UrlHostParts {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
let py = ob.py();
let dict = ob.downcast::<PyDict>()?;

Ok(UrlHostParts {
username: dict.get_as(intern!(py, "username"))?,
password: dict.get_as(intern!(py, "password"))?,
host: dict.get_as(intern!(py, "host"))?,
port: dict.get_as(intern!(py, "port"))?,
username: dict.get_as::<Option<_>>(intern!(py, "username"))?.flatten(),
password: dict.get_as::<Option<_>>(intern!(py, "password"))?.flatten(),
host: dict.get_as::<Option<_>>(intern!(py, "host"))?.flatten(),
port: dict.get_as::<Option<_>>(intern!(py, "port"))?.flatten(),
})
}
}
Expand Down
69 changes: 64 additions & 5 deletions src/validators/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::str::Chars;

use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use pyo3::types::{PyDict, PyList, PyType};

use ahash::AHashSet;
use url::{ParseError, SyntaxViolation, Url};
Expand All @@ -26,6 +26,7 @@ type AllowedSchemas = Option<(AHashSet<String>, String)>;
#[derive(Debug, Clone)]
pub struct UrlValidator {
strict: bool,
cls: Option<Py<PyType>>,
max_length: Option<usize>,
allowed_schemes: AllowedSchemas,
host_required: bool,
Expand All @@ -47,6 +48,7 @@ impl BuildValidator for UrlValidator {

Ok(Self {
strict: is_strict(schema, config)?,
cls: schema.get_as(intern!(schema.py(), "cls"))?,
max_length: schema.get_as(intern!(schema.py(), "max_length"))?,
host_required: schema.get_as(intern!(schema.py(), "host_required"))?.unwrap_or(false),
default_host: schema.get_as(intern!(schema.py(), "default_host"))?,
Expand All @@ -59,7 +61,7 @@ impl BuildValidator for UrlValidator {
}
}

impl_py_gc_traverse!(UrlValidator {});
impl_py_gc_traverse!(UrlValidator { cls });

impl Validator for UrlValidator {
fn validate<'py>(
Expand Down Expand Up @@ -93,7 +95,31 @@ impl Validator for UrlValidator {
Ok(()) => {
// Lax rather than strict to preserve V2.4 semantic that str wins over url in union
state.floor_exactness(Exactness::Lax);
Ok(either_url.into_py(py))

if let Some(url_subclass) = &self.cls {
// TODO: we do an extra build for a subclass here, we should avoid this
// in v2.11 for perf reasons, but this is a worthwhile patch for now
// given that we want isinstance to work properly for subclasses of Url
let py_url = match either_url {
EitherUrl::Py(py_url) => py_url.get().clone(),
EitherUrl::Rust(rust_url) => PyUrl::new(rust_url),
};

let py_url = PyUrl::build(
url_subclass.bind(py),
py_url.scheme(),
py_url.host(),
py_url.username(),
py_url.password(),
py_url.port(),
py_url.path().filter(|path| *path != "/"),
py_url.query(),
py_url.fragment(),
)?;
Ok(py_url.into_py(py))
} else {
Ok(either_url.into_py(py))
}
}
Err(error_type) => Err(ValError::new(error_type, input)),
}
Expand Down Expand Up @@ -186,6 +212,7 @@ impl CopyFromPyUrl for EitherUrl<'_> {
#[derive(Debug, Clone)]
pub struct MultiHostUrlValidator {
strict: bool,
cls: Option<Py<PyType>>,
max_length: Option<usize>,
allowed_schemes: AllowedSchemas,
host_required: bool,
Expand Down Expand Up @@ -213,6 +240,7 @@ impl BuildValidator for MultiHostUrlValidator {
}
Ok(Self {
strict: is_strict(schema, config)?,
cls: schema.get_as(intern!(schema.py(), "cls"))?,
max_length: schema.get_as(intern!(schema.py(), "max_length"))?,
allowed_schemes,
host_required: schema.get_as(intern!(schema.py(), "host_required"))?.unwrap_or(false),
Expand All @@ -225,7 +253,7 @@ impl BuildValidator for MultiHostUrlValidator {
}
}

impl_py_gc_traverse!(MultiHostUrlValidator {});
impl_py_gc_traverse!(MultiHostUrlValidator { cls });

impl Validator for MultiHostUrlValidator {
fn validate<'py>(
Expand Down Expand Up @@ -258,7 +286,38 @@ impl Validator for MultiHostUrlValidator {
Ok(()) => {
// Lax rather than strict to preserve V2.4 semantic that str wins over url in union
state.floor_exactness(Exactness::Lax);
Ok(multi_url.into_py(py))

if let Some(url_subclass) = &self.cls {
// TODO: we do an extra build for a subclass here, we should avoid this
// in v2.11 for perf reasons, but this is a worthwhile patch for now
// given that we want isinstance to work properly for subclasses of Url
let py_url = match multi_url {
EitherMultiHostUrl::Py(py_url) => py_url.get().clone(),
EitherMultiHostUrl::Rust(rust_url) => rust_url,
};

let hosts = py_url
.hosts(py)?
.into_iter()
.map(|host| host.extract().expect("host should be a valid UrlHostParts"))
.collect();

let py_url = PyMultiHostUrl::build(
url_subclass.bind(py),
py_url.scheme(),
Some(hosts),
py_url.path().filter(|path| *path != "/"),
py_url.query(),
py_url.fragment(),
None,
None,
None,
None,
)?;
Ok(py_url.into_py(py))
} else {
Ok(multi_url.into_py(py))
}
}
Err(error_type) => Err(ValError::new(error_type, input)),
}
Expand Down
16 changes: 16 additions & 0 deletions tests/validators/test_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,3 +1305,19 @@ def test_url_build() -> None:
)
assert url == Url('postgresql://testuser:testpassword@127.0.0.1:5432/database?sslmode=require#test')
assert str(url) == 'postgresql://testuser:testpassword@127.0.0.1:5432/database?sslmode=require#test'


def test_url_subclass() -> None:
class UrlSubclass(Url):
pass

validator = SchemaValidator(core_schema.url_schema(cls=UrlSubclass))
assert isinstance(validator.validate_python('http://example.com'), UrlSubclass)


def test_multi_host_url_subclass() -> None:
class MultiHostUrlSubclass(MultiHostUrl):
pass

validator = SchemaValidator(core_schema.multi_host_url_schema(cls=MultiHostUrlSubclass))
assert isinstance(validator.validate_python('http://example.com'), MultiHostUrlSubclass)
Loading