Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support pagination tokens from sync/messages in the relations API #11952

Merged
merged 9 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11952.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where pagination tokens from `/sync` and `/messages` could not be provided to the `/relations` API.
57 changes: 39 additions & 18 deletions synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,45 @@
PaginationChunk,
RelationPaginationToken,
)
from synapse.types import JsonDict
from synapse.types import JsonDict, RoomStreamToken, StreamToken

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore

logger = logging.getLogger(__name__)


async def _parse_token(
store: "DataStore", token: Optional[str]
) -> Optional[StreamToken]:
"""
For backwards compatibility support RelationPaginationToken, but new pagination
tokens are generated as full StreamTokens, to be compatible with /sync and /messages.
"""
if not token:
return None
# Luckily the format for StreamToken and RelationPaginationToken differ enough
# that they can easily be separated. An "_" appears in the serialization of
# RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses
# "-" only for separators.
if "_" in token:
return await StreamToken.from_string(store, token)
else:
relation_token = RelationPaginationToken.from_string(token)
return StreamToken(
room_key=RoomStreamToken(relation_token.topological, relation_token.stream),
presence_key=0,
typing_key=0,
receipt_key=0,
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=0,
groups_key=0,
)


class RelationPaginationServlet(RestServlet):
"""API to paginate relations on an event by topological ordering, optionally
filtered by relation type and event type.
Expand Down Expand Up @@ -88,13 +119,8 @@ async def on_GET(
pagination_chunk = PaginationChunk(chunk=[])
else:
# Return the relations
from_token = None
if from_token_str:
from_token = RelationPaginationToken.from_string(from_token_str)

to_token = None
if to_token_str:
to_token = RelationPaginationToken.from_string(to_token_str)
from_token = await _parse_token(self.store, from_token_str)
to_token = await _parse_token(self.store, to_token_str)

pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
Expand Down Expand Up @@ -125,7 +151,7 @@ async def on_GET(
events, now, bundle_aggregations=aggregations
)

return_value = pagination_chunk.to_dict()
return_value = await pagination_chunk.to_dict(self.store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event

Expand Down Expand Up @@ -216,7 +242,7 @@ async def on_GET(
to_token=to_token,
)

return 200, pagination_chunk.to_dict()
return 200, await pagination_chunk.to_dict(self.store)


class RelationAggregationGroupPaginationServlet(RestServlet):
Expand Down Expand Up @@ -287,13 +313,8 @@ async def on_GET(
from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to")

from_token = None
if from_token_str:
from_token = RelationPaginationToken.from_string(from_token_str)

to_token = None
if to_token_str:
to_token = RelationPaginationToken.from_string(to_token_str)
from_token = await _parse_token(self.store, from_token_str)
to_token = await _parse_token(self.store, to_token_str)

result = await self.store.get_relations_for_event(
event_id=parent_id,
Expand All @@ -313,7 +334,7 @@ async def on_GET(
now = self.clock.time_msec()
serialized_events = self._event_serializer.serialize_events(events, now)

return_value = result.to_dict()
return_value = await result.to_dict(self.store)
return_value["chunk"] = serialized_events

return 200, return_value
Expand Down
46 changes: 31 additions & 15 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,13 @@
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
from synapse.types import JsonDict
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,8 +95,8 @@ async def get_relations_for_event(
aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[RelationPaginationToken] = None,
to_token: Optional[RelationPaginationToken] = None,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> PaginationChunk:
"""Get a list of relations for an event, ordered by topological ordering.

Expand Down Expand Up @@ -138,8 +135,10 @@ async def get_relations_for_event(
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
from_token=from_token.room_key.as_historical_tuple()
if from_token
else None,
to_token=to_token.room_key.as_historical_tuple() if to_token else None,
engine=self.database_engine,
)

Expand Down Expand Up @@ -177,12 +176,27 @@ def _get_recent_references_for_event_txn(
last_topo_id = row[1]
last_stream_id = row[2]

next_batch = None
# If there are more events, generate the next pagination key.
next_token = None
if len(events) > limit and last_topo_id and last_stream_id:
next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
next_key = RoomStreamToken(last_topo_id, last_stream_id)
if from_token:
next_token = from_token.copy_and_replace("room_key", next_key)
else:
next_token = StreamToken(
room_key=next_key,
presence_key=0,
typing_key=0,
receipt_key=0,
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=0,
groups_key=0,
)

return PaginationChunk(
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
)

return await self.db_pool.runInteraction(
Expand Down Expand Up @@ -676,13 +690,15 @@ async def _get_bundled_aggregation_for_event(

annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
aggregations.annotations = annotations.to_dict()
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)

references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = references.to_dict()
aggregations.references = await references.to_dict(cast("DataStore", self))

# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
Expand Down
15 changes: 9 additions & 6 deletions synapse/storage/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License.

import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import attr

from synapse.api.errors import SynapseError
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore

logger = logging.getLogger(__name__)


Expand All @@ -39,14 +42,14 @@ class PaginationChunk:
next_batch: Optional[Any] = None
prev_batch: Optional[Any] = None

def to_dict(self) -> Dict[str, Any]:
async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
d = {"chunk": self.chunk}

if self.next_batch:
d["next_batch"] = self.next_batch.to_string()
d["next_batch"] = await self.next_batch.to_string(store)

if self.prev_batch:
d["prev_batch"] = self.prev_batch.to_string()
d["prev_batch"] = await self.prev_batch.to_string(store)

return d

Expand Down Expand Up @@ -75,7 +78,7 @@ def from_string(string: str) -> "RelationPaginationToken":
except ValueError:
raise SynapseError(400, "Invalid relation pagination token")

def to_string(self) -> str:
async def to_string(self, store: "DataStore") -> str:
reivilibre marked this conversation as resolved.
Show resolved Hide resolved
return "%d-%d" % (self.topological, self.stream)

def as_tuple(self) -> Tuple[Any, ...]:
Expand Down Expand Up @@ -105,7 +108,7 @@ def from_string(string: str) -> "AggregationPaginationToken":
except ValueError:
raise SynapseError(400, "Invalid aggregation pagination token")

def to_string(self) -> str:
async def to_string(self, store: "DataStore") -> str:
return "%d-%d" % (self.count, self.stream)

def as_tuple(self) -> Tuple[Any, ...]:
Expand Down
Loading