diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0794e85dd138..f2da30c3bf9c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -370,7 +370,7 @@ jobs: # Run Complement - run: | set -o pipefail - go test -v -json -tags synapse_blacklist,msc2403,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt + go test -v -json -tags synapse_blacklist,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt shell: bash name: Run Complement Tests env: diff --git a/.gitignore b/.gitignore index 9034c01f089d..6997d21ae447 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,9 @@ __pycache__/ /media_store/ /uploads +# For direnv users +/.envrc + # IDEs /.idea/ /.ropeproject/ diff --git a/changelog.d/11881.feature b/changelog.d/11881.feature new file mode 100644 index 000000000000..392294ffc335 --- /dev/null +++ b/changelog.d/11881.feature @@ -0,0 +1 @@ +Send device list changes to application services as specified by [MSC3202](https://github.com/matrix-org/matrix-spec-proposals/pull/3202), using unstable prefixes. The `msc3202_transaction_extensions` experimental homeserver config option must be enabled and `org.matrix.msc3202: true` must be present in the application service registration file for device list changes to be sent. The "left" field is currently always empty. \ No newline at end of file diff --git a/changelog.d/12165.misc b/changelog.d/12165.misc new file mode 100644 index 000000000000..4b80b0562e59 --- /dev/null +++ b/changelog.d/12165.misc @@ -0,0 +1 @@ +Remove lingering unstable references to MSC2403 (knocking). diff --git a/changelog.d/12193.misc b/changelog.d/12193.misc new file mode 100644 index 000000000000..a721254d224c --- /dev/null +++ b/changelog.d/12193.misc @@ -0,0 +1 @@ +Omit sending "offline" presence updates to application services after they are initially configured. \ No newline at end of file diff --git a/changelog.d/12271.doc b/changelog.d/12271.doc new file mode 100644 index 000000000000..d9696fc5d5fa --- /dev/null +++ b/changelog.d/12271.doc @@ -0,0 +1 @@ +Clarify documentation for running SyTest against Synapse, including use of Postgres and worker mode. \ No newline at end of file diff --git a/changelog.d/12293.removal b/changelog.d/12293.removal new file mode 100644 index 000000000000..25214a4b4944 --- /dev/null +++ b/changelog.d/12293.removal @@ -0,0 +1 @@ +Remove the unused and unstable `/aggregations` endpoint which was removed from [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/changelog.d/12302.feature b/changelog.d/12302.feature new file mode 100644 index 000000000000..603fa2d23a45 --- /dev/null +++ b/changelog.d/12302.feature @@ -0,0 +1 @@ +Add a module callback to react to new 3PID (email address, phone number) associations. diff --git a/changelog.d/12330.misc b/changelog.d/12330.misc new file mode 100644 index 000000000000..9f333e718a86 --- /dev/null +++ b/changelog.d/12330.misc @@ -0,0 +1 @@ +Avoid trying to calculate the state at outlier events. diff --git a/changelog.d/12331.doc b/changelog.d/12331.doc new file mode 100644 index 000000000000..ec0ca3ea9531 --- /dev/null +++ b/changelog.d/12331.doc @@ -0,0 +1 @@ +Update dead links in `check-newsfragment.sh` to point to the correct documentation URL. diff --git a/changelog.d/12333.bugfix b/changelog.d/12333.bugfix new file mode 100644 index 000000000000..2c073a77d58e --- /dev/null +++ b/changelog.d/12333.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug affecting URL previews that would generate a 500 response instead of a 403 if the previewed URL includes a port that isn't allowed by the relevant blacklist. diff --git a/changelog.d/12334.misc b/changelog.d/12334.misc new file mode 100644 index 000000000000..10a57e23b3af --- /dev/null +++ b/changelog.d/12334.misc @@ -0,0 +1 @@ +Remove the `tox` packaging job: it will be redundant once #11537 lands. diff --git a/changelog.d/12335.misc b/changelog.d/12335.misc new file mode 100644 index 000000000000..39ea3611596c --- /dev/null +++ b/changelog.d/12335.misc @@ -0,0 +1 @@ +Ignore `.envrc` for `direnv` users. diff --git a/changelog.d/12336.misc b/changelog.d/12336.misc new file mode 100644 index 000000000000..0aecd543f920 --- /dev/null +++ b/changelog.d/12336.misc @@ -0,0 +1 @@ +Remove the (broadly unused, dev-only) dockerfile for pg tests. diff --git a/changelog.d/12338.misc b/changelog.d/12338.misc new file mode 100644 index 000000000000..376089f32767 --- /dev/null +++ b/changelog.d/12338.misc @@ -0,0 +1 @@ +Refactor relations code to remove an unnecessary class. diff --git a/docker/Dockerfile-pgtests b/docker/Dockerfile-pgtests deleted file mode 100644 index b94484ea7fd6..000000000000 --- a/docker/Dockerfile-pgtests +++ /dev/null @@ -1,30 +0,0 @@ -# Use the Sytest image that comes with a lot of the build dependencies -# pre-installed -FROM matrixdotorg/sytest:focal - -# The Sytest image doesn't come with python, so install that -RUN apt-get update && apt-get -qq install -y python3 python3-dev python3-pip - -# We need tox to run the tests in run_pg_tests.sh -RUN python3 -m pip install tox - -# Initialise the db -RUN su -c '/usr/lib/postgresql/10/bin/initdb -D /var/lib/postgresql/data -E "UTF-8" --lc-collate="C.UTF-8" --lc-ctype="C.UTF-8" --username=postgres' postgres - -# Add a user with our UID and GID so that files get created on the host owned -# by us, not root. -ARG UID -ARG GID -RUN groupadd --gid $GID user -RUN useradd --uid $UID --gid $GID --groups sudo --no-create-home user - -# Ensure we can start postgres by sudo-ing as the postgres user. -RUN apt-get update && apt-get -qq install -y sudo -RUN echo "user ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers - -ADD run_pg_tests.sh /run_pg_tests.sh -# Use the "exec form" of ENTRYPOINT (https://docs.docker.com/engine/reference/builder/#entrypoint) -# so that we can `docker run` this container and pass arguments to pg_tests.sh -ENTRYPOINT ["/run_pg_tests.sh"] - -USER user diff --git a/docker/README-testing.md b/docker/README-testing.md index 6a5baf9e2835..b0105092758b 100644 --- a/docker/README-testing.md +++ b/docker/README-testing.md @@ -78,7 +78,7 @@ the root of your Complement checkout and run: docker build -t matrixdotorg/complement-synapse-workers -f dockerfiles/SynapseWorkers.Dockerfile dockerfiles ``` -This will build an image with the tag `complement-synapse`, which can be handed to +This will build an image with the tag `complement-synapse-workers`, which can be handed to Complement for testing via the `COMPLEMENT_BASE_IMAGE` environment variable. Refer to [Complement's documentation](https://github.com/matrix-org/complement/#running) for how to run the tests, as well as the various available command line flags. diff --git a/docker/run_pg_tests.sh b/docker/run_pg_tests.sh deleted file mode 100755 index b22b6ef16b7e..000000000000 --- a/docker/run_pg_tests.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env bash - -# This script runs the PostgreSQL tests inside a Docker container. It expects -# the relevant source files to be mounted into /src (done automatically by the -# caller script). It will set up the database, run it, and then use the tox -# configuration to run the tests. - -set -e - -# Set PGUSER so Synapse's tests know what user to connect to the database with -export PGUSER=postgres - -# Start the database -sudo -u postgres /usr/lib/postgresql/10/bin/pg_ctl -w -D /var/lib/postgresql/data start - -# Run the tests -cd /src -export TRIAL_FLAGS="-j 4" -tox --workdir=./.tox-pg-container -e py37-postgres "$@" diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md index fe29470f2612..4a8e4020127e 100644 --- a/docs/development/contributing_guide.md +++ b/docs/development/contributing_guide.md @@ -215,27 +215,6 @@ export SYNAPSE_POSTGRES_PASSWORD=mydevenvpassword trial ``` -#### Prebuilt container - -Since configuring PostgreSQL can be fiddly, we can make use of a pre-made -Docker container to set up PostgreSQL and run our tests for us. To do so, run - -```shell -scripts-dev/test_postgresql.sh -``` - -Any extra arguments to the script will be passed to `tox` and then to `trial`, -so we can run a specific test in this container with e.g. - -```shell -scripts-dev/test_postgresql.sh tests.replication.test_sharded_event_persister.EventPersisterShardTestCase -``` - -The container creates a folder in your Synapse checkout called -`.tox-pg-container` and uses this as a tox environment. The output of any -`trial` runs goes into `_trial_temp` in your synapse source directory — the same -as running `trial` directly on your host machine. - ## Run the integration tests ([Sytest](https://github.com/matrix-org/sytest)). The integration tests are a more comprehensive suite of tests. They @@ -249,8 +228,14 @@ configuration: ```sh $ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v /path/to/where/you/want/logs\:/logs matrixdotorg/sytest-synapse:buster ``` +(Note that the paths must be full paths! You could also write `$(realpath relative/path)` if needed.) + +This configuration should generally cover your needs. + +- To run with Postgres, supply the `-e POSTGRES=1 -e MULTI_POSTGRES=1` environment flags. +- To run with Synapse in worker mode, supply the `-e WORKERS=1 -e REDIS=1` environment flags (in addition to the Postgres flags). -This configuration should generally cover your needs. For more details about other configurations, see [documentation in the SyTest repo](https://github.com/matrix-org/sytest/blob/develop/docker/README.md). +For more details about other configurations, see the [Docker-specific documentation in the SyTest repo](https://github.com/matrix-org/sytest/blob/develop/docker/README.md). ## Run the integration tests ([Complement](https://github.com/matrix-org/complement)). diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index 1d3c39967faa..e1a5b6524fb4 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -247,6 +247,24 @@ admin API. If multiple modules implement this callback, Synapse runs them all in order. +### `on_threepid_bind` + +_First introduced in Synapse v1.56.0_ + +```python +async def on_threepid_bind(user_id: str, medium: str, address: str) -> None: +``` + +Called after creating an association between a local user and a third-party identifier +(email address, phone number). The module is given the Matrix ID of the user the +association is for, as well as the medium (`email` or `msisdn`) and address of the +third-party identifier. + +Note that this callback is _not_ called after a successful association on an _identity +server_. + +If multiple modules implement this callback, Synapse runs them all in order. + ## Example The example below is a module that implements the third-party rules callback diff --git a/scripts-dev/check-newsfragment.sh b/scripts-dev/check-newsfragment.sh index 493558ad651b..effea0929c93 100755 --- a/scripts-dev/check-newsfragment.sh +++ b/scripts-dev/check-newsfragment.sh @@ -19,7 +19,7 @@ if ! git diff --quiet FETCH_HEAD... -- debian; then if git diff --quiet FETCH_HEAD... -- debian/changelog; then echo "Updates to debian directory, but no update to the changelog." >&2 echo "!! Please see the contributing guide for help writing your changelog entry:" >&2 - echo "https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#debian-changelog" >&2 + echo "https://matrix-org.github.io/synapse/latest/development/contributing_guide.html#debian-changelog" >&2 exit 1 fi fi @@ -32,7 +32,7 @@ fi # Print a link to the contributing guide if the user makes a mistake CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing your changelog entry: -https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog" +https://matrix-org.github.io/synapse/latest/development/contributing_guide.html#changelog" # If check-newsfragment returns a non-zero exit code, print the contributing guide and exit python -m towncrier.check --compare-with=origin/develop || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1) diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 0a79a4063f5a..d1b59ff0401b 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -71,4 +71,4 @@ fi # Run the tests! echo "Images built; running complement" -go test -v -tags synapse_blacklist,msc2403,msc2716,msc3030 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... +go test -v -tags synapse_blacklist,msc2716,msc3030 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... diff --git a/scripts-dev/test_postgresql.sh b/scripts-dev/test_postgresql.sh deleted file mode 100755 index 43cfa256e4da..000000000000 --- a/scripts-dev/test_postgresql.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env bash - -# This script builds the Docker image to run the PostgreSQL tests, and then runs -# the tests. It uses a dedicated tox environment so that we don't have to -# rebuild it each time. - -# Command line arguments to this script are forwarded to "tox" and then to "trial". - -set -e - -# Build, and tag -docker build docker/ \ - --build-arg "UID=$(id -u)" \ - --build-arg "GID=$(id -g)" \ - -f docker/Dockerfile-pgtests \ - -t synapsepgtests - -# Run, mounting the current directory into /src -docker run --rm -it -v "$(pwd):/src" -v synapse-pg-test-tox:/tox synapsepgtests "$@" diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 07ec95f1d67e..d23d9221bc5b 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +23,13 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id +from synapse.types import ( + DeviceListUpdates, + GroupID, + JsonDict, + UserID, + get_domain_from_id, +) from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: @@ -400,6 +407,7 @@ def __init__( to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, ): self.service = service self.id = id @@ -408,6 +416,7 @@ def __init__( self.to_device_messages = to_device_messages self.one_time_key_counts = one_time_key_counts self.unused_fallback_keys = unused_fallback_keys + self.device_list_summary = device_list_summary async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -424,6 +433,7 @@ async def send(self, as_api: "ApplicationServiceApi") -> bool: to_device_messages=self.to_device_messages, one_time_key_counts=self.one_time_key_counts, unused_fallback_keys=self.unused_fallback_keys, + device_list_summary=self.device_list_summary, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 98fe354014c4..0cdbb04bfbe3 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +28,7 @@ from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient -from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -225,6 +226,7 @@ async def push_bulk( to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, txn_id: Optional[int] = None, ) -> bool: """ @@ -268,6 +270,7 @@ async def push_bulk( } ) + # TODO: Update to stable prefixes once MSC3202 completes FCP merge if service.msc3202_transaction_extensions: if one_time_key_counts: body[ @@ -277,6 +280,11 @@ async def push_bulk( body[ "org.matrix.msc3202.device_unused_fallback_keys" ] = unused_fallback_keys + if device_list_summary: + body["org.matrix.msc3202.device_lists"] = { + "changed": list(device_list_summary.changed), + "left": list(device_list_summary.left), + } try: await self.put_json( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index a6084b9c355d..3b49e6071677 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -72,7 +72,7 @@ from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main import DataStore -from synapse.types import JsonDict +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import Clock if TYPE_CHECKING: @@ -122,6 +122,7 @@ def enqueue_for_appservice( events: Optional[Collection[EventBase]] = None, ephemeral: Optional[Collection[JsonDict]] = None, to_device_messages: Optional[Collection[JsonDict]] = None, + device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: """ Enqueue some data to be sent off to an application service. @@ -133,10 +134,18 @@ def enqueue_for_appservice( to_device_messages: The to-device messages to send. These differ from normal to-device messages sent to clients, as they have 'to_device_id' and 'to_user_id' fields. + device_list_summary: A summary of users that the application service either needs + to refresh the device lists of, or those that the application service need no + longer track the device lists of. """ # We purposefully allow this method to run with empty events/ephemeral # collections, so that callers do not need to check iterable size themselves. - if not events and not ephemeral and not to_device_messages: + if ( + not events + and not ephemeral + and not to_device_messages + and not device_list_summary + ): return if events: @@ -147,6 +156,10 @@ def enqueue_for_appservice( self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend( to_device_messages ) + if device_list_summary: + self.queuer.queued_device_list_summaries.setdefault( + appservice.id, [] + ).append(device_list_summary) # Kick off a new application service transaction self.queuer.start_background_request(appservice) @@ -169,6 +182,8 @@ def __init__( self.queued_ephemeral: Dict[str, List[JsonDict]] = {} # dict of {service_id: [to_device_message_json]} self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} + # dict of {service_id: [device_list_summary]} + self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {} # the appservices which currently have a transaction in flight self.requests_in_flight: Set[str] = set() @@ -212,7 +227,35 @@ async def _send_request(self, service: ApplicationService) -> None: ] del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION] - if not events and not ephemeral and not to_device_messages_to_send: + # Consolidate any pending device list summaries into a single, up-to-date + # summary. + # Note: this code assumes that in a single DeviceListUpdates, a user will + # never be in both "changed" and "left" sets. + device_list_summary = DeviceListUpdates() + for summary in self.queued_device_list_summaries.get(service.id, []): + # For every user in the incoming "changed" set: + # * Remove them from the existing "left" set if necessary + # (as we need to start tracking them again) + # * Add them to the existing "changed" set if necessary. + device_list_summary.left.difference_update(summary.changed) + device_list_summary.changed.update(summary.changed) + + # For every user in the incoming "left" set: + # * Remove them from the existing "changed" set if necessary + # (we no longer need to track them) + # * Add them to the existing "left" set if necessary. + device_list_summary.changed.difference_update(summary.left) + device_list_summary.left.update(summary.left) + self.queued_device_list_summaries.clear() + + if ( + not events + and not ephemeral + and not to_device_messages_to_send + # DeviceListUpdates is True if either the 'changed' or 'left' sets have + # at least one entry, otherwise False + and not device_list_summary + ): return one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None @@ -240,6 +283,7 @@ async def _send_request(self, service: ApplicationService) -> None: to_device_messages_to_send, one_time_key_counts, unused_fallback_keys, + device_list_summary, ) except Exception: logger.exception("AS request failed") @@ -322,6 +366,7 @@ async def send( to_device_messages: Optional[List[JsonDict]] = None, one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, + device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -336,6 +381,7 @@ async def send( appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. + device_list_summary: The device list summary to include in the transaction. """ try: txn = await self.store.create_appservice_txn( @@ -345,6 +391,7 @@ async def send( to_device_messages=to_device_messages or [], one_time_key_counts=one_time_key_counts or {}, unused_fallback_keys=unused_fallback_keys or {}, + device_list_summary=device_list_summary or DeviceListUpdates(), ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 439bfe15269c..ada165f23854 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -170,6 +170,7 @@ def _load_appservice( # When enabled, appservice transactions contain the following information: # - device One-Time Key counts # - device unused fallback key usage states + # - device list changes msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) if not isinstance(msc3202_transaction_extensions, bool): raise ValueError( diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 064db4487c85..d6bb1f752b52 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -59,8 +59,9 @@ def read_config(self, config: JsonDict, **kwargs): "msc3202_device_masquerading", False ) - # Portion of MSC3202 related to transaction extensions: - # sending one-time key counts and fallback key usage to application services. + # The portion of MSC3202 related to transaction extensions: + # sending device list changes, one-time key counts and fallback key + # usage to application services. self.msc3202_transaction_extensions: bool = experimental.get( "msc3202_transaction_extensions", False ) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index bfca454f510d..ef68e2028220 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -42,6 +42,7 @@ CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] +ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable] def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: @@ -169,6 +170,7 @@ def __init__(self, hs: "HomeServer"): self._on_user_deactivation_status_changed_callbacks: List[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = [] + self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = [] def register_third_party_rules_callbacks( self, @@ -187,6 +189,7 @@ def register_third_party_rules_callbacks( on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = None, + on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -221,6 +224,9 @@ def register_third_party_rules_callbacks( on_user_deactivation_status_changed, ) + if on_threepid_bind is not None: + self._on_threepid_bind_callbacks.append(on_threepid_bind) + async def check_event_allowed( self, event: EventBase, context: EventContext ) -> Tuple[bool, Optional[dict]]: @@ -479,3 +485,23 @@ async def on_user_deactivation_status_changed( logger.exception( "Failed to run module API callback %s: %s", callback, e ) + + async def on_threepid_bind(self, user_id: str, medium: str, address: str) -> None: + """Called after a threepid association has been verified and stored. + + Note that this callback is called when an association is created on the + local homeserver, not when it's created on an identity server (and then kept track + of so that it can be unbound on the same IS later on). + + Args: + user_id: the user being associated with the threepid. + medium: the threepid's medium. + address: the threepid's address. + """ + for callback in self._on_threepid_bind_callbacks: + try: + await callback(user_id, medium, address) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index bd913e524e7b..316c4b677ce1 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -33,7 +33,13 @@ wrap_as_background_process, ) from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID +from synapse.types import ( + DeviceListUpdates, + JsonDict, + RoomAlias, + RoomStreamToken, + UserID, +) from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure @@ -58,6 +64,9 @@ def __init__(self, hs: "HomeServer"): self._msc2409_to_device_messages_enabled = ( hs.config.experimental.msc2409_to_device_messages_enabled ) + self._msc3202_transaction_extensions_enabled = ( + hs.config.experimental.msc3202_transaction_extensions + ) self.current_max = 0 self.is_processing = False @@ -204,9 +213,9 @@ def notify_interested_services_ephemeral( Args: stream_key: The stream the event came from. - `stream_key` can be "typing_key", "receipt_key", "presence_key" or - "to_device_key". Any other value for `stream_key` will cause this function - to return early. + `stream_key` can be "typing_key", "receipt_key", "presence_key", + "to_device_key" or "device_list_key". Any other value for `stream_key` + will cause this function to return early. Ephemeral events will only be pushed to appservices that have opted into receiving them by setting `push_ephemeral` to true in their registration @@ -230,6 +239,7 @@ def notify_interested_services_ephemeral( "receipt_key", "presence_key", "to_device_key", + "device_list_key", ): return @@ -253,15 +263,37 @@ def notify_interested_services_ephemeral( ): return + # Ignore device lists if the feature flag is not enabled + if ( + stream_key == "device_list_key" + and not self._msc3202_transaction_extensions_enabled + ): + return + # Check whether there are any appservices which have registered to receive # ephemeral events. # # Note that whether these events are actually relevant to these appservices # is decided later on. + services = self.store.get_app_services() services = [ service - for service in self.store.get_app_services() - if service.supports_ephemeral + for service in services + # Different stream keys require different support booleans + if ( + stream_key + in ( + "typing_key", + "receipt_key", + "presence_key", + "to_device_key", + ) + and service.supports_ephemeral + ) + or ( + stream_key == "device_list_key" + and service.msc3202_transaction_extensions + ) ] if not services: # Bail out early if none of the target appservices have explicitly registered @@ -336,6 +368,20 @@ async def _notify_interested_services_ephemeral( service, "to_device", new_token ) + elif stream_key == "device_list_key": + device_list_summary = await self._get_device_list_summary( + service, new_token + ) + if device_list_summary: + self.scheduler.enqueue_for_appservice( + service, device_list_summary=device_list_summary + ) + + # Persist the latest handled stream token for this appservice + await self.store.set_appservice_stream_type_pos( + service, "device_list", new_token + ) + async def _handle_typing( self, service: ApplicationService, new_token: int ) -> List[JsonDict]: @@ -542,6 +588,96 @@ async def _get_to_device_messages( return message_payload + async def _get_device_list_summary( + self, + appservice: ApplicationService, + new_key: int, + ) -> DeviceListUpdates: + """ + Retrieve a list of users who have changed their device lists. + + Args: + appservice: The application service to retrieve device list changes for. + new_key: The stream key of the device list change that triggered this method call. + + Returns: + A set of device list updates, comprised of users that the appservices needs to: + * resync the device list of, and + * stop tracking the device list of. + """ + # Fetch the last successfully processed device list update stream ID + # for this appservice. + from_key = await self.store.get_type_stream_id_for_appservice( + appservice, "device_list" + ) + + # Fetch the users who have modified their device list since then. + users_with_changed_device_lists = ( + await self.store.get_users_whose_devices_changed(from_key, to_key=new_key) + ) + + # Filter out any users the application service is not interested in + # + # For each user who changed their device list, we want to check whether this + # appservice would be interested in the change. + filtered_users_with_changed_device_lists = { + user_id + for user_id in users_with_changed_device_lists + if await self._is_appservice_interested_in_device_lists_of_user( + appservice, user_id + ) + } + + # Create a summary of "changed" and "left" users. + # TODO: Calculate "left" users. + device_list_summary = DeviceListUpdates( + changed=filtered_users_with_changed_device_lists + ) + + return device_list_summary + + async def _is_appservice_interested_in_device_lists_of_user( + self, + appservice: ApplicationService, + user_id: str, + ) -> bool: + """ + Returns whether a given application service is interested in the device list + updates of a given user. + + The application service is interested in the user's device list updates if any + of the following are true: + * The user is the appservice's sender localpart user. + * The user is in the appservice's user namespace. + * At least one member of one room that the user is a part of is in the + appservice's user namespace. + * The appservice is explicitly (via room ID or alias) interested in at + least one room that the user is in. + + Args: + appservice: The application service to gauge interest of. + user_id: The ID of the user whose device list interest is in question. + + Returns: + True if the application service is interested in the user's device lists, False + otherwise. + """ + # This method checks against both the sender localpart user as well as if the + # user is in the appservice's user namespace. + if appservice.is_interested_in_user(user_id): + return True + + # Determine whether any of the rooms the user is in justifies sending this + # device list update to the application service. + room_ids = await self.store.get_rooms_for_user(user_id) + for room_id in room_ids: + # This method covers checking room members for appservice interest as well as + # room ID and alias checks. + if await appservice.is_interested_in_room(room_id, self.store): + return True + + return False + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3e29c96a49e5..86991d26ce79 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -211,6 +211,7 @@ def __init__(self, hs: "HomeServer"): self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.auth.password_enabled self._password_localdb_enabled = hs.config.auth.password_localdb_enabled + self._third_party_rules = hs.get_third_party_event_rules() # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. @@ -1505,6 +1506,8 @@ async def add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) + await self._third_party_rules.on_threepid_bind(user_id, medium, address) + async def delete_threepid( self, user_id: str, medium: str, address: str, id_server: Optional[str] = None ) -> bool: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 4bd87709f373..567afc910f27 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -469,6 +469,12 @@ async def process_remote_join( if context.rejected: raise SynapseError(400, "Join event was rejected") + # the remote server is responsible for sending our join event to the rest + # of the federation. Indeed, attempting to do so will result in problems + # when we try to look up the state before the join (to get the server list) + # and discover that we do not have it. + event.internal_metadata.proactively_send = False + return await self.persist_events_and_notify(room_id, [(event, context)]) async def backfill( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 34d9411bbf61..dace31d87e17 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1625,7 +1625,7 @@ async def get_new_events( # We'll actually pull the presence updates for these users at the end. interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set() - if from_key: + if from_key is not None: # First get all users that have had a presence update updated_users = stream_change_cache.get_all_entities_changed(from_key) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 73217d135d7c..a36936b5206c 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast +from typing import TYPE_CHECKING, Dict, Iterable, Optional import attr from frozendict import frozendict @@ -25,7 +25,6 @@ if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -116,7 +115,7 @@ async def get_relations( if event is None: raise SynapseError(404, "Unknown parent event.") - pagination_chunk = await self._main_store.get_relations_for_event( + related_events, next_token = await self._main_store.get_relations_for_event( event_id=event_id, event=event, room_id=room_id, @@ -129,9 +128,7 @@ async def get_relations( to_token=to_token, ) - events = await self._main_store.get_events_as_list( - [c["event_id"] for c in pagination_chunk.chunk] - ) + events = await self._main_store.get_events_as_list(related_events) events = await filter_events_for_client( self._storage, user_id, events, is_peeking=(member_event_id is None) @@ -152,9 +149,16 @@ async def get_relations( events, now, bundle_aggregations=aggregations ) - return_value = await pagination_chunk.to_dict(self._main_store) - return_value["chunk"] = serialized_events - return_value["original_event"] = original_event + return_value = { + "chunk": serialized_events, + "original_event": original_event, + } + + if next_token: + return_value["next_batch"] = await next_token.to_string(self._main_store) + + if from_token: + return_value["prev_batch"] = await from_token.to_string(self._main_store) return return_value @@ -193,16 +197,21 @@ async def _get_bundled_aggregation_for_event( annotations = await self._main_store.get_aggregation_groups_for_event( event_id, room_id ) - if annotations.chunk: - aggregations.annotations = await annotations.to_dict( - cast("DataStore", self) - ) + if annotations: + aggregations.annotations = {"chunk": annotations} - references = await self._main_store.get_relations_for_event( + references, next_token = await self._main_store.get_relations_for_event( event_id, event, room_id, RelationTypes.REFERENCE, direction="f" ) - if references.chunk: - aggregations.references = await references.to_dict(cast("DataStore", self)) + if references: + aggregations.references = { + "chunk": [{"event_id": event_id} for event_id in references] + } + + if next_token: + aggregations.references["next_batch"] = await next_token.to_string( + self._main_store + ) # Store the bundled aggregations in the event metadata for later use. return aggregations diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index bceafca3b1c6..303c38c7460e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,17 +13,7 @@ # limitations under the License. import itertools import logging -from typing import ( - TYPE_CHECKING, - Any, - Collection, - Dict, - FrozenSet, - List, - Optional, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -41,6 +31,7 @@ from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( + DeviceListUpdates, JsonDict, MutableStateMap, Requester, @@ -184,21 +175,6 @@ def __bool__(self) -> bool: return bool(self.join or self.invite or self.leave) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class DeviceLists: - """ - Attributes: - changed: List of user_ids whose devices may have changed - left: List of user_ids whose devices we no longer track - """ - - changed: Collection[str] - left: Collection[str] - - def __bool__(self) -> bool: - return bool(self.changed or self.left) - - @attr.s(slots=True, auto_attribs=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined @@ -240,7 +216,7 @@ class SyncResult: knocked: List[KnockedSyncResult] archived: List[ArchivedSyncResult] to_device: List[JsonDict] - device_lists: DeviceLists + device_lists: DeviceListUpdates device_one_time_keys_count: JsonDict device_unused_fallback_key_types: List[str] groups: Optional[GroupsSyncResult] @@ -1264,8 +1240,8 @@ async def _generate_sync_entry_for_device_list( newly_joined_or_invited_or_knocked_users: Set[str], newly_left_rooms: Set[str], newly_left_users: Set[str], - ) -> DeviceLists: - """Generate the DeviceLists section of sync + ) -> DeviceListUpdates: + """Generate the DeviceListUpdates section of sync Args: sync_result_builder @@ -1383,9 +1359,11 @@ async def _generate_sync_entry_for_device_list( if any(e.room_id in joined_rooms for e in entries): newly_left_users.discard(user_id) - return DeviceLists(changed=users_that_have_changed, left=newly_left_users) + return DeviceListUpdates( + changed=users_that_have_changed, left=newly_left_users + ) else: - return DeviceLists(changed=[], left=[]) + return DeviceListUpdates() async def _generate_sync_entry_for_to_device( self, sync_result_builder: "SyncResultBuilder" diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index ba9755f08b8c..3c7dcca74daf 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -62,6 +62,7 @@ ON_CREATE_ROOM_CALLBACK, ON_NEW_EVENT_CALLBACK, ON_PROFILE_UPDATE_CALLBACK, + ON_THREEPID_BIND_CALLBACK, ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK, ) from synapse.handlers.account_validity import ( @@ -293,6 +294,7 @@ def register_third_party_rules_callbacks( on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = None, + on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, ) -> None: """Registers callbacks for third party event rules capabilities. @@ -308,6 +310,7 @@ def register_third_party_rules_callbacks( check_can_deactivate_user=check_can_deactivate_user, on_profile_update=on_profile_update, on_user_deactivation_status_changed=on_user_deactivation_status_changed, + on_threepid_bind=on_threepid_bind, ) def register_presence_router_callbacks( diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index c16078b187ee..55c96a2af3d9 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -12,22 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This class implements the proposed relation APIs from MSC 1849. - -Since the MSC has not been approved all APIs here are unstable and may change at -any time to reflect changes in the MSC. -""" - import logging from typing import TYPE_CHECKING, Optional, Tuple -from synapse.api.constants import RelationTypes -from synapse.api.errors import SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import AggregationPaginationToken from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: @@ -93,166 +84,5 @@ async def on_GET( return 200, result -class RelationAggregationPaginationServlet(RestServlet): - """API to paginate aggregation groups of relations, e.g. paginate the - types and counts of the reactions on the events. - - Example request and response: - - GET /rooms/{room_id}/aggregations/{parent_id} - - { - chunk: [ - { - "type": "m.reaction", - "key": "👍", - "count": 3 - } - ] - } - """ - - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" - "(/(?P[^/]*)(/(?P[^/]*))?)?$", - releases=(), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self.event_handler = hs.get_event_handler() - - async def on_GET( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: Optional[str] = None, - event_type: Optional[str] = None, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - - if relation_type not in (RelationTypes.ANNOTATION, None): - raise SynapseError( - 400, f"Relation type must be '{RelationTypes.ANNOTATION}'" - ) - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - # Return the relations - from_token = None - if from_token_str: - from_token = AggregationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = AggregationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_aggregation_groups_for_event( - event_id=parent_id, - room_id=room_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - return 200, await pagination_chunk.to_dict(self.store) - - -class RelationAggregationGroupPaginationServlet(RestServlet): - """API to paginate within an aggregation group of relations, e.g. paginate - all the 👍 reactions on an event. - - Example request and response: - - GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 - - { - chunk: [ - { - "type": "m.reaction", - "content": { - "m.relates_to": { - "rel_type": "m.annotation", - "key": "👍" - } - } - }, - ... - ] - } - """ - - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" - "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)$", - releases=(), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self._relations_handler = hs.get_relations_handler() - - async def on_GET( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: str, - event_type: str, - key: str, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - if relation_type != RelationTypes.ANNOTATION: - raise SynapseError(400, "Relation type must be 'annotation'") - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - from_token = None - if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) - - result = await self._relations_handler.get_relations( - requester=requester, - event_id=parent_id, - room_id=room_id, - relation_type=relation_type, - event_type=event_type, - aggregation_key=key, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - return 200, result - - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) - RelationAggregationPaginationServlet(hs).register(http_server) - RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index d47af8ead6b7..50383bdbd1c5 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -200,12 +200,17 @@ async def _async_render_GET(self, request: SynapseRequest) -> None: match = False continue + # Some attributes might not be parsed as strings by urlsplit (such as the + # port, which is parsed as an int). Because we use match functions that + # expect strings, we want to make sure that's what we give them. + value_str = str(value) + if pattern.startswith("^"): - if not re.match(pattern, getattr(url_tuple, attrib)): + if not re.match(pattern, value_str): match = False continue else: - if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern): + if not fnmatch.fnmatch(value_str, pattern): match = False continue if match: diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 06944465582b..55e1ab099d53 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -29,7 +29,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.types import JsonDict +from synapse.types import DeviceListUpdates, JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import _CacheContext, cached @@ -217,6 +217,7 @@ async def create_appservice_txn( to_device_messages: List[JsonDict], one_time_key_counts: TransactionOneTimeKeyCounts, unused_fallback_keys: TransactionUnusedFallbackKeys, + device_list_summary: DeviceListUpdates, ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -231,6 +232,7 @@ async def create_appservice_txn( appservice devices in the transaction. unused_fallback_keys: Lists of unused fallback keys for relevant appservice devices in the transaction. + device_list_summary: The device list summary to include in the transaction. Returns: A new transaction. @@ -268,6 +270,7 @@ def _create_appservice_txn(txn): to_device_messages=to_device_messages, one_time_key_counts=one_time_key_counts, unused_fallback_keys=unused_fallback_keys, + device_list_summary=device_list_summary, ) return await self.db_pool.runInteraction( @@ -359,8 +362,8 @@ def _get_oldest_unsent_txn(txn): events = await self.get_events_as_list(event_ids) - # TODO: to-device messages, one-time key counts and unused fallback keys - # are not yet populated for catch-up transactions. + # TODO: to-device messages, one-time key counts, device list summaries and unused + # fallback keys are not yet populated for catch-up transactions. # We likely want to populate those for reliability. return AppServiceTransaction( service=service, @@ -370,6 +373,7 @@ def _get_oldest_unsent_txn(txn): to_device_messages=[], one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: @@ -430,7 +434,7 @@ def get_new_events_for_appservice_txn(txn): async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: - if type not in ("read_receipt", "presence", "to_device"): + if type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (type,) @@ -446,7 +450,8 @@ def get_type_stream_id_for_appservice_txn(txn): ) last_stream_id = txn.fetchone() if last_stream_id is None or last_stream_id[0] is None: # no row exists - return 0 + # Stream tokens always start from 1, to avoid foot guns around `0` being falsey. + return 1 else: return int(last_stream_id[0]) @@ -457,7 +462,7 @@ def get_type_stream_id_for_appservice_txn(txn): async def set_appservice_stream_type_pos( self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if stream_type not in ("read_receipt", "presence", "to_device"): + if stream_type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (stream_type,) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3b3a089b7627..f08f7834d39e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -681,42 +681,64 @@ def get_cached_device_list_changes( return self._device_list_stream_cache.get_all_entities_changed(from_key) async def get_users_whose_devices_changed( - self, from_key: int, user_ids: Iterable[str] + self, + from_key: int, + user_ids: Optional[Iterable[str]] = None, + to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key: The device lists stream token - user_ids: The user IDs to query for devices. + from_key: The minimum device lists stream token to query device list changes for, + exclusive. + user_ids: If provided, only check if these users have changed their device lists. + Otherwise changes from all users are returned. + to_key: The maximum device lists stream token to query device list changes for, + inclusive. Returns: - The set of user_ids whose devices have changed since `from_key` + The set of user_ids whose devices have changed since `from_key` (exclusive) + until `to_key` (inclusive). """ - # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + if user_ids is None: + # Get set of all users that have had device list changes since 'from_key' + user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( + from_key + ) + else: + # The same as above, but filter results to only those users in 'user_ids' + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) - if not to_check: + if not user_ids_to_check: return set() def _get_users_whose_devices_changed_txn(txn): changes = set() - sql = """ + stream_id_where_clause = "stream_id > ?" + sql_args = [from_key] + + if to_key: + stream_id_where_clause += " AND stream_id <= ?" + sql_args.append(to_key) + + sql = f""" SELECT DISTINCT user_id FROM device_lists_stream - WHERE stream_id > ? + WHERE {stream_id_where_clause} AND """ - for chunk in batch_iter(to_check, 100): + # Query device changes with a batch of users at a time + for chunk in batch_iter(user_ids_to_check, 100): clause, args = make_in_list_sql_clause( txn.database_engine, "user_id", chunk ) - txn.execute(sql + clause, (from_key,) + tuple(args)) + txn.execute(sql + clause, sql_args + args) changes.update(user_id for user_id, in txn) return changes diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index b2295fd51f60..64a78081402e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -26,8 +26,6 @@ cast, ) -import attr - from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -39,8 +37,7 @@ ) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine -from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import RoomStreamToken, StreamToken +from synapse.types import JsonDict, RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -73,7 +70,7 @@ async def get_relations_for_event( direction: str = "b", from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, - ) -> PaginationChunk: + ) -> Tuple[List[str], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. Args: @@ -90,8 +87,10 @@ async def get_relations_for_event( to_token: Fetch rows up to the given token, or up to the end if None. Returns: - List of event IDs that match relations requested. The rows are of - the form `{"event_id": "..."}`. + A tuple of: + A list of related event IDs + + The next stream token, if one exists. """ # We don't use `event_id`, it's there so that we can cache based on # it. The `event_id` must match the `event.event_id`. @@ -146,7 +145,7 @@ async def get_relations_for_event( def _get_recent_references_for_event_txn( txn: LoggingTransaction, - ) -> PaginationChunk: + ) -> Tuple[List[str], Optional[StreamToken]]: txn.execute(sql, where_args + [limit + 1]) last_topo_id = None @@ -156,7 +155,7 @@ def _get_recent_references_for_event_txn( # Do not include edits for redacted events as they leak event # content. if not is_redacted or row[1] != RelationTypes.REPLACE: - events.append({"event_id": row[0]}) + events.append(row[0]) last_topo_id = row[2] last_stream_id = row[3] @@ -179,9 +178,7 @@ def _get_recent_references_for_event_txn( groups_key=0, ) - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token - ) + return events[:limit], next_token return await self.db_pool.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn @@ -252,15 +249,8 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool: @cached(tree=True) async def get_aggregation_groups_for_event( - self, - event_id: str, - room_id: str, - event_type: Optional[str] = None, - limit: int = 5, - direction: str = "b", - from_token: Optional[AggregationPaginationToken] = None, - to_token: Optional[AggregationPaginationToken] = None, - ) -> PaginationChunk: + self, event_id: str, room_id: str, limit: int = 5 + ) -> List[JsonDict]: """Get a list of annotations on the event, grouped by event type and aggregation key, sorted by count. @@ -270,79 +260,36 @@ async def get_aggregation_groups_for_event( Args: event_id: Fetch events that relate to this event ID. room_id: The room the event belongs to. - event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. - direction: Whether to fetch the highest count first (`"b"`) or - the lowest count first (`"f"`). - from_token: Fetch rows from the given token, or from the start if None. - to_token: Fetch rows up to the given token, or up to the end if None. Returns: List of groups of annotations that match. Each row is a dict with `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [ + where_args = [ event_id, room_id, RelationTypes.ANNOTATION, + limit, ] - if event_type: - where_clause.append("type = ?") - where_args.append(event_type) - - having_clause = generate_pagination_where_clause( - direction=direction, - column_names=("COUNT(*)", "MAX(stream_ordering)"), - from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] - to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] - engine=self.database_engine, - ) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - if having_clause: - having_clause = "HAVING " + having_clause - else: - having_clause = "" - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) + SELECT type, aggregation_key, COUNT(DISTINCT sender) FROM event_relations INNER JOIN events USING (event_id) - WHERE {where_clause} + WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? GROUP BY relation_type, type, aggregation_key - {having_clause} - ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + ORDER BY COUNT(*) DESC LIMIT ? - """.format( - where_clause=" AND ".join(where_clause), - order=order, - having_clause=having_clause, - ) + """ def _get_aggregation_groups_for_event_txn( txn: LoggingTransaction, - ) -> PaginationChunk: - txn.execute(sql, where_args + [limit + 1]) - - next_batch = None - events = [] - for row in txn: - events.append({"type": row[0], "key": row[1], "count": row[2]}) - next_batch = AggregationPaginationToken(row[2], row[3]) + ) -> List[JsonDict]: + txn.execute(sql, where_args) - if len(events) <= limit: - next_batch = None - - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) + return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] return await self.db_pool.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py deleted file mode 100644 index fba270150b63..000000000000 --- a/synapse/storage/relations.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple - -import attr - -from synapse.api.errors import SynapseError -from synapse.types import JsonDict - -if TYPE_CHECKING: - from synapse.storage.databases.main import DataStore - -logger = logging.getLogger(__name__) - - -@attr.s(slots=True, auto_attribs=True) -class PaginationChunk: - """Returned by relation pagination APIs. - - Attributes: - chunk: The rows returned by pagination - next_batch: Token to fetch next set of results with, if - None then there are no more results. - prev_batch: Token to fetch previous set of results with, if - None then there are no previous results. - """ - - chunk: List[JsonDict] - next_batch: Optional[Any] = None - prev_batch: Optional[Any] = None - - async def to_dict(self, store: "DataStore") -> Dict[str, Any]: - d = {"chunk": self.chunk} - - if self.next_batch: - d["next_batch"] = await self.next_batch.to_string(store) - - if self.prev_batch: - d["prev_batch"] = await self.prev_batch.to_string(store) - - return d - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class AggregationPaginationToken: - """Pagination token for relation aggregation pagination API. - - As the results are order by count and then MAX(stream_ordering) of the - aggregation groups, we can just use them as our pagination token. - - Attributes: - count: The count of relations in the boundary group. - stream: The MAX stream ordering in the boundary group. - """ - - count: int - stream: int - - @staticmethod - def from_string(string: str) -> "AggregationPaginationToken": - try: - c, s = string.split("-") - return AggregationPaginationToken(int(c), int(s)) - except ValueError: - raise SynapseError(400, "Invalid aggregation pagination token") - - async def to_string(self, store: "DataStore") -> str: - return "%d-%d" % (self.count, self.stream) - - def as_tuple(self) -> Tuple[Any, ...]: - return attr.astuple(self) diff --git a/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql new file mode 100644 index 000000000000..7590e34b94f3 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql @@ -0,0 +1,23 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to track what device list changes stream id that this application +-- service has been caught up to. + +-- We explicitly don't set this field as "NOT NULL", as having NULL as a possible +-- state is useful for determining if we've ever sent traffic for a stream type +-- to an appservice. See https://github.com/matrix-org/synapse/issues/10836 for +-- one way this can be used. +ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT; \ No newline at end of file diff --git a/synapse/types.py b/synapse/types.py index 5ce2a5b0a5ee..500597b3a488 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -25,6 +25,7 @@ Match, MutableMapping, Optional, + Set, Tuple, Type, TypeVar, @@ -748,6 +749,30 @@ class ReadReceipt: data: JsonDict +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DeviceListUpdates: + """ + An object containing a diff of information regarding other users' device lists, intended for + a recipient to carry out device list tracking. + + Attributes: + changed: A set of users whose device lists have changed recently. + left: A set of users who the recipient no longer needs to track the device lists of. + Typically when those users no longer share any end-to-end encryption enabled rooms. + """ + + # We need to use a factory here, otherwise `set` is not evaluated at + # object instantiation, but instead at class definition instantiation. + # The latter happening only once, thus always giving you the same sets + # across multiple DeviceListUpdates instances. + # Also see: don't define mutable default arguments. + changed: Set[str] = attr.ib(factory=set) + left: Set[str] = attr.ib(factory=set) + + def __bool__(self) -> bool: + return bool(self.changed or self.left) + + def get_verify_key_from_cross_signing_key(key_info): """Get the key ID and signedjson verify key from a cross-signing key dict diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 1cbb059357fa..0b22afdc7598 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -24,6 +24,7 @@ ) from synapse.logging.context import make_deferred_yieldable from synapse.server import HomeServer +from synapse.types import DeviceListUpdates from synapse.util import Clock from tests import unittest @@ -70,6 +71,7 @@ def test_single_service_up_txn_sent(self): to_device_messages=[], # txn made and saved one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -96,6 +98,7 @@ def test_single_service_down(self): to_device_messages=[], # txn made and saved one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(0, txn.send.call_count) # txn not sent though self.assertEqual(0, txn.complete.call_count) # or completed @@ -124,6 +127,7 @@ def test_single_service_up_txn_not_sent(self): to_device_messages=[], one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made self.assertEqual(1, self.recoverer.recover.call_count) # and invoked @@ -225,7 +229,9 @@ def test_send_single_event_no_queue(self): service = Mock(id=4) event = Mock() self.scheduler.enqueue_for_appservice(service, events=[event]) - self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None) + self.txn_ctrl.send.assert_called_once_with( + service, [event], [], [], None, None, DeviceListUpdates() + ) def test_send_single_event_with_queue(self): d = defer.Deferred() @@ -240,12 +246,14 @@ def test_send_single_event_with_queue(self): # (call enqueue_for_appservice multiple times deliberately) self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) - self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [event], [], [], None, None, DeviceListUpdates() + ) self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) self.txn_ctrl.send.assert_called_with( - service, [event2, event3], [], [], None, None + service, [event2, event3], [], [], None, None, DeviceListUpdates() ) self.assertEqual(2, self.txn_ctrl.send.call_count) @@ -272,15 +280,21 @@ def do_send(*args, **kwargs): # send events for different ASes and make sure they are sent self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv1, [srv_1_event], [], [], None, None, DeviceListUpdates() + ) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv2, [srv_2_event], [], [], None, None, DeviceListUpdates() + ) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv2, [srv_2_event2], [], [], None, None, DeviceListUpdates() + ) self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -300,17 +314,17 @@ def do_send(*args, **kwargs): # Expect the first event to be sent immediately. self.txn_ctrl.send.assert_called_with( - service, [event_list[0]], [], [], None, None + service, [event_list[0]], [], [], None, None, DeviceListUpdates() ) srv_1_defer.callback(service) # Then send the next 100 events self.txn_ctrl.send.assert_called_with( - service, event_list[1:101], [], [], None, None + service, event_list[1:101], [], [], None, None, DeviceListUpdates() ) srv_2_defer.callback(service) # Then the final 99 events self.txn_ctrl.send.assert_called_with( - service, event_list[101:], [], [], None, None + service, event_list[101:], [], [], None, None, DeviceListUpdates() ) self.assertEqual(3, self.txn_ctrl.send.call_count) @@ -320,7 +334,7 @@ def test_send_single_ephemeral_no_queue(self): event_list = [Mock(name="event")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], event_list, [], None, None + service, [], event_list, [], None, None, DeviceListUpdates() ) def test_send_multiple_ephemeral_no_queue(self): @@ -329,7 +343,7 @@ def test_send_multiple_ephemeral_no_queue(self): event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], event_list, [], None, None + service, [], event_list, [], None, None, DeviceListUpdates() ) def test_send_single_ephemeral_with_queue(self): @@ -345,13 +359,21 @@ def test_send_single_ephemeral_with_queue(self): # Send more events: expect send() to NOT be called multiple times. self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [], event_list_1, [], None, None, DeviceListUpdates() + ) self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( - service, [], event_list_2 + event_list_3, [], None, None + service, + [], + event_list_2 + event_list_3, + [], + None, + None, + DeviceListUpdates(), ) self.assertEqual(2, self.txn_ctrl.send.call_count) @@ -365,8 +387,10 @@ def test_send_large_txns_ephemeral(self): event_list = first_chunk + second_chunk self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], first_chunk, [], None, None + service, [], first_chunk, [], None, None, DeviceListUpdates() ) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [], second_chunk, [], None, None, DeviceListUpdates() + ) self.assertEqual(2, self.txn_ctrl.send.call_count) diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 648a01618e8b..d21c11b716cd 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -23,7 +23,7 @@ from synapse.types import RoomAlias from tests.test_utils import event_injection -from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config +from tests.unittest import FederatingHomeserverTestCase, TestCase class KnockingStrippedStateEventHelperMixin(TestCase): @@ -221,7 +221,6 @@ async def _check_event_auth(origin, event, context, *args, **kwargs): return super().prepare(reactor, clock, homeserver) - @override_config({"experimental_features": {"msc2403_enabled": True}}) def test_room_state_returned_when_knocking(self): """ Tests that specific, stripped state events from a room are returned after diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index cead9f90df56..8c72cf6b308b 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -15,6 +15,8 @@ from typing import Dict, Iterable, List, Optional from unittest.mock import Mock +from parameterized import parameterized + from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -471,6 +473,7 @@ def test_application_services_receive_local_to_device(self): to_device_messages, _otks, _fbks, + _device_list_summary, ) = self.send_mock.call_args[0] # Assert that this was the same to-device message that local_user sent @@ -583,7 +586,15 @@ def test_application_services_receive_bursts_of_to_device(self): service_id_to_message_count: Dict[str, int] = {} for call in self.send_mock.call_args_list: - service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0] + ( + service, + _events, + _ephemeral, + to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = call[0] # Check that this was made to an interested service self.assertIn(service, interested_appservices) @@ -627,6 +638,114 @@ def _register_application_service( return appservice +class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase): + """ + Tests that the ApplicationServicesHandler sends device list updates to application + services correctly. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Allow us to modify cached feature flags mid-test + self.as_handler = hs.get_application_service_handler() + + # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that + # will be sent over the wire + self.put_json = simple_async_mock() + hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] + + # Mock out application services, and allow defining our own in tests + self._services: List[ApplicationService] = [] + self.hs.get_datastores().main.get_app_services = Mock( + return_value=self._services + ) + + # Test across a variety of configuration values + @parameterized.expand( + [ + (True, True, True), + (True, False, False), + (False, True, False), + (False, False, False), + ] + ) + def test_application_service_receives_device_list_updates( + self, + experimental_feature_enabled: bool, + as_supports_txn_extensions: bool, + as_should_receive_device_list_updates: bool, + ): + """ + Tests that an application service receives notice of changed device + lists for a user, when a user changes their device lists. + + Arguments above are populated by parameterized. + + Args: + as_should_receive_device_list_updates: Whether we expect the AS to receive the + device list changes. + experimental_feature_enabled: Whether the "msc3202_transaction_extensions" experimental + feature is enabled. This feature must be enabled for device lists to ASs to work. + as_supports_txn_extensions: Whether the application service has explicitly registered + to receive information defined by MSC3202 - which includes device list changes. + """ + # Change whether the experimental feature is enabled or disabled before making + # device list changes + self.as_handler._msc3202_transaction_extensions_enabled = ( + experimental_feature_enabled + ) + + # Create an appservice that is interested in "local_user" + appservice = ApplicationService( + token=random_string(10), + hostname="example.com", + id=random_string(10), + sender="@as:example.com", + rate_limited=False, + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@local_user:.+", + "exclusive": False, + } + ], + }, + supports_ephemeral=True, + msc3202_transaction_extensions=as_supports_txn_extensions, + # Must be set for Synapse to try pushing data to the AS + hs_token="abcde", + url="some_url", + ) + + # Register the application service + self._services.append(appservice) + + # Register a user on the homeserver + self.local_user = self.register_user("local_user", "password") + self.local_user_token = self.login("local_user", "password") + + if as_should_receive_device_list_updates: + # Ensure that the resulting JSON uses the unstable prefix and contains the + # expected users + self.put_json.assert_called_once() + json_body = self.put_json.call_args[1]["json_body"] + + # Our application service should have received a device list update with + # "local_user" in the "changed" list + device_list_dict = json_body.get("org.matrix.msc3202.device_lists", {}) + self.assertEqual([], device_list_dict["left"]) + self.assertEqual([self.local_user], device_list_dict["changed"]) + + else: + # No device list changes should have been sent out + self.put_json.assert_not_called() + + class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): # Argument indices for pulling out arguments from a `send_mock`. ARG_OTK_COUNTS = 4 diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index fe97a0b3dde1..419eef166ac4 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import urllib.parse from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch @@ -145,16 +144,6 @@ def _get_bundled_aggregations(self) -> JsonDict: self.assertEquals(200, channel.code, channel.json_body) return channel.json_body["unsigned"].get("m.relations", {}) - def _get_aggregations(self) -> List[JsonDict]: - """Request /aggregations on the parent ID and includes the returned chunk.""" - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - return channel.json_body["chunk"] - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: """ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. @@ -264,43 +253,6 @@ def test_deny_forked_thread(self) -> None: expected_response_code=400, ) - def test_aggregation(self) -> None: - """Test that annotations get correctly aggregated.""" - - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual( - channel.json_body, - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - ) - - def test_aggregation_must_be_annotation(self) -> None: - """Test that aggregations must be annotations.""" - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations" - f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(400, channel.code, channel.json_body) - def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -394,15 +346,6 @@ def test_ignore_invalid_room(self) -> None: self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - # And when fetching aggregations. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - # And for bundled aggregations. channel = self.make_request( "GET", @@ -717,15 +660,6 @@ def test_unknown_relations(self) -> None: self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) - # But unknown relations can be directly queried. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") @@ -941,131 +875,6 @@ def test_pagination_from_sync_and_messages(self) -> None: annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] ) - def test_aggregation_pagination_groups(self) -> None: - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - - idx += 1 - idx %= len(access_tokens) - - prev_token: Optional[str] = None - found_groups: Dict[str, int] = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEqual(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self) -> None: - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="👍", - access_token=access_tokens[idx], - ) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - - prev_token = "" - found_event_ids: List[str] = [] - encoded_key = urllib.parse.quote_plus("👍".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - class BundledAggregationsTestCase(BaseRelationsTestCase): """ @@ -1453,10 +1262,6 @@ def test_redact_relation_annotation(self) -> None: {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, ) - # Both relations appear in the aggregation. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}]) - # Redact one of the reactions. self._redact(to_redact_event_id) @@ -1469,10 +1274,6 @@ def test_redact_relation_annotation(self) -> None: {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) - # The unredacted aggregation should still exist. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) - def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. @@ -1578,10 +1379,6 @@ def test_redact_parent_annotation(self) -> None: self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.ANNOTATION, relations) - # The aggregation should exist. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}]) - # Redact the original event. self._redact(self.parent_id) @@ -1594,10 +1391,6 @@ def test_redact_parent_annotation(self) -> None: {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, ) - # There's nothing to aggregate. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_parent_thread(self) -> None: """ diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index f0f3a54f8234..3cebbd18f04c 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -341,7 +341,6 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: hs, self.room_id, self.user_id ) - @override_config({"experimental_features": {"msc2403_enabled": True}}) def test_knock_room_state(self) -> None: """Tests that /sync returns state from a room after knocking on it.""" # Knock on a room diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index e7de67e3a3b0..5eb0f243f747 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -896,3 +896,44 @@ def test_check_can_shutdown_room(self) -> None: # Check that the mock was called with the right room ID self.assertEqual(args[1], self.room_id) + + def test_on_threepid_bind(self) -> None: + """Tests that the on_threepid_bind module callback is called correctly after + associating a 3PID to an account. + """ + # Register a mocked callback. + threepid_bind_mock = Mock(return_value=make_awaitable(None)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Also register a normal user we can modify. + user_id = self.register_user("user", "password") + + # Add a 3PID to the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [ + { + "medium": "email", + "address": "foo@example.com", + }, + ], + }, + access_token=admin_tok, + ) + + # Check that the shutdown was blocked + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + threepid_bind_mock.assert_called_once() + args = threepid_bind_mock.call_args[0] + + # Check that the mock was called with the right parameters + self.assertEqual(args, (user_id, "email", "foo@example.com")) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 5148c39874e2..3b24d0ace622 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -17,7 +17,7 @@ import os import re from typing import Any, Dict, Optional, Sequence, Tuple, Type -from urllib.parse import urlencode +from urllib.parse import quote, urlencode from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address @@ -69,7 +69,6 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: "2001:800::/21", ) config["url_preview_ip_range_whitelist"] = ("1.1.1.1",) - config["url_preview_url_blacklist"] = [] config["url_preview_accept_language"] = [ "en-UK", "en-US;q=0.9", @@ -1123,3 +1122,43 @@ def test_cache_expiry(self) -> None: os.path.exists(path), f"{os.path.relpath(path, self.media_store_path)} was not deleted", ) + + @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]}) + def test_blacklist_port(self) -> None: + """Tests that blacklisting URLs with a port makes previewing such URLs + fail with a 403 error and doesn't impact other previews. + """ + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + bad_url = quote("http://matrix.org:8888/foo") + good_url = quote("http://matrix.org/foo") + + channel = self.make_request( + "GET", + "preview_url?url=" + bad_url, + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual(channel.code, 403, channel.result) + + channel = self.make_request( + "GET", + "preview_url?url=" + good_url, + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ee599f433667..97de9a59e0e5 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.types import DeviceListUpdates from synapse.util import Clock from tests import unittest @@ -267,7 +268,9 @@ def test_create_appservice_txn_first( events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) ) self.assertEqual(txn.id, 1) @@ -283,7 +286,9 @@ def test_create_appservice_txn_older_last_txn( self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) self.assertEqual(txn.id, 9646) self.assertEqual(txn.events, events) @@ -296,7 +301,9 @@ def test_create_appservice_txn_up_to_date_last_txn( events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) self.assertEqual(txn.id, 9644) self.assertEqual(txn.events, events) @@ -320,7 +327,9 @@ def test_create_appservice_txn_up_fuzzing( self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) self.assertEqual(txn.id, 9644) self.assertEqual(txn.events, events) @@ -476,12 +485,12 @@ def test_get_type_stream_id_for_appservice_no_value(self) -> None: value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") ) - self.assertEqual(value, 0) + self.assertEqual(value, 1) value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "presence") ) - self.assertEqual(value, 0) + self.assertEqual(value, 1) def test_get_type_stream_id_for_appservice_invalid_type(self) -> None: self.get_failure( diff --git a/tox.ini b/tox.ini index 7a2f13b58b2b..8b9ccd672fca 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = packaging, py37, py38, py39, py310 +envlist = py37, py38, py39, py310 # we require tox>=2.3.2 for the fix to https://github.com/tox-dev/tox/issues/208 minversion = 2.3.2 @@ -103,11 +103,3 @@ setenv = commands = python -m synmark {posargs:} -[testenv:packaging] -skip_install = true -usedevelop = false -deps = - check-manifest -commands = - check-manifest -