Skip to content

Commit

Permalink
More annotations that I forgot to break up into multiple commits.
Browse files Browse the repository at this point in the history
- validators.py
 - Finish annotating return types.
 - Change ensure_one_of takes a `Collection`, not a `Container`, since it needs to be iterable within `UnpermittedComponentError.__init__`.
 - Change `authority_is_valid` to permit None as an input; continuation of making sure is_valid allowing None propogates. Also, this behavior is depended on elsewhere in the library (just one spot, I think).
- parseresult.py
 - Add variable annotations to `ParseResultMixin`, and make sure _generate_authority is allowed to return `None`.
 - Fix `ParseResultBytes.copy_with` not accepting an int for port.
 - Annotate return type for `authority_from`.
- misc.py
 - Use common base for `URIReference` and `IRIReference` as annotation for `merge_path` and remove circular import.
- exceptions.py
 - Annotate everything.
- _mixin.py
 - Add variable annotations to `URIMixin`; they're under a TYPE_CHECKING block so that only the subclasse's annotations can be found in cases of introspection. Might be overkill.
 - Use `uri.URIReference` to annotate parameters for various functions.
  - TODO: Check if these are potentially too wide, since `IRIReference` also exists and inherits from `URIMixin`?
  - Use hacky "typing.cast within an elided if block" trick to improve typing within `URIMixin.resolve_with`.
  • Loading branch information
Sachaa-Thanasius committed Jun 17, 2024
1 parent 9862d65 commit 3c22353
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 36 deletions.
22 changes: 19 additions & 3 deletions src/rfc3986/_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import exceptions as exc
from . import misc
from . import normalizers
from . import uri
from . import validators
from ._typing_compat import Self as _Self

Expand All @@ -20,6 +21,14 @@ class _AuthorityInfo(t.TypedDict):
class URIMixin:
"""Mixin with all shared methods for URIs and IRIs."""

if t.TYPE_CHECKING:
scheme: t.Optional[str]
authority: t.Optional[str]
path: t.Optional[str]
query: t.Optional[str]
fragment: t.Optional[str]
encoding: str

def authority_info(self) -> _AuthorityInfo:
"""Return a dictionary with the ``userinfo``, ``host``, and ``port``.
Expand Down Expand Up @@ -251,7 +260,7 @@ def fragment_is_valid(self, require: bool = False) -> bool:
)
return validators.fragment_is_valid(self.fragment, require)

def normalized_equality(self, other_ref) -> bool:
def normalized_equality(self, other_ref: "uri.URIReference") -> bool:
"""Compare this URIReference to another URIReference.
:param URIReference other_ref: (required), The reference with which
Expand All @@ -261,7 +270,11 @@ def normalized_equality(self, other_ref) -> bool:
"""
return tuple(self.normalize()) == tuple(other_ref.normalize())

def resolve_with(self, base_uri, strict: bool = False) -> _Self:
def resolve_with(
self,
base_uri: t.Union[str, "uri.URIReference"],
strict: bool = False,
) -> _Self:
"""Use an absolute URI Reference to resolve this relative reference.
Assuming this is a relative reference that you would like to resolve,
Expand All @@ -280,6 +293,9 @@ def resolve_with(self, base_uri, strict: bool = False) -> _Self:
if not isinstance(base_uri, URIMixin):
base_uri = type(self).from_string(base_uri)

if t.TYPE_CHECKING:
base_uri = t.cast(uri.URIReference, base_uri)

