diff --git a/pynq/utils.py b/pynq/utils.py index 0adc836eb1..d095072d58 100755 --- a/pynq/utils.py +++ b/pynq/utils.py @@ -747,7 +747,7 @@ def _create_code(num): return _function_text + call_line -def run_notebook(notebook, root_path=".", timeout=30): +def run_notebook(notebook, root_path=".", timeout=30, prerun=None): """Run a notebook in Jupyter This function will copy all of the files in ``root_path`` to a @@ -766,6 +766,9 @@ def run_notebook(notebook, root_path=".", timeout=30): The root notebook folder (default ".") timeout : int Length of time to run the notebook in seconds (default 30) + prerun : function + Function to run prior to starting the notebook, takes the + temporary copy of root_path as a parameter """ import nbformat @@ -774,6 +777,8 @@ def run_notebook(notebook, root_path=".", timeout=30): workdir = os.path.join(td, 'work') notebook_dir = os.path.join(workdir, os.path.dirname(notebook)) shutil.copytree(root_path, workdir) + if prerun is not None: + prerun(workdir) fullpath = os.path.join(workdir, notebook) with open(fullpath, "r") as f: nb = nbformat.read(f, as_version=4)