Skip to content

Commit

Permalink
👌 Add data_overwrite option to Result.optimize and Result.create_scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
s-weigand committed Feb 24, 2023
1 parent d4293f7 commit 6b6a35d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
31 changes: 28 additions & 3 deletions glotaran/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ def create_scheme(
parameters_name: str,
maximum_number_function_evaluations: int | None = None,
clp_link_tolerance: float = 0.0,
data_overwrite: Mapping[str, LoadableDataset] | None = None,
) -> Scheme:
"""Create a scheme for optimization.
Expand All @@ -642,18 +643,33 @@ def create_scheme(
The maximum number of function evaluations.
clp_link_tolerance : float
The CLP link tolerance.
data_overwrite: Mapping[str, LoadableDataset] | None
Allows to bypass the default dataset lookup in the project ``data`` folder and use a
different dataset for the optimization without changing the model. This especially
useful when working with preprocessed data. Defaults to ``None``.
Returns
-------
Scheme
The created scheme.
"""
if data_overwrite is None:
data_overwrite = {}
loaded_model = self.load_model(model_name)
data = {dataset: self.load_data(dataset) for dataset in loaded_model.dataset}
data_overwrite = {
dataset_name: dataset_value
for dataset_name, dataset_value in data_overwrite.items()
if dataset_name in loaded_model.dataset
}
data = {
dataset_name: self.data[dataset_name]
for dataset_name in loaded_model.dataset
if dataset_name not in data_overwrite
}
return Scheme(
model=loaded_model,
parameters=self.load_parameters(parameters_name),
data=data,
data=data | data_overwrite,
maximum_number_function_evaluations=maximum_number_function_evaluations,
clp_link_tolerance=clp_link_tolerance,
)
Expand All @@ -665,6 +681,7 @@ def optimize(
result_name: str | None = None,
maximum_number_function_evaluations: int | None = None,
clp_link_tolerance: float = 0.0,
data_overwrite: Mapping[str, LoadableDataset] | None = None,
) -> Result:
"""Optimize a model.
Expand All @@ -680,6 +697,10 @@ def optimize(
The maximum number of function evaluations.
clp_link_tolerance : float
The CLP link tolerance.
data_overwrite: Mapping[str, LoadableDataset] | None
Allows to bypass the default dataset lookup in the project ``data`` folder and use a
different dataset for the optimization without changing the model. This especially
useful when working with preprocessed data. Defaults to ``None``.
Returns
-------
Expand All @@ -689,7 +710,11 @@ def optimize(
from glotaran.optimization.optimize import optimize

scheme = self.create_scheme(
model_name, parameters_name, maximum_number_function_evaluations, clp_link_tolerance
model_name=model_name,
parameters_name=parameters_name,
maximum_number_function_evaluations=maximum_number_function_evaluations,
clp_link_tolerance=clp_link_tolerance,
data_overwrite=data_overwrite,
)
result = optimize(scheme)

Expand Down
26 changes: 26 additions & 0 deletions glotaran/project/test/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from glotaran.testing.simulated_data.sequential_spectral_decay import (
PARAMETERS as example_parameter,
)
from glotaran.typing.types import LoadableDataset
from glotaran.utils.io import chdir_context


Expand Down Expand Up @@ -324,6 +325,31 @@ def test_create_scheme(existing_project: Project):
assert scheme.maximum_number_function_evaluations == 1


@pytest.mark.parametrize(
"data",
(xr.DataArray([1]), xr.Dataset({"data": xr.DataArray([1])}), "file"),
)
def test_create_scheme_data_overwrite(
tmp_path: Path, existing_project: Project, data: LoadableDataset
):
"""Test data_overwrite functionality."""

if data == "file":
data = tmp_path / "file_data.nc"
save_dataset(xr.Dataset({"data": xr.DataArray([1])}), data)

scheme = existing_project.create_scheme(
model_name="test_model",
parameters_name="test_parameters",
data_overwrite={"dataset_1": data},
)

assert len(scheme.data) == 1
assert "dataset_1" in scheme.data
assert "dataset_1" in scheme.model.dataset
assert scheme.data["dataset_1"].equals(xr.Dataset({"data": xr.DataArray([1])}))


@pytest.mark.parametrize("result_name", ["test", None])
def test_run_optimization(existing_project: Project, result_name: str | None):
assert existing_project.has_models
Expand Down

0 comments on commit 6b6a35d

Please sign in to comment.