Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Consolidate logic for parsing relations.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed May 10, 2022
1 parent 5c00151 commit e50674f
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 56 deletions.
1 change: 1 addition & 0 deletions changelog.d/12693.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Consolidate parsing of relation information from events.
44 changes: 44 additions & 0 deletions synapse/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import abc
import collections.abc
import os
from typing import (
TYPE_CHECKING,
Expand All @@ -32,9 +33,11 @@
overload,
)

import attr
from typing_extensions import Literal
from unpaddedbase64 import encode_base64

from synapse.api.constants import RelationTypes
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.caches import intern_dict
Expand Down Expand Up @@ -287,6 +290,17 @@ def is_historical(self) -> bool:
return self._dict.get("historical", False)


@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventRelation:
# The target event of the relation.
parent_id: str
# The relation type.
rel_type: str
# The aggregation key. Will be None if the rel_type is not m.annotation or is
# not a string.
aggregation_key: Optional[str]


class EventBase(metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
Expand Down Expand Up @@ -415,6 +429,36 @@ def auth_event_ids(self) -> Sequence[str]:
"""
return [e for e, _ in self._dict["auth_events"]]

def relation(self) -> Optional[_EventRelation]:
"""
Parse the event's relation information.
Returns:
The event relation information, if it is valid. None, otherwise.
"""
relation = self.content.get("m.relates_to")
if not relation or not isinstance(relation, collections.abc.Mapping):
# No relation information.
return None

# Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
if not isinstance(rel_type, str):
return None

parent_id = relation.get("event_id")
if not isinstance(parent_id, str):
return None

# Annotations have a key field.
aggregation_key = None
if rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation.get("key")
if not isinstance(aggregation_key, str):
aggregation_key = None

return _EventRelation(parent_id, rel_type, aggregation_key)

def freeze(self) -> None:
"""'Freeze' the event dict, so it cannot be modified by accident"""

Expand Down
28 changes: 11 additions & 17 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,20 +1056,11 @@ async def _validate_event_relation(self, event: EventBase) -> None:
SynapseError if the event is invalid.
"""

relation = event.content.get("m.relates_to")
relation = event.relation()
if not relation:
return

relation_type = relation.get("rel_type")
if not relation_type:
return

# Ensure the parent is real.
relates_to = relation.get("event_id")
if not relates_to:
return

parent_event = await self.store.get_event(relates_to, allow_none=True)
parent_event = await self.store.get_event(relation.parent_id, allow_none=True)
if parent_event:
# And in the same room.
if parent_event.room_id != event.room_id:
Expand All @@ -1078,28 +1069,31 @@ async def _validate_event_relation(self, event: EventBase) -> None:
else:
# There must be some reason that the client knows the event exists,
# see if there are existing relations. If so, assume everything is fine.
if not await self.store.event_is_target_of_relation(relates_to):
if not await self.store.event_is_target_of_relation(relation.parent_id):
# Otherwise, the client can't know about the parent event!
raise SynapseError(400, "Can't send relation to unknown event")

# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
# event multiple times).
if relation_type == RelationTypes.ANNOTATION:
aggregation_key = relation["key"]
if relation.rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation.aggregation_key

if aggregation_key is None:
raise SynapseError(400, "Missing aggregation key")

if len(aggregation_key) > 500:
raise SynapseError(400, "Aggregation key is too long")

already_exists = await self.store.has_user_annotated_event(
relates_to, event.type, aggregation_key, event.sender
relation.parent_id, event.type, aggregation_key, event.sender
)
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")

# Don't attempt to start a thread if the parent event is a relation.
elif relation_type == RelationTypes.THREAD:
if await self.store.event_includes_relation(relates_to):
elif relation.rel_type == RelationTypes.THREAD:
if await self.store.event_includes_relation(relation.parent_id):
raise SynapseError(
400, "Cannot start threads from an event with a relation"
)
Expand Down
18 changes: 9 additions & 9 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
import logging
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -373,20 +372,21 @@ async def get_bundled_aggregations(
if event.is_state():
continue

relates_to = event.content.get("m.relates_to")
relation_type = None
if isinstance(relates_to, collections.abc.Mapping):
relation_type = relates_to.get("rel_type")
relates_to = event.relation()
if relates_to:
# An event which is a replacement (ie edit) or annotation (ie,
# reaction) may not have any other event related to it.
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
if relates_to.rel_type in (
RelationTypes.ANNOTATION,
RelationTypes.REPLACE,
):
continue

# Track the event's relation information for later.
relations_by_id[event.event_id] = relates_to.rel_type

# The event should get bundled aggregations.
events_by_id[event.event_id] = event
# Track the event's relation information for later.
if isinstance(relation_type, str):
relations_by_id[event.event_id] = relation_type

# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
Expand Down
4 changes: 2 additions & 2 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
return False

# Exclude edits.
relates_to = event.content.get("m.relates_to", {})
if relates_to.get("rel_type") == RelationTypes.REPLACE:
relates_to = event.relation()
if relates_to and relates_to.rel_type == RelationTypes.REPLACE:
return False

# Mark events that have a non-empty string body as unread.
Expand Down
45 changes: 19 additions & 26 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,52 +1815,45 @@ def _handle_event_relations(
txn: The current database transaction.
event: The event which might have relations.
"""
relation = event.content.get("m.relates_to")
relation = event.relation()
if not relation:
# No relations
# No relation, nothing to do.
return

# Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
if not isinstance(rel_type, str):
return

parent_id = relation.get("event_id")
if not isinstance(parent_id, str):
return

# Annotations have a key field.
aggregation_key = None
if rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation.get("key")

self.db_pool.simple_insert_txn(
txn,
table="event_relations",
values={
"event_id": event.event_id,
"relates_to_id": parent_id,
"relation_type": rel_type,
"aggregation_key": aggregation_key,
"relates_to_id": relation.parent_id,
"relation_type": relation.rel_type,
"aggregation_key": relation.aggregation_key,
},
)

txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
txn.call_after(
self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
self.store.get_relations_for_event.invalidate, (relation.parent_id,)
)
txn.call_after(
self.store.get_aggregation_groups_for_event.invalidate,
(relation.parent_id,),
)

if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
if relation.rel_type == RelationTypes.REPLACE:
txn.call_after(
self.store.get_applicable_edit.invalidate, (relation.parent_id,)
)

if rel_type == RelationTypes.THREAD:
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
if relation.rel_type == RelationTypes.THREAD:
txn.call_after(
self.store.get_thread_summary.invalidate, (relation.parent_id,)
)
# It should be safe to only invalidate the cache if the user has not
# previously participated in the thread, but that's difficult (and
# potentially error-prone) so it is always invalidated.
txn.call_after(
self.store.get_thread_participated.invalidate,
(parent_id, event.sender),
(relation.parent_id, event.sender),
)

def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
Expand Down
8 changes: 6 additions & 2 deletions tests/rest/client/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,13 @@ def test_unread_counts(self) -> None:
self._check_unread_count(3)

# Check that custom events with a body increase the unread counter.
self.helper.send_event(
result = self.helper.send_event(
self.room_id,
"org.matrix.custom_type",
{"body": "hello"},
tok=self.tok2,
)
event_id = result["event_id"]
self._check_unread_count(4)

# Check that edits don't increase the unread counter.
Expand All @@ -693,7 +694,10 @@ def test_unread_counts(self) -> None:
content={
"body": "hello",
"msgtype": "m.text",
"m.relates_to": {"rel_type": RelationTypes.REPLACE},
"m.relates_to": {
"rel_type": RelationTypes.REPLACE,
"event_id": event_id,
},
},
tok=self.tok2,
)
Expand Down

0 comments on commit e50674f

Please sign in to comment.