Skip to content

Commit

Permalink
Improve perf of sync device lists (element-hq#17216)
Browse files Browse the repository at this point in the history
Re-introduces element-hq#17191, and includes element-hq#17197 and element-hq#17214

The basic idea is to stop calling `get_rooms_for_user` everywhere, and
instead use the table `device_lists_changes_in_room`.

Commits reviewable one-by-one.
  • Loading branch information
erikjohnston authored and Mic92 committed Jun 14, 2024
1 parent 15fb34a commit c282b30
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 62 deletions.
1 change: 1 addition & 0 deletions changelog.d/17216.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve performance of calculating device lists changes in `/sync`.
22 changes: 18 additions & 4 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,32 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict:

@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: StrCollection, from_token: StreamToken
self,
user_id: str,
room_ids: StrCollection,
from_token: StreamToken,
now_token: Optional[StreamToken] = None,
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
now_device_lists_key = self.store.get_device_stream_token()
if now_token:
now_device_lists_key = now_token.device_list_key

changed_users = await self.store.get_device_list_changes_in_rooms(
room_ids, from_token.device_list_key
room_ids,
from_token.device_list_key,
now_device_lists_key,
)

if changed_users is not None:
# We also check if the given user has changed their device. If
# they're in no rooms then the above query won't include them.
changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, [user_id]
from_token.device_list_key,
[user_id],
to_key=now_device_lists_key,
)
changed_users.update(changed)
return changed_users
Expand All @@ -190,7 +202,9 @@ async def get_device_changes_in_shared_rooms(
tracked_users.add(user_id)

changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users
from_token.device_list_key,
tracked_users,
to_key=now_device_lists_key,
)

return changed
Expand Down
38 changes: 7 additions & 31 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,38 +1886,14 @@ async def _generate_sync_entry_for_device_list(

# Step 1a, check for changes in devices of users we share a room
# with
#
# We do this in two different ways depending on what we have cached.
# If we already have a list of all the user that have changed since
# the last sync then it's likely more efficient to compare the rooms
# they're in with the rooms the syncing user is in.
#
# If we don't have that info cached then we get all the users that
# share a room with our user and check if those users have changed.
cache_result = self.store.get_cached_device_list_changes(
since_token.device_list_key
)
if cache_result.hit:
changed_users = cache_result.entities

result = await self.store.get_rooms_for_users(changed_users)

for changed_user_id, entries in result.items():
# Check if the changed user shares any rooms with the user,
# or if the changed user is the syncing user (as we always
# want to include device list updates of their own devices).
if user_id == changed_user_id or any(
rid in joined_room_ids for rid in entries
):
users_that_have_changed.add(changed_user_id)
else:
users_that_have_changed = (
await self._device_handler.get_device_changes_in_shared_rooms(
user_id,
sync_result_builder.joined_room_ids,
from_token=since_token,
)
users_that_have_changed = (
await self._device_handler.get_device_changes_in_shared_rooms(
user_id,
sync_result_builder.joined_room_ids,
from_token=since_token,
now_token=sync_result_builder.now_token,
)
)

# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
Expand Down
15 changes: 9 additions & 6 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ async def on_rdata(
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME:
if any(row.entity.startswith("@") and not row.is_signature for row in rows):
prev_token = self.store.get_device_stream_token()
all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token
)
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)

self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
Expand Down Expand Up @@ -146,12 +155,6 @@ async def on_rdata(
StreamKeyType.TO_DEVICE, token, users=entities
)
elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set()
for row in rows:
if row.entity.startswith("@") and not row.is_signature:
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)

