diff --git a/changelog.d/13581.feature b/changelog.d/13581.feature new file mode 100644 index 000000000000..1db4ab71d361 --- /dev/null +++ b/changelog.d/13581.feature @@ -0,0 +1 @@ +Implement [MSC3814](https://github.com/matrix-org/matrix-spec-proposals/pull/3814), dehydrated devices v2/shrivelled sessions, with a few changes (as proposed on the MSC) and move [MSC2697](https://github.com/matrix-org/matrix-spec-proposals/pull/2697) behind a config flag. Contributed by Nico from Famedly. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 702b81e636c9..a5d2cf4a893c 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,6 +32,17 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # MSC2716 (importing historical messages) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) + # MSC2285 (unstable private read receipts) + self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) + + # MSC2697 (device dehydration) + # Enabled by default since this option was added after adding the feature. + self.msc2697_enabled: bool = experimental.get("msc2697_enabled", True) + + # MSC3814 (dehydrated devices with SSSS) + # This is an alternative method to achieve the same goals as MSC2697. + self.msc3814_enabled: bool = experimental.get("msc3814_enabled", False) + # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 444c08bc2eef..7419e7e62a07 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -13,10 +13,11 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict +from http import HTTPStatus +from typing import TYPE_CHECKING, Any, Dict, Optional from synapse.api.constants import EduTypes, ToDeviceEventTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -46,6 +47,9 @@ def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.is_mine = hs.is_mine + if hs.config.experimental.msc3814_enabled: + self.event_sources = hs.get_event_sources() + self.device_handler = hs.get_device_handler() # We only need to poke the federation sender explicitly if its on the # same instance. Other federation sender instances will get notified by @@ -293,3 +297,79 @@ async def send_device_message( # Enqueue a new federation transaction to send the new # device messages to each remote destination. self.federation_sender.send_device_messages(destination) + + async def get_events_for_dehydrated_device( + self, + requester: Requester, + device_id: str, + since_token: Optional[str], + limit: int, + ) -> JsonDict: + """Fetches up to `limit` events sent to `device_id` starting from `since_token` and returns the new since token.""" + + user_id = requester.user.to_string() + + # TODO(Nico): Figure out who should be allowed to use that endpoint. + # For now we just allow it for yourself and for the dehydrated device. + if device_id != requester.device_id: + dehydrated_device = await self.device_handler.get_dehydrated_device(user_id) + if dehydrated_device is not None and device_id != dehydrated_device[0]: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "Can only fetch messages for own device or dehydrated devices", + Codes.FORBIDDEN, + ) + + since_stream_id = 0 + if since_token: + if not since_token.startswith("d"): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "from parameter %r has an invalid format" % (since_token,), + errcode=Codes.INVALID_PARAM, + ) + + try: + since_stream_id = int(since_token[1:]) + except Exception: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "from parameter %r has an invalid format" % (since_token,), + errcode=Codes.INVALID_PARAM, + ) + + # if we have a since token, delete any to-device messages before that token + # (since we now know that the device has received them) + deleted = await self.store.delete_messages_for_device( + user_id, device_id, since_stream_id + ) + logger.debug( + "Deleted %d to-device messages up to %d", deleted, since_stream_id + ) + + to_token = self.event_sources.get_current_token().to_device_key + + messages, stream_id = await self.store.get_messages_for_device( + user_id, device_id, since_stream_id, to_token, limit + ) + + for message in messages: + # We pop here as we shouldn't be sending the message ID down + # `/sync` + message_id = message.pop("message_id", None) + if message_id: + set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id) + + logger.debug( + "Returning %d to-device messages between %d and %d (current token: %d) for dehydrated device %s", + len(messages), + since_stream_id, + stream_id, + to_token, + device_id, + ) + + return { + "events": messages, + "next_batch": f"d{stream_id}", + } diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index ed6ce78d4771..c5260aff28a8 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -18,11 +18,13 @@ from synapse.api import errors from synapse.api.errors import NotFoundError -from synapse.http.server import HttpServer +from synapse.http.server import HttpServer, cancellable from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_integer, parse_json_object_from_request, + parse_string, ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns, interactive_auth_handler @@ -194,6 +196,8 @@ async def on_PUT( class DehydratedDeviceServlet(RestServlet): """Retrieve or store a dehydrated device. + Implements both MSC2697 and MSC3814. + GET /org.matrix.msc2697.v2/dehydrated_device HTTP/1.1 200 OK @@ -226,14 +230,19 @@ class DehydratedDeviceServlet(RestServlet): """ - PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) - - def __init__(self, hs: "HomeServer"): + def __init__(self, hs: "HomeServer", msc2697: bool = True): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() + self.PATTERNS = client_patterns( + "/org.matrix.msc2697.v2/dehydrated_device$" + if msc2697 + else "/org.matrix.msc3814.v1/dehydrated_device$", + releases=(), + ) + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) dehydrated_device = await self.device_handler.get_dehydrated_device( @@ -327,9 +336,44 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: return 200, result +class DehydratedDeviceEventsServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc3814.v1/dehydrated_device/(?P[^/]*)/events$", + releases=(), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.message_handler = hs.get_device_message_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + + @cancellable + async def on_GET( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + from_tok = parse_string(request, "from") + limit = parse_integer(request, "limit", 100) + + msgs = await self.message_handler.get_events_for_dehydrated_device( + requester=requester, + device_id=device_id, + since_token=from_tok, + limit=limit, + ) + + return 200, msgs + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server) - DehydratedDeviceServlet(hs).register(http_server) - ClaimDehydratedDeviceServlet(hs).register(http_server) + if hs.config.experimental.msc2697_enabled: + DehydratedDeviceServlet(hs, msc2697=True).register(http_server) + ClaimDehydratedDeviceServlet(hs).register(http_server) + if hs.config.experimental.msc3814_enabled: + DehydratedDeviceServlet(hs, msc2697=False).register(http_server) + DehydratedDeviceEventsServlet(hs).register(http_server) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b8b465d35b8f..de3cccb5d6d5 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -16,11 +16,13 @@ from typing import Optional +from twisted.internet.defer import ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN from synapse.server import HomeServer +from synapse.types import create_requester from synapse.util import Clock from tests import unittest @@ -265,6 +267,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() + self.message_handler = hs.get_device_message_handler() self.registration = hs.get_registration_handler() self.auth = hs.get_auth() self.store = hs.get_datastores().main @@ -342,3 +345,104 @@ def test_dehydrate_and_rehydrate_device(self) -> None: ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) self.assertIsNone(ret) + + @unittest.override_config( + {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}} + ) + def test_dehydrate_v2_and_fetch_events(self) -> None: + user_id = "@boris:server" + + self.get_success(self.store.register_user(user_id, "foobar")) + + # First check if we can store and fetch a dehydrated device + stored_dehydrated_device_id = self.get_success( + self.handler.store_dehydrated_device( + user_id=user_id, + device_data={"device_data": {"foo": "bar"}}, + initial_device_display_name="dehydrated device", + ) + ) + + retrieved_device_id, device_data = self.get_success( + self.handler.get_dehydrated_device(user_id=user_id) + ) + + self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) + self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) + + # Create a new login for the user + device_id, access_token, _expiration_time, _refresh_token = self.get_success( + self.registration.register_device( + user_id=user_id, + device_id=None, + initial_display_name="new device", + ) + ) + + requester = create_requester(user_id, device_id=device_id) + + # Fetching messages for a non existing device should return an error + self.get_failure( + self.message_handler.get_events_for_dehydrated_device( + requester=requester, + device_id="not the right device ID", + since_token=None, + limit=10, + ), + SynapseError, + ) + + # Send a message to the dehydrated device + ensureDeferred( + self.message_handler.send_device_message( + requester=requester, + message_type="test.message", + messages={user_id: {stored_dehydrated_device_id: {"body": "foo"}}}, + ) + ) + self.pump() + + # Fetch the message of the dehydrated device + res = self.get_success( + self.message_handler.get_events_for_dehydrated_device( + requester=requester, + device_id=stored_dehydrated_device_id, + since_token=None, + limit=10, + ) + ) + + self.assertTrue(len(res["next_batch"]) > 1) + self.assertEqual(len(res["events"]), 1) + self.assertEqual(res["events"][0]["content"]["body"], "foo") + + # Fetch the message of the dehydrated device again, which should return nothing and delete the old messages + res = self.get_success( + self.message_handler.get_events_for_dehydrated_device( + requester=requester, + device_id=stored_dehydrated_device_id, + since_token=res["next_batch"], + limit=10, + ) + ) + self.assertTrue(len(res["next_batch"]) > 1) + self.assertEqual(len(res["events"]), 0) + + # Fetching messages without since should return nothing, since the messages got deleted + res = self.get_success( + self.message_handler.get_events_for_dehydrated_device( + requester=requester, + device_id=stored_dehydrated_device_id, + since_token=None, + limit=10, + ) + ) + self.assertTrue(len(res["next_batch"]) > 1) + self.assertEqual(len(res["events"]), 0) + + # We don't delete the device when fetch messages for now. + # # make sure that the device ID that we were initially assigned no longer exists + # self.get_failure( + # self.handler.get_device(user_id, device_id), + # NotFoundError, + # )