Skip to content

Commit

Permalink
Sliding Sync: Add E2EE extension (MSC3884) (#17454)
Browse files Browse the repository at this point in the history
  • Loading branch information
MadLittleMods authored Jul 22, 2024
1 parent d221512 commit de05a64
Show file tree
Hide file tree
Showing 9 changed files with 1,023 additions and 34 deletions.
1 change: 1 addition & 0 deletions changelog.d/17454.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add E2EE extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
17 changes: 13 additions & 4 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo
from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
ScheduledTask,
Expand Down Expand Up @@ -214,7 +215,7 @@ async def get_device_changes_in_shared_rooms(
@cancellable
async def get_user_ids_changed(
self, user_id: str, from_token: StreamToken
) -> JsonDict:
) -> DeviceListUpdates:
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
"""
Expand Down Expand Up @@ -341,11 +342,19 @@ async def get_user_ids_changed(
possibly_joined = set()
possibly_left = set()

result = {"changed": list(possibly_joined), "left": list(possibly_left)}
device_list_updates = DeviceListUpdates(
changed=possibly_joined,
left=possibly_left,
)

log_kv(result)
log_kv(
{
"changed": device_list_updates.changed,
"left": device_list_updates.left,
}
)

return result
return device_list_updates

async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
if not self.hs.is_mine(UserID.from_string(user_id)):
Expand Down
107 changes: 92 additions & 15 deletions synapse/handlers/sliding_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,18 @@
#
import logging
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Final, List, Mapping, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Any,
Dict,
Final,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)

import attr
from immutabledict import immutabledict
Expand All @@ -33,6 +44,7 @@
from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
from synapse.storage.roommember import MemberSummary
from synapse.types import (
DeviceListUpdates,
JsonDict,
PersistedEventPosition,
Requester,
Expand Down Expand Up @@ -343,6 +355,7 @@ def __init__(self, hs: "HomeServer"):
self.notifier = hs.get_notifier()
self.event_sources = hs.get_event_sources()
self.relations_handler = hs.get_relations_handler()
self.device_handler = hs.get_device_handler()
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync

async def wait_for_sync_for_user(
Expand Down Expand Up @@ -371,10 +384,6 @@ async def wait_for_sync_for_user(
# auth_blocking will occur)
await self.auth_blocking.check_auth_blocking(requester=requester)

# TODO: If the To-Device extension is enabled and we have a `from_token`, delete
# any to-device messages before that token (since we now know that the device
# has received them). (see sync v2 for how to do this)

# If we're working with a user-provided token, we need to make sure to wait for
# this worker to catch up with the token so we don't skip past any incoming
# events or future events if the user is nefariously, manually modifying the
Expand Down Expand Up @@ -617,7 +626,9 @@ async def handle_room(room_id: str) -> None:
await concurrently_execute(handle_room, relevant_room_map, 10)

extensions = await self.get_extensions_response(
sync_config=sync_config, to_token=to_token
sync_config=sync_config,
from_token=from_token,
to_token=to_token,
)

return SlidingSyncResult(
Expand Down Expand Up @@ -1776,48 +1787,64 @@ async def get_extensions_response(
self,
sync_config: SlidingSyncConfig,
to_token: StreamToken,
from_token: Optional[StreamToken],
) -> SlidingSyncResult.Extensions:
"""Handle extension requests.
Args:
sync_config: Sync configuration
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
"""

if sync_config.extensions is None:
return SlidingSyncResult.Extensions()

to_device_response = None
if sync_config.extensions.to_device:
to_device_response = await self.get_to_device_extensions_response(
if sync_config.extensions.to_device is not None:
to_device_response = await self.get_to_device_extension_response(
sync_config=sync_config,
to_device_request=sync_config.extensions.to_device,
to_token=to_token,
)

return SlidingSyncResult.Extensions(to_device=to_device_response)
e2ee_response = None
if sync_config.extensions.e2ee is not None:
e2ee_response = await self.get_e2ee_extension_response(
sync_config=sync_config,
e2ee_request=sync_config.extensions.e2ee,
to_token=to_token,
from_token=from_token,
)

async def get_to_device_extensions_response(
return SlidingSyncResult.Extensions(
to_device=to_device_response,
e2ee=e2ee_response,
)

async def get_to_device_extension_response(
self,
sync_config: SlidingSyncConfig,
to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
to_token: StreamToken,
) -> SlidingSyncResult.Extensions.ToDeviceExtension:
) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]:
"""Handle to-device extension (MSC3885)
Args:
sync_config: Sync configuration
to_device_request: The to-device extension from the request
to_token: The point in the stream to sync up to.
"""

user_id = sync_config.user.to_string()
device_id = sync_config.device_id

# Skip if the extension is not enabled
if not to_device_request.enabled:
return None

# Check that this request has a valid device ID (not all requests have
# to belong to a device, and so device_id is None), and that the
# extension is enabled.
if device_id is None or not to_device_request.enabled:
# to belong to a device, and so device_id is None)
if device_id is None:
return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=f"{to_token.to_device_key}",
events=[],
Expand Down Expand Up @@ -1868,3 +1895,53 @@ async def get_to_device_extensions_response(
next_batch=f"{stream_id}",
events=messages,
)

async def get_e2ee_extension_response(
self,
sync_config: SlidingSyncConfig,
e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension,
to_token: StreamToken,
from_token: Optional[StreamToken],
) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]:
"""Handle E2EE device extension (MSC3884)
Args:
sync_config: Sync configuration
e2ee_request: The e2ee extension from the request
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
"""
user_id = sync_config.user.to_string()
device_id = sync_config.device_id

# Skip if the extension is not enabled
if not e2ee_request.enabled:
return None

device_list_updates: Optional[DeviceListUpdates] = None
if from_token is not None:
# TODO: This should take into account the `from_token` and `to_token`
device_list_updates = await self.device_handler.get_user_ids_changed(
user_id=user_id,
from_token=from_token,
)

device_one_time_keys_count: Mapping[str, int] = {}
device_unused_fallback_key_types: Sequence[str] = []
if device_id:
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
device_one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
device_unused_fallback_key_types = (
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)

return SlidingSyncResult.Extensions.E2eeExtension(
device_list_updates=device_list_updates,
device_one_time_keys_count=device_one_time_keys_count,
device_unused_fallback_key_types=device_unused_fallback_key_types,
)
10 changes: 8 additions & 2 deletions synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,15 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

user_id = requester.user.to_string()

results = await self.device_handler.get_user_ids_changed(user_id, from_token)
device_list_updates = await self.device_handler.get_user_ids_changed(
user_id, from_token
)

response: JsonDict = {}
response["changed"] = list(device_list_updates.changed)
response["left"] = list(device_list_updates.left)

return 200, results
return 200, response


class OneTimeKeyServlet(RestServlet):
Expand Down
32 changes: 29 additions & 3 deletions synapse/rest/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,15 +1081,41 @@ async def encode_rooms(
async def encode_extensions(
self, requester: Requester, extensions: SlidingSyncResult.Extensions
) -> JsonDict:
result = {}
serialized_extensions: JsonDict = {}

if extensions.to_device is not None:
result["to_device"] = {
serialized_extensions["to_device"] = {
"next_batch": extensions.to_device.next_batch,
"events": extensions.to_device.events,
}

return result
if extensions.e2ee is not None:
serialized_extensions["e2ee"] = {
# We always include this because
# https://github.com/vector-im/element-android/issues/3725. The spec
# isn't terribly clear on when this can be omitted and how a client
# would tell the difference between "no keys present" and "nothing
# changed" in terms of whole field absent / individual key type entry
# absent Corresponding synapse issue:
# https://github.com/matrix-org/synapse/issues/10456
"device_one_time_keys_count": extensions.e2ee.device_one_time_keys_count,
# https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
# states that this field should always be included, as long as the
# server supports the feature.
"device_unused_fallback_key_types": extensions.e2ee.device_unused_fallback_key_types,
}

if extensions.e2ee.device_list_updates is not None:
serialized_extensions["e2ee"]["device_lists"] = {}

serialized_extensions["e2ee"]["device_lists"]["changed"] = list(
extensions.e2ee.device_list_updates.changed
)
serialized_extensions["e2ee"]["device_lists"]["left"] = list(
extensions.e2ee.device_list_updates.left
)

return serialized_extensions


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
Expand Down
7 changes: 4 additions & 3 deletions synapse/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,11 +1219,12 @@ class ReadReceipt:
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListUpdates:
"""
An object containing a diff of information regarding other users' device lists, intended for
a recipient to carry out device list tracking.
An object containing a diff of information regarding other users' device lists,
intended for a recipient to carry out device list tracking.
Attributes:
changed: A set of users whose device lists have changed recently.
changed: A set of users who have updated their device identity or
cross-signing keys, or who now share an encrypted room with.
left: A set of users who the recipient no longer needs to track the device lists of.
Typically when those users no longer share any end-to-end encryption enabled rooms.
"""
Expand Down
48 changes: 45 additions & 3 deletions synapse/types/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#
#
from enum import Enum
from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Dict, Final, List, Mapping, Optional, Sequence, Tuple

import attr
from typing_extensions import TypedDict
Expand All @@ -31,7 +31,7 @@
from pydantic import Extra

from synapse.events import EventBase
from synapse.types import JsonDict, JsonMapping, StreamToken, UserID
from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, StreamToken, UserID
from synapse.types.rest.client import SlidingSyncBody

if TYPE_CHECKING:
Expand Down Expand Up @@ -264,6 +264,7 @@ class Extensions:
Attributes:
to_device: The to-device extension (MSC3885)
e2ee: The E2EE device extension (MSC3884)
"""

@attr.s(slots=True, frozen=True, auto_attribs=True)
Expand All @@ -282,10 +283,51 @@ class ToDeviceExtension:
def __bool__(self) -> bool:
return bool(self.events)

@attr.s(slots=True, frozen=True, auto_attribs=True)
class E2eeExtension:
"""The E2EE device extension (MSC3884)
Attributes:
device_list_updates: List of user_ids whose devices have changed or left (only
present on incremental syncs).
device_one_time_keys_count: Map from key algorithm to the number of
unclaimed one-time keys currently held on the server for this device. If
an algorithm is unlisted, the count for that algorithm is assumed to be
zero. If this entire parameter is missing, the count for all algorithms
is assumed to be zero.
device_unused_fallback_key_types: List of unused fallback key algorithms
for this device.
"""

# Only present on incremental syncs
device_list_updates: Optional[DeviceListUpdates]
device_one_time_keys_count: Mapping[str, int]
device_unused_fallback_key_types: Sequence[str]

def __bool__(self) -> bool:
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
#
# Also related:
# https://github.com/element-hq/element-android/issues/3725 and
# https://github.com/matrix-org/synapse/issues/10456
default_otk = self.device_one_time_keys_count.get("signed_curve25519")
more_than_default_otk = len(self.device_one_time_keys_count) > 1 or (
default_otk is not None and default_otk > 0
)

return bool(
more_than_default_otk
or self.device_list_updates
or self.device_unused_fallback_key_types
)

to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None

def __bool__(self) -> bool:
return bool(self.to_device)
return bool(self.to_device or self.e2ee)

next_pos: StreamToken
lists: Dict[str, SlidingWindowList]
Expand Down
Loading

0 comments on commit de05a64

Please sign in to comment.