Skip to content

Commit

Permalink
Notify ChannelQueue that the response router thread is finishing (jup…
Browse files Browse the repository at this point in the history
  • Loading branch information
CiprianAnton authored Jun 29, 2022
1 parent 4f1e09e commit ee40dbc
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 6 deletions.
20 changes: 15 additions & 5 deletions jupyter_server/gateway/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,12 +498,14 @@ def cleanup_resources(self, restart=False):
class ChannelQueue(Queue):

channel_name: Optional[str] = None
response_router_finished: bool

def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log: Logger):
super().__init__()
self.channel_name = channel_name
self.channel_socket = channel_socket
self.log = log
self.response_router_finished = False

async def _async_get(self, timeout=None):
if timeout is None:
Expand All @@ -516,6 +518,8 @@ async def _async_get(self, timeout=None):
try:
return self.get(block=False)
except Empty:
if self.response_router_finished:
raise RuntimeError("Response router had finished")
if monotonic() > end_time:
raise
await asyncio.sleep(0)
Expand Down Expand Up @@ -598,16 +602,16 @@ class GatewayKernelClient(AsyncKernelClient):
# flag for whether execute requests should be allowed to call raw_input:
allow_stdin = False
_channels_stopped: bool
_channel_queues: Optional[dict]
_channel_queues: Optional[Dict[str, ChannelQueue]]
_control_channel: Optional[ChannelQueue]
_hb_channel: Optional[ChannelQueue]
_stdin_channel: Optional[ChannelQueue]
_iopub_channel: Optional[ChannelQueue]
_shell_channel: Optional[ChannelQueue]

def __init__(self, **kwargs):
def __init__(self, kernel_id, **kwargs):
super().__init__(**kwargs)
self.kernel_id = kwargs["kernel_id"]
self.kernel_id = kernel_id
self.channel_socket: Optional[websocket.WebSocket] = None
self.response_router: Optional[Thread] = None
self._channels_stopped = False
Expand Down Expand Up @@ -644,13 +648,14 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont
enable_multithread=True,
sslopt=ssl_options,
)
self.response_router = Thread(target=self._route_responses)
self.response_router.start()

await ensure_async(
super().start_channels(shell=shell, iopub=iopub, stdin=stdin, hb=hb, control=control)
)

self.response_router = Thread(target=self._route_responses)
self.response_router.start()

def stop_channels(self):
"""Stops all the running channels for this kernel.
Expand Down Expand Up @@ -753,6 +758,11 @@ def _route_responses(self):
if not self._channels_stopped:
self.log.warning(f"Unexpected exception encountered ({be})")

# Notify channel queues that this thread had finished and no more messages are being received
assert self._channel_queues is not None
for channel_queue in self._channel_queues.values():
channel_queue.response_router_finished = True

self.log.debug("Response router thread exiting...")


Expand Down
56 changes: 55 additions & 1 deletion tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from tornado.httpclient import HTTPRequest, HTTPResponse
from tornado.web import HTTPError

from jupyter_server.gateway.managers import ChannelQueue, GatewayClient
from jupyter_server.gateway.managers import (
ChannelQueue,
GatewayClient,
GatewayKernelManager,
)
from jupyter_server.utils import ensure_async

from .utils import expected_http_error
Expand Down Expand Up @@ -164,6 +168,15 @@ async def mock_gateway_request(url, **kwargs):
mock_http_user = "alice"


def mock_websocket_create_connection(recv_side_effect=None):
def helper(*args, **kwargs):
mock = MagicMock()
mock.recv = MagicMock(side_effect=recv_side_effect)
return mock

return helper


@pytest.fixture
def init_gateway(monkeypatch):
"""Initializes the server for use as a gateway client."""
Expand Down Expand Up @@ -321,6 +334,39 @@ async def test_gateway_shutdown(init_gateway, jp_serverapp, jp_fetch, missing_ke
assert await is_kernel_running(jp_fetch, k2) is False


@patch("websocket.create_connection", mock_websocket_create_connection(recv_side_effect=Exception))
async def test_kernel_client_response_router_notifies_channel_queue_when_finished(
init_gateway, jp_serverapp, jp_fetch
):
# create
kernel_id = await create_kernel(jp_fetch, "kspec_bar")

# get kernel manager
km: GatewayKernelManager = jp_serverapp.kernel_manager.get_kernel(kernel_id)

# create kernel client
kc = km.client()

await ensure_async(kc.start_channels())

with pytest.raises(RuntimeError):
await kc.iopub_channel.get_msg(timeout=10)

all_channels = [
kc.shell_channel,
kc.iopub_channel,
kc.stdin_channel,
kc.hb_channel,
kc.control_channel,
]
assert all(channel.response_router_finished if True else False for channel in all_channels)

await ensure_async(kc.stop_channels())

# delete
await delete_kernel(jp_fetch, kernel_id)


async def test_channel_queue_get_msg_with_invalid_timeout():
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())

Expand Down Expand Up @@ -352,6 +398,14 @@ async def test_channel_queue_get_msg_with_existing_item():
assert received_message == sent_message


async def test_channel_queue_get_msg_when_response_router_had_finished():
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())
queue.response_router_finished = True

with pytest.raises(RuntimeError):
await queue.get_msg()


#
# Test methods below...
#
Expand Down

0 comments on commit ee40dbc

Please sign in to comment.