diff --git a/changelog.d/12689.misc b/changelog.d/12689.misc new file mode 100644 index 000000000000..daa484ea3019 --- /dev/null +++ b/changelog.d/12689.misc @@ -0,0 +1 @@ +Refactor `EventContext` class. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 8120c305df14..9ccd24b298bb 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -17,11 +17,8 @@ from frozendict import frozendict from typing_extensions import Literal -from twisted.internet.defer import Deferred - from synapse.appservice import ApplicationService from synapse.events import EventBase -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import JsonDict, StateMap if TYPE_CHECKING: @@ -61,6 +58,9 @@ class EventContext: If ``state_group`` is None (ie, the event is an outlier), ``state_group_before_event`` will always also be ``None``. + state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None + then this is the delta of the state between the two groups. + prev_group: If it is known, ``state_group``'s prev_group. Note that this being None does not necessarily mean that ``state_group`` does not have a prev_group! @@ -79,73 +79,47 @@ class EventContext: app_service: If this event is being sent by a (local) application service, that app service. - _current_state_ids: The room state map, including this event - ie, the state - in ``state_group``. - - (type, state_key) -> event_id - - For an outlier, this is {} - - Note that this is a private attribute: it should be accessed via - ``get_current_state_ids``. _AsyncEventContext impl calculates this - on-demand: it will be None until that happens. - - _prev_state_ids: The room state map, excluding this event - ie, the state - in ``state_group_before_event``. For a non-state - event, this will be the same as _current_state_events. - - Note that it is a completely different thing to prev_group! - - (type, state_key) -> event_id - - For an outlier, this is {} - - As with _current_state_ids, this is a private attribute. It should be - accessed via get_prev_state_ids. - partial_state: if True, we may be storing this event with a temporary, incomplete state. """ + _storage: "Storage" rejected: Union[Literal[False], str] = False _state_group: Optional[int] = None state_group_before_event: Optional[int] = None + _state_delta_due_to_event: Optional[StateMap[str]] = None prev_group: Optional[int] = None delta_ids: Optional[StateMap[str]] = None app_service: Optional[ApplicationService] = None - _current_state_ids: Optional[StateMap[str]] = None - _prev_state_ids: Optional[StateMap[str]] = None - partial_state: bool = False @staticmethod def with_state( + storage: "Storage", state_group: Optional[int], state_group_before_event: Optional[int], - current_state_ids: Optional[StateMap[str]], - prev_state_ids: Optional[StateMap[str]], + state_delta_due_to_event: Optional[StateMap[str]], partial_state: bool, prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ) -> "EventContext": return EventContext( - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, + storage=storage, state_group=state_group, state_group_before_event=state_group_before_event, + state_delta_due_to_event=state_delta_due_to_event, prev_group=prev_group, delta_ids=delta_ids, partial_state=partial_state, ) @staticmethod - def for_outlier() -> "EventContext": + def for_outlier( + storage: "Storage", + ) -> "EventContext": """Return an EventContext instance suitable for persisting an outlier event""" - return EventContext( - current_state_ids={}, - prev_state_ids={}, - ) + return EventContext(storage=storage) async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: """Converts self to a type that can be serialized as JSON, and then @@ -158,24 +132,14 @@ async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: The serialized event. """ - # We don't serialize the full state dicts, instead they get pulled out - # of the DB on the other side. However, the other side can't figure out - # the prev_state_ids, so if we're a state event we include the event - # id that we replaced in the state. - if event.is_state(): - prev_state_ids = await self.get_prev_state_ids() - prev_state_id = prev_state_ids.get((event.type, event.state_key)) - else: - prev_state_id = None - return { - "prev_state_id": prev_state_id, - "event_type": event.type, - "event_state_key": event.get_state_key(), "state_group": self._state_group, "state_group_before_event": self.state_group_before_event, "rejected": self.rejected, "prev_group": self.prev_group, + "state_delta_due_to_event": _encode_state_dict( + self._state_delta_due_to_event + ), "delta_ids": _encode_state_dict(self.delta_ids), "app_service_id": self.app_service.id if self.app_service else None, "partial_state": self.partial_state, @@ -193,16 +157,16 @@ def deserialize(storage: "Storage", input: JsonDict) -> "EventContext": Returns: The event context. """ - context = _AsyncEventContextImpl( + context = EventContext( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. storage=storage, - prev_state_id=input["prev_state_id"], - event_type=input["event_type"], - event_state_key=input["event_state_key"], state_group=input["state_group"], state_group_before_event=input["state_group_before_event"], prev_group=input["prev_group"], + state_delta_due_to_event=_decode_state_dict( + input["state_delta_due_to_event"] + ), delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], partial_state=input.get("partial_state", False), @@ -250,8 +214,15 @@ async def get_current_state_ids(self) -> Optional[StateMap[str]]: if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - await self._ensure_fetched() - return self._current_state_ids + assert self._state_delta_due_to_event is not None + + prev_state_ids = await self.get_prev_state_ids() + + if self._state_delta_due_to_event: + prev_state_ids = dict(prev_state_ids) + prev_state_ids.update(self._state_delta_due_to_event) + + return prev_state_ids async def get_prev_state_ids(self) -> StateMap[str]: """ @@ -266,94 +237,10 @@ async def get_prev_state_ids(self) -> StateMap[str]: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - await self._ensure_fetched() - # There *should* be previous state IDs now. - assert self._prev_state_ids is not None - return self._prev_state_ids - - def get_cached_current_state_ids(self) -> Optional[StateMap[str]]: - """Gets the current state IDs if we have them already cached. - - It is an error to access this for a rejected event, since rejected state should - not make it into the room state. This method will raise an exception if - ``rejected`` is set. - - Returns: - Returns None if we haven't cached the state or if state_group is None - (which happens when the associated event is an outlier). - - Otherwise, returns the the current state IDs. - """ - if self.rejected: - raise RuntimeError("Attempt to access state_ids of rejected event") - - return self._current_state_ids - - async def _ensure_fetched(self) -> None: - return None - - -@attr.s(slots=True) -class _AsyncEventContextImpl(EventContext): - """ - An implementation of EventContext which fetches _current_state_ids and - _prev_state_ids from the database on demand. - - Attributes: - - _storage - - _fetching_state_deferred: Resolves when *_state_ids have been calculated. - None if we haven't started calculating yet - - _event_type: The type of the event the context is associated with. - - _event_state_key: The state_key of the event the context is associated with. - - _prev_state_id: If the event associated with the context is a state event, - then `_prev_state_id` is the event_id of the state that was replaced. - """ - - # This needs to have a default as we're inheriting - _storage: "Storage" = attr.ib(default=None) - _prev_state_id: Optional[str] = attr.ib(default=None) - _event_type: str = attr.ib(default=None) - _event_state_key: Optional[str] = attr.ib(default=None) - _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None) - - async def _ensure_fetched(self) -> None: - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background(self._fill_out_state) - - await make_deferred_yieldable(self._fetching_state_deferred) - - async def _fill_out_state(self) -> None: - """Called to populate the _current_state_ids and _prev_state_ids - attributes by loading from the database. - """ - if self.state_group is None: - # No state group means the event is an outlier. Usually the state_ids dicts are also - # pre-set to empty dicts, but they get reset when the context is serialized, so set - # them to empty dicts again here. - self._current_state_ids = {} - self._prev_state_ids = {} - return - - current_state_ids = await self._storage.state.get_state_ids_for_group( - self.state_group + assert self.state_group_before_event is not None + return await self._storage.state.get_state_ids_for_group( + self.state_group_before_event ) - # Set this separately so mypy knows current_state_ids is not None. - self._current_state_ids = current_state_ids - if self._event_state_key is not None: - self._prev_state_ids = dict(current_state_ids) - - key = (self._event_type, self._event_state_key) - if self._prev_state_id: - self._prev_state_ids[key] = self._prev_state_id - else: - self._prev_state_ids.pop(key, None) - else: - self._prev_state_ids = current_state_ids def _encode_state_dict( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 38dc5b1f6edf..be5099b507f6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -659,7 +659,7 @@ async def do_knock( # in the invitee's sync stream. It is stripped out for all other local users. event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] - context = EventContext.for_outlier() + context = EventContext.for_outlier(self.storage) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -848,7 +848,7 @@ async def on_invite_request( ) ) - context = EventContext.for_outlier() + context = EventContext.for_outlier(self.storage) await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -877,7 +877,7 @@ async def do_remotely_reject_invite( await self.federation_client.send_leave(host_list, event) - context = EventContext.for_outlier() + context = EventContext.for_outlier(self.storage) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 6cf927e4ff7b..6d11b32b61db 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1423,7 +1423,7 @@ def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]: # we're not bothering about room state, so flag the event as an outlier. event.internal_metadata.outlier = True - context = EventContext.for_outlier() + context = EventContext.for_outlier(self._storage) try: validate_event_for_room_version(room_version_obj, event) check_auth_rules_for_event(room_version_obj, event, auth) @@ -1874,10 +1874,10 @@ async def _update_context_for_auth_events( ) return EventContext.with_state( + storage=self._storage, state_group=state_group, state_group_before_event=context.state_group_before_event, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, + state_delta_due_to_event=state_updates, prev_group=prev_group, delta_ids=state_updates, partial_state=context.partial_state, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c28b792e6fe2..e47799e7f962 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -757,6 +757,10 @@ async def deduplicate_state_event( The previous version of the event is returned, if it is found in the event context. Otherwise, None is returned. """ + if event.internal_metadata.is_outlier(): + # This can happen due to out of band memberships + return None + prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: @@ -1001,7 +1005,7 @@ async def create_new_client_event( # after it is created if builder.internal_metadata.outlier: event.internal_metadata.outlier = True - context = EventContext.for_outlier() + context = EventContext.for_outlier(self.storage) elif ( event.type == EventTypes.MSC2716_INSERTION and state_event_ids diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 60758df01664..730d9cd35463 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -40,5 +40,9 @@ def __init__(self, hs: "HomeServer"): async def handle_push_actions_for_event( self, event: EventBase, context: EventContext ) -> None: + if event.internal_metadata.is_outlier(): + # This can happen due to out of band memberships + return + with Measure(self.clock, "action_for_event_by_user"): await self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index cad3b4264007..54e41d537584 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -130,6 +130,7 @@ def __init__(self, hs: "HomeServer"): self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() + self._storage = hs.get_storage() @overload async def get_current_state( @@ -361,10 +362,10 @@ async def compute_event_context( if not event.is_state(): return EventContext.with_state( + storage=self._storage, state_group_before_event=state_group_before_event, state_group=state_group_before_event, - current_state_ids=state_ids_before_event, - prev_state_ids=state_ids_before_event, + state_delta_due_to_event={}, prev_group=state_group_before_event_prev_group, delta_ids=deltas_to_state_group_before_event, partial_state=partial_state, @@ -393,10 +394,10 @@ async def compute_event_context( ) return EventContext.with_state( + storage=self._storage, state_group=state_group_after_event, state_group_before_event=state_group_before_event, - current_state_ids=state_ids_after_event, - prev_state_ids=state_ids_before_event, + state_delta_due_to_event=delta_ids, prev_group=state_group_before_event, delta_ids=delta_ids, partial_state=partial_state, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6c12653bb3c6..f544bcfff07f 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -128,7 +128,6 @@ async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], *, - current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremities: Dict[str, Set[str]], use_negative_stream_ordering: bool = False, @@ -139,8 +138,6 @@ async def _persist_events_and_state_updates( Args: events_and_contexts: - current_state_for_room: Map from room_id to the current state of - the room based on forward extremities state_delta_for_room: Map from room_id to the delta to apply to room state new_forward_extremities: Map from room_id to set of event IDs @@ -215,9 +212,6 @@ async def _persist_events_and_state_updates( event_counter.labels(event.type, origin_type, origin_entity).inc() - for room_id, new_state in current_state_for_room.items(): - self.store.get_current_state_ids.prefill((room_id,), new_state) - for room_id, latest_event_ids in new_forward_extremities.items(): self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 97118045a1ad..a7f6338e058d 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -487,12 +487,6 @@ async def _persist_event_batch( # extremities in each room new_forward_extremities: Dict[str, Set[str]] = {} - # map room_id->(type,state_key)->event_id tracking the full - # state in each room after adding these events. - # This is simply used to prefill the get_current_state_ids - # cache - current_state_for_room: Dict[str, StateMap[str]] = {} - # map room_id->(to_delete, to_insert) where to_delete is a list # of type/state keys to remove from current state, and to_insert # is a map (type,key)->event_id giving the state delta in each @@ -628,14 +622,8 @@ async def _persist_event_batch( state_delta_for_room[room_id] = delta - # If we have the current_state then lets prefill - # the cache with it. - if current_state is not None: - current_state_for_room[room_id] = current_state - await self.persist_events_store._persist_events_and_state_updates( chunk, - current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, new_forward_extremities=new_forward_extremities, use_negative_stream_ordering=backfilled, @@ -733,7 +721,8 @@ async def _get_new_state_after_events( The first state map is the full new current state and the second is the delta to the existing current state. If both are None then - there has been no change. + there has been no change. Either or neither can be None if there + has been a change. The function may prune some old entries from the set of new forward extremities if it's safe to do so. @@ -743,9 +732,6 @@ async def _get_new_state_after_events( the new current state is only returned if we've already calculated it. """ - # map from state_group to ((type, key) -> event_id) state map - state_groups_map = {} - # Map from (prev state group, new state group) -> delta state dict state_group_deltas = {} @@ -759,16 +745,6 @@ async def _get_new_state_after_events( ) continue - if ctx.state_group in state_groups_map: - continue - - # We're only interested in pulling out state that has already - # been cached in the context. We'll pull stuff out of the DB later - # if necessary. - current_state_ids = ctx.get_cached_current_state_ids() - if current_state_ids is not None: - state_groups_map[ctx.state_group] = current_state_ids - if ctx.prev_group: state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids @@ -826,18 +802,14 @@ async def _get_new_state_after_events( delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) if delta_ids is not None: # We have a delta from the existing to new current state, - # so lets just return that. If we happen to already have - # the current state in memory then lets also return that, - # but it doesn't matter if we don't. - new_state = state_groups_map.get(new_state_group) - return new_state, delta_ids, new_latest_event_ids + # so lets just return that. + return None, delta_ids, new_latest_event_ids # Now that we have calculated new_state_groups we need to get # their state IDs so we can resolve to a single state set. - missing_state = new_state_groups - set(state_groups_map) - if missing_state: - group_to_state = await self.state_store._get_state_for_groups(missing_state) - state_groups_map.update(group_to_state) + state_groups_map = await self.state_store._get_state_for_groups( + new_state_groups + ) if len(new_state_groups) == 1: # If there is only one state group, then we know what the current diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 489ba5773672..e64b28f28b86 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -148,7 +148,9 @@ def _test_process_pulled_event_with_missing_state( prev_event.internal_metadata.outlier = True persistence = self.hs.get_storage().persistence self.get_success( - persistence.persist_event(prev_event, EventContext.for_outlier()) + persistence.persist_event( + prev_event, EventContext.for_outlier(self.hs.get_storage()) + ) ) else: diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 401020fd6361..c7661e71868f 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -393,7 +393,7 @@ def _persist(txn): # We need to persist the events to the events and state_events # tables. persist_events_store._store_event_txn( - txn, [(e, EventContext()) for e in events] + txn, [(e, EventContext(self.hs.get_storage())) for e in events] ) # Actually call the function that calculates the auth chain stuff. diff --git a/tests/test_state.py b/tests/test_state.py index e4baa6913746..651ec1c7d4bd 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -88,6 +88,9 @@ async def get_state_groups_ids(self, room_id, event_ids): return groups + async def get_state_ids_for_group(self, state_group): + return self._group_to_state[state_group] + async def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids ): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index d0230f9ebbc5..7a9b01ef9d44 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -234,7 +234,9 @@ def _inject_outlier(self) -> EventBase: event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event.internal_metadata.outlier = True self.get_success( - self.storage.persistence.persist_event(event, EventContext.for_outlier()) + self.storage.persistence.persist_event( + event, EventContext.for_outlier(self.storage) + ) ) return event