Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ‘Œ Improve Project API data handling #1257

Merged
merged 8 commits into from
Feb 24, 2023
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- ✨ Allow usage of subfolders in project API for parameters, models and data (#1232)
- ✨ Allow import of xarray objects in project API import_data (#1235)
- 🩹 Add number_of_clps to result and correct degrees_of_freedom calculation (#1249)
- πŸ‘Œ Improve Project API data handling (#1257)

### 🩹 Bug fixes

Expand Down
60 changes: 46 additions & 14 deletions glotaran/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
if TYPE_CHECKING:
from collections.abc import Hashable

from glotaran.typing.types import LoadableDataset

TEMPLATE = "version: {gta_version}"

PROJECT_FILE_NAME = "project.gta"
Expand Down Expand Up @@ -203,31 +205,36 @@ def load_data(

def import_data(
self,
dataset: str | Path | xr.Dataset | xr.DataArray,
dataset: LoadableDataset | Mapping[str, LoadableDataset],
dataset_name: str | None = None,
allow_overwrite: bool = False,
ignore_existing: bool = False,
ignore_existing: bool = True,
s-weigand marked this conversation as resolved.
Show resolved Hide resolved
):
"""Import a dataset.
"""Import a dataset by saving it as an .nc file in the project's data folder.

Parameters
----------
dataset : str | Path | xr.Dataset | xr.DataArray
dataset : LoadableDataset
Dataset instance or path to a dataset.
dataset_name : str | None
The name of the dataset (needs to be provided when dataset is an xarray instance).
Defaults to None.
allow_overwrite: bool
Whether to overwrite an existing dataset.
ignore_existing: bool
Whether to ignore import if the dataset already exists.
Whether to skip import if the dataset already exists and allow_overwrite is False.
Defaults to ``True``.
"""
self._data_registry.import_data(
dataset,
dataset_name=dataset_name,
allow_overwrite=allow_overwrite,
ignore_existing=ignore_existing,
)
if not isinstance(dataset, Mapping) or isinstance(dataset, (xr.Dataset, xr.DataArray)):
dataset = {dataset_name: dataset}

for key, value in dataset.items():
self._data_registry.import_data(
value,
dataset_name=key,
allow_overwrite=allow_overwrite,
ignore_existing=ignore_existing,
)

@property
def has_models(self) -> bool:
Expand Down Expand Up @@ -623,6 +630,7 @@ def create_scheme(
parameters_name: str,
maximum_number_function_evaluations: int | None = None,
clp_link_tolerance: float = 0.0,
data_lookup_override: Mapping[str, LoadableDataset] | None = None,
) -> Scheme:
"""Create a scheme for optimization.

Expand All @@ -636,18 +644,33 @@ def create_scheme(
The maximum number of function evaluations.
clp_link_tolerance : float
The CLP link tolerance.
data_lookup_override: Mapping[str, LoadableDataset] | None
Allows to bypass the default dataset lookup in the project ``data`` folder and use a
different dataset for the optimization without changing the model. This is especially
useful when working with preprocessed data. Defaults to ``None``.

Returns
-------
Scheme
The created scheme.
"""
if data_lookup_override is None:
data_lookup_override = {}
loaded_model = self.load_model(model_name)
data = {dataset: self.load_data(dataset) for dataset in loaded_model.dataset}
data_lookup_override = {
dataset_name: dataset_value
for dataset_name, dataset_value in data_lookup_override.items()
if dataset_name in loaded_model.dataset
}
data = {
dataset_name: self.data[dataset_name]
for dataset_name in loaded_model.dataset
if dataset_name not in data_lookup_override
}
return Scheme(
model=loaded_model,
parameters=self.load_parameters(parameters_name),
data=data,
data=data | data_lookup_override,
maximum_number_function_evaluations=maximum_number_function_evaluations,
clp_link_tolerance=clp_link_tolerance,
)
Expand All @@ -659,6 +682,7 @@ def optimize(
result_name: str | None = None,
maximum_number_function_evaluations: int | None = None,
clp_link_tolerance: float = 0.0,
data_lookup_override: Mapping[str, LoadableDataset] | None = None,
) -> Result:
"""Optimize a model.

Expand All @@ -674,6 +698,10 @@ def optimize(
The maximum number of function evaluations.
clp_link_tolerance : float
The CLP link tolerance.
data_lookup_override: Mapping[str, LoadableDataset] | None
Allows to bypass the default dataset lookup in the project ``data`` folder and use a
different dataset for the optimization without changing the model. This is especially
useful when working with preprocessed data. Defaults to ``None``.

Returns
-------
Expand All @@ -683,7 +711,11 @@ def optimize(
from glotaran.optimization.optimize import optimize

scheme = self.create_scheme(
model_name, parameters_name, maximum_number_function_evaluations, clp_link_tolerance
model_name=model_name,
parameters_name=parameters_name,
maximum_number_function_evaluations=maximum_number_function_evaluations,
clp_link_tolerance=clp_link_tolerance,
data_lookup_override=data_lookup_override,
)
result = optimize(scheme)

Expand Down
2 changes: 1 addition & 1 deletion glotaran/project/project_data_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,6 @@ def import_data(
dataset = load_dataset(dataset)

data_path = self.directory / f"{dataset_name}.nc"
if data_path.exists() and ignore_existing:
if data_path.exists() and ignore_existing and allow_overwrite is False:
return
save_dataset(dataset, data_path, allow_overwrite=allow_overwrite)
69 changes: 65 additions & 4 deletions glotaran/project/test/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from glotaran import __version__ as gta_version
from glotaran.builtin.io.yml.utils import load_dict
from glotaran.io import load_dataset
from glotaran.io import load_parameters
from glotaran.io import save_dataset
from glotaran.io import save_parameters
Expand All @@ -30,6 +31,7 @@
from glotaran.testing.simulated_data.sequential_spectral_decay import (
PARAMETERS as example_parameter,
)
from glotaran.typing.types import LoadableDataset
from glotaran.utils.io import chdir_context


Expand Down Expand Up @@ -248,11 +250,10 @@ def test_import_data(tmp_path: Path, name: str | None):
save_dataset(example_dataset, test_data)

project.import_data(test_data, dataset_name=name)
with pytest.raises(FileExistsError):
project.import_data(test_data, dataset_name=name)
project.import_data(test_data, dataset_name=name)

project.import_data(test_data, dataset_name=name, allow_overwrite=True)
project.import_data(test_data, dataset_name=name, ignore_existing=True)
with pytest.raises(FileExistsError):
project.import_data(test_data, dataset_name=name, ignore_existing=False)

data_folder = tmp_path / "test_project/data"

Expand Down Expand Up @@ -288,6 +289,41 @@ def test_import_data_xarray(tmp_path: Path, data: xr.Dataset | xr.DataArray):
assert project.load_data("test_data").equals(xr.Dataset({"data": xr.DataArray([1])}))


def test_import_data_allow_overwrite(existing_project: Project):
"""Overwrite data when ``allow_overwrite==True``."""

dummy_data = xr.Dataset({"data": xr.DataArray([1])})

assert not existing_project.load_data("dataset_1").equals(dummy_data)

existing_project.import_data(dummy_data, dataset_name="dataset_1", allow_overwrite=True)

assert existing_project.load_data("dataset_1").equals(dummy_data)


@pytest.mark.parametrize(
"data",
(
xr.DataArray([1]),
xr.Dataset({"data": xr.DataArray([1])}),
),
)
def test_import_data_mapping(tmp_path: Path, data: xr.Dataset | xr.DataArray):
"""Import data as a mapping"""
project = Project.open(tmp_path)

test_data = tmp_path / "import_data.nc"
save_dataset(example_dataset, test_data)

project.import_data({"test_data_1": data, "test_data_2": test_data})

assert (tmp_path / "data/test_data_1.nc").is_file() is True
assert (tmp_path / "data/test_data_2.nc").is_file() is True

assert project.load_data("test_data_1").equals(xr.Dataset({"data": xr.DataArray([1])}))
assert project.load_data("test_data_2").equals(load_dataset(test_data))


def test_create_scheme(existing_project: Project):
scheme = existing_project.create_scheme(
model_name="test_model",
Expand All @@ -300,6 +336,31 @@ def test_create_scheme(existing_project: Project):
assert scheme.maximum_number_function_evaluations == 1


@pytest.mark.parametrize(
"data",
(xr.DataArray([1]), xr.Dataset({"data": xr.DataArray([1])}), "file"),
)
def test_create_scheme_data_lookup_override(
tmp_path: Path, existing_project: Project, data: LoadableDataset
):
"""Test data_lookup_override functionality."""

if data == "file":
data = tmp_path / "file_data.nc"
save_dataset(xr.Dataset({"data": xr.DataArray([1])}), data)

scheme = existing_project.create_scheme(
model_name="test_model",
parameters_name="test_parameters",
data_lookup_override={"dataset_1": data},
)

assert len(scheme.data) == 1
assert "dataset_1" in scheme.data
assert "dataset_1" in scheme.model.dataset
assert scheme.data["dataset_1"].equals(xr.Dataset({"data": xr.DataArray([1])}))


@pytest.mark.parametrize("result_name", ["test", None])
def test_run_optimization(existing_project: Project, result_name: str | None):
assert existing_project.has_models
Expand Down