Skip to content

Commit

Permalink
Merge pull request #905 from MSeal/processMessage
Browse files Browse the repository at this point in the history
Refactored execute preprocessor to have a process_message function
  • Loading branch information
mpacer authored Apr 21, 2019
2 parents 7245701 + 2f9c697 commit 7502276
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 40 deletions.
113 changes: 77 additions & 36 deletions nbconvert/preprocessors/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
from ..utils.exceptions import ConversionException


class CellExecutionComplete(Exception):
"""
Used as a control signal for cell execution across run_cell and
process_message function calls. Raised when all execution requests
are completed and no further messages are expected from the kernel
over zeromq channels.
"""
pass

class CellExecutionError(ConversionException):
"""
Custom exception to propagate exceptions that are raised during
Expand Down Expand Up @@ -401,13 +410,14 @@ def preprocess_cell(self, cell, resources, cell_index):
return cell, resources

reply, outputs = self.run_cell(cell, cell_index)
# Backwards compatability for processes that wrap run_cell
cell.outputs = outputs

cell_allows_errors = (self.allow_errors or "raises-exception"
in cell.metadata.get("tags", []))

if self.force_raise_errors or not cell_allows_errors:
for out in outputs:
for out in cell.outputs:
if out.output_type == 'error':
raise CellExecutionError.from_cell_and_msg(cell, out)
if (reply is not None) and reply['content']['status'] == 'error':
Expand Down Expand Up @@ -488,7 +498,7 @@ def run_cell(self, cell, cell_index=0):
self.log.debug("Executing cell:\n%s", cell.source)
exec_reply = self._wait_for_reply(parent_msg_id, cell)

outs = cell.outputs = []
cell.outputs = []
self.clear_before_next_output = False

while True:
Expand All @@ -509,62 +519,93 @@ def run_cell(self, cell, cell_index=0):
# not an output from our execution
continue

msg_type = msg['msg_type']
self.log.debug("output: %s", msg_type)
content = msg['content']
# Will raise CellExecutionComplete when completed
try:
self.process_message(msg, cell, cell_index)
except CellExecutionComplete:
break

# set the prompt number for the input and the output
if 'execution_count' in content:
cell['execution_count'] = content['execution_count']
# Return cell.outputs still for backwards compatability
return exec_reply, cell.outputs

if msg_type == 'status':
if content['execution_state'] == 'idle':
break
else:
continue
elif msg_type == 'execute_input':
continue
elif msg_type == 'clear_output':
self.clear_output(outs, msg, cell_index)
continue
elif msg_type.startswith('comm'):
self.handle_comm_msg(outs, msg, cell_index)
continue
def process_message(self, msg, cell, cell_index):
"""
Processes a kernel message, updates cell state, and returns the
resulting output object that was appended to cell.outputs.
The input argument `cell` is modified in-place.
display_id = None
if msg_type in {'execute_result', 'display_data', 'update_display_data'}:
display_id = msg['content'].get('transient', {}).get('display_id', None)
if display_id:
self._update_display_id(display_id, msg)
if msg_type == 'update_display_data':
# update_display_data doesn't get recorded
continue
Parameters
----------
msg : dict
The kernel message being processed.
cell : nbformat.NotebookNode
The cell which is currently being processed.
cell_index : int
The position of the cell within the notebook object.
self.output(outs, msg, display_id, cell_index)
Returns
-------
output : dict
The execution output payload (or None for no output).
return exec_reply, outs
Raises
------
CellExecutionComplete
Once a message arrives which indicates computation completeness.
"""
msg_type = msg['msg_type']
self.log.debug("msg_type: %s", msg_type)
content = msg['content']
self.log.debug("content: %s", content)

display_id = content.get('transient', {}).get('display_id', None)
if display_id and msg_type in {'execute_result', 'display_data', 'update_display_data'}:
self._update_display_id(display_id, msg)

# set the prompt number for the input and the output
if 'execution_count' in content:
cell['execution_count'] = content['execution_count']

if msg_type == 'status':
if content['execution_state'] == 'idle':
raise CellExecutionComplete()
elif msg_type == 'clear_output':
self.clear_output(cell.outputs, msg, cell_index)
elif msg_type.startswith('comm'):
self.handle_comm_msg(cell.outputs, msg, cell_index)
# Check for remaining messages we don't process
elif msg_type not in ['execute_input', 'update_display_data']:
# Assign output as our processed "result"
return self.output(cell.outputs, msg, display_id, cell_index)

def output(self, outs, msg, display_id, cell_index):
msg_type = msg['msg_type']
if self.clear_before_next_output:
self.log.debug('Executing delayed clear_output')
outs[:] = []
self.clear_display_id_mapping(cell_index)
self.clear_before_next_output = False

try:
out = output_from_msg(msg)
except ValueError:
self.log.error("unhandled iopub msg: " + msg_type)
return

if self.clear_before_next_output:
self.log.debug('Executing delayed clear_output')
outs[:] = []
self.clear_display_id_mapping(cell_index)
self.clear_before_next_output = False

if display_id:
# record output index in:
# _display_id_map[display_id][cell_idx]
cell_map = self._display_id_map.setdefault(display_id, {})
output_idx_list = cell_map.setdefault(cell_index, [])
output_idx_list.append(len(outs))

outs.append(out)

return out

def clear_output(self, outs, msg, cell_index):
content = msg['content']
if content.get('wait'):
Expand Down
32 changes: 28 additions & 4 deletions nbconvert/preprocessors/tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..execute import ExecutePreprocessor, CellExecutionError, executenb

import IPython
from mock import MagicMock
from traitlets import TraitError
from nbformat import NotebookNode
from jupyter_client.kernelspec import KernelSpecManager
Expand Down Expand Up @@ -51,7 +52,6 @@ def _normalize_base64(b64_text):
except (ValueError, TypeError):
return b64_text


class ExecuteTestBase(PreprocessorTestsBase):
def build_preprocessor(self, opts):
"""Make an instance of a preprocessor"""
Expand Down Expand Up @@ -185,18 +185,18 @@ def normalize_output(output):
def assert_notebooks_equal(self, expected, actual):
expected_cells = expected['cells']
actual_cells = actual['cells']
self.assertEqual(len(expected_cells), len(actual_cells))
assert len(expected_cells) == len(actual_cells)

for expected_cell, actual_cell in zip(expected_cells, actual_cells):
expected_outputs = expected_cell.get('outputs', [])
actual_outputs = actual_cell.get('outputs', [])
normalized_expected_outputs = list(map(self.normalize_output, expected_outputs))
normalized_actual_outputs = list(map(self.normalize_output, actual_outputs))
self.assertEqual(normalized_expected_outputs, normalized_actual_outputs)
assert normalized_expected_outputs == normalized_actual_outputs

expected_execution_count = expected_cell.get('execution_count', None)
actual_execution_count = actual_cell.get('execution_count', None)
self.assertEqual(expected_execution_count, actual_execution_count)
assert expected_execution_count == actual_execution_count


def test_constructor(self):
Expand Down Expand Up @@ -395,6 +395,30 @@ def test_custom_kernel_manager(self):
for method, call_count in expected:
self.assertNotEqual(call_count, 0, '{} was called'.format(method))

def test_process_message_wrapper(self):
outputs = []

class WrappedPreProc(ExecutePreprocessor):
def process_message(self, msg, cell, cell_index):
result = super(WrappedPreProc, self).process_message(msg, cell, cell_index)
if result:
outputs.append(result)
return result

current_dir = os.path.dirname(__file__)
filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb')

with io.open(filename) as f:
input_nb = nbformat.read(f, 4)

original = copy.deepcopy(input_nb)
wpp = WrappedPreProc()
executed = wpp.preprocess(input_nb, {})[0]
assert outputs == [
{'name': 'stdout', 'output_type': 'stream', 'text': 'Hello World\n'}
]
self.assert_notebooks_equal(original, executed)

def test_execute_function(self):
# Test the executenb() convenience API
filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb')
Expand Down

0 comments on commit 7502276

Please sign in to comment.