Skip to content

Commit

Permalink
Merge pull request #5966 from kevin-bates/gateway-async-km
Browse files Browse the repository at this point in the history
Update GatewayKernelManager to derive from AsyncMappingKernelManager
  • Loading branch information
kevin-bates authored Feb 2, 2021
2 parents c353da4 + 3e69164 commit 5d96514
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 71 deletions.
93 changes: 40 additions & 53 deletions notebook/gateway/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import json

from socket import gaierror
from tornado import gen, web
from tornado import web
from tornado.escape import json_encode, json_decode, url_escape
from tornado.httpclient import HTTPClient, AsyncHTTPClient, HTTPError

from ..services.kernels.kernelmanager import MappingKernelManager
from ..services.kernels.kernelmanager import AsyncMappingKernelManager
from ..services.sessions.sessionmanager import SessionManager

from jupyter_client.kernelspec import KernelSpecManager
Expand Down Expand Up @@ -303,13 +303,12 @@ def load_connection_args(self, **kwargs):
return kwargs


@gen.coroutine
def gateway_request(endpoint, **kwargs):
async def gateway_request(endpoint, **kwargs):
"""Make an async request to kernel gateway endpoint, returns a response """
client = AsyncHTTPClient()
kwargs = GatewayClient.instance().load_connection_args(**kwargs)
try:
response = yield client.fetch(endpoint, **kwargs)
response = await client.fetch(endpoint, **kwargs)
# Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect
# or the server is not running.
# NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes
Expand All @@ -332,10 +331,10 @@ def gateway_request(endpoint, **kwargs):
"url is valid and the Gateway instance is running.".format(GatewayClient.instance().url)
) from e

raise gen.Return(response)
return response


class GatewayKernelManager(MappingKernelManager):
class GatewayKernelManager(AsyncMappingKernelManager):
"""Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway."""

# We'll maintain our own set of kernel ids
Expand Down Expand Up @@ -367,8 +366,7 @@ def _get_kernel_endpoint_url(self, kernel_id=None):

return self.base_endpoint

@gen.coroutine
def start_kernel(self, kernel_id=None, path=None, **kwargs):
async def start_kernel(self, kernel_id=None, path=None, **kwargs):
"""Start a kernel for a session and return its kernel_id.
Parameters
Expand Down Expand Up @@ -403,21 +401,20 @@ def start_kernel(self, kernel_id=None, path=None, **kwargs):

json_body = json_encode({'name': kernel_name, 'env': kernel_env})

response = yield gateway_request(kernel_url, method='POST', body=json_body)
response = await gateway_request(kernel_url, method='POST', body=json_body)
kernel = json_decode(response.body)
kernel_id = kernel['id']
self.log.info("Kernel started: %s" % kernel_id)
self.log.debug("Kernel args: %r" % kwargs)
else:
kernel = yield self.get_kernel(kernel_id)
kernel = await self.get_kernel(kernel_id)
kernel_id = kernel['id']
self.log.info("Using existing kernel: %s" % kernel_id)

self._kernels[kernel_id] = kernel
raise gen.Return(kernel_id)
return kernel_id

@gen.coroutine
def get_kernel(self, kernel_id=None, **kwargs):
async def get_kernel(self, kernel_id=None, **kwargs):
"""Get kernel for kernel_id.
Parameters
Expand All @@ -428,7 +425,7 @@ def get_kernel(self, kernel_id=None, **kwargs):
kernel_url = self._get_kernel_endpoint_url(kernel_id)
self.log.debug("Request kernel at: %s" % kernel_url)
try:
response = yield gateway_request(kernel_url, method='GET')
response = await gateway_request(kernel_url, method='GET')
except web.HTTPError as error:
if error.status_code == 404:
self.log.warn("Kernel not found at: %s" % kernel_url)
Expand All @@ -440,10 +437,9 @@ def get_kernel(self, kernel_id=None, **kwargs):
kernel = json_decode(response.body)
self._kernels[kernel_id] = kernel
self.log.debug("Kernel retrieved: %s" % kernel)
raise gen.Return(kernel)
return kernel

@gen.coroutine
def kernel_model(self, kernel_id):
async def kernel_model(self, kernel_id):
"""Return a dictionary of kernel information described in the
JSON standard model.
Expand All @@ -453,21 +449,19 @@ def kernel_model(self, kernel_id):
The uuid of the kernel.
"""
self.log.debug("RemoteKernelManager.kernel_model: %s", kernel_id)
model = yield self.get_kernel(kernel_id)
raise gen.Return(model)
model = await self.get_kernel(kernel_id)
return model

