Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Make _get_e2e_device_keys_and_signatures_txn return an attrs #8224

Merged
merged 2 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8224.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.
8 changes: 4 additions & 4 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,17 +293,17 @@ async def _get_device_update_edus_by_remote(
prev_id = stream_id

if device is not None:
key_json = device.get("key_json", None)
key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)

if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
if device.signatures:
for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)

device_display_name = device.get("device_display_name", None)
device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
Expand Down
52 changes: 36 additions & 16 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple

import attr
from canonicaljson import encode_canonical_json

from twisted.enterprise.adbapi import Connection
Expand All @@ -33,6 +34,21 @@
from synapse.handlers.e2e_keys import SignatureListItem


@attr.s
class DeviceKeyLookupResult:
"""The type returned by _get_e2e_device_keys_and_signatures_txn"""

display_name = attr.ib(type=Optional[str])

# the key data from e2e_device_keys_json. Typically includes fields like
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
# key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str])

# cross-signing sigs
signatures = attr.ib(type=Optional[Dict], default=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a Dict[str, str]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well yes it probably is. It's also going away in a couple of commits time 😇



class EndToEndKeyWorkerStore(SQLBaseStore):
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
Expand Down Expand Up @@ -61,17 +77,17 @@ def _get_e2e_device_keys_for_federation_query_txn(
for device_id, device in user_devices.items():
result = {"device_id": device_id}

key_json = device.get("key_json", None)
key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)

if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
if device.signatures:
for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)

device_display_name = device.get("device_display_name", None)
device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name

Expand Down Expand Up @@ -109,13 +125,13 @@ async def get_e2e_device_keys_for_cs_api(
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
r = db_to_json(device_info.pop("key_json"))
r = db_to_json(device_info.key_json)
r["unsigned"] = {}
display_name = device_info["device_display_name"]
display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
if "signatures" in device_info:
for sig_user_id, sigs in device_info["signatures"].items():
if device_info.signatures:
for sig_user_id, sigs in device_info.signatures.items():
r.setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
Expand All @@ -126,7 +142,7 @@ async def get_e2e_device_keys_for_cs_api(
@trace
def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[Dict]]]:
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)

Expand Down Expand Up @@ -161,7 +177,7 @@ def _get_e2e_device_keys_and_signatures_txn(

sql = (
"SELECT user_id, device_id, "
" d.display_name AS device_display_name, "
" d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
Expand All @@ -172,13 +188,14 @@ def _get_e2e_device_keys_and_signatures_txn(
)

txn.execute(sql, query_params)
rows = self.db_pool.cursor_to_dict(txn)

result = {}
for row in rows:
result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
deleted_devices.remove((row["user_id"], row["device_id"]))
result.setdefault(row["user_id"], {})[row["device_id"]] = row
deleted_devices.remove((user_id, device_id))
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
display_name, key_json
)

if include_deleted_devices:
for user_id, device_id in deleted_devices:
Expand Down Expand Up @@ -209,7 +226,10 @@ def _get_e2e_device_keys_and_signatures_txn(
# note that target_device_result will be None for deleted devices.
continue

target_device_signatures = target_device_result.setdefault("signatures", {})
target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}

signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
Expand Down