diff --git a/changelog.md b/changelog.md index 603ea12ce..58cb4c954 100644 --- a/changelog.md +++ b/changelog.md @@ -23,6 +23,7 @@ - ✨ Allow usage of subfolders in project API for parameters, models and data (#1232) - ✨ Allow import of xarray objects in project API import_data (#1235) - 🩹 Add number_of_clps to result and correct degrees_of_freedom calculation (#1249) +- 👌 Improve Project API data handling (#1257) ### 🩹 Bug fixes diff --git a/glotaran/project/project.py b/glotaran/project/project.py index 7cc5ea179..157ddaf68 100644 --- a/glotaran/project/project.py +++ b/glotaran/project/project.py @@ -35,6 +35,8 @@ if TYPE_CHECKING: from collections.abc import Hashable + from glotaran.typing.types import LoadableDataset + TEMPLATE = "version: {gta_version}" PROJECT_FILE_NAME = "project.gta" @@ -203,16 +205,16 @@ def load_data( def import_data( self, - dataset: str | Path | xr.Dataset | xr.DataArray, + dataset: LoadableDataset | Mapping[str, LoadableDataset], dataset_name: str | None = None, allow_overwrite: bool = False, - ignore_existing: bool = False, + ignore_existing: bool = True, ): - """Import a dataset. + """Import a dataset by saving it as an .nc file in the project's data folder. Parameters ---------- - dataset : str | Path | xr.Dataset | xr.DataArray + dataset : LoadableDataset Dataset instance or path to a dataset. dataset_name : str | None The name of the dataset (needs to be provided when dataset is an xarray instance). @@ -220,14 +222,19 @@ def import_data( allow_overwrite: bool Whether to overwrite an existing dataset. ignore_existing: bool - Whether to ignore import if the dataset already exists. + Whether to skip import if the dataset already exists and allow_overwrite is False. + Defaults to ``True``. """ - self._data_registry.import_data( - dataset, - dataset_name=dataset_name, - allow_overwrite=allow_overwrite, - ignore_existing=ignore_existing, - ) + if not isinstance(dataset, Mapping) or isinstance(dataset, (xr.Dataset, xr.DataArray)): + dataset = {dataset_name: dataset} + + for key, value in dataset.items(): + self._data_registry.import_data( + value, + dataset_name=key, + allow_overwrite=allow_overwrite, + ignore_existing=ignore_existing, + ) @property def has_models(self) -> bool: @@ -623,6 +630,7 @@ def create_scheme( parameters_name: str, maximum_number_function_evaluations: int | None = None, clp_link_tolerance: float = 0.0, + data_lookup_override: Mapping[str, LoadableDataset] | None = None, ) -> Scheme: """Create a scheme for optimization. @@ -636,18 +644,33 @@ def create_scheme( The maximum number of function evaluations. clp_link_tolerance : float The CLP link tolerance. + data_lookup_override: Mapping[str, LoadableDataset] | None + Allows to bypass the default dataset lookup in the project ``data`` folder and use a + different dataset for the optimization without changing the model. This is especially + useful when working with preprocessed data. Defaults to ``None``. Returns ------- Scheme The created scheme. """ + if data_lookup_override is None: + data_lookup_override = {} loaded_model = self.load_model(model_name) - data = {dataset: self.load_data(dataset) for dataset in loaded_model.dataset} + data_lookup_override = { + dataset_name: dataset_value + for dataset_name, dataset_value in data_lookup_override.items() + if dataset_name in loaded_model.dataset + } + data = { + dataset_name: self.data[dataset_name] + for dataset_name in loaded_model.dataset + if dataset_name not in data_lookup_override + } return Scheme( model=loaded_model, parameters=self.load_parameters(parameters_name), - data=data, + data=data | data_lookup_override, maximum_number_function_evaluations=maximum_number_function_evaluations, clp_link_tolerance=clp_link_tolerance, ) @@ -659,6 +682,7 @@ def optimize( result_name: str | None = None, maximum_number_function_evaluations: int | None = None, clp_link_tolerance: float = 0.0, + data_lookup_override: Mapping[str, LoadableDataset] | None = None, ) -> Result: """Optimize a model. @@ -674,6 +698,10 @@ def optimize( The maximum number of function evaluations. clp_link_tolerance : float The CLP link tolerance. + data_lookup_override: Mapping[str, LoadableDataset] | None + Allows to bypass the default dataset lookup in the project ``data`` folder and use a + different dataset for the optimization without changing the model. This is especially + useful when working with preprocessed data. Defaults to ``None``. Returns ------- @@ -683,7 +711,11 @@ def optimize( from glotaran.optimization.optimize import optimize scheme = self.create_scheme( - model_name, parameters_name, maximum_number_function_evaluations, clp_link_tolerance + model_name=model_name, + parameters_name=parameters_name, + maximum_number_function_evaluations=maximum_number_function_evaluations, + clp_link_tolerance=clp_link_tolerance, + data_lookup_override=data_lookup_override, ) result = optimize(scheme) diff --git a/glotaran/project/project_data_registry.py b/glotaran/project/project_data_registry.py index 34978d41f..111dddef5 100644 --- a/glotaran/project/project_data_registry.py +++ b/glotaran/project/project_data_registry.py @@ -73,6 +73,6 @@ def import_data( dataset = load_dataset(dataset) data_path = self.directory / f"{dataset_name}.nc" - if data_path.exists() and ignore_existing: + if data_path.exists() and ignore_existing and allow_overwrite is False: return save_dataset(dataset, data_path, allow_overwrite=allow_overwrite) diff --git a/glotaran/project/test/test_project.py b/glotaran/project/test/test_project.py index 9967e41b2..943198baa 100644 --- a/glotaran/project/test/test_project.py +++ b/glotaran/project/test/test_project.py @@ -16,6 +16,7 @@ from glotaran import __version__ as gta_version from glotaran.builtin.io.yml.utils import load_dict +from glotaran.io import load_dataset from glotaran.io import load_parameters from glotaran.io import save_dataset from glotaran.io import save_parameters @@ -30,6 +31,7 @@ from glotaran.testing.simulated_data.sequential_spectral_decay import ( PARAMETERS as example_parameter, ) +from glotaran.typing.types import LoadableDataset from glotaran.utils.io import chdir_context @@ -248,11 +250,10 @@ def test_import_data(tmp_path: Path, name: str | None): save_dataset(example_dataset, test_data) project.import_data(test_data, dataset_name=name) - with pytest.raises(FileExistsError): - project.import_data(test_data, dataset_name=name) + project.import_data(test_data, dataset_name=name) - project.import_data(test_data, dataset_name=name, allow_overwrite=True) - project.import_data(test_data, dataset_name=name, ignore_existing=True) + with pytest.raises(FileExistsError): + project.import_data(test_data, dataset_name=name, ignore_existing=False) data_folder = tmp_path / "test_project/data" @@ -288,6 +289,41 @@ def test_import_data_xarray(tmp_path: Path, data: xr.Dataset | xr.DataArray): assert project.load_data("test_data").equals(xr.Dataset({"data": xr.DataArray([1])})) +def test_import_data_allow_overwrite(existing_project: Project): + """Overwrite data when ``allow_overwrite==True``.""" + + dummy_data = xr.Dataset({"data": xr.DataArray([1])}) + + assert not existing_project.load_data("dataset_1").equals(dummy_data) + + existing_project.import_data(dummy_data, dataset_name="dataset_1", allow_overwrite=True) + + assert existing_project.load_data("dataset_1").equals(dummy_data) + + +@pytest.mark.parametrize( + "data", + ( + xr.DataArray([1]), + xr.Dataset({"data": xr.DataArray([1])}), + ), +) +def test_import_data_mapping(tmp_path: Path, data: xr.Dataset | xr.DataArray): + """Import data as a mapping""" + project = Project.open(tmp_path) + + test_data = tmp_path / "import_data.nc" + save_dataset(example_dataset, test_data) + + project.import_data({"test_data_1": data, "test_data_2": test_data}) + + assert (tmp_path / "data/test_data_1.nc").is_file() is True + assert (tmp_path / "data/test_data_2.nc").is_file() is True + + assert project.load_data("test_data_1").equals(xr.Dataset({"data": xr.DataArray([1])})) + assert project.load_data("test_data_2").equals(load_dataset(test_data)) + + def test_create_scheme(existing_project: Project): scheme = existing_project.create_scheme( model_name="test_model", @@ -300,6 +336,31 @@ def test_create_scheme(existing_project: Project): assert scheme.maximum_number_function_evaluations == 1 +@pytest.mark.parametrize( + "data", + (xr.DataArray([1]), xr.Dataset({"data": xr.DataArray([1])}), "file"), +) +def test_create_scheme_data_lookup_override( + tmp_path: Path, existing_project: Project, data: LoadableDataset +): + """Test data_lookup_override functionality.""" + + if data == "file": + data = tmp_path / "file_data.nc" + save_dataset(xr.Dataset({"data": xr.DataArray([1])}), data) + + scheme = existing_project.create_scheme( + model_name="test_model", + parameters_name="test_parameters", + data_lookup_override={"dataset_1": data}, + ) + + assert len(scheme.data) == 1 + assert "dataset_1" in scheme.data + assert "dataset_1" in scheme.model.dataset + assert scheme.data["dataset_1"].equals(xr.Dataset({"data": xr.DataArray([1])})) + + @pytest.mark.parametrize("result_name", ["test", None]) def test_run_optimization(existing_project: Project, result_name: str | None): assert existing_project.has_models