@gen.coroutine
def list_kernels(self, **kwargs):
async def list_kernels(self, **kwargs):
"""Get a list of kernels."""
kernel_url = self._get_kernel_endpoint_url()
self.log.debug("Request list kernels: %s", kernel_url)
response = yield gateway_request(kernel_url, method='GET')
response = await gateway_request(kernel_url, method='GET')
kernels = json_decode(response.body)
self._kernels = {x['id']: x for x in kernels}
raise gen.Return(kernels)
return kernels

@gen.coroutine
def shutdown_kernel(self, kernel_id, now=False, restart=False):
async def shutdown_kernel(self, kernel_id, now=False, restart=False):
"""Shutdown a kernel by its kernel uuid.
Parameters
Expand All @@ -481,12 +475,11 @@ def shutdown_kernel(self, kernel_id, now=False, restart=False):
"""
kernel_url = self._get_kernel_endpoint_url(kernel_id)
self.log.debug("Request shutdown kernel at: %s", kernel_url)
response = yield gateway_request(kernel_url, method='DELETE')
response = await gateway_request(kernel_url, method='DELETE')
self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason)
self.remove_kernel(kernel_id)

@gen.coroutine
def restart_kernel(self, kernel_id, now=False, **kwargs):
async def restart_kernel(self, kernel_id, now=False, **kwargs):
"""Restart a kernel by its kernel uuid.
Parameters
Expand All @@ -496,11 +489,10 @@ def restart_kernel(self, kernel_id, now=False, **kwargs):
"""
kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/restart'
self.log.debug("Request restart kernel at: %s", kernel_url)
response = yield gateway_request(kernel_url, method='POST', body=json_encode({}))
response = await gateway_request(kernel_url, method='POST', body=json_encode({}))
self.log.debug("Restart kernel response: %d %s", response.code, response.reason)

@gen.coroutine
def interrupt_kernel(self, kernel_id, **kwargs):
async def interrupt_kernel(self, kernel_id, **kwargs):
"""Interrupt a kernel by its kernel uuid.
Parameters
Expand All @@ -510,7 +502,7 @@ def interrupt_kernel(self, kernel_id, **kwargs):
"""
kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/interrupt'
self.log.debug("Request interrupt kernel at: %s", kernel_url)
response = yield gateway_request(kernel_url, method='POST', body=json_encode({}))
response = await gateway_request(kernel_url, method='POST', body=json_encode({}))
self.log.debug("Interrupt kernel response: %d %s", response.code, response.reason)

def shutdown_all(self, now=False):
Expand Down Expand Up @@ -565,9 +557,8 @@ def _get_kernelspecs_endpoint_url(self, kernel_name=None):

return self.base_endpoint

@gen.coroutine
def get_all_specs(self):
fetched_kspecs = yield self.list_kernel_specs()
async def get_all_specs(self):
fetched_kspecs = await self.list_kernel_specs()

# get the default kernel name and compare to that of this server.
# If different log a warning and reset the default. However, the
Expand All @@ -583,19 +574,17 @@ def get_all_specs(self):
km.default_kernel_name = remote_default_kernel_name

remote_kspecs = fetched_kspecs.get('kernelspecs')
raise gen.Return(remote_kspecs)
return remote_kspecs

@gen.coroutine
def list_kernel_specs(self):
async def list_kernel_specs(self):
"""Get a list of kernel specs."""
kernel_spec_url = self._get_kernelspecs_endpoint_url()
self.log.debug("Request list kernel specs at: %s", kernel_spec_url)
response = yield gateway_request(kernel_spec_url, method='GET')
response = await gateway_request(kernel_spec_url, method='GET')
kernel_specs = json_decode(response.body)
raise gen.Return(kernel_specs)
return kernel_specs

@gen.coroutine
def get_kernel_spec(self, kernel_name, **kwargs):
async def get_kernel_spec(self, kernel_name, **kwargs):
"""Get kernel spec for kernel_name.
Parameters
Expand All @@ -606,7 +595,7 @@ def get_kernel_spec(self, kernel_name, **kwargs):
kernel_spec_url = self._get_kernelspecs_endpoint_url(kernel_name=str(kernel_name))
self.log.debug("Request kernel spec at: %s" % kernel_spec_url)
try:
response = yield gateway_request(kernel_spec_url, method='GET')
response = await gateway_request(kernel_spec_url, method='GET')
except web.HTTPError as error:
if error.status_code == 404:
# Convert not found to KeyError since that's what the Notebook handler expects
Expand All @@ -620,10 +609,9 @@ def get_kernel_spec(self, kernel_name, **kwargs):
else:
kernel_spec = json_decode(response.body)

raise gen.Return(kernel_spec)
return kernel_spec

@gen.coroutine
def get_kernel_spec_resource(self, kernel_name, path):
async def get_kernel_spec_resource(self, kernel_name, path):
"""Get kernel spec for kernel_name.
Parameters
Expand All @@ -636,22 +624,21 @@ def get_kernel_spec_resource(self, kernel_name, path):
kernel_spec_resource_url = url_path_join(self.base_resource_endpoint, str(kernel_name), str(path))
self.log.debug("Request kernel spec resource '{}' at: {}".format(path, kernel_spec_resource_url))
try:
response = yield gateway_request(kernel_spec_resource_url, method='GET')
response = await gateway_request(kernel_spec_resource_url, method='GET')
except web.HTTPError as error:
if error.status_code == 404:
kernel_spec_resource = None
else:
raise
else:
kernel_spec_resource = response.body
raise gen.Return(kernel_spec_resource)
return kernel_spec_resource


