From 7337c482ea42dea0c585a680e5736e7dabcc361e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 28 Aug 2020 16:41:25 +0100 Subject: [PATCH 1/3] Split fetching device keys and signatures into two transactions I think this is simpler (and moves stuff out of the db threads) --- changelog.d/8233.misc | 1 + .../storage/databases/main/end_to_end_keys.py | 108 +++++++++++------- 2 files changed, 65 insertions(+), 44 deletions(-) create mode 100644 changelog.d/8233.misc diff --git a/changelog.d/8233.misc b/changelog.d/8233.misc new file mode 100644 index 000000000000..979c8b227bbc --- /dev/null +++ b/changelog.d/8233.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index cc0b15ae0787..6726c923f685 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -25,6 +25,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.types import Cursor from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -45,8 +46,9 @@ class DeviceKeyLookupResult: # key) and "signatures" (a signature of the structure by the ed25519 key) key_json = attr.ib(type=Optional[str]) - # cross-signing sigs - signatures = attr.ib(type=Optional[Dict], default=None) + # cross-signing sigs on this device. + # dict from (signing user_id)->(signing device_id)->sig + signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict) class EndToEndKeyWorkerStore(SQLBaseStore): @@ -154,22 +156,57 @@ async def get_e2e_device_keys_and_signatures( result = await self.db_pool.runInteraction( "get_e2e_device_keys", - self._get_e2e_device_keys_and_signatures_txn, + self._get_e2e_device_keys_txn, query_list, include_all_devices, include_deleted_devices, ) + # get the (user_id, device_id) tuples to look up cross-signatures for + signature_query = ( + (user_id, device_id) + for user_id, dev in result.items() + for device_id, d in dev.items() + if d is not None + ) + + for batch in batch_iter(signature_query, 50): + cross_sigs_result = await self.db_pool.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_for_devices_txn, + batch, + ) + + # add each cross-signing signature to the correct device in the result dict. + for row in cross_sigs_result: + signing_user_id = row["user_id"] + signing_key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + signature = row["signature"] + + target_device_result = result[target_user_id][target_device_id] + target_device_signatures = target_device_result.signatures + + signing_user_signatures = target_device_signatures.setdefault( + signing_user_id, {} + ) + signing_user_signatures[signing_key_id] = signature + log_kv(result) return result - def _get_e2e_device_keys_and_signatures_txn( + def _get_e2e_device_keys_txn( self, txn, query_list, include_all_devices=False, include_deleted_devices=False ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + """Get information on devices from the database + + The results include the device's keys and self-signatures, but *not* any + cross-signing signatures which have been added subsequently (for which, see + get_e2e_device_keys_and_signatures) + """ query_clauses = [] query_params = [] - signature_query_clauses = [] - signature_query_params = [] if include_all_devices is False: include_deleted_devices = False @@ -180,20 +217,12 @@ def _get_e2e_device_keys_and_signatures_txn( for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) - signature_query_clause = "target_user_id = ?" - signature_query_params.append(user_id) if device_id is not None: query_clause += " AND device_id = ?" query_params.append(device_id) - signature_query_clause += " AND target_device_id = ?" - signature_query_params.append(device_id) - - signature_query_clause += " AND user_id = ?" - signature_query_params.append(user_id) query_clauses.append(query_clause) - signature_query_clauses.append(signature_query_clause) sql = ( "SELECT user_id, device_id, " @@ -221,41 +250,32 @@ def _get_e2e_device_keys_and_signatures_txn( for user_id, device_id in deleted_devices: result.setdefault(user_id, {})[device_id] = None - # get signatures on the device - signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % ( - " OR ".join("(" + q + ")" for q in signature_query_clauses) - ) - - txn.execute(signature_sql, signature_query_params) - rows = self.db_pool.cursor_to_dict(txn) - - # add each cross-signing signature to the correct device in the result dict. - for row in rows: - signing_user_id = row["user_id"] - signing_key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] - signature = row["signature"] - - target_user_result = result.get(target_user_id) - if not target_user_result: - continue + return result - target_device_result = target_user_result.get(target_device_id) - if not target_device_result: - # note that target_device_result will be None for deleted devices. - continue + def _get_e2e_cross_signing_signatures_for_devices_txn( + self, txn: Cursor, device_query: Iterable[Tuple[str, str]] + ) -> List[Dict]: + """Get cross-signing signatures for a given list of devices - target_device_signatures = target_device_result.signatures - if target_device_signatures is None: - target_device_signatures = target_device_result.signatures = {} + Returns signatures made by the owner of the devices. Each entry in the result + is a dict containing the fields from the database ('user_id', 'key_id', + 'target_user_id', 'target_device_id', 'signature'). + """ + signature_query_clauses = [] + signature_query_params = [] - signing_user_signatures = target_device_signatures.setdefault( - signing_user_id, {} + for (user_id, device_id) in device_query: + signature_query_clauses.append( + "target_user_id = ? AND target_device_id = ? AND user_id = ?" ) - signing_user_signatures[signing_key_id] = signature + signature_query_params.extend([user_id, device_id, user_id]) - return result + signature_sql = "SELECT * FROM e2e_cross_signing_signatures WHERE %s" % ( + " OR ".join("(" + q + ")" for q in signature_query_clauses) + ) + + txn.execute(signature_sql, signature_query_params) + return self.db_pool.cursor_to_dict(txn) async def get_e2e_one_time_keys( self, user_id: str, device_id: str, key_ids: List[str] From b5ae2d95cc586ae908ec63266c4cf1a054c044ca Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 3 Sep 2020 17:20:10 +0100 Subject: [PATCH 2/3] Update synapse/storage/databases/main/end_to_end_keys.py Co-authored-by: Patrick Cloke --- synapse/storage/databases/main/end_to_end_keys.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 6726c923f685..0f12c6a61baa 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -201,10 +201,10 @@ def _get_e2e_device_keys_txn( ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: """Get information on devices from the database - The results include the device's keys and self-signatures, but *not* any - cross-signing signatures which have been added subsequently (for which, see - get_e2e_device_keys_and_signatures) - """ + The results include the device's keys and self-signatures, but *not* any + cross-signing signatures which have been added subsequently (for which, see + get_e2e_device_keys_and_signatures) + """ query_clauses = [] query_params = [] From 9d5ba5867a85f3e8c169c9fd0ef214e51c2c9349 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 3 Sep 2020 17:52:15 +0100 Subject: [PATCH 3/3] Avoid returning unnecessary columns --- .../storage/databases/main/end_to_end_keys.py | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 0f12c6a61baa..09af03323371 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -135,7 +135,10 @@ async def get_e2e_device_keys_and_signatures( include_all_devices: bool = False, include_deleted_devices: bool = False, ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: - """Fetch a list of device keys, together with their cross-signatures. + """Fetch a list of device keys + + Any cross-signatures made on the keys by the owner of the device are also + included. Args: query_list: List of pairs of user_ids and device_ids. Device id can be None @@ -178,20 +181,14 @@ async def get_e2e_device_keys_and_signatures( ) # add each cross-signing signature to the correct device in the result dict. - for row in cross_sigs_result: - signing_user_id = row["user_id"] - signing_key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] - signature = row["signature"] - - target_device_result = result[target_user_id][target_device_id] + for (user_id, key_id, device_id, signature) in cross_sigs_result: + target_device_result = result[user_id][device_id] target_device_signatures = target_device_result.signatures signing_user_signatures = target_device_signatures.setdefault( - signing_user_id, {} + user_id, {} ) - signing_user_signatures[signing_key_id] = signature + signing_user_signatures[key_id] = signature log_kv(result) return result @@ -254,12 +251,13 @@ def _get_e2e_device_keys_txn( def _get_e2e_cross_signing_signatures_for_devices_txn( self, txn: Cursor, device_query: Iterable[Tuple[str, str]] - ) -> List[Dict]: + ) -> List[Tuple[str, str, str, str]]: """Get cross-signing signatures for a given list of devices - Returns signatures made by the owner of the devices. Each entry in the result - is a dict containing the fields from the database ('user_id', 'key_id', - 'target_user_id', 'target_device_id', 'signature'). + Returns signatures made by the owners of the devices. + + Returns: a list of results; each entry in the list is a tuple of + (user_id, key_id, target_device_id, signature). """ signature_query_clauses = [] signature_query_params = [] @@ -270,12 +268,15 @@ def _get_e2e_cross_signing_signatures_for_devices_txn( ) signature_query_params.extend([user_id, device_id, user_id]) - signature_sql = "SELECT * FROM e2e_cross_signing_signatures WHERE %s" % ( + signature_sql = """ + SELECT user_id, key_id, target_device_id, signature + FROM e2e_cross_signing_signatures WHERE %s + """ % ( " OR ".join("(" + q + ")" for q in signature_query_clauses) ) txn.execute(signature_sql, signature_query_params) - return self.db_pool.cursor_to_dict(txn) + return txn.fetchall() async def get_e2e_one_time_keys( self, user_id: str, device_id: str, key_ids: List[str]