diff --git a/jupyter_server/base/zmqhandlers.py b/jupyter_server/base/zmqhandlers.py index 8b774a25d8..a3abef1658 100644 --- a/jupyter_server/base/zmqhandlers.py +++ b/jupyter_server/base/zmqhandlers.py @@ -24,6 +24,64 @@ from .handlers import JupyterHandler +def serialize_binary_message(msg): + """serialize a message as a binary blob + + Header: + + 4 bytes: number of msg parts (nbufs) as 32b int + 4 * nbufs bytes: offset for each buffer as integer as 32b int + + Offsets are from the start of the buffer, including the header. + + Returns + ------- + The message serialized to bytes. + + """ + # don't modify msg or buffer list in-place + msg = msg.copy() + buffers = list(msg.pop("buffers")) + if sys.version_info < (3, 4): + buffers = [x.tobytes() for x in buffers] + bmsg = json.dumps(msg, default=json_default).encode("utf8") + buffers.insert(0, bmsg) + nbufs = len(buffers) + offsets = [4 * (nbufs + 1)] + for buf in buffers[:-1]: + offsets.append(offsets[-1] + len(buf)) + offsets_buf = struct.pack("!" + "I" * (nbufs + 1), nbufs, *offsets) + buffers.insert(0, offsets_buf) + return b"".join(buffers) + + +def deserialize_binary_message(bmsg): + """deserialize a message from a binary blog + + Header: + + 4 bytes: number of msg parts (nbufs) as 32b int + 4 * nbufs bytes: offset for each buffer as integer as 32b int + + Offsets are from the start of the buffer, including the header. + + Returns + ------- + message dictionary + """ + nbufs = struct.unpack("!i", bmsg[:4])[0] + offsets = list(struct.unpack("!" + "I" * nbufs, bmsg[4 : 4 * (nbufs + 1)])) + offsets.append(None) + bufs = [] + for start, stop in zip(offsets[:-1], offsets[1:]): + bufs.append(bmsg[start:stop]) + msg = json.loads(bufs[0].decode("utf8")) + msg["header"] = extract_dates(msg["header"]) + msg["parent_header"] = extract_dates(msg["parent_header"]) + msg["buffers"] = bufs[1:] + return msg + + # ping interval for keeping websockets alive (30 seconds) WS_PING_INTERVAL = 30000 @@ -155,6 +213,37 @@ def send_error(self, *args, **kwargs): # we can close the connection more gracefully. self.stream.close() + def _reserialize_reply(self, msg_or_list, channel=None): + """Reserialize a reply message using JSON. + + msg_or_list can be an already-deserialized msg dict or the zmq buffer list. + If it is the zmq list, it will be deserialized with self.session. + + This takes the msg list from the ZMQ socket and serializes the result for the websocket. + This method should be used by self._on_zmq_reply to build messages that can + be sent back to the browser. + + """ + if isinstance(msg_or_list, dict): + # already unpacked + msg = msg_or_list + else: + idents, msg_list = self.session.feed_identities(msg_or_list) + msg = self.session.deserialize(msg_list) + if channel: + msg["channel"] = channel + if msg["buffers"]: + buf = serialize_binary_message(msg) + return buf + else: + smsg = json.dumps(msg, default=json_default) + return cast_unicode(smsg) + + def select_subprotocol(self, subprotocols): + selected_subprotocol = "0.0.1" if "0.0.1" in subprotocols else None + # None is the default, "legacy" protocol + return selected_subprotocol + def _on_zmq_reply(self, stream, msg_list): # Sometimes this gets triggered when the on_close method is scheduled in the # eventloop but hasn't been called. @@ -163,6 +252,22 @@ def _on_zmq_reply(self, stream, msg_list): self.close() return channel = getattr(stream, "channel", None) + try: + msg = self._reserialize_reply(msg_list, channel=channel) + except Exception: + self.log.critical("Malformed message: %r" % msg_list, exc_info=True) + else: + self.write_message(msg, binary=isinstance(msg, bytes)) + + def _on_zmq_reply_0_0_1(self, stream, msg_list): + # Sometimes this gets triggered when the on_close method is scheduled in the + # eventloop but hasn't been called. + if self.ws_connection is None or stream.closed(): + self.log.warning("zmq message arrived on closed channel") + self.close() + return + + channel = getattr(stream, "channel", None) offsets = [] curr_sum = 0 for msg in msg_list: diff --git a/jupyter_server/services/kernels/handlers.py b/jupyter_server/services/kernels/handlers.py index c7927e3eb8..23851c1970 100644 --- a/jupyter_server/services/kernels/handlers.py +++ b/jupyter_server/services/kernels/handlers.py @@ -22,6 +22,7 @@ from ...base.handlers import APIHandler from ...base.zmqhandlers import AuthenticatedZMQStreamHandler +from ...base.zmqhandlers import deserialize_binary_message from jupyter_server.utils import ensure_async from jupyter_server.utils import url_escape from jupyter_server.utils import url_path_join @@ -457,9 +458,33 @@ def on_message(self, msg): # already closed, ignore the message self.log.debug("Received message on closed websocket %r", msg) return + + if self.selected_subprotocol == "0.0.1": + return self.on_message_0_0_1(msg) + + if isinstance(msg, bytes): + msg = deserialize_binary_message(msg) + else: + msg = json.loads(msg) + channel = msg.pop("channel", None) + if channel is None: + self.log.warning("No channel specified, assuming shell: %s", msg) + channel = "shell" + if channel not in self.channels: + self.log.warning("No such channel: %r", channel) + return + am = self.kernel_manager.allowed_message_types + mt = msg["header"]["msg_type"] + if am and mt not in am: + self.log.warning('Received message of type "%s", which is not allowed. Ignoring.' % mt) + else: + stream = self.channels[channel] + self.session.send(stream, msg) + + def on_message_0_0_1(self, msg): layout_len = int.from_bytes(msg[:2], "little") layout = json.loads(msg[2:2 + layout_len]) - msg_list = list(get_msg_list(msg[2 + layout_len:], layout["offsets"])) + msg_list = list(get_msg_list_from_zmq(msg[2 + layout_len:], layout["offsets"])) channel = layout["channel"] if channel not in self.channels: self.log.warning("No such channel: %r", channel) @@ -489,6 +514,139 @@ def get_msg_field(self, field, value, msg_list): def _on_zmq_reply(self, stream, msg_list): + if self.selected_subprotocol == "0.0.1": + return self._on_zmq_reply_0_0_1(stream, msg_list) + + idents, fed_msg_list = self.session.feed_identities(msg_list) + msg = self.session.deserialize(fed_msg_list) + + parent = msg["parent_header"] + + def write_stderr(error_message): + self.log.warning(error_message) + msg = self.session.msg( + "stream", content={"text": error_message + "\n", "name": "stderr"}, parent=parent + ) + msg["channel"] = "iopub" + self.write_message(json.dumps(msg, default=json_default)) + + channel = getattr(stream, "channel", None) + msg_type = msg["header"]["msg_type"] + + if channel == "iopub" and msg_type == "error": + self._on_error(msg) + + if ( + channel == "iopub" + and msg_type == "status" + and msg["content"].get("execution_state") == "idle" + ): + # reset rate limit counter on status=idle, + # to avoid 'Run All' hitting limits prematurely. + self._iopub_window_byte_queue = [] + self._iopub_window_msg_count = 0 + self._iopub_window_byte_count = 0 + self._iopub_msgs_exceeded = False + self._iopub_data_exceeded = False + + if channel == "iopub" and msg_type not in {"status", "comm_open", "execute_input"}: + + # Remove the counts queued for removal. + now = IOLoop.current().time() + while len(self._iopub_window_byte_queue) > 0: + queued = self._iopub_window_byte_queue[0] + if now >= queued[0]: + self._iopub_window_byte_count -= queued[1] + self._iopub_window_msg_count -= 1 + del self._iopub_window_byte_queue[0] + else: + # This part of the queue hasn't be reached yet, so we can + # abort the loop. + break + + # Increment the bytes and message count + self._iopub_window_msg_count += 1 + if msg_type == "stream": + byte_count = sum([len(x) for x in msg_list]) + else: + byte_count = 0 + self._iopub_window_byte_count += byte_count + + # Queue a removal of the byte and message count for a time in the + # future, when we are no longer interested in it. + self._iopub_window_byte_queue.append((now + self.rate_limit_window, byte_count)) + + # Check the limits, set the limit flags, and reset the + # message and data counts. + msg_rate = float(self._iopub_window_msg_count) / self.rate_limit_window + data_rate = float(self._iopub_window_byte_count) / self.rate_limit_window + + # Check the msg rate + if self.iopub_msg_rate_limit > 0 and msg_rate > self.iopub_msg_rate_limit: + if not self._iopub_msgs_exceeded: + self._iopub_msgs_exceeded = True + write_stderr( + dedent( + """\ + IOPub message rate exceeded. + The Jupyter server will temporarily stop sending output + to the client in order to avoid crashing it. + To change this limit, set the config variable + `--ServerApp.iopub_msg_rate_limit`. + + Current values: + ServerApp.iopub_msg_rate_limit={} (msgs/sec) + ServerApp.rate_limit_window={} (secs) + """.format( + self.iopub_msg_rate_limit, self.rate_limit_window + ) + ) + ) + else: + # resume once we've got some headroom below the limit + if self._iopub_msgs_exceeded and msg_rate < (0.8 * self.iopub_msg_rate_limit): + self._iopub_msgs_exceeded = False + if not self._iopub_data_exceeded: + self.log.warning("iopub messages resumed") + + # Check the data rate + if self.iopub_data_rate_limit > 0 and data_rate > self.iopub_data_rate_limit: + if not self._iopub_data_exceeded: + self._iopub_data_exceeded = True + write_stderr( + dedent( + """\ + IOPub data rate exceeded. + The Jupyter server will temporarily stop sending output + to the client in order to avoid crashing it. + To change this limit, set the config variable + `--ServerApp.iopub_data_rate_limit`. + + Current values: + ServerApp.iopub_data_rate_limit={} (bytes/sec) + ServerApp.rate_limit_window={} (secs) + """.format( + self.iopub_data_rate_limit, self.rate_limit_window + ) + ) + ) + else: + # resume once we've got some headroom below the limit + if self._iopub_data_exceeded and data_rate < (0.8 * self.iopub_data_rate_limit): + self._iopub_data_exceeded = False + if not self._iopub_msgs_exceeded: + self.log.warning("iopub messages resumed") + + # If either of the limit flags are set, do not send the message. + if self._iopub_msgs_exceeded or self._iopub_data_exceeded: + # we didn't send it, remove the current message from the calculus + self._iopub_window_msg_count -= 1 + self._iopub_window_byte_count -= byte_count + self._iopub_window_byte_queue.pop(-1) + return + super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg) + + def _on_zmq_reply_0_0_1(self, stream, msg_list): idents, fed_msg_list = self.session.feed_identities(msg_list) # parse only what is needed for now @@ -498,15 +656,13 @@ def _on_zmq_reply(self, stream, msg_list): content = None parent_header = None - def write_stderr(error_message): + def write_stderr(error_message, parent_header): self.log.warning(error_message) - # parent_header = self.get_msg_field("parent_header", fed_msg_list) - # msg = self.session.msg( - # "stream", content={"text": error_message + "\n", "name": "stderr"}, parent=parent_header - # ) - # msg["channel"] = "iopub" - # self.write_message(json.dumps(msg, default=json_default)) - # FIXME: write this message + msg = self.session.msg( + "stream", content={"text": error_message + "\n", "name": "stderr"}, parent=parent_header + ) + bin_msg = serialize_msg_to_ws(msg, "iopub", self.session.pack) + self.write_message(bin_msg, binary=True) channel = getattr(stream, "channel", None) @@ -572,6 +728,7 @@ def write_stderr(error_message): if self.iopub_msg_rate_limit > 0 and msg_rate > self.iopub_msg_rate_limit: if not self._iopub_msgs_exceeded: self._iopub_msgs_exceeded = True + parent_header = self.get_msg_field("parent_header", parent_header, fed_msg_list) write_stderr( dedent( """\ @@ -587,7 +744,8 @@ def write_stderr(error_message): """.format( self.iopub_msg_rate_limit, self.rate_limit_window ) - ) + ), + parent_header ) else: # resume once we've got some headroom below the limit @@ -600,6 +758,7 @@ def write_stderr(error_message): if self.iopub_data_rate_limit > 0 and data_rate > self.iopub_data_rate_limit: if not self._iopub_data_exceeded: self._iopub_data_exceeded = True + parent_header = self.get_msg_field("parent_header", parent_header, fed_msg_list) write_stderr( dedent( """\ @@ -615,7 +774,8 @@ def write_stderr(error_message): """.format( self.iopub_data_rate_limit, self.rate_limit_window ) - ) + ), + parent_header ) else: # resume once we've got some headroom below the limit @@ -631,7 +791,7 @@ def write_stderr(error_message): self._iopub_window_byte_count -= byte_count self._iopub_window_byte_queue.pop(-1) return - super(ZMQChannelsHandler, self)._on_zmq_reply(stream, fed_msg_list[1:]) + super(ZMQChannelsHandler, self)._on_zmq_reply_0_0_1(stream, fed_msg_list[1:]) def close(self): super(ZMQChannelsHandler, self).close() @@ -680,10 +840,13 @@ def _send_status_message(self, status): # ensures proper ordering on the IOPub channel # that all messages from the stopped kernel have been delivered iopub.flush() - # msg = self.session.msg("status", {"execution_state": status}) - # msg["channel"] = "iopub" - # self.write_message(json.dumps(msg, default=json_default)) - # FIXME: write this message + msg = self.session.msg("status", {"execution_state": status}) + if not self.selected_subprotocol: + msg["channel"] = "iopub" + self.write_message(json.dumps(msg, default=json_default)) + elif self.selected_subprotocol == "0.0.1": + bin_msg = serialize_msg_to_ws(msg, "iopub") + self.write_message(bin_msg, binary=True) def on_kernel_restarted(self): self.log.warning("kernel %s restarted", self.kernel_id) @@ -693,14 +856,42 @@ def on_restart_failed(self): self.log.error("kernel %s restarted failed!", self.kernel_id) self._send_status_message("dead") + def _on_error(self, msg): + if self.kernel_manager.allow_tracebacks: + return + msg["content"]["ename"] = "ExecutionError" + msg["content"]["evalue"] = "Execution error" + msg["content"]["traceback"] = [self.kernel_manager.traceback_replacement_message] + -def get_msg_list(msg, offsets): +def get_msg_list_from_zmq(msg, offsets): i0 = 0 for i1 in offsets: yield msg[i0:i1] i0 = i1 yield msg[i0:] +def serialize_msg_to_ws(msg, channel, pack): + offsets = [] + curr_sum = 0 + parts = [ + pack(msg["header"]), + pack(msg["parent_header"]), + pack(msg["metadata"]), + pack(msg["content"]), + ] + for part in parts: + length = len(part) + offsets.append(length + curr_sum) + curr_sum += length + layout = json.dumps({ + "channel": channel, + "offsets": offsets, + }).encode("utf-8") + layout_length = len(layout).to_bytes(2, byteorder="little") + bin_msg = b"".join([layout_length, layout] + parts) + return bin_msg + # ----------------------------------------------------------------------------- # URL to handler mappings