Skip to content

Commit

Permalink
👌 Improve Project API data handling (#1257)
Browse files Browse the repository at this point in the history
This PR improves the data handling of the Project API in notebook workflows.

Since notebooks are often executed as a whole (`Run All`) the default of `ignore_existing=False` in `project.import_data` leads to using `ignore_existing=True` each time calling it in an actual case study, which makes the code less readable and adds a lot of redundancy and a default of `ignore_existing=True` is more sensible (IMHO `ignore_existing=False` only makes sense in a CLI or GUI context). 

Having to import each dataset one by one also clutters the code quite a lot.
```py
project.import_data("measured_data/Npq2_220219_800target3fasea.ascii", dataset_name="TA")
project.import_data(guide_s1, dataset_name="guide_s1")
project.import_data(guide_s2, dataset_name="guide_s2")
project.import_data(guide_s3, dataset_name="guide_s3")
project.import_data(guide_s4, dataset_name="guide_s4")
project.import_data(guide_s5, dataset_name="guide_s5")
project.import_data(guide_s6, dataset_name="guide_s6")
project.import_data(guide_s7, dataset_name="guide_s7")
project.import_data(guide_s8, dataset_name="guide_s8")
```
This is why it is a lot more convenient to allow the use of a mapping (especially since that mapping could have been defined and used for plotting even before importing glotaran at all)
```py
my_datasets ={
    "TA":"measured_data/Npq2_220219_800target3fasea.ascii"
    "guide_s1":guide_s1,
    "guide_s2":guide_s2,
    "guide_s3":guide_s3,
    "guide_s4":guide_s4,
    "guide_s5":guide_s5,
    "guide_s6":guide_s6,
    "guide_s7":guide_s7,
    "guide_s8":guide_s8,
}
project.import_data(my_datasets)
```

Lastly, users might want to switch out datasets in the optimization without touching the model definition for example to use an averaged dataset to have a quicker feedback loop or to use some other kind of preprocessing/correcting on the data and compare results with the exact same model. 

```py
project.optimize("my_model", "my_parameters", data_lookup_overwrite={"TA": averaged_data})
```

* 👌 Change default of ignore_existing in import_data to True

This allow using a notbook workflow w/o cluttering it with 'ignore_existing=True' all over the place

* 👌 Allow importing data from a mapping

* 👌 Add data_overwrite option to Result.optimize and Result.create_scheme

* 👌 Changed data_overwrite to data_lookup_overwrite to clarify that data are not overwritten

* 🚧📚 Added change to changelog

* 🩹 Fix allow_overwrite=True not having any effect if ignore_existing=True

* 🧹 Renamed data_lookup_overwrite to data_lookup_override

* 👌 Apply suggestions from code review

Co-authored-by: Joris Snellenburg <jsnel@users.noreply.github.com>

---------

Co-authored-by: Joris Snellenburg <jsnel@users.noreply.github.com>
  • Loading branch information
s-weigand and jsnel authored Feb 24, 2023
1 parent 44e9e2a commit ad8f53e
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 19 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- ✨ Allow usage of subfolders in project API for parameters, models and data (#1232)
- ✨ Allow import of xarray objects in project API import_data (#1235)
- 🩹 Add number_of_clps to result and correct degrees_of_freedom calculation (#1249)
- 👌 Improve Project API data handling (#1257)

### 🩹 Bug fixes

Expand Down
60 changes: 46 additions & 14 deletions glotaran/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
if TYPE_CHECKING:
from collections.abc import Hashable

from glotaran.typing.types import LoadableDataset

TEMPLATE = "version: {gta_version}"

PROJECT_FILE_NAME = "project.gta"
Expand Down Expand Up @@ -203,31 +205,36 @@ def load_data(

def import_data(
self,
dataset: str | Path | xr.Dataset | xr.DataArray,
dataset: LoadableDataset | Mapping[str, LoadableDataset],
dataset_name: str | None = None,
allow_overwrite: bool = False,
ignore_existing: bool = False,
ignore_existing: bool = True,
):
"""Import a dataset.
"""Import a dataset by saving it as an .nc file in the project's data folder.
Parameters
----------
dataset : str | Path | xr.Dataset | xr.DataArray
dataset : LoadableDataset
Dataset instance or path to a dataset.
dataset_name : str | None
The name of the dataset (needs to be provided when dataset is an xarray instance).
Defaults to None.
allow_overwrite: bool
Whether to overwrite an existing dataset.
ignore_existing: bool
Whether to ignore import if the dataset already exists.
Whether to skip import if the dataset already exists and allow_overwrite is False.
Defaults to ``True``.
"""
self._data_registry.import_data(
dataset,
dataset_name=dataset_name,
allow_overwrite=allow_overwrite,
ignore_existing=ignore_existing,
)
if not isinstance(dataset, Mapping) or isinstance(dataset, (xr.Dataset, xr.DataArray)):
dataset = {dataset_name: dataset}

for key, value in dataset.items():
self._data_registry.import_data(
value,
dataset_name=key,
allow_overwrite=allow_overwrite,
ignore_existing=ignore_existing,
)

@property
def has_models(self) -> bool:
Expand Down Expand Up @@ -623,6 +630,7 @@ def create_scheme(
parameters_name: str,
maximum_number_function_evaluations: int | None = None,
clp_link_tolerance: float = 0.0,
data_lookup_override: Mapping[str, LoadableDataset] | None = None,
) -> Scheme:
"""Create a scheme for optimization.
Expand All @@ -636,18 +644,33 @@ def create_scheme(
The maximum number of function evaluations.
clp_link_tolerance : float
The CLP link tolerance.
data_lookup_override: Mapping[str, LoadableDataset] | None
Allows to bypass the default dataset lookup in the project ``data`` folder and use a
different dataset for the optimization without changing the model. This is especially
useful when working with preprocessed data. Defaults to ``None``.
Returns
-------
Scheme
The created scheme.
"""
if data_lookup_override is None:
data_lookup_override = {}
loaded_model = self.load_model(model_name)
data = {dataset: self.load_data(dataset) for dataset in loaded_model.dataset}
data_lookup_override = {
dataset_name: dataset_value
for dataset_name, dataset_value in data_lookup_override.items()
if dataset_name in loaded_model.dataset
}
data = {
dataset_name: self.data[dataset_name]
for dataset_name in loaded_model.dataset
if dataset_name not in data_lookup_override
}
return Scheme(
model=loaded_model,
parameters=self.load_parameters(parameters_name),
data=data,
data=data | data_lookup_override,
maximum_number_function_evaluations=maximum_number_function_evaluations,
clp_link_tolerance=clp_link_tolerance,
)
Expand All @@ -659,6 +682,7 @@ def optimize(
result_name: str | None = None,
maximum_number_function_evaluations: int | None = None,
clp_link_tolerance: float = 0.0,
data_lookup_override: Mapping[str, LoadableDataset] | None = None,
) -> Result:
"""Optimize a model.
Expand All @@ -674,6 +698,10 @@ def optimize(
The maximum number of function evaluations.
clp_link_tolerance : float
The CLP link tolerance.
data_lookup_override: Mapping[str, LoadableDataset] | None
Allows to bypass the default dataset lookup in the project ``data`` folder and use a
different dataset for the optimization without changing the model. This is especially
useful when working with preprocessed data. Defaults to ``None``.
Returns
-------
Expand All @@ -683,7 +711,11 @@ def optimize(
from glotaran.optimization.optimize import optimize

scheme = self.create_scheme(
model_name, parameters_name, maximum_number_function_evaluations, clp_link_tolerance
model_name=model_name,
parameters_name=parameters_name,
maximum_number_function_evaluations=maximum_number_function_evaluations,
clp_link_tolerance=clp_link_tolerance,
data_lookup_override=data_lookup_override,
)
result = optimize(scheme)

Expand Down
2 changes: 1 addition & 1 deletion glotaran/project/project_data_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,6 @@ def import_data(
dataset = load_dataset(dataset)

data_path = self.directory / f"{dataset_name}.nc"
if data_path.exists() and ignore_existing:
if data_path.exists() and ignore_existing and allow_overwrite is False:
return
save_dataset(dataset, data_path, allow_overwrite=allow_overwrite)
69 changes: 65 additions & 4 deletions glotaran/project/test/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from glotaran import __version__ as gta_version
from glotaran.builtin.io.yml.utils import load_dict
from glotaran.io import load_dataset
from glotaran.io import load_parameters
from glotaran.io import save_dataset
from glotaran.io import save_parameters
Expand All @@ -30,6 +31,7 @@
from glotaran.testing.simulated_data.sequential_spectral_decay import (
PARAMETERS as example_parameter,
)
from glotaran.typing.types import LoadableDataset
from glotaran.utils.io import chdir_context


Expand Down Expand Up @@ -248,11 +250,10 @@ def test_import_data(tmp_path: Path, name: str | None):
save_dataset(example_dataset, test_data)

project.import_data(test_data, dataset_name=name)
with pytest.raises(FileExistsError):
project.import_data(test_data, dataset_name=name)
project.import_data(test_data, dataset_name=name)

project.import_data(test_data, dataset_name=name, allow_overwrite=True)
project.import_data(test_data, dataset_name=name, ignore_existing=True)
with pytest.raises(FileExistsError):
project.import_data(test_data, dataset_name=name, ignore_existing=False)

data_folder = tmp_path / "test_project/data"

Expand Down Expand Up @@ -288,6 +289,41 @@ def test_import_data_xarray(tmp_path: Path, data: xr.Dataset | xr.DataArray):
assert project.load_data("test_data").equals(xr.Dataset({"data": xr.DataArray([1])}))


def test_import_data_allow_overwrite(existing_project: Project):
"""Overwrite data when ``allow_overwrite==True``."""

dummy_data = xr.Dataset({"data": xr.DataArray([1])})

assert not existing_project.load_data("dataset_1").equals(dummy_data)

existing_project.import_data(dummy_data, dataset_name="dataset_1", allow_overwrite=True)

assert existing_project.load_data("dataset_1").equals(dummy_data)


@pytest.mark.parametrize(
"data",
(
xr.DataArray([1]),
xr.Dataset({"data": xr.DataArray([1])}),
),
)
def test_import_data_mapping(tmp_path: Path, data: xr.Dataset | xr.DataArray):
"""Import data as a mapping"""
project = Project.open(tmp_path)

test_data = tmp_path / "import_data.nc"
save_dataset(example_dataset, test_data)

project.import_data({"test_data_1": data, "test_data_2": test_data})

assert (tmp_path / "data/test_data_1.nc").is_file() is True
assert (tmp_path / "data/test_data_2.nc").is_file() is True

assert project.load_data("test_data_1").equals(xr.Dataset({"data": xr.DataArray([1])}))
assert project.load_data("test_data_2").equals(load_dataset(test_data))


def test_create_scheme(existing_project: Project):
scheme = existing_project.create_scheme(
model_name="test_model",
Expand All @@ -300,6 +336,31 @@ def test_create_scheme(existing_project: Project):
assert scheme.maximum_number_function_evaluations == 1


@pytest.mark.parametrize(
"data",
(xr.DataArray([1]), xr.Dataset({"data": xr.DataArray([1])}), "file"),
)
def test_create_scheme_data_lookup_override(
tmp_path: Path, existing_project: Project, data: LoadableDataset
):
"""Test data_lookup_override functionality."""

if data == "file":
data = tmp_path / "file_data.nc"
save_dataset(xr.Dataset({"data": xr.DataArray([1])}), data)

scheme = existing_project.create_scheme(
model_name="test_model",
parameters_name="test_parameters",
data_lookup_override={"dataset_1": data},
)

assert len(scheme.data) == 1
assert "dataset_1" in scheme.data
assert "dataset_1" in scheme.model.dataset
assert scheme.data["dataset_1"].equals(xr.Dataset({"data": xr.DataArray([1])}))


@pytest.mark.parametrize("result_name", ["test", None])
def test_run_optimization(existing_project: Project, result_name: str | None):
assert existing_project.has_models
Expand Down

0 comments on commit ad8f53e

Please sign in to comment.