Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support filtering by relations per MSC3440 #11236

Merged
merged 12 commits into from
Nov 9, 2021
1 change: 1 addition & 0 deletions changelog.d/11236.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support filtering by relation senders & types per [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).
115 changes: 84 additions & 31 deletions synapse/api/filtering.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -86,6 +86,9 @@
# cf https://github.com/matrix-org/matrix-doc/pull/2326
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
# MSC3440, filtering by event relations.
"io.element.relation_senders": {"type": "array", "items": {"type": "string"}},
"io.element.relation_types": {"type": "array", "items": {"type": "string"}},
},
}

Expand Down Expand Up @@ -146,14 +149,16 @@ def matrix_user_id_validator(user_id_str: str) -> UserID:

class Filtering:
def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self.store = hs.get_datastore()

self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})

async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str]
) -> "FilterCollection":
result = await self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(result)
return FilterCollection(self._hs, result)

def add_user_filter(
self, user_localpart: str, user_filter: JsonDict
Expand Down Expand Up @@ -191,21 +196,22 @@ def check_valid_filter(self, user_filter_json: JsonDict) -> None:


class FilterCollection:
def __init__(self, filter_json: JsonDict):
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
self._filter_json = filter_json

room_filter_json = self._filter_json.get("room", {})

self._room_filter = Filter(
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}
hs,
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")},
)

self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(room_filter_json.get("state", {}))
self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
self._room_account_data = Filter(room_filter_json.get("account_data", {}))
self._presence_filter = Filter(filter_json.get("presence", {}))
self._account_data = Filter(filter_json.get("account_data", {}))
self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(hs, room_filter_json.get("state", {}))
self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {}))
self._room_account_data = Filter(hs, room_filter_json.get("account_data", {}))
self._presence_filter = Filter(hs, filter_json.get("presence", {}))
self._account_data = Filter(hs, filter_json.get("account_data", {}))

self.include_leave = filter_json.get("room", {}).get("include_leave", False)
self.event_fields = filter_json.get("event_fields", [])
Expand All @@ -232,25 +238,37 @@ def lazy_load_members(self) -> bool:
def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members

