Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Wire up new token when fetching events streams
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Oct 1, 2020
1 parent b2172da commit d7da8ca
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 36 deletions.
200 changes: 164 additions & 36 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
Expand Down Expand Up @@ -209,6 +210,49 @@ def _make_generic_sql_bound(
)


def _filter_results(
direction: str,
from_token: Optional[RoomStreamToken],
to_token: Optional[RoomStreamToken],
instance_name: str,
stream_ordering: int,
) -> bool:
"""Filter results from fetching events in the DB against the given tokens.
This is necessary to handle the case where the tokens include positions
maps, which we handle by fetching more than necessary from the DB and then
filtering (rather than attempting to construct a complicated SQL query).
"""

# We will have already filtered by the topological tokens, so we don't
# bother checking topological token bounds again.
if from_token and from_token.topological:
from_token = None

if to_token and to_token.topological:
to_token = None

lower_bound = None
if direction == "f" and from_token:
lower_bound = from_token.get_stream_pos_for_instance(instance_name)
elif direction == "b" and to_token:
lower_bound = to_token.get_stream_pos_for_instance(instance_name)

if lower_bound and stream_ordering <= lower_bound:
return False

upper_bound = None
if direction == "b" and from_token:
upper_bound = from_token.get_stream_pos_for_instance(instance_name)
elif direction == "f" and to_token:
upper_bound = to_token.get_stream_pos_for_instance(instance_name)

if upper_bound and upper_bound < stream_ordering:
return False

return True


def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
Expand Down Expand Up @@ -306,7 +350,26 @@ def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()

def get_room_max_token(self) -> RoomStreamToken:
return RoomStreamToken(None, self.get_room_max_stream_ordering())
min_pos = self._stream_id_gen.get_current_token()

positions = {}
if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
# The `min_pos` is the minimum position that we know all instances
# have finished persisting to, so we only care about instances whose
# positions are ahead of that. (Instance positions can be behind the
# min position as there are times we can work out that the minimum
# position is ahead of the naive minimum across all current
# positions. See MultiWriterIdGenerator for details)
positions = {
i: p
for i, p in self._stream_id_gen.get_positions().items()
if p > min_pos
}

if set(positions.values()) == {min_pos}:
positions = {}

return RoomStreamToken(None, min_pos, positions)

async def get_room_events_stream_for_rooms(
self,
Expand Down Expand Up @@ -405,25 +468,39 @@ async def get_room_events_stream_for_room(
if from_key == to_key:
return [], from_key

from_id = from_key.stream
to_id = to_key.stream

has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
has_changed = self._events_stream_cache.has_entity_changed(
room_id, from_key.stream
)

if not has_changed:
return [], from_key

def f(txn):
sql = (
"SELECT event_id, stream_ordering FROM events WHERE"
" room_id = ?"
" AND not outlier"
" AND stream_ordering > ? AND stream_ordering <= ?"
" ORDER BY stream_ordering %s LIMIT ?"
) % (order,)
txn.execute(sql, (room_id, from_id, to_id, limit))

rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and the filter down
min_from_id = from_key.stream
max_to_id = to_key.get_max_stream_pos()

sql = """
SELECT event_id, instance_name, stream_ordering
FROM events
WHERE
room_id = ?
AND not outlier
AND stream_ordering > ? AND stream_ordering <= ?
ORDER BY stream_ordering %s LIMIT ?
""" % (
order,
)
txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))

rows = [
_EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, stream_ordering in txn
if _filter_results(
"f", from_key, to_key, instance_name, stream_ordering
)
][:limit]
return rows

rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
Expand All @@ -432,7 +509,7 @@ def f(txn):
[r.event_id for r in rows], get_prev_content=True
)

self._set_before_and_after(ret, rows, topo_order=from_id is None)
self._set_before_and_after(ret, rows, topo_order=from_key.stream is None)

if order.lower() == "desc":
ret.reverse()
Expand All @@ -449,31 +526,39 @@ def f(txn):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
from_id = from_key.stream
to_id = to_key.stream

