Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve perf of sync device lists #17216

Merged
merged 5 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17216.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve performance of calculating device lists changes in `/sync`.
22 changes: 18 additions & 4 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,32 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict:

@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: StrCollection, from_token: StreamToken
self,
user_id: str,
room_ids: StrCollection,
from_token: StreamToken,
now_token: Optional[StreamToken] = None,
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
now_device_lists_key = self.store.get_device_stream_token()
if now_token:
now_device_lists_key = now_token.device_list_key

changed_users = await self.store.get_device_list_changes_in_rooms(
room_ids, from_token.device_list_key
room_ids,
from_token.device_list_key,
now_device_lists_key,
)

if changed_users is not None:
# We also check if the given user has changed their device. If
# they're in no rooms then the above query won't include them.
changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, [user_id]
from_token.device_list_key,
[user_id],
to_key=now_device_lists_key,
)
changed_users.update(changed)
return changed_users
Expand All @@ -190,7 +202,9 @@ async def get_device_changes_in_shared_rooms(
tracked_users.add(user_id)

changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users
from_token.device_list_key,
tracked_users,
to_key=now_device_lists_key,
)

return changed
Expand Down
38 changes: 7 additions & 31 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,38 +1886,14 @@ async def _generate_sync_entry_for_device_list(

# Step 1a, check for changes in devices of users we share a room
# with
#
# We do this in two different ways depending on what we have cached.
# If we already have a list of all the user that have changed since
# the last sync then it's likely more efficient to compare the rooms
# they're in with the rooms the syncing user is in.
#
# If we don't have that info cached then we get all the users that
# share a room with our user and check if those users have changed.
cache_result = self.store.get_cached_device_list_changes(
since_token.device_list_key
)
if cache_result.hit:
changed_users = cache_result.entities

result = await self.store.get_rooms_for_users(changed_users)

for changed_user_id, entries in result.items():
# Check if the changed user shares any rooms with the user,
# or if the changed user is the syncing user (as we always
# want to include device list updates of their own devices).
if user_id == changed_user_id or any(
rid in joined_room_ids for rid in entries
):
users_that_have_changed.add(changed_user_id)
else:
users_that_have_changed = (
await self._device_handler.get_device_changes_in_shared_rooms(
user_id,
sync_result_builder.joined_room_ids,
from_token=since_token,
)
users_that_have_changed = (
await self._device_handler.get_device_changes_in_shared_rooms(
user_id,
sync_result_builder.joined_room_ids,
from_token=since_token,
now_token=sync_result_builder.now_token,
)
)

# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
Expand Down
15 changes: 9 additions & 6 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ async def on_rdata(
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME:
if any(row.entity.startswith("@") and not row.is_signature for row in rows):
prev_token = self.store.get_device_stream_token()
all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token
)
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)

self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
Expand Down Expand Up @@ -146,12 +155,6 @@ async def on_rdata(
StreamKeyType.TO_DEVICE, token, users=entities
)
elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set()
for row in rows:
if row.entity.startswith("@") and not row.is_signature:
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)

# `all_room_ids` can be large, so let's wake up those streams in batches
for batched_room_ids in batch_iter(all_room_ids, 100):
self.notifier.on_new_event(
Expand Down
89 changes: 68 additions & 21 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import (
AllEntitiesChangedResult,
StreamChangeCache,
)
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
Expand Down Expand Up @@ -132,6 +129,20 @@ def __init__(
prefilled_cache=device_list_prefill,
)

device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_changes_in_room",
entity_column="room_id",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_room_stream_cache = StreamChangeCache(
"DeviceListRoomStreamChangeCache",
min_device_list_room_id,
prefilled_cache=device_list_room_prefill,
)

(
user_signature_stream_prefill,
user_signature_stream_list_id,
Expand Down Expand Up @@ -209,6 +220,13 @@ def _invalidate_caches_for_devices(
row.entity, token
)

def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
) -> None:
"Record that device lists have changed in rooms"
for room_id in room_ids:
self._device_list_room_stream_cache.entity_has_changed(room_id, token)

def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()

Expand Down Expand Up @@ -832,16 +850,6 @@ async def get_cached_devices_for_user(
)
return {device[0]: db_to_json(device[1]) for device in devices}

def get_cached_device_list_changes(
self,
from_key: int,
) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""

return self._device_list_stream_cache.get_all_entities_changed(from_key)

@cancellable
async def get_all_devices_changed(
self,
Expand Down Expand Up @@ -1457,7 +1465,7 @@ async def _get_min_device_lists_changes_in_room(self) -> int:

@cancellable
async def get_device_list_changes_in_rooms(
self, room_ids: Collection[str], from_id: int
self, room_ids: Collection[str], from_id: int, to_id: int
) -> Optional[Set[str]]:
"""Return the set of users whose devices have changed in the given rooms
since the given stream ID.
Expand All @@ -1473,9 +1481,15 @@ async def get_device_list_changes_in_rooms(
if min_stream_id > from_id:
return None

changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
room_ids, from_id
)
if not changed_room_ids:
return set()

sql = """
SELECT DISTINCT user_id FROM device_lists_changes_in_room
WHERE {clause} AND stream_id >= ?
WHERE {clause} AND stream_id > ? AND stream_id <= ?
"""

def _get_device_list_changes_in_rooms_txn(
Expand All @@ -1487,11 +1501,12 @@ def _get_device_list_changes_in_rooms_txn(
return {user_id for user_id, in txn}

changes = set()
for chunk in batch_iter(room_ids, 1000):
for chunk in batch_iter(changed_room_ids, 1000):
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", chunk
)
args.append(from_id)
args.append(to_id)

changes |= await self.db_pool.runInteraction(
"get_device_list_changes_in_rooms",
Expand All @@ -1502,6 +1517,34 @@ def _get_device_list_changes_in_rooms_txn(

return changes

async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
"""Return the set of rooms where devices have changed since the given
stream ID.

Will raise an exception if the given stream ID is too old.
"""

min_stream_id = await self._get_min_device_lists_changes_in_room()

if min_stream_id > from_id:
raise Exception("stream ID is too old")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really materially important, but feels like this should be a more specific exception type


sql = """
SELECT DISTINCT room_id FROM device_lists_changes_in_room
WHERE stream_id > ? AND stream_id <= ?
"""

def _get_all_device_list_changes_txn(
txn: LoggingTransaction,
) -> Set[str]:
txn.execute(sql, (from_id, to_id))
return {room_id for room_id, in txn}

return await self.db_pool.runInteraction(
"get_all_device_list_changes",
_get_all_device_list_changes_txn,
)

async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
Expand Down Expand Up @@ -1962,8 +2005,8 @@ def _update_remote_device_list_cache_txn(
async def add_device_change_to_streams(
self,
user_id: str,
device_ids: Collection[str],
room_ids: Collection[str],
device_ids: StrCollection,
room_ids: StrCollection,
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
Expand Down Expand Up @@ -2122,8 +2165,8 @@ def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Iterable[str],
room_ids: Collection[str],
device_ids: StrCollection,
room_ids: StrCollection,
stream_ids: List[int],
context: Dict[str, str],
) -> None:
Expand Down Expand Up @@ -2161,6 +2204,10 @@ def _add_device_outbound_room_poke_txn(
],
)

txn.call_after(
self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
)

async def get_uncoverted_outbound_room_pokes(
self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
Expand Down
Loading