Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor EventContext #12689

Merged
merged 9 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12689.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `EventContext` class.
173 changes: 28 additions & 145 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@
from frozendict import frozendict
from typing_extensions import Literal

from twisted.internet.defer import Deferred

from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import JsonDict, StateMap

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,6 +58,9 @@ class EventContext:
If ``state_group`` is None (ie, the event is an outlier),
``state_group_before_event`` will always also be ``None``.

delta_before_after: If `state_group` and `state_group_before_event` are not None
then this is the delta of the state between the two groups.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better name we can use? Maybe state_delta_at_event or something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard because I want to make sure people don't think that its the state delta since the prev_group

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! Just the before_after I find confusing. Before/after what?!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh i see. Yes. Ermh, delta_before_and_after_event?


prev_group: If it is known, ``state_group``'s prev_group. Note that this being
None does not necessarily mean that ``state_group`` does not have
a prev_group!
Expand All @@ -79,73 +79,47 @@ class EventContext:
app_service: If this event is being sent by a (local) application service, that
app service.

_current_state_ids: The room state map, including this event - ie, the state
in ``state_group``.

(type, state_key) -> event_id

For an outlier, this is {}

Note that this is a private attribute: it should be accessed via
``get_current_state_ids``. _AsyncEventContext impl calculates this
on-demand: it will be None until that happens.

_prev_state_ids: The room state map, excluding this event - ie, the state
in ``state_group_before_event``. For a non-state
event, this will be the same as _current_state_events.

Note that it is a completely different thing to prev_group!

(type, state_key) -> event_id

For an outlier, this is {}

As with _current_state_ids, this is a private attribute. It should be
accessed via get_prev_state_ids.

