diff --git a/pyproject.toml b/pyproject.toml index be4335d7..76e6956d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,3 +83,9 @@ module = [ disallow_untyped_defs = false check_untyped_defs = false ignore_missing_imports = true + +[[tool.mypy.overrides]] + module = [ + "treq._multipart", + ] + disallow_untyped_defs = false diff --git a/setup.py b/setup.py index 24ba08f6..12a15ca4 100644 --- a/setup.py +++ b/setup.py @@ -31,12 +31,10 @@ python_requires=">=3.7", install_requires=[ "incremental", - "requests >= 2.1.0", "hyperlink >= 21.0.0", "Twisted[tls] >= 22.10.0", # For #11635 "attrs", "typing_extensions >= 3.10.0", - "multipart", ], extras_require={ "dev": [ diff --git a/src/treq/_multipart.py b/src/treq/_multipart.py new file mode 100644 index 00000000..6a9d976d --- /dev/null +++ b/src/treq/_multipart.py @@ -0,0 +1,561 @@ +# -*- coding: utf-8 -*- + +""" +Parser for multipart/form-data +============================== + +This module provides a parser for the multipart/form-data format. It can read +from a file, a socket or a WSGI environment. The parser can be used to replace +cgi.FieldStorage to work around its limitations. + +..note:: + + Copyright (c) 2010, Marcel Hellkamp. + Inspired by the Werkzeug library: http://werkzeug.pocoo.org/ + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. +""" + +import re +from io import BytesIO +from typing import IO +from tempfile import TemporaryFile +from urllib.parse import parse_qs +from wsgiref.headers import Headers +from collections.abc import MutableMapping as DictMixin + + +__author__ = "Marcel Hellkamp" +__version__ = "0.2.5" +__license__ = "MIT" +__all__ = ["MultipartError", "MultipartParser", "MultipartPart", "parse_form_data"] + + +############################################################################## +################################ Helper & Misc ############################### +############################################################################## +# Some of these were copied from bottle: https://bottlepy.org + + +# --------- +# MultiDict +# --------- + + +class MultiDict(DictMixin): + """A dict that remembers old values for each key. + HTTP headers may repeat with differing values, + such as Set-Cookie. We need to remember all + values. + """ + + def __init__(self, *args, **kwargs): + self.dict = dict() + for k, v in dict(*args, **kwargs).items(): + self[k] = v + + def __len__(self): + return len(self.dict) + + def __iter__(self): + return iter(self.dict) + + def __contains__(self, key): + return key in self.dict + + def __delitem__(self, key): + del self.dict[key] + + def keys(self): + return self.dict.keys() + + def __getitem__(self, key): + return self.get(key, KeyError, -1) + + def __setitem__(self, key, value): + self.append(key, value) + + def append(self, key, value): + self.dict.setdefault(key, []).append(value) + + def replace(self, key, value): + self.dict[key] = [value] + + def getall(self, key): + return self.dict.get(key) or [] + + def get(self, key, default=None, index=-1): + if key not in self.dict and default != KeyError: + return [default][index] + + return self.dict[key][index] + + def iterallitems(self): + for key, values in self.dict.items(): + for value in values: + yield key, value + + +def to_bytes(data, enc="utf8"): + if isinstance(data, str): + data = data.encode(enc) + + return data + + +def copy_file(stream, target, maxread=-1, buffer_size=2**16): + """Read from :stream and write to :target until :maxread or EOF.""" + size, read = 0, stream.read + + while True: + to_read = buffer_size if maxread < 0 else min(buffer_size, maxread - size) + part = read(to_read) + + if not part: + return size + + target.write(part) + size += len(part) + + +# ------------- +# Header Parser +# ------------- + + +_special = re.escape('()<>@,;:"\\/[]?={} \t') +_re_special = re.compile(r"[%s]" % _special) +_quoted_string = r'"(?:\\.|[^"])*"' # Quoted string +_value = r"(?:[^%s]+|%s)" % (_special, _quoted_string) # Save or quoted string +_option = r"(?:;|^)\s*([^%s]+)\s*=\s*(%s)" % (_special, _value) +_re_option = re.compile(_option) # key=value part of an Content-Type like header + + +def header_quote(val): + if not _re_special.search(val): + return val + + return '"' + val.replace("\\", "\\\\").replace('"', '\\"') + '"' + + +def header_unquote(val, filename=False): + if val[0] == val[-1] == '"': + val = val[1:-1] + + if val[1:3] == ":\\" or val[:2] == "\\\\": + val = val.split("\\")[-1] # fix ie6 bug: full path --> filename + + return val.replace("\\\\", "\\").replace('\\"', '"') + + return val + + +def parse_options_header(header, options=None): + if ";" not in header: + return header.lower().strip(), {} + + content_type, tail = header.split(";", 1) + options = options or {} + + for match in _re_option.finditer(tail): + key = match.group(1).lower() + value = header_unquote(match.group(2), key == "filename") + options[key] = value + + return content_type, options + + +############################################################################## +################################## Multipart ################################# +############################################################################## + + +class MultipartError(ValueError): + pass + + +class MultipartParser(object): + def __init__( + self, + stream, + boundary, + content_length=-1, + disk_limit=2**30, + mem_limit=2**20, + memfile_limit=2**18, + buffer_size=2**16, + charset="latin1", + ): + """Parse a multipart/form-data byte stream. This object is an iterator + over the parts of the message. + + :param stream: A file-like stream. Must implement ``.read(size)``. + :param boundary: The multipart boundary as a byte string. + :param content_length: The maximum number of bytes to read. + """ + self.stream = stream + self.boundary = boundary + self.content_length = content_length + self.disk_limit = disk_limit + self.memfile_limit = memfile_limit + self.mem_limit = min(mem_limit, self.disk_limit) + self.buffer_size = min(buffer_size, self.mem_limit) + self.charset = charset + + if self.buffer_size - 6 < len(boundary): # "--boundary--\r\n" + raise MultipartError("Boundary does not fit into buffer_size.") + + self._done = [] + self._part_iter = None + + def __iter__(self): + """Iterate over the parts of the multipart message.""" + if not self._part_iter: + self._part_iter = self._iterparse() + + for part in self._done: + yield part + + for part in self._part_iter: + self._done.append(part) + yield part + + def parts(self): + """Returns a list with all parts of the multipart message.""" + return list(self) + + def get(self, name, default=None): + """Return the first part with that name or a default value (None).""" + for part in self: + if name == part.name: + return part + + return default + + def get_all(self, name): + """Return a list of parts with that name.""" + return [p for p in self if p.name == name] + + def _lineiter(self): + """Iterate over a binary file-like object line by line. Each line is + returned as a (line, line_ending) tuple. If the line does not fit + into self.buffer_size, line_ending is empty and the rest of the line + is returned with the next iteration. + """ + read = self.stream.read + maxread, maxbuf = self.content_length, self.buffer_size + buffer = b"" # buffer for the last (partial) line + + while True: + data = read(maxbuf if maxread < 0 else min(maxbuf, maxread)) + maxread -= len(data) + lines = (buffer + data).splitlines(True) + len_first_line = len(lines[0]) + + # be sure that the first line does not become too big + if len_first_line > self.buffer_size: + # at the same time don't split a '\r\n' accidentally + if len_first_line == self.buffer_size + 1 and lines[0].endswith( + b"\r\n" + ): + splitpos = self.buffer_size - 1 + else: + splitpos = self.buffer_size + lines[:1] = [lines[0][:splitpos], lines[0][splitpos:]] + + if data: + buffer = lines[-1] + lines = lines[:-1] + + for line in lines: + if line.endswith(b"\r\n"): + yield line[:-2], b"\r\n" + elif line.endswith(b"\n"): + yield line[:-1], b"\n" + elif line.endswith(b"\r"): + yield line[:-1], b"\r" + else: + yield line, b"" + + if not data: + break + + def _iterparse(self): + lines, line = self._lineiter(), "" + separator = b"--" + to_bytes(self.boundary) + terminator = b"--" + to_bytes(self.boundary) + b"--" + + # Consume first boundary. Ignore any preamble, as required by RFC + # 2046, section 5.1.1. + for line, nl in lines: + if line in (separator, terminator): + break + else: + raise MultipartError("Stream does not contain boundary") + + # Check for empty data + if line == terminator: + for _ in lines: + raise MultipartError("Data after end of stream") + return + + # For each part in stream... + mem_used, disk_used = 0, 0 # Track used resources to prevent DoS + is_tail = False # True if the last line was incomplete (cutted) + + opts = { + "buffer_size": self.buffer_size, + "memfile_limit": self.memfile_limit, + "charset": self.charset, + } + + part = MultipartPart(**opts) + + for line, nl in lines: + if line == terminator and not is_tail: + part.file.seek(0) + yield part + break + + elif line == separator and not is_tail: + if part.is_buffered(): + mem_used += part.size + else: + disk_used += part.size + part.file.seek(0) + + yield part + + part = MultipartPart(**opts) + + else: + is_tail = not nl # The next line continues this one + try: + part.feed(line, nl) + + if part.is_buffered(): + if part.size + mem_used > self.mem_limit: + raise MultipartError("Memory limit reached.") + elif part.size + disk_used > self.disk_limit: + raise MultipartError("Disk limit reached.") + except MultipartError: + part.close() + raise + else: + # If we run off the end of the loop, the current MultipartPart + # will not have been yielded, so it's our responsibility to + # close it. + part.close() + + if line != terminator: + raise MultipartError("Unexpected end of multipart stream.") + + +class MultipartPart(object): + file: IO[bytes] + + def __init__(self, buffer_size=2**16, memfile_limit=2**18, charset="latin1"): + self.headerlist = [] + self.headers = None + self.file = False # type:ignore + self.size = 0 + self._buf = b"" + self.disposition = None + self.name = None + self.filename = None + self.content_type = None + self.charset = charset + self.memfile_limit = memfile_limit + self.buffer_size = buffer_size + + def feed(self, line, nl=""): + if self.file: + return self.write_body(line, nl) + + return self.write_header(line, nl) + + def write_header(self, line, nl): + line = line.decode(self.charset) + + if not nl: + raise MultipartError("Unexpected end of line in header.") + + if not line.strip(): # blank line -> end of header segment + self.finish_header() + elif line[0] in " \t" and self.headerlist: + name, value = self.headerlist.pop() + self.headerlist.append((name, value + line.strip())) + else: + if ":" not in line: + raise MultipartError("Syntax error in header: No colon.") + + name, value = line.split(":", 1) + self.headerlist.append((name.strip(), value.strip())) + + def write_body(self, line, nl): + if not line and not nl: + return # This does not even flush the buffer + + self.size += len(line) + len(self._buf) + self.file.write(self._buf + line) + self._buf = nl + + if self.content_length > 0 and self.size > self.content_length: + raise MultipartError("Size of body exceeds Content-Length header.") + + if self.size > self.memfile_limit and isinstance(self.file, BytesIO): + # TODO: What about non-file uploads that exceed the memfile_limit? + self.file, old = TemporaryFile(mode="w+b"), self.file + old.seek(0) + copy_file(old, self.file, self.size, self.buffer_size) + + def finish_header(self): + self.file = BytesIO() + self.headers = Headers(self.headerlist) + content_disposition = self.headers.get("Content-Disposition", "") + content_type = self.headers.get("Content-Type", "") + + if not content_disposition: + raise MultipartError("Content-Disposition header is missing.") + + self.disposition, self.options = parse_options_header(content_disposition) + self.name = self.options.get("name") + self.filename = self.options.get("filename") + self.content_type, options = parse_options_header(content_type) + self.charset = options.get("charset") or self.charset + self.content_length = int(self.headers.get("Content-Length", "-1")) + + def is_buffered(self): + """Return true if the data is fully buffered in memory.""" + return isinstance(self.file, BytesIO) + + @property + def value(self): + """Data decoded with the specified charset""" + + return self.raw.decode(self.charset) + + @property + def raw(self): + """Data without decoding""" + pos = self.file.tell() + self.file.seek(0) + + try: + val = self.file.read() + except IOError: + raise + finally: + self.file.seek(pos) + + return val + + def save_as(self, path): + with open(path, "wb") as fp: + pos = self.file.tell() + + try: + self.file.seek(0) + size = copy_file(self.file, fp) + finally: + self.file.seek(pos) + + return size + + def close(self): + if self.file: + self.file.close() + self.file = False # type:ignore + + +############################################################################## +#################################### WSGI #################################### +############################################################################## + + +def parse_form_data(environ, charset="utf8", strict=False, **kwargs): + """Parse form data from an environ dict and return a (forms, files) tuple. + Both tuple values are dictionaries with the form-field name as a key + (unicode) and lists as values (multiple values per key are possible). + The forms-dictionary contains form-field values as unicode strings. + The files-dictionary contains :class:`MultipartPart` instances, either + because the form-field was a file-upload or the value is too big to fit + into memory limits. + + :param environ: An WSGI environment dict. + :param charset: The charset to use if unsure. (default: utf8) + :param strict: If True, raise :exc:`MultipartError` on any parsing + errors. These are silently ignored by default. + """ + + forms, files = MultiDict(), MultiDict() + + try: + if environ.get("REQUEST_METHOD", "GET").upper() not in ("POST", "PUT"): + raise MultipartError("Request method other than POST or PUT.") + content_length = int(environ.get("CONTENT_LENGTH", "-1")) + content_type = environ.get("CONTENT_TYPE", "") + + if not content_type: + raise MultipartError("Missing Content-Type header.") + + content_type, options = parse_options_header(content_type) + stream = environ.get("wsgi.input") or BytesIO() + kwargs["charset"] = charset = options.get("charset", charset) + + if content_type == "multipart/form-data": + boundary = options.get("boundary", "") + + if not boundary: + raise MultipartError("No boundary for multipart/form-data.") + + for part in MultipartParser(stream, boundary, content_length, **kwargs): + if part.filename or not part.is_buffered(): + files[part.name] = part + else: # TODO: Big form-fields are in the files dict. really? + forms[part.name] = part.value + + elif content_type in ( + "application/x-www-form-urlencoded", + "application/x-url-encoded", + ): + mem_limit = kwargs.get("mem_limit", 2**20) + if content_length > mem_limit: + raise MultipartError("Request too big. Increase MAXMEM.") + + data = stream.read(mem_limit).decode(charset) + + if stream.read(1): # These is more that does not fit mem_limit + raise MultipartError("Request too big. Increase MAXMEM.") + + data = parse_qs(data, keep_blank_values=True, encoding=charset) + + for key, values in data.items(): + for value in values: + forms[key] = value + else: + raise MultipartError("Unsupported content type.") + + except MultipartError: + if strict: + for part in files.values(): + part.close() + raise + + return forms, files diff --git a/src/treq/_types.py b/src/treq/_types.py index b758f5e9..673e08cd 100644 --- a/src/treq/_types.py +++ b/src/treq/_types.py @@ -1,6 +1,7 @@ # Copyright (c) The treq Authors. # See LICENSE for details. import io +from .cookies import TreqieJar from http.cookiejar import CookieJar from typing import Any, Dict, Iterable, List, Mapping, Tuple, Union @@ -48,6 +49,7 @@ class _ITreqReactor(IReactorTCP, IReactorTime, IReactorPluggableNameResolver): ] _CookiesType = Union[ + TreqieJar, CookieJar, Mapping[str, str], ] diff --git a/src/treq/api.py b/src/treq/api.py index 19e551f9..6730296f 100644 --- a/src/treq/api.py +++ b/src/treq/api.py @@ -1,9 +1,23 @@ from __future__ import absolute_import, division, print_function +from typing import Callable, Concatenate, ParamSpec, TypeVar + from twisted.web.client import Agent, HTTPConnectionPool +from treq._types import _URLType from treq.client import HTTPClient +P = ParamSpec("P") +R = TypeVar("R") + + +def _like( + method: Callable[Concatenate[HTTPClient, _URLType, P], R] +) -> Callable[ + [Callable[Concatenate[_URLType, P], R]], Callable[Concatenate[_URLType, P], R] +]: + return lambda x: x + def head(url, **kwargs): """ @@ -14,6 +28,7 @@ def head(url, **kwargs): return _client(kwargs).head(url, _stacklevel=4, **kwargs) +@_like(HTTPClient.get) def get(url, headers=None, **kwargs): """ Make a ``GET`` request. diff --git a/src/treq/client.py b/src/treq/client.py index 7c172c69..db17a583 100644 --- a/src/treq/client.py +++ b/src/treq/client.py @@ -1,3 +1,5 @@ +# -*- test-case-name: treq.test.test_client -*- +from __future__ import annotations import io import mimetypes import uuid @@ -7,20 +9,21 @@ from typing import ( Any, Callable, + Concatenate, Iterable, Iterator, List, Mapping, Optional, + ParamSpec, Tuple, + TypeVar, Union, ) from urllib.parse import quote_plus from urllib.parse import urlencode as _urlencode from hyperlink import DecodedURL, EncodedURL -from requests.cookies import merge_cookies -from treq.cookies import scoped_cookie from twisted.internet.defer import Deferred from twisted.internet.interfaces import IProtocol from twisted.python.components import proxyForInterface, registerAdapter @@ -50,8 +53,14 @@ _URLType, ) from treq.auth import add_auth +from treq.cookies import scoped_cookie from treq.response import _Response +from .cookies import TreqieJar + +P = ParamSpec("P") +R = TypeVar("R") + class _Nothing: """Type of the sentinel `_NOTHING`""" @@ -67,7 +76,7 @@ def urlencode(query: _ParamsType, doseq: bool) -> bytes: def _scoped_cookiejar_from_dict( url_object: EncodedURL, cookie_dict: Optional[Mapping[str, str]] -) -> CookieJar: +) -> TreqieJar: """ Create a CookieJar from a dictionary whose cookies are all scoped to the given URL's origin. @@ -75,7 +84,7 @@ def _scoped_cookiejar_from_dict( @note: This does not scope the cookies to any particular path, only the host, port, and scheme of the given URL. """ - cookie_jar = CookieJar() + cookie_jar = TreqieJar() if cookie_dict is None: return cookie_jar for k, v in cookie_dict.items(): @@ -83,6 +92,12 @@ def _scoped_cookiejar_from_dict( return cookie_jar +def _merge_cookies(left: TreqieJar, right: CookieJar) -> TreqieJar: + for cookie in right: + left.set_cookie(cookie) + return left + + class _BodyBufferingProtocol(proxyForInterface(IProtocol)): # type: ignore def __init__(self, original, buffer, finished): self.original = original @@ -130,26 +145,29 @@ def deliverBody(self, protocol): self._waiters.append(protocol) +P2 = ParamSpec("P2") + + +def _like(c: Callable[Concatenate[HTTPClient, str, _URLType, P], R]) -> Callable[ + [Callable[Concatenate[HTTPClient, _URLType, P], R]], + Callable[Concatenate[HTTPClient, _URLType, P], R], +]: + return lambda x: x + + class HTTPClient: def __init__( self, agent: IAgent, - cookiejar: Optional[CookieJar] = None, + cookiejar: Optional[TreqieJar] = None, data_to_body_producer: Callable[[Any], IBodyProducer] = IBodyProducer, ) -> None: self._agent = agent if cookiejar is None: - cookiejar = CookieJar() + cookiejar = TreqieJar() self._cookiejar = cookiejar self._data_to_body_producer = data_to_body_producer - def get(self, url: _URLType, **kwargs: Any) -> "Deferred[_Response]": - """ - See :func:`treq.get()`. - """ - kwargs.setdefault("_stacklevel", 3) - return self.request("GET", url, **kwargs) - def put( self, url: _URLType, data: Optional[_DataType] = None, **kwargs: Any ) -> "Deferred[_Response]": @@ -246,7 +264,7 @@ def request( if not isinstance(cookies, CookieJar): cookies = _scoped_cookiejar_from_dict(parsed_url, cookies) - merge_cookies(self._cookiejar, cookies) + _merge_cookies(self._cookiejar, cookies) wrapped_agent: IAgent = CookieAgent(self._agent, self._cookiejar) if allow_redirects: @@ -283,6 +301,16 @@ def gotResult(result): return d.addCallback(_Response, self._cookiejar) + @_like(request) + def get(self, url: _URLType, **kwargs: Any) -> "Deferred[_Response]": + """ + See :func:`treq.get()`. + """ + kwargs.setdefault("_stacklevel", 3) + return self.request("GET", url, **kwargs) + + reveal_type(get) + def _request_headers( self, headers: Optional[_HeadersType], stacklevel: int ) -> Headers: diff --git a/src/treq/content.py b/src/treq/content.py index e3f4aaad..60a0694d 100644 --- a/src/treq/content.py +++ b/src/treq/content.py @@ -1,7 +1,11 @@ +""" +Utilities related to retrieving the contents of the response-body. +""" + import json from typing import Any, Callable, FrozenSet, List, Optional, cast -import multipart # type: ignore +from ._multipart import parse_options_header from twisted.internet.defer import Deferred, succeed from twisted.internet.protocol import Protocol, connectionDone from twisted.python.failure import Failure @@ -29,7 +33,7 @@ def _encoding_from_headers(headers: Headers) -> Optional[str]: # This seems to be the choice browsers make when encountering multiple # content-type headers. - media_type, params = multipart.parse_options_header(content_types[-1]) + media_type, params = parse_options_header(content_types[-1]) charset = params.get("charset") if charset: diff --git a/src/treq/cookies.py b/src/treq/cookies.py index 20abac5f..442e04cd 100644 --- a/src/treq/cookies.py +++ b/src/treq/cookies.py @@ -1,3 +1,4 @@ +# -*- test-case-name: treq.test.test_integration -*- """ Convenience helpers for :mod:`http.cookiejar` """ @@ -8,6 +9,14 @@ from hyperlink import EncodedURL +class TreqieJar(CookieJar): + def __getitem__(self, name: str) -> str: + for cookie in self: + if cookie.name == name and cookie.value is not None: + return cookie.value + raise KeyError(name) + + def scoped_cookie(origin: Union[str, EncodedURL], name: str, value: str) -> Cookie: """ Create a cookie scoped to a given URL's origin. diff --git a/src/treq/response.py b/src/treq/response.py index df6010da..13d5ad64 100644 --- a/src/treq/response.py +++ b/src/treq/response.py @@ -1,5 +1,4 @@ from typing import Any, Callable, List -from requests.cookies import cookiejar_from_dict from http.cookiejar import CookieJar from twisted.internet.defer import Deferred from twisted.python import reflect @@ -16,7 +15,7 @@ class _Response(proxyForInterface(IResponse)): # type: ignore """ original: IResponse - _cookiejar: CookieJar + _cookiejar: TreqieJar def __init__(self, original: IResponse, cookiejar: CookieJar): self.original = original @@ -107,11 +106,7 @@ def cookies(self) -> CookieJar: """ Get a copy of this response's cookies. """ - # NB: This actually returns a RequestsCookieJar, but we type it as a - # regular CookieJar because we want to ditch requests as a dependency. - # Full deprecation deprecation will require a subclass or wrapper that - # warns about the RequestCookieJar extensions. - jar: CookieJar = cookiejar_from_dict({}) + jar = CookieJar() for cookie in self._cookiejar: jar.set_cookie(cookie) diff --git a/src/treq/test/test_cookies.py b/src/treq/test/test_cookies.py index 41946d52..31707544 100644 --- a/src/treq/test/test_cookies.py +++ b/src/treq/test/test_cookies.py @@ -1,18 +1,19 @@ -from http.cookiejar import CookieJar, Cookie +from http.cookiejar import Cookie, CookieJar import attrs -from twisted.internet.testing import StringTransport +from treq._agentspy import RequestRecord, agent_spy +from treq.client import HTTPClient +from treq.cookies import scoped_cookie, search from twisted.internet.interfaces import IProtocol -from twisted.trial.unittest import SynchronousTestCase +from twisted.internet.testing import StringTransport from twisted.python.failure import Failure +from twisted.trial.unittest import SynchronousTestCase from twisted.web.client import ResponseDone from twisted.web.http_headers import Headers from twisted.web.iweb import IClientRequest, IResponse from zope.interface import implementer -from treq._agentspy import agent_spy, RequestRecord -from treq.client import HTTPClient -from treq.cookies import scoped_cookie, search +from ..cookies import TreqieJar @implementer(IClientRequest) @@ -135,7 +136,7 @@ class HTTPClientCookieTests(SynchronousTestCase): def setUp(self) -> None: self.agent, self.requests = agent_spy() - self.cookiejar = CookieJar() + self.cookiejar = TreqieJar() self.client = HTTPClient(self.agent, self.cookiejar) def test_cookies_in_jars(self) -> None: diff --git a/src/treq/test/test_multipart.py b/src/treq/test/test_multipart.py index 999f1afd..7b123006 100644 --- a/src/treq/test/test_multipart.py +++ b/src/treq/test/test_multipart.py @@ -5,7 +5,7 @@ from io import BytesIO -from multipart import MultipartParser # type: ignore +from .._multipart import MultipartParser from twisted.trial import unittest from zope.interface.verify import verifyObject diff --git a/src/treq/test/test_treq_integration.py b/src/treq/test/test_treq_integration.py index 1518fb46..0bef51ce 100644 --- a/src/treq/test/test_treq_integration.py +++ b/src/treq/test/test_treq_integration.py @@ -1,24 +1,25 @@ +from __future__ import annotations from io import BytesIO +from typing import Callable, Concatenate, ParamSpec, TypeVar -from twisted.python.url import URL - -from twisted.trial.unittest import TestCase +import treq +from treq.test.util import DEBUG, skip_on_windows_because_of_199 +from twisted.internet import reactor from twisted.internet.defer import CancelledError, inlineCallbacks +from twisted.internet.ssl import Certificate, trustRootFromCertificates from twisted.internet.task import deferLater -from twisted.internet import reactor from twisted.internet.tcp import Client -from twisted.internet.ssl import Certificate, trustRootFromCertificates - -from twisted.web.client import (Agent, BrowserLikePolicyForHTTPS, - HTTPConnectionPool, ResponseFailed) - -from treq.test.util import DEBUG, skip_on_windows_because_of_199 +from twisted.python.url import URL +from twisted.trial.unittest import TestCase +from twisted.web.client import ( + Agent, + BrowserLikePolicyForHTTPS, + HTTPConnectionPool, + ResponseFailed, +) from .local_httpbin.parent import _HTTPBinProcess -import treq - - skip = skip_on_windows_because_of_199() @@ -26,25 +27,30 @@ def print_response(response): if DEBUG: print() - print('---') + print("---") print(response.code) print(response.headers) print(response.request.headers) text = yield treq.text_content(response) print(text) - print('---') + print("---") -def with_baseurl(method): - def _request(self, url, *args, **kwargs): - return method(self.baseurl + url, - *args, - agent=self.agent, - pool=self.pool, - **kwargs) +P = ParamSpec("P") +R = TypeVar("R") - return _request +def with_baseurl( + method: Callable[Concatenate[str, P], R] +) -> Callable[Concatenate[TreqIntegrationTests, str, P], R]: + def _request( + self: TreqIntegrationTests, url: str, *args: P.args, **kwargs: P.kwargs + ) -> R: + return method( + self.baseurl + url, *args, agent=self.agent, pool=self.pool, **kwargs + ) + + return _request class TreqIntegrationTests(TestCase): get = with_baseurl(treq.get) @@ -58,11 +64,10 @@ class TreqIntegrationTests(TestCase): @inlineCallbacks def setUp(self): - description = yield self._httpbin_process.server_description( - reactor) - self.baseurl = URL(scheme=u"http", - host=description.host, - port=description.port).asText() + description = yield self._httpbin_process.server_description(reactor) + self.baseurl = URL( + scheme="http", host=description.host, port=description.port + ).asText() self.agent = Agent(reactor) self.pool = HTTPConnectionPool(reactor, False) @@ -82,85 +87,83 @@ def _check_fds(_): @inlineCallbacks def assert_data(self, response, expected_data): body = yield treq.json_content(response) - self.assertIn('data', body) - self.assertEqual(body['data'], expected_data) + self.assertIn("data", body) + self.assertEqual(body["data"], expected_data) @inlineCallbacks def assert_sent_header(self, response, header, expected_value): body = yield treq.json_content(response) - self.assertIn(header, body['headers']) - self.assertEqual(body['headers'][header], expected_value) + self.assertIn(header, body["headers"]) + self.assertEqual(body["headers"][header], expected_value) @inlineCallbacks def test_get(self): - response = yield self.get('/get') + response = yield self.get("/get") self.assertEqual(response.code, 200) yield print_response(response) @inlineCallbacks def test_get_headers(self): - response = yield self.get('/get', {b'X-Blah': [b'Foo', b'Bar']}) + response = yield self.get("/get", {b"X-Blah": [b"Foo", b"Bar"]}) self.assertEqual(response.code, 200) - yield self.assert_sent_header(response, 'X-Blah', 'Foo,Bar') + yield self.assert_sent_header(response, "X-Blah", "Foo,Bar") yield print_response(response) @inlineCallbacks def test_get_headers_unicode(self): - response = yield self.get('/get', {u'X-Blah': [u'Foo', b'Bar']}) + response = yield self.get("/get", {"X-Blah": ["Foo", b"Bar"]}) self.assertEqual(response.code, 200) - yield self.assert_sent_header(response, 'X-Blah', 'Foo,Bar') + yield self.assert_sent_header(response, "X-Blah", "Foo,Bar") yield print_response(response) @inlineCallbacks def test_get_302_absolute_redirect(self): - response = yield self.get( - '/redirect-to?url={0}/get'.format(self.baseurl)) + response = yield self.get("/redirect-to?url={0}/get".format(self.baseurl)) self.assertEqual(response.code, 200) yield print_response(response) @inlineCallbacks def test_get_302_relative_redirect(self): - response = yield self.get('/relative-redirect/1') + response = yield self.get("/relative-redirect/1") self.assertEqual(response.code, 200) yield print_response(response) @inlineCallbacks def test_get_302_redirect_disallowed(self): - response = yield self.get('/redirect/1', allow_redirects=False) + response = yield self.get("/redirect/1", allow_redirects=False) self.assertEqual(response.code, 302) yield print_response(response) @inlineCallbacks def test_head(self): - response = yield self.head('/get') + response = yield self.head("/get") body = yield treq.content(response) - self.assertEqual(b'', body) + self.assertEqual(b"", body) yield print_response(response) @inlineCallbacks def test_head_302_absolute_redirect(self): - response = yield self.head( - '/redirect-to?url={0}/get'.format(self.baseurl)) + response = yield self.head("/redirect-to?url={0}/get".format(self.baseurl)) self.assertEqual(response.code, 200) yield print_response(response) @inlineCallbacks def test_head_302_relative_redirect(self): - response = yield self.head('/relative-redirect/1') + response = yield self.head("/relative-redirect/1") self.assertEqual(response.code, 200) yield print_response(response) @inlineCallbacks def test_head_302_redirect_disallowed(self): - response = yield self.head('/redirect/1', allow_redirects=False) + response = yield self.head("/redirect/1", allow_redirects=False) self.assertEqual(response.code, 302) yield print_response(response) @inlineCallbacks def test_post(self): - response = yield self.post('/post', b'Hello!') + response = yield self.post("/post", b"Hello!") self.assertEqual(response.code, 200) - yield self.assert_data(response, 'Hello!') + yield self.assert_data(response, "Hello!") yield print_response(response) @inlineCallbacks @@ -174,70 +177,66 @@ def read(*args, **kwargs): return BytesIO.read(*args, **kwargs) response = yield self.post( - '/post', - data={"a": "b"}, - files={"file1": FileLikeObject(b"file")}) + "/post", data={"a": "b"}, files={"file1": FileLikeObject(b"file")} + ) self.assertEqual(response.code, 200) body = yield treq.json_content(response) - self.assertEqual('b', body['form']['a']) - self.assertEqual('file', body['files']['file1']) + self.assertEqual("b", body["form"]["a"]) + self.assertEqual("file", body["files"]["file1"]) yield print_response(response) @inlineCallbacks def test_post_headers(self): response = yield self.post( - '/post', - b'{msg: "Hello!"}', - headers={'Content-Type': ['application/json']} + "/post", b'{msg: "Hello!"}', headers={"Content-Type": ["application/json"]} ) self.assertEqual(response.code, 200) - yield self.assert_sent_header( - response, 'Content-Type', 'application/json') + yield self.assert_sent_header(response, "Content-Type", "application/json") yield self.assert_data(response, '{msg: "Hello!"}') yield print_response(response) @inlineCallbacks def test_put(self): - response = yield self.put('/put', data=b'Hello!') + response = yield self.put("/put", data=b"Hello!") yield print_response(response) @inlineCallbacks def test_patch(self): - response = yield self.patch('/patch', data=b'Hello!') + response = yield self.patch("/patch", data=b"Hello!") self.assertEqual(response.code, 200) - yield self.assert_data(response, 'Hello!') + yield self.assert_data(response, "Hello!") yield print_response(response) @inlineCallbacks def test_delete(self): - response = yield self.delete('/delete') + response = yield self.delete("/delete") self.assertEqual(response.code, 200) yield print_response(response) @inlineCallbacks def test_gzip(self): - response = yield self.get('/gzip') + response = yield self.get("/gzip") self.assertEqual(response.code, 200) yield print_response(response) json = yield treq.json_content(response) - self.assertTrue(json['gzipped']) + self.assertTrue(json["gzipped"]) @inlineCallbacks def test_basic_auth(self): - response = yield self.get('/basic-auth/treq/treq', - auth=('treq', 'treq')) + response = yield self.get("/basic-auth/treq/treq", auth=("treq", "treq")) self.assertEqual(response.code, 200) yield print_response(response) json = yield treq.json_content(response) - self.assertTrue(json['authenticated']) - self.assertEqual(json['user'], 'treq') + self.assertTrue(json["authenticated"]) + self.assertEqual(json["user"], "treq") @inlineCallbacks def test_failed_basic_auth(self): - response = yield self.get('/basic-auth/treq/treq', - auth=('not-treq', 'not-treq')) + response = yield self.get( + "/basic-auth/treq/treq", auth=("not-treq", "not-treq") + ) self.assertEqual(response.code, 401) yield print_response(response) @@ -246,26 +245,25 @@ def test_timeout(self): """ Verify a timeout fires if a request takes too long. """ - yield self.assertFailure(self.get('/delay/2', timeout=1), - CancelledError, - ResponseFailed) + yield self.assertFailure( + self.get("/delay/2", timeout=1), CancelledError, ResponseFailed + ) @inlineCallbacks def test_cookie(self): - response = yield self.get('/cookies', cookies={'hello': 'there'}) + response = yield self.get("/cookies", cookies={"hello": "there"}) self.assertEqual(response.code, 200) yield print_response(response) json = yield treq.json_content(response) - self.assertEqual(json['cookies']['hello'], 'there') + self.assertEqual(json["cookies"]["hello"], "there") - @inlineCallbacks - def test_set_cookie(self): - response = yield self.get('/cookies/set', - allow_redirects=False, - params={'hello': 'there'}) + async def test_set_cookie(self) -> None: + response = await self.get( + "/cookies/set", allow_redirects=False, params={"hello": "there"} + ) # self.assertEqual(response.code, 200) - yield print_response(response) - self.assertEqual(response.cookies()['hello'], 'there') + await print_response(response) + self.assertEqual(response.cookies()["hello"], "there") class HTTPSTreqIntegrationTests(TreqIntegrationTests): @@ -273,11 +271,10 @@ class HTTPSTreqIntegrationTests(TreqIntegrationTests): @inlineCallbacks def setUp(self): - description = yield self._httpbin_process.server_description( - reactor) - self.baseurl = URL(scheme=u"https", - host=description.host, - port=description.port).asText() + description = yield self._httpbin_process.server_description(reactor) + self.baseurl = URL( + scheme="https", host=description.host, port=description.port + ).asText() root = trustRootFromCertificates( [Certificate.loadPEM(description.cacert)], diff --git a/tox.ini b/tox.ini index 1ce2bece..664343da 100644 --- a/tox.ini +++ b/tox.ini @@ -11,9 +11,9 @@ extras = dev deps = coverage - twisted_lowest: Twisted==22.10.0 - twisted_latest: Twisted - twisted_trunk: https://github.com/twisted/twisted/archive/trunk.zip + twisted_lowest: Twisted[tls]==22.10.0 + twisted_latest: Twisted[tls] + twisted_trunk: Twisted[tls]@https://github.com/twisted/twisted/archive/trunk.zip setenv = # Avoid unnecessary network access when creating virtualenvs for speed. VIRTUALENV_NO_DOWNLOAD=1 @@ -30,7 +30,6 @@ basepython = python3.12 deps = mypy==1.0.1 mypy-zope==0.9.1 - types-requests commands = mypy \ --cache-dir="{toxworkdir}/mypy_cache" \ @@ -81,4 +80,4 @@ commands = # This is a minimal Black-compatible config. # See https://black.readthedocs.io/en/stable/compatible_configs.html#flake8 max-line-length = 88 -extend-ignore = E203, W503 +extend-ignore = E203, W503, E266