diff --git a/changelog.d/384.feature.rst b/changelog.d/384.feature.rst new file mode 100644 index 00000000..bf7d4903 --- /dev/null +++ b/changelog.d/384.feature.rst @@ -0,0 +1 @@ +The new :mod:`treq.cookies` module provides helper functions for working with `http.cookiejar.Cookie` and `CookieJar` objects. diff --git a/changelog.d/384.misc.rst b/changelog.d/384.misc.rst deleted file mode 100644 index e69de29b..00000000 diff --git a/docs/api.rst b/docs/api.rst index b01ee010..2be0cb42 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -89,6 +89,13 @@ Authentication .. autoexception:: UnknownAuthConfig +Cookies +------- + +.. module:: treq.cookies + +.. autofunction:: scoped_cookie + Test Helpers ------------ diff --git a/docs/examples/using_cookies.py b/docs/examples/using_cookies.py index 2e107277..4805f890 100644 --- a/docs/examples/using_cookies.py +++ b/docs/examples/using_cookies.py @@ -4,19 +4,16 @@ import treq -def main(reactor, *args): - d = treq.get('https://httpbin.org/cookies/set?hello=world') +async def main(reactor): + resp = await treq.get("https://httpbin.org/cookies/set?hello=world") - def _get_jar(resp): - jar = resp.cookies() + jar = resp.cookies() + [cookie] = treq.cookies.raid(jar, domain="httpbin.org", name="hello") + print("The server set our hello cookie to: {}".format(cookie.value)) - print('The server set our hello cookie to: {}'.format(jar['hello'])) + await treq.get("https://httpbin.org/cookies", cookies=jar).addCallback( + print_response + ) - return treq.get('https://httpbin.org/cookies', cookies=jar) - d.addCallback(_get_jar) - d.addCallback(print_response) - - return d - -react(main, []) +react(main) diff --git a/src/treq/client.py b/src/treq/client.py index 8a6bf634..7c172c69 100644 --- a/src/treq/client.py +++ b/src/treq/client.py @@ -2,7 +2,7 @@ import mimetypes import uuid from collections import abc -from http.cookiejar import Cookie, CookieJar +from http.cookiejar import CookieJar from json import dumps as json_dumps from typing import ( Any, @@ -20,6 +20,7 @@ from hyperlink import DecodedURL, EncodedURL from requests.cookies import merge_cookies +from treq.cookies import scoped_cookie from twisted.internet.defer import Deferred from twisted.internet.interfaces import IProtocol from twisted.python.components import proxyForInterface, registerAdapter @@ -78,39 +79,7 @@ def _scoped_cookiejar_from_dict( if cookie_dict is None: return cookie_jar for k, v in cookie_dict.items(): - secure = url_object.scheme == "https" - port_specified = not ( - (url_object.scheme == "https" and url_object.port == 443) - or (url_object.scheme == "http" and url_object.port == 80) - ) - port = str(url_object.port) if port_specified else None - domain = url_object.host - netscape_domain = domain if "." in domain else domain + ".local" - - cookie_jar.set_cookie( - Cookie( - # Scoping - domain=netscape_domain, - port=port, - secure=secure, - port_specified=port_specified, - # Contents - name=k, - value=v, - # Constant/always-the-same stuff - version=0, - path="/", - expires=None, - discard=False, - comment=None, - comment_url=None, - rfc2109=False, - path_specified=False, - domain_specified=False, - domain_initial_dot=False, - rest={}, - ) - ) + cookie_jar.set_cookie(scoped_cookie(url_object, k, v)) return cookie_jar diff --git a/src/treq/cookies.py b/src/treq/cookies.py new file mode 100644 index 00000000..e9079a2e --- /dev/null +++ b/src/treq/cookies.py @@ -0,0 +1,101 @@ +""" +Convenience helpers for :mod:`http.cookiejar` +""" + +from typing import Union, Iterable, Optional +from http.cookiejar import Cookie, CookieJar + +from hyperlink import EncodedURL + + +def scoped_cookie(origin: Union[str, EncodedURL], name: str, value: str) -> Cookie: + """ + Create a cookie scoped to a given URL's origin. + + You can insert the result directly into a `CookieJar`, like:: + + jar = CookieJar() + jar.set_cookie(scoped_cookie("https://example.tld", "flavor", "chocolate")) + + await treq.get("https://domain.example", cookies=jar) + + :param origin: + A URL that specifies the domain and port number of the cookie. + + If the protocol is HTTP*S* the cookie is marked ``Secure``, meaning + it will not be attached to HTTP requests. Otherwise the cookie will be + attached to both HTTP and HTTPS requests + + :param name: Name of the cookie. + + :param value: Value of the cookie. + + .. note:: + + This does not scope the cookies to any particular path, only the + host, port, and scheme of the given URL. + """ + if isinstance(origin, EncodedURL): + url_object = origin + else: + url_object = EncodedURL.from_text(origin) + + secure = url_object.scheme == "https" + port_specified = not ( + (url_object.scheme == "https" and url_object.port == 443) + or (url_object.scheme == "http" and url_object.port == 80) + ) + port = str(url_object.port) if port_specified else None + domain = url_object.host + netscape_domain = domain if "." in domain else domain + ".local" + return Cookie( + # Scoping + domain=netscape_domain, + port=port, + secure=secure, + port_specified=port_specified, + # Contents + name=name, + value=value, + # Constant/always-the-same stuff + version=0, + path="/", + expires=None, + discard=False, + comment=None, + comment_url=None, + rfc2109=False, + path_specified=False, + domain_specified=False, + domain_initial_dot=False, + rest={}, + ) + + +def raid( + jar: CookieJar, *, domain: str, name: Optional[str] = None +) -> Iterable[Cookie]: + """ + Search the cookie jar for matching cookies. + + This is O(n) on the number of cookies in the jar. + + :param jar: The `CookieJar` (or subclass thereof) to search. + + :param domain: + Domain, as in the URL, to match. ``.local`` is appended to + a bare hostname. Subdomains are not matched (i.e., searching + for ``foo.bar.tld`` won't return a cookie set for ``bar.tld``). + + :param name: Cookie name to match (exactly) + + :param path: URL path to match (exactly) + """ + netscape_domain = domain if "." in domain else domain + ".local" + + for c in jar: + if c.domain != netscape_domain: + continue + if name is not None and c.name != name: + continue + yield c diff --git a/src/treq/test/test_cookies.py b/src/treq/test/test_cookies.py new file mode 100644 index 00000000..16895103 --- /dev/null +++ b/src/treq/test/test_cookies.py @@ -0,0 +1,260 @@ +from http.cookiejar import CookieJar, Cookie + +import attrs +from twisted.internet.testing import StringTransport +from twisted.internet.interfaces import IProtocol +from twisted.trial.unittest import SynchronousTestCase +from twisted.python.failure import Failure +from twisted.web.client import ResponseDone +from twisted.web.http_headers import Headers +from twisted.web.iweb import IClientRequest, IResponse +from zope.interface import implementer + +from treq._agentspy import agent_spy, RequestRecord +from treq.client import HTTPClient +from treq.cookies import scoped_cookie, raid + + +@implementer(IClientRequest) +@attrs.define +class _ClientRequest: + absoluteURI: bytes + headers: Headers + method: bytes + + +@implementer(IResponse) +class QuickResponse: + """A response that immediately delivers the body.""" + + version = (b"HTTP", 1, 1) + code = 200 + phrase = "OK" + previousResponse = None + + def __init__(self, record: RequestRecord, headers: Headers, body: bytes = b"") -> None: + self.request = _ClientRequest(record.uri, record.headers or Headers(), record.method) + self.headers = headers + self.length = len(body) + self._body = body + + def deliverBody(self, protocol: IProtocol) -> None: + t = StringTransport() + protocol.makeConnection(t) + if t.producerState != "producing": + raise NotImplementedError("pausing IPushProducer") + protocol.dataReceived(self._body) + protocol.connectionLost(Failure(ResponseDone())) + + def setPreviousResponse(self, response: IResponse) -> None: + raise NotImplementedError + + + +class ScopedCookieTests(SynchronousTestCase): + """Test `treq.cookies.scoped_cookie()`""" + + def test_http(self) -> None: + """Scoping an HTTP origin produces a non-Secure cookie.""" + c = scoped_cookie("http://foo.bar", "x", "y") + self.assertEqual(c.domain, "foo.bar") + self.assertIsNone(c.port) + self.assertFalse(c.port_specified) + self.assertFalse(c.secure) + + def test_https(self) -> None: + """ + Scoping to an HTTPS origin produces a Secure cookie that + won't be sent to HTTP origins. + """ + c = scoped_cookie("https://foo.bar", "x", "y") + self.assertEqual(c.domain, "foo.bar") + self.assertIsNone(c.port) + self.assertFalse(c.port_specified) + self.assertTrue(c.secure) + + def test_port(self) -> None: + """ + Setting a non-default port produces a cookie with that port. + """ + c = scoped_cookie("https://foo.bar:4433", "x", "y") + self.assertEqual(c.domain, "foo.bar") + self.assertEqual(c.port, "4433") + self.assertTrue(c.port_specified) + self.assertTrue(c.secure) + + def test_hostname(self) -> None: + """ + When the origin has a bare hostname, a ``.local`` suffix is applied + to form the cookie domain. + """ + c = scoped_cookie("http://mynas", "x", "y") + self.assertEqual(c.domain, "mynas.local") + + +class RaidTests(SynchronousTestCase): + """Test `treq.cookies.raid()`""" + + def test_domain(self) -> None: + """`raid()` filters by domain.""" + jar = CookieJar() + jar.set_cookie(scoped_cookie("http://an.example", "http", "a")) + jar.set_cookie(scoped_cookie("https://an.example", "https", "b")) + jar.set_cookie(scoped_cookie("https://f.an.example", "subdomain", "c")) + jar.set_cookie(scoped_cookie("https://f.an.example", "https", "d")) + jar.set_cookie(scoped_cookie("https://host", "v", "n")) + + self.assertEqual( + {(c.name, c.value) for c in raid(jar, domain="an.example")}, + {("http", "a"), ("https", "b")}, + ) + self.assertEqual( + {(c.name, c.value) for c in raid(jar, domain="f.an.example")}, + {("subdomain", "c"), ("https", "d")}, + ) + self.assertEqual( + {(c.name, c.value) for c in raid(jar, domain="host")}, + {("v", "n")}, + ) + + def test_name(self) -> None: + """`raid()` filters by cookie name.""" + jar = CookieJar() + jar.set_cookie(scoped_cookie("https://host", "a", "1")) + jar.set_cookie(scoped_cookie("https://host", "b", "2")) + + self.assertEqual({c.value for c in raid(jar, domain="host", name="a")}, {"1"}) + self.assertEqual({c.value for c in raid(jar, domain="host", name="b")}, {"2"}) + + +class HTTPClientCookieTests(SynchronousTestCase): + """Test how HTTPClient's request methods handle the *cookies* argument.""" + + def setUp(self) -> None: + self.agent, self.requests = agent_spy() + self.cookiejar = CookieJar() + self.client = HTTPClient(self.agent, self.cookiejar) + + def test_cookies_in_jars(self) -> None: + """ + Issuing a request with cookies merges them into the client's cookie jar. + Cookies received in a response are also merged into the client's cookie jar. + """ + self.cookiejar.set_cookie( + Cookie( + domain="twisted.example", + port=None, + secure=True, + port_specified=False, + name="a", + value="b", + version=0, + path="/", + expires=None, + discard=False, + comment=None, + comment_url=None, + rfc2109=False, + path_specified=False, + domain_specified=False, + domain_initial_dot=False, + rest={}, + ) + ) + d = self.client.request("GET", "https://twisted.example", cookies={"b": "c"}) + self.assertNoResult(d) + + [request] = self.requests + assert request.headers is not None + self.assertEqual(request.headers.getRawHeaders("Cookie"), ["a=b; b=c"]) + + request.deferred.callback(QuickResponse(request, Headers({"Set-Cookie": ["a=c"]}))) + + response = self.successResultOf(d) + expected = {"a": "c", "b": "c"} + self.assertEqual({c.name: c.value for c in self.cookiejar}, expected) + self.assertEqual({c.name: c.value for c in response.cookies()}, expected) + + def test_cookies_pass_jar(self) -> None: + """ + Passing the *cookies* argument to `HTTPClient.request()` updates + the client's cookie jar and sends cookies with the request. Upon + receipt of the response the client's cookie jar is updated. + """ + self.cookiejar.set_cookie(scoped_cookie("https://tx.example", "a", "a")) + self.cookiejar.set_cookie(scoped_cookie("http://tx.example", "p", "q")) + self.cookiejar.set_cookie(scoped_cookie("https://rx.example", "b", "b")) + + jar = CookieJar() + jar.set_cookie(scoped_cookie("https://tx.example", "a", "b")) + jar.set_cookie(scoped_cookie("https://rx.example", "a", "c")) + + d = self.client.request("GET", "https://tx.example", cookies=jar) + self.assertNoResult(d) + + self.assertEqual( + {(c.domain, c.name, c.value) for c in self.cookiejar}, + { + ("tx.example", "a", "b"), + ("tx.example", "p", "q"), + ("rx.example", "a", "c"), + ("rx.example", "b", "b"), + }, + ) + + [request] = self.requests + assert request.headers is not None + self.assertEqual(request.headers.getRawHeaders("Cookie"), ["a=b; p=q"]) + + def test_cookies_dict(self) -> None: + """ + Passing a dict for the *cookies* argument to `HTTPClient.request()` + creates cookies that are bound to the + + the client's cookie jar and sends cookies with the request. Upon + receipt of the response the client's cookie jar is updated. + """ + d = self.client.request("GET", "https://twisted.example", cookies={"a": "b"}) + self.assertNoResult(d) + + [cookie] = self.cookiejar + self.assertEqual(cookie.name, "a") + self.assertEqual(cookie.value, "b") + # Attributes inferred from the URL: + self.assertEqual(cookie.domain, "twisted.example") + self.assertFalse(cookie.port_specified) + self.assertTrue(cookie.secure) + + [request] = self.requests + assert request.headers is not None + self.assertEqual(request.headers.getRawHeaders("Cookie"), ["a=b"]) + + def test_response_cookies(self) -> None: + """ + The `_Request.cookies()` method returns a copy of the request + cookiejar merged with any cookies from the response. This jar + matches the client cookiejar at the instant the request was + received. + """ + self.cookiejar.set_cookie(scoped_cookie("http://twisted.example", "a", "1")) + self.cookiejar.set_cookie(scoped_cookie("https://twisted.example", "b", "1")) + + d = self.client.request("GET", "https://twisted.example") + [request] = self.requests + request.deferred.callback(QuickResponse(request, Headers({"Set-Cookie": ["a=2; Secure"]}))) + response = self.successResultOf(d) + + # The client jar was updated. + [a] = raid(self.cookiejar, domain="twisted.example", name="a") + self.assertEqual(a.value, "2") + self.assertTrue(a.secure, True) + + responseJar = response.cookies() + self.assertIsNot(self.cookiejar, responseJar) # It's a copy. + self.assertIsNot(self.cookiejar, response.cookies()) # Another copy. + + # They contain the same cookies. + self.assertEqual( + {(c.name, c.value, c.secure) for c in self.cookiejar}, + {(c.name, c.value, c.secure) for c in response.cookies()}, + )