Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add some missing type hints to cache datastore. (#12216)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Mar 16, 2022
1 parent 8696560 commit c486fa5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 21 deletions.
1 change: 1 addition & 0 deletions changelog.d/12216.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints for cache storage.
57 changes: 36 additions & 21 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
Expand All @@ -31,6 +32,7 @@
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand Down Expand Up @@ -82,7 +84,9 @@ async def get_all_updated_caches(
if last_id == current_id:
return [], current_id, False

def get_all_updated_caches_txn(txn):
def get_all_updated_caches_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
Expand All @@ -107,7 +111,9 @@ def get_all_updated_caches_txn(txn):
"get_all_updated_caches", get_all_updated_caches_txn
)

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
Expand Down Expand Up @@ -142,10 +148,11 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):

super().process_replication_rows(stream_name, instance_name, token, rows)

def _process_event_stream_row(self, token, row):
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data

if row.type == EventsStreamEventRow.TypeId:
assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event(
token,
data.event_id,
Expand All @@ -157,9 +164,8 @@ def _process_event_stream_row(self, token, row):
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed(
row.data.room_id, token
)
assert isinstance(data, EventsStreamCurrentStateRow)
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)

if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
Expand All @@ -170,15 +176,15 @@ def _process_event_stream_row(self, token, row):

def _invalidate_caches_for_event(
self,
stream_ordering,
event_id,
room_id,
etype,
state_key,
redacts,
relates_to,
backfilled,
):
stream_ordering: int,
event_id: str,
room_id: str,
etype: str,
state_key: Optional[str],
redacts: Optional[str],
relates_to: Optional[str],
backfilled: bool,
) -> None:
self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))

Expand Down Expand Up @@ -207,7 +213,9 @@ def _invalidate_caches_for_event(
self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,))

async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
Expand All @@ -227,7 +235,12 @@ async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ..
keys,
)

def _invalidate_cache_and_stream(self, txn, cache_func, keys):
def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
Expand All @@ -238,7 +251,9 @@ def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)

def _invalidate_all_cache_and_stream(self, txn, cache_func):
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""
Expand Down Expand Up @@ -279,8 +294,8 @@ def _invalidate_state_caches_and_stream(
)

def _send_invalidation_to_replication(
self, txn, cache_name: str, keys: Optional[Iterable[Any]]
):
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:
"""Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally.
Expand Down Expand Up @@ -315,7 +330,7 @@ def _send_invalidation_to_replication(
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
"invalidation_ts": self._clock.time_msec(),
},
)

Expand Down

0 comments on commit c486fa5

Please sign in to comment.