partial_state: if True, we may be storing this event with a temporary,
incomplete state.
"""

_storage: "Storage"
rejected: Union[Literal[False], str] = False
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
_delta_before_after: Optional[StateMap[str]] = None
prev_group: Optional[int] = None
delta_ids: Optional[StateMap[str]] = None
app_service: Optional[ApplicationService] = None

_current_state_ids: Optional[StateMap[str]] = None
_prev_state_ids: Optional[StateMap[str]] = None

partial_state: bool = False

@staticmethod
def with_state(
storage: "Storage",
state_group: Optional[int],
state_group_before_event: Optional[int],
current_state_ids: Optional[StateMap[str]],
prev_state_ids: Optional[StateMap[str]],
delta_before_after: Optional[StateMap[str]],
partial_state: bool,
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
) -> "EventContext":
return EventContext(
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
storage=storage,
state_group=state_group,
state_group_before_event=state_group_before_event,
delta_before_after=delta_before_after,
prev_group=prev_group,
delta_ids=delta_ids,
partial_state=partial_state,
)

@staticmethod
def for_outlier() -> "EventContext":
def for_outlier(
storage: "Storage",
) -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(
current_state_ids={},
prev_state_ids={},
)
return EventContext(storage=storage)

async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
"""Converts self to a type that can be serialized as JSON, and then
Expand All @@ -158,24 +132,12 @@ async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
The serialized event.
"""

# We don't serialize the full state dicts, instead they get pulled out
# of the DB on the other side. However, the other side can't figure out
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
prev_state_ids = await self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None

return {
"prev_state_id": prev_state_id,
"event_type": event.type,
"event_state_key": event.get_state_key(),
"state_group": self._state_group,
"state_group_before_event": self.state_group_before_event,
"rejected": self.rejected,
"prev_group": self.prev_group,
"delta_before_after": _encode_state_dict(self._delta_before_after),
"delta_ids": _encode_state_dict(self.delta_ids),
"app_service_id": self.app_service.id if self.app_service else None,
"partial_state": self.partial_state,
Expand All @@ -193,16 +155,14 @@ def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
Returns:
The event context.
"""
context = _AsyncEventContextImpl(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So happy to see this logic disappear. 🔥

context = EventContext(
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
storage=storage,
prev_state_id=input["prev_state_id"],
event_type=input["event_type"],
event_state_key=input["event_state_key"],
state_group=input["state_group"],
state_group_before_event=input["state_group_before_event"],
prev_group=input["prev_group"],
delta_before_after=_decode_state_dict(input["delta_before_after"]),
delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"],
partial_state=input.get("partial_state", False),
Expand Down Expand Up @@ -250,8 +210,15 @@ async def get_current_state_ids(self) -> Optional[StateMap[str]]:
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")

await self._ensure_fetched()
return self._current_state_ids
assert self._delta_before_after is not None

prev_state_ids = await self.get_prev_state_ids()

if self._delta_before_after:
prev_state_ids = dict(prev_state_ids)
prev_state_ids.update(self._delta_before_after)

return prev_state_ids

async def get_prev_state_ids(self) -> StateMap[str]:
"""
Expand All @@ -266,94 +233,10 @@ async def get_prev_state_ids(self) -> StateMap[str]:
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
await self._ensure_fetched()
# There *should* be previous state IDs now.
assert self._prev_state_ids is not None
return self._prev_state_ids

def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
"""Gets the current state IDs if we have them already cached.

It is an error to access this for a rejected event, since rejected state should
not make it into the room state. This method will raise an exception if
``rejected`` is set.

Returns:
Returns None if we haven't cached the state or if state_group is None
(which happens when the associated event is an outlier).

Otherwise, returns the the current state IDs.
"""
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")

return self._current_state_ids

async def _ensure_fetched(self) -> None:
return None


@attr.s(slots=True)
class _AsyncEventContextImpl(EventContext):
"""
An implementation of EventContext which fetches _current_state_ids and
_prev_state_ids from the database on demand.

Attributes:

_storage

_fetching_state_deferred: Resolves when *_state_ids have been calculated.
None if we haven't started calculating yet

_event_type: The type of the event the context is associated with.

_event_state_key: The state_key of the event the context is associated with.

_prev_state_id: If the event associated with the context is a state event,
then `_prev_state_id` is the event_id of the state that was replaced.
"""

# This needs to have a default as we're inheriting
_storage: "Storage" = attr.ib(default=None)
_prev_state_id: Optional[str] = attr.ib(default=None)
_event_type: str = attr.ib(default=None)
_event_state_key: Optional[str] = attr.ib(default=None)
_fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)

async def _ensure_fetched(self) -> None:
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state)

await make_deferred_yieldable(self._fetching_state_deferred)

async def _fill_out_state(self) -> None:
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
# No state group means the event is an outlier. Usually the state_ids dicts are also
# pre-set to empty dicts, but they get reset when the context is serialized, so set
# them to empty dicts again here.
self._current_state_ids = {}
self._prev_state_ids = {}
return

current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group
assert self.state_group_before_event is not None
return await self._storage.state.get_state_ids_for_group(
self.state_group_before_event
)
# Set this separately so mypy knows current_state_ids is not None.
self._current_state_ids = current_state_ids
if self._event_state_key is not None:
self._prev_state_ids = dict(current_state_ids)

key = (self._event_type, self._event_state_key)
if self._prev_state_id:
self._prev_state_ids[key] = self._prev_state_id
else:
self._prev_state_ids.pop(key, None)
else:
self._prev_state_ids = current_state_ids


def _encode_state_dict(
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ async def do_knock(
# in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]

context = EventContext.for_outlier()
context = EventContext.for_outlier(self.storage)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
Expand Down Expand Up @@ -848,7 +848,7 @@ async def on_invite_request(
)
)

context = EventContext.for_outlier()
context = EventContext.for_outlier(self.storage)
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
Expand Down Expand Up @@ -877,7 +877,7 @@ async def do_remotely_reject_invite(

await self.federation_client.send_leave(host_list, event)

context = EventContext.for_outlier()
context = EventContext.for_outlier(self.storage)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
# we're not bothering about room state, so flag the event as an outlier.
event.internal_metadata.outlier = True

context = EventContext.for_outlier()
context = EventContext.for_outlier(self._storage)
try:
validate_event_for_room_version(room_version_obj, event)
check_auth_rules_for_event(room_version_obj, event, auth)
Expand Down Expand Up @@ -1874,10 +1874,10 @@ async def _update_context_for_auth_events(
)

return EventContext.with_state(
storage=self._storage,
state_group=state_group,
state_group_before_event=context.state_group_before_event,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
delta_before_after=state_updates,
prev_group=prev_group,
delta_ids=state_updates,
partial_state=context.partial_state,
Expand Down
6 changes: 5 additions & 1 deletion synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,10 @@ async def deduplicate_state_event(
The previous version of the event is returned, if it is found in the
event context. Otherwise, None is returned.
"""
if event.internal_metadata.is_outlier():
# This can happen due to out of band memberships
return None

prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
Expand Down Expand Up @@ -1001,7 +1005,7 @@ async def create_new_client_event(
# after it is created
if builder.internal_metadata.outlier:
event.internal_metadata.outlier = True
context = EventContext.for_outlier()
context = EventContext.for_outlier(self.storage)
elif (
event.type == EventTypes.MSC2716_INSERTION
and state_event_ids
Expand Down
4 changes: 4 additions & 0 deletions synapse/push/action_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,9 @@ def __init__(self, hs: "HomeServer"):
async def handle_push_actions_for_event(
self, event: EventBase, context: EventContext
) -> None:
if event.internal_metadata.is_outlier():
# This can happen due to out of band memberships
return

with Measure(self.clock, "action_for_event_by_user"):
await self.bulk_evaluator.action_for_event_by_user(event, context)
9 changes: 5 additions & 4 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(self, hs: "HomeServer"):
self.state_store = hs.get_storage().state
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage = hs.get_storage()

@overload
async def get_current_state(
Expand Down Expand Up @@ -361,10 +362,10 @@ async def compute_event_context(

if not event.is_state():
return EventContext.with_state(
storage=self._storage,
state_group_before_event=state_group_before_event,
state_group=state_group_before_event,
current_state_ids=state_ids_before_event,
prev_state_ids=state_ids_before_event,
delta_before_after={},
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
partial_state=partial_state,
Expand Down Expand Up @@ -393,10 +394,10 @@ async def compute_event_context(
)

return EventContext.with_state(
storage=self._storage,
state_group=state_group_after_event,
state_group_before_event=state_group_before_event,
current_state_ids=state_ids_after_event,
prev_state_ids=state_ids_before_event,
delta_before_after=delta_ids,
prev_group=state_group_before_event,
delta_ids=delta_ids,
partial_state=partial_state,
Expand Down
Loading