diff --git a/changelog.d/11354.misc b/changelog.d/11354.misc new file mode 100644 index 000000000000..86594a332db9 --- /dev/null +++ b/changelog.d/11354.misc @@ -0,0 +1 @@ +Add type hints to storage classes. diff --git a/mypy.ini b/mypy.ini index b2953974ea10..c133e1715d65 100644 --- a/mypy.ini +++ b/mypy.ini @@ -26,7 +26,6 @@ exclude = (?x) |synapse/storage/databases/__init__.py |synapse/storage/databases/main/__init__.py |synapse/storage/databases/main/account_data.py - |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py |synapse/storage/databases/main/e2e_room_keys.py |synapse/storage/databases/main/end_to_end_keys.py diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 640ff15277db..8ac345c86852 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -35,6 +35,7 @@ from synapse.logging.context import ( make_deferred_yieldable, run_in_background, ) +from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore @@ -57,7 +58,6 @@ from synapse.storage.databases.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore -from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) @@ -179,10 +179,10 @@ class Store( MainStateBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore, EndToEndKeyBackgroundStore, - StatsStore, PusherWorkerStore, PresenceBackgroundUpdateStore, GroupServerWorkerStore, + SlavedEventStore, ): def execute(self, f, *args, **kwargs): return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) @@ -229,6 +229,10 @@ class MockHomeserver: def get_instance_name(self): return "master" + def should_send_federation(self) -> bool: + "Should this server be sending federation traffic directly?" + return False + class Porter(object): def __init__(self, **kwargs): diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 42238f7f280b..138c9dda0c88 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -28,7 +28,6 @@ from synapse.config.logger import setup_logging from synapse.events import EventBase from synapse.handlers.admin import ExfiltrationWriter -from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore @@ -61,7 +60,6 @@ class AdminCmdSlavedStore( SlavedPushRuleStore, SlavedEventStore, SlavedClientIpStore, - BaseSlavedStore, RoomWorkerStore, ): pass diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 46f0feff7000..2549a145b6ca 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -47,14 +47,12 @@ from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource -from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.directory import DirectoryStore -from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.filtering import SlavedFilteringStore from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.slave.storage.keys import SlavedKeyStore @@ -114,7 +112,6 @@ from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.session import SessionStore -from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore @@ -223,7 +220,6 @@ class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. UserDirectoryStore, - StatsStore, UIAuthWorkerStore, EndToEndRoomKeyStore, PresenceStore, @@ -236,7 +232,6 @@ class GenericWorkerSlavedStore( SlavedPusherStore, CensorEventsStore, ClientIpWorkerStore, - SlavedEventStore, SlavedKeyStore, RoomWorkerStore, DirectoryStore, @@ -252,7 +247,6 @@ class GenericWorkerSlavedStore( TransactionWorkerStore, LockStore, SessionStore, - BaseSlavedStore, ): # Properties that multiple storage classes define. Tell mypy what the # expected type is. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 7ecb446e7c78..7dac1f9ff394 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore @@ -30,9 +30,7 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen: Optional[ - MultiWriterIdGenerator - ] = MultiWriterIdGenerator( + self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, stream_name="caches", diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 63ed50caa5eb..864b0719767e 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -16,16 +16,7 @@ from typing import TYPE_CHECKING from synapse.storage.database import DatabasePool -from synapse.storage.databases.main.event_federation import EventFederationWorkerStore -from synapse.storage.databases.main.event_push_actions import ( - EventPushActionsWorkerStore, -) -from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.databases.main.relations import RelationsWorkerStore -from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore -from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -47,15 +38,8 @@ class SlavedEventStore( - EventFederationWorkerStore, - RoomMemberWorkerStore, - EventPushActionsWorkerStore, - StreamWorkerStore, StateGroupWorkerStore, - EventsWorkerStore, - SignatureWorkerStore, UserErasureWorkerStore, - RelationsWorkerStore, BaseSlavedStore, ): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 4d5f86286213..061511d78b1a 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -20,7 +20,7 @@ from .events import SlavedEventStore -class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): +class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore): def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() diff --git a/synapse/storage/database.py b/synapse/storage/database.py index d4cab69ebfe5..8611fe492c0a 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -175,7 +175,7 @@ def commit(self) -> None: def rollback(self) -> None: self.conn.rollback() - def __enter__(self) -> "Connection": + def __enter__(self) -> "LoggingDatabaseConnection": self.conn.__enter__() return self diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 9ff2d8d8c35a..e5a2369a22cb 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -67,7 +67,6 @@ from .session import SessionStore from .signatures import SignatureStore from .state import StateStore -from .stats import StatsStore from .stream import StreamStore from .tags import TagsStore from .transactions import TransactionWorkerStore @@ -119,7 +118,6 @@ class DataStore( GroupServerStore, UserErasureStore, MonthlyActiveUsersStore, - StatsStore, RelationsStore, CensorEventsStore, UIAuthStore, @@ -154,7 +152,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): db_conn, "local_group_updates", "stream_id" ) - self._cache_id_gen: Optional[MultiWriterIdGenerator] if isinstance(self.database_engine, PostgresEngine): # We set the `writers` to an empty list here as we don't care about # missing updates over restarts, as we'll not have anything in our diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 36e8422fc63b..066365a8b5c4 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -15,7 +15,7 @@ import itertools import logging -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Tuple from synapse.api.constants import EventTypes from synapse.replication.tcp.streams import BackfillStream, CachesStream @@ -24,9 +24,22 @@ EventsStreamCurrentStateRow, EventsStreamEventRow, ) -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore +from synapse.storage.databases.main.event_push_actions import ( + EventPushActionsWorkerStore, +) +from synapse.storage.databases.main.relations import RelationsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore +from synapse.storage.databases.main.state_deltas import StateDeltasStore +from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.util.caches.descriptors import _CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -39,16 +52,35 @@ # based on the current state when notifying workers over replication. CURRENT_STATE_CACHE_NAME = "cs_cache_fake" +# Corresponds to the (cache_func, keys, invalidation_ts) db columns. +_CacheData = Tuple[str, Optional[List[str]], Optional[int]] + + +class CacheInvalidationWorkerStore( + EventFederationWorkerStore, + RelationsWorkerStore, + EventPushActionsWorkerStore, + StreamWorkerStore, + StateDeltasStore, + RoomMemberWorkerStore, +): + # This class must be mixed in with a child class which provides the following + # attribute. TODO: can we get static analysis to enforce this? + _cache_id_gen: Optional[MultiWriterIdGenerator] -class CacheInvalidationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() async def get_all_updated_caches( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + ) -> Tuple[List[Tuple[int, _CacheData]], int, bool]: """Get updates for caches replication stream. Args: @@ -73,7 +105,9 @@ async def get_all_updated_caches( if last_id == current_id: return [], current_id, False - def get_all_updated_caches_txn(txn): + def get_all_updated_caches_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, _CacheData]], int, bool]: # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. @@ -85,7 +119,13 @@ def get_all_updated_caches_txn(txn): LIMIT ? """ txn.execute(sql, (last_id, instance_name, limit)) - updates = [(row[0], row[1:]) for row in txn] + updates: List[Tuple[int, _CacheData]] = [] + row: Tuple[int, str, Optional[List[str]], Optional[int]] + # Type safety: iterating over `txn` yields `Tuple`, i.e. + # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a + # variadic tuple to a fixed length tuple and flags it up as an error. + for row in txn: # type: ignore[assignment] + updates.append((row[0], row[1:])) limited = False upto_token = current_id if len(updates) >= limit: @@ -192,7 +232,9 @@ def _invalidate_caches_for_event( self.get_aggregation_groups_for_event.invalidate((relates_to,)) self.get_applicable_edit.invalidate((relates_to,)) - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + async def invalidate_cache_and_stream( + self, cache_name: str, keys: Tuple[str, ...] + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -212,7 +254,9 @@ async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, .. keys, ) - def _invalidate_cache_and_stream(self, txn, cache_func, keys): + def _invalidate_cache_and_stream( + self, txn: LoggingTransaction, cache_func: _CachedFunction, keys: Iterable[str] + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -223,7 +267,9 @@ def _invalidate_cache_and_stream(self, txn, cache_func, keys): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - def _invalidate_all_cache_and_stream(self, txn, cache_func): + def _invalidate_all_cache_and_stream( + self, txn: LoggingTransaction, cache_func: _CachedFunction + ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. """ @@ -231,7 +277,9 @@ def _invalidate_all_cache_and_stream(self, txn, cache_func): txn.call_after(cache_func.invalidate_all) self._send_invalidation_to_replication(txn, cache_func.__name__, None) - def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + def _invalidate_state_caches_and_stream( + self, txn: LoggingTransaction, room_id: str, members_changed: Collection[str] + ) -> None: """Special case invalidation of caches based on current state. We special case this so that we can batch the cache invalidations into a @@ -239,8 +287,8 @@ def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): Args: txn - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have changed + room_id: Room where state changed + members_changed: The user_ids of members that have changed """ txn.call_after(self._invalidate_state_caches, room_id, members_changed) @@ -262,8 +310,8 @@ def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): ) def _send_invalidation_to_replication( - self, txn, cache_name: str, keys: Optional[Iterable[Any]] - ): + self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[str]] + ) -> None: """Notifies replication that given cache has been invalidated. Note that this does *not* invalidate the cache locally. @@ -284,6 +332,7 @@ def _send_invalidation_to_replication( # the transaction. However, we want to only get an ID when we want # to use it, here, so we need to call __enter__ manually, and have # __exit__ called after the transaction finishes. + assert self._cache_id_gen is not None stream_id = self._cache_id_gen.get_next_txn(txn) txn.call_after(self.hs.get_notifier().on_new_replication_data) @@ -298,7 +347,7 @@ def _send_invalidation_to_replication( "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, - "invalidation_ts": self.clock.time_msec(), + "invalidation_ts": self._clock.time_msec(), }, ) diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 0f56e10220d0..0a087fdb2e92 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -17,10 +17,8 @@ from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.util import json_encoder if TYPE_CHECKING: @@ -30,7 +28,7 @@ logger = logging.getLogger(__name__) -class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): +class CensorEventsStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 68901b43352b..404cd96278be 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -18,15 +18,11 @@ from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore -from synapse.storage.databases.main.event_federation import EventFederationWorkerStore logger = logging.getLogger(__name__) -class EventForwardExtremitiesStore( - EventFederationWorkerStore, - CacheInvalidationWorkerStore, -): +class EventForwardExtremitiesStore(CacheInvalidationWorkerStore): async def delete_forward_extremities_for_room(self, room_id: str) -> int: """Delete any extra forward extremities for a room. diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index fa782023d4ee..695d1304582b 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -19,7 +19,7 @@ from synapse.api.errors import NotFoundError, StoreError from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore @@ -71,7 +71,6 @@ class PushRulesWorkerStore( PusherWorkerStore, RoomMemberWorkerStore, EventsWorkerStore, - SQLBaseStore, metaclass=abc.ABCMeta, ): """This is an abstract base class where subclasses must implement diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6c7d6ba50848..556878dd02ca 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1877,7 +1877,7 @@ async def is_guest(self, user_id: str) -> bool: return res if res else False -class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): +class RegistrationStore(RegistrationBackgroundUpdateStore, StatsStore): def __init__( self, database: DatabasePool, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 5d7b59d861c9..f735a4217658 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -100,7 +100,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.server_name = hs.hostname - self.clock = self.hs.get_clock() self.stats_enabled = hs.config.stats.stats_enabled self.stats_delta_processing_lock = DeferredLock() @@ -601,7 +600,7 @@ def _fetch_current_state_stats(txn): local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] await self.update_stats_delta( - ts=self.clock.time_msec(), + ts=self._clock.time_msec(), stats_type="room", stats_id=room_id, fields={}, @@ -638,7 +637,7 @@ def _calculate_and_set_initial_state_for_user_txn(txn): ) await self.update_stats_delta( - ts=self.clock.time_msec(), + ts=self._clock.time_msec(), stats_type="user", stats_id=user_id, fields={}, diff --git a/synapse/visibility.py b/synapse/visibility.py index 17532059e9f8..cca1b275b2de 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -146,7 +146,9 @@ def allowed(event: EventBase) -> Optional[EventBase]: max_lifetime = retention_policy.get("max_lifetime") if max_lifetime is not None: - oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime + # TODO: reveal_type(storage.main) yields Any. Can we find a way of + # telling mypy that storage.main is a generic `DataStoreT`? + oldest_allowed_ts = storage.main._clock.time_msec() - max_lifetime if event.origin_server_ts < oldest_allowed_ts: return None