Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support MSC3814: Dehydrated devices v2 aka shrivelled sessions #13581

Closed
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13581.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement [MSC3814](https://github.com/matrix-org/matrix-spec-proposals/pull/3814), dehydrated devices v2/shrivelled sessions, with a few changes (as proposed on the MSC) and move [MSC2697](https://github.com/matrix-org/matrix-spec-proposals/pull/2697) behind a config flag. Contributed by Nico from Famedly.
8 changes: 8 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# MSC2285 (unstable private read receipts)
self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False)

# MSC2697 (device dehydration)
# Enabled by default since this option was added after adding the feature.
self.msc2697_enabled: bool = experimental.get("msc2697_enabled", True)

# MSC3814 (dehydrated devices with SSSS)
# This is an alternative method to achieve the same goals as MSC2697.
self.msc3814_enabled: bool = experimental.get("msc3814_enabled", False)

# MSC3244 (room version capabilities)
self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)

Expand Down
84 changes: 82 additions & 2 deletions synapse/handlers/devicemessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Any, Dict
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, Optional

from synapse.api.constants import EduTypes, ToDeviceEventTypes
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
Expand Down Expand Up @@ -46,6 +47,9 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.is_mine = hs.is_mine
if hs.config.experimental.msc3814_enabled:
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()

# We only need to poke the federation sender explicitly if its on the
# same instance. Other federation sender instances will get notified by
Expand Down Expand Up @@ -293,3 +297,79 @@ async def send_device_message(
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
self.federation_sender.send_device_messages(destination)

async def get_events_for_dehydrated_device(
self,
requester: Requester,
device_id: str,
since_token: Optional[str],
limit: int,
) -> JsonDict:
"""Fetches up to `limit` events sent to `device_id` starting from `since_token` and returns the new since token."""

user_id = requester.user.to_string()

# TODO(Nico): Figure out who should be allowed to use that endpoint.
# For now we just allow it for yourself and for the dehydrated device.
if device_id != requester.device_id:
dehydrated_device = await self.device_handler.get_dehydrated_device(user_id)
if dehydrated_device is not None and device_id != dehydrated_device[0]:
raise SynapseError(
HTTPStatus.FORBIDDEN,
"Can only fetch messages for own device or dehydrated devices",
Codes.UNAUTHORIZED,
nico-famedly marked this conversation as resolved.
Show resolved Hide resolved
)

since_stream_id = 0
if since_token:
if not since_token.startswith("d"):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"from parameter %r has an invalid format" % (since_token,),
errcode=Codes.INVALID_PARAM,
)

try:
since_stream_id = int(since_token[1:])
except Exception:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"from parameter %r has an invalid format" % (since_token,),
errcode=Codes.INVALID_PARAM,
)

# if we have a since token, delete any to-device messages before that token
# (since we now know that the device has received them)
deleted = await self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
logger.debug(
"Deleted %d to-device messages up to %d", deleted, since_stream_id
)

to_token = self.event_sources.get_current_token().to_device_key

messages, stream_id = await self.store.get_messages_for_device(
user_id, device_id, since_stream_id, to_token, limit
)

for message in messages:
# We pop here as we shouldn't be sending the message ID down
# `/sync`
message_id = message.pop("message_id", None)
if message_id:
set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)

logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d) for dehydrated device %s",
len(messages),
since_stream_id,
stream_id,
to_token,
device_id,
)

return {
"events": messages,
"next_batch": f"d{stream_id}",
}
56 changes: 50 additions & 6 deletions synapse/rest/client/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

from synapse.api import errors
from synapse.api.errors import NotFoundError
from synapse.http.server import HttpServer
from synapse.http.server import HttpServer, cancellable
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns, interactive_auth_handler
Expand Down Expand Up @@ -194,6 +196,8 @@ async def on_PUT(
class DehydratedDeviceServlet(RestServlet):
"""Retrieve or store a dehydrated device.

Implements both MSC2697 and MSC3814.

GET /org.matrix.msc2697.v2/dehydrated_device

HTTP/1.1 200 OK
Expand Down Expand Up @@ -226,14 +230,19 @@ class DehydratedDeviceServlet(RestServlet):

"""

PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())

def __init__(self, hs: "HomeServer"):
def __init__(self, hs: "HomeServer", msc2697: bool = True):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()

self.PATTERNS = client_patterns(
"/org.matrix.msc2697.v2/dehydrated_device$"
if msc2697
else "/org.matrix.msc3814.v1/dehydrated_device$",
releases=(),
)

