Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Finish up type hints for federation client code #15465

Merged
merged 9 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15465.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
6 changes: 0 additions & 6 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ exclude = (?x)
|synapse/storage/schema/
)$

[mypy-synapse.federation.transport.client]
disallow_untyped_defs = False

[mypy-synapse.http.matrixfederationclient]
disallow_untyped_defs = False

[mypy-synapse.metrics._reactor_metrics]
disallow_untyped_defs = False
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
Expand Down
8 changes: 2 additions & 6 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,11 @@ async def backfill(
logger.debug("backfill transaction_data=%r", transaction_data)

if not isinstance(transaction_data, dict):
# TODO we probably want an exception type specific to federation
# client validation.
raise TypeError("Backfill transaction_data is not a dict.")
raise InvalidResponseError("Backfill transaction_data is not a dict.")
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

transaction_data_pdus = transaction_data.get("pdus")
if not isinstance(transaction_data_pdus, list):
# TODO we probably want an exception type specific to federation
# client validation.
raise TypeError("transaction_data.pdus is not a list.")
raise InvalidResponseError("transaction_data.pdus is not a list.")

room_version = await self.store.get_room_version(room_id)

Expand Down
17 changes: 13 additions & 4 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import urllib
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Expand All @@ -42,18 +43,21 @@
)
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
from synapse.http.types import QueryParams
from synapse.types import JsonDict
from synapse.util import ExceptionBundle

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)


class TransportLayerClient:
"""Sends federation HTTP requests to other servers"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
Expand Down Expand Up @@ -133,7 +137,7 @@ async def get_event(

async def backfill(
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
) -> Optional[JsonDict]:
) -> Optional[Union[JsonDict, list]]:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""Requests `limit` previous PDUs in a given context before list of
PDUs.

Expand Down Expand Up @@ -388,6 +392,7 @@ async def send_leave_v1(
# server was just having a momentary blip, the room will be out of
# sync.
ignore_backoff=True,
parser=LegacyJsonSendParser(),
)

async def send_leave_v2(
Expand Down Expand Up @@ -445,7 +450,11 @@ async def send_invite_v1(
path = _create_v1_path("/invite/%s/%s", room_id, event_id)

return await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True
destination=destination,
path=path,
data=content,
ignore_backoff=True,
parser=LegacyJsonSendParser(),
)

async def send_invite_v2(
Expand Down
76 changes: 58 additions & 18 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import logging
import random
import sys
import typing
import urllib.parse
from http import HTTPStatus
from io import BytesIO, StringIO
Expand All @@ -30,9 +29,11 @@
Generic,
List,
Optional,
TextIO,
Tuple,
TypeVar,
Union,
cast,
overload,
)

Expand Down Expand Up @@ -183,20 +184,61 @@ def get_json(self) -> Optional[JsonDict]:
return self.json


class JsonParser(ByteParser[Union[JsonDict, list]]):
class _BaseJsonParser(ByteParser[T]):
"""A parser that buffers the response and tries to parse it as JSON."""

CONTENT_TYPE = "application/json"

def __init__(self) -> None:
def __init__(
self, validator: Optional[Callable[[Optional[object]], bool]] = None
) -> None:
"""
Args:
validator: A callable which takes the parsed JSON value and returns
true if the value is valid.
"""
self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer)
self._validator = validator

def write(self, data: bytes) -> int:
return self._binary_wrapper.write(data)

def finish(self) -> Union[JsonDict, list]:
return json_decoder.decode(self._buffer.getvalue())
def finish(self) -> T:
result = json_decoder.decode(self._buffer.getvalue())
if self._validator is not None and not self._validator(result):
raise ValueError(
f"Received incorrect JSON value: {result.__class__.__name__}"
)
return result


class JsonParser(_BaseJsonParser[JsonDict]):
"""A parser that buffers the response and tries to parse it as a JSON object."""

def __init__(self) -> None:
super().__init__(self._validate)

@staticmethod
def _validate(v: Any) -> bool:
return isinstance(v, dict)


class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
"""Ensure the legacy responses of /send_join & /send_leave are correct."""

def __init__(self) -> None:
super().__init__(self._validate)

@staticmethod
def _validate(v: Any) -> bool:
# Match [integer, JSON dict]
return (
isinstance(v, list)
and len(v) == 2
and type(v[0]) == int
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
and isinstance(v[1], dict)
)


async def _handle_response(
Expand Down Expand Up @@ -313,9 +355,7 @@ async def _handle_response(
class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly."""

def __init__(
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
):
def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file

Expand Down Expand Up @@ -793,7 +833,7 @@ async def put_json(
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
) -> Union[JsonDict, list]:
) -> JsonDict:
...

@overload
Expand Down Expand Up @@ -825,8 +865,8 @@ async def put_json(
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
):
parser: Optional[ByteParser[T]] = None,
) -> Union[JsonDict, T]:
"""Sends the specified json data using PUT

Args:
Expand Down Expand Up @@ -902,7 +942,7 @@ async def put_json(
_sec_timeout = self.default_timeout

if parser is None:
parser = JsonParser()
parser = cast(ByteParser[T], JsonParser())
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

body = await _handle_response(
self.reactor,
Expand All @@ -924,7 +964,7 @@ async def post_json(
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
) -> JsonDict:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""Sends the specified json data using POST

Args:
Expand Down Expand Up @@ -998,7 +1038,7 @@ async def get_json(
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
) -> Union[JsonDict, list]:
) -> JsonDict:
...

@overload
Expand All @@ -1024,8 +1064,8 @@ async def get_json(
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
):
parser: Optional[ByteParser[T]] = None,
) -> Union[JsonDict, T]:
"""GETs some json from the given host homeserver and path

Args:
Expand Down Expand Up @@ -1091,7 +1131,7 @@ async def get_json(
_sec_timeout = self.default_timeout

if parser is None:
parser = JsonParser()
parser = cast(ByteParser[T], JsonParser())

body = await _handle_response(
self.reactor,
Expand All @@ -1112,7 +1152,7 @@ async def delete_json(
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
) -> JsonDict:
"""Send a DELETE request to the remote expecting some json response

Args:
Expand Down
10 changes: 5 additions & 5 deletions tests/federation/test_complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_join_too_large(self) -> None:
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_join_too_large_admin(self) -> None:
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_join_too_large_once_joined(self) -> None:
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
Expand Down Expand Up @@ -200,7 +200,7 @@ def test_join_too_large_no_admin(self) -> None:
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
Expand Down Expand Up @@ -230,7 +230,7 @@ def test_join_too_large_admin(self) -> None:
fed_transport = self.hs.get_federation_transport_client()

# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
Expand Down
6 changes: 3 additions & 3 deletions tests/http/test_matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from synapse.api.errors import RequestSendFailed
from synapse.http.matrixfederationclient import (
JsonParser,
ByteParser,
MatrixFederationHttpClient,
MatrixFederationRequest,
)
Expand Down Expand Up @@ -618,9 +618,9 @@ def test_too_big(self) -> None:
while not test_d.called:
protocol.dataReceived(b"a" * chunk_size)
sent += chunk_size
self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
self.assertLessEqual(sent, ByteParser.MAX_RESPONSE_SIZE)

self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
self.assertEqual(sent, ByteParser.MAX_RESPONSE_SIZE)

f = self.failureResultOf(test_d)
self.assertIsInstance(f.value, RequestSendFailed)
Expand Down