Skip to content

Commit

Permalink
implement buffering of messages on last dropped connection
Browse files Browse the repository at this point in the history
- buffer is per-kernel
- session_key is stored because only a single session can resume the buffer and we can't be sure
- on any new connection to a kernel, buffer is flushed.
  If session_key matches, it is replayed.
  Otherwise, it is discarded.
- buffer is an unbounded list for now
  • Loading branch information
minrk committed Oct 3, 2017
1 parent 569bb25 commit 8f8363a
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 16 deletions.
37 changes: 24 additions & 13 deletions notebook/services/kernels/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,6 @@ def initialize(self):
self._kernel_info_future = Future()
self._close_future = Future()
self.session_key = ''

# TODO: the buffer should likely be a memory bounded queue, we're starting with a list to keep it simple
# TODO: Min suggested this should exist on the `Kernel` as well, not in this ZMQChannelsHandler
self.message_buffer = []

# Rate limiting code
self._iopub_window_msg_count = 0
Expand Down Expand Up @@ -259,12 +255,15 @@ def _register_session(self):
self.log.warning("Replacing stale connection: %s", self.session_key)
yield stale_handler.close()
self._open_sessions[self.session_key] = self

def open(self, kernel_id):
super(ZMQChannelsHandler, self).open()
self.kernel_manager.notify_connect(kernel_id)

# on new connections, flush the message buffer
replay_buffer = self.kernel_manager.stop_buffering(kernel_id, self.session_key)

try:
# TODO: if this is a reconnection, we'll replay messages
self.create_stream()
except web.HTTPError as e:
self.log.error("Error opening stream: %s", e)
Expand All @@ -274,9 +273,16 @@ def open(self, kernel_id):
if not stream.closed():
stream.close()
self.close()
else:
for channel, stream in self.channels.items():
stream.on_recv_stream(self._on_zmq_reply)
return

if replay_buffer:
self.log.info("Replaying %s buffered messages", len(replay_buffer))
for channel, msg_list in replay_buffer:
stream = self.channels[channel]
self._on_zmq_reply(stream, msg_list)

for channel, stream in self.channels.items():
stream.on_recv_stream(self._on_zmq_reply)

def on_message(self, msg):
if not self.channels:
Expand All @@ -296,7 +302,7 @@ def on_message(self, msg):
return
stream = self.channels[channel]
self.session.send(stream, msg)

def _on_zmq_reply(self, stream, msg_list):
idents, fed_msg_list = self.session.feed_identities(msg_list)
msg = self.session.deserialize(fed_msg_list)
Expand All @@ -309,7 +315,6 @@ def write_stderr(error_message):
)
msg['channel'] = 'iopub'
self.write_message(json.dumps(msg, default=date_default))

channel = getattr(stream, 'channel', None)
msg_type = msg['header']['msg_type']

Expand Down Expand Up @@ -412,12 +417,11 @@ def close(self):
return self._close_future

def on_close(self):
# TODO: Start buffering messages

self.log.debug("Websocket closed %s", self.session_key)
# unregister myself as an open session (only if it's really me)
if self._open_sessions.get(self.session_key) is self:
self._open_sessions.pop(self.session_key)

km = self.kernel_manager
if self.kernel_id in km:
km.notify_disconnect(self.kernel_id)
Expand All @@ -427,6 +431,13 @@ def on_close(self):
km.remove_restart_callback(
self.kernel_id, self.on_restart_failed, 'dead',
)

# start buffering instead of closing if this was the last connection
if km._kernel_connections[self.kernel_id] == 0:
km.start_buffering(self.kernel_id, self.session_key, self.channels)
self._close_future.set_result(None)
return

# This method can be called twice, once by self.kernel_died and once
# from the WebSocket close event. If the WebSocket connection is
# closed before the ZMQ streams are setup, they could be None.
Expand Down
86 changes: 83 additions & 3 deletions notebook/services/kernels/kernelmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

from collections import defaultdict
from functools import partial
import os

from tornado import gen, web
Expand All @@ -15,13 +17,13 @@

from jupyter_client.session import Session
from jupyter_client.multikernelmanager import MultiKernelManager
from traitlets import Bool, Dict, List, Unicode, TraitError, Integer, default, validate
from traitlets import Any, Bool, Dict, List, Unicode, TraitError, Integer, default, validate

