diff --git a/nbrmd/contentsmanager.py b/nbrmd/contentsmanager.py index a291bbc79..c01acc2ac 100644 --- a/nbrmd/contentsmanager.py +++ b/nbrmd/contentsmanager.py @@ -33,21 +33,44 @@ def _reads(text, as_version, **kwargs): def check_formats(formats): """ Parse, validate and return notebooks extensions - :param formats: a list of notebook extensions, or a comma separated string - :return: list of extensions + :param formats: a list of lists of notebook extensions, + or a colon separated string of extension groups, comma separated + :return: list of lists (groups) of notebook extensions """ - if not isinstance(formats, list): - formats = formats.split(',') - - formats = [fmt if fmt.startswith('.') else '.' + fmt - for fmt in formats if fmt != ''] - allowed = nbrmd.NOTEBOOK_EXTENSIONS + ['.nb.py', '.nb.R'] - if not isinstance(formats, list) or not set(formats).issubset(allowed): - raise TypeError("Notebook metadata 'nbrmd_formats' " - "should be subset of {}, but was {}" - "".format(str(allowed), str(formats))) - return formats + # Parse formats represented as strings + if not isinstance(formats, list): + formats = [group.split(',') for group in formats.split(';')] + + expected_format = ("Notebook metadata 'nbrmd_formats' should " + "be a list of extension groups, like 'ipynb,Rmd'. " + "Groups can be separated with colon, for instance: " + "'ipynb,nb.py;script.ipynb,py'") + + validated_formats = [] + for group in formats: + if not isinstance(group, list): + raise TypeError('Group extension should be a list, but was {}.\n{}' + .format(str(group), expected_format)) + validated_group = [] + for fmt in group: + if fmt == '': + continue + if not fmt.startswith('.'): + fmt = '.' + fmt + if not any([fmt.endswith(ext) + for ext in nbrmd.NOTEBOOK_EXTENSIONS]): + raise ValueError('Group extension {} contains {}, ' + 'which does not end with either {}.\n{}' + .format(str(group), fmt, + str(nbrmd.NOTEBOOK_EXTENSIONS), + expected_format)) + validated_group.append(fmt) + + if validated_group: + validated_formats.append(validated_group) + + return validated_formats def file_fmt_ext(path): @@ -85,6 +108,23 @@ def all_nb_extensions(self): 'Can be any of ipynb,Rmd,py,R,nb.py,nb.R comma separated', config=True) + def format_group(self, fmt, nb=None): + """Return the group of extensions that contains 'fmt'""" + nbrmd_formats = ((nb.metadata.get('nbrmd_formats') if nb else None) + or self.default_nbrmd_formats) + + nbrmd_formats = check_formats(nbrmd_formats) + + # Find group that contains the current format + for group in nbrmd_formats: + if fmt in group: + return group + + if ['.ipynb'] in nbrmd_formats: + return [fmt, '.ipynb'] + + return [fmt] + def _read_notebook(self, os_path, as_version=4, load_alternative_format=True): """Read a notebook from an os path.""" @@ -100,51 +140,42 @@ def _read_notebook(self, os_path, as_version=4, if not load_alternative_format: return nb - ext = fmt + fmt_group = self.format_group(fmt, nb) - # Notebook formats: default, notebook metadata, or current extension - nbrmd_formats = (nb.metadata.get('nbrmd_formats') or - self.default_nbrmd_formats) - - nbrmd_formats = check_formats(nbrmd_formats) - - if ext not in nbrmd_formats: - nbrmd_formats.append(ext) - - nbrmd_formats = check_formats(nbrmd_formats) + source_format = fmt + outputs_format = fmt - # Source format is current ext, or is first non .ipynb format - # that is found on disk - source_format = None - if ext != '.ipynb': - source_format = ext + # Source format is first non ipynb format found on disk + if fmt.endswith('.ipynb'): + for alt_fmt in fmt_group: + if not alt_fmt.endswith('.ipynb') and \ + os.path.isfile(file + alt_fmt): + source_format = alt_fmt + break + # Outputs taken from ipynb if in group else: - for fmt in nbrmd_formats: - if fmt != '.ipynb' and os.path.isfile(file + fmt): - source_format = fmt + for alt_fmt in fmt_group: + if alt_fmt.endswith('.ipynb'): + outputs_format = alt_fmt break - nb_outputs = None - if source_format is not None and ext != source_format: + if source_format == outputs_format: + self.log.info('Reading {}'.format(os.path.basename(os_path))) + nb = self._read_notebook(file + source_format, + as_version=as_version, + load_alternative_format=False) + else: self.log.info('Reading SOURCE from {}' .format(os.path.basename(file + source_format))) - self.log.info('Reading OUTPUTS from {}' - .format(os.path.basename(os_path))) - nb_outputs = nb nb = self._read_notebook(file + source_format, as_version=as_version, load_alternative_format=False) - elif ext != '.ipynb' and '.ipynb' in nbrmd_formats \ - and os.path.isfile(file + '.ipynb'): - self.log.info('Reading SOURCE from {}' - .format(os.path.basename(os_path))) self.log.info('Reading OUTPUTS from {}' - .format(os.path.basename(file + '.ipynb'))) - nb_outputs = self._read_notebook(file + '.ipynb', + .format(os.path.basename(file + outputs_format))) + nb_outputs = self._read_notebook(file + outputs_format, as_version=as_version, load_alternative_format=False) - if nb_outputs is not None: combine.combine_inputs_with_outputs(nb, nb_outputs) if self.notary.check_signature(nb_outputs): self.notary.sign(nb) @@ -153,24 +184,13 @@ def _read_notebook(self, os_path, as_version=4, def _save_notebook(self, os_path, nb): """Save a notebook to an os_path.""" - os_file, org_fmt, org_ext = file_fmt_ext(os_path) - - formats = (nb.get('metadata', {}).get('nbrmd_formats') or - self.default_nbrmd_formats) - - formats = check_formats(formats) - - if org_fmt not in formats: - formats.append(org_fmt) - - formats = check_formats(formats) - - for fmt in formats: - os_path_fmt = os_file + fmt + os_file, fmt, _ = file_fmt_ext(os_path) + for alt_fmt in self.format_group(fmt, nb): + os_path_fmt = os_file + alt_fmt self.log.info("Saving %s", os.path.basename(os_path_fmt)) - ext = fmt.replace('.nb.', '.') - if ext in self.nb_extensions: - with mock.patch('nbformat.writes', _nbrmd_writes(ext)): + alt_ext = '.' + alt_fmt.split('.')[-1] + if alt_ext in self.nb_extensions: + with mock.patch('nbformat.writes', _nbrmd_writes(alt_ext)): super(RmdFileContentsManager, self) \ ._save_notebook(os_path_fmt, nb) else: @@ -193,18 +213,30 @@ def get(self, path, content=True, type=None, format=None): def trust_notebook(self, path): """Trust the current notebook""" - file, _ = os.path.splitext(path) - super(RmdFileContentsManager, self).trust_notebook(file + '.ipynb') + file, fmt, ext = file_fmt_ext(path) + for alt_fmt in self.format_group(fmt): + if alt_fmt.endswith('.ipynb'): + super(RmdFileContentsManager, self).trust_notebook(file + + alt_fmt) def rename_file(self, old_path, new_path): """Rename the current notebook, as well as its alternative representations""" - old_file, org_fmt, org_ext = file_fmt_ext(old_path) - new_file, new_fmt, new_ext = file_fmt_ext(new_path) - if org_ext in self.all_nb_extensions() and org_ext == new_ext: - for fmt in self.all_nb_extensions() + ['.nb.py', '.nb.R']: - if self.file_exists(old_file + fmt): + old_file, org_fmt, _ = file_fmt_ext(old_path) + new_file, new_fmt, _ = file_fmt_ext(new_path) + nbrmd_formats = check_formats(self.default_nbrmd_formats) + + if org_fmt == new_fmt: + # Find group that contains the current format + fmt_group = [] + for group in nbrmd_formats: + if org_fmt in group: + fmt_group = group + break + + for alt_fmt in fmt_group: + if self.file_exists(old_file + alt_fmt): super(RmdFileContentsManager, self) \ - .rename_file(old_file + fmt, new_file + fmt) + .rename_file(old_file + alt_fmt, new_file + alt_fmt) else: super(RmdFileContentsManager, self).rename_file(old_path, new_path) diff --git a/tests/test_contentsmanager.py b/tests/test_contentsmanager.py index 5879747a9..647a22cae 100644 --- a/tests/test_contentsmanager.py +++ b/tests/test_contentsmanager.py @@ -21,6 +21,7 @@ def test_load_save_rename(nb_file, tmpdir): tmp_rmd = 'notebook.Rmd' cm = RmdFileContentsManager() + cm.default_nbrmd_formats = 'ipynb,Rmd' cm.root_dir = str(tmpdir) # open ipynb, save Rmd, reopen @@ -51,6 +52,7 @@ def test_load_save_rename_nbpy(nb_file, tmpdir): tmp_nbpy = 'notebook.nb.py' cm = RmdFileContentsManager() + cm.default_nbrmd_formats = 'ipynb,nb.py' cm.root_dir = str(tmpdir) # open ipynb, save nb.py, reopen diff --git a/tests/test_save_multiple.py b/tests/test_save_multiple.py index 5ceffc3dd..6f56c4664 100644 --- a/tests/test_save_multiple.py +++ b/tests/test_save_multiple.py @@ -15,7 +15,7 @@ def test_rmd_is_ok(nb_file, tmpdir): tmp_ipynb = 'notebook.ipynb' tmp_rmd = 'notebook.Rmd' - nb.metadata['nbrmd_formats'] = ['.Rmd'] + nb.metadata['nbrmd_formats'] = 'ipynb,Rmd' cm = RmdFileContentsManager() cm.root_dir = str(tmpdir) @@ -53,7 +53,7 @@ def test_all_files_created(nb_file, tmpdir): tmp_ipynb = 'notebook.ipynb' tmp_rmd = 'notebook.Rmd' tmp_py = 'notebook.py' - nb.metadata['nbrmd_formats'] = ['.ipynb', '.Rmd', '.py'] + nb.metadata['nbrmd_formats'] = 'ipynb,Rmd,py' cm = RmdFileContentsManager() cm.root_dir = str(tmpdir) @@ -109,7 +109,7 @@ def test_no_rmd_on_not_notebook(tmpdir): cm = RmdFileContentsManager() cm.root_dir = str(tmpdir) - cm.default_nbrmd_formats = '.Rmd' + cm.default_nbrmd_formats = 'ipynb,Rmd' with pytest.raises(HTTPError): cm.save(model=dict(type='not notebook', @@ -124,7 +124,7 @@ def test_no_rmd_on_not_v4(tmpdir): cm = RmdFileContentsManager() cm.root_dir = str(tmpdir) - cm.default_nbrmd_formats = '.Rmd' + cm.default_nbrmd_formats = 'ipynb,Rmd' with pytest.raises(NotebookValidationError): cm.save(model=dict(type='notebook',