try:
self._validator.validate(base_uri)
except exc.ValidationError:
Expand Down Expand Up @@ -388,6 +404,6 @@ def copy_with(
for key, value in list(attributes.items()):
if value is misc.UseExisting:
del attributes[key]
uri = self._replace(**attributes)
uri: "uri.URIReference" = self._replace(**attributes)
uri.encoding = self.encoding
return uri
22 changes: 15 additions & 7 deletions src/rfc3986/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Exceptions module for rfc3986."""
import typing as t

from . import compat
from . import uri


class RFC3986Exception(Exception):
Expand All @@ -11,7 +14,7 @@ class RFC3986Exception(Exception):
class InvalidAuthority(RFC3986Exception):
"""Exception when the authority string is invalid."""

def __init__(self, authority):
def __init__(self, authority: t.Union[str, bytes]) -> None:
"""Initialize the exception with the invalid authority."""
super().__init__(
f"The authority ({compat.to_str(authority)}) is not valid."
Expand All @@ -21,15 +24,15 @@ def __init__(self, authority):
class InvalidPort(RFC3986Exception):
"""Exception when the port is invalid."""

def __init__(self, port):
def __init__(self, port: str) -> None:
"""Initialize the exception with the invalid port."""
super().__init__(f'The port ("{port}") is not valid.')


class ResolutionError(RFC3986Exception):
"""Exception to indicate a failure to resolve a URI."""

def __init__(self, uri):
def __init__(self, uri: "uri.URIReference") -> None:
"""Initialize the error with the failed URI."""
super().__init__(
"{} does not meet the requirements for resolution.".format(
Expand All @@ -47,7 +50,7 @@ class ValidationError(RFC3986Exception):
class MissingComponentError(ValidationError):
"""Exception raised when a required component is missing."""

def __init__(self, uri, *component_names):
def __init__(self, uri: "uri.URIReference", *component_names: str) -> None:
"""Initialize the error with the missing component name."""
verb = "was"
if len(component_names) > 1:
Expand All @@ -66,7 +69,12 @@ def __init__(self, uri, *component_names):
class UnpermittedComponentError(ValidationError):
"""Exception raised when a component has an unpermitted value."""

def __init__(self, component_name, component_value, allowed_values):
def __init__(
self,
component_name: str,
component_value: t.Any,
allowed_values: t.Collection[t.Any],
) -> None:
"""Initialize the error with the unpermitted component."""
super().__init__(
"{} was required to be one of {!r} but was {!r}".format(
Expand All @@ -86,7 +94,7 @@ def __init__(self, component_name, component_value, allowed_values):
class PasswordForbidden(ValidationError):
"""Exception raised when a URL has a password in the userinfo section."""

def __init__(self, uri):
def __init__(self, uri: t.Union[str, "uri.URIReference"]) -> None:
"""Initialize the error with the URI that failed validation."""
unsplit = getattr(uri, "unsplit", lambda: uri)
super().__init__(
Expand All @@ -100,7 +108,7 @@ def __init__(self, uri):
class InvalidComponentsError(ValidationError):
"""Exception raised when one or more components are invalid."""

def __init__(self, uri, *component_names):
def __init__(self, uri: "uri.URIReference", *component_names: str) -> None:
"""Initialize the error with the invalid component name(s)."""
verb = "was"
if len(component_names) > 1:
Expand Down
6 changes: 1 addition & 5 deletions src/rfc3986/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@

from . import abnf_regexp

if t.TYPE_CHECKING:
# Break an import loop.
from . import uri


class URIReferenceBase(t.NamedTuple):
"""The namedtuple used as a superclass of URIReference and IRIReference."""
Expand Down Expand Up @@ -130,7 +126,7 @@ class URIReferenceBase(t.NamedTuple):


# Path merger as defined in http://tools.ietf.org/html/rfc3986#section-5.2.3
def merge_paths(base_uri: "uri.URIReference", relative_path: str) -> str:
def merge_paths(base_uri: URIReferenceBase, relative_path: str) -> str:
"""Merge a base URI's path with a relative URI's path."""
if base_uri.path is None and base_uri.authority is not None:
return "/" + relative_path
Expand Down
20 changes: 17 additions & 3 deletions src/rfc3986/parseresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,21 @@


class ParseResultMixin(t.Generic[t.AnyStr]):
if t.TYPE_CHECKING:
userinfo: t.Optional[t.AnyStr]
host: t.Optional[t.AnyStr]
port: t.Optional[int]
query: t.Optional[t.AnyStr]
encoding: str

@property
def authority(self) -> t.Optional[t.AnyStr]:
...

def _generate_authority(
self,
attributes: t.Dict[str, t.Optional[t.AnyStr]],
) -> str:
) -> t.Optional[str]:
# I swear I did not align the comparisons below. That's just how they
# happened to align based on pep8 and attribute lengths.
userinfo, host, port = (
Expand Down Expand Up @@ -402,7 +413,7 @@ def copy_with(
scheme: t.Optional[t.Union[str, bytes]] = misc.UseExisting,
userinfo: t.Optional[t.Union[str, bytes]] = misc.UseExisting,
host: t.Optional[t.Union[str, bytes]] = misc.UseExisting,
port: t.Optional[t.Union[str, bytes]] = misc.UseExisting,
port: t.Optional[t.Union[int, str, bytes]] = misc.UseExisting,
path: t.Optional[t.Union[str, bytes]] = misc.UseExisting,
query: t.Optional[t.Union[str, bytes]] = misc.UseExisting,
fragment: t.Optional[t.Union[str, bytes]] = misc.UseExisting,
Expand Down Expand Up @@ -490,7 +501,10 @@ def split_authority(
return userinfo, host, port


def authority_from(reference: "uri.URIReference", strict: bool):
def authority_from(
reference: "uri.URIReference",
strict: bool,
) -> t.Tuple[t.Optional[str], t.Optional[str], t.Optional[int]]:
try:
subauthority = reference.authority_info()
except exceptions.InvalidAuthority:
Expand Down
39 changes: 21 additions & 18 deletions src/rfc3986/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import misc
from . import normalizers
from . import uri
from ._typing_compat import Self as _Self


class Validator:
Expand Down Expand Up @@ -51,13 +52,13 @@ class Validator:
["scheme", "userinfo", "host", "port", "path", "query", "fragment"]
)

def __init__(self):
def __init__(self) -> None:
"""Initialize our default validations."""
self.allowed_schemes: set[str] = set()
self.allowed_hosts: set[str] = set()
self.allowed_ports: set[str] = set()
self.allow_password = True
self.required_components = {
self.allowed_schemes: t.Set[str] = set()
self.allowed_hosts: t.Set[str] = set()
self.allowed_ports: t.Set[str] = set()
self.allow_password: bool = True
self.required_components: t.Dict[str, bool] = {
"scheme": False,
"userinfo": False,
"host": False,
Expand All @@ -66,9 +67,11 @@ def __init__(self):
"query": False,
"fragment": False,
}
self.validated_components = self.required_components.copy()
self.validated_components: t.Dict[
str, bool
] = self.required_components.copy()

def allow_schemes(self, *schemes: str):
def allow_schemes(self, *schemes: str) -> _Self:
"""Require the scheme to be one of the provided schemes.
.. versionadded:: 1.0
Expand All @@ -84,7 +87,7 @@ def allow_schemes(self, *schemes: str):
self.allowed_schemes.add(normalizers.normalize_scheme(scheme))
return self

def allow_hosts(self, *hosts: str):
def allow_hosts(self, *hosts: str) -> _Self:
"""Require the host to be one of the provided hosts.
.. versionadded:: 1.0
Expand All @@ -100,7 +103,7 @@ def allow_hosts(self, *hosts: str):
self.allowed_hosts.add(normalizers.normalize_host(host))
return self

def allow_ports(self, *ports: str):
def allow_ports(self, *ports: str) -> _Self:
"""Require the port to be one of the provided ports.
.. versionadded:: 1.0
Expand All @@ -118,7 +121,7 @@ def allow_ports(self, *ports: str):
self.allowed_ports.add(port)
return self

def allow_use_of_password(self):
def allow_use_of_password(self) -> _Self:
"""Allow passwords to be present in the URI.
.. versionadded:: 1.0
Expand All @@ -131,7 +134,7 @@ def allow_use_of_password(self):
self.allow_password = True
return self

def forbid_use_of_password(self):
def forbid_use_of_password(self) -> _Self:
"""Prevent passwords from being included in the URI.
.. versionadded:: 1.0
Expand All @@ -144,7 +147,7 @@ def forbid_use_of_password(self):
self.allow_password = False
return self

def check_validity_of(self, *components: str):
def check_validity_of(self, *components: str) -> _Self:
"""Check the validity of the components provided.
This can be specified repeatedly.
Expand All @@ -167,7 +170,7 @@ def check_validity_of(self, *components: str):
)
return self

def require_presence_of(self, *components: str):
def require_presence_of(self, *components: str) -> _Self:
"""Require the components provided.
This can be specified repeatedly.
Expand All @@ -190,7 +193,7 @@ def require_presence_of(self, *components: str):
)
return self

def validate(self, uri: "uri.URIReference"):
def validate(self, uri: "uri.URIReference") -> None:
"""Check a URI for conditions specified on this validator.
.. versionadded:: 1.0
Expand Down Expand Up @@ -244,7 +247,7 @@ def check_password(uri: "uri.URIReference") -> None:


def ensure_one_of(
allowed_values: t.Container[object],
allowed_values: t.Collection[object],
uri: "uri.URIReference",
attribute: str,
) -> None:
Expand All @@ -261,7 +264,7 @@ def ensure_one_of(
def ensure_required_components_exist(
uri: "uri.URIReference",
required_components: t.Iterable[str],
):
) -> None:
"""Assert that all required components are present in the URI."""
missing_components = sorted(
component
Expand Down Expand Up @@ -294,7 +297,7 @@ def is_valid(


def authority_is_valid(
authority: str,
authority: t.Optional[str],
host: t.Optional[str] = None,
require: bool = False,
) -> bool:
Expand Down

0 comments on commit 3c22353

Please sign in to comment.