diff --git a/changelog.d/12477.misc b/changelog.d/12477.misc new file mode 100644 index 000000000000..e793d08e5e3f --- /dev/null +++ b/changelog.d/12477.misc @@ -0,0 +1 @@ +Add some type hints to datastore. \ No newline at end of file diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 46042b2bf7af..8120c305df14 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,6 +15,7 @@ import attr from frozendict import frozendict +from typing_extensions import Literal from twisted.internet.defer import Deferred @@ -106,7 +107,7 @@ class EventContext: incomplete state. """ - rejected: Union[bool, str] = False + rejected: Union[Literal[False], str] = False _state_group: Optional[int] = None state_group_before_event: Optional[int] = None prev_group: Optional[int] = None diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ad611b2c0bb2..6c12653bb3c6 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -49,7 +49,7 @@ from synapse.storage.engines.postgres import PostgresEngine from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator -from synapse.types import StateMap, get_domain_from_id +from synapse.types import JsonDict, StateMap, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically @@ -235,7 +235,9 @@ async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[st """ results: List[str] = [] - def _get_events_which_are_prevs_txn(txn, batch): + def _get_events_which_are_prevs_txn( + txn: LoggingTransaction, batch: Collection[str] + ) -> None: sql = """ SELECT prev_event_id, internal_metadata FROM event_edges @@ -285,7 +287,9 @@ async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str] # and their prev events. existing_prevs = set() - def _get_prevs_before_rejected_txn(txn, batch): + def _get_prevs_before_rejected_txn( + txn: LoggingTransaction, batch: Collection[str] + ) -> None: to_recursively_check = batch while to_recursively_check: @@ -515,7 +519,7 @@ def _persist_event_auth_chain_txn( @classmethod def _add_chain_cover_index( cls, - txn, + txn: LoggingTransaction, db_pool: DatabasePool, event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], @@ -809,7 +813,7 @@ def _add_chain_cover_index( @staticmethod def _allocate_chain_ids( - txn, + txn: LoggingTransaction, db_pool: DatabasePool, event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], @@ -943,7 +947,7 @@ def _persist_transaction_ids_txn( self, txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], - ): + ) -> None: """Persist the mapping from transaction IDs to event IDs (if defined).""" to_insert = [] @@ -997,7 +1001,7 @@ def _update_current_state_txn( txn: LoggingTransaction, state_delta_by_room: Dict[str, DeltaState], stream_id: int, - ): + ) -> None: for room_id, delta_state in state_delta_by_room.items(): to_delete = delta_state.to_delete to_insert = delta_state.to_insert @@ -1155,7 +1159,7 @@ def _update_current_state_txn( txn, room_id, members_changed ) - def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): + def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None: """Update the room version in the database based off current state events. @@ -1189,7 +1193,7 @@ def _update_forward_extremities_txn( txn: LoggingTransaction, new_forward_extremities: Dict[str, Set[str]], max_stream_order: int, - ): + ) -> None: for room_id in new_forward_extremities.keys(): self.db_pool.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} @@ -1254,9 +1258,9 @@ def _filter_events_and_contexts_for_duplicates( def _update_room_depths_txn( self, - txn, + txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], - ): + ) -> None: """Update min_depth for each room Args: @@ -1385,7 +1389,7 @@ def _store_event_txn( # nothing to do here return - def event_dict(event): + def event_dict(event: EventBase) -> JsonDict: d = event.get_dict() d.pop("redacted", None) d.pop("redacted_because", None) @@ -1476,18 +1480,20 @@ def event_dict(event): ), ) - def _store_rejected_events_txn(self, txn, events_and_contexts): + def _store_rejected_events_txn( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> List[Tuple[EventBase, EventContext]]: """Add rows to the 'rejections' table for received events which were rejected Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting + txn: db connection + events_and_contexts: events we are persisting Returns: - list[(EventBase, EventContext)] new list, without the rejected - events. + new list, without the rejected events. """ # Remove the rejected events from the list now that we've added them # to the events table and the events_json table. @@ -1508,7 +1514,7 @@ def _update_metadata_tables_txn( events_and_contexts: List[Tuple[EventBase, EventContext]], all_events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool = False, - ): + ) -> None: """Update all the miscellaneous tables for new events Args: @@ -1602,7 +1608,11 @@ def _update_metadata_tables_txn( # Prefill the event cache self._add_to_cache(txn, events_and_contexts) - def _add_to_cache(self, txn, events_and_contexts): + def _add_to_cache( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> None: to_prefill = [] rows = [] @@ -1633,7 +1643,7 @@ def _add_to_cache(self, txn, events_and_contexts): if not row["rejects"] and not row["redacts"]: to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) - def prefill(): + def prefill() -> None: for cache_entry in to_prefill: self.store._get_event_cache.set( (cache_entry.event.event_id,), cache_entry @@ -1663,19 +1673,24 @@ def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: ) def insert_labels_for_event_txn( - self, txn, event_id, labels, room_id, topological_ordering - ): + self, + txn: LoggingTransaction, + event_id: str, + labels: List[str], + room_id: str, + topological_ordering: int, + ) -> None: """Store the mapping between an event's ID and its labels, with one row per (event_id, label) tuple. Args: - txn (LoggingTransaction): The transaction to execute. - event_id (str): The event's ID. - labels (list[str]): A list of text labels. - room_id (str): The ID of the room the event was sent to. - topological_ordering (int): The position of the event in the room's topology. + txn: The transaction to execute. + event_id: The event's ID. + labels: A list of text labels. + room_id: The ID of the room the event was sent to. + topological_ordering: The position of the event in the room's topology. """ - return self.db_pool.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn=txn, table="event_labels", keys=("event_id", "label", "room_id", "topological_ordering"), @@ -1684,25 +1699,32 @@ def insert_labels_for_event_txn( ], ) - def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + def _insert_event_expiry_txn( + self, txn: LoggingTransaction, event_id: str, expiry_ts: int + ) -> None: """Save the expiry timestamp associated with a given event ID. Args: - txn (LoggingTransaction): The database transaction to use. - event_id (str): The event ID the expiry timestamp is associated with. - expiry_ts (int): The timestamp at which to expire (delete) the event. + txn: The database transaction to use. + event_id: The event ID the expiry timestamp is associated with. + expiry_ts: The timestamp at which to expire (delete) the event. """ - return self.db_pool.simple_insert_txn( + self.db_pool.simple_insert_txn( txn=txn, table="event_expiry", values={"event_id": event_id, "expiry_ts": expiry_ts}, ) def _store_room_members_txn( - self, txn, events, *, inhibit_local_membership_updates: bool = False - ): + self, + txn: LoggingTransaction, + events: List[EventBase], + *, + inhibit_local_membership_updates: bool = False, + ) -> None: """ Store a room member in the database. + Args: txn: The transaction to use. events: List of events to store. @@ -1742,6 +1764,7 @@ def non_null_str_or_none(val: Any) -> Optional[str]: ) for event in events: + assert event.internal_metadata.stream_ordering is not None txn.call_after( self.store._membership_stream_cache.entity_has_changed, event.state_key, @@ -1838,7 +1861,9 @@ def _handle_event_relations( (parent_id, event.sender), ) - def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): + def _handle_insertion_event( + self, txn: LoggingTransaction, event: EventBase + ) -> None: """Handles keeping track of insertion events and edges/connections. Part of MSC2716. @@ -1899,7 +1924,7 @@ def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): }, ) - def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase): + def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None: """Handles inserting the batch edges/connections between the batch event and an insertion event. Part of MSC2716. @@ -1999,25 +2024,29 @@ def _handle_redact_relations( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) - def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase): + def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("topic"), str): self.store_event_search_txn( txn, event, "content.topic", event.content["topic"] ) - def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase): + def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("name"), str): self.store_event_search_txn( txn, event, "content.name", event.content["name"] ) - def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase): + def _store_room_message_txn( + self, txn: LoggingTransaction, event: EventBase + ) -> None: if isinstance(event.content.get("body"), str): self.store_event_search_txn( txn, event, "content.body", event.content["body"] ) - def _store_retention_policy_for_room_txn(self, txn, event): + def _store_retention_policy_for_room_txn( + self, txn: LoggingTransaction, event: EventBase + ) -> None: if not event.is_state(): logger.debug("Ignoring non-state m.room.retention event") return @@ -2077,8 +2106,11 @@ def store_event_search_txn( ) def _set_push_actions_for_event_and_users_txn( - self, txn, events_and_contexts, all_events_and_contexts - ): + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + all_events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> None: """Handles moving push actions from staging table to main event_push_actions table for all events in `events_and_contexts`. @@ -2086,12 +2118,10 @@ def _set_push_actions_for_event_and_users_txn( from the push action staging area. Args: - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. + events_and_contexts: events we are persisting + all_events_and_contexts: all events that we were going to persist. + This includes events we've already persisted, etc, that wouldn't + appear in events_and_context. """ # Only non outlier events will have push actions associated with them, @@ -2160,7 +2190,9 @@ def _set_push_actions_for_event_and_users_txn( ), ) - def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): + def _remove_push_actions_for_event_id_txn( + self, txn: LoggingTransaction, room_id: str, event_id: str + ) -> None: # Sad that we have to blow away the cache for the whole room here txn.call_after( self.store.get_unread_event_push_actions_by_room_for_user.invalidate, @@ -2171,7 +2203,9 @@ def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): (room_id, event_id), ) - def _store_rejections_txn(self, txn, event_id, reason): + def _store_rejections_txn( + self, txn: LoggingTransaction, event_id: str, reason: str + ) -> None: self.db_pool.simple_insert_txn( txn, table="rejections", @@ -2183,8 +2217,10 @@ def _store_rejections_txn(self, txn, event_id, reason): ) def _store_event_state_mappings_txn( - self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] - ): + self, + txn: LoggingTransaction, + events_and_contexts: Collection[Tuple[EventBase, EventContext]], + ) -> None: state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): @@ -2241,7 +2277,9 @@ def _store_event_state_mappings_txn( state_group_id, ) - def _update_min_depth_for_room_txn(self, txn, room_id, depth): + def _update_min_depth_for_room_txn( + self, txn: LoggingTransaction, room_id: str, depth: int + ) -> None: min_depth = self.store._get_min_depth_interaction(txn, room_id) if min_depth is not None and depth >= min_depth: @@ -2254,7 +2292,9 @@ def _update_min_depth_for_room_txn(self, txn, room_id, depth): values={"min_depth": depth}, ) - def _handle_mult_prev_events(self, txn, events): + def _handle_mult_prev_events( + self, txn: LoggingTransaction, events: List[EventBase] + ) -> None: """ For the given event, update the event edges table and forward and backward extremities tables. @@ -2272,7 +2312,9 @@ def _handle_mult_prev_events(self, txn, events): self._update_backward_extremeties(txn, events) - def _update_backward_extremeties(self, txn, events): + def _update_backward_extremeties( + self, txn: LoggingTransaction, events: List[EventBase] + ) -> None: """Updates the event_backward_extremities tables based on the new/updated events being persisted. diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 3c49e7ec98e2..78e0773b2a88 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -14,7 +14,7 @@ import logging import re -from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple import attr @@ -27,7 +27,7 @@ LoggingTransaction, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.types import JsonDict if TYPE_CHECKING: @@ -149,7 +149,9 @@ def __init__( self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings ) - async def _background_reindex_search(self, progress, batch_size): + async def _background_reindex_search( + self, progress: JsonDict, batch_size: int + ) -> int: # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -157,7 +159,7 @@ async def _background_reindex_search(self, progress, batch_size): TYPES = ["m.room.name", "m.room.message", "m.room.topic"] - def reindex_search_txn(txn): + def reindex_search_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT stream_ordering, event_id, room_id, type, json, " " origin_server_ts FROM events" @@ -255,12 +257,14 @@ def reindex_search_txn(txn): return result - async def _background_reindex_gin_search(self, progress, batch_size): + async def _background_reindex_gin_search( + self, progress: JsonDict, batch_size: int + ) -> int: """This handles old synapses which used GIST indexes, if any; converting them back to be GIN as per the actual schema. """ - def create_index(conn): + def create_index(conn: LoggingDatabaseConnection) -> None: conn.rollback() # we have to set autocommit, because postgres refuses to @@ -299,7 +303,9 @@ def create_index(conn): ) return 1 - async def _background_reindex_search_order(self, progress, batch_size): + async def _background_reindex_search_order( + self, progress: JsonDict, batch_size: int + ) -> int: target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -307,7 +313,7 @@ async def _background_reindex_search_order(self, progress, batch_size): if not have_added_index: - def create_index(conn): + def create_index(conn: LoggingDatabaseConnection) -> None: conn.rollback() conn.set_session(autocommit=True) c = conn.cursor() @@ -336,7 +342,7 @@ def create_index(conn): pg, ) - def reindex_search_txn(txn): + def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]: sql = ( "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," " origin_server_ts = e.origin_server_ts" @@ -644,7 +650,8 @@ async def search_rooms( else: raise Exception("Unrecognized database engine") - args.append(limit) + # mypy expects to append only a `str`, not an `int` + args.append(limit) # type: ignore[arg-type] results = await self.db_pool.execute( "search_rooms", self.db_pool.cursor_to_dict, sql, *args @@ -705,7 +712,7 @@ async def _find_highlights_in_postgres( A set of strings. """ - def f(txn): + def f(txn: LoggingTransaction) -> Set[str]: highlight_words = set() for event in events: # As a hack we simply join values of all possible keys. This is @@ -759,11 +766,11 @@ def f(txn): return await self.db_pool.runInteraction("_find_highlights", f) -def _to_postgres_options(options_dict): +def _to_postgres_options(options_dict: JsonDict) -> str: return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) -def _parse_query(database_engine, search_term): +def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str: """Takes a plain unicode string from the user and converts it into a form that can be passed to database. We use this so that we can add prefix matching, which isn't something