Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix typing for @cached wrapped functions (#8240)
Browse files Browse the repository at this point in the history
This requires adding a mypy plugin to fiddle with the type signatures a bit.
  • Loading branch information
erikjohnston authored Sep 3, 2020
1 parent 15c35c2 commit 208e1d3
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 20 deletions.
1 change: 1 addition & 0 deletions changelog.d/8240.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix type hints for functions decorated with `@cached`.
3 changes: 2 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[mypy]
namespace_packages = True
plugins = mypy_zope:plugin
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
follow_imports = silent
check_untyped_defs = True
show_error_codes = True
Expand Down Expand Up @@ -51,6 +51,7 @@ files =
synapse/storage/util,
synapse/streams,
synapse/types.py,
synapse/util/caches/descriptors.py,
synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py,
tests/replication,
Expand Down
85 changes: 85 additions & 0 deletions scripts-dev/mypy_synapse_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This is a mypy plugin for Synpase to deal with some of the funky typing that
can crop up, e.g the cache descriptors.
"""

from typing import Callable, Optional

from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType


class SynapsePlugin(Plugin):
def get_method_signature_hook(
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
):
return cached_function_method_signature
return None


def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `_CachedFunction.__call__` signature to be correct.
It already has *almost* the correct signature, except:
1. the `self` argument needs to be marked as "bound"; and
2. any `cache_context` argument should be removed.
"""

# First we mark this as a bound function signature.
signature = bind_self(ctx.default_signature)

# Secondly, we remove any "cache_context" args.
#
# Note: We should be only doing this if `cache_context=True` is set, but if
# it isn't then the code will raise an exception when its called anyway, so
# its not the end of the world.
context_arg_index = None
for idx, name in enumerate(signature.arg_names):
if name == "cache_context":
context_arg_index = idx
break

if context_arg_index:
arg_types = list(signature.arg_types)
arg_types.pop(context_arg_index)

arg_names = list(signature.arg_names)
arg_names.pop(context_arg_index)

arg_kinds = list(signature.arg_kinds)
arg_kinds.pop(context_arg_index)

signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
)

return signature


def plugin(version: str):
# This is the entry point of the plugin, and let's us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version
# string.
#
# However, since we pin the version of mypy Synapse uses in CI, we don't
# really care.
return SynapsePlugin
10 changes: 5 additions & 5 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,11 @@ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
if not prevs - seen:
return

latest = await self.store.get_latest_event_ids_in_room(room_id)
latest_list = await self.store.get_latest_event_ids_in_room(room_id)

# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest)
latest = set(latest_list)
latest |= seen

logger.info(
Expand Down Expand Up @@ -781,7 +781,7 @@ async def _process_received_pdu(
# keys across all devices.
current_keys = [
key
for device in cached_devices
for device in cached_devices.values()
for key in device.get("keys", {}).get("keys", {}).values()
]

Expand Down Expand Up @@ -2119,8 +2119,8 @@ async def _check_for_soft_fail(
if backfilled or event.internal_metadata.is_outlier():
return

extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids)
extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())

if extrem_ids == prev_event_ids:
Expand Down
42 changes: 28 additions & 14 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
import inspect
import logging
import threading
from typing import Any, Tuple, Union, cast
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from weakref import WeakValueDictionary

from prometheus_client import Gauge
from typing_extensions import Protocol

from twisted.internet import defer

Expand All @@ -38,17 +37,22 @@

CacheKey = Union[Tuple, Any]

F = TypeVar("F", bound=Callable[..., Any])

class _CachedFunction(Protocol):

class _CachedFunction(Generic[F]):
invalidate = None # type: Any
invalidate_all = None # type: Any
invalidate_many = None # type: Any
prefill = None # type: Any
cache = None # type: Any
num_args = None # type: Any

def __name__(self):
...
__name__ = None # type: str

# Note: This function signature is actually fiddled with by the synapse mypy
# plugin to a) make it a bound method, and b) remove any `cache_context` arg.
__call__ = None # type: F


cache_pending_metric = Gauge(
Expand Down Expand Up @@ -123,7 +127,7 @@ def __init__(

self.name = name
self.keylen = keylen
self.thread = None
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
Expand Down Expand Up @@ -662,9 +666,13 @@ def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheCon


def cached(
max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
):
return lambda orig: CacheDescriptor(
max_entries: int = 1000,
num_args: Optional[int] = None,
tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
Expand All @@ -673,8 +681,12 @@ def cached(
iterable=iterable,
)

return cast(Callable[[F], _CachedFunction[F]], func)

def cachedList(cached_method_name, list_name, num_args=None):

def cachedList(
cached_method_name: str, list_name: str, num_args: Optional[int] = None
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
Expand All @@ -684,11 +696,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
cache.
Args:
cached_method_name (str): The name of the single-item lookup method.
cached_method_name: The name of the single-item lookup method.
This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to
list_name: The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
num_args: Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.
Example:
Expand All @@ -702,9 +714,11 @@ def do_something(self, first_arg):
def batch_do_something(self, first_arg, second_args):
...
"""
return lambda orig: CacheListDescriptor(
func = lambda orig: CacheListDescriptor(
orig,
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
)

return cast(Callable[[F], _CachedFunction[F]], func)

0 comments on commit 208e1d3

Please sign in to comment.