Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add cache for get_membership_from_event_ids #12272

Merged
merged 8 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12272.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a new cache `_get_membership_from_event_id` to speed up push rule calculations in large rooms.
30 changes: 16 additions & 14 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
Expand Down Expand Up @@ -292,7 +293,7 @@ def _condition_checker(
return True


MemberMap = Dict[str, Tuple[str, str]]
MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
Expand All @@ -306,7 +307,7 @@ class RulesForRoomData:
*only* include data, and not references to e.g. the data stores.
"""

# event_id -> (user_id, state)
# event_id -> EventIdMembership
member_map: MemberMap = attr.Factory(dict)
# user_id -> rules
rules_by_user: RulesByUser = attr.Factory(dict)
Expand Down Expand Up @@ -447,11 +448,10 @@ async def get_rules(

res = self.data.member_map.get(event_id, None)
if res:
user_id, state = res
if state == Membership.JOIN:
rules = self.data.rules_by_user.get(user_id, None)
if res.membership == Membership.JOIN:
rules = self.data.rules_by_user.get(res.user_id, None)
if rules:
ret_rules_by_user[user_id] = rules
ret_rules_by_user[res.user_id] = rules
continue

# If a user has left a room we remove their push rule. If they
Expand Down Expand Up @@ -502,24 +502,26 @@ async def _update_rules_with_member_event_ids(
"""
sequence = self.data.sequence

rows = await self.store.get_membership_from_event_ids(member_event_ids.values())

members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
members = await self.store.get_membership_from_event_ids(
member_event_ids.values()
)

# If the event is a join event then it will be in current state evnts
# If the event is a join event then it will be in current state events
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.values():
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
members[event_id] = EventIdMembership(
user_id=event.state_key, membership=event.membership
)

if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())

joined_user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
entry.user_id
for entry in members.values()
if entry and entry.membership == Membership.JOIN
}

logger.debug("Joined: %r", joined_user_ids)
Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ def _invalidate_caches_for_event(

self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))

# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
self._get_membership_from_event_id.invalidate((event_id,))

if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)

Expand Down
7 changes: 7 additions & 0 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,13 @@ def non_null_str_or_none(val: Any) -> Optional[str]:
(event.state_key,),
)

# The `_get_joined_profile_from_event_id` is immutable, except for the
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
# case where we look up an event *before* persisting it.
txn.call_after(
self.store._get_membership_from_event_id.invalidate,
(event.event_id,),
)

# We update the local_current_membership table only if the event is
# "current", i.e., its something that has just happened.
#
Expand Down
37 changes: 33 additions & 4 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"


@attr.s(frozen=True, slots=True, auto_attribs=True)
class EventIdMembership:
"""Returned by `get_membership_from_event_ids`"""

user_id: str
membership: str


class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(
self,
Expand Down Expand Up @@ -772,7 +780,7 @@ async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_membership_from_event_ids",
desc="_get_joined_profiles_from_event_ids",
)

return {
Expand Down Expand Up @@ -1000,12 +1008,26 @@ async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:

return set(room_ids)

@cached(max_entries=5000)
async def _get_membership_from_event_id(
self, member_event_id: str
) -> Optional[EventIdMembership]:
raise NotImplementedError()

@cachedList(
cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
)
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs."""
) -> Dict[str, Optional[EventIdMembership]]:
"""Get user_id and membership of a set of event IDs.

Returns:
Mapping from event ID to `EventIdMembership` if the event is a
membership event, otherwise the value is None.
"""

return await self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
Expand All @@ -1015,6 +1037,13 @@ async def get_membership_from_event_ids(
desc="get_membership_from_event_ids",
)

return {
row["event_id"]: EventIdMembership(
membership=row["membership"], user_id=row["user_id"]
)
for row in rows
}

async def is_local_host_in_room_ignoring_users(
self, room_id: str, ignore_users: Collection[str]
) -> bool:
Expand Down
15 changes: 11 additions & 4 deletions synapse/storage/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,8 +1023,13 @@ async def _is_server_still_joined(

# Check if any of the changes that we don't have events for are joins.
if events_to_check:
rows = await self.main_store.get_membership_from_event_ids(events_to_check)
is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
members = await self.main_store.get_membership_from_event_ids(
events_to_check
)
is_still_joined = any(
member and member.membership == Membership.JOIN
for member in members.values()
)
if is_still_joined:
return True

Expand Down Expand Up @@ -1060,9 +1065,11 @@ async def _is_server_still_joined(
), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
potentially_left_users.update(
row["user_id"] for row in rows if row["membership"] == Membership.JOIN
member.user_id
for member in members.values()
if member and member.membership == Membership.JOIN
)

return False
Expand Down