Skip to content

Commit

Permalink
Transition to async/await
Browse files Browse the repository at this point in the history
  • Loading branch information
kevin-bates committed Mar 20, 2020
1 parent f6ad2d1 commit 10df48f
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 149 deletions.
7 changes: 3 additions & 4 deletions enterprise_gateway/enterprisegatewayapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Distributed under the terms of the Modified BSD License.
"""Enterprise Gateway Jupyter application."""

import asyncio
import errno
import getpass
import logging
Expand All @@ -19,15 +20,13 @@
from zmq.eventloop import ioloop
ioloop.install()

from tornado import httpserver
from tornado import web
from tornado import httpserver, web
from tornado.log import enable_pretty_logging, LogFormatter

from traitlets import default, List, Set, Unicode, Type, Instance, Bool, CBool, Integer, observe
from traitlets.config import Configurable
from jupyter_core.application import JupyterApp, base_aliases
from jupyter_client.kernelspec import KernelSpecManager
from notebook.services.kernels.kernelmanager import MappingKernelManager
from notebook.notebookapp import random_ports
from notebook.utils import url_path_join

Expand Down Expand Up @@ -678,7 +677,7 @@ def shutdown(self):
"""Shuts down all running kernels."""
kids = self.kernel_manager.list_kernel_ids()
for kid in kids:
self.kernel_manager.shutdown_kernel(kid, now=True)
asyncio.get_event_loop().run_until_complete(self.kernel_manager.shutdown_kernel(kid, now=True))

def stop(self):
"""
Expand Down
35 changes: 15 additions & 20 deletions enterprise_gateway/services/kernels/remotemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import uuid

from tornado import gen, web
from tornado import web
from ipython_genutils.py3compat import unicode_type
from ipython_genutils.importstring import import_item
from notebook.services.kernels.kernelmanager import AsyncMappingKernelManager
Expand Down Expand Up @@ -60,8 +60,7 @@ def _refresh_kernel(self, kernel_id):
self.parent.kernel_session_manager.load_session(kernel_id)
return self.parent.kernel_session_manager.start_session(kernel_id)

@gen.coroutine
def start_kernel(self, *args, **kwargs):
async def start_kernel(self, *args, **kwargs):
"""Starts a kernel for a session and return its kernel_id.
Returns
Expand All @@ -73,9 +72,9 @@ def start_kernel(self, *args, **kwargs):
username = KernelSessionManager.get_kernel_username(**kwargs)
self.log.debug("RemoteMappingKernelManager.start_kernel: {kernel_name}, kernel_username: {username}".
format(kernel_name=kwargs['kernel_name'], username=username))
kernel_id = yield super(RemoteMappingKernelManager, self).start_kernel(*args, **kwargs)
kernel_id = await super(RemoteMappingKernelManager, self).start_kernel(*args, **kwargs)
self.parent.kernel_session_manager.create_session(kernel_id, **kwargs)
raise gen.Return(kernel_id)
return kernel_id

def remove_kernel(self, kernel_id):
""" Removes the kernel associated with `kernel_id` from the internal map and deletes the kernel session. """
Expand Down Expand Up @@ -223,8 +222,7 @@ def __init__(self, **kwargs):
if hasattr(self, "cache_ports"):
self.cache_ports = False

@gen.coroutine
def start_kernel(self, **kwargs):
async def start_kernel(self, **kwargs):
"""Starts a kernel in a separate process.
Where the started kernel resides depends on the configured process proxy.
Expand All @@ -237,7 +235,7 @@ def start_kernel(self, **kwargs):
"""
self._get_process_proxy()
self._capture_user_overrides(**kwargs)
yield super(RemoteKernelManager, self).start_kernel(**kwargs)
await super(RemoteKernelManager, self).start_kernel(**kwargs)

def _capture_user_overrides(self, **kwargs):
"""
Expand Down Expand Up @@ -273,8 +271,7 @@ def from_ns(match):
return [pat.sub(from_ns, arg) for arg in cmd]
return cmd

@gen.coroutine
def _launch_kernel(self, kernel_cmd, **kwargs):
async def _launch_kernel(self, kernel_cmd, **kwargs):
# Note: despite the under-bar prefix to this method, the jupyter_client comment says that
# this method should be "[overridden] in a subclass to launch kernel subprocesses differently".
# So that's what we've done.
Expand All @@ -293,21 +290,19 @@ def _launch_kernel(self, kernel_cmd, **kwargs):
del env['KG_AUTH_TOKEN']

self.log.debug("Launching kernel: {} with command: {}".format(self.kernel_spec.display_name, kernel_cmd))
res = yield self.process_proxy.launch_process(kernel_cmd, **kwargs)
raise gen.Return(res)
proxy = await self.process_proxy.launch_process(kernel_cmd, **kwargs)
return proxy

@gen.coroutine
def request_shutdown(self, restart=False):
""" Send a shutdown request via control channel and process proxy (if remote). """
yield super(RemoteKernelManager, self).request_shutdown(restart)
super(RemoteKernelManager, self).request_shutdown(restart)

# If we're using a remote proxy, we need to send the launcher indication that we're
# shutting down so it can exit its listener thread, if its using one.
if isinstance(self.process_proxy, RemoteProcessProxy):
self.process_proxy.shutdown_listener()

@gen.coroutine
def restart_kernel(self, now=False, **kwargs):
async def restart_kernel(self, now=False, **kwargs):
"""Restarts a kernel with the arguments that were used to launch it.
This is an automatic restart request (now=True) AND this is associated with a
Expand Down Expand Up @@ -340,7 +335,7 @@ def restart_kernel(self, now=False, **kwargs):
# Use the parent mapping kernel manager so activity monitoring and culling is also shutdown
self.parent.shutdown_kernel(kernel_id, now=now)
return
yield super(RemoteKernelManager, self).restart_kernel(now, **kwargs)
await super(RemoteKernelManager, self).restart_kernel(now, **kwargs)
if isinstance(self.process_proxy, RemoteProcessProxy): # for remote kernels...
# Re-establish activity watching...
if self._activity_stream:
Expand All @@ -351,7 +346,7 @@ def restart_kernel(self, now=False, **kwargs):
self.parent.parent.kernel_session_manager.refresh_session(kernel_id)
self.restarting = False

