diff --git a/nbconvert/preprocessors/execute.py b/nbconvert/preprocessors/execute.py index 01c8a8120..a4e01c4d3 100644 --- a/nbconvert/preprocessors/execute.py +++ b/nbconvert/preprocessors/execute.py @@ -25,6 +25,15 @@ from ..utils.exceptions import ConversionException +class CellExecutionComplete(Exception): + """ + Used as a control signal for cell execution across run_cell and + process_message function calls. Raised when all execution requests + are completed and no further messages are expected from the kernel + over zeromq channels. + """ + pass + class CellExecutionError(ConversionException): """ Custom exception to propagate exceptions that are raised during @@ -401,13 +410,14 @@ def preprocess_cell(self, cell, resources, cell_index): return cell, resources reply, outputs = self.run_cell(cell, cell_index) + # Backwards compatability for processes that wrap run_cell cell.outputs = outputs cell_allows_errors = (self.allow_errors or "raises-exception" in cell.metadata.get("tags", [])) if self.force_raise_errors or not cell_allows_errors: - for out in outputs: + for out in cell.outputs: if out.output_type == 'error': raise CellExecutionError.from_cell_and_msg(cell, out) if (reply is not None) and reply['content']['status'] == 'error': @@ -488,7 +498,7 @@ def run_cell(self, cell, cell_index=0): self.log.debug("Executing cell:\n%s", cell.source) exec_reply = self._wait_for_reply(parent_msg_id, cell) - outs = cell.outputs = [] + cell.outputs = [] self.clear_before_next_output = False while True: @@ -509,62 +519,93 @@ def run_cell(self, cell, cell_index=0): # not an output from our execution continue - msg_type = msg['msg_type'] - self.log.debug("output: %s", msg_type) - content = msg['content'] + # Will raise CellExecutionComplete when completed + try: + self.process_message(msg, cell, cell_index) + except CellExecutionComplete: + break - # set the prompt number for the input and the output - if 'execution_count' in content: - cell['execution_count'] = content['execution_count'] + # Return cell.outputs still for backwards compatability + return exec_reply, cell.outputs - if msg_type == 'status': - if content['execution_state'] == 'idle': - break - else: - continue - elif msg_type == 'execute_input': - continue - elif msg_type == 'clear_output': - self.clear_output(outs, msg, cell_index) - continue - elif msg_type.startswith('comm'): - self.handle_comm_msg(outs, msg, cell_index) - continue + def process_message(self, msg, cell, cell_index): + """ + Processes a kernel message, updates cell state, and returns the + resulting output object that was appended to cell.outputs. + + The input argument `cell` is modified in-place. - display_id = None - if msg_type in {'execute_result', 'display_data', 'update_display_data'}: - display_id = msg['content'].get('transient', {}).get('display_id', None) - if display_id: - self._update_display_id(display_id, msg) - if msg_type == 'update_display_data': - # update_display_data doesn't get recorded - continue + Parameters + ---------- + msg : dict + The kernel message being processed. + cell : nbformat.NotebookNode + The cell which is currently being processed. + cell_index : int + The position of the cell within the notebook object. - self.output(outs, msg, display_id, cell_index) + Returns + ------- + output : dict + The execution output payload (or None for no output). - return exec_reply, outs + Raises + ------ + CellExecutionComplete + Once a message arrives which indicates computation completeness. + + """ + msg_type = msg['msg_type'] + self.log.debug("msg_type: %s", msg_type) + content = msg['content'] + self.log.debug("content: %s", content) + + display_id = content.get('transient', {}).get('display_id', None) + if display_id and msg_type in {'execute_result', 'display_data', 'update_display_data'}: + self._update_display_id(display_id, msg) + + # set the prompt number for the input and the output + if 'execution_count' in content: + cell['execution_count'] = content['execution_count'] + + if msg_type == 'status': + if content['execution_state'] == 'idle': + raise CellExecutionComplete() + elif msg_type == 'clear_output': + self.clear_output(cell.outputs, msg, cell_index) + elif msg_type.startswith('comm'): + self.handle_comm_msg(cell.outputs, msg, cell_index) + # Check for remaining messages we don't process + elif msg_type not in ['execute_input', 'update_display_data']: + # Assign output as our processed "result" + return self.output(cell.outputs, msg, display_id, cell_index) def output(self, outs, msg, display_id, cell_index): msg_type = msg['msg_type'] - if self.clear_before_next_output: - self.log.debug('Executing delayed clear_output') - outs[:] = [] - self.clear_display_id_mapping(cell_index) - self.clear_before_next_output = False try: out = output_from_msg(msg) except ValueError: self.log.error("unhandled iopub msg: " + msg_type) return + + if self.clear_before_next_output: + self.log.debug('Executing delayed clear_output') + outs[:] = [] + self.clear_display_id_mapping(cell_index) + self.clear_before_next_output = False + if display_id: # record output index in: # _display_id_map[display_id][cell_idx] cell_map = self._display_id_map.setdefault(display_id, {}) output_idx_list = cell_map.setdefault(cell_index, []) output_idx_list.append(len(outs)) + outs.append(out) + return out + def clear_output(self, outs, msg, cell_index): content = msg['content'] if content.get('wait'): diff --git a/nbconvert/preprocessors/tests/test_execute.py b/nbconvert/preprocessors/tests/test_execute.py index 0ea7519a9..ad8bf82c7 100644 --- a/nbconvert/preprocessors/tests/test_execute.py +++ b/nbconvert/preprocessors/tests/test_execute.py @@ -23,6 +23,7 @@ from ..execute import ExecutePreprocessor, CellExecutionError, executenb import IPython +from mock import MagicMock from traitlets import TraitError from nbformat import NotebookNode from jupyter_client.kernelspec import KernelSpecManager @@ -51,7 +52,6 @@ def _normalize_base64(b64_text): except (ValueError, TypeError): return b64_text - class ExecuteTestBase(PreprocessorTestsBase): def build_preprocessor(self, opts): """Make an instance of a preprocessor""" @@ -185,18 +185,18 @@ def normalize_output(output): def assert_notebooks_equal(self, expected, actual): expected_cells = expected['cells'] actual_cells = actual['cells'] - self.assertEqual(len(expected_cells), len(actual_cells)) + assert len(expected_cells) == len(actual_cells) for expected_cell, actual_cell in zip(expected_cells, actual_cells): expected_outputs = expected_cell.get('outputs', []) actual_outputs = actual_cell.get('outputs', []) normalized_expected_outputs = list(map(self.normalize_output, expected_outputs)) normalized_actual_outputs = list(map(self.normalize_output, actual_outputs)) - self.assertEqual(normalized_expected_outputs, normalized_actual_outputs) + assert normalized_expected_outputs == normalized_actual_outputs expected_execution_count = expected_cell.get('execution_count', None) actual_execution_count = actual_cell.get('execution_count', None) - self.assertEqual(expected_execution_count, actual_execution_count) + assert expected_execution_count == actual_execution_count def test_constructor(self): @@ -395,6 +395,30 @@ def test_custom_kernel_manager(self): for method, call_count in expected: self.assertNotEqual(call_count, 0, '{} was called'.format(method)) + def test_process_message_wrapper(self): + outputs = [] + + class WrappedPreProc(ExecutePreprocessor): + def process_message(self, msg, cell, cell_index): + result = super(WrappedPreProc, self).process_message(msg, cell, cell_index) + if result: + outputs.append(result) + return result + + current_dir = os.path.dirname(__file__) + filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb') + + with io.open(filename) as f: + input_nb = nbformat.read(f, 4) + + original = copy.deepcopy(input_nb) + wpp = WrappedPreProc() + executed = wpp.preprocess(input_nb, {})[0] + assert outputs == [ + {'name': 'stdout', 'output_type': 'stream', 'text': 'Hello World\n'} + ] + self.assert_notebooks_equal(original, executed) + def test_execute_function(self): # Test the executenb() convenience API filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb')