diff --git a/docs/source/operators/config-env-debug.md b/docs/source/operators/config-env-debug.md index 79bae20aa..4ab98dc13 100644 --- a/docs/source/operators/config-env-debug.md +++ b/docs/source/operators/config-env-debug.md @@ -27,6 +27,10 @@ The following environment variables may be useful for troubleshooting: EG_POLL_INTERVAL=0.5 The interval (in seconds) to wait before checking poll results again. + EG_RESTART_STATUS_POLL_INTERVAL=1.0 + The interval (in seconds) to wait before polling for the restart status again when duplicate restart request + for the same kernel is received or when a shutdown request is received while kernel is still restarting. + EG_REMOVE_CONTAINER=True Used by launch_docker.py, indicates whether the kernel's docker container should be removed following its shutdown. Set this value to 'False' if you want the container diff --git a/enterprise_gateway/services/kernels/remotemanager.py b/enterprise_gateway/services/kernels/remotemanager.py index 5a3e38dfe..0971d3f2c 100644 --- a/enterprise_gateway/services/kernels/remotemanager.py +++ b/enterprise_gateway/services/kernels/remotemanager.py @@ -1,10 +1,11 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. """Kernel managers that operate against a remote process.""" - +import asyncio import os import re import signal +import time import uuid from jupyter_client.ioloop.manager import AsyncIOLoopKernelManager @@ -18,6 +19,9 @@ from ..processproxies.processproxy import LocalProcessProxy, RemoteProcessProxy from ..sessions.kernelsessionmanager import KernelSessionManager +default_kernel_launch_timeout = float(os.getenv("EG_KERNEL_LAUNCH_TIMEOUT", "30")) +kernel_restart_status_poll_interval = float(os.getenv("EG_RESTART_STATUS_POLL_INTERVAL", 1.0)) + def import_item(name): """Import and return ``bar`` given the string ``foo.bar``. @@ -189,6 +193,47 @@ async def start_kernel(self, *args, **kwargs): self.parent.kernel_session_manager.create_session(kernel_id, **kwargs) return kernel_id + async def restart_kernel(self, kernel_id): + kernel = self.get_kernel(kernel_id) + if kernel.restarting: # assuming duplicate request. + await self.wait_for_restart_finish(kernel_id, "restart") + self.log.info("Skipping kernel restart as this was duplicate request.") + return + try: + kernel.restarting = True # Moved in out of RemoteKernelManager + await super().restart_kernel(kernel_id) + finally: + kernel.restarting = False + + async def shutdown_kernel(self, kernel_id, now=False, restart=False): + kernel = self.get_kernel(kernel_id) + if kernel.restarting: + await self.wait_for_restart_finish(kernel_id, "shutdown") + try: + await super().shutdown_kernel(kernel_id, now, restart) + except KeyError as ke: # this is hit for multiple shutdown request. + self.log.exception(f"Exception while shutting down kernel: '{kernel_id}': {ke}") + raise web.HTTPError(404, "Kernel does not exist: %s" % kernel_id) + + async def wait_for_restart_finish(self, kernel_id, action="shutdown"): + kernel = self.get_kernel(kernel_id) + start_time = float(time.time()) # epoc time + timeout = kernel.kernel_launch_timeout + poll_time = kernel_restart_status_poll_interval + self.log.info( + f"Kernel '{kernel_id}' was restarting when {action} request received. Polling every {poll_time} " + f"seconds for next {timeout} seconds for kernel to complete its restart." + ) + while kernel.restarting: + now = float(time.time()) + if (now - start_time) > timeout: + self.log.info( + f"Timeout: Exiting restart wait loop in order to {action} kernel '{kernel_id}'." + ) + break + await asyncio.sleep(poll_time) + return + def _enforce_kernel_limits(self, username: str) -> None: """ If MaxKernels or MaxKernelsPerUser are configured, enforce the respective values. @@ -341,6 +386,7 @@ def __init__(self, **kwargs): self.sigint_value = None self.kernel_id = None self.user_overrides = {} + self.kernel_launch_timeout = default_kernel_launch_timeout self.restarting = False # need to track whether we're in a restart situation or not # If this instance supports port caching, then disable cache_ports since we don't need this @@ -412,6 +458,10 @@ def _capture_user_overrides(self, **kwargs): of the kernelspec env stanza that would have otherwise overridden the user-provided values. """ env = kwargs.get("env", {}) + # If KERNEL_LAUNCH_TIMEOUT is passed in the payload, override it. + self.kernel_launch_timeout = float( + env.get("KERNEL_LAUNCH_TIMEOUT", default_kernel_launch_timeout) + ) self.user_overrides.update( { key: value @@ -504,7 +554,8 @@ async def restart_kernel(self, now=False, **kwargs): Any options specified here will overwrite those used to launch the kernel. """ - self.restarting = True + if now: # if auto-restarting (when now is True), indicate we're restarting. + self.restarting = True kernel_id = self.kernel_id or os.path.basename(self.connection_file).replace( "kernel-", "" ).replace(".json", "") @@ -535,7 +586,8 @@ async def restart_kernel(self, now=False, **kwargs): # Refresh persisted state. if self.kernel_session_manager: self.kernel_session_manager.refresh_session(kernel_id) - self.restarting = False + if now: + self.restarting = False async def signal_kernel(self, signum): """