Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more websocket connection tests and fix bugs #1085

Merged
merged 2 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions jupyter_server/services/kernels/connection/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ def handle_outgoing_message(self, stream: str, outgoing_msg: list) -> None:
else:
msg = self.session.deserialize(fed_msg_list)

if isinstance(stream, str):
stream = self.channels[stream]

channel = getattr(stream, "channel", None)
parts = fed_msg_list[1:]

Expand Down Expand Up @@ -534,7 +537,7 @@ def _reserialize_reply(self, msg_or_list, channel=None):
return json.dumps(msg, default=json_default)

def select_subprotocol(self, subprotocols):
preferred_protocol = self.settings.get("kernel_ws_protocol")
preferred_protocol = self.kernel_ws_protocol
if preferred_protocol is None:
preferred_protocol = "v1.kernel.websocket.jupyter.org"
elif preferred_protocol == "":
Expand Down Expand Up @@ -792,7 +795,7 @@ def on_restart_failed(self):
self._send_status_message("dead")

def _on_error(self, channel, msg, msg_list):
if self.kernel_manager.allow_tracebacks:
if self.multi_kernel_manager.allow_tracebacks:
return

if channel == "iopub":
Expand Down
46 changes: 46 additions & 0 deletions tests/services/kernels/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import asyncio
import json
from unittest.mock import MagicMock

from jupyter_client.jsonutil import json_clean, json_default
from jupyter_client.session import Session
from tornado.httpserver import HTTPRequest
from tornado.websocket import WebSocketHandler

from jupyter_server.serverapp import ServerApp
from jupyter_server.services.kernels.connection.channels import (
ZMQChannelsWebsocketConnection,
)


async def test_websocket_connection(jp_serverapp):
app: ServerApp = jp_serverapp
kernel_id = await app.kernel_manager.start_kernel()
kernel = app.kernel_manager.get_kernel(kernel_id)
request = HTTPRequest("foo", "GET")
request.connection = MagicMock()
handler = WebSocketHandler(app.web_app, request)
handler.ws_connection = MagicMock()
handler.ws_connection.is_closing = lambda: False
conn = ZMQChannelsWebsocketConnection(parent=kernel, websocket_handler=handler)
await conn.prepare()
conn.connect()
await asyncio.wrap_future(conn.nudge())
session: Session = kernel.session
msg = session.msg("data_pub", content={"a": "b"})
data = json.dumps(
json_clean(msg),
default=json_default,
ensure_ascii=False,
allow_nan=False,
)
conn.handle_incoming_message(data)
conn.handle_outgoing_message("iopub", session.serialize(msg))
assert (
conn.select_subprotocol(["v1.kernel.websocket.jupyter.org"])
== "v1.kernel.websocket.jupyter.org"
)
conn.write_stderr("test", {})
conn.on_kernel_restarted()
conn.on_restart_failed()
conn._on_error("shell", msg, session.serialize(msg))