def signal_kernel(self, signum):
async def signal_kernel(self, signum):
"""Sends signal `signum` to the kernel process. """
if self.has_kernel:
if signum == signal.SIGINT:
Expand Down Expand Up @@ -383,7 +378,6 @@ def signal_kernel(self, signum):
else:
raise RuntimeError("Cannot signal kernel. No kernel is running!")

@gen.coroutine
def cleanup(self, connection_file=True):
"""Clean up resources when the kernel is shut down"""

Expand All @@ -394,7 +388,8 @@ def cleanup(self, connection_file=True):
if self.process_proxy:
self.process_proxy.cleanup()
self.process_proxy = None
yield super(RemoteKernelManager, self).cleanup(connection_file)

super(RemoteKernelManager, self).cleanup(connection_file)

def write_connection_file(self):
"""Write connection info to JSON dict in self.connection_file if the kernel is local.
Expand Down
40 changes: 17 additions & 23 deletions enterprise_gateway/services/processproxies/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
# Distributed under the terms of the Modified BSD License.
"""Code related to managing kernels running in Conductor clusters."""

import asyncio
import os
import signal
import json
import time
import subprocess
import socket
import re

from tornado import gen
import signal
import socket
import subprocess

from jupyter_client import launch_kernel, localinterfaces

Expand All @@ -36,10 +34,9 @@ def __init__(self, kernel_manager, proxy_config):
self.conductor_endpoint = proxy_config.get('conductor_endpoint',
kernel_manager.parent.parent.conductor_endpoint)

@gen.coroutine
def launch_process(self, kernel_cmd, **kwargs):
async def launch_process(self, kernel_cmd, **kwargs):
"""Launches the specified process within a Conductor cluster environment."""
yield super(ConductorClusterProcessProxy, self).launch_process(kernel_cmd, **kwargs)
await super(ConductorClusterProcessProxy, self).launch_process(kernel_cmd, **kwargs)
# Get cred from process env
env_dict = dict(os.environ.copy())
if env_dict and 'EGO_SERVICE_CREDENTIAL' in env_dict:
Expand All @@ -58,8 +55,8 @@ def launch_process(self, kernel_cmd, **kwargs):
self.env = kwargs.get('env')
self.log.debug("Conductor cluster kernel launched using Conductor endpoint: {}, pid: {}, Kernel ID: {}, "
"cmd: '{}'".format(self.conductor_endpoint, self.local_proc.pid, self.kernel_id, kernel_cmd))
yield self.confirm_remote_startup()
raise gen.Return(self)
await self.confirm_remote_startup()
return self

def _update_launch_info(self, kernel_cmd, **kwargs):
""" Dynamically assemble the spark-submit configuration passed from NB2KG."""
Expand Down Expand Up @@ -116,8 +113,7 @@ def send_signal(self, signum):
else:
return super(ConductorClusterProcessProxy, self).send_signal(signum)

@gen.coroutine
def kill(self):
async def kill(self):
"""Kill a kernel.
:return: None if the application existed and is not in RUNNING state, False otherwise.
"""
Expand All @@ -130,7 +126,7 @@ def kill(self):
i = 1
state = self._query_app_state_by_driver_id(self.driver_id)
while state not in ConductorClusterProcessProxy.final_states and i <= max_poll_attempts:
yield gen.sleep(poll_interval)
await asyncio.sleep(poll_interval)
state = self._query_app_state_by_driver_id(self.driver_id)
i = i + 1

Expand All @@ -141,7 +137,7 @@ def kill(self):

self.log.debug("ConductorClusterProcessProxy.kill, application ID: {}, kernel ID: {}, state: {}"
.format(self.application_id, self.kernel_id, state))
raise gen.Return(result)
return result

def cleanup(self):
# we might have a defunct process (if using waitAppCompletion = false) - so poll, kill, wait when we have
Expand Down Expand Up @@ -176,8 +172,7 @@ def _parse_driver_submission_id(self, submission_response):
self.driver_id = driver_id[0]
self.log.debug("Driver ID: {}".format(driver_id[0]))

@gen.coroutine
def confirm_remote_startup(self):
async def confirm_remote_startup(self):
""" Confirms the application is in a started state before returning. Should post-RUNNING states be
unexpectedly encountered ('FINISHED', 'KILLED', 'RECLAIMED') then we must throw, otherwise the rest
of the gateway will believe its talking to a valid kernel.
Expand All @@ -191,7 +186,7 @@ def confirm_remote_startup(self):
output = self.local_proc.stderr.read().decode("utf-8")
self._parse_driver_submission_id(output)
i += 1
yield self.handle_timeout()
await self.handle_timeout()

