From ee40dbc985e45f5c2699626e0f8f350b7f4989ac Mon Sep 17 00:00:00 2001 From: Ciprian Anton Date: Wed, 29 Jun 2022 18:28:21 +0300 Subject: [PATCH] Notify ChannelQueue that the response router thread is finishing (#896) --- jupyter_server/gateway/managers.py | 20 ++++++++--- tests/test_gateway.py | 56 +++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index 71441546ae..4be1c00571 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -498,12 +498,14 @@ def cleanup_resources(self, restart=False): class ChannelQueue(Queue): channel_name: Optional[str] = None + response_router_finished: bool def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log: Logger): super().__init__() self.channel_name = channel_name self.channel_socket = channel_socket self.log = log + self.response_router_finished = False async def _async_get(self, timeout=None): if timeout is None: @@ -516,6 +518,8 @@ async def _async_get(self, timeout=None): try: return self.get(block=False) except Empty: + if self.response_router_finished: + raise RuntimeError("Response router had finished") if monotonic() > end_time: raise await asyncio.sleep(0) @@ -598,16 +602,16 @@ class GatewayKernelClient(AsyncKernelClient): # flag for whether execute requests should be allowed to call raw_input: allow_stdin = False _channels_stopped: bool - _channel_queues: Optional[dict] + _channel_queues: Optional[Dict[str, ChannelQueue]] _control_channel: Optional[ChannelQueue] _hb_channel: Optional[ChannelQueue] _stdin_channel: Optional[ChannelQueue] _iopub_channel: Optional[ChannelQueue] _shell_channel: Optional[ChannelQueue] - def __init__(self, **kwargs): + def __init__(self, kernel_id, **kwargs): super().__init__(**kwargs) - self.kernel_id = kwargs["kernel_id"] + self.kernel_id = kernel_id self.channel_socket: Optional[websocket.WebSocket] = None self.response_router: Optional[Thread] = None self._channels_stopped = False @@ -644,13 +648,14 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont enable_multithread=True, sslopt=ssl_options, ) - self.response_router = Thread(target=self._route_responses) - self.response_router.start() await ensure_async( super().start_channels(shell=shell, iopub=iopub, stdin=stdin, hb=hb, control=control) ) + self.response_router = Thread(target=self._route_responses) + self.response_router.start() + def stop_channels(self): """Stops all the running channels for this kernel. @@ -753,6 +758,11 @@ def _route_responses(self): if not self._channels_stopped: self.log.warning(f"Unexpected exception encountered ({be})") + # Notify channel queues that this thread had finished and no more messages are being received + assert self._channel_queues is not None + for channel_queue in self._channel_queues.values(): + channel_queue.response_router_finished = True + self.log.debug("Response router thread exiting...") diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 8c2a221de4..b45ed982db 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -14,7 +14,11 @@ from tornado.httpclient import HTTPRequest, HTTPResponse from tornado.web import HTTPError -from jupyter_server.gateway.managers import ChannelQueue, GatewayClient +from jupyter_server.gateway.managers import ( + ChannelQueue, + GatewayClient, + GatewayKernelManager, +) from jupyter_server.utils import ensure_async from .utils import expected_http_error @@ -164,6 +168,15 @@ async def mock_gateway_request(url, **kwargs): mock_http_user = "alice" +def mock_websocket_create_connection(recv_side_effect=None): + def helper(*args, **kwargs): + mock = MagicMock() + mock.recv = MagicMock(side_effect=recv_side_effect) + return mock + + return helper + + @pytest.fixture def init_gateway(monkeypatch): """Initializes the server for use as a gateway client.""" @@ -321,6 +334,39 @@ async def test_gateway_shutdown(init_gateway, jp_serverapp, jp_fetch, missing_ke assert await is_kernel_running(jp_fetch, k2) is False +@patch("websocket.create_connection", mock_websocket_create_connection(recv_side_effect=Exception)) +async def test_kernel_client_response_router_notifies_channel_queue_when_finished( + init_gateway, jp_serverapp, jp_fetch +): + # create + kernel_id = await create_kernel(jp_fetch, "kspec_bar") + + # get kernel manager + km: GatewayKernelManager = jp_serverapp.kernel_manager.get_kernel(kernel_id) + + # create kernel client + kc = km.client() + + await ensure_async(kc.start_channels()) + + with pytest.raises(RuntimeError): + await kc.iopub_channel.get_msg(timeout=10) + + all_channels = [ + kc.shell_channel, + kc.iopub_channel, + kc.stdin_channel, + kc.hb_channel, + kc.control_channel, + ] + assert all(channel.response_router_finished if True else False for channel in all_channels) + + await ensure_async(kc.stop_channels()) + + # delete + await delete_kernel(jp_fetch, kernel_id) + + async def test_channel_queue_get_msg_with_invalid_timeout(): queue = ChannelQueue("iopub", MagicMock(), logging.getLogger()) @@ -352,6 +398,14 @@ async def test_channel_queue_get_msg_with_existing_item(): assert received_message == sent_message +async def test_channel_queue_get_msg_when_response_router_had_finished(): + queue = ChannelQueue("iopub", MagicMock(), logging.getLogger()) + queue.response_router_finished = True + + with pytest.raises(RuntimeError): + await queue.get_msg() + + # # Test methods below... #