from notebook.utils import to_os_path, exists
from notebook._tz import utcnow, isoformat
from ipython_genutils.py3compat import getcwd

from datetime import datetime, timedelta
from datetime import timedelta


class MappingKernelManager(MultiKernelManager):
Expand Down Expand Up @@ -81,6 +83,11 @@ def _update_root_dir(self, proposal):
Only effective if cull_idle_timeout is not 0."""
)

_kernel_buffers = Any()
@default('_kernel_buffers')
def _default_kernel_buffers(self):
return defaultdict(lambda: {'buffer': [], 'session_key': '', 'channels': {}})

#-------------------------------------------------------------------------
# Methods for managing kernels and sessions
#-------------------------------------------------------------------------
Expand Down Expand Up @@ -142,10 +149,82 @@ def start_kernel(self, kernel_id=None, path=None, **kwargs):
# py2-compat
raise gen.Return(kernel_id)

def start_buffering(self, kernel_id, session_key, channels):
"""Start buffering messages for a kernel
Parameters
----------
kernel_id : str
The id of the kernel to stop buffering.
session_key: str
The session_key, if any, that should get the buffer.
If the session_key matches the current buffered session_key,
the buffer will be returned.
channels: dict({'channel': ZMQStream})
The zmq channels whose messages should be buffered.
"""
self.log.info("Starting buffering for %s", session_key)
self._check_kernel_id(kernel_id)
# clear previous buffering state
self.stop_buffering(kernel_id)
buffer_info = self._kernel_buffers[kernel_id]
# record the session key because only one session can buffer
buffer_info['session_key'] = session_key
# TODO: the buffer should likely be a memory bounded queue, we're starting with a list to keep it simple
buffer_info['buffer'] = []
buffer_info['channels'] = channels

# forward any future messages to the internal buffer
def buffer_msg(channel, msg_parts):
self.log.debug("Buffering msg on %s:%s", kernel_id, channel)
buffer_info['buffer'].append((channel, msg_parts))

for channel, stream in channels.items():
stream.on_recv(partial(buffer_msg, channel))

def stop_buffering(self, kernel_id, session_key=None):
"""Stop buffering kernel messages
if session_key matches the current buffered session for the kernel,
the buffer will be returned. Otherwise, an empty list will be returned.
Parameters
----------
kernel_id : str
The id of the kernel to stop buffering.
session_key: str, optional
The session_key, if any, that should get the buffer.
If the session_key matches the current buffered session_key,
the buffer will be returned.
"""
self.log.debug("Clearing buffer for %s", kernel_id)
self._check_kernel_id(kernel_id)

if kernel_id not in self._kernel_buffers:
return
buffer_info = self._kernel_buffers.pop(kernel_id)
# close buffering streams
for stream in buffer_info['channels'].values():
if not stream.closed():
stream.on_recv(None)
stream.socket.close()
stream.close()

msg_buffer = buffer_info['buffer']
if msg_buffer and buffer_info['session_key'] != session_key:
self.log.info("Discarding %s buffered messages for %s",
len(msg_buffer), buffer_info['session_key'])
msg_buffer = []

# return previous buffer if it matched the session key
return msg_buffer

def shutdown_kernel(self, kernel_id, now=False):
"""Shutdown a kernel by kernel_id"""
self._check_kernel_id(kernel_id)
self._kernels[kernel_id]._activity_stream.close()
kernel = self._kernels[kernel_id]
kernel._activity_stream.close()
self.stop_buffering(kernel_id)
self._kernel_connections.pop(kernel_id, None)
return super(MappingKernelManager, self).shutdown_kernel(kernel_id, now=now)

Expand Down Expand Up @@ -256,6 +335,7 @@ def record_activity(msg_list):

idents, fed_msg_list = session.feed_identities(msg_list)
msg = session.deserialize(fed_msg_list)

msg_type = msg['header']['msg_type']
self.log.debug("activity on %s: %s", kernel_id, msg_type)
if msg_type == 'status':
Expand Down

0 comments on commit 8f8363a

Please sign in to comment.