if from_key == to_key:
return []

if from_id:
if from_key:
has_changed = self._membership_stream_cache.has_entity_changed(
user_id, int(from_id)
user_id, int(from_key.stream)
)
if not has_changed:
return []

def f(txn):
sql = (
"SELECT m.event_id, stream_ordering FROM events AS e,"
" room_memberships AS m"
" WHERE e.event_id = m.event_id"
" AND m.user_id = ?"
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
)
txn.execute(sql, (user_id, from_id, to_id))
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and the filter down
min_from_id = from_key.stream
max_to_id = to_key.get_max_stream_pos()

sql = """
SELECT m.event_id, instance_name, stream_ordering
FROM events AS e, room_memberships AS m
WHERE e.event_id = m.event_id
AND m.user_id = ?
AND e.stream_ordering > ? AND e.stream_ordering <= ?
ORDER BY e.stream_ordering ASC
"""
txn.execute(sql, (user_id, min_from_id, max_to_id,))

rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
rows = [
_EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, stream_ordering in txn
if _filter_results(
"f", from_key, to_key, instance_name, stream_ordering
)
]

return rows

Expand Down Expand Up @@ -980,11 +1065,44 @@ def _paginate_room_events_txn(
else:
order = "ASC"

# The bounds for the stream tokens are complicated by the fact the fact
# that we need to handle the instance_map part of the tokens. We do this
# by fetching all events between the min stream token and the maximum
# stream token (as return by `RoomStreamToken.get_max_stream_pos`) and
# then filtering the results.
if from_token.topological is not None:
from_bound = from_token.as_tuple()
elif direction == "b":
from_bound = (
None,
from_token.get_max_stream_pos(),
)
else:
from_bound = (
None,
from_token.stream,
)

to_bound = None
if to_token:
if to_token.topological is not None:
to_bound = to_token.as_tuple()
elif direction == "b":
to_bound = (
None,
to_token.stream,
)
else:
to_bound = (
None,
to_token.get_max_stream_pos(),
)

bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=from_token.as_tuple(),
to_token=to_token.as_tuple() if to_token else None,
from_token=from_bound,
to_token=to_bound,
engine=self.database_engine,
)

Expand All @@ -994,7 +1112,8 @@ def _paginate_room_events_txn(
bounds += " AND " + filter_clause
args.extend(filter_args)

args.append(int(limit))
# We fetch more events as we'll filter the result set
args.append(int(limit) * 2)

select_keywords = "SELECT"
join_clause = ""
Expand All @@ -1016,7 +1135,9 @@ def _paginate_room_events_txn(
select_keywords += "DISTINCT"

sql = """
%(select_keywords)s event_id, topological_ordering, stream_ordering
%(select_keywords)s
event_id, instance_name,
topological_ordering, stream_ordering
FROM events
%(join_clause)s
WHERE outlier = ? AND room_id = ? AND %(bounds)s
Expand All @@ -1031,7 +1152,14 @@ def _paginate_room_events_txn(

txn.execute(sql, args)

rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
# Filter the result set.
rows = [
_EventDictReturn(event_id, topological_ordering, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results(
direction, from_token, to_token, instance_name, stream_ordering
)
][:limit]

if rows:
topo = rows[-1].topological_ordering
Expand Down
15 changes: 15 additions & 0 deletions synapse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,21 @@ def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
def as_tuple(self) -> Tuple[Optional[int], int]:
return (self.topological, self.stream)

def get_stream_pos_for_instance(self, instance_name: str) -> int:
"""Get the stream position for the instance
"""
return self.instance_map.get(instance_name, self.stream)

def get_max_stream_pos(self) -> int:
"""Get the maximum stream position referenced in this token.
The corresponding "min" position is, by definition just `self.stream`.
This is used to handle tokens that have non-empty `instance_map`, and so
reference stream positions after the `self.stream` position.
"""
return max(self.instance_map.values(), default=self.stream)

async def to_string(self, store: "DataStore") -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
Expand Down

0 comments on commit d7da8ca

Please sign in to comment.