Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #5788 from matrix-org/rav/metaredactions
Browse files Browse the repository at this point in the history
  • Loading branch information
anoadragon453 committed Feb 20, 2020
2 parents c24b899 + fb86217 commit 408959c
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 101 deletions.
1 change: 1 addition & 0 deletions changelog.d/5788.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Correctly handle redactions of redactions.
213 changes: 112 additions & 101 deletions synapse/storage/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util import batch_iter
Expand Down Expand Up @@ -342,13 +337,12 @@ def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
log_ctx = LoggingContext.current_context()
log_ctx.record_event_fetch(len(missing_events_ids))

# Note that _enqueue_events is also responsible for turning db rows
# Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
# _enqueue_events is a bit of a rubbish name but naming is hard.
missing_events = yield self._enqueue_events(
missing_events = yield self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)

Expand Down Expand Up @@ -421,28 +415,28 @@ def _fetch_event_list(self, conn, event_list):
The fetch requests. Each entry consists of a list of event
ids to be fetched, and a deferred to be completed once the
events have been fetched.
The deferreds are callbacked with a dictionary mapping from event id
to event row. Note that it may well contain additional events that
were not part of this request.
"""
with Measure(self._clock, "_fetch_event_list"):
try:
event_id_lists = list(zip(*event_list))[0]
event_ids = [item for sublist in event_id_lists for item in sublist]
events_to_fetch = set(
event_id for events, _ in event_list for event_id in events
)

row_dict = self._new_transaction(
conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
)

# We only want to resolve deferreds from the main thread
def fire(lst, res):
for ids, d in lst:
if not d.called:
try:
with PreserveLoggingContext():
d.callback([res[i] for i in ids if i in res])
except Exception:
logger.exception("Failed to callback")
def fire():
for _, d in event_list:
d.callback(row_dict)

with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
self.hs.get_reactor().callFromThread(fire)
except Exception as e:
logger.exception("do_fetch")

Expand All @@ -457,13 +451,98 @@ def fire(evs, exc):
self.hs.get_reactor().callFromThread(fire, event_list, e)

@defer.inlineCallbacks
def _enqueue_events(self, events, allow_rejected=False):
def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups.
Args:
event_ids (Iterable[str]): The event_ids of the events to fetch
allow_rejected (bool): Whether to include rejected events
Returns:
Deferred[Dict[str, _EventCacheEntry]]:
map from event id to result. May return extra events which
weren't asked for.
"""
fetched_events = {}
events_to_fetch = event_ids

while events_to_fetch:
row_map = yield self._enqueue_events(events_to_fetch)

# we need to recursively fetch any redactions of those events
redaction_ids = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
fetched_events[event_id] = row
if row:
redaction_ids.update(row["redactions"])

events_to_fetch = redaction_ids.difference(fetched_events.keys())
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)

# build a map from event_id to EventBase
event_map = {}
for event_id, row in fetched_events.items():
if not row:
continue
assert row["event_id"] == event_id

rejected_reason = row["rejected_reason"]

if not allow_rejected and rejected_reason:
continue

d = json.loads(row["json"])
internal_metadata = json.loads(row["internal_metadata"])

format_version = row["format_version"]
if format_version is None:
# This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1

original_ev = event_type_from_format_version(format_version)(
event_dict=d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)

event_map[event_id] = original_ev

# finally, we can decide whether each one nededs redacting, and build
# the cache entries.
result_map = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id]["redactions"]
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)

cache_entry = _EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)

self._get_event_cache.prefill((event_id,), cache_entry)
result_map[event_id] = cache_entry

return result_map

@defer.inlineCallbacks
def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
Args:
events (Iterable[str]): events to be fetched.
Returns:
Deferred[Dict[str, Dict]]: map from event id to row data from the database.
May contain events that weren't requested.
"""
if not events:
return {}

events_d = defer.Deferred()
with self._event_fetch_lock:
Expand All @@ -482,32 +561,12 @@ def _enqueue_events(self, events, allow_rejected=False):
"fetch_events", self.runWithConnection, self._do_fetch
)

logger.debug("Loading %d events", len(events))
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
rows = yield events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))

if not allow_rejected:
rows[:] = [r for r in rows if r["rejected_reason"] is None]

res = yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self._get_event_from_row,
row["internal_metadata"],
row["json"],
row["redactions"],
rejected_reason=row["rejected_reason"],
format_version=row["format_version"],
)
for row in rows
],
consumeErrors=True,
)
)
row_map = yield events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))

return {e.event.event_id: e for e in res if e}
return row_map

def _fetch_event_rows(self, txn, event_ids):
"""Fetch event rows from the database
Expand Down Expand Up @@ -580,57 +639,16 @@ def _fetch_event_rows(self, txn, event_ids):

