diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index e6449d68b854..6eff289a0d42 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -53,6 +53,7 @@ ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import Collection, PersistedEventPosition, RoomStreamToken from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -209,6 +210,49 @@ def _make_generic_sql_bound( ) +def _filter_results( + direction: str, + from_token: Optional[RoomStreamToken], + to_token: Optional[RoomStreamToken], + instance_name: str, + stream_ordering: int, +) -> bool: + """Filter results from fetching events in the DB against the given tokens. + + This is necessary to handle the case where the tokens include positions + maps, which we handle by fetching more than necessary from the DB and then + filtering (rather than attempting to construct a complicated SQL query). + """ + + # We will have already filtered by the topological tokens, so we don't + # bother checking topological token bounds again. + if from_token and from_token.topological: + from_token = None + + if to_token and to_token.topological: + to_token = None + + lower_bound = None + if direction == "f" and from_token: + lower_bound = from_token.get_stream_pos_for_instance(instance_name) + elif direction == "b" and to_token: + lower_bound = to_token.get_stream_pos_for_instance(instance_name) + + if lower_bound and stream_ordering <= lower_bound: + return False + + upper_bound = None + if direction == "b" and from_token: + upper_bound = from_token.get_stream_pos_for_instance(instance_name) + elif direction == "f" and to_token: + upper_bound = to_token.get_stream_pos_for_instance(instance_name) + + if upper_bound and upper_bound < stream_ordering: + return False + + return True + + def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create @@ -306,7 +350,26 @@ def get_room_min_stream_ordering(self) -> int: raise NotImplementedError() def get_room_max_token(self) -> RoomStreamToken: - return RoomStreamToken(None, self.get_room_max_stream_ordering()) + min_pos = self._stream_id_gen.get_current_token() + + positions = {} + if isinstance(self._stream_id_gen, MultiWriterIdGenerator): + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + positions = { + i: p + for i, p in self._stream_id_gen.get_positions().items() + if p > min_pos + } + + if set(positions.values()) == {min_pos}: + positions = {} + + return RoomStreamToken(None, min_pos, positions) async def get_room_events_stream_for_rooms( self, @@ -405,25 +468,39 @@ async def get_room_events_stream_for_room( if from_key == to_key: return [], from_key - from_id = from_key.stream - to_id = to_key.stream - - has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) + has_changed = self._events_stream_cache.has_entity_changed( + room_id, from_key.stream + ) if not has_changed: return [], from_key def f(txn): - sql = ( - "SELECT event_id, stream_ordering FROM events WHERE" - " room_id = ?" - " AND not outlier" - " AND stream_ordering > ? AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) - txn.execute(sql, (room_id, from_id, to_id, limit)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and the filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT event_id, instance_name, stream_ordering + FROM events + WHERE + room_id = ? + AND not outlier + AND stream_ordering > ? AND stream_ordering <= ? + ORDER BY stream_ordering %s LIMIT ? + """ % ( + order, + ) + txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, stream_ordering in txn + if _filter_results( + "f", from_key, to_key, instance_name, stream_ordering + ) + ][:limit] return rows rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) @@ -432,7 +509,7 @@ def f(txn): [r.event_id for r in rows], get_prev_content=True ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._set_before_and_after(ret, rows, topo_order=from_key.stream is None) if order.lower() == "desc": ret.reverse() @@ -449,31 +526,39 @@ def f(txn): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - from_id = from_key.stream - to_id = to_key.stream - if from_key == to_key: return [] - if from_id: + if from_key: has_changed = self._membership_stream_cache.has_entity_changed( - user_id, int(from_id) + user_id, int(from_key.stream) ) if not has_changed: return [] def f(txn): - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND e.stream_ordering > ? AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - ) - txn.execute(sql, (user_id, from_id, to_id)) + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and the filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT m.event_id, instance_name, stream_ordering + FROM events AS e, room_memberships AS m + WHERE e.event_id = m.event_id + AND m.user_id = ? + AND e.stream_ordering > ? AND e.stream_ordering <= ? + ORDER BY e.stream_ordering ASC + """ + txn.execute(sql, (user_id, min_from_id, max_to_id,)) - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, stream_ordering in txn + if _filter_results( + "f", from_key, to_key, instance_name, stream_ordering + ) + ] return rows @@ -980,11 +1065,44 @@ def _paginate_room_events_txn( else: order = "ASC" + # The bounds for the stream tokens are complicated by the fact the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as return by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + if from_token.topological is not None: + from_bound = from_token.as_tuple() + elif direction == "b": + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound = None + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_tuple() + elif direction == "b": + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.as_tuple(), - to_token=to_token.as_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) @@ -994,7 +1112,8 @@ def _paginate_room_events_txn( bounds += " AND " + filter_clause args.extend(filter_args) - args.append(int(limit)) + # We fetch more events as we'll filter the result set + args.append(int(limit) * 2) select_keywords = "SELECT" join_clause = "" @@ -1016,7 +1135,9 @@ def _paginate_room_events_txn( select_keywords += "DISTINCT" sql = """ - %(select_keywords)s event_id, topological_ordering, stream_ordering + %(select_keywords)s + event_id, instance_name, + topological_ordering, stream_ordering FROM events %(join_clause)s WHERE outlier = ? AND room_id = ? AND %(bounds)s @@ -1031,7 +1152,14 @@ def _paginate_room_events_txn( txn.execute(sql, args) - rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] + # Filter the result set. + rows = [ + _EventDictReturn(event_id, topological_ordering, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + direction, from_token, to_token, instance_name, stream_ordering + ) + ][:limit] if rows: topo = rows[-1].topological_ordering diff --git a/synapse/types.py b/synapse/types.py index 550a968a03be..b2f18fe55eaa 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -488,6 +488,21 @@ def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": def as_tuple(self) -> Tuple[Optional[int], int]: return (self.topological, self.stream) + def get_stream_pos_for_instance(self, instance_name: str) -> int: + """Get the stream position for the instance + """ + return self.instance_map.get(instance_name, self.stream) + + def get_max_stream_pos(self) -> int: + """Get the maximum stream position referenced in this token. + + The corresponding "min" position is, by definition just `self.stream`. + + This is used to handle tokens that have non-empty `instance_map`, and so + reference stream positions after the `self.stream` position. + """ + return max(self.instance_map.values(), default=self.stream) + async def to_string(self, store: "DataStore") -> str: if self.topological is not None: return "t%d-%d" % (self.topological, self.stream)