diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index 48e3653744..a3acec5a53 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -92,7 +92,7 @@ async def kernel_model(self, kernel_id): The uuid of the kernel. """ model = None - km = self.get_kernel(kernel_id) + km = self.get_kernel(str(kernel_id)) if km: model = km.kernel return model @@ -166,13 +166,14 @@ async def interrupt_kernel(self, kernel_id, **kwargs): async def shutdown_all(self, now=False): """Shutdown all kernels.""" - for kernel_id in self._kernels: + kids = list(self._kernels) + for kernel_id in kids: km = self.get_kernel(kernel_id) await km.shutdown_kernel(now=now) self.remove_kernel(kernel_id) async def cull_kernels(self): - """Override cull_kernels so we can be sure their state is current.""" + """Override cull_kernels, so we can be sure their state is current.""" await self.list_kernels() await super().cull_kernels() @@ -295,7 +296,7 @@ class GatewaySessionManager(SessionManager): kernel_manager = Instance("jupyter_server.gateway.managers.GatewayMappingKernelManager") async def kernel_culled(self, kernel_id): - """Checks if the kernel is still considered alive and returns true if its not found.""" + """Checks if the kernel is still considered alive and returns true if it's not found.""" kernel = None try: km = self.kernel_manager.get_kernel(kernel_id) @@ -387,7 +388,7 @@ async def refresh_model(self, model=None): if isinstance(self.parent, AsyncMappingKernelManager): # Update connections only if there's a mapping kernel manager parent for # this kernel manager. The current kernel manager instance may not have - # an parent instance if, say, a server extension is using another application + # a parent instance if, say, a server extension is using another application # (e.g., papermill) that uses a KernelManager instance directly. self.parent._kernel_connections[self.kernel_id] = int(model["connections"]) @@ -448,8 +449,14 @@ async def shutdown_kernel(self, now=False, restart=False): if self.has_kernel: self.log.debug("Request shutdown kernel at: %s", self.kernel_url) - response = await gateway_request(self.kernel_url, method="DELETE") - self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason) + try: + response = await gateway_request(self.kernel_url, method="DELETE") + self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason) + except web.HTTPError as error: + if error.status_code == 404: + self.log.debug("Shutdown kernel response: kernel not found (ignored)") + else: + raise async def restart_kernel(self, **kw): """Restarts a kernel via HTTP.""" @@ -518,7 +525,7 @@ def send(self, msg: dict) -> None: @staticmethod def serialize_datetime(dt): - if isinstance(dt, (datetime.datetime)): + if isinstance(dt, datetime.datetime): return dt.timestamp() return None @@ -597,7 +604,7 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont """Starts the channels for this kernel. For this class, we establish a websocket connection to the destination - and setup the channel-based queues on which applicable messages will + and set up the channel-based queues on which applicable messages will be posted. """ @@ -608,10 +615,11 @@ async def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, cont "channels", ) # Gather cert info in case where ssl is desired... - ssl_options = {} - ssl_options["ca_certs"] = GatewayClient.instance().ca_certs - ssl_options["certfile"] = GatewayClient.instance().client_cert - ssl_options["keyfile"] = GatewayClient.instance().client_key + ssl_options = { + "ca_certs": GatewayClient.instance().ca_certs, + "certfile": GatewayClient.instance().client_cert, + "keyfile": GatewayClient.instance().client_key, + } self.channel_socket = websocket.create_connection( ws_url, @@ -722,7 +730,7 @@ def _route_responses(self): self._channel_queues[channel].put_nowait(response_message) except websocket.WebSocketConnectionClosedException: - pass # websocket closure most likely due to shutdown + pass # websocket closure most likely due to shut down except BaseException as be: if not self._channels_stopped: diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 4eb3c4a9bf..4134088861 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -135,6 +135,9 @@ async def mock_gateway_request(url, **kwargs): # Shutdown existing kernel if endpoint.rfind("/api/kernels/") >= 0 and method == "DELETE": requested_kernel_id = endpoint.rpartition("/")[2] + if requested_kernel_id not in running_kernels: + raise HTTPError(404, message="Kernel does not exist: %s" % requested_kernel_id) + running_kernels.pop( requested_kernel_id ) # Simulate shutdown by removing kernel from running set @@ -292,6 +295,29 @@ async def test_gateway_kernel_lifecycle(init_gateway, jp_fetch): assert await is_kernel_running(jp_fetch, kernel_id) is False +@pytest.mark.parametrize("missing_kernel", [True, False]) +async def test_gateway_shutdown(init_gateway, jp_serverapp, jp_fetch, missing_kernel): + # Validate server shutdown when multiple gateway kernels are present or + # we've lost track of at least one (missing) kernel + + # create two kernels + k1 = await create_kernel(jp_fetch, "kspec_bar") + k2 = await create_kernel(jp_fetch, "kspec_bar") + + # ensure they're considered running + assert await is_kernel_running(jp_fetch, k1) is True + assert await is_kernel_running(jp_fetch, k2) is True + + if missing_kernel: + running_kernels.pop(k1) # "terminate" kernel w/o our knowledge + + with mocked_gateway: + await jp_serverapp.kernel_manager.shutdown_all() + + assert await is_kernel_running(jp_fetch, k1) is False + assert await is_kernel_running(jp_fetch, k2) is False + + # # Test methods below... #