class GatewaySessionManager(SessionManager):
kernel_manager = Instance('notebook.gateway.managers.GatewayKernelManager')

@gen.coroutine
def kernel_culled(self, kernel_id):
async def kernel_culled(self, kernel_id):
"""Checks if the kernel is still considered alive and returns true if its not found. """
kernel = yield self.kernel_manager.get_kernel(kernel_id)
raise gen.Return(kernel is None)
kernel = await self.kernel_manager.get_kernel(kernel_id)
return kernel is None
35 changes: 17 additions & 18 deletions notebook/tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def generate_model(name):
return model


@gen.coroutine
def mock_gateway_request(url, **kwargs):
async def mock_gateway_request(url, **kwargs):
method = 'GET'
if kwargs['method']:
method = kwargs['method']
Expand All @@ -51,17 +50,17 @@ def mock_gateway_request(url, **kwargs):
# Fetch all kernelspecs
if endpoint.endswith('/api/kernelspecs') and method == 'GET':
response_buf = StringIO(json.dumps(kernelspecs))
response = yield maybe_future(HTTPResponse(request, 200, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 200, buffer=response_buf))
return response

# Fetch named kernelspec
if endpoint.rfind('/api/kernelspecs/') >= 0 and method == 'GET':
requested_kernelspec = endpoint.rpartition('/')[2]
kspecs = kernelspecs.get('kernelspecs')
if requested_kernelspec in kspecs:
response_buf = StringIO(json.dumps(kspecs.get(requested_kernelspec)))
response = yield maybe_future(HTTPResponse(request, 200, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 200, buffer=response_buf))
return response
else:
raise HTTPError(404, message='Kernelspec does not exist: %s' % requested_kernelspec)

Expand All @@ -75,8 +74,8 @@ def mock_gateway_request(url, **kwargs):
model = generate_model(name)
running_kernels[model.get('id')] = model # Register model as a running kernel
response_buf = StringIO(json.dumps(model))
response = yield maybe_future(HTTPResponse(request, 201, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 201, buffer=response_buf))
return response

# Fetch list of running kernels
if endpoint.endswith('/api/kernels') and method == 'GET':
Expand All @@ -85,24 +84,24 @@ def mock_gateway_request(url, **kwargs):
model = running_kernels.get(kernel_id)
kernels.append(model)
response_buf = StringIO(json.dumps(kernels))
response = yield maybe_future(HTTPResponse(request, 200, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 200, buffer=response_buf))
return response

# Interrupt or restart existing kernel
if endpoint.rfind('/api/kernels/') >= 0 and method == 'POST':
requested_kernel_id, sep, action = endpoint.rpartition('/api/kernels/')[2].rpartition('/')

if action == 'interrupt':
if requested_kernel_id in running_kernels:
response = yield maybe_future(HTTPResponse(request, 204))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 204))
return response
else:
raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id)
elif action == 'restart':
if requested_kernel_id in running_kernels:
response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id)))
response = yield maybe_future(HTTPResponse(request, 204, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 204, buffer=response_buf))
return response
else:
raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id)
else:
Expand All @@ -112,16 +111,16 @@ def mock_gateway_request(url, **kwargs):
if endpoint.rfind('/api/kernels/') >= 0 and method == 'DELETE':
requested_kernel_id = endpoint.rpartition('/')[2]
running_kernels.pop(requested_kernel_id) # Simulate shutdown by removing kernel from running set
response = yield maybe_future(HTTPResponse(request, 204))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 204))
return response

# Fetch existing kernel
if endpoint.rfind('/api/kernels/') >= 0 and method == 'GET':
requested_kernel_id = endpoint.rpartition('/')[2]
if requested_kernel_id in running_kernels:
response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id)))
response = yield maybe_future(HTTPResponse(request, 200, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 200, buffer=response_buf))
return response
else:
raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id)

Expand Down

0 comments on commit 5d96514

Please sign in to comment.