async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
dehydrated_device = await self.device_handler.get_dehydrated_device(
Expand Down Expand Up @@ -327,9 +336,44 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return 200, result


class DehydratedDeviceEventsServlet(RestServlet):
PATTERNS = client_patterns(
"/org.matrix.msc3814.v1/dehydrated_device/(?P<device_id>[^/]*)/events$",
releases=(),
)

def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_device_message_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main

@cancellable
async def on_GET(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

from_tok = parse_string(request, "from")
limit = parse_integer(request, "limit", 100)

msgs = await self.message_handler.get_events_for_dehydrated_device(
requester=requester,
device_id=device_id,
since_token=from_tok,
limit=limit,
)

return 200, msgs


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
DehydratedDeviceServlet(hs).register(http_server)
ClaimDehydratedDeviceServlet(hs).register(http_server)
if hs.config.experimental.msc2697_enabled:
DehydratedDeviceServlet(hs, msc2697=True).register(http_server)
ClaimDehydratedDeviceServlet(hs).register(http_server)
if hs.config.experimental.msc3814_enabled:
DehydratedDeviceServlet(hs, msc2697=False).register(http_server)
DehydratedDeviceEventsServlet(hs).register(http_server)
104 changes: 104 additions & 0 deletions tests/handlers/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

from typing import Optional

from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor

from synapse.api.errors import NotFoundError, SynapseError
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
from synapse.server import HomeServer
from synapse.types import create_requester
from synapse.util import Clock

from tests import unittest
Expand Down Expand Up @@ -265,6 +267,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.message_handler = hs.get_device_message_handler()
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
Expand Down Expand Up @@ -342,3 +345,104 @@ def test_dehydrate_and_rehydrate_device(self) -> None:
ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))

self.assertIsNone(ret)

@unittest.override_config(
{"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
)
def test_dehydrate_v2_and_fetch_events(self) -> None:
user_id = "@boris:server"

self.get_success(self.store.register_user(user_id, "foobar"))

# First check if we can store and fetch a dehydrated device
stored_dehydrated_device_id = self.get_success(
self.handler.store_dehydrated_device(
user_id=user_id,
device_data={"device_data": {"foo": "bar"}},
initial_device_display_name="dehydrated device",
)
)

retrieved_device_id, device_data = self.get_success(
self.handler.get_dehydrated_device(user_id=user_id)
)

self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})

# Create a new login for the user
device_id, access_token, _expiration_time, _refresh_token = self.get_success(
self.registration.register_device(
user_id=user_id,
device_id=None,
initial_display_name="new device",
)
)

requester = create_requester(user_id, device_id=device_id)

# Fetching messages for a non existing device should return an error
self.get_failure(
self.message_handler.get_events_for_dehydrated_device(
requester=requester,
device_id="not the right device ID",
since_token=None,
limit=10,
),
SynapseError,
)

# Send a message to the dehydrated device
ensureDeferred(
self.message_handler.send_device_message(
requester=requester,
message_type="test.message",
messages={user_id: {stored_dehydrated_device_id: {"body": "foo"}}},
)
)
self.pump()

# Fetch the message of the dehydrated device
res = self.get_success(
self.message_handler.get_events_for_dehydrated_device(
requester=requester,
device_id=stored_dehydrated_device_id,
since_token=None,
limit=10,
)
)

self.assertTrue(len(res["next_batch"]) > 1)
self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo")

# Fetch the message of the dehydrated device again, which should return nothing and delete the old messages
res = self.get_success(
self.message_handler.get_events_for_dehydrated_device(
requester=requester,
device_id=stored_dehydrated_device_id,
since_token=res["next_batch"],
limit=10,
)
)
self.assertTrue(len(res["next_batch"]) > 1)
self.assertEqual(len(res["events"]), 0)

# Fetching messages without since should return nothing, since the messages got deleted
res = self.get_success(
self.message_handler.get_events_for_dehydrated_device(
requester=requester,
device_id=stored_dehydrated_device_id,
since_token=None,
limit=10,
)
)
self.assertTrue(len(res["next_batch"]) > 1)
self.assertEqual(len(res["events"]), 0)

# We don't delete the device when fetch messages for now.
# # make sure that the device ID that we were initially assigned no longer exists
# self.get_failure(
# self.handler.get_device(user_id, device_id),
# NotFoundError,
# )