Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix limit logic for EventsStream (#7358)
Browse files Browse the repository at this point in the history
* Factor out functions for injecting events into database

I want to add some more flexibility to the tools for injecting events into the
database, and I don't want to clutter up HomeserverTestCase with them, so let's
factor them out to a new file.

* Rework TestReplicationDataHandler

This wasn't very easy to work with: the mock wrapping was largely superfluous,
and it's useful to be able to inspect the received rows, and clear out the
received list.

* Fix AssertionErrors being thrown by EventsStream

Part of the problem was that there was an off-by-one error in the assertion,
but also the limit logic was too simple. Fix it all up and add some tests.
  • Loading branch information
richvdh committed Apr 29, 2020
1 parent eeef963 commit c2e1a21
Show file tree
Hide file tree
Showing 14 changed files with 658 additions and 67 deletions.
1 change: 1 addition & 0 deletions changelog.d/7358.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
4 changes: 3 additions & 1 deletion synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def __init__(self, hs):
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]

self._position_linearizer = Linearizer("replication_position")
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)

# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
Expand Down
22 changes: 8 additions & 14 deletions synapse/replication/tcp/streams/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,22 +170,16 @@ async def _update_function(
limited = False
upper_limit = current_token

# next up is the state delta table

state_rows = await self._store.get_all_updated_current_state_deltas(
# next up is the state delta table.
(
state_rows,
upper_limit,
state_rows_limited,
) = await self._store.get_all_updated_current_state_deltas(
from_token, upper_limit, target_row_count
) # type: List[Tuple]

# again, if we've hit the limit there, we'll need to limit the other sources
assert len(state_rows) < target_row_count
if len(state_rows) == target_row_count:
assert state_rows[-1][0] <= upper_limit
upper_limit = state_rows[-1][0]
limited = True
)

# FIXME: is it a given that there is only one row per stream_id in the
# state_deltas table (so that we can be sure that we have got all of the
# rows for upper_limit)?
limited = limited or state_rows_limited

# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
Expand Down
5 changes: 5 additions & 0 deletions synapse/server.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory

class HomeServer(object):
@property
Expand Down Expand Up @@ -121,3 +122,7 @@ class HomeServer(object):
pass
def get_instance_id(self) -> str:
pass
def get_event_builder_factory(self) -> EventBuilderFactory:
pass
def get_storage(self) -> synapse.storage.Storage:
pass
64 changes: 60 additions & 4 deletions synapse/storage/data_stores/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import threading
from collections import namedtuple
from typing import List, Optional
from typing import List, Optional, Tuple

from canonicaljson import json
from constantly import NamedConstant, Names
Expand Down Expand Up @@ -1084,18 +1084,74 @@ def get_all_new_backfill_event_rows(txn):
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)

def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
async def get_all_updated_current_state_deltas(
self, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
from_token: The previous stream token. Updates from this stream id will
be excluded.
to_token: The current stream token (ie the upper limit). Updates up to this
stream id will be included (modulo the 'limit' param)
target_row_count: The number of rows to try to return. If more rows are
available, we will set 'limited' in the result. In the event of a large
batch, we may return more rows than this.
Returns:
A triplet `(updates, new_last_token, limited)`, where:
* `updates` is a list of database tuples.
* `new_last_token` is the new position in stream.
* `limited` is whether there are more updates to fetch.
"""

def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
txn.execute(sql, (from_token, to_token, target_row_count))
return txn.fetchall()

return self.db.runInteraction(
def get_deltas_for_stream_id_txn(txn, stream_id):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
return txn.fetchall()

# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.

rows = await self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
) # type: List[Tuple]

# if we've got fewer rows than the limit, we're good
if len(rows) < target_row_count:
return rows, to_token, False

# we hit the limit, so reduce the upper limit so that we exclude the stream id
# of the last row in the result.
assert rows[-1][0] <= to_token
to_token = rows[-1][0] - 1

# search backwards through the list for the point to truncate
for idx in range(len(rows) - 1, 0, -1):
if rows[idx - 1][0] <= to_token:
return rows[:idx], to_token, True

# bother. We didn't get a full set of changes for even a single
# stream id. let's run the query again, without a row limit, but for
# just one stream id.
to_token += 1
rows = await self.db.runInteraction(
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
)
return rows, to_token, True
41 changes: 24 additions & 17 deletions tests/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional

from mock import Mock
import logging
from typing import Any, Dict, List, Optional, Tuple

import attr

Expand All @@ -25,6 +24,7 @@

from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
Expand Down Expand Up @@ -65,9 +65,7 @@ def prepare(self, reactor, clock, hs):
# databases objects are the same.
self.worker_hs.get_datastore().db = hs.get_datastore().db

self.test_handler = Mock(
wraps=TestReplicationDataHandler(self.worker_hs.get_datastore())
)
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler

repl_handler = ReplicationCommandHandler(self.worker_hs)
Expand All @@ -78,6 +76,9 @@ def prepare(self, reactor, clock, hs):
self._client_transport = None
self._server_transport = None

def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore())

def reconnect(self):
if self._client_transport:
self.client.close()
Expand Down Expand Up @@ -174,22 +175,28 @@ def assert_request_is_get_repl_stream_updates(
class TestReplicationDataHandler(ReplicationDataHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""

def __init__(self, hs):
super().__init__(hs)
self.streams = set()
self._received_rdata_rows = []
def __init__(self, store: BaseSlavedStore):
super().__init__(store)

# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]

# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]

def get_streams_to_replicate(self):
positions = {s: 0 for s in self.streams}
for stream, token, _ in self._received_rdata_rows:
if stream in self.streams:
positions[stream] = max(token, positions.get(stream, 0))
return positions
return self.stream_positions

async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
for r in rows:
self._received_rdata_rows.append((stream_name, token, r))
self.received_rdata_rows.append((stream_name, token, r))

if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token


@attr.s()
Expand Down Expand Up @@ -221,7 +228,7 @@ def __init__(self, reactor: IReactorTime):
super().__init__()
self.reactor = reactor

self._pull_to_push_producer = None
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]

def registerProducer(self, producer, streaming):
# Convert pull producers to push producer.
Expand Down
Loading

0 comments on commit c2e1a21

Please sign in to comment.