From e50674f7f077f6da46bec9d8ff4d5df4c47d833f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 10 May 2022 08:50:38 -0400 Subject: [PATCH 1/2] Consolidate logic for parsing relations. --- changelog.d/12693.misc | 1 + synapse/events/__init__.py | 44 +++++++++++++++++++++++ synapse/handlers/message.py | 28 ++++++--------- synapse/handlers/relations.py | 18 +++++----- synapse/push/bulk_push_rule_evaluator.py | 4 +-- synapse/storage/databases/main/events.py | 45 ++++++++++-------------- tests/rest/client/test_sync.py | 8 +++-- 7 files changed, 92 insertions(+), 56 deletions(-) create mode 100644 changelog.d/12693.misc diff --git a/changelog.d/12693.misc b/changelog.d/12693.misc new file mode 100644 index 000000000000..8bd1e1cb0cd5 --- /dev/null +++ b/changelog.d/12693.misc @@ -0,0 +1 @@ +Consolidate parsing of relation information from events. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index c238376caf62..ded2d191d36d 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -15,6 +15,7 @@ # limitations under the License. import abc +import collections.abc import os from typing import ( TYPE_CHECKING, @@ -32,9 +33,11 @@ overload, ) +import attr from typing_extensions import Literal from unpaddedbase64 import encode_base64 +from synapse.api.constants import RelationTypes from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.types import JsonDict, RoomStreamToken from synapse.util.caches import intern_dict @@ -287,6 +290,17 @@ def is_historical(self) -> bool: return self._dict.get("historical", False) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventRelation: + # The target event of the relation. + parent_id: str + # The relation type. + rel_type: str + # The aggregation key. Will be None if the rel_type is not m.annotation or is + # not a string. + aggregation_key: Optional[str] + + class EventBase(metaclass=abc.ABCMeta): @property @abc.abstractmethod @@ -415,6 +429,36 @@ def auth_event_ids(self) -> Sequence[str]: """ return [e for e, _ in self._dict["auth_events"]] + def relation(self) -> Optional[_EventRelation]: + """ + Parse the event's relation information. + + Returns: + The event relation information, if it is valid. None, otherwise. + """ + relation = self.content.get("m.relates_to") + if not relation or not isinstance(relation, collections.abc.Mapping): + # No relation information. + return None + + # Relations must have a type and parent event ID. + rel_type = relation.get("rel_type") + if not isinstance(rel_type, str): + return None + + parent_id = relation.get("event_id") + if not isinstance(parent_id, str): + return None + + # Annotations have a key field. + aggregation_key = None + if rel_type == RelationTypes.ANNOTATION: + aggregation_key = relation.get("key") + if not isinstance(aggregation_key, str): + aggregation_key = None + + return _EventRelation(parent_id, rel_type, aggregation_key) + def freeze(self) -> None: """'Freeze' the event dict, so it cannot be modified by accident""" diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c28b792e6fe2..57ab2a5bd846 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1056,20 +1056,11 @@ async def _validate_event_relation(self, event: EventBase) -> None: SynapseError if the event is invalid. """ - relation = event.content.get("m.relates_to") + relation = event.relation() if not relation: return - relation_type = relation.get("rel_type") - if not relation_type: - return - - # Ensure the parent is real. - relates_to = relation.get("event_id") - if not relates_to: - return - - parent_event = await self.store.get_event(relates_to, allow_none=True) + parent_event = await self.store.get_event(relation.parent_id, allow_none=True) if parent_event: # And in the same room. if parent_event.room_id != event.room_id: @@ -1078,28 +1069,31 @@ async def _validate_event_relation(self, event: EventBase) -> None: else: # There must be some reason that the client knows the event exists, # see if there are existing relations. If so, assume everything is fine. - if not await self.store.event_is_target_of_relation(relates_to): + if not await self.store.event_is_target_of_relation(relation.parent_id): # Otherwise, the client can't know about the parent event! raise SynapseError(400, "Can't send relation to unknown event") # If this event is an annotation then we check that that the sender # can't annotate the same way twice (e.g. stops users from liking an # event multiple times). - if relation_type == RelationTypes.ANNOTATION: - aggregation_key = relation["key"] + if relation.rel_type == RelationTypes.ANNOTATION: + aggregation_key = relation.aggregation_key + + if aggregation_key is None: + raise SynapseError(400, "Missing aggregation key") if len(aggregation_key) > 500: raise SynapseError(400, "Aggregation key is too long") already_exists = await self.store.has_user_annotated_event( - relates_to, event.type, aggregation_key, event.sender + relation.parent_id, event.type, aggregation_key, event.sender ) if already_exists: raise SynapseError(400, "Can't send same reaction twice") # Don't attempt to start a thread if the parent event is a relation. - elif relation_type == RelationTypes.THREAD: - if await self.store.event_includes_relation(relates_to): + elif relation.rel_type == RelationTypes.THREAD: + if await self.store.event_includes_relation(relation.parent_id): raise SynapseError( 400, "Cannot start threads from an event with a relation" ) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index c2754ec918de..019d12b75693 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections.abc import logging from typing import ( TYPE_CHECKING, @@ -373,20 +372,21 @@ async def get_bundled_aggregations( if event.is_state(): continue - relates_to = event.content.get("m.relates_to") - relation_type = None - if isinstance(relates_to, collections.abc.Mapping): - relation_type = relates_to.get("rel_type") + relates_to = event.relation() + if relates_to: # An event which is a replacement (ie edit) or annotation (ie, # reaction) may not have any other event related to it. - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + if relates_to.rel_type in ( + RelationTypes.ANNOTATION, + RelationTypes.REPLACE, + ): continue + # Track the event's relation information for later. + relations_by_id[event.event_id] = relates_to.rel_type + # The event should get bundled aggregations. events_by_id[event.event_id] = event - # Track the event's relation information for later. - if isinstance(relation_type, str): - relations_by_id[event.event_id] = relation_type # event ID -> bundled aggregation in non-serialized form. results: Dict[str, BundledAggregations] = {} diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 85ddb56c6eb4..534e003c62b9 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -77,8 +77,8 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool: return False # Exclude edits. - relates_to = event.content.get("m.relates_to", {}) - if relates_to.get("rel_type") == RelationTypes.REPLACE: + relates_to = event.relation() + if relates_to and relates_to.rel_type == RelationTypes.REPLACE: return False # Mark events that have a non-empty string body as unread. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ed29a0a5e2db..52bc1bef66e3 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1815,52 +1815,45 @@ def _handle_event_relations( txn: The current database transaction. event: The event which might have relations. """ - relation = event.content.get("m.relates_to") + relation = event.relation() if not relation: - # No relations + # No relation, nothing to do. return - # Relations must have a type and parent event ID. - rel_type = relation.get("rel_type") - if not isinstance(rel_type, str): - return - - parent_id = relation.get("event_id") - if not isinstance(parent_id, str): - return - - # Annotations have a key field. - aggregation_key = None - if rel_type == RelationTypes.ANNOTATION: - aggregation_key = relation.get("key") - self.db_pool.simple_insert_txn( txn, table="event_relations", values={ "event_id": event.event_id, - "relates_to_id": parent_id, - "relation_type": rel_type, - "aggregation_key": aggregation_key, + "relates_to_id": relation.parent_id, + "relation_type": relation.rel_type, + "aggregation_key": relation.aggregation_key, }, ) - txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,)) txn.call_after( - self.store.get_aggregation_groups_for_event.invalidate, (parent_id,) + self.store.get_relations_for_event.invalidate, (relation.parent_id,) + ) + txn.call_after( + self.store.get_aggregation_groups_for_event.invalidate, + (relation.parent_id,), ) - if rel_type == RelationTypes.REPLACE: - txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + if relation.rel_type == RelationTypes.REPLACE: + txn.call_after( + self.store.get_applicable_edit.invalidate, (relation.parent_id,) + ) - if rel_type == RelationTypes.THREAD: - txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + if relation.rel_type == RelationTypes.THREAD: + txn.call_after( + self.store.get_thread_summary.invalidate, (relation.parent_id,) + ) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and # potentially error-prone) so it is always invalidated. txn.call_after( self.store.get_thread_participated.invalidate, - (parent_id, event.sender), + (relation.parent_id, event.sender), ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 010833764957..8170f20213e3 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -678,12 +678,13 @@ def test_unread_counts(self) -> None: self._check_unread_count(3) # Check that custom events with a body increase the unread counter. - self.helper.send_event( + result = self.helper.send_event( self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, ) + event_id = result["event_id"] self._check_unread_count(4) # Check that edits don't increase the unread counter. @@ -693,7 +694,10 @@ def test_unread_counts(self) -> None: content={ "body": "hello", "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": event_id, + }, }, tok=self.tok2, ) From 1ff43dde7dd87e39305df83f24e01cc1e84ef1bd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 13 May 2022 08:09:16 -0400 Subject: [PATCH 2/2] Move to a free-standing function. --- synapse/events/__init__.py | 83 ++++++++++++------------ synapse/handlers/message.py | 4 +- synapse/handlers/relations.py | 4 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/storage/databases/main/events.py | 6 +- 5 files changed, 51 insertions(+), 50 deletions(-) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index ded2d191d36d..39ad2793d98d 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -290,17 +290,6 @@ def is_historical(self) -> bool: return self._dict.get("historical", False) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _EventRelation: - # The target event of the relation. - parent_id: str - # The relation type. - rel_type: str - # The aggregation key. Will be None if the rel_type is not m.annotation or is - # not a string. - aggregation_key: Optional[str] - - class EventBase(metaclass=abc.ABCMeta): @property @abc.abstractmethod @@ -429,36 +418,6 @@ def auth_event_ids(self) -> Sequence[str]: """ return [e for e, _ in self._dict["auth_events"]] - def relation(self) -> Optional[_EventRelation]: - """ - Parse the event's relation information. - - Returns: - The event relation information, if it is valid. None, otherwise. - """ - relation = self.content.get("m.relates_to") - if not relation or not isinstance(relation, collections.abc.Mapping): - # No relation information. - return None - - # Relations must have a type and parent event ID. - rel_type = relation.get("rel_type") - if not isinstance(rel_type, str): - return None - - parent_id = relation.get("event_id") - if not isinstance(parent_id, str): - return None - - # Annotations have a key field. - aggregation_key = None - if rel_type == RelationTypes.ANNOTATION: - aggregation_key = relation.get("key") - if not isinstance(aggregation_key, str): - aggregation_key = None - - return _EventRelation(parent_id, rel_type, aggregation_key) - def freeze(self) -> None: """'Freeze' the event dict, so it cannot be modified by accident""" @@ -659,3 +618,45 @@ def make_event_from_dict( return event_type( event_dict, room_version, internal_metadata_dict or {}, rejected_reason ) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventRelation: + # The target event of the relation. + parent_id: str + # The relation type. + rel_type: str + # The aggregation key. Will be None if the rel_type is not m.annotation or is + # not a string. + aggregation_key: Optional[str] + + +def relation_from_event(event: EventBase) -> Optional[_EventRelation]: + """ + Attempt to parse relation information an event. + + Returns: + The event relation information, if it is valid. None, otherwise. + """ + relation = event.content.get("m.relates_to") + if not relation or not isinstance(relation, collections.abc.Mapping): + # No relation information. + return None + + # Relations must have a type and parent event ID. + rel_type = relation.get("rel_type") + if not isinstance(rel_type, str): + return None + + parent_id = relation.get("event_id") + if not isinstance(parent_id, str): + return None + + # Annotations have a key field. + aggregation_key = None + if rel_type == RelationTypes.ANNOTATION: + aggregation_key = relation.get("key") + if not isinstance(aggregation_key, str): + aggregation_key = None + + return _EventRelation(parent_id, rel_type, aggregation_key) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9dd429510f0a..0951b9c71f75 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -44,7 +44,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.event_auth import validate_event_for_room_version -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator @@ -1060,7 +1060,7 @@ async def _validate_event_relation(self, event: EventBase) -> None: SynapseError if the event is invalid. """ - relation = event.relation() + relation = relation_from_event(event) if not relation: return diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 019d12b75693..ab7e54857d56 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -27,7 +27,7 @@ from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.storage.databases.main.relations import _RelatedEvent from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -372,7 +372,7 @@ async def get_bundled_aggregations( if event.is_state(): continue - relates_to = event.relation() + relates_to = relation_from_event(event) if relates_to: # An event which is a replacement (ie edit) or annotation (ie, # reaction) may not have any other event related to it. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 3ce27931acf6..4ac2c546bf2a 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.event_auth import get_user_power_level -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership @@ -78,7 +78,7 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool: return False # Exclude edits. - relates_to = event.relation() + relates_to = relation_from_event(event) if relates_to and relates_to.rel_type == RelationTypes.REPLACE: return False diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5c6496c5b3db..42d484dc98d9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -36,8 +36,8 @@ import synapse.metrics from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.events import EventBase, relation_from_event +from synapse.events.snapshot import EventContext from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -1807,7 +1807,7 @@ def _handle_event_relations( txn: The current database transaction. event: The event which might have relations. """ - relation = event.relation() + relation = relation_from_event(event) if not relation: # No relation, nothing to do. return