diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2425ab35..0d607cf1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,6 +30,9 @@ jobs: - python-version: 3.9 env: TOXENV: black + - python-version: 3.9 + env: + TOXENV: typing steps: - uses: actions/checkout@v2 diff --git a/.gitignore b/.gitignore index a279e1af..714a9be8 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,7 @@ _trial_temp .coverage coverage.xml .cache +.mypy_cache/ +/index.txt +.dmypy.json +.hypothesis/ diff --git a/docs/w3lib.rst b/docs/w3lib.rst index bfde0304..502554ff 100644 --- a/docs/w3lib.rst +++ b/docs/w3lib.rst @@ -26,3 +26,6 @@ w3lib Package .. automodule:: w3lib.url :members: + + .. autoclass:: ParseDataURIResult + :members: diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..d4c7c859 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,12 @@ +[mypy] +exclude = .*flycheck_.* +show_error_codes = True +check_untyped_defs = True + +[mypy-w3lib.*] +# All non-tests functions must be typed. +disallow_untyped_defs = True + +[mypy-tests.*] +# Allow test functions to be untyped +disallow_untyped_defs = False diff --git a/tests/test_encoding.py b/tests/test_encoding.py index b9e78922..33d7f110 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -38,14 +38,16 @@ def test_bom(self): utf32le = b"\xff\xfe\x00\x00\x34\x6c\x00\x00" for string in (utf16be, utf16le, utf32be, utf32le): bom_encoding, bom = read_bom(string) + assert bom_encoding is not None + assert bom is not None decoded = string[len(bom) :].decode(bom_encoding) self.assertEqual(water_unicode, decoded) # Body without BOM - enc, bom = read_bom("foo") + enc, bom = read_bom(b"foo") self.assertEqual(enc, None) self.assertEqual(bom, None) # Empty body - enc, bom = read_bom("") + enc, bom = read_bom(b"") self.assertEqual(enc, None) self.assertEqual(bom, None) diff --git a/tests/test_html.py b/tests/test_html.py index 5a092f29..f6ca90d2 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -69,7 +69,7 @@ def test_illegal_entities(self): def test_browser_hack(self): # check browser hack for numeric character references in the 80-9F range self.assertEqual(replace_entities("x™y", encoding="cp1252"), "x\u2122y") - self.assertEqual(replace_entities("x™y", encoding="cp1252"), u"x\u2122y") + self.assertEqual(replace_entities("x™y", encoding="cp1252"), "x\u2122y") def test_missing_semicolon(self): for entity, result in ( diff --git a/tests/test_http.py b/tests/test_http.py index 127f4de9..efabb0ab 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,6 +1,11 @@ import unittest from collections import OrderedDict -from w3lib.http import basic_auth_header, headers_dict_to_raw, headers_raw_to_dict +from w3lib.http import ( + HeadersDictInput, + basic_auth_header, + headers_dict_to_raw, + headers_raw_to_dict, +) __doctests__ = ["w3lib.http"] # for trial support @@ -47,7 +52,9 @@ def test_headers_dict_to_raw(self): ) def test_headers_dict_to_raw_listtuple(self): - dct = OrderedDict([(b"Content-type", [b"text/html"]), (b"Accept", [b"gzip"])]) + dct: HeadersDictInput = OrderedDict( + [(b"Content-type", [b"text/html"]), (b"Accept", [b"gzip"])] + ) self.assertEqual( headers_dict_to_raw(dct), b"Content-type: text/html\r\nAccept: gzip" ) @@ -70,12 +77,13 @@ def test_headers_dict_to_raw_listtuple(self): ) def test_headers_dict_to_raw_wrong_values(self): - dct = OrderedDict( + dct: HeadersDictInput = OrderedDict( [ (b"Content-type", 0), ] ) self.assertEqual(headers_dict_to_raw(dct), b"") + self.assertEqual(headers_dict_to_raw(dct), b"") dct = OrderedDict([(b"Content-type", 1), (b"Accept", [b"gzip"])]) self.assertEqual(headers_dict_to_raw(dct), b"Accept: gzip") diff --git a/tests/test_url.py b/tests/test_url.py index edd816c6..fe9ee999 100644 --- a/tests/test_url.py +++ b/tests/test_url.py @@ -508,7 +508,7 @@ def test_add_or_replace_parameters(self): def test_add_or_replace_parameters_does_not_change_input_param(self): url = "http://domain/test?arg=original" input_param = {"arg": "value"} - new_url = add_or_replace_parameters(url, input_param) # noqa + add_or_replace_parameters(url, input_param) # noqa self.assertEqual(input_param, {"arg": "value"}) def test_url_query_cleaner(self): @@ -817,15 +817,18 @@ def test_non_ascii_percent_encoding_in_paths(self): self.assertEqual( canonicalize_url("http://www.example.com/a do?a=1"), "http://www.example.com/a%20do?a=1", - ), + ) + self.assertEqual( canonicalize_url("http://www.example.com/a %20do?a=1"), "http://www.example.com/a%20%20do?a=1", - ), + ) + self.assertEqual( canonicalize_url("http://www.example.com/a do£.html?a=1"), "http://www.example.com/a%20do%C2%A3.html?a=1", ) + self.assertEqual( canonicalize_url(b"http://www.example.com/a do\xc2\xa3.html?a=1"), "http://www.example.com/a%20do%C2%A3.html?a=1", diff --git a/tests/test_util.py b/tests/test_util.py index 7243d175..088147c0 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -20,7 +20,7 @@ def test_deprecation(self): class ToBytesTestCase(TestCase): def test_type_error(self): with raises(TypeError): - to_bytes(True) + to_bytes(True) # type: ignore class ToNativeStrTestCase(TestCase): @@ -32,7 +32,7 @@ def test_deprecation(self): class ToUnicodeTestCase(TestCase): def test_type_error(self): with raises(TypeError): - to_unicode(True) + to_unicode(True) # type: ignore class UnicodeToStrTestCase(TestCase): diff --git a/tox.ini b/tox.ini index 29e9145f..4e8e4767 100644 --- a/tox.ini +++ b/tox.ini @@ -22,6 +22,15 @@ deps = commands = bandit -r -c .bandit.yml {posargs:w3lib} +[testenv:typing] +basepython = python3 +deps = + # mypy would error if pytest (or its sub) not found + pytest + mypy==0.910 +commands = + mypy --show-error-codes {posargs: w3lib tests} + [testenv:flake8] basepython = python3 deps = diff --git a/w3lib/_types.py b/w3lib/_types.py new file mode 100644 index 00000000..84499a6c --- /dev/null +++ b/w3lib/_types.py @@ -0,0 +1,5 @@ +from typing import Union + +# the base class UnicodeError doesn't have attributes like start / end +AnyUnicodeError = Union[UnicodeEncodeError, UnicodeDecodeError] +StrOrBytes = Union[str, bytes] diff --git a/w3lib/encoding.py b/w3lib/encoding.py index 1a231155..32252105 100644 --- a/w3lib/encoding.py +++ b/w3lib/encoding.py @@ -3,11 +3,14 @@ """ import re, codecs, encodings from sys import version_info +from typing import Callable, Match, Optional, Tuple, Union, cast +from w3lib._types import AnyUnicodeError, StrOrBytes +from w3lib.util import to_native_str _HEADER_ENCODING_RE = re.compile(r"charset=([\w-]+)", re.I) -def http_content_type_encoding(content_type): +def http_content_type_encoding(content_type: Optional[str]) -> Optional[str]: """Extract the encoding in the content-type header >>> import w3lib.encoding @@ -21,6 +24,8 @@ def http_content_type_encoding(content_type): if match: return resolve_encoding(match.group(1)) + return None + # regexp for parsing HTTP meta tags _TEMPLATE = r"""%s\s*=\s*["']?\s*%s\s*["']?""" @@ -51,7 +56,7 @@ def http_content_type_encoding(content_type): ) -def html_body_declared_encoding(html_body_str): +def html_body_declared_encoding(html_body_str: StrOrBytes) -> Optional[str]: '''Return the encoding specified in meta tags in the html body, or ``None`` if no suitable encoding was found @@ -75,6 +80,7 @@ def html_body_declared_encoding(html_body_str): # html5 suggests the first 1024 bytes are sufficient, we allow for more chunk = html_body_str[:4096] + match: Union[Optional[Match[bytes]], Optional[Match[str]]] if isinstance(chunk, bytes): match = _BODY_ENCODING_BYTES_RE.search(chunk) else: @@ -87,7 +93,9 @@ def html_body_declared_encoding(html_body_str): or match.group("xmlcharset") ) if encoding: - return resolve_encoding(encoding) + return resolve_encoding(to_native_str(encoding)) + + return None # Default encoding translation @@ -118,7 +126,7 @@ def html_body_declared_encoding(html_body_str): } -def _c18n_encoding(encoding): +def _c18n_encoding(encoding: str) -> str: """Canonicalize an encoding name This performs normalization and translates aliases using python's @@ -128,7 +136,7 @@ def _c18n_encoding(encoding): return encodings.aliases.aliases.get(normed, normed) -def resolve_encoding(encoding_alias): +def resolve_encoding(encoding_alias: str) -> Optional[str]: """Return the encoding that `encoding_alias` maps to, or ``None`` if the encoding cannot be interpreted @@ -158,7 +166,7 @@ def resolve_encoding(encoding_alias): _FIRST_CHARS = set(c[0] for (c, _) in _BOM_TABLE) -def read_bom(data): +def read_bom(data: bytes) -> Union[Tuple[None, None], Tuple[str, bytes]]: r"""Read the byte order mark in the text, if present, and return the encoding represented by the BOM and the BOM. @@ -189,10 +197,12 @@ def read_bom(data): # Python decoder doesn't follow unicode standard when handling # bad utf-8 encoded strings. see http://bugs.python.org/issue8271 -codecs.register_error("w3lib_replace", lambda exc: ("\ufffd", exc.end)) +codecs.register_error( + "w3lib_replace", lambda exc: ("\ufffd", cast(AnyUnicodeError, exc).end) +) -def to_unicode(data_str, encoding): +def to_unicode(data_str: bytes, encoding: str) -> str: """Convert a str object to unicode using the encoding given Characters that cannot be converted will be converted to ``\\ufffd`` (the @@ -204,8 +214,11 @@ def to_unicode(data_str, encoding): def html_to_unicode( - content_type_header, html_body_str, default_encoding="utf8", auto_detect_fun=None -): + content_type_header: Optional[str], + html_body_str: bytes, + default_encoding: str = "utf8", + auto_detect_fun: Optional[Callable[[bytes], str]] = None, +) -> Tuple[str, str]: r'''Convert raw html bytes to unicode This attempts to make a reasonable guess at the content encoding of the @@ -273,17 +286,20 @@ def html_to_unicode( if enc is not None: # remove BOM if it agrees with the encoding if enc == bom_enc: + bom = cast(bytes, bom) html_body_str = html_body_str[len(bom) :] elif enc == "utf-16" or enc == "utf-32": # read endianness from BOM, or default to big endian # tools.ietf.org/html/rfc2781 section 4.3 if bom_enc is not None and bom_enc.startswith(enc): enc = bom_enc + bom = cast(bytes, bom) html_body_str = html_body_str[len(bom) :] else: enc += "-be" return enc, to_unicode(html_body_str, enc) if bom_enc is not None: + bom = cast(bytes, bom) return bom_enc, to_unicode(html_body_str[len(bom) :], bom_enc) enc = html_body_declared_encoding(html_body_str) if enc is None and (auto_detect_fun is not None): diff --git a/w3lib/html.py b/w3lib/html.py index 2bea60c9..634d90f5 100644 --- a/w3lib/html.py +++ b/w3lib/html.py @@ -4,10 +4,12 @@ import re from html.entities import name2codepoint +from typing import Iterable, Match, AnyStr, Optional, Pattern, Tuple, Union from urllib.parse import urljoin from w3lib.util import to_unicode from w3lib.url import safe_url_string +from w3lib._types import StrOrBytes _ent_re = re.compile( r"&((?P[a-z\d]+)|#(?P\d+)|#x(?P[a-f\d]+))(?P;?)", @@ -26,7 +28,12 @@ HTML5_WHITESPACE = " \t\n\r\x0c" -def replace_entities(text, keep=(), remove_illegal=True, encoding="utf-8"): +def replace_entities( + text: AnyStr, + keep: Iterable[str] = (), + remove_illegal: bool = True, + encoding: str = "utf-8", +) -> str: """Remove entities from the given `text` by converting them to their corresponding unicode character. @@ -54,8 +61,9 @@ def replace_entities(text, keep=(), remove_illegal=True, encoding="utf-8"): """ - def convert_entity(m): + def convert_entity(m: Match) -> str: groups = m.groupdict() + number = None if groups.get("dec"): number = int(groups["dec"], 10) elif groups.get("hex"): @@ -86,11 +94,11 @@ def convert_entity(m): return _ent_re.sub(convert_entity, to_unicode(text, encoding)) -def has_entities(text, encoding=None): +def has_entities(text: AnyStr, encoding: Optional[str] = None) -> bool: return bool(_ent_re.search(to_unicode(text, encoding))) -def replace_tags(text, token="", encoding=None): +def replace_tags(text: AnyStr, token: str = "", encoding: Optional[str] = None) -> str: """Replace all markup tags found in the given `text` by the given token. By default `token` is an empty string so it just removes all tags. @@ -116,7 +124,7 @@ def replace_tags(text, token="", encoding=None): _REMOVECOMMENTS_RE = re.compile("|$)", re.DOTALL) -def remove_comments(text, encoding=None): +def remove_comments(text: AnyStr, encoding: Optional[str] = None) -> str: """Remove HTML Comments. >>> import w3lib.html @@ -126,11 +134,16 @@ def remove_comments(text, encoding=None): """ - text = to_unicode(text, encoding) - return _REMOVECOMMENTS_RE.sub("", text) + utext = to_unicode(text, encoding) + return _REMOVECOMMENTS_RE.sub("", utext) -def remove_tags(text, which_ones=(), keep=(), encoding=None): +def remove_tags( + text: AnyStr, + which_ones: Iterable[str] = (), + keep: Iterable[str] = (), + encoding: Optional[str] = None, +) -> str: """Remove HTML Tags only. `which_ones` and `keep` are both tuples, there are four cases: @@ -180,14 +193,14 @@ def remove_tags(text, which_ones=(), keep=(), encoding=None): which_ones = {tag.lower() for tag in which_ones} keep = {tag.lower() for tag in keep} - def will_remove(tag): + def will_remove(tag: str) -> bool: tag = tag.lower() if which_ones: return tag in which_ones else: return tag not in keep - def remove_tag(m): + def remove_tag(m: Match) -> str: tag = m.group(1) return "" if will_remove(tag) else m.group(0) @@ -197,7 +210,9 @@ def remove_tag(m): return retags.sub(remove_tag, to_unicode(text, encoding)) -def remove_tags_with_content(text, which_ones=(), encoding=None): +def remove_tags_with_content( + text: AnyStr, which_ones: Iterable[str] = (), encoding: Optional[str] = None +) -> str: """Remove tags and their content. `which_ones` is a tuple of which tags to remove including their content. @@ -211,19 +226,22 @@ def remove_tags_with_content(text, which_ones=(), encoding=None): """ - text = to_unicode(text, encoding) + utext = to_unicode(text, encoding) if which_ones: tags = "|".join( [r"<%s\b.*?|<%s\s*/>" % (tag, tag, tag) for tag in which_ones] ) retags = re.compile(tags, re.DOTALL | re.IGNORECASE) - text = retags.sub("", text) - return text + utext = retags.sub("", utext) + return utext def replace_escape_chars( - text, which_ones=("\n", "\t", "\r"), replace_by="", encoding=None -): + text: AnyStr, + which_ones: Iterable[str] = ("\n", "\t", "\r"), + replace_by: StrOrBytes = "", + encoding: Optional[str] = None, +) -> str: """Remove escape characters. `which_ones` is a tuple of which escape characters we want to remove. @@ -234,13 +252,18 @@ def replace_escape_chars( """ - text = to_unicode(text, encoding) + utext = to_unicode(text, encoding) for ec in which_ones: - text = text.replace(ec, to_unicode(replace_by, encoding)) - return text + utext = utext.replace(ec, to_unicode(replace_by, encoding)) + return utext -def unquote_markup(text, keep=(), remove_illegal=True, encoding=None): +def unquote_markup( + text: AnyStr, + keep: Iterable[str] = (), + remove_illegal: bool = True, + encoding: Optional[str] = None, +) -> str: """ This function receives markup as a text (always a unicode string or a UTF-8 encoded string) and does the following: @@ -252,7 +275,7 @@ def unquote_markup(text, keep=(), remove_illegal=True, encoding=None): """ - def _get_fragments(txt, pattern): + def _get_fragments(txt: str, pattern: Pattern) -> Iterable[Union[str, Match]]: offset = 0 for match in pattern.finditer(txt): match_s, match_e = match.span(1) @@ -261,9 +284,9 @@ def _get_fragments(txt, pattern): offset = match_e yield txt[offset:] - text = to_unicode(text, encoding) + utext = to_unicode(text, encoding) ret_text = "" - for fragment in _get_fragments(text, _cdata_re): + for fragment in _get_fragments(utext, _cdata_re): if isinstance(fragment, str): # it's not a CDATA (so we try to remove its entities) ret_text += replace_entities( @@ -275,7 +298,9 @@ def _get_fragments(txt, pattern): return ret_text -def get_base_url(text, baseurl="", encoding="utf-8"): +def get_base_url( + text: AnyStr, baseurl: StrOrBytes = "", encoding: str = "utf-8" +) -> str: """Return the base url if declared in the given HTML `text`, relative to the given base url. @@ -283,8 +308,8 @@ def get_base_url(text, baseurl="", encoding="utf-8"): """ - text = to_unicode(text, encoding) - m = _baseurl_re.search(text) + utext = to_unicode(text, encoding) + m = _baseurl_re.search(utext) if m: return urljoin( safe_url_string(baseurl), safe_url_string(m.group(1), encoding=encoding) @@ -294,8 +319,11 @@ def get_base_url(text, baseurl="", encoding="utf-8"): def get_meta_refresh( - text, baseurl="", encoding="utf-8", ignore_tags=("script", "noscript") -): + text: AnyStr, + baseurl: str = "", + encoding: str = "utf-8", + ignore_tags: Iterable[str] = ("script", "noscript"), +) -> Tuple[Optional[float], Optional[str]]: """Return the http-equiv parameter of the HTML meta element from the given HTML text and return a tuple ``(interval, url)`` where interval is an integer containing the delay in seconds (or zero if not present) and url is a @@ -306,13 +334,13 @@ def get_meta_refresh( """ try: - text = to_unicode(text, encoding) + utext = to_unicode(text, encoding) except UnicodeDecodeError: print(text) raise - text = remove_tags_with_content(text, ignore_tags) - text = remove_comments(replace_entities(text)) - m = _meta_refresh_re.search(text) + utext = remove_tags_with_content(utext, ignore_tags) + utext = remove_comments(replace_entities(utext)) + m = _meta_refresh_re.search(utext) if m: interval = float(m.group("int")) url = safe_url_string(m.group("url").strip(" \"'"), encoding) @@ -322,7 +350,7 @@ def get_meta_refresh( return None, None -def strip_html5_whitespace(text): +def strip_html5_whitespace(text: str) -> str: r""" Strip all leading and trailing space characters (as defined in https://www.w3.org/TR/html5/infrastructure.html#space-character). diff --git a/w3lib/http.py b/w3lib/http.py index f3793922..4ea31fad 100644 --- a/w3lib/http.py +++ b/w3lib/http.py @@ -1,7 +1,12 @@ from base64 import urlsafe_b64encode +from typing import Any, List, MutableMapping, Optional, AnyStr, Sequence, Union, Mapping +from w3lib.util import to_bytes, to_native_str +HeadersDictInput = Mapping[bytes, Union[Any, Sequence]] +HeadersDictOutput = MutableMapping[bytes, List[bytes]] -def headers_raw_to_dict(headers_raw): + +def headers_raw_to_dict(headers_raw: Optional[bytes]) -> Optional[HeadersDictOutput]: r""" Convert raw headers (single multi-line bytestring) to a dictionary. @@ -30,7 +35,7 @@ def headers_raw_to_dict(headers_raw): headers = headers_raw.splitlines() headers_tuples = [header.split(b":", 1) for header in headers] - result_dict = {} + result_dict: HeadersDictOutput = {} for header_item in headers_tuples: if not len(header_item) == 2: continue @@ -46,7 +51,7 @@ def headers_raw_to_dict(headers_raw): return result_dict -def headers_dict_to_raw(headers_dict): +def headers_dict_to_raw(headers_dict: Optional[HeadersDictInput]) -> Optional[bytes]: r""" Returns a raw HTTP headers representation of headers @@ -78,7 +83,9 @@ def headers_dict_to_raw(headers_dict): return b"\r\n".join(raw_lines) -def basic_auth_header(username, password, encoding="ISO-8859-1"): +def basic_auth_header( + username: AnyStr, password: AnyStr, encoding: str = "ISO-8859-1" +) -> bytes: """ Return an `Authorization` header field value for `HTTP Basic Access Authentication (RFC 2617)`_ @@ -90,10 +97,8 @@ def basic_auth_header(username, password, encoding="ISO-8859-1"): """ - auth = "%s:%s" % (username, password) - if not isinstance(auth, bytes): - # XXX: RFC 2617 doesn't define encoding, but ISO-8859-1 - # seems to be the most widely used encoding here. See also: - # http://greenbytes.de/tech/webdav/draft-ietf-httpauth-basicauth-enc-latest.html - auth = auth.encode(encoding) - return b"Basic " + urlsafe_b64encode(auth) + auth = "%s:%s" % (to_native_str(username), to_native_str(password)) + # XXX: RFC 2617 doesn't define encoding, but ISO-8859-1 + # seems to be the most widely used encoding here. See also: + # http://greenbytes.de/tech/webdav/draft-ietf-httpauth-basicauth-enc-latest.html + return b"Basic " + urlsafe_b64encode(to_bytes(auth, encoding=encoding)) diff --git a/w3lib/url.py b/w3lib/url.py index 9a39c98f..71398516 100644 --- a/w3lib/url.py +++ b/w3lib/url.py @@ -8,9 +8,18 @@ import posixpath import re import string -from collections import namedtuple +from typing import ( + cast, + Callable, + Dict, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) from urllib.parse import ( - _coerce_args, parse_qs, parse_qsl, ParseResult, @@ -23,13 +32,16 @@ urlunparse, urlunsplit, ) +from urllib.parse import _coerce_args # type: ignore from urllib.request import pathname2url, url2pathname from w3lib.util import to_unicode +from w3lib._types import AnyUnicodeError, StrOrBytes # error handling function for bytes-to-Unicode decoding errors with URLs -def _quote_byte(error): - return (quote(error.object[error.start : error.end]), error.end) +def _quote_byte(error: UnicodeError) -> Tuple[str, int]: + error = cast(AnyUnicodeError, error) + return (to_unicode(quote(error.object[error.start : error.end])), error.end) codecs.register_error("percentencode", _quote_byte) @@ -49,7 +61,12 @@ def _quote_byte(error): ) # see https://infra.spec.whatwg.org/#ascii-tab-or-newline -def safe_url_string(url, encoding="utf8", path_encoding="utf8", quote_path=True): +def safe_url_string( + url: StrOrBytes, + encoding: str = "utf8", + path_encoding: str = "utf8", + quote_path: bool = True, +) -> str: """Convert the given URL into a legal URL by escaping unsafe characters according to RFC-3986. Also, ASCII tabs and newlines are removed as per https://url.spec.whatwg.org/#url-parsing. @@ -81,9 +98,11 @@ def safe_url_string(url, encoding="utf8", path_encoding="utf8", quote_path=True) # IDNA encoding can fail for too long labels (>63 characters) # or missing labels (e.g. http://.example.com) try: - netloc = parts.netloc.encode("idna").decode() + netloc_bytes = parts.netloc.encode("idna") except UnicodeError: netloc = parts.netloc + else: + netloc = netloc_bytes.decode() # default encoding for path component SHOULD be UTF-8 if quote_path: @@ -105,7 +124,9 @@ def safe_url_string(url, encoding="utf8", path_encoding="utf8", quote_path=True) _parent_dirs = re.compile(r"/?(\.\./)+") -def safe_download_url(url, encoding="utf8", path_encoding="utf8"): +def safe_download_url( + url: StrOrBytes, encoding: str = "utf8", path_encoding: str = "utf8" +) -> str: """Make a url for download. This will call safe_url_string and then strip the fragment, if one exists. The path will be normalised. @@ -124,11 +145,16 @@ def safe_download_url(url, encoding="utf8", path_encoding="utf8"): return urlunsplit((scheme, netloc, path, query, "")) -def is_url(text): +def is_url(text: str) -> bool: return text.partition("://")[0] in ("file", "http", "https") -def url_query_parameter(url, parameter, default=None, keep_blank_values=0): +def url_query_parameter( + url: StrOrBytes, + parameter: str, + default: Optional[str] = None, + keep_blank_values: Union[bool, int] = 0, +) -> Optional[str]: """Return the value of a url parameter, given the url and parameter name General case: @@ -157,19 +183,24 @@ def url_query_parameter(url, parameter, default=None, keep_blank_values=0): """ - queryparams = parse_qs(urlsplit(str(url))[3], keep_blank_values=keep_blank_values) - return queryparams.get(parameter, [default])[0] + queryparams = parse_qs( + urlsplit(str(url))[3], keep_blank_values=bool(keep_blank_values) + ) + if parameter in queryparams: + return queryparams[parameter][0] + else: + return default def url_query_cleaner( - url, - parameterlist=(), - sep="&", - kvsep="=", - remove=False, - unique=True, - keep_fragments=False, -): + url: StrOrBytes, + parameterlist: Union[StrOrBytes, Sequence[StrOrBytes]] = (), + sep: str = "&", + kvsep: str = "=", + remove: bool = False, + unique: bool = True, + keep_fragments: bool = False, +) -> str: """Clean URL arguments leaving only those passed in the parameterlist keeping order >>> import w3lib.url @@ -204,6 +235,8 @@ def url_query_cleaner( if isinstance(parameterlist, (str, bytes)): parameterlist = [parameterlist] url, fragment = urldefrag(url) + url = cast(str, url) + fragment = cast(str, fragment) base, _, query = url.partition("?") seen = set() querylist = [] @@ -223,10 +256,10 @@ def url_query_cleaner( url = "?".join([base, sep.join(querylist)]) if querylist else base if keep_fragments: url += "#" + fragment - return url + return cast(str, url) -def _add_or_replace_parameters(url, params): +def _add_or_replace_parameters(url: str, params: Dict[str, str]) -> str: parsed = urlsplit(url) current_args = parse_qsl(parsed.query, keep_blank_values=True) @@ -248,7 +281,7 @@ def _add_or_replace_parameters(url, params): return urlunsplit(parsed._replace(query=query)) -def add_or_replace_parameter(url, name, new_value): +def add_or_replace_parameter(url: str, name: str, new_value: str) -> str: """Add or remove a parameter to a given url >>> import w3lib.url @@ -264,7 +297,7 @@ def add_or_replace_parameter(url, name, new_value): return _add_or_replace_parameters(url, {name: new_value}) -def add_or_replace_parameters(url, new_parameters): +def add_or_replace_parameters(url: str, new_parameters: Dict[str, str]) -> str: """Add or remove a parameters to a given url >>> import w3lib.url @@ -279,7 +312,7 @@ def add_or_replace_parameters(url, new_parameters): return _add_or_replace_parameters(url, new_parameters) -def path_to_file_uri(path): +def path_to_file_uri(path: str) -> str: """Convert local filesystem path to legal File URIs as described in: http://en.wikipedia.org/wiki/File_URI_scheme """ @@ -289,7 +322,7 @@ def path_to_file_uri(path): return "file:///%s" % x.lstrip("/") -def file_uri_to_path(uri): +def file_uri_to_path(uri: str) -> str: """Convert File URI to local filesystem path according to: http://en.wikipedia.org/wiki/File_URI_scheme """ @@ -297,7 +330,7 @@ def file_uri_to_path(uri): return url2pathname(uri_path) -def any_to_uri(uri_or_path): +def any_to_uri(uri_or_path: str) -> str: """If given a path name, return its File URI, otherwise return it unmodified """ @@ -342,19 +375,20 @@ def any_to_uri(uri_or_path): ).encode() ) -_ParseDataURIResult = namedtuple( - "ParseDataURIResult", "media_type media_type_parameters data" -) - -def parse_data_uri(uri): - """ +class ParseDataURIResult(NamedTuple): + """Named tuple returned by :func:`parse_data_uri`.""" - Parse a data: URI, returning a 3-tuple of media type, dictionary of media - type parameters, and data. + #: MIME type type and subtype, separated by / (e.g. ``"text/plain"``). + media_type: str + #: MIME type parameters (e.g. ``{"charset": "US-ASCII"}``). + media_type_parameters: Dict[str, str] + #: Data, decoded if it was encoded in base64 format. + data: bytes - """ +def parse_data_uri(uri: StrOrBytes) -> ParseDataURIResult: + """Parse a data: URI into :class:`ParseDataURIResult`.""" if not isinstance(uri, bytes): uri = safe_url_string(uri).encode("ascii") @@ -389,7 +423,7 @@ def parse_data_uri(uri): if m: attribute, value, value_quoted = m.groups() if value_quoted: - value = re.sub(br"\\(.)", r"\1", value_quoted) + value = re.sub(br"\\(.)", rb"\1", value_quoted) media_type_params[attribute.decode()] = value.decode() uri = uri[m.end() :] else: @@ -404,7 +438,7 @@ def parse_data_uri(uri): raise ValueError("invalid data URI") data = base64.b64decode(data) - return _ParseDataURIResult(media_type, media_type_params, data) + return ParseDataURIResult(media_type, media_type_params, data) __all__ = [ @@ -423,7 +457,9 @@ def parse_data_uri(uri): ] -def _safe_ParseResult(parts, encoding="utf8", path_encoding="utf8"): +def _safe_ParseResult( + parts: ParseResult, encoding: str = "utf8", path_encoding: str = "utf8" +) -> Tuple[str, str, str, str, str, str]: # IDNA encoding can fail for too long labels (>63 characters) # or missing labels (e.g. http://.example.com) try: @@ -441,7 +477,12 @@ def _safe_ParseResult(parts, encoding="utf8", path_encoding="utf8"): ) -def canonicalize_url(url, keep_blank_values=True, keep_fragments=False, encoding=None): +def canonicalize_url( + url: Union[StrOrBytes, ParseResult], + keep_blank_values: bool = True, + keep_fragments: bool = False, + encoding: Optional[str] = None, +) -> str: r"""Canonicalize the given url by applying the following procedures: - sort query arguments, first by key, then by value @@ -479,7 +520,7 @@ def canonicalize_url(url, keep_blank_values=True, keep_fragments=False, encoding scheme, netloc, path, params, query, fragment = _safe_ParseResult( parse_url(url), encoding=encoding or "utf8" ) - except UnicodeEncodeError as e: + except UnicodeEncodeError: scheme, netloc, path, params, query, fragment = _safe_ParseResult( parse_url(url), encoding="utf8" ) @@ -529,7 +570,7 @@ def canonicalize_url(url, keep_blank_values=True, keep_fragments=False, encoding ) -def _unquotepath(path): +def _unquotepath(path: str) -> bytes: for reserved in ("2f", "2F", "3f", "3F"): path = path.replace("%" + reserved, "%25" + reserved.upper()) @@ -541,7 +582,9 @@ def _unquotepath(path): return unquote_to_bytes(path) -def parse_url(url, encoding=None): +def parse_url( + url: Union[StrOrBytes, ParseResult], encoding: Optional[str] = None +) -> ParseResult: """Return urlparsed url from the given argument (which could be an already parsed url) """ @@ -550,7 +593,9 @@ def parse_url(url, encoding=None): return urlparse(to_unicode(url, encoding)) -def parse_qsl_to_bytes(qs, keep_blank_values=False): +def parse_qsl_to_bytes( + qs: str, keep_blank_values: bool = False +) -> List[Tuple[bytes, bytes]]: """Parse a query given as a string argument. Data are returned as a list of name, value pairs as bytes. @@ -570,7 +615,8 @@ def parse_qsl_to_bytes(qs, keep_blank_values=False): # (at https://hg.python.org/cpython/rev/c38ac7ab8d9a) # except for the unquote(s, encoding, errors) calls replaced # with unquote_to_bytes(s) - qs, _coerce_result = _coerce_args(qs) + coerce_args = cast(Callable[..., Tuple[str, Callable]], _coerce_args) + qs, _coerce_result = coerce_args(qs) pairs = [s2 for s1 in qs.split("&") for s2 in s1.split(";")] r = [] for name_value in pairs: @@ -584,11 +630,11 @@ def parse_qsl_to_bytes(qs, keep_blank_values=False): else: continue if len(nv[1]) or keep_blank_values: - name = nv[0].replace("+", " ") + name: StrOrBytes = nv[0].replace("+", " ") name = unquote_to_bytes(name) name = _coerce_result(name) - value = nv[1].replace("+", " ") + value: StrOrBytes = nv[1].replace("+", " ") value = unquote_to_bytes(value) value = _coerce_result(value) - r.append((name, value)) + r.append((cast(bytes, name), cast(bytes, value))) return r diff --git a/w3lib/util.py b/w3lib/util.py index db8e16e8..70f4ef52 100644 --- a/w3lib/util.py +++ b/w3lib/util.py @@ -1,7 +1,12 @@ from warnings import warn +from typing import Optional +from w3lib._types import StrOrBytes -def str_to_unicode(text, encoding=None, errors="strict"): + +def str_to_unicode( + text: StrOrBytes, encoding: Optional[str] = None, errors: str = "strict" +) -> str: warn( "The w3lib.utils.str_to_unicode function is deprecated and " "will be removed in a future release.", @@ -15,7 +20,9 @@ def str_to_unicode(text, encoding=None, errors="strict"): return text -def unicode_to_str(text, encoding=None, errors="strict"): +def unicode_to_str( + text: StrOrBytes, encoding: Optional[str] = None, errors: str = "strict" +) -> bytes: warn( "The w3lib.utils.unicode_to_str function is deprecated and " "will be removed in a future release.", @@ -29,7 +36,9 @@ def unicode_to_str(text, encoding=None, errors="strict"): return text -def to_unicode(text, encoding=None, errors="strict"): +def to_unicode( + text: StrOrBytes, encoding: Optional[str] = None, errors: str = "strict" +) -> str: """Return the unicode representation of a bytes object `text`. If `text` is already an unicode object, return it as-is.""" if isinstance(text, str): @@ -43,7 +52,9 @@ def to_unicode(text, encoding=None, errors="strict"): return text.decode(encoding, errors) -def to_bytes(text, encoding=None, errors="strict"): +def to_bytes( + text: StrOrBytes, encoding: Optional[str] = None, errors: str = "strict" +) -> bytes: """Return the binary representation of `text`. If `text` is already a bytes object, return it as-is.""" if isinstance(text, bytes): @@ -57,7 +68,9 @@ def to_bytes(text, encoding=None, errors="strict"): return text.encode(encoding, errors) -def to_native_str(text, encoding=None, errors="strict"): +def to_native_str( + text: StrOrBytes, encoding: Optional[str] = None, errors: str = "strict" +) -> str: """Return str representation of `text`""" warn( "The w3lib.utils.to_native_str function is deprecated and "