Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type annotations to SimpleHttpClient #8372

Merged
merged 3 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8372.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations to `SimpleHttpClient`.
169 changes: 113 additions & 56 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
import urllib
from io import BytesIO
from typing import Any, BinaryIO, Dict, Iterable, List, Optional, Tuple, Union

import treq
from canonicaljson import encode_canonical_json
Expand All @@ -37,6 +38,7 @@
from twisted.web.client import Agent, HTTPConnectionPool, readBody
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse

from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import (
Expand All @@ -57,6 +59,12 @@
"synapse_http_client_responses", "", ["method", "code"]
)

# the type of the headers list, to be passed to the t.w.h.Headers
RawHeaders = Dict[Union[str, bytes], List[Union[str, bytes]]]

# the type of the query params, to be passed into `urlencode`
QueryParams = Dict[Union[str, bytes], Union[str, bytes, Iterable[Union[str, bytes]]]]


def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
"""
Expand Down Expand Up @@ -285,13 +293,26 @@ def __getattr__(_self, attr):
ip_blacklist=self._ip_blacklist,
)

async def request(self, method, uri, data=None, headers=None):
async def request(
self,
method: str,
uri: str,
data: Optional[bytes] = None,
headers: Optional[Headers] = None,
) -> IResponse:
"""
Args:
method (str): HTTP method to use.
uri (str): URI to query.
data (bytes): Data to send in the request body, if applicable.
headers (t.w.http_headers.Headers): Request headers.
method: HTTP method to use.
uri: URI to query.
data: Data to send in the request body, if applicable.
headers: Request headers.

Returns:
Response object, once the headers have been read.

Raises:
RequestTimedOutError if the request times out before the headers are read

"""
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
Expand Down Expand Up @@ -324,6 +345,8 @@ async def request(self, method, uri, data=None, headers=None):
headers=headers,
**self._extra_treq_args
)
# we use our own timeout mechanism rather than treq's as a workaround
# for https://twistedmatrix.com/trac/ticket/9534.
request_deferred = timeout_deferred(
request_deferred,
60,
Expand Down Expand Up @@ -353,18 +376,26 @@ async def request(self, method, uri, data=None, headers=None):
set_tag("error_reason", e.args[0])
raise

async def post_urlencoded_get_json(self, uri, args={}, headers=None):
async def post_urlencoded_get_json(
self,
uri: str,
args: Dict[str, Union[str, List[str]]] = {},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this meant to be the QueryParams type? I think not since it is used differently than get_raw does, but wanted to double check!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hadn't really thought about it tbh. It's worth noting that they aren't actually query params, and as you say it's used differently. I just copied the old type def from the docstring here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to leave it this way then. 👍

headers: Optional[RawHeaders] = None,
) -> Any:
"""
Args:
uri (str):
args (dict[str, str|List[str]]): query params
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
uri: uri to query
args: parameters to be url-encoded in the body
headers: a map from header name to a list of values for that header

Returns:
object: parsed json
parsed json

Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

HttpResponseException: On a non-2xx HTTP response.

ValueError: if the response was not JSON
Expand Down Expand Up @@ -398,19 +429,24 @@ async def post_urlencoded_get_json(self, uri, args={}, headers=None):
response.code, response.phrase.decode("ascii", errors="replace"), body
)

async def post_json_get_json(self, uri, post_json, headers=None):
async def post_json_get_json(
self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
) -> Any:
"""

Args:
uri (str):
post_json (object):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
uri: URI to query.
post_json: request body, to be encoded as json
headers: a map from header name to a list of values for that header

Returns:
object: parsed json
parsed json

Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

HttpResponseException: On a non-2xx HTTP response.

ValueError: if the response was not JSON
Expand Down Expand Up @@ -440,21 +476,22 @@ async def post_json_get_json(self, uri, post_json, headers=None):
response.code, response.phrase.decode("ascii", errors="replace"), body
)

async def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.
async def get_json(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
) -> Any:
"""Gets some json from the given URI.

Args:
uri (str): The URI to request, not including query parameters
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
uri: The URI to request, not including query parameters
args: A dictionary used to create query string
headers: a map from header name to a list of values for that header
Returns:
Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

HttpResponseException On a non-2xx HTTP response.

ValueError: if the response was not JSON
Expand All @@ -466,22 +503,27 @@ async def get_json(self, uri, args={}, headers=None):
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))

async def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
async def put_json(
self,
uri: str,
json_body: Any,
args: QueryParams = {},
headers: RawHeaders = None,
) -> Any:
"""Puts some json to the given URI.

Args:
uri (str): The URI to request, not including query parameters
json_body (dict): The JSON to put in the HTTP body,
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
uri: The URI to request, not including query parameters
json_body: The JSON to put in the HTTP body,
args: A dictionary used to create query strings
headers: a map from header name to a list of values for that header
Returns:
Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

HttpResponseException On a non-2xx HTTP response.

ValueError: if the response was not JSON
Expand Down Expand Up @@ -513,21 +555,23 @@ async def put_json(self, uri, json_body, args={}, headers=None):
response.code, response.phrase.decode("ascii", errors="replace"), body
)

async def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.
async def get_raw(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
) -> bytes:
"""Gets raw text from the given URI.

Args:
uri (str): The URI to request, not including query parameters
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
uri: The URI to request, not including query parameters
args: A dictionary used to create query strings
headers: a map from header name to a list of values for that header
Returns:
Succeeds when we get *any* 2xx HTTP response, with the
Succeeds when we get a 2xx HTTP response, with the
HTTP body as bytes.
Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
Expand All @@ -552,16 +596,29 @@ async def get_raw(self, uri, args={}, headers=None):
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.

async def get_file(self, url, output_stream, max_size=None, headers=None):
async def get_file(
self,
url: str,
output_stream: BinaryIO,
max_size: Optional[int] = None,
headers: Optional[RawHeaders] = None,
) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
"""GETs a file from a given URL
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
url: The URL to GET
output_stream: File to write the response body to.
headers: A map from header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
A tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.

Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

SynapseError: if the response is not a 2xx, the remote file is too large, or
another exception happens during the download.
"""

actual_headers = {b"User-Agent": [self.user_agent]}
Expand Down