Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix additional type hints from Twisted upgrade #9518

Merged
merged 11 commits into from
Mar 3, 2021
1 change: 1 addition & 0 deletions changelog.d/9518.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix incorrect type hints.
18 changes: 9 additions & 9 deletions synapse/http/federation/matrix_federation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
from typing import List, Optional
from typing import Any, Generator, List, Optional

from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer
Expand Down Expand Up @@ -116,7 +116,7 @@ def request(
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
) -> defer.Deferred:
) -> Generator[defer.Deferred, Any, defer.Deferred]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't switched this to async since request is part of the IAgent interface.

"""
Args:
method: HTTP method: GET/POST/etc
Expand Down Expand Up @@ -177,17 +177,17 @@ def request(
# We need to make sure the host header is set to the netloc of the
# server and that a user-agent is provided.
if headers is None:
headers = Headers()
request_headers = Headers()
else:
headers = headers.copy()
request_headers = headers.copy()

if not headers.hasHeader(b"host"):
headers.addRawHeader(b"host", parsed_uri.netloc)
if not headers.hasHeader(b"user-agent"):
headers.addRawHeader(b"user-agent", self.user_agent)
if not request_headers.hasHeader(b"host"):
request_headers.addRawHeader(b"host", parsed_uri.netloc)
if not request_headers.hasHeader(b"user-agent"):
request_headers.addRawHeader(b"user-agent", self.user_agent)

res = yield make_deferred_yieldable(
self._agent.request(method, uri, headers, bodyProducer)
self._agent.request(method, uri, request_headers, bodyProducer)
)

return res
Expand Down
6 changes: 3 additions & 3 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
RequestSendFailed: if the Content-Type header is missing or isn't JSON

"""
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
content_type_headers = headers.getRawHeaders(b"Content-Type")
if content_type_headers is None:
raise RequestSendFailed(
RuntimeError("No Content-Type header received from remote server"),
can_retry=False,
)