def filter_presence(
async def filter_presence(
self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
return self._presence_filter.filter(events)
return await self._presence_filter.filter(events)

def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._account_data.filter(events)
async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return await self._account_data.filter(events)

def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_state_filter.filter(self._room_filter.filter(events))
async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
return await self._room_state_filter.filter(
await self._room_filter.filter(events)
)

def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_timeline_filter.filter(self._room_filter.filter(events))
async def filter_room_timeline(
self, events: Iterable[EventBase]
) -> List[EventBase]:
return await self._room_timeline_filter.filter(
await self._room_filter.filter(events)
)

def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
async def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return await self._room_ephemeral_filter.filter(
await self._room_filter.filter(events)
)

def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_account_data.filter(self._room_filter.filter(events))
async def filter_room_account_data(
self, events: Iterable[JsonDict]
) -> List[JsonDict]:
return await self._room_account_data.filter(
await self._room_filter.filter(events)
)

def blocks_all_presence(self) -> bool:
return (
Expand All @@ -274,7 +292,9 @@ def blocks_all_room_timeline(self) -> bool:


class Filter:
def __init__(self, filter_json: JsonDict):
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
self._hs = hs
self._store = hs.get_datastore()
self.filter_json = filter_json

self.limit = filter_json.get("limit", 10)
Expand All @@ -297,6 +317,20 @@ def __init__(self, filter_json: JsonDict):
self.labels = filter_json.get("org.matrix.labels", None)
self.not_labels = filter_json.get("org.matrix.not_labels", [])

# Ideally these would be rejected at the endpoint if they were provided
# and not supported, but that would involve modifying the JSON schema
# based on the homeserver configuration.
if hs.config.experimental.msc3440_enabled:
self.relation_senders = self.filter_json.get(
"io.element.relation_senders", None
)
self.relation_types = self.filter_json.get(
"io.element.relation_types", None
)
else:
self.relation_senders = None
self.relation_types = None

def filters_all_types(self) -> bool:
return "*" in self.not_types

Expand All @@ -306,7 +340,7 @@ def filters_all_senders(self) -> bool:
def filters_all_rooms(self) -> bool:
return "*" in self.not_rooms

def check(self, event: FilterEvent) -> bool:
def _check(self, event: FilterEvent) -> bool:
"""Checks whether the filter matches the given event.
Args:
Expand Down Expand Up @@ -420,8 +454,30 @@ def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:

return room_ids

def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return list(filter(self.check, events))
async def _check_event_relations(
self, events: Iterable[FilterEvent]
) -> List[FilterEvent]:
Comment on lines +457 to +459
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separating this out to do a single query for all events might be a bit premature optimization, but doing a separate query per event seemed quite expensive.

# The event IDs to check, mypy doesn't understand the ifinstance check.
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
event_ids_to_keep = set(
await self._store.events_have_relations(
event_ids, self.relation_senders, self.relation_types
)
)

return [
event
for event in events
if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep
]

async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
result = [event for event in events if self._check(event)]

if self.relation_senders or self.relation_types:
return await self._check_event_relations(result)

return result

def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
"""Returns a new filter with the given room IDs appended.
Expand All @@ -433,7 +489,7 @@ def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
filter: A new filter including the given rooms and the old
filter's rooms.
"""
newFilter = Filter(self.filter_json)
newFilter = Filter(self._hs, self.filter_json)
newFilter.rooms += room_ids
return newFilter

Expand All @@ -444,6 +500,3 @@ def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
return actual_value.startswith(type_prefix)
else:
return actual_value == filter_value


DEFAULT_FILTER_COLLECTION = FilterCollection({})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needed to be moved somewhere that it can have access to a HomeServer.

2 changes: 1 addition & 1 deletion synapse/handlers/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ async def get_messages(

if events:
if event_filter:
events = event_filter.filter(events)
events = await event_filter.filter(events)

events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None)
Expand Down
8 changes: 5 additions & 3 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,8 +1158,10 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
)

if event_filter:
results["events_before"] = event_filter.filter(results["events_before"])
results["events_after"] = event_filter.filter(results["events_after"])
results["events_before"] = await event_filter.filter(
results["events_before"]
)
results["events_after"] = await event_filter.filter(results["events_after"])

results["events_before"] = await filter_evts(results["events_before"])
results["events_after"] = await filter_evts(results["events_after"])
Expand Down Expand Up @@ -1195,7 +1197,7 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:

state_events = list(state[last_event_id].values())
if event_filter:
state_events = event_filter.filter(state_events)
state_events = await event_filter.filter(state_events)

results["state"] = await filter_evts(state_events)

Expand Down
8 changes: 5 additions & 3 deletions synapse/handlers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def search(
% (set(group_keys) - {"room_id", "sender"},),
)

search_filter = Filter(filter_dict)
search_filter = Filter(self.hs, filter_dict)

# TODO: Search through left rooms too
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
Expand Down Expand Up @@ -242,7 +242,7 @@ async def search(

rank_map.update({r["event"].event_id: r["rank"] for r in results})

filtered_events = search_filter.filter([r["event"] for r in results])
filtered_events = await search_filter.filter([r["event"] for r in results])

events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
Expand Down Expand Up @@ -292,7 +292,9 @@ async def search(

rank_map.update({r["event"].event_id: r["rank"] for r in results})

filtered_events = search_filter.filter([r["event"] for r in results])
filtered_events = await search_filter.filter(
[r["event"] for r in results]
)

events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
Expand Down
18 changes: 10 additions & 8 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ async def _load_filtered_recents(
log_kv({"limited": limited})

if potential_recents:
recents = sync_config.filter_collection.filter_room_timeline(
recents = await sync_config.filter_collection.filter_room_timeline(
potential_recents
)
log_kv({"recents_after_sync_filtering": len(recents)})
Expand Down Expand Up @@ -575,8 +575,8 @@ async def _load_filtered_recents(

log_kv({"loaded_recents": len(events)})

loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
loaded_recents = (
await sync_config.filter_collection.filter_room_timeline(events)
)

log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)})
Expand Down Expand Up @@ -1015,7 +1015,7 @@ async def compute_state_delta(

return {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
for e in await sync_config.filter_collection.filter_room_state(
list(state.values())
)
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
Expand Down Expand Up @@ -1383,7 +1383,7 @@ async def _generate_sync_entry_for_account_data(
sync_config.user
)

account_data_for_user = sync_config.filter_collection.filter_account_data(
account_data_for_user = await sync_config.filter_collection.filter_account_data(
[
{"type": account_data_type, "content": content}
for account_data_type, content in account_data.items()
Expand Down Expand Up @@ -1448,7 +1448,7 @@ async def _generate_sync_entry_for_presence(
# Deduplicate the presence entries so that there's at most one per user
presence = list({p.user_id: p for p in presence}.values())

presence = sync_config.filter_collection.filter_presence(presence)
presence = await sync_config.filter_collection.filter_presence(presence)

sync_result_builder.presence = presence

Expand Down Expand Up @@ -2021,12 +2021,14 @@ async def _generate_room_entry(
)

account_data_events = (
sync_config.filter_collection.filter_room_account_data(
await sync_config.filter_collection.filter_room_account_data(
account_data_events
)
)

ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral)
ephemeral = await sync_config.filter_collection.filter_room_ephemeral(
ephemeral
)

if not (
always_include
Expand Down
5 changes: 4 additions & 1 deletion synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,7 @@ class RoomEventContextServlet(RestServlet):

def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
Expand All @@ -600,7 +601,9 @@ async def on_GET(
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None

Expand Down
10 changes: 8 additions & 2 deletions synapse/rest/client/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ class RoomMessageListRestServlet(RestServlet):

def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
Expand All @@ -567,7 +568,9 @@ async def on_GET(
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
Expand Down Expand Up @@ -672,6 +675,7 @@ class RoomEventContextServlet(RestServlet):

def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
Expand All @@ -688,7 +692,9 @@ async def on_GET(
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None

Expand Down
Loading