# `all_room_ids` can be large, so let's wake up those streams in batches
for batched_room_ids in batch_iter(all_room_ids, 100):
self.notifier.on_new_event(
Expand Down
89 changes: 68 additions & 21 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import (
AllEntitiesChangedResult,
StreamChangeCache,
)
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
Expand Down Expand Up @@ -132,6 +129,20 @@ def __init__(
prefilled_cache=device_list_prefill,
)

device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_changes_in_room",
entity_column="room_id",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_room_stream_cache = StreamChangeCache(
"DeviceListRoomStreamChangeCache",
min_device_list_room_id,
prefilled_cache=device_list_room_prefill,
)

(
user_signature_stream_prefill,
user_signature_stream_list_id,
Expand Down Expand Up @@ -209,6 +220,13 @@ def _invalidate_caches_for_devices(
row.entity, token
)

def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
) -> None:
"Record that device lists have changed in rooms"
for room_id in room_ids:
self._device_list_room_stream_cache.entity_has_changed(room_id, token)

def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()

Expand Down Expand Up @@ -832,16 +850,6 @@ async def get_cached_devices_for_user(
)
return {device[0]: db_to_json(device[1]) for device in devices}

def get_cached_device_list_changes(
self,
from_key: int,
) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""

return self._device_list_stream_cache.get_all_entities_changed(from_key)

@cancellable
async def get_all_devices_changed(
self,
Expand Down Expand Up @@ -1457,7 +1465,7 @@ async def _get_min_device_lists_changes_in_room(self) -> int:

@cancellable
async def get_device_list_changes_in_rooms(
self, room_ids: Collection[str], from_id: int
self, room_ids: Collection[str], from_id: int, to_id: int
) -> Optional[Set[str]]:
"""Return the set of users whose devices have changed in the given rooms
since the given stream ID.
Expand All @@ -1473,9 +1481,15 @@ async def get_device_list_changes_in_rooms(
if min_stream_id > from_id:
return None

changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
room_ids, from_id
)
if not changed_room_ids:
return set()

sql = """
SELECT DISTINCT user_id FROM device_lists_changes_in_room
WHERE {clause} AND stream_id >= ?
WHERE {clause} AND stream_id > ? AND stream_id <= ?
"""

def _get_device_list_changes_in_rooms_txn(
Expand All @@ -1487,11 +1501,12 @@ def _get_device_list_changes_in_rooms_txn(
return {user_id for user_id, in txn}

changes = set()
for chunk in batch_iter(room_ids, 1000):
for chunk in batch_iter(changed_room_ids, 1000):
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", chunk
)
args.append(from_id)
args.append(to_id)

changes |= await self.db_pool.runInteraction(
"get_device_list_changes_in_rooms",
Expand All @@ -1502,6 +1517,34 @@ def _get_device_list_changes_in_rooms_txn(

return changes

async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
"""Return the set of rooms where devices have changed since the given
stream ID.
Will raise an exception if the given stream ID is too old.
"""

min_stream_id = await self._get_min_device_lists_changes_in_room()

if min_stream_id > from_id:
raise Exception("stream ID is too old")

sql = """
SELECT DISTINCT room_id FROM device_lists_changes_in_room
WHERE stream_id > ? AND stream_id <= ?
"""

def _get_all_device_list_changes_txn(
txn: LoggingTransaction,
) -> Set[str]:
txn.execute(sql, (from_id, to_id))
return {room_id for room_id, in txn}

return await self.db_pool.runInteraction(
"get_all_device_list_changes",
_get_all_device_list_changes_txn,
)

async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
Expand Down Expand Up @@ -1962,8 +2005,8 @@ def _update_remote_device_list_cache_txn(
async def add_device_change_to_streams(
self,
user_id: str,
device_ids: Collection[str],
room_ids: Collection[str],
device_ids: StrCollection,
room_ids: StrCollection,
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
Expand Down Expand Up @@ -2122,8 +2165,8 @@ def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Iterable[str],
room_ids: Collection[str],
device_ids: StrCollection,
room_ids: StrCollection,
stream_ids: List[int],
context: Dict[str, str],
) -> None:
Expand Down Expand Up @@ -2161,6 +2204,10 @@ def _add_device_outbound_room_poke_txn(
],
)

txn.call_after(
self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
)

async def get_uncoverted_outbound_room_pokes(
self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
Expand Down

0 comments on commit c282b30

Please sign in to comment.