From 0065f62d8c808230e324b42bee472ffb9032d591 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Thu, 9 Dec 2021 20:59:25 +0000 Subject: [PATCH] Add type hints to `synapse/storage/databases/main/end_to_end_keys.py` --- changelog.d/11551.misc | 1 + mypy.ini | 4 +- synapse/storage/databases/main/__init__.py | 3 - .../storage/databases/main/end_to_end_keys.py | 211 ++++++++++++------ 4 files changed, 150 insertions(+), 69 deletions(-) create mode 100644 changelog.d/11551.misc diff --git a/changelog.d/11551.misc b/changelog.d/11551.misc new file mode 100644 index 000000000000..d451940bf216 --- /dev/null +++ b/changelog.d/11551.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. diff --git a/mypy.ini b/mypy.ini index 1caf807e8505..5772757d552f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -29,7 +29,6 @@ exclude = (?x) |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py |synapse/storage/databases/main/e2e_room_keys.py - |synapse/storage/databases/main/end_to_end_keys.py |synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/event_push_actions.py |synapse/storage/databases/main/events_bg_updates.py @@ -187,6 +186,9 @@ disallow_untyped_defs = True [mypy-synapse.storage.databases.main.directory] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.end_to_end_keys] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.events_worker] disallow_untyped_defs = True diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 9ff2d8d8c35a..065145c0d280 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -143,9 +143,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): ("device_lists_outbound_pokes", "stream_id"), ], ) - self._cross_signing_id_gen = StreamIdGenerator( - db_conn, "e2e_cross_signing_keys", "stream_id" - ) self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index b06c1dc45b2d..57b5ffbad32b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,19 +14,32 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + cast, +) import attr from canonicaljson import encode_canonical_json -from twisted.enterprise.adbapi import Connection - from synapse.api.constants import DeviceKeyAlgorithms from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -50,7 +63,12 @@ class DeviceKeyLookupResult: class EndToEndKeyBackgroundStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -62,8 +80,13 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer" ) -class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): +class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._allow_device_name_lookup_over_federation = ( @@ -124,7 +147,7 @@ async def get_e2e_device_keys_for_cs_api( # Build the result structure, un-jsonify the results, and add the # "unsigned" section - rv = {} + rv: Dict[str, Dict[str, JsonDict]] = {} for user_id, device_keys in results.items(): rv[user_id] = {} for device_id, device_info in device_keys.items(): @@ -195,6 +218,10 @@ async def get_e2e_device_keys_and_signatures( # add each cross-signing signature to the correct device in the result dict. for (user_id, key_id, device_id, signature) in cross_sigs_result: target_device_result = result[user_id][device_id] + # We've only looked up cross-signatures for non-deleted devices with key + # data. + assert target_device_result is not None + assert target_device_result.keys is not None target_device_signatures = target_device_result.keys.setdefault( "signatures", {} ) @@ -207,7 +234,11 @@ async def get_e2e_device_keys_and_signatures( return result def _get_e2e_device_keys_txn( - self, txn, query_list, include_all_devices=False, include_deleted_devices=False + self, + txn: LoggingTransaction, + query_list: Collection[Tuple[str, str]], + include_all_devices: bool = False, + include_deleted_devices: bool = False, ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: """Get information on devices from the database @@ -263,7 +294,7 @@ def _get_e2e_device_keys_txn( return result def _get_e2e_cross_signing_signatures_for_devices_txn( - self, txn: Cursor, device_query: Iterable[Tuple[str, str]] + self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]] ) -> List[Tuple[str, str, str, str]]: """Get cross-signing signatures for a given list of devices @@ -289,7 +320,17 @@ def _get_e2e_cross_signing_signatures_for_devices_txn( ) txn.execute(signature_sql, signature_query_params) - return txn.fetchall() + return cast( + List[ + Tuple[ + str, + str, + str, + str, + ] + ], + txn.fetchall(), + ) async def get_e2e_one_time_keys( self, user_id: str, device_id: str, key_ids: List[str] @@ -335,7 +376,7 @@ async def add_e2e_one_time_keys( new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ - def _add_e2e_one_time_keys(txn): + def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None: set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("new_keys", new_keys) @@ -375,7 +416,7 @@ async def count_e2e_one_time_keys( A mapping from algorithm to number of keys for that algorithm. """ - def _count_e2e_one_time_keys(txn): + def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]: sql = ( "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ?" @@ -421,7 +462,11 @@ async def set_e2e_fallback_keys( ) def _set_e2e_fallback_keys_txn( - self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + fallback_keys: JsonDict, ) -> None: # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad @@ -483,7 +528,7 @@ async def get_e2e_unused_fallback_key_types( async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None - ) -> Optional[dict]: + ) -> Optional[JsonDict]: """Returns a user's cross-signing key. Args: @@ -504,7 +549,7 @@ async def get_e2e_cross_signing_key( return user_keys.get(key_type) @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id): + def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ @@ -517,7 +562,7 @@ def _get_bare_e2e_cross_signing_keys(self, user_id): ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -531,32 +576,35 @@ async def _get_bare_e2e_cross_signing_keys_bulk( their user ID will map to None. """ - return await self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_bare_e2e_cross_signing_keys_bulk", self._get_bare_e2e_cross_signing_keys_bulk_txn, user_ids, ) + # The `Optional` comes from the `@cachedList` decorator. + return cast(Dict[str, Optional[Dict[str, JsonDict]]], result) + def _get_bare_e2e_cross_signing_keys_bulk_txn( self, - txn: Connection, + txn: LoggingTransaction, user_ids: Iterable[str], - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Dict[str, JsonDict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_ids (list[str]): the users whose keys are being requested + txn: db connection + user_ids: the users whose keys are being requested Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. If a user's cross-signing keys were not found, their user - ID will not be in the dict. + Mapping from user ID to key type to key data. + If a user's cross-signing keys were not found, their user ID will not be in + the dict. """ - result = {} + result: Dict[str, Dict[str, JsonDict]] = {} for user_chunk in batch_iter(user_ids, 100): clause, params = make_in_list_sql_clause( @@ -596,43 +644,48 @@ def _get_bare_e2e_cross_signing_keys_bulk_txn( user_id = row["user_id"] key_type = row["keytype"] key = db_to_json(row["keydata"]) - user_info = result.setdefault(user_id, {}) - user_info[key_type] = key + user_keys = result.setdefault(user_id, {}) + user_keys[key_type] = key return result def _get_e2e_cross_signing_signatures_txn( self, - txn: Connection, - keys: Dict[str, Dict[str, dict]], + txn: LoggingTransaction, + keys: Dict[str, Optional[Dict[str, JsonDict]]], from_user_id: str, - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing signatures made by a user on a set of keys. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - keys (dict[str, dict[str, dict]]): a map of user ID to key type to - key data. This dict will be modified to add signatures. - from_user_id (str): fetch the signatures made by this user + txn: db connection + keys: a map of user ID to key type to key data. + This dict will be modified to add signatures. + from_user_id: fetch the signatures made by this user Returns: - dict[str, dict[str, dict]]: mapping from user ID to key type to key - data. The return value will be the same as the keys argument, - with the modifications included. + Mapping from user ID to key type to key data. + The return value will be the same as the keys argument, with the + modifications included. """ # find out what cross-signing keys (a.k.a. devices) we need to get # signatures for. This is a map of (user_id, device_id) to key type # (device_id is the key's public part). - devices = {} + devices: Dict[Tuple[str, str], str] = {} - for user_id, user_info in keys.items(): - if user_info is None: + for user_id, user_keys in keys.items(): + if user_keys is None: continue - for key_type, key in user_info.items(): + for key_type, key in user_keys.items(): device_id = None for k in key["keys"].values(): device_id = k + # `key` ought to be a `CrossSigningKey`, whose .keys property is a + # dictionary with a single entry: + # "algorithm:base64_public_key": "base64_public_key" + # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing + assert isinstance(device_id, str) devices[(user_id, device_id)] = key_type for batch in batch_iter(devices.keys(), size=100): @@ -656,15 +709,20 @@ def _get_e2e_cross_signing_signatures_txn( # and add the signatures to the appropriate keys for row in rows: - key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] + key_id: str = row["key_id"] + target_user_id: str = row["target_user_id"] + target_device_id: str = row["target_device_id"] key_type = devices[(target_user_id, target_device_id)] # We need to copy everything, because the result may have come # from the cache. dict.copy only does a shallow copy, so we # need to recursively copy the dicts that will be modified. - user_info = keys[target_user_id] = keys[target_user_id].copy() - target_user_key = user_info[key_type] = user_info[key_type].copy() + user_keys = keys[target_user_id] + # `user_keys` cannot be `None` because we only fetched signatures for + # users with keys + assert user_keys is not None + user_keys = keys[target_user_id] = user_keys.copy() + + target_user_key = user_keys[key_type] = user_keys[key_type].copy() if "signatures" in target_user_key: signatures = target_user_key["signatures"] = target_user_key[ "signatures" @@ -683,7 +741,7 @@ def _get_e2e_cross_signing_signatures_txn( async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Optional[Dict[str, dict]]]: + ) -> Dict[str, Optional[Dict[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -741,7 +799,9 @@ async def get_all_user_signature_changes_for_remotes( if last_id == current_id: return [], current_id, False - def _get_all_user_signature_changes_for_remotes_txn(txn): + def _get_all_user_signature_changes_for_remotes_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT stream_id, from_user_id AS user_id FROM user_signature_stream @@ -785,7 +845,7 @@ async def claim_e2e_one_time_keys( @trace def _claim_e2e_one_time_key_simple( - txn, user_id: str, device_id: str, algorithm: str + txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str ) -> Optional[Tuple[str, str]]: """Claim OTK for device for DBs that don't support RETURNING. @@ -825,7 +885,7 @@ def _claim_e2e_one_time_key_simple( @trace def _claim_e2e_one_time_key_returning( - txn, user_id: str, device_id: str, algorithm: str + txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str ) -> Optional[Tuple[str, str]]: """Claim OTK for device for DBs that support RETURNING. @@ -860,7 +920,7 @@ def _claim_e2e_one_time_key_returning( key_id, key_json = otk_row return f"{algorithm}:{key_id}", key_json - results = {} + results: Dict[str, Dict[str, Dict[str, str]]] = {} for user_id, device_id, algorithm in query_list: if self.database_engine.supports_returning: # If we support RETURNING clause we can use a single query that @@ -930,6 +990,18 @@ def _claim_e2e_one_time_key_returning( class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._cross_signing_id_gen = StreamIdGenerator( + db_conn, "e2e_cross_signing_keys", "stream_id" + ) + async def set_e2e_device_keys( self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict ) -> bool: @@ -937,7 +1009,7 @@ async def set_e2e_device_keys( or the keys were already in the database. """ - def _set_e2e_device_keys_txn(txn): + def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool: set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("time_now", time_now) @@ -973,7 +1045,7 @@ def _set_e2e_device_keys_txn(txn): ) async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: - def delete_e2e_keys_by_device_txn(txn): + def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: log_kv( { "message": "Deleting keys for device", @@ -1012,17 +1084,24 @@ def delete_e2e_keys_by_device_txn(txn): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id): + def _set_e2e_cross_signing_key_txn( + self, + txn: LoggingTransaction, + user_id: str, + key_type: str, + key: JsonDict, + stream_id: int, + ) -> None: """Set a user's cross-signing key. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_id (str): the user to set the signing key for - key_type (str): the type of key that is being set: either 'master' + txn: db connection + user_id: the user to set the signing key for + key_type: the type of key that is being set: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key - key (dict): the key data - stream_id (int) + key: the key data + stream_id """ # the 'key' dict will look something like: # { @@ -1075,13 +1154,15 @@ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id) txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) - async def set_e2e_cross_signing_key(self, user_id, key_type, key): + async def set_e2e_cross_signing_key( + self, user_id: str, key_type: str, key: JsonDict + ) -> None: """Set a user's cross-signing key. Args: - user_id (str): the user to set the user-signing key for - key_type (str): the type of cross-signing key to set - key (dict): the key data + user_id: the user to set the user-signing key for + key_type: the type of cross-signing key to set + key: the key data """ async with self._cross_signing_id_gen.get_next() as stream_id: