Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Separate creating an event context and persisting the event in the fed handler #9800

Merged
merged 13 commits into from
Apr 14, 2021
1 change: 1 addition & 0 deletions changelog.d/9800.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.
178 changes: 113 additions & 65 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@

@attr.s(slots=True)
class _NewEventInfo:
"""Holds information about a received event, ready for passing to _handle_new_events
"""Holds information about a received event, ready for passing to _auth_and_persist_events

Attributes:
event: the received event
Expand Down Expand Up @@ -808,7 +808,10 @@ async def _process_received_pdu(
logger.debug("Processing event: %s", event)

try:
await self._handle_new_event(origin, event, state=state)
context = await self.state_handler.compute_event_context(
event, old_state=state
)
await self._auth_and_persist_event(origin, event, context, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)

Expand Down Expand Up @@ -1011,7 +1014,9 @@ async def backfill(
)

if ev_infos:
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
await self._auth_and_persist_events(
dest, room_id, ev_infos, backfilled=True
)

# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
Expand All @@ -1024,10 +1029,12 @@ async def backfill(
# non-outliers
assert not event.internal_metadata.is_outlier()

context = await self.state_handler.compute_event_context(event)

# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
await self._handle_new_event(dest, event, backfilled=True)
await self._auth_and_persist_event(dest, event, context, backfilled=True)

return events

Expand Down Expand Up @@ -1361,7 +1368,7 @@ async def get_event(event_id: str):

event_infos.append(_NewEventInfo(event, None, auth))

await self._handle_new_events(
await self._auth_and_persist_events(
destination,
room_id,
event_infos,
Expand Down Expand Up @@ -1667,10 +1674,11 @@ async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
# would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin

context = await self._handle_new_event(origin, event)
context = await self.state_handler.compute_event_context(event)
context = await self._auth_and_persist_event(origin, event, context)

logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
"on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
Expand Down Expand Up @@ -1879,10 +1887,11 @@ async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:

event.internal_metadata.outlier = False

await self._handle_new_event(origin, event)
context = await self.state_handler.compute_event_context(event)
await self._auth_and_persist_event(origin, event, context)

logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s",
"on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
Expand Down Expand Up @@ -1990,16 +1999,47 @@ async def get_persisted_pdu(
async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context)

async def _handle_new_event(
async def _auth_and_persist_event(
self,
origin: str,
event: EventBase,
context: EventContext,
state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None,
backfilled: bool = False,
) -> EventContext:
context = await self._prep_event(
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
"""
Process an event by performing auth checks and then persisting to the database.

Args:
origin: The host the event originates from.
event: The event itself.
context:
The event context.

NB that this function potentially modifies it.
state:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
auth_events:
Map from (event_type, state_key) to event

Normally, our calculated auth_events based on the state of the room
at the event's position in the DAG, though occasionally (eg if the
event is an outlier), may be the auth events claimed by the remote
server.
clokep marked this conversation as resolved.
Show resolved Hide resolved
backfilled: True if the event was backfilled.

Returns:
The event context.
"""
context = await self._check_event_auth(
origin,
event,
context,
state=state,
auth_events=auth_events,
backfilled=backfilled,
)

try:
Expand All @@ -2023,7 +2063,7 @@ async def _handle_new_event(

return context

async def _handle_new_events(
async def _auth_and_persist_events(
self,
origin: str,
room_id: str,
Expand All @@ -2041,9 +2081,13 @@ async def _handle_new_events(
async def prep(ev_info: _NewEventInfo):
event = ev_info.event
with nested_logging_context(suffix=event.event_id):
res = await self._prep_event(
res = await self.state_handler.compute_event_context(
event, old_state=ev_info.state
)
res = await self._check_event_auth(
origin,
event,
res,
state=ev_info.state,
auth_events=ev_info.auth_events,
backfilled=backfilled,
Expand Down Expand Up @@ -2178,49 +2222,6 @@ async def _persist_auth_tree(
room_id, [(event, new_event_context)]
)

async def _prep_event(
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
context = await self.state_handler.compute_event_context(event, old_state=state)

if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}

# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = await self.store.get_event(
event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c

context = await self.do_auth(origin, event, context, auth_events=auth_events)

if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)

if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)

# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)

return context

async def _check_for_soft_fail(
self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
) -> None:
Expand Down Expand Up @@ -2331,19 +2332,28 @@ async def on_get_missing_events(

return missing_events

async def do_auth(
async def _check_event_auth(
self,
origin: str,
event: EventBase,
context: EventContext,
auth_events: MutableStateMap[EventBase],
state: Optional[Iterable[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
"""
Checks whether an event should be rejected (for failing auth checks).

Args:
origin:
event:
origin: The host the event originates from.
event: The event itself.
context:
The event context.

NB that this function potentially modifies it.
state:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
auth_events:
Map from (event_type, state_key) to event

Expand All @@ -2353,12 +2363,34 @@ async def do_auth(
server.

Also NB that this function adds entries to it.

If this is not provided, it is calculated from the previous state IDs.
backfilled: True if the event was backfilled.

Returns:
updated context object
The updated context object.
"""
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]

if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}

# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = await self.store.get_event(
event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c

try:
context = await self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events
Expand All @@ -2380,6 +2412,17 @@ async def do_auth(
logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR

if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)

if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)

# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)

richvdh marked this conversation as resolved.
Show resolved Hide resolved
return context

async def _update_auth_events_and_context_for_auth(
Expand All @@ -2389,7 +2432,7 @@ async def _update_auth_events_and_context_for_auth(
context: EventContext,
auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""Helper for do_auth. See there for docs.
"""Helper for _check_event_auth. See there for docs.

Checks whether a given event has the expected auth events. If it
doesn't then we talk to the remote server to compare state to see if
Expand Down Expand Up @@ -2469,9 +2512,14 @@ async def _update_auth_events_and_context_for_auth(
e.internal_metadata.outlier = True

logger.debug(
"do_auth %s missing_auth: %s", event.event_id, e.event_id
"_check_event_auth %s missing_auth: %s",
event.event_id,
e.event_id,
)
context = await self.state_handler.compute_event_context(e)
await self._auth_and_persist_event(
origin, e, context, auth_events=auth
)
await self._handle_new_event(origin, e, auth_events=auth)

if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e
Expand Down
6 changes: 4 additions & 2 deletions tests/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ def setUp(self):
)

self.handler = self.homeserver.get_federation_handler()
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
context
self.handler._check_event_auth = (
lambda origin, event, context, state, auth_events, backfilled: succeed(
context
)
)
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
Expand Down