From 5282ba1e2bbff2635dc09aec45fd42a56c1a4545 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 28 Mar 2023 14:26:27 -0400 Subject: [PATCH] Implement MSC3983 to proxy /keys/claim queries to appservices. (#15314) Experimental support for MSC3983 is behind a configuration flag. If enabled, for users which are exclusively owned by an application service then the appservice will be queried for one-time keys *if* there are none uploaded to Synapse. --- changelog.d/15314.feature | 1 + synapse/appservice/api.py | 56 ++++++++++++++ synapse/config/experimental.py | 5 ++ synapse/federation/federation_server.py | 20 ++--- synapse/handlers/appservice.py | 74 +++++++++++++++++- synapse/handlers/e2e_keys.py | 57 ++++++++++++-- .../storage/databases/main/end_to_end_keys.py | 36 ++++++--- tests/appservice/test_api.py | 59 ++++++++++++++ tests/handlers/test_e2e_keys.py | 76 ++++++++++++++++++- 9 files changed, 355 insertions(+), 29 deletions(-) create mode 100644 changelog.d/15314.feature diff --git a/changelog.d/15314.feature b/changelog.d/15314.feature new file mode 100644 index 000000000000..68b289b0cc7c --- /dev/null +++ b/changelog.d/15314.feature @@ -0,0 +1 @@ +Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)). diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 4812fb44961f..51ee0e79df10 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -388,6 +388,62 @@ async def push_bulk( failed_transactions_counter.labels(service.id).inc() return False + async def claim_client_keys( + self, service: "ApplicationService", query: List[Tuple[str, str, str]] + ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: + """Claim one time keys from an application service. + + Args: + query: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A tuple of: + A map of user ID -> a map device ID -> a map of key ID -> JSON dict. + + A copy of the input which has not been fulfilled because the + appservice doesn't support this endpoint or has not returned + data for that tuple. + """ + if service.url is None: + return {}, query + + # This is required by the configuration. + assert service.hs_token is not None + + # Create the expected payload shape. + body: Dict[str, Dict[str, List[str]]] = {} + for user_id, device, algorithm in query: + body.setdefault(user_id, {}).setdefault(device, []).append(algorithm) + + uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim" + try: + response = await self.post_json_get_json( + uri, + body, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, + ) + except CodeMessageException as e: + # The appservice doesn't support this endpoint. + if e.code == 404 or e.code == 405: + return {}, query + logger.warning("claim_keys to %s received %s", uri, e.code) + return {}, query + except Exception as ex: + logger.warning("claim_keys to %s threw exception %s", uri, ex) + return {}, query + + # Check if the appservice fulfilled all of the queried user/device/algorithms + # or if some are still missing. + # + # TODO This places a lot of faith in the response shape being correct. + missing = [ + (user_id, device, algorithm) + for user_id, device, algorithm in query + if algorithm not in response.get(user_id, {}).get(device, []) + ] + + return response, missing + def _serialize( self, service: "ApplicationService", events: Iterable[EventBase] ) -> List[JsonDict]: diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 99dcd27c7426..53e6fc2b54d6 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -74,6 +74,11 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: "msc3202_transaction_extensions", False ) + # MSC3983: Proxying OTK claim requests to exclusive ASes. + self.msc3983_appservice_otk_claims: bool = experimental.get( + "msc3983_appservice_otk_claims", False + ) + # MSC3706 (server-side support for partial state in /send_join responses) # Synapse will always serve partial state responses to requests using the stable # query parameter `omit_members`. If this flag is set, Synapse will also serve diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 6d99845de561..64e99292ec04 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -86,7 +86,7 @@ from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary from synapse.types import JsonDict, StateMap, get_domain_from_id -from synapse.util import json_decoder, unwrapFirstError +from synapse.util import unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_server_name @@ -135,6 +135,7 @@ def __init__(self, hs: "HomeServer"): self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() self._room_member_handler = hs.get_room_member_handler() + self._e2e_keys_handler = hs.get_e2e_keys_handler() self._state_storage_controller = hs.get_storage_controllers().state @@ -1012,15 +1013,14 @@ async def on_claim_client_keys( query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = await self.store.claim_e2e_one_time_keys(query) - - json_result: Dict[str, Dict[str, dict]] = {} - for user_id, device_keys in results.items(): - for device_id, keys in device_keys.items(): - for key_id, json_str in keys.items(): - json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_str) - } + results = await self._e2e_keys_handler.claim_local_one_time_keys(query) + + json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + for result in results: + for user_id, device_keys in result.items(): + for device_id, keys in device_keys.items(): + for key_id, key in keys.items(): + json_result.setdefault(user_id, {})[device_id] = {key_id: key} logger.info( "Claimed one-time-keys: %s", diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index ec3ab968e925..953df4d9cd26 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) from prometheus_client import Counter @@ -829,3 +838,66 @@ async def _check_user_exists(self, user_id: str) -> bool: if unknown_user: return await self.query_user_exists(user_id) return True + + async def claim_e2e_one_time_keys( + self, query: Iterable[Tuple[str, str, str]] + ) -> Tuple[ + Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]] + ]: + """Claim one time keys from application services. + + Args: + query: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A tuple of: + An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. + + A copy of the input which has not been fulfilled (either because + they are not appservice users or the appservice does not support + providing OTKs). + """ + services = self.store.get_app_services() + + # Partition the users by appservice. + query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {} + missing = [] + for user_id, device, algorithm in query: + if not self.store.get_if_app_services_interested_in_user(user_id): + missing.append((user_id, device, algorithm)) + continue + + # Find the associated appservice. + for service in services: + if service.is_exclusive_user(user_id): + query_by_appservice.setdefault(service.id, []).append( + (user_id, device, algorithm) + ) + continue + + # Query each service in parallel. + results = await make_deferred_yieldable( + defer.DeferredList( + [ + run_in_background( + self.appservice_api.claim_client_keys, + # We know this must be an app service. + self.store.get_app_service_by_id(service_id), # type: ignore[arg-type] + service_query, + ) + for service_id, service_query in query_by_appservice.items() + ], + consumeErrors=True, + ) + ) + + # Patch together the results -- they are all independent (since they + # require exclusive control over the users). They get returned as a list + # and the caller combines them. + claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = [] + for success, result in results: + if success: + claimed_keys.append(result[0]) + missing.extend(result[1]) + + return claimed_keys, missing diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 4e9c8d8db0c7..9e7c2c45b557 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple @@ -53,6 +52,7 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() + self._appservice_handler = hs.get_application_service_handler() self.is_mine = hs.is_mine self.clock = hs.get_clock() @@ -88,6 +88,10 @@ def __init__(self, hs: "HomeServer"): max_count=10, ) + self._query_appservices_for_otks = ( + hs.config.experimental.msc3983_appservice_otk_claims + ) + @trace @cancellable async def query_devices( @@ -542,6 +546,42 @@ async def on_federation_query_client_keys( return ret + async def claim_local_one_time_keys( + self, local_query: List[Tuple[str, str, str]] + ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: + """Claim one time keys for local users. + + 1. Attempt to claim OTKs from the database. + 2. Ask application services if they provide OTKs. + 3. Attempt to fetch fallback keys from the database. + + Args: + local_query: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. + """ + + otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) + + # If the application services have not provided any keys via the C-S + # API, query it directly for one-time keys. + if self._query_appservices_for_otks: + ( + appservice_results, + not_found, + ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) + else: + appservice_results = [] + + # For each user that does not have a one-time keys available, see if + # there is a fallback key. + fallback_results = await self.store.claim_e2e_fallback_keys(not_found) + + # Return the results in order, each item from the input query should + # only appear once in the combined list. + return (otk_results, *appservice_results, fallback_results) + @trace async def claim_one_time_keys( self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int] @@ -561,17 +601,18 @@ async def claim_one_time_keys( set_tag("local_key_query", str(local_query)) set_tag("remote_key_query", str(remote_queries)) - results = await self.store.claim_e2e_one_time_keys(local_query) + results = await self.claim_local_one_time_keys(local_query) # A map of user ID -> device ID -> key ID -> key. json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + for result in results: + for user_id, device_keys in result.items(): + for device_id, keys in device_keys.items(): + for key_id, key in keys.items(): + json_result.setdefault(user_id, {})[device_id] = {key_id: key} + + # Remote failures. failures: Dict[str, JsonDict] = {} - for user_id, device_keys in results.items(): - for device_id, keys in device_keys.items(): - for key_id, json_str in keys.items(): - json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_str) - } @trace async def claim_client_keys(destination: str) -> None: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index a3b6c8ae8e82..dc7768c50cab 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -51,7 +51,7 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict -from synapse.util import json_encoder +from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter @@ -1028,14 +1028,17 @@ def get_device_stream_token(self) -> int: async def claim_e2e_one_time_keys( self, query_list: Iterable[Tuple[str, str, str]] - ) -> Dict[str, Dict[str, Dict[str, str]]]: + ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: """Take a list of one time keys out of the database. Args: query_list: An iterable of tuples of (user ID, device ID, algorithm). Returns: - A map of user ID -> a map device ID -> a map of key ID -> JSON bytes. + A tuple pf: + A map of user ID -> a map device ID -> a map of key ID -> JSON. + + A copy of the input which has not been fulfilled. """ @trace @@ -1115,7 +1118,8 @@ def _claim_e2e_one_time_key_returning( key_id, key_json = otk_row return f"{algorithm}:{key_id}", key_json - results: Dict[str, Dict[str, Dict[str, str]]] = {} + results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + missing: List[Tuple[str, str, str]] = [] for user_id, device_id, algorithm in query_list: if self.database_engine.supports_returning: # If we support RETURNING clause we can use a single query that @@ -1138,11 +1142,25 @@ def _claim_e2e_one_time_key_returning( device_results = results.setdefault(user_id, {}).setdefault( device_id, {} ) - device_results[claim_row[0]] = claim_row[1] - continue + device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) + else: + missing.append((user_id, device_id, algorithm)) + + return results, missing + + async def claim_e2e_fallback_keys( + self, query_list: Iterable[Tuple[str, str, str]] + ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: + """Take a list of fallback keys out of the database. - # No one-time key available, so see if there's a fallback - # key + Args: + query_list: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A map of user ID -> a map device ID -> a map of key ID -> JSON. + """ + results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + for user_id, device_id, algorithm in query_list: row = await self.db_pool.simple_select_one( table="e2e_fallback_keys_json", keyvalues={ @@ -1179,7 +1197,7 @@ def _claim_e2e_one_time_key_returning( ) device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) - device_results[f"{algorithm}:{key_id}"] = key_json + device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) return results diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 9d183b733ec2..0dd02b7d5839 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -105,3 +105,62 @@ async def get_json( ) self.assertEqual(self.request_url, URL_LOCATION) self.assertEqual(result, SUCCESS_RESULT_LOCATION) + + def test_claim_keys(self) -> None: + """ + Tests that the /keys/claim response is properly parsed for missing + keys. + """ + + RESPONSE: JsonDict = { + "@alice:example.org": { + "DEVICE_1": { + "signed_curve25519:AAAAHg": { + # We don't really care about the content of the keys, + # they get passed back transparently. + }, + "signed_curve25519:BBBBHg": {}, + }, + "DEVICE_2": {"signed_curve25519:CCCCHg": {}}, + }, + } + + async def post_json_get_json( + uri: str, + post_json: Any, + headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], + ) -> JsonDict: + # Ensure the access token is passed as both a header and query arg. + if not headers.get("Authorization"): + raise RuntimeError("Access token not provided") + + self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) + return RESPONSE + + # We assign to a method, which mypy doesn't like. + self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[assignment] + + MISSING_KEYS = [ + # Known user, known device, missing algorithm. + ("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"), + # Known user, missing device. + ("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"), + # Unknown user. + ("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"), + ] + + claimed_keys, missing = self.get_success( + self.api.claim_client_keys( + self.service, + [ + # Found devices + ("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"), + ("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"), + ("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"), + ] + + MISSING_KEYS, + ) + ) + + self.assertEqual(claimed_keys, RESPONSE) + self.assertEqual(missing, MISSING_KEYS) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 6b4cba65d069..4ff04fc66bb7 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -23,18 +23,24 @@ from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError +from synapse.appservice import ApplicationService from synapse.handlers.device import DeviceHandler from synapse.server import HomeServer +from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.types import JsonDict from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable +from tests.unittest import override_config class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - return self.setup_test_homeserver(federation_client=mock.Mock()) + self.appservice_api = mock.Mock() + return self.setup_test_homeserver( + federation_client=mock.Mock(), application_service_api=self.appservice_api + ) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_keys_handler() @@ -941,3 +947,71 @@ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> Non # The two requests to the local homeserver should be identical. self.assertEqual(response_1, response_2) + + @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}}) + def test_query_appservice(self) -> None: + local_user = "@boris:" + self.hs.hostname + device_id_1 = "xyz" + fallback_key = {"alg1:k1": "fallback_key1"} + device_id_2 = "abc" + otk = {"alg1:k2": "key2"} + + # Inject an appservice interested in this user. + appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + self.hs.get_datastores().main.services_cache = [appservice] + self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( + [appservice] + ) + + # Setup a response, but only for device 2. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")]) + ) + + # we shouldn't have any unused fallback keys yet + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(res, []) + + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id_1, + {"fallback_keys": fallback_key}, + ) + ) + + # we should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # claiming an OTK when no OTKs are available should ask the appservice, then + # query the fallback keys. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + { + "one_time_keys": { + local_user: {device_id_1: "alg1", device_id_2: "alg1"} + } + }, + timeout=None, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": { + local_user: {device_id_1: fallback_key, device_id_2: otk} + }, + }, + )