if self._get_application_id(True):
# Once we have an application ID, start monitoring state, obtain assigned host and get connection info
Expand All @@ -207,7 +202,7 @@ def confirm_remote_startup(self):
format(i, app_state, self.assigned_host, self.kernel_id, self.application_id))

if self.assigned_host != '':
ready_to_connect = yield self.receive_connection_info()
ready_to_connect = await self.receive_connection_info()
else:
self.detect_launch_failure()

Expand All @@ -227,10 +222,9 @@ def _get_application_state(self):
self.assigned_ip = socket.gethostbyname(self.assigned_host)
return app_state

@gen.coroutine
def handle_timeout(self):
async def handle_timeout(self):
"""Checks to see if the kernel launch timeout has been exceeded while awaiting connection info."""
yield gen.sleep(poll_interval)
await asyncio.sleep(poll_interval)
time_interval = RemoteProcessProxy.get_time_diff(self.start_time, RemoteProcessProxy.get_current_time())

if time_interval > self.kernel_launch_timeout:
Expand All @@ -245,7 +239,7 @@ def handle_timeout(self):
else:
reason = "App {} is WAITING, but waited too long ({} secs) to get connection file". \
format(self.application_id, self.kernel_launch_timeout)
self.kill()
await self.kill()
timeout_message = "KernelID: '{}' launch timeout due to: {}".format(self.kernel_id, reason)
self.log_and_raise(http_status_code=error_http_code, reason=timeout_message)

Expand Down
21 changes: 8 additions & 13 deletions enterprise_gateway/services/processproxies/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
# Distributed under the terms of the Modified BSD License.
"""Code related to managing kernels running in containers."""

import abc
import os
import signal
import abc

import urllib3 # docker ends up using this and it causes lots of noise, so turn off warnings

from tornado import gen

from jupyter_client import launch_kernel, localinterfaces

from .processproxy import RemoteProcessProxy
Expand Down Expand Up @@ -53,8 +50,7 @@ def _determine_kernel_images(self, proxy_config):
self.kernel_executor_image = proxy_config.get('executor_image_name')
self.kernel_executor_image = os.environ.get('KERNEL_EXECUTOR_IMAGE', self.kernel_executor_image)

@gen.coroutine
def launch_process(self, kernel_cmd, **kwargs):
async def launch_process(self, kernel_cmd, **kwargs):
"""Launches the specified process within the container environment."""
# Set env before superclass call so we see these in the debug output

Expand All @@ -67,7 +63,7 @@ def launch_process(self, kernel_cmd, **kwargs):

self._enforce_uid_gid_blacklists(**kwargs)

yield super(ContainerProcessProxy, self).launch_process(kernel_cmd, **kwargs)
await super(ContainerProcessProxy, self).launch_process(kernel_cmd, **kwargs)

self.local_proc = launch_kernel(kernel_cmd, **kwargs)
self.pid = self.local_proc.pid
Expand All @@ -76,8 +72,8 @@ def launch_process(self, kernel_cmd, **kwargs):
self.log.info("{}: kernel launched. Kernel image: {}, KernelID: {}, cmd: '{}'"
.format(self.__class__.__name__, self.kernel_image, self.kernel_id, kernel_cmd))

yield self.confirm_remote_startup()
raise gen.Return(self)
await self.confirm_remote_startup()
return self

def _enforce_uid_gid_blacklists(self, **kwargs):
"""Determine UID and GID with which to launch container and ensure they do not appear in blacklist."""
Expand Down Expand Up @@ -155,20 +151,19 @@ def cleanup(self):
self.kill()
super(ContainerProcessProxy, self).cleanup()

@gen.coroutine
def confirm_remote_startup(self):
async def confirm_remote_startup(self):
"""Confirms the container has started and returned necessary connection information."""
self.start_time = RemoteProcessProxy.get_current_time()
i = 0
ready_to_connect = False # we're ready to connect when we have a connection file to use
while not ready_to_connect:
i += 1
yield self.handle_timeout()
await self.handle_timeout()

container_status = self.get_container_status(str(i))
if container_status:
if self.assigned_host != '':
ready_to_connect = yield self.receive_connection_info()
ready_to_connect = await self.receive_connection_info()
self.pid = 0 # We won't send process signals for kubernetes lifecycle management
self.pgid = 0
else:
Expand Down
Loading

0 comments on commit 10df48f

Please sign in to comment.