diff --git a/glotaran/project/project.py b/glotaran/project/project.py index cc098e0f5..8e6bc8be5 100644 --- a/glotaran/project/project.py +++ b/glotaran/project/project.py @@ -149,7 +149,7 @@ def data(self) -> dict[str, Path]: """ return self._data_registry.items - def load_data(self, dataset_name: str) -> xr.Dataset | xr.DataArray: + def load_data(self, dataset_name: str) -> xr.Dataset: """Load a dataset. Parameters @@ -170,11 +170,14 @@ def load_data(self, dataset_name: str) -> xr.Dataset | xr.DataArray: .. # noqa: DAR402 """ - return self._data_registry.load_item(dataset_name) + dataset = self._data_registry.load_item(dataset_name) + if isinstance(dataset, xr.DataArray): + dataset = dataset.to_dataset(name="data") + return dataset def import_data( self, - path: str | Path, + dataset: str | Path | xr.Dataset | xr.DataArray, name: str | None = None, allow_overwrite: bool = False, ignore_existing: bool = False, @@ -183,17 +186,18 @@ def import_data( Parameters ---------- - path : str | Path - The path to the dataset. + dataset : str | Path | xr.Dataset | xr.DataArray + Dataset instance or path to a dataset. name : str | None - The name of the dataset. + The name of the dataset (needs to be provided when dataset is an xarray instance). + Defaults to None. allow_overwrite: bool Whether to overwrite an existing dataset. ignore_existing: bool Whether to ignore import if the dataset already exists. """ self._data_registry.import_data( - path, name=name, allow_overwrite=allow_overwrite, ignore_existing=ignore_existing + dataset, name=name, allow_overwrite=allow_overwrite, ignore_existing=ignore_existing ) @property diff --git a/glotaran/project/project_data_registry.py b/glotaran/project/project_data_registry.py index c0414e3f4..8461b1b2e 100644 --- a/glotaran/project/project_data_registry.py +++ b/glotaran/project/project_data_registry.py @@ -3,6 +3,8 @@ from pathlib import Path +import xarray as xr + from glotaran.io import load_dataset from glotaran.io import save_dataset from glotaran.plugin_system.data_io_registration import supported_file_extensions_data_io @@ -29,7 +31,7 @@ def __init__(self, directory: Path): def import_data( self, - path: str | Path, + dataset: str | Path | xr.Dataset | xr.DataArray, name: str | None = None, allow_overwrite: bool = False, ignore_existing: bool = False, @@ -38,25 +40,39 @@ def import_data( Parameters ---------- - path : str | Path - The path to the dataset. + dataset : str | Path | xr.Dataset | xr.DataArray + Dataset instance or path to a dataset. name : str | None - The name of the dataset. + The name of the dataset (needs to be provided when dataset is an xarray instance). + Defaults to None. allow_overwrite: bool Whether to overwrite an existing dataset. ignore_existing: bool Whether to ignore import if the dataset already exists. + + Raises + ------ + ValueError + When importing from xarray object and not providing a name. """ - path = Path(path) + if isinstance(dataset, (xr.DataArray, xr.Dataset)) and name is None: + raise ValueError( + "When importing data from a 'xarray.Dataset' or 'xarray.DataArray' " + "it is required to also pass a name." + ) + if isinstance(dataset, xr.DataArray): + dataset = dataset.to_dataset(name="data") - if path.is_absolute() is False: - path = (self.directory.parent / path).resolve() + if isinstance(dataset, (str, Path)): + dataset = Path(dataset) - name = name or path.stem - data_path = self.directory / f"{name}.nc" + if dataset.is_absolute() is False: + dataset = (self.directory.parent / dataset).resolve() + name = name or dataset.stem + dataset = load_dataset(dataset) + + data_path = self.directory / f"{name}.nc" if data_path.exists() and ignore_existing: return - - dataset = load_dataset(path) save_dataset(dataset, data_path, allow_overwrite=allow_overwrite) diff --git a/glotaran/project/test/test_project.py b/glotaran/project/test/test_project.py index 6f97fdf17..9ea21e299 100644 --- a/glotaran/project/test/test_project.py +++ b/glotaran/project/test/test_project.py @@ -8,6 +8,7 @@ from typing import Literal import pytest +import xarray as xr from _pytest.monkeypatch import MonkeyPatch from _pytest.recwarn import WarningsRecorder from IPython.core.formatters import format_display_data @@ -197,6 +198,23 @@ def test_import_data(project_folder: Path, project_file: Path, test_data: Path, assert data == example_dataset +@pytest.mark.parametrize( + "data", + ( + xr.DataArray([1]), + xr.Dataset({"data": xr.DataArray([1])}), + ), +) +def test_import_data_xarray(tmp_path: Path, data: xr.Dataset | xr.DataArray): + """Loaded data are always a dataset.""" + project = Project.open(tmp_path) + project.import_data(data, name="test_data") + + assert (tmp_path / "data/test_data.nc").is_file() is True + + assert project.load_data("test_data").equals(xr.Dataset({"data": xr.DataArray([1])})) + + def test_create_scheme(project_file: Path): project = Project.open(project_file) @@ -527,6 +545,22 @@ def test_missing_file_errors(tmp_path: Path, project_folder: Path): project = Project.open(project_folder) + with pytest.raises(ValueError) as exc_info: + project.import_data(xr.Dataset({"data": [1]})) + + assert str(exc_info.value) == ( + "When importing data from a 'xarray.Dataset' or 'xarray.DataArray' " + "it is required to also pass a name." + ) + + with pytest.raises(ValueError) as exc_info: + project.import_data(xr.DataArray([1])) + + assert str(exc_info.value) == ( + "When importing data from a 'xarray.Dataset' or 'xarray.DataArray' " + "it is required to also pass a name." + ) + with pytest.raises(ValueError) as exc_info: project.load_data("not-existing")