Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add cache for get_membership_from_event_ids #12272

Merged
merged 8 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12272.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a new cache `_get_membership_from_event_id` to speed up push rule calculations in large rooms.
30 changes: 16 additions & 14 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
Expand Down Expand Up @@ -292,7 +293,7 @@ def _condition_checker(
return True


MemberMap = Dict[str, Tuple[str, str]]
MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
Expand All @@ -306,7 +307,7 @@ class RulesForRoomData:
*only* include data, and not references to e.g. the data stores.
"""

# event_id -> (user_id, state)
# event_id -> EventIdMembership
member_map: MemberMap = attr.Factory(dict)
# user_id -> rules
rules_by_user: RulesByUser = attr.Factory(dict)
Expand Down Expand Up @@ -447,11 +448,10 @@ async def get_rules(

res = self.data.member_map.get(event_id, None)
if res:
user_id, state = res
if state == Membership.JOIN:
rules = self.data.rules_by_user.get(user_id, None)
if res.membership == Membership.JOIN:
rules = self.data.rules_by_user.get(res.user_id, None)
if rules:
ret_rules_by_user[user_id] = rules
ret_rules_by_user[res.user_id] = rules
continue

# If a user has left a room we remove their push rule. If they
Expand Down Expand Up @@ -502,24 +502,26 @@ async def _update_rules_with_member_event_ids(
"""
sequence = self.data.sequence

rows = await self.store.get_membership_from_event_ids(member_event_ids.values())

members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
members = await self.store.get_membership_from_event_ids(
member_event_ids.values()
)

# If the event is a join event then it will be in current state evnts
# If the event is a join event then it will be in current state events
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.values():
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
members[event_id] = EventIdMembership(
user_id=event.state_key, membership=event.membership
)

if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())

joined_user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
entry.user_id
for entry in members.values()
if entry and entry.membership == Membership.JOIN
}

logger.debug("Joined: %r", joined_user_ids)
Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ def _invalidate_caches_for_event(

self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))

# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
self._get_membership_from_event_id.invalidate((event_id,))

if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)

Expand Down
7 changes: 7 additions & 0 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,13 @@ def non_null_str_or_none(val: Any) -> Optional[str]:
(event.state_key,),
)

# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
txn.call_after(
self.store._get_membership_from_event_id.invalidate,
(event.event_id,),
)

# We update the local_current_membership table only if the event is
# "current", i.e., its something that has just happened.
#
Expand Down
37 changes: 33 additions & 4 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"


@attr.s(frozen=True, slots=True, auto_attribs=True)
class EventIdMembership:
"""Returned by `get_membership_from_event_ids`"""

user_id: str
membership: str


class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(
self,
Expand Down Expand Up @@ -772,7 +780,7 @@ async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_membership_from_event_ids",
desc="_get_joined_profiles_from_event_ids",
)

return {
Expand Down Expand Up @@ -1000,12 +1008,26 @@ async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:

return set(room_ids)

@cached(max_entries=5000)
async def _get_membership_from_event_id(
self, member_event_id: str
) -> Optional[EventIdMembership]:
raise NotImplementedError()

@cachedList(
cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
)
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs."""
) -> Dict[str, Optional[EventIdMembership]]:
"""Get user_id and membership of a set of event IDs.

Returns:
Mapping from event ID to `EventIdMembership` if the event is a
membership event, otherwise the value is None.
"""

return await self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
Expand All @@ -1015,6 +1037,13 @@ async def get_membership_from_event_ids(
desc="get_membership_from_event_ids",
)

return {
row["event_id"]: EventIdMembership(
membership=row["membership"], user_id=row["user_id"]
)
for row in rows
}

async def is_local_host_in_room_ignoring_users(
self, room_id: str, ignore_users: Collection[str]
) -> bool:
Expand Down
15 changes: 11 additions & 4 deletions synapse/storage/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,8 +1023,13 @@ async def _is_server_still_joined(

# Check if any of the changes that we don't have events for are joins.
if events_to_check:
rows = await self.main_store.get_membership_from_event_ids(events_to_check)
is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
members = await self.main_store.get_membership_from_event_ids(
events_to_check
)
is_still_joined = any(
member and member.membership == Membership.JOIN
for member in members.values()
)
if is_still_joined:
return True

Expand Down Expand Up @@ -1060,9 +1065,11 @@ async def _is_server_still_joined(
), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
potentially_left_users.update(
row["user_id"] for row in rows if row["membership"] == Membership.JOIN
member.user_id
for member in members.values()
if member and member.membership == Membership.JOIN
)

return False
Expand Down