diff --git a/glotaran/project/project.py b/glotaran/project/project.py index c41778de6..617b36491 100644 --- a/glotaran/project/project.py +++ b/glotaran/project/project.py @@ -35,6 +35,8 @@ if TYPE_CHECKING: from collections.abc import Hashable + from glotaran.typing.types import LoadableDataset + TEMPLATE = "version: {gta_version}" PROJECT_FILE_NAME = "project.gta" @@ -203,7 +205,7 @@ def load_data( def import_data( self, - dataset: str | Path | xr.Dataset | xr.DataArray, + dataset: LoadableDataset | Mapping[str, LoadableDataset], dataset_name: str | None = None, allow_overwrite: bool = False, ignore_existing: bool = True, @@ -212,7 +214,7 @@ def import_data( Parameters ---------- - dataset : str | Path | xr.Dataset | xr.DataArray + dataset : LoadableDataset Dataset instance or path to a dataset. dataset_name : str | None The name of the dataset (needs to be provided when dataset is an xarray instance). @@ -222,12 +224,16 @@ def import_data( ignore_existing: bool Whether to ignore import if the dataset already exists. Defaults to ``True``. """ - self._data_registry.import_data( - dataset, - dataset_name=dataset_name, - allow_overwrite=allow_overwrite, - ignore_existing=ignore_existing, - ) + if not isinstance(dataset, Mapping) or isinstance(dataset, (xr.Dataset, xr.DataArray)): + dataset = {dataset_name: dataset} + + for key, value in dataset.items(): + self._data_registry.import_data( + value, + dataset_name=key, + allow_overwrite=allow_overwrite, + ignore_existing=ignore_existing, + ) @property def has_models(self) -> bool: diff --git a/glotaran/project/test/test_project.py b/glotaran/project/test/test_project.py index 67a1e4913..e0cdbeb70 100644 --- a/glotaran/project/test/test_project.py +++ b/glotaran/project/test/test_project.py @@ -16,6 +16,7 @@ from glotaran import __version__ as gta_version from glotaran.builtin.io.yml.utils import load_dict +from glotaran.io import load_dataset from glotaran.io import load_parameters from glotaran.io import save_dataset from glotaran.io import save_parameters @@ -288,6 +289,29 @@ def test_import_data_xarray(tmp_path: Path, data: xr.Dataset | xr.DataArray): assert project.load_data("test_data").equals(xr.Dataset({"data": xr.DataArray([1])})) +@pytest.mark.parametrize( + "data", + ( + xr.DataArray([1]), + xr.Dataset({"data": xr.DataArray([1])}), + ), +) +def test_import_data_mapping(tmp_path: Path, data: xr.Dataset | xr.DataArray): + """Import data as a mapping""" + project = Project.open(tmp_path) + + test_data = tmp_path / "import_data.nc" + save_dataset(example_dataset, test_data) + + project.import_data({"test_data_1": data, "test_data_2": test_data}) + + assert (tmp_path / "data/test_data_1.nc").is_file() is True + assert (tmp_path / "data/test_data_2.nc").is_file() is True + + assert project.load_data("test_data_1").equals(xr.Dataset({"data": xr.DataArray([1])})) + assert project.load_data("test_data_2").equals(load_dataset(test_data)) + + def test_create_scheme(existing_project: Project): scheme = existing_project.create_scheme( model_name="test_model",