Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit a750347

Browse files
committed
Ensure forked threads are not allowed.
1 parent 1fff047 commit a750347

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed

synapse/handlers/message.py

+7
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,13 @@ async def _validate_event_relation(self, event: EventBase) -> None:
10481048
if already_exists:
10491049
raise SynapseError(400, "Can't send same reaction twice")
10501050

1051+
# If this relation is a thread, then ensure thread head is not part of
1052+
# a thread already.
1053+
elif relation_type == RelationTypes.THREAD:
1054+
already_thread = await self.store.get_event_thread(relates_to)
1055+
if already_thread:
1056+
raise SynapseError(400, "Can't fork threads")
1057+
10511058
@measure_func("handle_new_client_event")
10521059
async def handle_new_client_event(
10531060
self,

synapse/storage/databases/main/relations.py

+31
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,37 @@ def _get_if_user_has_annotated_event(txn):
372372
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
373373
)
374374

375+
async def get_event_thread(self, event_id: str) -> Optional[str]:
376+
"""Return an event's thread.
377+
378+
Args:
379+
event_id: The event being used as the start of a new thread.
380+
381+
Returns:
382+
The thread ID of the event.
383+
"""
384+
385+
sql = """
386+
SELECT relates_to_id FROM event_relations
387+
WHERE
388+
event_id = ?
389+
AND relation_type = ?
390+
LIMIT 1;
391+
"""
392+
393+
def _get_thread_id(txn) -> Optional[str]:
394+
txn.execute(
395+
sql,
396+
(
397+
event_id,
398+
RelationTypes.THREAD,
399+
),
400+
)
401+
402+
return txn.fetchone()
403+
404+
return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
405+
375406

376407
class RelationsStore(RelationsWorkerStore):
377408
pass

tests/rest/client/test_relations.py

+19
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,25 @@ def test_deny_double_react(self):
119119
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
120120
self.assertEquals(400, channel.code, channel.json_body)
121121

122+
def test_deny_forked_thread(self):
123+
"""It is invalid to start a thread off a thread."""
124+
channel = self._send_relation(
125+
RelationTypes.THREAD,
126+
"m.room.message",
127+
content={"msgtype": "m.text", "body": "foo"},
128+
parent_id=self.parent_id,
129+
)
130+
self.assertEquals(200, channel.code, channel.json_body)
131+
parent_id = channel.json_body["event_id"]
132+
133+
channel = self._send_relation(
134+
RelationTypes.THREAD,
135+
"m.room.message",
136+
content={"msgtype": "m.text", "body": "foo"},
137+
parent_id=parent_id,
138+
)
139+
self.assertEquals(400, channel.code, channel.json_body)
140+
122141
def test_basic_paginate_relations(self):
123142
"""Tests that calling pagination API correctly the latest relations."""
124143
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")

0 commit comments

Comments
 (0)