Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Store the thread ID with the receipt.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Sep 12, 2022
1 parent 18c60ed commit 5af7a44
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 37 deletions.
1 change: 1 addition & 0 deletions synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
receipt.receipt_type,
receipt.user_id,
receipt.event_ids,
receipt.thread_id,
receipt.data,
)

Expand Down
25 changes: 18 additions & 7 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ def simple_upsert_txn(
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
"""
Expand All @@ -1213,7 +1214,12 @@ def simple_upsert_txn(

if table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
txn,
table,
keyvalues,
values,
insertion_values=insertion_values,
where_clause=where_clause,
)
else:
return self.simple_upsert_txn_emulated(
Expand All @@ -1222,6 +1228,7 @@ def simple_upsert_txn(
keyvalues,
values,
insertion_values=insertion_values,
where_clause=where_clause,
lock=lock,
)

Expand All @@ -1232,6 +1239,7 @@ def simple_upsert_txn_emulated(
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
"""
Expand Down Expand Up @@ -1259,14 +1267,15 @@ def _getwhere(key: str) -> str:
else:
return "%s = ?" % (key,)

where = [_getwhere(k) for k in keyvalues]
if where_clause:
where.append(where_clause)

if not values:
# If `values` is empty, then all of the values we care about are in
# the unique key, so there is nothing to UPDATE. We can just do a
# SELECT instead to see if it exists.
sql = "SELECT 1 FROM %s WHERE %s" % (
table,
" AND ".join(_getwhere(k) for k in keyvalues),
)
sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where))
sqlargs = list(keyvalues.values())
txn.execute(sql, sqlargs)
if txn.fetchall():
Expand All @@ -1277,7 +1286,7 @@ def _getwhere(key: str) -> str:
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join(_getwhere(k) for k in keyvalues),
" AND ".join(where),
)
sqlargs = list(values.values()) + list(keyvalues.values())

Expand Down Expand Up @@ -1307,6 +1316,7 @@ def simple_upsert_txn_native_upsert(
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
) -> bool:
"""
Use the native UPSERT functionality in PostgreSQL.
Expand All @@ -1331,11 +1341,12 @@ def simple_upsert_txn_native_upsert(
allvalues.update(values)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)

sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s") % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
f"WHERE {where_clause}" if where_clause else "",
latter,
)
txn.execute(sql, list(allvalues.values()))
Expand Down
71 changes: 53 additions & 18 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def _insert_linearized_receipt_txn(
receipt_type: str,
user_id: str,
event_id: str,
thread_id: Optional[str],
data: JsonDict,
stream_id: int,
) -> Optional[int]:
Expand All @@ -657,12 +658,27 @@ def _insert_linearized_receipt_txn(
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
if stream_ordering is not None:
sql = (
"SELECT stream_ordering, event_id FROM events"
" INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
if thread_id is None:
thread_clause = "r.thread_id IS NULL"
thread_args = ()
else:
thread_clause = "r.thread_id = ?"
thread_args = (thread_id,)

sql = f"""
SELECT stream_ordering, event_id FROM events
INNER JOIN receipts_linearized AS r USING (event_id, room_id)
WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause}
"""
txn.execute(
sql,
(
room_id,
receipt_type,
user_id,
)
+ thread_args,
)
txn.execute(sql, (room_id, receipt_type, user_id))

for so, eid in txn:
if int(so) >= stream_ordering:
Expand All @@ -682,20 +698,27 @@ def _insert_linearized_receipt_txn(
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
)

keyvalues = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
where_clause = ""
if thread_id is None:
where_clause = "thread_id IS NULL"
else:
keyvalues["thread_id"] = thread_id

self.db_pool.simple_upsert_txn(
txn,
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
},
keyvalues=keyvalues,
values={
"stream_id": stream_id,
"event_id": event_id,
"data": json_encoder.encode(data),
"thread_id": None,
},
where_clause=where_clause,
# receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
lock=False,
Expand Down Expand Up @@ -747,6 +770,7 @@ async def insert_receipt(
receipt_type: str,
user_id: str,
event_ids: List[str],
thread_id: Optional[str],
data: dict,
) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
Expand Down Expand Up @@ -779,6 +803,7 @@ async def insert_receipt(
receipt_type,
user_id,
linearized_event_id,
thread_id,
data,
stream_id=stream_id,
# Read committed is actually beneficial here because we check for a receipt with
Expand All @@ -793,7 +818,8 @@ async def insert_receipt(

now = self._clock.time_msec()
logger.debug(
"RR for event %s in %s (%i ms old)",
"Receipt %s for event %s in %s (%i ms old)",
receipt_type,
linearized_event_id,
room_id,
now - event_ts,
Expand All @@ -806,6 +832,7 @@ async def insert_receipt(
receipt_type,
user_id,
event_ids,
thread_id,
data,
)

Expand All @@ -820,6 +847,7 @@ def _insert_graph_receipt_txn(
receipt_type: str,
user_id: str,
event_ids: List[str],
thread_id: Optional[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts
Expand All @@ -831,19 +859,26 @@ def _insert_graph_receipt_txn(
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))

keyvalues = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
where_clause = ""
if thread_id is None:
where_clause = "thread_id IS NULL"
else:
keyvalues["thread_id"] = thread_id

self.db_pool.simple_upsert_txn(
txn,
table="receipts_graph",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
},
keyvalues=keyvalues,
values={
"event_ids": json_encoder.encode(event_ids),
"data": json_encoder.encode(data),
"thread_id": None,
},
where_clause=where_clause,
# receipts_graph has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
lock=False,
Expand Down
1 change: 1 addition & 0 deletions tests/handlers/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def test_sending_read_receipt_batches_to_application_services(self):
receipt_type="m.read",
user_id=self.local_user,
event_ids=[f"$eventid_{i}"],
thread_id=None,
data={},
)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/replication/slave/storage/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_push_actions_for_user(self, send_receipt: bool):
if send_receipt:
self.get_success(
self.master_store.insert_receipt(
ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {}
ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {}
)
)

Expand Down
14 changes: 12 additions & 2 deletions tests/replication/tcp/streams/test_receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def test_receipt(self):
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastores().main.insert_receipt(
"!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
"!room:blue",
"m.read",
USER_ID,
["$event:blue"],
thread_id=None,
data={"a": 1},
)
)
self.replicate()
Expand All @@ -57,7 +62,12 @@ def test_receipt(self):

self.get_success(
self.hs.get_datastores().main.insert_receipt(
"!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
"!room2:blue",
"m.read",
USER_ID,
["$event2:foo"],
thread_id=None,
data={"a": 2},
)
)
self.replicate()
Expand Down
4 changes: 3 additions & 1 deletion tests/storage/test_event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _mark_read(event_id: str) -> None:
"m.read",
user_id=user_id,
event_ids=[event_id],
thread_id=None,
data={},
)
)
Expand Down Expand Up @@ -262,13 +263,14 @@ def _create_event(
def _rotate() -> None:
self.get_success(self.store._rotate_notifs())

def _mark_read(event_id: str) -> None:
def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None:
self.get_success(
self.store.insert_receipt(
room_id,
"m.read",
user_id=user_id,
event_ids=[event_id],
thread_id=thread_id,
data={},
)
)
Expand Down
Loading

0 comments on commit 5af7a44

Please sign in to comment.