return event_dict

@defer.inlineCallbacks
def _get_event_from_row(
self, internal_metadata, js, redactions, format_version, rejected_reason=None
):
"""Parse an event row which has been read from the database
Args:
internal_metadata (str): json-encoded internal_metadata column
js (str): json-encoded event body from event_json
redactions (list[str]): a list of the events which claim to have redacted
this event, from the redactions table
format_version: (str): the 'format_version' column
rejected_reason (str|None): the reason this event was rejected, if any
Returns:
_EventCacheEntry
"""
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)

if format_version is None:
# This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1

original_ev = event_type_from_format_version(format_version)(
event_dict=d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)

redacted_event = yield self._maybe_redact_event_row(original_ev, redactions)

cache_entry = _EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)

self._get_event_cache.prefill((original_ev.event_id,), cache_entry)

return cache_entry

@defer.inlineCallbacks
def _maybe_redact_event_row(self, original_ev, redactions):
def _maybe_redact_event_row(self, original_ev, redactions, event_map):
"""Given an event object and a list of possible redacting event ids,
determine whether to honour any of those redactions and if so return a redacted
event.
Args:
original_ev (EventBase):
redactions (iterable[str]): list of event ids of potential redaction events
event_map (dict[str, EventBase]): other events which have been fetched, in
which we can look up the redaaction events. Map from event id to event.
Returns:
Deferred[EventBase|None]: if the event should be redacted, a pruned
Expand All @@ -640,15 +658,9 @@ def _maybe_redact_event_row(self, original_ev, redactions):
# we choose to ignore redactions of m.room.create events.
return None

if original_ev.type == "m.room.redaction":
# ... and redaction events
return None

redaction_map = yield self._get_events_from_cache_or_db(redactions)

for redaction_id in redactions:
redaction_entry = redaction_map.get(redaction_id)
if not redaction_entry:
redaction_event = event_map.get(redaction_id)
if not redaction_event or redaction_event.rejected_reason:
# we don't have the redaction event, or the redaction event was not
# authorized.
logger.debug(
Expand All @@ -658,7 +670,6 @@ def _maybe_redact_event_row(self, original_ev, redactions):
)
continue

redaction_event = redaction_entry.event
if redaction_event.room_id != original_ev.room_id:
logger.debug(
"%s was redacted by %s but redaction was in a different room!",
Expand Down
70 changes: 70 additions & 0 deletions tests/storage/test_redaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from mock import Mock

from twisted.internet import defer

from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
Expand Down Expand Up @@ -216,3 +218,71 @@ def test_redact_join(self):
},
event.unsigned["redacted_because"],
)

def test_circular_redaction(self):
redaction_event_id1 = "$redaction1_id:test"
redaction_event_id2 = "$redaction2_id:test"

class EventIdManglingBuilder:
def __init__(self, base_builder, event_id):
self._base_builder = base_builder
self._event_id = event_id

@defer.inlineCallbacks
def build(self, prev_event_ids):
built_event = yield self._base_builder.build(prev_event_ids)
built_event.event_id = self._event_id
built_event._event_dict["event_id"] = self._event_id
return built_event

@property
def room_id(self):
return self._base_builder.room_id

event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Redaction,
"sender": self.u_alice.to_string(),
"room_id": self.room1.to_string(),
"content": {"reason": "test"},
"redacts": redaction_event_id2,
},
),
redaction_event_id1,
)
)
)

self.get_success(self.store.persist_event(event_1, context_1))

event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Redaction,
"sender": self.u_alice.to_string(),
"room_id": self.room1.to_string(),
"content": {"reason": "test"},
"redacts": redaction_event_id1,
},
),
redaction_event_id2,
)
)
)
self.get_success(self.store.persist_event(event_2, context_2))

# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))

# it should have been redacted
self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2)
self.assertEqual(
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
)

0 comments on commit 408959c

Please sign in to comment.