diff --git a/locust/exception.py b/locust/exception.py index b42cbffd86..66ae64ad10 100644 --- a/locust/exception.py +++ b/locust/exception.py @@ -54,6 +54,22 @@ class RPCError(Exception): """ +class RPCSendError(Exception): + """ + Exception when sending message to client. + + When raised from zmqrpc, sending can be retried or RPC can be reestablished. + """ + + +class RPCReceiveError(Exception): + """ + Exception when receiving message from client is interrupted or message is corrupted. + + When raised from zmqrpc, client connection should be reestablished. + """ + + class AuthCredentialsError(ValueError): """ Exception when the auth credentials provided diff --git a/locust/rpc/zmqrpc.py b/locust/rpc/zmqrpc.py index 8a4634b3f3..2ec3b7c000 100644 --- a/locust/rpc/zmqrpc.py +++ b/locust/rpc/zmqrpc.py @@ -1,7 +1,7 @@ import zmq.green as zmq from .protocol import Message from locust.util.exception_handler import retry -from locust.exception import RPCError +from locust.exception import RPCError, RPCSendError, RPCReceiveError import zmq.error as zmqerr import msgpack.exceptions as msgerr @@ -19,21 +19,21 @@ def send(self, msg): try: self.socket.send(msg.serialize(), zmq.NOBLOCK) except zmqerr.ZMQError as e: - raise RPCError("ZMQ sent failure") from e + raise RPCSendError("ZMQ sent failure") from e @retry() def send_to_client(self, msg): try: self.socket.send_multipart([msg.node_id.encode(), msg.serialize()]) except zmqerr.ZMQError as e: - raise RPCError("ZMQ sent failure") from e + raise RPCSendError("ZMQ sent failure") from e def recv(self): try: data = self.socket.recv() msg = Message.unserialize(data) except msgerr.ExtraData as e: - raise RPCError("ZMQ interrupted message") from e + raise RPCReceiveError("ZMQ interrupted message") from e except zmqerr.ZMQError as e: raise RPCError("ZMQ network broken") from e return msg @@ -42,15 +42,18 @@ def recv_from_client(self): try: data = self.socket.recv_multipart() addr = data[0].decode() - msg = Message.unserialize(data[1]) - except (UnicodeDecodeError, msgerr.ExtraData) as e: - raise RPCError("ZMQ interrupted message") from e + except UnicodeDecodeError as e: + raise RPCReceiveError("ZMQ interrupted or corrupted message") from e except zmqerr.ZMQError as e: raise RPCError("ZMQ network broken") from e + try: + msg = Message.unserialize(data[1]) + except (UnicodeDecodeError, msgerr.ExtraData) as e: + raise RPCReceiveError("ZMQ interrupted or corrupted message") from e return addr, msg - def close(self): - self.socket.close() + def close(self, linger=None): + self.socket.close(linger=linger) class Server(BaseSocket): diff --git a/locust/runners.py b/locust/runners.py index 37156fb86e..4570af5eef 100644 --- a/locust/runners.py +++ b/locust/runners.py @@ -46,7 +46,7 @@ from . import User from locust import __version__ from .dispatch import UsersDispatcher -from .exception import RPCError +from .exception import RPCError, RPCReceiveError, RPCSendError from .log import greenlet_exception_logger from .rpc import ( Message, @@ -946,9 +946,9 @@ def heartbeat_worker(self) -> NoReturn: self.start(user_count=self.target_user_count, spawn_rate=self.spawn_rate) def reset_connection(self) -> None: - logger.info("Reset connection to worker") + logger.info("Resetting RPC server and all client connections.") try: - self.server.close() + self.server.close(linger=0) self.server = rpc.Server(self.master_bind_host, self.master_bind_port) self.connection_broken = False except RPCError as e: @@ -958,12 +958,26 @@ def client_listener(self) -> NoReturn: while True: try: client_id, msg = self.server.recv_from_client() + except RPCReceiveError as e: + logger.error(f"RPCError when receiving from client: {e}. Will reset client {client_id}.") + try: + self.server.send_to_client(Message("reconnect", None, client_id)) + except Exception as e: + logger.error(f"Error sending reconnect message to client: {e}. Will reset RPC server.") + self.connection_broken = True + gevent.sleep(FALLBACK_INTERVAL) + continue + except RPCSendError as e: + logger.error(f"Error sending reconnect message to client: {e}. Will reset RPC server.") + self.connection_broken = True + gevent.sleep(FALLBACK_INTERVAL) + continue except RPCError as e: - if self.clients.ready: - logger.error(f"RPCError found when receiving from client: {e}") + if self.clients.ready or self.clients.spawning or self.clients.running: + logger.error(f"RPCError: {e}. Will reset RPC server.") else: logger.debug( - "RPCError found when receiving from client: %s (but no clients were expected to be connected anyway)" + "RPCError when receiving from client: %s (but no clients were expected to be connected anyway)" % (e) ) self.connection_broken = True @@ -1285,6 +1299,9 @@ def worker(self) -> NoReturn: self.stop() self._send_stats() # send a final report, in case there were any samples not yet reported self.greenlet.kill(block=True) + elif msg.type == "reconnect": + logger.warning("Received reconnect message from master. Resetting RPC connection.") + self.reset_connection() elif msg.type in self.custom_messages: logger.debug(f"Received {msg.type} message from master") self.custom_messages[msg.type](environment=self.environment, msg=msg) diff --git a/locust/test/test_runners.py b/locust/test/test_runners.py index bd65031a24..14b0ee66f8 100644 --- a/locust/test/test_runners.py +++ b/locust/test/test_runners.py @@ -21,10 +21,7 @@ ) from locust.argument_parser import parse_options from locust.env import Environment -from locust.exception import ( - RPCError, - StopUser, -) +from locust.exception import RPCError, StopUser, RPCReceiveError from locust.main import create_environment from locust.rpc import Message from locust.runners import ( @@ -49,6 +46,7 @@ from .util import patch_env NETWORK_BROKEN = "network broken" +BAD_MESSAGE = "bad message" def mocked_rpc(raise_on_close=True): @@ -83,9 +81,11 @@ def recv_from_client(self): msg = Message.unserialize(results) if msg.data == NETWORK_BROKEN: raise RPCError() + if msg.data == BAD_MESSAGE: + raise RPCReceiveError("Bad message") return msg.node_id, msg - def close(self): + def close(self, linger=None): if self.raise_error_on_close: raise RPCError() else: @@ -2923,6 +2923,34 @@ def test_master_discard_first_client_ready(self): self.assertEqual("ack", server.outbox[0][1].type) self.assertEqual(1, len(server.outbox)) + def test_worker_sends_bad_message_to_master(self): + """ + Validate master sends reconnect message to worker when it receives a bad message. + """ + + class TestUser(User): + @task + def my_task(self): + pass + + with mock.patch("locust.rpc.rpc.Server", mocked_rpc()) as server: + master = self.get_runner(user_classes=[TestUser]) + server.mocked_send(Message("client_ready", __version__, "zeh_fake_client1")) + self.assertEqual(1, len(master.clients)) + self.assertTrue( + "zeh_fake_client1" in master.clients, "Could not find fake client in master instance's clients dict" + ) + + master.start(10, 10) + sleep(0.1) + server.mocked_send(Message("stats", BAD_MESSAGE, "zeh_fake_client1")) + self.assertEqual(4, len(server.outbox)) + + # Expected message order in outbox: ack, spawn, reconnect, ack + self.assertEqual( + "reconnect", server.outbox[2][1].type, "Master didn't send worker reconnect message when expected." + ) + class TestWorkerRunner(LocustTestCase): def setUp(self): @@ -3201,6 +3229,54 @@ def my_task(self): worker.quit() + def test_reset_rpc_connection_to_master(self): + """ + Validate worker resets RPC connection to master on "reconnect" message. + """ + + class MyUser(User): + wait_time = constant(1) + + @task + def my_task(self): + pass + + with mock.patch("locust.rpc.rpc.Client", mocked_rpc(raise_on_close=False)) as client: + client_id = id(client) + worker = self.get_runner(environment=Environment(), user_classes=[MyUser], client=client) + client.mocked_send( + Message( + "spawn", + { + "timestamp": 1605538584, + "user_classes_count": {"MyUser": 10}, + "host": "", + "stop_timeout": None, + "parsed_options": {}, + }, + "dummy_client_id", + ) + ) + sleep(0.6) + self.assertEqual(STATE_RUNNING, worker.state) + with self.assertLogs("locust.runners") as capture: + with mock.patch("locust.rpc.rpc.Client.close") as close: + client.mocked_send( + Message( + "reconnect", + None, + "dummy_client_id", + ) + ) + sleep(0) + worker.spawning_greenlet.join() + worker.quit() + close.assert_called_once() + self.assertIn( + "WARNING:locust.runners:Received reconnect message from master. Resetting RPC connection.", + capture.output, + ) + def test_change_user_count_during_spawning(self): class MyUser(User): wait_time = constant(1) diff --git a/locust/test/test_zmqrpc.py b/locust/test/test_zmqrpc.py index abaa8144ad..6ee2b17cb6 100644 --- a/locust/test/test_zmqrpc.py +++ b/locust/test/test_zmqrpc.py @@ -2,7 +2,7 @@ import zmq from locust.rpc import zmqrpc, Message from locust.test.testcases import LocustTestCase -from locust.exception import RPCError +from locust.exception import RPCError, RPCSendError, RPCReceiveError class ZMQRPC_tests(LocustTestCase): @@ -50,5 +50,5 @@ def test_rpc_error(self): with self.assertRaises(RPCError): server = zmqrpc.Server("127.0.0.1", server.port) server.close() - with self.assertRaises(RPCError): + with self.assertRaises(RPCSendError): server.send_to_client(Message("test", "message", "identity"))