Skip to content

Commit

Permalink
Support subclass inits for Url and MultiHostUrl (#1508)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Oct 30, 2024
1 parent 8568136 commit fe73652
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 15 deletions.
2 changes: 1 addition & 1 deletion python/pydantic_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ErrorTypeInfo(_TypedDict):
"""Example of context values."""


class MultiHostHost(_TypedDict):
class MultiHostHost(_TypedDict, total=False):
"""
A host part of a multi-host URL.
"""
Expand Down
4 changes: 2 additions & 2 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ class Url(SupportsAllComparisons):
scheme: str,
username: str | None = None,
password: str | None = None,
host: str,
host: str | None = None,
port: int | None = None,
path: str | None = None,
query: str | None = None,
Expand All @@ -596,7 +596,7 @@ class Url(SupportsAllComparisons):
scheme: The scheme part of the URL.
username: The username part of the URL, or omit for no username.
password: The password part of the URL, or omit for no password.
host: The host part of the URL.
host: The host part of the URL, or omit for no host.
port: The port part of the URL, or omit for no port.
path: The path part of the URL, or omit for no path.
query: The query part of the URL, or omit for no query.
Expand Down
8 changes: 8 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3655,6 +3655,7 @@ class MyModel:

class UrlSchema(TypedDict, total=False):
type: Required[Literal['url']]
cls: Type[Any]
max_length: int
allowed_schemes: List[str]
host_required: bool # default False
Expand All @@ -3669,6 +3670,7 @@ class UrlSchema(TypedDict, total=False):

def url_schema(
*,
cls: Type[Any] | None = None,
max_length: int | None = None,
allowed_schemes: list[str] | None = None,
host_required: bool | None = None,
Expand All @@ -3693,6 +3695,7 @@ def url_schema(
```
Args:
cls: The class to use for the URL build (a subclass of `pydantic_core.Url`)
max_length: The maximum length of the URL
allowed_schemes: The allowed URL schemes
host_required: Whether the URL must have a host
Expand All @@ -3706,6 +3709,7 @@ def url_schema(
"""
return _dict_not_none(
type='url',
cls=cls,
max_length=max_length,
allowed_schemes=allowed_schemes,
host_required=host_required,
Expand All @@ -3721,6 +3725,7 @@ def url_schema(

class MultiHostUrlSchema(TypedDict, total=False):
type: Required[Literal['multi-host-url']]
cls: Type[Any]
max_length: int
allowed_schemes: List[str]
host_required: bool # default False
Expand All @@ -3735,6 +3740,7 @@ class MultiHostUrlSchema(TypedDict, total=False):

def multi_host_url_schema(
*,
cls: Type[Any] | None = None,
max_length: int | None = None,
allowed_schemes: list[str] | None = None,
host_required: bool | None = None,
Expand All @@ -3759,6 +3765,7 @@ def multi_host_url_schema(
```
Args:
cls: The class to use for the URL build (a subclass of `pydantic_core.MultiHostUrl`)
max_length: The maximum length of the URL
allowed_schemes: The allowed URL schemes
host_required: Whether the URL must have a host
Expand All @@ -3772,6 +3779,7 @@ def multi_host_url_schema(
"""
return _dict_not_none(
type='multi-host-url',
cls=cls,
max_length=max_length,
allowed_schemes=allowed_schemes,
host_required=host_required,
Expand Down
16 changes: 9 additions & 7 deletions src/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ impl PyUrl {
}

#[classmethod]
#[pyo3(signature=(*, scheme, host, username=None, password=None, port=None, path=None, query=None, fragment=None))]
#[pyo3(signature=(*, scheme, host=None, username=None, password=None, port=None, path=None, query=None, fragment=None))]
#[allow(clippy::too_many_arguments)]
pub fn build<'py>(
cls: &Bound<'py, PyType>,
scheme: &str,
host: &str,
host: Option<&str>,
username: Option<&str>,
password: Option<&str>,
port: Option<u16>,
Expand All @@ -172,7 +172,7 @@ impl PyUrl {
let url_host = UrlHostParts {
username: username.map(Into::into),
password: password.map(Into::into),
host: Some(host.into()),
host: host.map(Into::into),
port,
};
let mut url = format!("{scheme}://{url_host}");
Expand Down Expand Up @@ -423,6 +423,7 @@ impl PyMultiHostUrl {
}
}

#[cfg_attr(debug_assertions, derive(Debug))]
pub struct UrlHostParts {
username: Option<String>,
password: Option<String>,
Expand All @@ -440,11 +441,12 @@ impl FromPyObject<'_> for UrlHostParts {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
let py = ob.py();
let dict = ob.downcast::<PyDict>()?;

Ok(UrlHostParts {
username: dict.get_as(intern!(py, "username"))?,
password: dict.get_as(intern!(py, "password"))?,
host: dict.get_as(intern!(py, "host"))?,
port: dict.get_as(intern!(py, "port"))?,
username: dict.get_as::<Option<_>>(intern!(py, "username"))?.flatten(),
password: dict.get_as::<Option<_>>(intern!(py, "password"))?.flatten(),
host: dict.get_as::<Option<_>>(intern!(py, "host"))?.flatten(),
port: dict.get_as::<Option<_>>(intern!(py, "port"))?.flatten(),
})
}
}
Expand Down
69 changes: 64 additions & 5 deletions src/validators/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::str::Chars;

use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use pyo3::types::{PyDict, PyList, PyType};

use ahash::AHashSet;
use url::{ParseError, SyntaxViolation, Url};
Expand All @@ -26,6 +26,7 @@ type AllowedSchemas = Option<(AHashSet<String>, String)>;
#[derive(Debug, Clone)]
pub struct UrlValidator {
strict: bool,
cls: Option<Py<PyType>>,
max_length: Option<usize>,
allowed_schemes: AllowedSchemas,
host_required: bool,
Expand All @@ -47,6 +48,7 @@ impl BuildValidator for UrlValidator {

Ok(Self {
strict: is_strict(schema, config)?,
cls: schema.get_as(intern!(schema.py(), "cls"))?,
max_length: schema.get_as(intern!(schema.py(), "max_length"))?,
host_required: schema.get_as(intern!(schema.py(), "host_required"))?.unwrap_or(false),
default_host: schema.get_as(intern!(schema.py(), "default_host"))?,
Expand All @@ -59,7 +61,7 @@ impl BuildValidator for UrlValidator {
}
}

impl_py_gc_traverse!(UrlValidator {});
impl_py_gc_traverse!(UrlValidator { cls });

impl Validator for UrlValidator {
fn validate<'py>(
Expand Down Expand Up @@ -93,7 +95,31 @@ impl Validator for UrlValidator {
Ok(()) => {
// Lax rather than strict to preserve V2.4 semantic that str wins over url in union
state.floor_exactness(Exactness::Lax);
Ok(either_url.into_py(py))

if let Some(url_subclass) = &self.cls {
// TODO: we do an extra build for a subclass here, we should avoid this
// in v2.11 for perf reasons, but this is a worthwhile patch for now
// given that we want isinstance to work properly for subclasses of Url
let py_url = match either_url {
EitherUrl::Py(py_url) => py_url.get().clone(),
EitherUrl::Rust(rust_url) => PyUrl::new(rust_url),
};

let py_url = PyUrl::build(
url_subclass.bind(py),
py_url.scheme(),
py_url.host(),
py_url.username(),
py_url.password(),
py_url.port(),
py_url.path().filter(|path| *path != "/"),
py_url.query(),
py_url.fragment(),
)?;
Ok(py_url.into_py(py))
} else {
Ok(either_url.into_py(py))
}
}
Err(error_type) => Err(ValError::new(error_type, input)),
}
Expand Down Expand Up @@ -186,6 +212,7 @@ impl CopyFromPyUrl for EitherUrl<'_> {
#[derive(Debug, Clone)]
pub struct MultiHostUrlValidator {
strict: bool,
cls: Option<Py<PyType>>,
max_length: Option<usize>,
allowed_schemes: AllowedSchemas,
host_required: bool,
Expand Down Expand Up @@ -213,6 +240,7 @@ impl BuildValidator for MultiHostUrlValidator {
}
Ok(Self {
strict: is_strict(schema, config)?,
cls: schema.get_as(intern!(schema.py(), "cls"))?,
max_length: schema.get_as(intern!(schema.py(), "max_length"))?,
allowed_schemes,
host_required: schema.get_as(intern!(schema.py(), "host_required"))?.unwrap_or(false),
Expand All @@ -225,7 +253,7 @@ impl BuildValidator for MultiHostUrlValidator {
}
}

impl_py_gc_traverse!(MultiHostUrlValidator {});
impl_py_gc_traverse!(MultiHostUrlValidator { cls });

impl Validator for MultiHostUrlValidator {
fn validate<'py>(
Expand Down Expand Up @@ -258,7 +286,38 @@ impl Validator for MultiHostUrlValidator {
Ok(()) => {
// Lax rather than strict to preserve V2.4 semantic that str wins over url in union
state.floor_exactness(Exactness::Lax);
Ok(multi_url.into_py(py))

if let Some(url_subclass) = &self.cls {
// TODO: we do an extra build for a subclass here, we should avoid this
// in v2.11 for perf reasons, but this is a worthwhile patch for now
// given that we want isinstance to work properly for subclasses of Url
let py_url = match multi_url {
EitherMultiHostUrl::Py(py_url) => py_url.get().clone(),
EitherMultiHostUrl::Rust(rust_url) => rust_url,
};

let hosts = py_url
.hosts(py)?
.into_iter()
.map(|host| host.extract().expect("host should be a valid UrlHostParts"))
.collect();

let py_url = PyMultiHostUrl::build(
url_subclass.bind(py),
py_url.scheme(),
Some(hosts),
py_url.path().filter(|path| *path != "/"),
py_url.query(),
py_url.fragment(),
None,
None,
None,
None,
)?;
Ok(py_url.into_py(py))
} else {
Ok(multi_url.into_py(py))
}
}
Err(error_type) => Err(ValError::new(error_type, input)),
}
Expand Down
16 changes: 16 additions & 0 deletions tests/validators/test_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,3 +1305,19 @@ def test_url_build() -> None:
)
assert url == Url('postgresql://testuser:testpassword@127.0.0.1:5432/database?sslmode=require#test')
assert str(url) == 'postgresql://testuser:testpassword@127.0.0.1:5432/database?sslmode=require#test'


def test_url_subclass() -> None:
class UrlSubclass(Url):
pass

validator = SchemaValidator(core_schema.url_schema(cls=UrlSubclass))
assert isinstance(validator.validate_python('http://example.com'), UrlSubclass)


def test_multi_host_url_subclass() -> None:
class MultiHostUrlSubclass(MultiHostUrl):
pass

validator = SchemaValidator(core_schema.multi_host_url_schema(cls=MultiHostUrlSubclass))
assert isinstance(validator.validate_python('http://example.com'), MultiHostUrlSubclass)

0 comments on commit fe73652

Please sign in to comment.