c_type = c_type[0].decode("ascii") # only the first header
c_type = content_type_headers[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RequestSendFailed(
Expand Down
29 changes: 18 additions & 11 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import types
import urllib
from http import HTTPStatus
from inspect import isawaitable
from io import BytesIO
from typing import (
Any,
Expand All @@ -30,6 +31,7 @@
Iterable,
Iterator,
List,
Optional,
Pattern,
Tuple,
Union,
Expand Down Expand Up @@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients."""

if f.check(SynapseError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we change this to isinstance(f.value, SynapseError) perhaps?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. This is the proper way to check against twisted failures?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, probably easier (and more likely) for us to remove a # type: ignore in the future than switch away from an isinstance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my thought is that if we want to make that change we should do is as a discrete step, not bury it in this PR!

error_code = f.value.code
error_dict = f.value.error_dict()
# mypy doesn't understand that f.check asserts the type.
exc = f.value # type: SynapseError # type: ignore
error_code = exc.code
error_dict = exc.error_dict()

logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
else:
error_code = 500
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
Expand All @@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)

# Only respond with an error response if we haven't already started writing,
Expand Down Expand Up @@ -128,7 +132,8 @@ def return_html_error(
`{msg}` placeholders), or a jinja2 template
"""
if f.check(CodeMessageException):
cme = f.value
# mypy doesn't understand that f.check asserts the type.
cme = f.value # type: CodeMessageException # type: ignore
code = cme.code
msg = cme.msg

Expand All @@ -142,7 +147,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
else:
code = HTTPStatus.INTERNAL_SERVER_ERROR
Expand All @@ -151,7 +156,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)

if isinstance(error_template, str):
Expand Down Expand Up @@ -278,7 +283,7 @@ async def _async_render(self, request: Request):
raw_callback_return = method_handler(request)

# Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
if isawaitable(raw_callback_return):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return # type: ignore
Expand Down Expand Up @@ -399,8 +404,10 @@ def _get_handler_for_request(
A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback
"""
# At this point the path must be bytes.
request_path_bytes = request.path # type: bytes # type: ignore
request_path = request_path_bytes.decode("ascii")
# Treat HEAD requests as GET requests.
request_path = request.path.decode("ascii")
request_method = request.method
if request_method == b"HEAD":
request_method = b"GET"
Expand Down Expand Up @@ -551,7 +558,7 @@ def __init__(
request: Request,
iterator: Iterator[bytes],
):
self._request = request
self._request = request # type: Optional[Request]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We set it to None when we stop producing so we need to consider it optional here.

self._iterator = iterator
self._paused = False

Expand All @@ -563,7 +570,7 @@ def _send_data(self, data: List[bytes]) -> None:
"""
Send a list of bytes as a chunk of a response.
"""
if not data:
if not data or not self._request:
return
self._request.write(b"".join(data))

Expand Down
31 changes: 21 additions & 10 deletions synapse/http/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import contextlib
import logging
import time
from typing import Optional, Union
from typing import Optional, Type, Union

import attr
from zope.interface import implementer
Expand Down Expand Up @@ -57,7 +57,7 @@ class SynapseRequest(Request):

def __init__(self, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
self.site = channel.site
self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests
self.start_time = 0.0

Expand Down Expand Up @@ -96,25 +96,34 @@ def __repr__(self):
def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq)

def get_redacted_uri(self):
uri = self.uri
def get_redacted_uri(self) -> str:
"""Gets the redacted URI associated with the request (or placeholder if not
method has yet been received).
clokep marked this conversation as resolved.
Show resolved Hide resolved

Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.uri`.

Returns:
The redacted URI as as string.
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""
uri = self.uri # type: Union[bytes, str]
if isinstance(uri, bytes):
uri = self.uri.decode("ascii", errors="replace")
uri = uri.decode("ascii", errors="replace")
return redact_uri(uri)

def get_method(self):
def get_method(self) -> str:
"""Gets the method associated with the request (or placeholder if not
method has yet been received).
clokep marked this conversation as resolved.
Show resolved Hide resolved

Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.method`.

Returns:
str
The request method as as string.
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""
method = self.method
method = self.method # type: Union[bytes, str]
if isinstance(method, bytes):
method = self.method.decode("ascii")
return self.method.decode("ascii")
return method

def render(self, resrc):
Expand Down Expand Up @@ -432,7 +441,9 @@ def __init__(

assert config.http_options is not None
proxied = config.http_options.x_forwarded
self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
self.requestFactory = (
XForwardedForRequest if proxied else SynapseRequest
) # type: Type[Request]
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")

Expand Down
6 changes: 4 additions & 2 deletions synapse/logging/_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
from twisted.internet.interfaces import IPushProducer, ITransport
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
from twisted.internet.protocol import Factory, Protocol
from twisted.python.failure import Failure

Expand Down Expand Up @@ -121,7 +121,9 @@ def __init__(
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
endpoint = TCP4ClientEndpoint(
_reactor, self.host, self.port
) # type: IStreamClientEndpoint
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
else:
Expand Down
11 changes: 6 additions & 5 deletions synapse/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def collect(self):
REGISTRY.register(ReactorLastSeenMetric())


def runUntilCurrentTimer(func):
def runUntilCurrentTimer(reactor, func):
@functools.wraps(func)
def f(*args, **kwargs):
now = reactor.seconds()
Expand Down Expand Up @@ -590,13 +590,14 @@ def f(*args, **kwargs):

try:
# Ensure the reactor has all the attributes we expect
reactor.runUntilCurrent
reactor._newTimedCalls
reactor.threadCallQueue
reactor.seconds # type: ignore
reactor.runUntilCurrent # type: ignore
reactor._newTimedCalls # type: ignore
reactor.threadCallQueue # type: ignore

# runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore

# We manually run the GC each reactor tick so that we can get some metrics
# about time spent doing GC,
Expand Down
4 changes: 2 additions & 2 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Iterable, Optional, Tuple
from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple

from twisted.internet import defer

Expand Down Expand Up @@ -307,7 +307,7 @@ async def complete_sso_login_async(
@defer.inlineCallbacks
def get_state_events_in_room(
self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
) -> defer.Deferred:
) -> Generator[defer.Deferred, Any, defer.Deferred]:
"""Gets current state events for the given room.

(This is exposed for compatibility with the old SpamCheckerApi. We should
Expand Down
5 changes: 3 additions & 2 deletions synapse/push/httppusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
# limitations under the License.
import logging
import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union

from prometheus_client import Counter

from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from twisted.internet.interfaces import IDelayedCall

from synapse.api.constants import EventTypes
from synapse.events import EventBase
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusher_config.failing_since
self.timed_call = None
self.timed_call = None # type: Optional[IDelayedCall]
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
self._pusherpool = hs.get_pusherpool()
Expand Down
4 changes: 1 addition & 3 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ def __init__(self, hs: "HomeServer"):

# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
self._streams_to_waiters = (
{}
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]]

async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
Expand Down
3 changes: 2 additions & 1 deletion synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import twisted.internet.base
import twisted.internet.tcp
from twisted.internet import defer
from twisted.mail.smtp import sendmail
from twisted.web.iweb import IPolicyForHTTPS

Expand Down Expand Up @@ -403,7 +404,7 @@ def get_room_shutdown_handler(self) -> RoomShutdownHandler:
return RoomShutdownHandler(self)

@cache_in_self
def get_sendmail(self) -> sendmail:
def get_sendmail(self) -> Callable[..., defer.Deferred]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started adding a better type signature here, but sendmail is fairly complicated and can take bytes and str as almost every argument.

return sendmail

@cache_in_self
Expand Down
Loading