Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Implement MSC3983 to proxy /keys/claim queries to appservices. (#15314)
Browse files Browse the repository at this point in the history
Experimental support for MSC3983 is behind a configuration flag.
If enabled, for users which are exclusively owned by an application
service then the appservice will be queried for one-time keys *if*
there are none uploaded to Synapse.
  • Loading branch information
clokep authored Mar 28, 2023
1 parent 57481ca commit 5282ba1
Show file tree
Hide file tree
Showing 9 changed files with 355 additions and 29 deletions.
1 change: 1 addition & 0 deletions changelog.d/15314.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)).
56 changes: 56 additions & 0 deletions synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,62 @@ async def push_bulk(
failed_transactions_counter.labels(service.id).inc()
return False

async def claim_client_keys(
self, service: "ApplicationService", query: List[Tuple[str, str, str]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Claim one time keys from an application service.
Args:
query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A tuple of:
A map of user ID -> a map device ID -> a map of key ID -> JSON dict.
A copy of the input which has not been fulfilled because the
appservice doesn't support this endpoint or has not returned
data for that tuple.
"""
if service.url is None:
return {}, query

# This is required by the configuration.
assert service.hs_token is not None

# Create the expected payload shape.
body: Dict[str, Dict[str, List[str]]] = {}
for user_id, device, algorithm in query:
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)

uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
try:
response = await self.post_json_get_json(
uri,
body,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
)
except CodeMessageException as e:
# The appservice doesn't support this endpoint.
if e.code == 404 or e.code == 405:
return {}, query
logger.warning("claim_keys to %s received %s", uri, e.code)
return {}, query
except Exception as ex:
logger.warning("claim_keys to %s threw exception %s", uri, ex)
return {}, query

# Check if the appservice fulfilled all of the queried user/device/algorithms
# or if some are still missing.
#
# TODO This places a lot of faith in the response shape being correct.
missing = [
(user_id, device, algorithm)
for user_id, device, algorithm in query
if algorithm not in response.get(user_id, {}).get(device, [])
]

return response, missing

def _serialize(
self, service: "ApplicationService", events: Iterable[EventBase]
) -> List[JsonDict]:
Expand Down
5 changes: 5 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
"msc3202_transaction_extensions", False
)

# MSC3983: Proxying OTK claim requests to exclusive ASes.
self.msc3983_appservice_otk_claims: bool = experimental.get(
"msc3983_appservice_otk_claims", False
)

# MSC3706 (server-side support for partial state in /send_join responses)
# Synapse will always serve partial state responses to requests using the stable
# query parameter `omit_members`. If this flag is set, Synapse will also serve
Expand Down
20 changes: 10 additions & 10 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
from synapse.storage.roommember import MemberSummary
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_decoder, unwrapFirstError
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_server_name
Expand Down Expand Up @@ -135,6 +135,7 @@ def __init__(self, hs: "HomeServer"):
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
self._room_member_handler = hs.get_room_member_handler()
self._e2e_keys_handler = hs.get_e2e_keys_handler()

self._state_storage_controller = hs.get_storage_controllers().state

Expand Down Expand Up @@ -1012,15 +1013,14 @@ async def on_claim_client_keys(
query.append((user_id, device_id, algorithm))

log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self.store.claim_e2e_one_time_keys(query)

json_result: Dict[str, Dict[str, dict]] = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json_decoder.decode(json_str)
}
results = await self._e2e_keys_handler.claim_local_one_time_keys(query)

json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
json_result.setdefault(user_id, {})[device_id] = {key_id: key}

logger.info(
"Claimed one-time-keys: %s",
Expand Down
74 changes: 73 additions & 1 deletion synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)

from prometheus_client import Counter

Expand Down Expand Up @@ -829,3 +838,66 @@ async def _check_user_exists(self, user_id: str) -> bool:
if unknown_user:
return await self.query_user_exists(user_id)
return True

async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
) -> Tuple[
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
]:
"""Claim one time keys from application services.
Args:
query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A tuple of:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
A copy of the input which has not been fulfilled (either because
they are not appservice users or the appservice does not support
providing OTKs).
"""
services = self.store.get_app_services()

# Partition the users by appservice.
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
missing = []
for user_id, device, algorithm in query:
if not self.store.get_if_app_services_interested_in_user(user_id):
missing.append((user_id, device, algorithm))
continue

# Find the associated appservice.
for service in services:
if service.is_exclusive_user(user_id):
query_by_appservice.setdefault(service.id, []).append(
(user_id, device, algorithm)
)
continue

# Query each service in parallel.
results = await make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
self.appservice_api.claim_client_keys,
# We know this must be an app service.
self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
service_query,
)
for service_id, service_query in query_by_appservice.items()
],
consumeErrors=True,
)
)

# Patch together the results -- they are all independent (since they
# require exclusive control over the users). They get returned as a list
# and the caller combines them.
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
for success, result in results:
if success:
claimed_keys.append(result[0])
missing.extend(result[1])

return claimed_keys, missing
57 changes: 49 additions & 8 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple

Expand Down Expand Up @@ -53,6 +52,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
self._appservice_handler = hs.get_application_service_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()

Expand Down Expand Up @@ -88,6 +88,10 @@ def __init__(self, hs: "HomeServer"):
max_count=10,
)

self._query_appservices_for_otks = (
hs.config.experimental.msc3983_appservice_otk_claims
)

@trace
@cancellable
async def query_devices(
Expand Down Expand Up @@ -542,6 +546,42 @@ async def on_federation_query_client_keys(

return ret

async def claim_local_one_time_keys(
self, local_query: List[Tuple[str, str, str]]
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.
1. Attempt to claim OTKs from the database.
2. Ask application services if they provide OTKs.
3. Attempt to fetch fallback keys from the database.
Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""

otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)

# If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys.
if self._query_appservices_for_otks:
(
appservice_results,
not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
else:
appservice_results = []

# For each user that does not have a one-time keys available, see if
# there is a fallback key.
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)

# Return the results in order, each item from the input query should
# only appear once in the combined list.
return (otk_results, *appservice_results, fallback_results)

@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
Expand All @@ -561,17 +601,18 @@ async def claim_one_time_keys(
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))

results = await self.store.claim_e2e_one_time_keys(local_query)
results = await self.claim_local_one_time_keys(local_query)

# A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
json_result.setdefault(user_id, {})[device_id] = {key_id: key}

# Remote failures.
failures: Dict[str, JsonDict] = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json_decoder.decode(json_str)
}

@trace
async def claim_client_keys(destination: str) -> None:
Expand Down
36 changes: 27 additions & 9 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
Expand Down Expand Up @@ -1028,14 +1028,17 @@ def get_device_stream_token(self) -> int:

async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, str]]]:
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Take a list of one time keys out of the database.
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
A tuple pf:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
A copy of the input which has not been fulfilled.
"""

@trace
Expand Down Expand Up @@ -1115,7 +1118,8 @@ def _claim_e2e_one_time_key_returning(
key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json

results: Dict[str, Dict[str, Dict[str, str]]] = {}
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
missing: List[Tuple[str, str, str]] = []
for user_id, device_id, algorithm in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
Expand All @@ -1138,11 +1142,25 @@ def _claim_e2e_one_time_key_returning(
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
device_results[claim_row[0]] = claim_row[1]
continue
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
else:
missing.append((user_id, device_id, algorithm))

return results, missing

async def claim_e2e_fallback_keys(
self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Take a list of fallback keys out of the database.
# No one-time key available, so see if there's a fallback
# key
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm in query_list:
row = await self.db_pool.simple_select_one(
table="e2e_fallback_keys_json",
keyvalues={
Expand Down Expand Up @@ -1179,7 +1197,7 @@ def _claim_e2e_one_time_key_returning(
)

device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = key_json
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)

return results

Expand Down
Loading

0 comments on commit 5282ba1

Please sign in to comment.