From d1e36a93d2c75ec1a93d4eacea5eba2d49b6f130 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Wei=C3=9Fenborn?= Date: Tue, 12 Oct 2021 20:08:13 +0200 Subject: [PATCH] Refactor Result Saving (#841) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR changes the way results are saved and enables loading and validating results. Change summary - Added custom dataclass serialisation - Models are serializable now - Results can be loaded - Added save/load_result_file - Parameter History added to result Commits: * Added project to darglint, mypy and pydocstyle pre-commit checks * Added Model.as_dict and Model.get_parameters * Added project/dataclass, changed scheme, adapted yml io * Changed result use project/dataclass * Added test for save model yml * Added load and save result file and implemented it with yml * Refactored folder plugin, moved SavingOptions. * Update glotaran/project/result.py * Made variable names in model more consistent. * Added parameter history class and added it to result. * 🔧 Configure darglint to ignore protocol methods * Fixed Parameter doc. * ♻️ Refactored by Sourcery * 🧹 Partial revert of eb430c * 🧹 Partial revert of 761d4b * 🧹 Restored original behavior of save_result folder plugin + added files As a side effect it sets the file paths of saved files in the Result object. * 🧹 Removed unused variable name in yml save_model looping over dict * 🩹 Fixed wrong typing usage of any builtin function * 🔧👌 Activated mypy for parameter subpackage and fixed typing issues * 🔧👌🩹 Activated mypy for projectsubpackage and fixed typing issues * 🩹 Fix bug in _create_result when optimizer fails * ♻️ Rename ParameterHistory.{number_records,number_of_records) * 🔧 Raise interrogate threshold from 55% to 59% (current 60.7%) * 🧹 Removed obsolete darglint ignore comments * 🧹Fix typos * 👌 Add annotation to __str__ method * ♻️ Rename project.dataclasses to project.dataclass_helpers * Update glotaran/analysis/optimize.py * Update glotaran/parameter/parameter_group.py * Update glotaran/parameter/parameter_group.py * ♻️ Refactor glotaran/builtin/io/yml/test/test_save_model.py to use tmp_path instead of tmpdir * 🚇 🔧 Skip interrogate in pre-commit CI * 🧹 Renamed 'test_dataclasses.py' to 'test_dataclass_helpers.py' Co-authored-by: Sebastian Weigand Co-authored-by: Sourcery AI <> Co-authored-by: Joris Snellenburg --- .pre-commit-config.yaml | 12 +- glotaran/analysis/optimize.py | 15 +- glotaran/analysis/problem.py | 24 +- glotaran/analysis/problem_grouped.py | 9 +- glotaran/analysis/problem_ungrouped.py | 4 +- glotaran/builtin/io/folder/folder_plugin.py | 105 ++++-- .../io/folder/test/test_folder_plugin.py | 35 +- glotaran/builtin/io/netCDF/netCDF.py | 21 +- .../builtin/io/yml/test/test_save_model.py | 63 ++++ .../builtin/io/yml/test/test_save_result.py | 7 +- .../builtin/io/yml/test/test_save_scheme.py | 58 +++ glotaran/builtin/io/yml/yml.py | 216 +++++------ .../modules/test/test_project_result.py | 17 - ...roject_sheme.py => test_project_scheme.py} | 10 +- glotaran/examples/sequential.py | 2 + glotaran/io/__init__.py | 3 + glotaran/io/interface.py | 2 +- glotaran/model/dataset_model.py | 2 +- glotaran/model/item.py | 32 ++ glotaran/model/model.py | 162 +++++---- glotaran/model/property.py | 45 ++- glotaran/model/test/test_model.py | 57 ++- glotaran/parameter/__init__.py | 9 +- glotaran/parameter/parameter.py | 335 ++++++++++++------ glotaran/parameter/parameter_group.py | 288 +++++++++++---- glotaran/parameter/parameter_history.py | 158 +++++++++ .../parameter/test/test_parameter_history.py | 39 ++ .../plugin_system/data_io_registration.py | 3 + .../plugin_system/project_io_registration.py | 23 +- glotaran/project/__init__.py | 4 +- glotaran/project/dataclass_helpers.py | 125 +++++++ glotaran/project/result.py | 253 +++++++++---- glotaran/project/scheme.py | 116 ++++-- .../project/test/test_dataclass_helpers.py | 37 ++ glotaran/project/test/test_result.py | 20 ++ glotaran/project/test/test_scheme.py | 24 +- pyproject.toml | 2 +- requirements_dev.txt | 1 + setup.cfg | 8 +- 39 files changed, 1729 insertions(+), 617 deletions(-) create mode 100644 glotaran/builtin/io/yml/test/test_save_model.py create mode 100644 glotaran/builtin/io/yml/test/test_save_scheme.py rename glotaran/deprecation/modules/test/{test_project_sheme.py => test_project_scheme.py} (87%) create mode 100644 glotaran/parameter/parameter_history.py create mode 100644 glotaran/parameter/test/test_parameter_history.py create mode 100644 glotaran/project/dataclass_helpers.py create mode 100644 glotaran/project/test/test_dataclass_helpers.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2478c3ee0..12c35fa8a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,6 @@ +ci: + skip: [interrogate] + repos: # Formatters - repo: https://github.com/pre-commit/pre-commit-hooks @@ -77,7 +80,7 @@ repos: rev: 6.1.1 hooks: - id: pydocstyle - files: "^glotaran/(plugin_system|utils|deprecation|testing)" + files: "^glotaran/(plugin_system|utils|deprecation|testing|parameter|project)" exclude: "docs|tests?/" # this is needed due to the following issue: # https://github.com/PyCQA/pydocstyle/issues/368 @@ -87,14 +90,14 @@ repos: rev: v1.8.0 hooks: - id: darglint - files: "^glotaran/(plugin_system|utils|deprecation|testing)" + files: "^glotaran/(plugin_system|utils|deprecation|testing|parameter|project)" exclude: "docs|tests?/" - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.910 hooks: - id: mypy - files: "^glotaran/(plugin_system|utils|deprecation|testing)" + files: "^glotaran/(plugin_system|utils|deprecation|testing|parameter|project)" exclude: "docs" additional_dependencies: [types-all] @@ -112,6 +115,7 @@ repos: types: [file] types_or: [python, pyi] additional_dependencies: [flake8-docstrings, flake8-print] + exclude: "parameter.py" - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 @@ -133,7 +137,7 @@ repos: hooks: - id: rst-backticks - id: python-check-blanket-noqa - exclude: "docs|tests?" + exclude: "parameter.py|docs|tests?" - id: python-check-blanket-type-ignore exclude: "docs|tests?" - id: python-use-type-annotations diff --git a/glotaran/analysis/optimize.py b/glotaran/analysis/optimize.py index a07169f17..68c6377eb 100644 --- a/glotaran/analysis/optimize.py +++ b/glotaran/analysis/optimize.py @@ -6,6 +6,7 @@ from scipy.optimize import OptimizeResult from scipy.optimize import least_squares +from glotaran import __version__ as glotaran_version from glotaran.analysis.problem import Problem from glotaran.analysis.problem_grouped import GroupedProblem from glotaran.analysis.problem_ungrouped import UngroupedProblem @@ -91,23 +92,21 @@ def _create_result( success = ls_result is not None number_of_function_evaluation = ( - ls_result.nfev if ls_result is not None else len(problem.parameter_history) + ls_result.nfev if success else problem.parameter_history.number_of_records ) number_of_jacobian_evaluation = ls_result.njev if success else None - optimality = ls_result.optimality if success else None + optimality = float(ls_result.optimality) if success else None number_of_data_points = ls_result.fun.size if success else None number_of_variables = ls_result.x.size if success else None degrees_of_freedom = number_of_data_points - number_of_variables if success else None - chi_square = np.sum(ls_result.fun ** 2) if success else None + chi_square = float(np.sum(ls_result.fun ** 2)) if success else None reduced_chi_square = chi_square / degrees_of_freedom if success else None - root_mean_square_error = np.sqrt(reduced_chi_square) if success else None + root_mean_square_error = float(np.sqrt(reduced_chi_square)) if success else None jacobian = ls_result.jac if success else None if success: problem.parameters.set_from_label_and_value_arrays(free_parameter_labels, ls_result.x) - problem.reset() - history_index = None if success else -2 - data = problem.create_result_data(history_index=history_index) + data = problem.create_result_data(success) # the optimized parameters are those of the last run if the optimization has crashed parameters = problem.parameters covariance_matrix = None @@ -125,10 +124,12 @@ def _create_result( additional_penalty=problem.additional_penalty, cost=problem.cost, data=data, + glotaran_version=glotaran_version, free_parameter_labels=free_parameter_labels, number_of_function_evaluations=number_of_function_evaluation, initial_parameters=problem.scheme.parameters, optimized_parameters=parameters, + parameter_history=problem.parameter_history, scheme=problem.scheme, success=success, termination_reason=termination_reason, diff --git a/glotaran/analysis/problem.py b/glotaran/analysis/problem.py index 4d10d4c31..91cf0fadc 100644 --- a/glotaran/analysis/problem.py +++ b/glotaran/analysis/problem.py @@ -16,13 +16,19 @@ from glotaran.model import DatasetModel from glotaran.model import Model from glotaran.parameter import ParameterGroup +from glotaran.parameter import ParameterHistory from glotaran.project import Scheme if TYPE_CHECKING: from typing import Hashable -class ParameterError(ValueError): +class InitialParameterError(ValueError): + def __init__(self): + super().__init__("Initial parameters can not be evaluated.") + + +class ParameterNotInitializedError(ValueError): def __init__(self): super().__init__("Parameter not initialized") @@ -83,7 +89,7 @@ def __init__(self, scheme: Scheme): self._overwrite_index_dependent = self.model.need_index_dependent() self._parameters = scheme.parameters.copy() - self._parameter_history = [] + self._parameter_history = ParameterHistory() self._model.validate(raise_exception=True) @@ -140,7 +146,7 @@ def parameters(self, parameters: ParameterGroup): self.reset() @property - def parameter_history(self) -> list[ParameterGroup]: + def parameter_history(self) -> ParameterHistory: return self._parameter_history @property @@ -318,13 +324,15 @@ def _add_weight(self, label, dataset): ) dataset.weight[idx] *= weight.value - def create_result_data( - self, copy: bool = True, history_index: int | None = None - ) -> dict[str, xr.Dataset]: + def create_result_data(self, copy: bool = True, success: bool = True) -> dict[str, xr.Dataset]: - if history_index is not None and history_index != -1: - self.parameters = self.parameter_history[history_index] + if not success: + if self.parameter_history.number_of_records > 1: + self.parameters.set_from_history(self.parameter_history, -2) + else: + raise InitialParameterError() + self.reset() self.prepare_result_creation() result_data = {} for label, dataset_model in self.dataset_models.items(): diff --git a/glotaran/analysis/problem_grouped.py b/glotaran/analysis/problem_grouped.py index e26f33bea..31c047f5a 100644 --- a/glotaran/analysis/problem_grouped.py +++ b/glotaran/analysis/problem_grouped.py @@ -2,13 +2,14 @@ import collections import itertools +from typing import Any from typing import Deque import numpy as np import xarray as xr from glotaran.analysis.problem import GroupedProblemDescriptor -from glotaran.analysis.problem import ParameterError +from glotaran.analysis.problem import ParameterNotInitializedError from glotaran.analysis.problem import Problem from glotaran.analysis.problem import ProblemGroup from glotaran.analysis.util import CalculatedMatrix @@ -192,7 +193,7 @@ def groups(self) -> dict[str, list[str]]: def calculate_matrices(self): if self._parameters is None: - raise ParameterError + raise ParameterNotInitializedError if self._index_dependent: self.calculate_index_dependent_matrices() else: @@ -308,7 +309,7 @@ def _index_dependent_residual( problem: ProblemGroup, matrix: CalculatedMatrix, clp_labels: str, - index: any, + index: Any, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: reduced_clp_labels = matrix.clp_labels @@ -338,7 +339,7 @@ def _index_dependent_residual( ) return clp_labels, clps, weighted_residual, residual - def _index_independent_residual(self, problem: ProblemGroup, index: any): + def _index_independent_residual(self, problem: ProblemGroup, index: Any): matrix = self.reduced_matrices[problem.group] reduced_clp_labels = matrix.clp_labels matrix = matrix.matrix.copy() diff --git a/glotaran/analysis/problem_ungrouped.py b/glotaran/analysis/problem_ungrouped.py index 8afd5f164..dbe0df5ae 100644 --- a/glotaran/analysis/problem_ungrouped.py +++ b/glotaran/analysis/problem_ungrouped.py @@ -3,7 +3,7 @@ import numpy as np import xarray as xr -from glotaran.analysis.problem import ParameterError +from glotaran.analysis.problem import ParameterNotInitializedError from glotaran.analysis.problem import Problem from glotaran.analysis.util import CalculatedMatrix from glotaran.analysis.util import apply_weight @@ -51,7 +51,7 @@ def calculate_matrices( ]: """Calculates the model matrices.""" if self._parameters is None: - raise ParameterError + raise ParameterNotInitializedError self._matrices = {} self._global_matrices = {} diff --git a/glotaran/builtin/io/folder/folder_plugin.py b/glotaran/builtin/io/folder/folder_plugin.py index 743056851..1745ec7fe 100644 --- a/glotaran/builtin/io/folder/folder_plugin.py +++ b/glotaran/builtin/io/folder/folder_plugin.py @@ -6,13 +6,19 @@ from __future__ import annotations -import os +from pathlib import Path from typing import TYPE_CHECKING +from glotaran.io import save_dataset +from glotaran.io import save_model +from glotaran.io import save_parameters +from glotaran.io import save_scheme from glotaran.io.interface import ProjectIoInterface +from glotaran.plugin_system.project_io_registration import SAVING_OPTIONS_DEFAULT from glotaran.plugin_system.project_io_registration import register_project_io if TYPE_CHECKING: + from glotaran.plugin_system.project_io_registration import SavingOptions from glotaran.project import Result @@ -24,21 +30,39 @@ class FolderProjectIo(ProjectIoInterface): a markdown summary output and the important data saved to files. """ - def save_result(self, result: Result, result_path: str) -> list[str]: + def save_result( + self, + result: Result, + result_path: str, + *, + saving_options: SavingOptions = SAVING_OPTIONS_DEFAULT, + ) -> list[str]: """Save the result to a given folder. Returns a list with paths of all saved items. - The following files are saved: + The following files are saved if not configured otherwise: * `result.md`: The result with the model formatted as markdown text. + * `model.yml`: Model spec file. + * `scheme.yml`: Scheme spec file. + * `initial_parameters.csv`: Initially used parameters. * `optimized_parameters.csv`: The optimized parameter as csv file. + * `parameter_history.csv`: Parameter changes over the optimization * `{dataset_label}.nc`: The result data for each dataset as NetCDF file. + Note + ---- + As a side effect it populates the file path properties of ``result`` which can be + used in other plugins (e.g. the ``yml`` save_result). + Parameters ---------- result : Result Result instance to be saved. result_path : str The path to the folder in which to save the result. + saving_options : SavingOptions + Options for saving the the result. + Returns ------- @@ -50,25 +74,64 @@ def save_result(self, result: Result, result_path: str) -> list[str]: ValueError If ``result_path`` is a file. """ - if not os.path.exists(result_path): - os.makedirs(result_path) - if not os.path.isdir(result_path): - raise ValueError(f"The path '{result_path}' is not a directory.") + result_folder = Path(result_path) + if result_folder.is_file(): + raise ValueError(f"The path '{result_folder}' is not a directory.") + result_folder.mkdir(parents=True, exist_ok=True) paths = [] - - md_path = os.path.join(result_path, "result.md") - with open(md_path, "w") as f: - f.write(str(result.markdown())) - paths.append(md_path) - - csv_path = os.path.join(result_path, "optimized_parameters.csv") - result.optimized_parameters.to_csv(csv_path) - paths.append(csv_path) - - for label, data in result.data.items(): - nc_path = os.path.join(result_path, f"{label}.nc") - data.to_netcdf(nc_path, engine="netcdf4") - paths.append(nc_path) + if saving_options.report: + report_file = result_folder / "result.md" + report_file.write_text(str(result.markdown())) + paths.append(report_file.as_posix()) + + result.scheme.model_file = "model.yml" + save_model( + result.scheme.model, result_folder / result.scheme.model_file, allow_overwrite=True + ) + paths.append((result_folder / result.scheme.model_file).as_posix()) + + result.initial_parameters_file = ( + result.scheme.parameters_file + ) = f"initial_parameters.{saving_options.parameter_format}" + save_parameters( + result.scheme.parameters, + result_folder / result.scheme.parameters_file, + format_name=saving_options.parameter_format, + allow_overwrite=True, + ) + paths.append((result_folder / result.scheme.parameters_file).as_posix()) + + result.optimized_parameters_file = ( + f"optimized_parameters.{saving_options.parameter_format}" + ) + save_parameters( + result.optimized_parameters, + result_folder / result.optimized_parameters_file, + format_name=saving_options.parameter_format, + allow_overwrite=True, + ) + paths.append((result_folder / result.optimized_parameters_file).as_posix()) + + result.scheme_file = "scheme.yml" + save_scheme(result.scheme, result_folder / result.scheme_file, allow_overwrite=True) + paths.append((result_folder / result.scheme_file).as_posix()) + + result.parameter_history_file = "parameter_history.csv" + result.parameter_history.to_csv(result_folder / result.parameter_history_file) + paths.append((result_folder / result.parameter_history_file).as_posix()) + + result.data_files = { + label: f"{label}.{saving_options.data_format}" for label in result.data + } + + for label, data_file in result.data_files.items(): + save_dataset( + result.data[label], + result_folder / data_file, + format_name=saving_options.data_format, + allow_overwrite=True, + ) + paths.append((result_folder / data_file).as_posix()) return paths diff --git a/glotaran/builtin/io/folder/test/test_folder_plugin.py b/glotaran/builtin/io/folder/test/test_folder_plugin.py index 282178f2a..710281d17 100644 --- a/glotaran/builtin/io/folder/test/test_folder_plugin.py +++ b/glotaran/builtin/io/folder/test/test_folder_plugin.py @@ -11,38 +11,47 @@ if TYPE_CHECKING: from typing import Literal - from py.path import local as TmpDir - from glotaran.project.result import Result @pytest.mark.parametrize("format_name", ("folder", "legacy")) def test_save_result_folder( - tmpdir: TmpDir, + tmp_path: Path, dummy_result: Result, # noqa: F811 format_name: Literal["folder", "legacy"], ): """Check all files exist.""" - result_dir = Path(tmpdir / "testresult") - save_result(result_path=str(result_dir), format_name=format_name, result=dummy_result) - - assert (result_dir / "result.md").exists() - assert (result_dir / "optimized_parameters.csv").exists() - assert (result_dir / "dataset1.nc").exists() - assert (result_dir / "dataset2.nc").exists() - assert (result_dir / "dataset3.nc").exists() + result_dir = tmp_path / "testresult" + save_paths = save_result( + result_path=str(result_dir), format_name=format_name, result=dummy_result + ) + + wanted_files = [ + "result.md", + "scheme.yml", + "model.yml", + "initial_parameters.csv", + "optimized_parameters.csv", + "parameter_history.csv", + "dataset1.nc", + "dataset2.nc", + "dataset3.nc", + ] + for wanted in wanted_files: + assert (result_dir / wanted).exists() + assert (result_dir / wanted).as_posix() in save_paths @pytest.mark.parametrize("format_name", ("folder", "legacy")) def test_save_result_folder_error_path_is_file( - tmpdir: TmpDir, + tmp_path: Path, dummy_result: Result, # noqa: F811 format_name: Literal["folder", "legacy"], ): """Raise error if result_path is a file without extension and overwrite is true.""" - result_dir = Path(tmpdir / "testresult") + result_dir = tmp_path / "testresult" result_dir.touch() with pytest.raises(ValueError, match="The path '.+?' is not a directory."): diff --git a/glotaran/builtin/io/netCDF/netCDF.py b/glotaran/builtin/io/netCDF/netCDF.py index 12fc47693..075e87699 100644 --- a/glotaran/builtin/io/netCDF/netCDF.py +++ b/glotaran/builtin/io/netCDF/netCDF.py @@ -4,8 +4,6 @@ from glotaran.io import DataIoInterface from glotaran.io import register_data_io -from glotaran.project import SavingOptions -from glotaran.project import default_data_filters @register_data_io("nc") @@ -19,21 +17,8 @@ def save_dataset( dataset: xr.Dataset, file_name: str, *, - saving_options: SavingOptions = SavingOptions(), + data_filters: list[str] | None = None, ): - data_to_save = dataset - - data_filter = ( - saving_options.data_filter - if saving_options.data_filter is not None - else default_data_filters[saving_options.level] - ) - - if data_filter is not None: - - data_to_save = xr.Dataset() - for item in data_filter: - data_to_save[item] = dataset[item] - - data_to_save.to_netcdf(file_name) + data_to_save = dataset if data_filters is None else dataset[data_filters] + data_to_save.to_netcdf(file_name, mode="w") diff --git a/glotaran/builtin/io/yml/test/test_save_model.py b/glotaran/builtin/io/yml/test/test_save_model.py new file mode 100644 index 000000000..c51b0438e --- /dev/null +++ b/glotaran/builtin/io/yml/test/test_save_model.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from glotaran.examples.sequential import model +from glotaran.io import load_model +from glotaran.io import save_model + +if TYPE_CHECKING: + from pathlib import Path + + +want = """dataset: + dataset1: + initial_concentration: j1 + irf: irf1 + megacomplex: + - m1 +default-megacomplex: decay +initial_concentration: + j1: + compartments: + - s1 + - s2 + - s3 + exclude_from_normalize: [] + parameters: + - j.1 + - j.0 + - j.0 +irf: + irf1: + backsweep: false + center: irf.center + normalize: true + type: gaussian + width: irf.width +k_matrix: + k1: + matrix: + (s2, s1): kinetic.1 + (s3, s2): kinetic.2 + (s3, s3): kinetic.3 +megacomplex: + m1: + dimension: time + k_matrix: + - k1 + type: decay +""" + + +def test_save_model( + tmp_path: Path, +): + """Check all files exist.""" + + model_path = tmp_path / "testmodel.yml" + save_model(file_name=model_path, format_name="yml", model=model) + + assert model_path.is_file() + assert model_path.read_text() == want + assert load_model(model_path).valid() diff --git a/glotaran/builtin/io/yml/test/test_save_result.py b/glotaran/builtin/io/yml/test/test_save_result.py index 1d42ffed2..426f4beae 100644 --- a/glotaran/builtin/io/yml/test/test_save_result.py +++ b/glotaran/builtin/io/yml/test/test_save_result.py @@ -7,19 +7,18 @@ from glotaran.project.test.test_result import dummy_result # noqa: F401 if TYPE_CHECKING: - from py.path import local as TmpDir from glotaran.project.result import Result def test_save_result_yml( - tmpdir: TmpDir, + tmp_path: Path, dummy_result: Result, # noqa: F811 ): """Check all files exist.""" - result_dir = Path(tmpdir / "testresult") - save_result(result_path=result_dir, format_name="yml", result=dummy_result) + result_dir = tmp_path / "testresult" + save_result(result_path=result_dir / "result.yml", result=dummy_result) assert (result_dir / "result.md").exists() assert (result_dir / "scheme.yml").exists() diff --git a/glotaran/builtin/io/yml/test/test_save_scheme.py b/glotaran/builtin/io/yml/test/test_save_scheme.py new file mode 100644 index 000000000..f44e909d2 --- /dev/null +++ b/glotaran/builtin/io/yml/test/test_save_scheme.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import xarray as xr + +from glotaran.examples.sequential import dataset +from glotaran.examples.sequential import model +from glotaran.examples.sequential import parameter +from glotaran.io import load_scheme +from glotaran.io import save_dataset +from glotaran.io import save_model +from glotaran.io import save_parameters +from glotaran.io import save_scheme +from glotaran.project import Scheme + +if TYPE_CHECKING: + from pathlib import Path + + +want = """add_svd: true +data_files: + dataset_1: d.nc +ftol: 1.0e-08 +group: null +group_tolerance: 0.0 +gtol: 1.0e-08 +maximum_number_function_evaluations: null +model_file: m.yml +non_negative_least_squares: false +optimization_method: TrustRegionReflection +parameters_file: p.csv +result_path: null +xtol: 1.0e-08 +""" + + +def test_save_scheme(tmp_path: Path): + scheme = Scheme( + model, + parameter, + {"dataset_1": dataset}, + model_file="m.yml", + parameters_file="p.csv", + data_files={"dataset_1": "d.nc"}, + ) + save_model(model, tmp_path / "m.yml") + save_parameters(parameter, tmp_path / "p.csv") + save_dataset(dataset, tmp_path / "d.nc") + scheme_path = tmp_path / "testscheme.yml" + save_scheme(file_name=scheme_path, format_name="yml", scheme=scheme) + + assert scheme_path.is_file() + assert scheme_path.read_text() == want + loaded = load_scheme(scheme_path) + print(loaded.model.validate(loaded.parameters)) + assert loaded.model.valid(loaded.parameters) + assert isinstance(scheme.data["dataset_1"], xr.Dataset) diff --git a/glotaran/builtin/io/yml/yml.py b/glotaran/builtin/io/yml/yml.py index 540c59824..83863a10b 100644 --- a/glotaran/builtin/io/yml/yml.py +++ b/glotaran/builtin/io/yml/yml.py @@ -1,29 +1,21 @@ from __future__ import annotations -import dataclasses -import os -import pathlib -from typing import TYPE_CHECKING +from pathlib import Path import yaml from glotaran.deprecation.modules.builtin_io_yml import model_spec_deprecations from glotaran.io import ProjectIoInterface -from glotaran.io import load_dataset -from glotaran.io import load_model -from glotaran.io import load_parameters from glotaran.io import register_project_io -from glotaran.io import save_dataset -from glotaran.io import save_parameters +from glotaran.io import save_result from glotaran.model import Model from glotaran.parameter import ParameterGroup -from glotaran.project import SavingOptions +from glotaran.project import Result from glotaran.project import Scheme +from glotaran.project.dataclass_helpers import asdict +from glotaran.project.dataclass_helpers import fromdict from glotaran.utils.sanitize import sanitize_yaml -if TYPE_CHECKING: - from glotaran.project import Result - @register_project_io(["yml", "yaml", "yml_str"]) class YmlProjectIo(ProjectIoInterface): @@ -41,12 +33,7 @@ def load_model(self, file_name: str) -> Model: The content of the file as dictionary. """ - if self.format == "yml_str": - spec = yaml.safe_load(file_name) - - else: - with open(file_name) as f: - spec = yaml.safe_load(f) + spec = self._load_yml(file_name) model_spec_deprecations(spec) @@ -67,13 +54,41 @@ def load_model(self, file_name: str) -> Model: return Model.from_dict(spec, megacomplex_types=None, default_megacomplex_type=None) + def save_model(self, model: Model, file_name: str): + """Save a Model instance to a spec file. + Parameters + ---------- + model: Model + Model instance to save to specs file. + file_name : str + File to write the model specs to. + """ + model_dict = model.as_dict() + # We replace tuples with strings + for items in model_dict.values(): + if not isinstance(items, (list, dict)): + continue + item_iterator = items if isinstance(items, list) else items.values() + for item in item_iterator: + for prop_name, prop in item.items(): + if isinstance(prop, dict) and any(isinstance(k, tuple) for k in prop): + keys = [f"({k[0]}, {k[1]})" for k in prop] + item[prop_name] = {f"{k}": v for k, v in zip(keys, prop.values())} + _write_dict(file_name, model_dict) + def load_parameters(self, file_name: str) -> ParameterGroup: + """Create a ParameterGroup instance from the specs defined in a file. + Parameters + ---------- + file_name : str + File containing the parameter specs. + Returns + ------- + ParameterGroup + ParameterGroup instance created from the file. + """ - if self.format == "yml_str": - spec = yaml.safe_load(file_name) - else: - with open(file_name) as f: - spec = yaml.safe_load(f) + spec = self._load_yml(file_name) if isinstance(spec, list): return ParameterGroup.from_list(spec) @@ -81,122 +96,51 @@ def load_parameters(self, file_name: str) -> ParameterGroup: return ParameterGroup.from_dict(spec) def load_scheme(self, file_name: str) -> Scheme: - if self.format == "yml_str": - yml = file_name - else: - try: - with open(file_name) as f: - yml = f.read() - except Exception as e: - raise OSError(f"Error opening scheme: {e}") - - try: - scheme = yaml.safe_load(yml) - except Exception as e: - raise ValueError(f"Error parsing scheme: {e}") - - if "model" not in scheme: - raise ValueError("Model file not specified.") - - try: - model = load_model(scheme["model"]) - except Exception as e: - raise ValueError(f"Error loading model: {e}") - - if "parameters" not in scheme: - raise ValueError("Parameters file not specified.") - - try: - parameters = load_parameters(scheme["parameters"]) - except Exception as e: - raise ValueError(f"Error loading parameters: {e}") - - if "data" not in scheme: - raise ValueError("No data specified.") - - data = {} - for label, path in scheme["data"].items(): - data_format = scheme.get("data_format", None) - path = str(pathlib.Path(path).resolve()) - - try: - data[label] = load_dataset(path, format_name=data_format) - except Exception as e: - raise ValueError(f"Error loading dataset '{label}': {e}") - - optimization_method = scheme.get("optimization_method", "TrustRegionReflection") - nnls = scheme.get("non-negative-least-squares", False) - nfev = scheme.get("maximum-number-function-evaluations", None) - ftol = scheme.get("ftol", 1e-8) - gtol = scheme.get("gtol", 1e-8) - xtol = scheme.get("xtol", 1e-8) - group = scheme.get("group", False) - group_tolerance = scheme.get("group_tolerance", 0.0) - saving = SavingOptions(**scheme.get("saving", {})) - return Scheme( - model=model, - parameters=parameters, - data=data, - non_negative_least_squares=nnls, - maximum_number_function_evaluations=nfev, - ftol=ftol, - gtol=gtol, - xtol=xtol, - group=group, - group_tolerance=group_tolerance, - optimization_method=optimization_method, - saving=saving, - ) + spec = self._load_yml(file_name) + file_path = Path(file_name) + return fromdict(Scheme, spec, folder=file_path.parent) def save_scheme(self, scheme: Scheme, file_name: str): - _write_dict(file_name, dataclasses.asdict(scheme)) + scheme_dict = asdict(scheme) + _write_dict(file_name, scheme_dict) + + def load_result(self, result_path: str) -> Result: + """Create a :class:`Result` instance from the specs defined in a file. + + Parameters + ---------- + result_path : str | PathLike[str] + Path containing the result data. + + Returns + ------- + Result + :class:`Result` instance created from the saved format. + """ + spec = self._load_yml(result_path) + return fromdict(Result, spec) def save_result(self, result: Result, result_path: str): - options = result.scheme.saving - - if os.path.exists(result_path): - raise FileExistsError(f"The path '{result_path}' is already existing.") - - os.makedirs(result_path) - - if options.report: - md_path = os.path.join(result_path, "result.md") - with open(md_path, "w") as f: - f.write(str(result.markdown())) - - scheme_path = os.path.join(result_path, "scheme.yml") - result_scheme = dataclasses.replace(result.scheme) - result_scheme.model = result_scheme.model.markdown() - result = dataclasses.replace(result) - result.scheme = scheme_path - - parameters_format = options.parameter_format - - initial_parameters_path = os.path.join( - result_path, f"initial_parameters.{parameters_format}" - ) - save_parameters(result.initial_parameters, initial_parameters_path, parameters_format) - result.initial_parameters = initial_parameters_path - result_scheme.parameters = initial_parameters_path - - optimized_parameters_path = os.path.join( - result_path, f"optimized_parameters.{parameters_format}" - ) - save_parameters(result.optimized_parameters, optimized_parameters_path, parameters_format) - result.optimized_parameters = optimized_parameters_path - - dataset_format = options.data_format - for label, dataset in result.data.items(): - dataset_path = os.path.join(result_path, f"{label}.{dataset_format}") - save_dataset(dataset, dataset_path, dataset_format, saving_options=options) - result.data[label] = dataset_path - result_scheme.data[label] = dataset_path - - result_file_path = os.path.join(result_path, "result.yml") - _write_dict(result_file_path, dataclasses.asdict(result)) - result_scheme.result_path = result_file_path - - self.save_scheme(scheme=result_scheme, file_name=scheme_path) + """Write a :class:`Result` instance to a spec file. + + Parameters + ---------- + result : Result + :class:`Result` instance to write. + result_path : str | PathLike[str] + Path to write the result data to. + """ + save_result(result, Path(result_path).parent.as_posix(), format_name="folder") + result_dict = asdict(result) + _write_dict(result_path, result_dict) + + def _load_yml(self, file_name: str) -> dict: + if self.format == "yml_str": + spec = yaml.safe_load(file_name) + else: + with open(file_name) as f: + spec = yaml.safe_load(f) + return spec def _write_dict(file_name: str, d: dict): diff --git a/glotaran/deprecation/modules/test/test_project_result.py b/glotaran/deprecation/modules/test/test_project_result.py index 83d2c5248..0e1e34d9c 100644 --- a/glotaran/deprecation/modules/test/test_project_result.py +++ b/glotaran/deprecation/modules/test/test_project_result.py @@ -9,27 +9,10 @@ from glotaran.project.test.test_result import dummy_result # noqa: F401 if TYPE_CHECKING: - from py.path import local as LocalPath from glotaran.project.result import Result -def test_Result_save_method(tmpdir: LocalPath, dummy_result: Result): # noqa: F811 - """Result.save(result_dir) creates all file""" - result_dir = tmpdir / "dummy" - result_dir.mkdir() - - deprecation_warning_on_call_test_helper( - dummy_result.save, args=[str(result_dir)], raise_exception=True - ) - - assert (result_dir / "result.md").exists() - assert (result_dir / "optimized_parameters.csv").exists() - assert (result_dir / "dataset1.nc").exists() - assert (result_dir / "dataset2.nc").exists() - assert (result_dir / "dataset3.nc").exists() - - def test_Result_get_dataset_method(dummy_result: Result): # noqa: F811 """Result.get_dataset(dataset_label) gives correct dataset.""" diff --git a/glotaran/deprecation/modules/test/test_project_sheme.py b/glotaran/deprecation/modules/test/test_project_scheme.py similarity index 87% rename from glotaran/deprecation/modules/test/test_project_sheme.py rename to glotaran/deprecation/modules/test/test_project_scheme.py index 42ce6daa1..0bf65635a 100644 --- a/glotaran/deprecation/modules/test/test_project_sheme.py +++ b/glotaran/deprecation/modules/test/test_project_scheme.py @@ -38,11 +38,11 @@ def test_Scheme_from_yaml_file_method(tmp_path: Path): scheme_path.write_text( f""" - model: {model_path} - parameters: {parameter_path} - non-negative-least-squares: True - maximum-number-function-evaluations: 42 - data: + model_file: {model_path} + parameters_file: {parameter_path} + non_negative_least_squares: True + maximum_number_function_evaluations: 42 + data_files: dataset1: {dataset_path}""" ) diff --git a/glotaran/examples/sequential.py b/glotaran/examples/sequential.py index d0f0f635e..cf7275b68 100644 --- a/glotaran/examples/sequential.py +++ b/glotaran/examples/sequential.py @@ -5,6 +5,7 @@ from glotaran.builtin.megacomplexes.spectral import SpectralMegacomplex from glotaran.model import Model from glotaran.parameter import ParameterGroup +from glotaran.project import Scheme sim_model = Model.from_dict( { @@ -148,3 +149,4 @@ }, megacomplex_types={"decay": DecayMegacomplex}, ) +scheme = Scheme(model=model, parameters=parameter, data={"dataset1": dataset}) diff --git a/glotaran/io/__init__.py b/glotaran/io/__init__.py index 07cd786f6..1bf069a30 100644 --- a/glotaran/io/__init__.py +++ b/glotaran/io/__init__.py @@ -17,6 +17,9 @@ from glotaran.plugin_system.data_io_registration import save_dataset from glotaran.plugin_system.data_io_registration import set_data_plugin from glotaran.plugin_system.data_io_registration import show_data_io_method_help +from glotaran.plugin_system.project_io_registration import SAVING_OPTIONS_DEFAULT +from glotaran.plugin_system.project_io_registration import SAVING_OPTIONS_MINIMAL +from glotaran.plugin_system.project_io_registration import SavingOptions from glotaran.plugin_system.project_io_registration import get_project_io_method from glotaran.plugin_system.project_io_registration import load_model from glotaran.plugin_system.project_io_registration import load_parameters diff --git a/glotaran/io/interface.py b/glotaran/io/interface.py index 1ea85627e..f6c5e5019 100644 --- a/glotaran/io/interface.py +++ b/glotaran/io/interface.py @@ -218,7 +218,7 @@ def load_result(self, result_path: str) -> Result: """ raise NotImplementedError(f"Cannot read result with format {self.format!r}") - def save_result(self, result: Result, result_path: str): + def save_result(self, result: Result, result_path: str) -> list[str] | None: """Save a Result instance to a spec file (**NOT IMPLEMENTED**). Parameters diff --git a/glotaran/model/dataset_model.py b/glotaran/model/dataset_model.py index 093475ef2..188fccdf6 100644 --- a/glotaran/model/dataset_model.py +++ b/glotaran/model/dataset_model.py @@ -152,7 +152,7 @@ def overwrite_index_dependent(self, index_dependent: bool): def has_global_model(self) -> bool: """Indicates if the dataset model can model the global dimension.""" - return len(self.global_megacomplex) != 0 + return self.global_megacomplex is not None and len(self.global_megacomplex) != 0 def set_coordinates(self, coords: dict[str, np.ndarray]): """Sets the dataset model's coordinates.""" diff --git a/glotaran/model/item.py b/glotaran/model/item.py index f3133090f..8b7a165dd 100644 --- a/glotaran/model/item.py +++ b/glotaran/model/item.py @@ -115,6 +115,9 @@ def decorator(cls): validate = _create_validation_func(cls) setattr(cls, "validate", validate) + as_dict = _create_as_dict_func(cls) + setattr(cls, "as_dict", as_dict) + get_state = _create_get_state_func(cls) setattr(cls, "__getstate__", get_state) @@ -124,6 +127,9 @@ def decorator(cls): fill = _create_fill_func(cls) setattr(cls, "fill", fill) + get_parameters = _create_get_parameters(cls) + setattr(cls, "get_parameters", get_parameters) + mprint = _create_mprint_func(cls) setattr(cls, "mprint", mprint) @@ -280,6 +286,18 @@ def validate(self, model: Model, parameters: ParameterGroup | None = None) -> li return validate +def _create_as_dict_func(cls): + @wrap_func_as_method(cls) + def as_dict(self) -> dict: + return { + name: getattr(self.__class__, name).as_dict_value(getattr(self, name)) + for name in self._glotaran_properties + if name != "label" and getattr(self, name) is not None + } + + return as_dict + + def _create_fill_func(cls): @wrap_func_as_method(cls) def fill(self, model: Model, parameters: ParameterGroup) -> cls: @@ -304,6 +322,20 @@ def fill(self, model: Model, parameters: ParameterGroup) -> cls: return fill +def _create_get_parameters(cls): + @wrap_func_as_method(cls) + def get_parameters(self) -> list[str]: + """Returns all parameter full labels of the item.""" + parameters = [] + for name in self._glotaran_properties: + value = getattr(self, name) + prop = getattr(self.__class__, name) + parameters += prop.get_parameters(value) + return parameters + + return get_parameters + + def _create_get_state_func(cls): @wrap_func_as_method(cls) def get_state(self) -> cls: diff --git a/glotaran/model/model.py b/glotaran/model/model.py index 4e8ceb74e..1813c652b 100644 --- a/glotaran/model/model.py +++ b/glotaran/model/model.py @@ -32,7 +32,7 @@ default_dataset_properties = { "megacomplex": List[str], "megacomplex_scale": {"type": List[Parameter], "allow_none": True}, - "global_megacomplex": {"type": List[str], "default": []}, + "global_megacomplex": {"type": List[str], "allow_none": True}, "global_megacomplex_scale": {"type": List[Parameter], "default": None, "allow_none": True}, "scale": {"type": Parameter, "default": None, "allow_none": True}, } @@ -75,6 +75,7 @@ def from_dict( default_megacomplex_type: str | None Overwrite 'default-megacomplex' in ``model_dict`` for testing. """ + model_dict = copy.deepcopy(model_dict) if default_megacomplex_type is None: default_megacomplex_type = model_dict.get("default-megacomplex") @@ -84,90 +85,75 @@ def from_dict( for m in model_dict["megacomplex"].values() if "type" in m } - if default_megacomplex_type is not None: + if ( + default_megacomplex_type is not None + and default_megacomplex_type not in megacomplex_types + ): megacomplex_types[default_megacomplex_type] = get_megacomplex(default_megacomplex_type) + if "default-megacomplex" in model_dict: model_dict.pop("default-megacomplex", None) model = cls( megacomplex_types=megacomplex_types, default_megacomplex_type=default_megacomplex_type ) - model_dict_local = copy.deepcopy(model_dict) # TODO: maybe redundant? - # iterate over items - for name, items in list(model_dict_local.items()): + for item_name, items in list(model_dict.items()): - if name not in model._model_items: - warn(f"Unknown model item type '{name}'.") + if item_name not in model.model_items: + warn(f"Unknown model item type '{item_name}'.") continue - is_list = isinstance(getattr(model, name), list) + is_list = isinstance(getattr(model, item_name), list) if is_list: - model._add_list_items(name, items) + model._add_list_items(item_name, items) else: - model._add_dict_items(name, items) + model._add_dict_items(item_name, items) return model - @property - def model_dimension(self): - """Deprecated use ``Scheme.model_dimensions['']`` instead""" - raise_deprecation_error( - deprecated_qual_name_usage="Model.model_dimension", - new_qual_name_usage=("Scheme.model_dimensions['']"), - to_be_removed_in_version="0.7.0", - ) - - @property - def global_dimension(self): - """Deprecated use ``Scheme.global_dimensions['']`` instead""" - raise_deprecation_error( - deprecated_qual_name_usage="Model.global_dimension", - new_qual_name_usage=("Scheme.global_dimensions['']"), - to_be_removed_in_version="0.7.0", - ) - - def _add_dict_items(self, name: str, items: dict): + def _add_dict_items(self, item_name: str, items: dict): for label, item in items.items(): - item_cls = self._model_items[name] + item_cls = self.model_items[item_name] is_typed = hasattr(item_cls, "_glotaran_model_item_typed") if is_typed: if "type" not in item and item_cls.get_default_type() is None: - raise ValueError(f"Missing type for attribute '{name}'") + raise ValueError(f"Missing type for attribute '{item_name}'") item_type = item.get("type", item_cls.get_default_type()) types = item_cls._glotaran_model_item_types if item_type not in types: - raise ValueError(f"Unknown type '{item_type}' for attribute '{name}'") + raise ValueError(f"Unknown type '{item_type}' for attribute '{item_name}'") item_cls = types[item_type] item["label"] = label item = item_cls.from_dict(item) - getattr(self, name)[label] = item + getattr(self, item_name)[label] = item - def _add_list_items(self, name: str, items: list): + def _add_list_items(self, item_name: str, items: list): for item in items: - item_cls = self._model_items[name] + item_cls = self.model_items[item_name] is_typed = hasattr(item_cls, "_glotaran_model_item_typed") if is_typed: if "type" not in item: - raise ValueError(f"Missing type for attribute '{name}'") + raise ValueError(f"Missing type for attribute '{item_name}'") item_type = item["type"] if item_type not in item_cls._glotaran_model_item_types: - raise ValueError(f"Unknown type '{item_type}' for attribute '{name}'") + raise ValueError(f"Unknown type '{item_type}' for attribute '{item_name}'") item_cls = item_cls._glotaran_model_item_types[item_type] item = item_cls.from_dict(item) - getattr(self, name).append(item) + getattr(self, item_name).append(item) def _add_megacomplexe_types(self): - for name, megacomplex_type in self._megacomplex_types.items(): + for megacomplex_name, megacomplex_type in self._megacomplex_types.items(): if not issubclass(megacomplex_type, Megacomplex): raise TypeError( - f"Megacomplex type {name}({megacomplex_type}) is not a subclass of Megacomplex" + f"Megacomplex type {megacomplex_name}({megacomplex_type}) " + "is not a subclass of Megacomplex" ) self._add_megacomplex_type(megacomplex_type) @@ -178,36 +164,36 @@ def _add_megacomplexe_types(self): def _add_megacomplex_type(self, megacomplex_type: type[Megacomplex]): - for name, item in megacomplex_type.glotaran_model_items().items(): - self._add_model_item(name, item) + for item_name, item in megacomplex_type.glotaran_model_items().items(): + self._add_model_item(item_name, item) - for name, item in megacomplex_type.glotaran_dataset_model_items().items(): - self._add_model_item(name, item) + for item_name, item in megacomplex_type.glotaran_dataset_model_items().items(): + self._add_model_item(item_name, item) - for name, prop in megacomplex_type.glotaran_dataset_properties().items(): - self._add_dataset_property(name, prop) + for property_name, prop in megacomplex_type.glotaran_dataset_properties().items(): + self._add_dataset_property(property_name, prop) - def _add_model_item(self, name: str, item: type): - if name in self._model_items: - if self._model_items[name] != item: + def _add_model_item(self, item_name: str, item: type): + if item_name in self._model_items: + if self.model_items[item_name] != item: raise ModelError( - f"Cannot add item of type {name}. Model item '{name}' was already defined" - "as a different type." + f"Cannot add item of type {item_name}. Model item '{item_name}' " + "was already defined as a different type." ) return - self._model_items[name] = item + self._model_items[item_name] = item if getattr(item, "_glotaran_has_label"): - setattr(self, f"{name}", {}) + setattr(self, f"{item_name}", {}) else: - setattr(self, f"{name}", []) + setattr(self, f"{item_name}", []) - def _add_dataset_property(self, name: str, dataset_property: dict[str, any]): - if name in self._dataset_properties: + def _add_dataset_property(self, property_name: str, dataset_property: dict[str, any]): + if property_name in self._dataset_properties: known_type = ( - self._dataset_properties[name] + self._dataset_properties[property_name] if not isinstance(self._dataset_properties, dict) - else self._dataset_properties[name]["type"] + else self._dataset_properties[property_name]["type"] ) new_type = ( dataset_property @@ -216,23 +202,41 @@ def _add_dataset_property(self, name: str, dataset_property: dict[str, any]): ) if known_type != new_type: raise ModelError( - f"Cannot add dataset property of type {name} as it was already defined" - "as a different type." + f"Cannot add dataset property of type {property_name} as it was " + "already defined as a different type." ) return - self._dataset_properties[name] = dataset_property + self._dataset_properties[property_name] = dataset_property def _add_default_items_and_properties(self): - for name, item in default_model_items.items(): - self._add_model_item(name, item) + for item_name, item in default_model_items.items(): + self._add_model_item(item_name, item) - for name, prop in default_dataset_properties.items(): - self._add_dataset_property(name, prop) + for property_name, prop in default_dataset_properties.items(): + self._add_dataset_property(property_name, prop) def _add_dataset_type(self): dataset_model_type = create_dataset_model_type(self._dataset_properties) self._add_model_item("dataset", dataset_model_type) + @property + def model_dimension(self): + """Deprecated use ``Scheme.model_dimensions['']`` instead""" + raise_deprecation_error( + deprecated_qual_name_usage="Model.model_dimension", + new_qual_name_usage=("Scheme.model_dimensions['']"), + to_be_removed_in_version="0.7.0", + ) + + @property + def global_dimension(self): + """Deprecated use ``Scheme.global_dimensions['']`` instead""" + raise_deprecation_error( + deprecated_qual_name_usage="Model.global_dimension", + new_qual_name_usage=("Scheme.global_dimensions['']"), + to_be_removed_in_version="0.7.0", + ) + @property def default_megacomplex(self) -> str: """The default megacomplex used by this model.""" @@ -253,6 +257,28 @@ def global_megacomplex(self) -> dict[str, Megacomplex]: """Alias for `glotaran.model.megacomplex`. Needed internally.""" return self.megacomplex + def as_dict(self) -> dict: + model_dict = {"default-megacomplex": self.default_megacomplex} + for item_name in self._model_items: + items = getattr(self, item_name) + if len(items) == 0: + continue + if isinstance(items, list): + model_dict[item_name] = [item.as_dict() for item in items] + else: + model_dict[item_name] = {label: item.as_dict() for label, item in items.items()} + + return model_dict + + def get_parameters(self) -> list[str]: + parameters = [] + for item_name in self.model_items: + items = getattr(self, item_name) + item_iterator = items if isinstance(items, list) else items.values() + for item in item_iterator: + parameters += item.get_parameters() + return parameters + def need_index_dependent(self) -> bool: """Returns true if e.g. clp_relations with intervals are present.""" return any(i.interval is not None for i in self.clp_constraints + self.clp_relations) @@ -282,7 +308,7 @@ def problem_list(self, parameters: ParameterGroup = None) -> list[str]: """ problems = [] - for name in self._model_items: + for name in self.model_items: items = getattr(self, name) if isinstance(items, list): for item in items: @@ -357,7 +383,7 @@ def markdown( string += ", ".join(self._megacomplex_types) string += "\n\n" - for name in self._model_items: + for name in self.model_items: items = getattr(self, name) if not items: continue @@ -380,5 +406,5 @@ def _repr_markdown_(self) -> str: """Special method used by ``ipython`` to render markdown.""" return str(self.markdown(base_heading_level=3)) - def __str__(self): + def __str__(self) -> str: return str(self.markdown()) diff --git a/glotaran/model/property.py b/glotaran/model/property.py index 41c6d4c0f..e93e79c7a 100644 --- a/glotaran/model/property.py +++ b/glotaran/model/property.py @@ -1,9 +1,18 @@ """The model property class.""" +from __future__ import annotations -import typing +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import List +from typing import Union from glotaran.model.util import wrap_func_as_method from glotaran.parameter import Parameter +from glotaran.parameter import ParameterGroup + +if TYPE_CHECKING: + from glotaran.model.model import Model class ModelProperty(property): @@ -13,7 +22,7 @@ def __init__(self, cls, name, prop_type, doc, default, allow_none): self._allow_none = allow_none self._determine_if_parameter(prop_type) - self._type = prop_type if not self._is_parameter else typing.Union[str, prop_type] + self._type = prop_type if not self._is_parameter else Union[str, prop_type] @wrap_func_as_method(cls, name=name) def setter(that_self, value: self._type): @@ -45,10 +54,21 @@ def allow_none(self) -> bool: return self._allow_none @property - def property_type(self) -> typing.Type: + def property_type(self) -> type: return self._type - def validate(self, value, model, parameters=None) -> typing.List[str]: + def as_dict_value(self, value): + if value is None: + return None + elif self._is_parameter_value: + return value.full_label + elif self._is_parameter_list: + return [v.full_label for v in value] + elif self._is_parameter_dict: + return {k: v.full_label for k, v in value.items()} + return value + + def validate(self, value: Any, model: Model, parameters: ParameterGroup = None) -> list[str]: if value is None and self.allow_none: return [] @@ -88,7 +108,7 @@ def validate(self, value, model, parameters=None) -> typing.List[str]: return missing_model + missing_parameters - def fill(self, value, model, parameter): + def fill(self, value: Any, model: Model, parameter: ParameterGroup) -> Any: if value is None: return None @@ -119,16 +139,27 @@ def fill(self, value, model, parameter): return value + def get_parameters(self, value: Any) -> list[str]: + if value is None: + return [] + elif self._is_parameter_value: + return [value.full_label] + elif self._is_parameter_list: + return [v.full_label for v in value] + elif self._is_parameter_dict: + return [v.full_label for v in value.values()] + return [] + def _determine_if_parameter(self, type): self._is_parameter_value = type is Parameter self._is_parameter_list = ( hasattr(type, "__origin__") - and issubclass(type.__origin__, typing.List) + and issubclass(type.__origin__, List) and type.__args__[0] is Parameter ) self._is_parameter_dict = ( hasattr(type, "__origin__") - and issubclass(type.__origin__, typing.Dict) + and issubclass(type.__origin__, Dict) and type.__args__[1] is Parameter ) self._is_parameter = ( diff --git a/glotaran/model/test/test_model.py b/glotaran/model/test/test_model.py index e34b0585b..14eff9741 100644 --- a/glotaran/model/test/test_model.py +++ b/glotaran/model/test/test_model.py @@ -37,6 +37,18 @@ class MockItem: pass +@model_item( + properties={ + "param": Parameter, + "param_list": List[Parameter], + "param_dict": {"type": Dict[Tuple[str, str], Parameter]}, + "number": int, + }, +) +class MockItemSimple: + pass + + @model_item(has_label=False) class MockItemNoLabel: pass @@ -79,8 +91,13 @@ class MockMegacomplex6(Megacomplex): pass +@megacomplex(dimension="model", model_items={"test_item_simple": MockItemSimple}) +class MockMegacomplex7(Megacomplex): + pass + + @pytest.fixture -def test_model(): +def test_model_dict(): model_dict = { "megacomplex": { "m1": {"test_item1": "t2"}, @@ -128,8 +145,13 @@ def test_model(): }, } model_dict["test_item_dataset"] = model_dict["test_item1"] + return model_dict + + +@pytest.fixture +def test_model(test_model_dict): return Model.from_dict( - model_dict, + test_model_dict, megacomplex_types={ "type1": MockMegacomplex1, "type5": MockMegacomplex5, @@ -354,6 +376,37 @@ def test_fill(test_model: Model, parameter: ParameterGroup): assert t.complex == {} +def test_model_as_dict(): + model_dict = { + "default-megacomplex": "type7", + "megacomplex": { + "m1": {"test_item_simple": "t2", "dimension": "model"}, + }, + "test_item_simple": { + "t1": { + "param": "foo", + "param_list": ["bar", "baz"], + "param_dict": {("s1", "s2"): "baz"}, + "number": 21, + }, + }, + "dataset": { + "dataset1": { + "megacomplex": ["m1"], + "scale": "scale_1", + }, + }, + } + model = Model.from_dict( + model_dict, + megacomplex_types={ + "type7": MockMegacomplex7, + }, + ) + as_model_dict = model.as_dict() + assert as_model_dict == model_dict + + def test_model_markdown_base_heading_level(test_model: Model): """base_heading_level applies to all sections.""" assert test_model.markdown().startswith("# Model") diff --git a/glotaran/parameter/__init__.py b/glotaran/parameter/__init__.py index 6ee5c1c96..9eb44b67a 100644 --- a/glotaran/parameter/__init__.py +++ b/glotaran/parameter/__init__.py @@ -1,5 +1,4 @@ -from glotaran.parameter import parameter -from glotaran.parameter import parameter_group - -Parameter = parameter.Parameter -ParameterGroup = parameter_group.ParameterGroup +"""The glotaran parameter package.""" +from glotaran.parameter.parameter import Parameter +from glotaran.parameter.parameter_group import ParameterGroup +from glotaran.parameter.parameter_history import ParameterHistory diff --git a/glotaran/parameter/parameter.py b/glotaran/parameter/parameter.py index f6e454be0..fac046a19 100644 --- a/glotaran/parameter/parameter.py +++ b/glotaran/parameter/parameter.py @@ -53,27 +53,26 @@ def __init__( Parameters ---------- - label : str, optional + label : str The label of the parameter., by default None - full_label : str, optional + full_label : str The label of the parameter with its path in a parameter group prepended. , by default None - expression : str, optional + expression : str Expression to calculate the parameters value from, e.g. if used in relation to another parameter. , by default None - maximum : int, optional + maximum : int Upper boundary for the parameter to be varied to., by default np.inf - minimum : int, optional + minimum : int Lower boundary for the parameter to be varied to., by default -np.inf - non_negative : bool, optional + non_negative : bool Whether the parameter should always be bigger than zero., by default False - value : float, optional + value : float Value of the parameter, by default np.nan - vary : bool, optional + vary : bool Whether the parameter should be changed during optimization or not. , by default True """ - self.label = label self.full_label = full_label or "" self.expression = expression @@ -88,7 +87,19 @@ def __init__( @staticmethod def valid_label(label: str) -> bool: - """Returns true if the `label` is valid string.""" + """Check if a label is a valid label for :class:`Parameter`. + + Parameters + ---------- + label : str + The label to validate. + + Returns + ------- + bool + Whether the label is valid. + + """ return VALID_LABEL_REGEX.search(label) is None and label not in RESERVED_LABELS @classmethod @@ -98,16 +109,21 @@ def from_list_or_value( default_options: dict = None, label: str = None, ) -> Parameter: - """Creates a parameter from a list or numeric value. + """Create a parameter from a list or numeric value. Parameters ---------- - value : + value : int | float | list The list or numeric value. - default_options : + default_options : dict A dictionary of default options. - label : + label : str The label of the parameter. + + Returns + ------- + Parameter + The created :class:`Parameter`. """ param = cls(label=label) options = None @@ -117,9 +133,9 @@ def from_list_or_value( else: values = sanitize_parameter_list(value) - param.label = _retrieve_from_list_by_type(values, str, label) - param.value = float(_retrieve_from_list_by_type(values, (int, float), 0)) - options = _retrieve_from_list_by_type(values, dict, None) + param.label = _retrieve_item_from_list_by_type(values, str, label) + param.value = float(_retrieve_item_from_list_by_type(values, (int, float), 0)) + options = _retrieve_item_from_list_by_type(values, dict, None) if default_options: param._set_options_from_dict(default_options) @@ -129,19 +145,17 @@ def from_list_or_value( return param def set_from_group(self, group: ParameterGroup): - """Sets all values of the parameter to the values of the corresponding parameter in the group. + """Set all values of the parameter to the values of the corresponding parameter in the group. Notes ----- - For internal use. Parameters ---------- - group : + group : ParameterGroup The :class:`glotaran.parameter.ParameterGroup`. """ - p = group.get(self.full_label) self.expression = p.expression self.maximum = p.maximum @@ -152,6 +166,13 @@ def set_from_group(self, group: ParameterGroup): self.vary = p.vary def _set_options_from_dict(self, options: dict): + """Set the parameter's options from a dictionary. + + Parameters + ---------- + options : dict + A dictionary containing parameter options. + """ if Keys.EXPR in options: self.expression = options[Keys.EXPR] if Keys.NON_NEG in options: @@ -165,7 +186,13 @@ def _set_options_from_dict(self, options: dict): @property def label(self) -> str | None: - """Label of the parameter""" + """Label of the parameter. + + Returns + ------- + str + The label. + """ return self._label @label.setter @@ -178,7 +205,13 @@ def label(self, label: str | None): @property def full_label(self) -> str: - """The label of the parameter with its path in a parameter group prepended.""" + """Label of the parameter with its path in a parameter group prepended. + + Returns + ------- + str + The full label. + """ return self._full_label @full_label.setter @@ -187,13 +220,20 @@ def full_label(self, full_label: str): @property def non_negative(self) -> bool: - r"""Indicates if the parameter is non-negativ. + r"""Indicate if the parameter is non-negativ. If true, the parameter will be transformed with :math:`p' = \log{p}` and :math:`p = \exp{p'}`. + Notes + ----- Always `False` if `expression` is not `None`. - """ # w605 + + Returns + ------- + bool + Whether the parameter is non-negativ. + """ return self._non_negative if self.expression is None else False @non_negative.setter @@ -202,9 +242,16 @@ def non_negative(self, non_negative: bool): @property def vary(self) -> bool: - """Indicates if the parameter should be optimized. + """Indicate if the parameter should be optimized. + Notes + ----- Always `False` if `expression` is not `None`. + + Returns + ------- + bool + Whether the parameter should be optimized. """ return self._vary if self.expression is None else False @@ -214,7 +261,13 @@ def vary(self, vary: bool): @property def maximum(self) -> float: - """The upper bound of the parameter.""" + """Upper bound of the parameter. + + Returns + ------- + float + The upper bound of the parameter. + """ return self._maximum @maximum.setter @@ -232,7 +285,14 @@ def maximum(self, maximum: int | float): @property def minimum(self) -> float: - """The lower bound of the parameter.""" + """Lower bound of the parameter. + + Returns + ------- + float + + The lower bound of the parameter. + """ return self._minimum @minimum.setter @@ -253,6 +313,11 @@ def expression(self) -> str | None: """Expression to calculate the parameters value from. This can used to set a relation to another parameter. + + Returns + ------- + str | None + The expression. """ return self._expression @@ -263,7 +328,14 @@ def expression(self, expression: str | None): @property def transformed_expression(self) -> str | None: - """The expression of the parameter transformed for evaluation within a `ParameterGroup`.""" + """Expression of the parameter transformed for evaluation within a `ParameterGroup`. + + Returns + ------- + str | None + The transformed expression. + + """ if self.expression is not None and self._transformed_expression is None: self._transformed_expression = PARAMETER_EXPRESION_REGEX.sub( r"group.get('\g').value", self.expression @@ -271,8 +343,15 @@ def transformed_expression(self) -> str | None: return self._transformed_expression @property - def standard_error(self) -> float: - """The standard error of the optimized parameter.""" + def standard_error(self) -> float: # noqa D401 + """Standard error of the optimized parameter. + + Returns + ------- + float + The standard error of the parameter. + """ + return self._stderr @standard_error.setter @@ -281,8 +360,14 @@ def standard_error(self, standard_error: float): @property def value(self) -> float: - """The value of the parameter""" - return self._getval() + """Value of the parameter. + + Returns + ------- + float + The value of the parameter. + """ + return self._value @value.setter def value(self, value: int | float): @@ -298,8 +383,13 @@ def value(self, value: int | float): self._value = value def get_value_and_bounds_for_optimization(self) -> tuple[float, float, float]: - """Gets the parameter value and bounds with expression and non-negative constraints - applied.""" + """Get the parameter value and bounds with expression and non-negative constraints applied. + + Returns + ------- + tuple[float, float, float] + A tuple containing the value, the lower and the upper bound. + """ value = self.value minimum = self.minimum maximum = self.maximum @@ -312,10 +402,16 @@ def get_value_and_bounds_for_optimization(self) -> tuple[float, float, float]: return value, minimum, maximum def set_value_from_optimization(self, value: float): - """Sets the value from an optimization result and reverses non-negative transformation.""" + """Set the value from an optimization result and reverses non-negative transformation. + + Parameters + ---------- + value : float + Value from optimization. + """ self.value = np.exp(value) if self.non_negative else value - def __getstate__(self): + def __getstate__(self): # noqa D400 """Get state for pickle.""" return ( self.label, @@ -329,7 +425,7 @@ def __getstate__(self): self.vary, ) - def __setstate__(self, state): + def __setstate__(self, state): # noqa D400 """Set state from pickle.""" ( self.label, @@ -343,21 +439,18 @@ def __setstate__(self, state): self.vary, ) = state - def _getval(self) -> float: - return self._value - - def __repr__(self): + def __repr__(self): # noqa D400 """Representation used by repl and tracebacks.""" return ( f"{type(self).__name__}(label={self.label!r}, value={self.value!r}," f" expression={self.expression!r}, vary={self.vary!r})" ) - def __array__(self): + def __array__(self): # noqa D400 """array""" - return np.array(float(self._getval()), dtype=float) + return np.array(float(self._value), dtype=float) - def __str__(self): + def __str__(self) -> str: # noqa D400 """Representation used by print and str.""" return ( f"__{self.label}__: _Value_: {self.value}, _StdErr_: {self.standard_error}, _Min_:" @@ -365,120 +458,134 @@ def __str__(self): f" _Non-Negative_: {self.non_negative}, _Expr_: {self.expression}" ) - def __abs__(self): + def __abs__(self): # noqa D400 """abs""" - return abs(self._getval()) + return abs(self._value) - def __neg__(self): + def __neg__(self): # noqa D400 """neg""" - return -self._getval() + return -self._value - def __pos__(self): + def __pos__(self): # noqa D400 """positive""" - return +self._getval() + return +self._value - def __int__(self): + def __int__(self): # noqa D400 """int""" - return int(self._getval()) + return int(self._value) - def __float__(self): + def __float__(self): # noqa D400 """float""" - return float(self._getval()) + return float(self._value) - def __trunc__(self): + def __trunc__(self): # noqa D400 """trunc""" - return self._getval().__trunc__() + return self._value.__trunc__() - def __add__(self, other): + def __add__(self, other): # noqa D400 """+""" - return self._getval() + other + return self._value + other - def __sub__(self, other): + def __sub__(self, other): # noqa D400 """-""" - return self._getval() - other + return self._value - other - def __truediv__(self, other): + def __truediv__(self, other): # noqa D400 """/""" - return self._getval() / other + return self._value / other - def __floordiv__(self, other): + def __floordiv__(self, other): # noqa D400 """//""" - return self._getval() // other + return self._value // other - def __divmod__(self, other): + def __divmod__(self, other): # noqa D400 """divmod""" - return divmod(self._getval(), other) + return divmod(self._value, other) - def __mod__(self, other): + def __mod__(self, other): # noqa D400 """%""" - return self._getval() % other + return self._value % other - def __mul__(self, other): + def __mul__(self, other): # noqa D400 """*""" - return self._getval() * other + return self._value * other - def __pow__(self, other): + def __pow__(self, other): # noqa D400 """**""" - return self._getval() ** other + return self._value ** other - def __gt__(self, other): + def __gt__(self, other): # noqa D400 """>""" - return self._getval() > other + return self._value > other - def __ge__(self, other): + def __ge__(self, other): # noqa D400 """>=""" - return self._getval() >= other + return self._value >= other - def __le__(self, other): + def __le__(self, other): # noqa D400 """<=""" - return self._getval() <= other + return self._value <= other - def __lt__(self, other): + def __lt__(self, other): # noqa D400 """<""" - return self._getval() < other + return self._value < other - def __eq__(self, other): + def __eq__(self, other): # noqa D400 """==""" - return self._getval() == other + return self._value == other - def __ne__(self, other): + def __ne__(self, other): # noqa D400 """!=""" - return self._getval() != other + return self._value != other - def __radd__(self, other): + def __radd__(self, other): # noqa D400 """+ (right)""" - return other + self._getval() + return other + self._value - def __rtruediv__(self, other): + def __rtruediv__(self, other): # noqa D400 """/ (right)""" - return other / self._getval() + return other / self._value - def __rdivmod__(self, other): + def __rdivmod__(self, other): # noqa D400 """divmod (right)""" - return divmod(other, self._getval()) + return divmod(other, self._value) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other): # noqa D400 """// (right)""" - return other // self._getval() + return other // self._value - def __rmod__(self, other): + def __rmod__(self, other): # noqa D400 """% (right)""" - return other % self._getval() + return other % self._value - def __rmul__(self, other): + def __rmul__(self, other): # noqa D400 """* (right)""" - return other * self._getval() + return other * self._value - def __rpow__(self, other): + def __rpow__(self, other): # noqa D400 """** (right)""" - return other ** self._getval() + return other ** self._value - def __rsub__(self, other): + def __rsub__(self, other): # noqa D400 """- (right)""" - return other - self._getval() + return other - self._value -def _log_value(value: float): +def _log_value(value: float) -> float: + """Get the logarithm of a value. + + Performs a check for edge cases and migitates numerical issues. + + Parameters + ---------- + value : float + The initial value. + + Returns + ------- + float + The logarithm of the value. + """ if not np.isfinite(value): return value if value == 1: @@ -486,9 +593,27 @@ def _log_value(value: float): return np.log(value) -def _retrieve_from_list_by_type(li: list, t: type | tuple[type, ...], default: Any): - tmp = list(filter(lambda x: isinstance(x, t), li)) +def _retrieve_item_from_list_by_type( + item_list: list, item_type: type | tuple[type, ...], default: Any +) -> Any: + """Retrieve an item from list which matches a given type. + + Parameters + ---------- + item_list : list + The list to retrieve from. + item_type : type | tuple[type, ...] + The item type or tuple of types to match. + default : Any + Returned if no item matches. + + Returns + ------- + Any + + """ + tmp = list(filter(lambda x: isinstance(x, item_type), item_list)) if not tmp: return default - li.remove(tmp[0]) + item_list.remove(tmp[0]) return tmp[0] diff --git a/glotaran/parameter/parameter_group.py b/glotaran/parameter/parameter_group.py index 52131187e..a603ec3c1 100644 --- a/glotaran/parameter/parameter_group.py +++ b/glotaran/parameter/parameter_group.py @@ -1,9 +1,10 @@ -"""The parameter group class""" +"""The parameter group class.""" from __future__ import annotations from copy import copy from textwrap import indent +from typing import TYPE_CHECKING from typing import Generator import asteval @@ -14,6 +15,9 @@ from glotaran.parameter.parameter import Parameter from glotaran.utils.ipython import MarkdownStr +if TYPE_CHECKING: + from glotaran.parameter.parameter_history import ParameterHistory + class ParameterNotFoundException(Exception): """Raised when a Parameter is not found in the Group.""" @@ -23,19 +27,28 @@ def __init__(self, path, label): class ParameterGroup(dict): - def __init__(self, label: str = None, root_group: ParameterGroup = None): - """Represents are group of parameters. Can contain other groups, creating a - tree-like hierarchy. + """Represents are group of parameters. - Parameters - ---------- - label : - The label of the group. - """ + Can contain other groups, creating a tree-like hierarchy. + + Parameters + ---------- + label : str + The label of the group. + root_group : ParameterGroup + The root group + + Raises + ------ + ValueError + Raised if the an invalid label is given. + """ + + def __init__(self, label: str = None, root_group: ParameterGroup = None): if label is not None and not Parameter.valid_label(label): raise ValueError(f"'{label}' is not a valid group label.") self._label = label - self._parameters = {} + self._parameters: dict[str, Parameter] = {} self._root_group = root_group self._evaluator = ( asteval.Interpreter(symtable=asteval.make_symbol_table(group=self)) @@ -51,16 +64,21 @@ def from_dict( label: str = None, root_group: ParameterGroup = None, ) -> ParameterGroup: - """Creates a :class:`ParameterGroup` from a dictionary. + """Create a :class:`ParameterGroup` from a dictionary. Parameters ---------- - parameter_dict : + parameter_dict : dict[str, dict | list] A parameter dictionary containing parameters. - label : - The label of root group. - root_group: + label : str + The label of the group. + root_group : ParameterGroup The root group + + Returns + ------- + ParameterGroup + The created :class:`ParameterGroup` """ root = cls(label=label, root_group=root_group) for label, item in parameter_dict.items(): @@ -80,16 +98,21 @@ def from_list( label: str = None, root_group: ParameterGroup = None, ) -> ParameterGroup: - """Creates a :class:`ParameterGroup` from a list. + """Create a :class:`ParameterGroup` from a list. Parameters ---------- - parameter_list : + parameter_list : list[float | list] A parameter list containing parameters - label : - The label of the root group. - root_group: + label : str + The label of the group. + root_group : ParameterGroup The root group + + Returns + ------- + ParameterGroup + The created :class:`ParameterGroup` """ root = cls(label=label, root_group=root_group) @@ -116,8 +139,27 @@ def from_list( @classmethod def from_dataframe(cls, df: pd.DataFrame, source: str = "DataFrame") -> ParameterGroup: - """Creates a :class:`ParameterGroup` from a :class:`pandas.DataFrame`""" + """Create a :class:`ParameterGroup` from a :class:`pandas.DataFrame`. + Parameters + ---------- + df : pd.DataFrame + The source data frame. + source : str + Optional name of the source file, used for error messages. + + Returns + ------- + ParameterGroup + The created parameter group. + + Raises + ------ + ValueError + Raised if the columns 'label' or 'value' doesn't exist. Also raised if the columns + 'minimum', 'maximum' or 'values' contain non numeric values or if the columns + 'non-negative' or 'vary' are no boolean. + """ for column_name in ["label", "value"]: if column_name not in df: raise ValueError(f"Missing column '{column_name}' in '{source}'") @@ -167,17 +209,36 @@ def from_dataframe(cls, df: pd.DataFrame, source: str = "DataFrame") -> Paramete return root @property - def label(self) -> str: - """Label of the group.""" + def label(self) -> str | None: + """Label of the group. + + Returns + ------- + str + The label of the group. + """ return self._label @property - def root_group(self) -> ParameterGroup: - """Root of the group.""" + def root_group(self) -> ParameterGroup | None: + """Root of the group. + + Returns + ------- + ParameterGroup + The root group. + """ return self._root_group def to_dataframe(self) -> pd.DataFrame: - parameter_dict = { + """Create a pandas data frame from the group. + + Returns + ------- + pd.DataFrame + The created data frame. + """ + parameter_dict: dict[str, list[str | float | bool | None]] = { "label": [], "value": [], "minimum": [], @@ -196,25 +257,18 @@ def to_dataframe(self) -> pd.DataFrame: parameter_dict["expression"].append(parameter.expression) return pd.DataFrame(parameter_dict) - def to_csv(self, filename: str, delimiter: str = ","): - """Writes a :class:`ParameterGroup` to a CSV file. - - Parameters - ---------- - filepath : - The path to the CSV file. - delimiter : str - The delimiter of the CSV file. - """ - self.to_dataframe().to_csv(filename, sep=delimiter, na_rep="None", index=False) - def add_parameter(self, parameter: Parameter | list[Parameter]): - """Adds a :class:`Parameter` to the group. + """Add a :class:`Parameter` to the group. Parameters ---------- - parameter : + parameter : Parameter | list[Parameter] The parameter to add. + + Raises + ------ + TypeError + If ``parameter`` or any item of it is not an instance of :class:`Parameter`. """ if not isinstance(parameter, list): parameter = [parameter] @@ -228,19 +282,30 @@ def add_parameter(self, parameter: Parameter | list[Parameter]): self._parameters[p.label] = p def add_group(self, group: ParameterGroup): - """Adds a :class:`ParameterGroup` to the group. + """Add a :class:`ParameterGroup` to the group. Parameters ---------- - group : + group : ParameterGroup The group to add. + + Raises + ------ + TypeError + Raised if the group is not an instance of :class:`ParameterGroup`. """ if not isinstance(group, ParameterGroup): raise TypeError("Group must be glotaran.parameter.ParameterGroup") self[group.label] = group def get_nr_roots(self) -> int: - """Returns the number of roots of the group.""" + """Return the number of roots of the group. + + Returns + ------- + int + The number of roots. + """ n = 0 root = self.root_group while root is not None: @@ -249,34 +314,53 @@ def get_nr_roots(self) -> int: return n def groups(self) -> Generator[ParameterGroup, None, None]: - """Returns a generator over all groups and their subgroups.""" + """Return a generator over all groups and their subgroups. + + Yields + ------ + ParameterGroup + A subgroup of :class:`ParameterGroup`. + """ for group in self: yield from group.groups() def has(self, label: str) -> bool: - """Checks if a parameter with the given label is in the group or in a subgroup. + """Check if a parameter with the given label is in the group or in a subgroup. Parameters ---------- - label : - The label of the parameter, with its path in a parameter group prepended. - """ + label : str + The label of the parameter, with its path in a :class:`ParameterGroup` prepended. + Returns + ------- + bool + Whether a parameter with the given label exists in the group. + """ try: self.get(label) return True except Exception: return False - def get(self, label: str) -> Parameter: - """Gets a :class:`Parameter` by its label. + def get(self, label: str) -> Parameter: # type:ignore[override] + """Get a :class:`Parameter` by its label. Parameters ---------- - label : - The label of the parameter, with its path in a parameter group prepended. + label : str + The label of the parameter, with its path in a :class:`ParameterGroup` prepended. + + Returns + ------- + Parameter + The parameter. + + Raises + ------ + ParameterNotFoundException + Raised if no parameter with the given label exists. """ - # sometimes the spec parser delivers the labels as int label = str(label) @@ -296,6 +380,14 @@ def get(self, label: str) -> Parameter: raise ParameterNotFoundException(path, label) def copy(self) -> ParameterGroup: + """Create a copy of the :class:`ParameterGroup`. + + Returns + ------- + ParameterGroup : + A copy of the :class:`ParameterGroup`. + + """ root = ParameterGroup(label=self.label, root_group=self.root_group) for label, parameter in self._parameters.items(): @@ -307,19 +399,22 @@ def copy(self) -> ParameterGroup: return root def all( - self, root: str = None, separator: str = "." + self, root: str | None = None, separator: str = "." ) -> Generator[tuple[str, Parameter], None, None]: - """Returns a generator over all parameter in the group and it's subgroups together with - their labels. + """Iterate over all parameter in the group and it's subgroups together with their labels. Parameters ---------- - root : + root : str The label of the root group - separator: + separator : str The separator for the parameter labels. - """ + Yields + ------ + tuple[str, Parameter] + A tuple containing the full label of the parameter and the parameter itself. + """ root = f"{root}{self.label}{separator}" if root is not None else "" for label, p in self._parameters.items(): yield (f"{root}{label}", p) @@ -329,13 +424,18 @@ def all( def get_label_value_and_bounds_arrays( self, exclude_non_vary: bool = False ) -> tuple[list[str], np.ndarray, np.ndarray, np.ndarray]: - """Returns a arrays of all parameter labels, values and bounds. + """Return a arrays of all parameter labels, values and bounds. Parameters ---------- - - exclude_non_vary: bool = False + exclude_non_vary: bool If true, parameters with `vary=False` are excluded. + + Returns + ------- + tuple[list[str], np.ndarray, np.ndarray, np.ndarray] + A tuple containing a list of parameter labels and + an array of the values, lower and upper bounds. """ self.update_parameter_expression() @@ -355,8 +455,20 @@ def get_label_value_and_bounds_arrays( return labels, np.asarray(values), np.asarray(lower_bounds), np.asarray(upper_bounds) def set_from_label_and_value_arrays(self, labels: list[str], values: np.ndarray): - """Updates the parameter values from a list of labels and values.""" + """Update the parameter values from a list of labels and values. + Parameters + ---------- + labels : list[str] + A list of parameter labels. + values : np.ndarray + An array of parameter values. + + Raises + ------ + ValueError + Raised if the size of the labels does not match the stize of values. + """ if len(labels) != len(values): raise ValueError( f"Length of labels({len(labels)}) not equal to length of values({len(values)})." @@ -367,8 +479,28 @@ def set_from_label_and_value_arrays(self, labels: list[str], values: np.ndarray) self.update_parameter_expression() + def set_from_history(self, history: ParameterHistory, index: int): + """Update the :class:`ParameterGroup` with values from a parameter history. + + Parameters + ---------- + history : ParameterHistory + The parameter history. + index : int + The history index. + """ + self.set_from_label_and_value_arrays( + history.parameter_labels, history.get_parameters(index) + ) + def update_parameter_expression(self): - """Updates all parameters which have an expression.""" + """Update all parameters which have an expression. + + Raises + ------ + ValueError + Raised if an expression evaluates to a non-numeric value. + """ for label, parameter in self.all(): if parameter.expression is not None: value = self._evaluator(parameter.transformed_expression) @@ -380,9 +512,14 @@ def update_parameter_expression(self): parameter.value = value def markdown(self) -> MarkdownStr: - """Formats the :class:`ParameterGroup` as markdown string. + """Format the :class:`ParameterGroup` as markdown string. This is done by recursing the nested :class:`ParameterGroup` tree. + + Returns + ------- + MarkdownStr : + The markdown representation as string. """ node_indentation = " " * self.get_nr_roots() return_string = "" @@ -424,12 +561,25 @@ def markdown(self) -> MarkdownStr: return MarkdownStr(return_string) def _repr_markdown_(self) -> str: - """Special method used by ``ipython`` to render markdown.""" + """Create a markdown respresentation. + + Special method used by ``ipython`` to render markdown. + + Returns + ------- + str : + The markdown representation as string. + """ return str(self.markdown()) - def __repr__(self): - """Representation used by repl and tracebacks.""" + def __repr__(self) -> str: + """Representation used by repl and tracebacks. + Returns + ------- + str : + A string representation of the :class:`ParameterGroup`. + """ parameter_short_notations = [ [str(parameter.label), parameter.value] for parameter in self._parameters.values() ] @@ -443,6 +593,6 @@ def __repr__(self): else: return super().__repr__() - def __str__(self): + def __str__(self) -> str: """Representation used by print and str.""" return str(self.markdown()) diff --git a/glotaran/parameter/parameter_history.py b/glotaran/parameter/parameter_history.py new file mode 100644 index 000000000..9ab64dd25 --- /dev/null +++ b/glotaran/parameter/parameter_history.py @@ -0,0 +1,158 @@ +"""The glotaran parameter history package.""" +from __future__ import annotations + +import numpy as np +import pandas as pd + +from glotaran.parameter.parameter_group import ParameterGroup + + +class ParameterHistory: + """A class representing a history of parameters.""" + + def __init__(self): + + self._parameter_labels: list[str] = [] + self._parameters: list[np.ndarray] = [] + + @classmethod + def from_dataframe(cls, history_df: pd.DataFrame) -> ParameterHistory: + """Create a history from a pandas data frame. + + Parameters + ---------- + history_df : pd.DataFrame + The source data frame. + + Returns + ------- + ParameterHistory + The created history. + """ + history = cls() + + history._parameter_labels = history_df.columns + + for parameter_values in history_df.values: + history._parameters.append(parameter_values) + + return history + + @classmethod + def from_csv(cls, path: str) -> ParameterHistory: + """Create a history from a csv file. + + Parameters + ---------- + path : str + The path to the csv file. + + Returns + ------- + ParameterHistory + The created history. + """ + df = pd.read_csv(path) + return cls.from_dataframe(df) + + @property + def parameter_labels(self) -> list[str]: + """Return the labels of the parameters in the history. + + Returns + ------- + list[str] + A list of parameter labels. + """ + return self._parameter_labels + + @property + def parameters(self) -> list[np.ndarray]: + """Return the parameters in the history. + + Returns + ------- + list[np.ndarray] + A list of parameters in the history. + """ + return self._parameters + + def __len__(self) -> int: + """Return the number of records in the history.""" + return self.number_of_records + + @property + def number_of_records(self) -> int: + """Return the number of records in the history. + + Returns + ------- + int + The number of records. + """ + return len(self._parameters) + + def to_dataframe(self) -> pd.DataFrame: + """Create a data frame from the history. + + Returns + ------- + pd.DataFrame + The created data frame. + """ + return pd.DataFrame(self._parameters, columns=self.parameter_labels) + + def to_csv(self, file_name: str, delimiter: str = ","): + """Write a :class:`ParameterGroup` to a CSV file. + + Parameters + ---------- + file_name : str + The path to the CSV file. + delimiter : str + The delimiter of the CSV file. + """ + self.to_dataframe().to_csv(file_name, sep=delimiter) + + def append(self, parameter_group: ParameterGroup): + """Append a :class:`ParameterGroup` to the history. + + Parameters + ---------- + parameter_group : ParameterGroup + The group to append. + + Raises + ------ + ValueError + Raised if the parameter labels of the group differs from previous groups. + """ + ( + parameter_labels, + parameter_values, + _, + _, + ) = parameter_group.get_label_value_and_bounds_arrays() + if len(self._parameter_labels) == 0: + self._parameter_labels = parameter_labels + if parameter_labels != self.parameter_labels: + raise ValueError( + "Cannot append parameter group. Parameter labels do not match existing." + ) + + self._parameters.append(parameter_values) + + def get_parameters(self, index: int) -> np.ndarray: + """Get parameters for a history index. + + Parameters + ---------- + index : int + The history index. + + Returns + ------- + np.ndarray + The parameter values at the history index as array. + """ + return self._parameters[index] diff --git a/glotaran/parameter/test/test_parameter_history.py b/glotaran/parameter/test/test_parameter_history.py new file mode 100644 index 000000000..8c408fd35 --- /dev/null +++ b/glotaran/parameter/test/test_parameter_history.py @@ -0,0 +1,39 @@ +import numpy as np + +from glotaran.parameter.parameter_group import ParameterGroup +from glotaran.parameter.parameter_history import ParameterHistory + + +def test_parameter_history(): + group0 = ParameterGroup.from_list([["1", 1], ["2", 4]]) + group1 = ParameterGroup.from_list([["1", 2], ["2", 5]]) + group2 = ParameterGroup.from_list([["1", 3], ["2", 6]]) + + history = ParameterHistory() + + history.append(group0) + + assert history.parameter_labels == ["1", "2"] + + assert history.number_of_records == 1 + assert all(history.get_parameters(0) == [1, 4]) + + history.append(group1) + + assert history.number_of_records == 2 + assert all(history.get_parameters(1) == [2, 5]) + + history.append(group2) + + assert history.number_of_records == 3 + assert all(history.get_parameters(2) == [3, 6]) + + df = history.to_dataframe() + + assert all(df.columns == history.parameter_labels) + assert np.all(df.values == history.parameters) + + group2.set_from_history(history, 0) + + assert group2.get("1") == 1 + assert group2.get("2") == 4 diff --git a/glotaran/plugin_system/data_io_registration.py b/glotaran/plugin_system/data_io_registration.py index 732002254..4b4a628c6 100644 --- a/glotaran/plugin_system/data_io_registration.py +++ b/glotaran/plugin_system/data_io_registration.py @@ -201,6 +201,7 @@ def save_dataset( file_name: str | PathLike[str], format_name: str = None, *, + data_filters: list[str] | None = None, allow_overwrite: bool = False, **kwargs: Any, ) -> None: @@ -214,6 +215,8 @@ def save_dataset( File to write the data to. format_name : str Format the file should be in, if not provided it will be inferred from the file extension. + data_filters : list[str] | None + Optional list of items in the dataset to be saved. allow_overwrite : bool Whether or not to allow overwriting existing files, by default False **kwargs : Any diff --git a/glotaran/plugin_system/project_io_registration.py b/glotaran/plugin_system/project_io_registration.py index b2cb10110..7bfc11894 100644 --- a/glotaran/plugin_system/project_io_registration.py +++ b/glotaran/plugin_system/project_io_registration.py @@ -8,6 +8,7 @@ """ from __future__ import annotations +from dataclasses import dataclass from typing import TYPE_CHECKING from typing import TypeVar @@ -53,6 +54,19 @@ ) +@dataclass +class SavingOptions: + """A collection of options for result saving.""" + + data_filter: list[str] | None = None + data_format: Literal["nc"] = "nc" + parameter_format: Literal["csv"] = "csv" + report: bool = True + + +SAVING_OPTIONS_DEFAULT = SavingOptions() +SAVING_OPTIONS_MINIMAL = SavingOptions(data_filter=["fitted_data", "residual"], report=False) + PROJECT_IO_METHODS = ( "load_model", "save_model", @@ -397,7 +411,7 @@ def save_result( *, allow_overwrite: bool = False, **kwargs: Any, -) -> None: +) -> list[str] | None: """Write a :class:`Result` instance to a spec file. Parameters @@ -414,12 +428,17 @@ def save_result( **kwargs : Any Additional keyword arguments passes to the ``save_result`` implementation of the project io plugin. + + Returns + ------- + list[str] | None + List of file paths which were saved. """ protect_from_overwrite(result_path, allow_overwrite=allow_overwrite) io = get_project_io( format_name or inferr_file_format(result_path, needs_to_exist=False, allow_folder=True) ) - io.save_result( # type: ignore[call-arg] + return io.save_result( # type: ignore[call-arg] result_path=str(result_path), result=result, **kwargs, diff --git a/glotaran/project/__init__.py b/glotaran/project/__init__.py index 0f94b1a7d..93aa57dfa 100644 --- a/glotaran/project/__init__.py +++ b/glotaran/project/__init__.py @@ -1,4 +1,4 @@ +"""The glotaran project package.""" + from glotaran.project.result import Result -from glotaran.project.scheme import SavingOptions from glotaran.project.scheme import Scheme -from glotaran.project.scheme import default_data_filters diff --git a/glotaran/project/dataclass_helpers.py b/glotaran/project/dataclass_helpers.py new file mode 100644 index 000000000..00b25c538 --- /dev/null +++ b/glotaran/project/dataclass_helpers.py @@ -0,0 +1,125 @@ +"""Contains helper methods for dataclasses.""" +from __future__ import annotations + +import dataclasses +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + from typing import Callable + from typing import TypeVar + + DefaultType = TypeVar("DefaultType") + + +def exclude_from_dict_field( + default: DefaultType = dataclasses.MISSING, # type:ignore[assignment] +) -> DefaultType: + """Create a dataclass field with which will be excluded from ``asdict``. + + Parameters + ---------- + default : DefaultType + The default value of the field. + + Returns + ------- + DefaultType + The created field. + """ + return dataclasses.field(default=default, metadata={"exclude_from_dict": True}) + + +def file_representation_field( + target: str, + loader: Callable[[str], Any], + default: DefaultType = dataclasses.MISSING, # type:ignore[assignment] +) -> DefaultType: + """Create a dataclass field with target and loader as metadata. + + Parameters + ---------- + target : str + The name of the represented field. + loader : Callable[[str], Any] + A function to load the target field from a file. + default : DefaultType + The default value of the field. + + Returns + ------- + DefaultType + The created field. + """ + return dataclasses.field(default=default, metadata={"target": target, "loader": loader}) + + +def asdict(dataclass: object) -> dict[str, Any]: + """Create a dictionary containing all fields of the dataclass. + + Parameters + ---------- + dataclass : object + A dataclass instance. + + Returns + ------- + dict[str, Any] : + The dataclass represented as a dictionary. + """ + fields = dataclasses.fields(dataclass) + + dataclass_dict = {} + for field in fields: + if "exclude_from_dict" not in field.metadata: + value = getattr(dataclass, field.name) + dataclass_dict[field.name] = ( + asdict(value) if dataclasses.is_dataclass(value) else value + ) + + return dataclass_dict + + +def fromdict(dataclass_type: type, dataclass_dict: dict[str, Any], folder: Path = None) -> object: + """Create a dataclass instance from a dict and loads all file represented fields. + + Parameters + ---------- + dataclass_type : type + A dataclass type. + dataclass_dict : dict[str, Any] + A dict for instancing the the dataclass. + folder : Path + The root folder for file paths. If ``None`` file paths are consider absolute. + + Returns + ------- + object + Created instance of dataclass_type. + """ + fields = dataclasses.fields(dataclass_type) + + for field in fields: + if "target" in field.metadata and "loader" in field.metadata: + file_path = dataclass_dict.get(field.name) + if file_path is None: + continue + elif isinstance(file_path, list): + dataclass_dict[field.metadata["target"]] = [ + field.metadata["loader"](f if folder is None else folder / f) + for f in file_path + ] + elif isinstance(file_path, dict): + dataclass_dict[field.metadata["target"]] = { + k: field.metadata["loader"](f if folder is None else folder / f) + for k, f in file_path.items() + } + else: + dataclass_dict[field.metadata["target"]] = field.metadata["loader"]( + file_path if folder is None else folder / file_path + ) + elif dataclasses.is_dataclass(field.default) and field.name in dataclass_dict: + dataclass_dict[field.name] = type(field.default)(**dataclass_dict[field.name]) + + return dataclass_type(**dataclass_dict) diff --git a/glotaran/project/result.py b/glotaran/project/result.py index 1b01d55c6..eaf4fbd82 100644 --- a/glotaran/project/result.py +++ b/glotaran/project/result.py @@ -2,6 +2,11 @@ from __future__ import annotations from dataclasses import dataclass +from dataclasses import replace +from typing import Any +from typing import Dict +from typing import List +from typing import cast import numpy as np import xarray as xr @@ -9,21 +14,68 @@ from tabulate import tabulate from glotaran.deprecation import deprecate +from glotaran.io import load_dataset +from glotaran.io import load_parameters +from glotaran.io import load_scheme from glotaran.io import save_result from glotaran.model import Model from glotaran.parameter import ParameterGroup +from glotaran.parameter import ParameterHistory +from glotaran.project.dataclass_helpers import exclude_from_dict_field +from glotaran.project.dataclass_helpers import file_representation_field from glotaran.project.scheme import Scheme from glotaran.utils.ipython import MarkdownStr +class IncompleteResultError(Exception): + """Exception raised if mandatory arguments to create a result are missing. + + Since some mandatory fields of result can be either created from file or by + passing a class instance, the file and instance initialization aren't allowed + to both be None at the same time, but each is allowed to be ``None`` by its own. + """ + + @dataclass class Result: - """The result of a global analysis""" + """The result of a global analysis.""" - additional_penalty: np.ndarray | None - """A vector with the value for each additional penalty, or None""" - cost: ArrayLike - data: dict[str, xr.Dataset] + number_of_function_evaluations: int + """The number of function evaluations.""" + + success: bool + """Indicates if the optimization was successful.""" + + termination_reason: str + """The reason (message when) the optimizer terminated""" + + glotaran_version: str + """The glotaran version used to create the result.""" + + free_parameter_labels: list[str] + """List of labels of the free parameters used in optimization.""" + + scheme: Scheme = cast(Scheme, exclude_from_dict_field(None)) + scheme_file: str | None = file_representation_field("scheme", load_scheme, None) + + initial_parameters: ParameterGroup = cast(ParameterGroup, exclude_from_dict_field(None)) + initial_parameters_file: str | None = file_representation_field( + "initial_parameters", load_parameters, None + ) + + optimized_parameters: ParameterGroup = cast(ParameterGroup, exclude_from_dict_field(None)) + """The optimized parameters, organized in a :class:`ParameterGroup`""" + optimized_parameters_file: str | None = file_representation_field( + "optimized_parameters", load_parameters, None + ) + + parameter_history: ParameterHistory = cast(ParameterHistory, exclude_from_dict_field(None)) + """The parameter history.""" + parameter_history_file: str | None = file_representation_field( + "parameter_history", ParameterHistory.from_csv, None + ) + + data: dict[str, xr.Dataset] = cast(Dict[str, xr.Dataset], exclude_from_dict_field(None)) """The resulting data as a dictionary of :xarraydoc:`Dataset`. Notes @@ -31,31 +83,30 @@ class Result: The actual content of the data depends on the actual model and can be found in the documentation for the model. """ - free_parameter_labels: list[str] - """List of labels of the free parameters used in optimization.""" - number_of_function_evaluations: int - """The number of function evaluations.""" - initial_parameters: ParameterGroup - optimized_parameters: ParameterGroup - """The optimized parameters, organized in a :class:`ParameterGroup`""" - scheme: Scheme - success: bool - """Indicates if the optimization was successful.""" - termination_reason: str - """The reason (message when) the optimizer terminated""" + data_files: dict[str, str] | None = file_representation_field("data", load_dataset, None) + + additional_penalty: np.ndarray | None = exclude_from_dict_field(None) + """A vector with the value for each additional penalty, or None""" + + cost: ArrayLike | None = exclude_from_dict_field(None) + """The final cost.""" # The below can be none in case of unsuccessful optimization + chi_square: float | None = None r"""The chi-square of the optimization. :math:`\chi^2 = \sum_i^N [{Residual}_i]^2`.""" - covariance_matrix: ArrayLike | None = None + + covariance_matrix: ArrayLike | None = exclude_from_dict_field(None) """Covariance matrix. The rows and columns are corresponding to :attr:`free_parameter_labels`.""" + degrees_of_freedom: int | None = None """Degrees of freedom in optimization :math:`N - N_{vars}`.""" - jacobian: ArrayLike | None = None + + jacobian: ArrayLike | list | None = exclude_from_dict_field(None) """Modified Jacobian matrix at the solution See also: :func:`scipy.optimize.least_squares` @@ -79,8 +130,54 @@ class Result: :math:`rms = \sqrt{\chi^2_{red}}` """ + def __post_init__(self): + """Validate fields and cast attributes to correct type.""" + self._check_mandatory_fields() + if isinstance(self.jacobian, list): + self.jacobian = np.array(self.jacobian) + self.covariance_matrix = np.array(self.covariance_matrix) + + def _check_mandatory_fields(self): + """Check that required fields which can be set from file are not ``None``. + + Raises + ------ + IncompleteResultError + If any mandatory field and its file representation is ``None``. + """ + mandatory_fields = [ + ("scheme", ""), + ("initial_parameters", ""), + ("optimized_parameters", ""), + ("parameter_history", ""), + ("data", "s"), + ] + missing_fields = [ + (mandatory_field, file_post_fix) + for mandatory_field, file_post_fix in mandatory_fields + if ( + getattr(self, mandatory_field) is None + and getattr(self, f"{mandatory_field}_file{file_post_fix}") is None + ) + ] + if len(missing_fields) != 0: + error_message = "Result is missing mandatory fields:\n" + for missing_field, file_post_fix in missing_fields: + error_message += ( + f" - Required filed {missing_field!r} is missing!\n" + f" Set either {missing_field!r} or '{missing_field}_file{file_post_fix}'." + ) + raise IncompleteResultError(error_message) + @property def model(self) -> Model: + """Return the model used to fit result. + + Returns + ------- + Model + The model instance. + """ return self.scheme.model def get_scheme(self) -> Scheme: @@ -99,29 +196,24 @@ def get_scheme(self) -> Scheme: if "weight" in dataset: data[label]["weight"] = dataset.weight - return Scheme( - model=self.model, - parameters=self.optimized_parameters, - data=data, - group_tolerance=self.scheme.group_tolerance, - non_negative_least_squares=self.scheme.non_negative_least_squares, - maximum_number_function_evaluations=self.scheme.maximum_number_function_evaluations, - ftol=self.scheme.ftol, - gtol=self.scheme.gtol, - xtol=self.scheme.xtol, - optimization_method=self.scheme.optimization_method, - ) + return replace(self.scheme, parameters=self.optimized_parameters) def markdown(self, with_model: bool = True, base_heading_level: int = 1) -> MarkdownStr: - """Formats the model as a markdown text. + """Format the model as a markdown text. Parameters ---------- - with_model : + with_model : bool If `True`, the model will be printed with initial and optimized parameters filled in. - """ + base_heading_level : int + The level of the base heading. - general_table_rows = [ + Returns + ------- + MarkdownStr : str + The scheme as markdown string. + """ + general_table_rows: list[list[Any]] = [ ["Number of residual evaluation", self.number_of_function_evaluations], ["Number of variables", self.number_of_variables], ["Number of datapoints", self.number_of_data_points], @@ -170,47 +262,70 @@ def markdown(self, with_model: bool = True, base_heading_level: int = 1) -> Mark return MarkdownStr(result_table) def _repr_markdown_(self) -> str: - """Special method used by ``ipython`` to render markdown.""" + """Return a markdown representation str. + + Special method used by ``ipython`` to render markdown. + + Returns + ------- + str + The scheme as markdown string. + """ return str(self.markdown(base_heading_level=3)) - def __str__(self): + def __str__(self) -> str: + """Overwrite of ``__str__``.""" return str(self.markdown(with_model=False)) - @deprecate( - deprecated_qual_name_usage="glotaran.project.result.Result.save(result_path)", - new_qual_name_usage=( - "glotaran.io.save_result(" - "result=result, result_path=result_path, " - 'format_name="legacy", allow_overwrite=True' - ")" - ), - to_be_removed_in_version="0.6.0", - importable_indices=(2, 1), - ) def save(self, path: str) -> list[str]: - """Saves the result to given folder. + """Save the result to given folder. - Warning - ------- - Deprecated use ``save_result(result_path=result_path, result=result, - format_name="legacy", allow_overwrite=True)`` instead. + Parameters + ---------- + path : str + The path to the folder in which to save the result. + Returns + ------- + list[str] + Paths to all the saved files. + """ + return cast( + List[str], + save_result(result_path=path, result=self, format_name="folder", allow_overwrite=True), + ) - Returns a list with paths of all saved items. - The following files are saved: + def recreate(self) -> Result: + """Recrate a result from the initial parameters. - * `result.md`: The result with the model formatted as markdown text. + Returns + ------- + Result : + The recreated result. + """ + from glotaran.analysis.optimize import optimize - * `optimized_parameters.csv`: The optimized parameter as csv file. + return optimize(self.scheme) - * `{dataset_label}.nc`: The result data for each dataset as NetCDF file. + def verify(self) -> bool: + """Verify a result. - Parameters - ---------- - path : - The path to the folder in which to save the result. + Returns + ------- + bool : + Weather the recreated result is equal to this result. """ - save_result(result_path=path, result=self, format_name="legacy", allow_overwrite=True) + recreated = self.recreate() + + if self.root_mean_square_error != recreated.root_mean_square_error: + return False + + for label, dataset in self.data.items(): + for attr, array in dataset.items(): + if not np.allclose(array, recreated.data[label][attr]): + return False + + return True @deprecate( deprecated_qual_name_usage="glotaran.project.result.Result.get_dataset(dataset_label)", @@ -219,7 +334,7 @@ def save(self, path: str) -> list[str]: importable_indices=(2, 2), ) def get_dataset(self, dataset_label: str) -> xr.Dataset: - """Returns the result dataset for the given dataset label. + """Return the result dataset for the given dataset label. Warning ------- @@ -229,8 +344,16 @@ def get_dataset(self, dataset_label: str) -> xr.Dataset: Parameters ---------- - dataset_label : + dataset_label : str The label of the dataset. + + Returns + ------- + xr.Dataset : + The dataset. + + + .. # noqa: DAR401 """ try: return self.data[dataset_label] diff --git a/glotaran/project/scheme.py b/glotaran/project/scheme.py index a308389a8..c62e9b01e 100644 --- a/glotaran/project/scheme.py +++ b/glotaran/project/scheme.py @@ -1,3 +1,4 @@ +"""The module for :class:``Scheme``.""" from __future__ import annotations import warnings @@ -5,7 +6,12 @@ from typing import TYPE_CHECKING from glotaran.deprecation import deprecate +from glotaran.io import load_dataset +from glotaran.io import load_model +from glotaran.io import load_parameters from glotaran.io import load_scheme +from glotaran.project.dataclass_helpers import exclude_from_dict_field +from glotaran.project.dataclass_helpers import file_representation_field from glotaran.utils.ipython import MarkdownStr if TYPE_CHECKING: @@ -17,23 +23,20 @@ from glotaran.model import Model from glotaran.parameter import ParameterGroup -default_data_filters = {"minimal": ["fitted_data", "residual"], "full": None} - @dataclass -class SavingOptions: - level: Literal["minimal", "full"] = "full" - data_filter: list[str] | None = None - data_format: str = "nc" - parameter_format: str = "csv" - report: bool = True +class Scheme: + """A scheme is a collection of a model, parameters and a dataset. + A scheme also holds options for optimization. + """ -@dataclass -class Scheme: - model: Model | str - parameters: ParameterGroup | str - data: dict[str, xr.DataArray | xr.Dataset | str] + model: Model = exclude_from_dict_field() + parameters: ParameterGroup = exclude_from_dict_field() + data: dict[str, xr.DataArray | xr.Dataset] = exclude_from_dict_field() + model_file: str | None = file_representation_field("model", load_model, default=None) + parameters_file: str | None = file_representation_field("parameters", load_parameters, None) + data_files: dict[str, str] | None = file_representation_field("data", load_dataset, None) group: bool | None = None group_tolerance: float = 0.0 non_negative_least_squares: bool = False @@ -47,37 +50,66 @@ class Scheme: "Dogbox", "Levenberg-Marquardt", ] = "TrustRegionReflection" - saving: SavingOptions = SavingOptions() result_path: str | None = None def problem_list(self) -> list[str]: - """Returns a list with all problems in the model and missing parameters.""" + """Return a list with all problems in the model and missing parameters. + + Returns + ------- + list[str] + A list of all problems found in the scheme's model. + """ return self.model.problem_list(self.parameters) def validate(self) -> str: - """Returns a string listing all problems in the model and missing parameters.""" + """Return a string listing all problems in the model and missing parameters. + + Returns + ------- + str + A user-friendly string containing all the problems of a model if any. + Defaults to 'Your model is valid.' if no problems are found. + """ return self.model.validate(self.parameters) - def valid(self, parameters: ParameterGroup = None) -> bool: - """Returns `True` if there are no problems with the model or the parameters, - else `False`.""" - return self.model.valid(parameters) + def valid(self) -> bool: + """Check if there are no problems with the model or the parameters. + + Returns + ------- + bool + Whether the scheme is valid. + """ + return self.model.valid(self.parameters) def markdown(self): - """Formats the :class:`Scheme` as markdown string.""" - markdown_str = self.model.markdown(parameters=self.parameters) + """Format the :class:`Scheme` as markdown string. - markdown_str += "\n\n" + Returns + ------- + MarkdownStr + The scheme as markdown string. + """ + model_markdown_str = self.model.markdown(parameters=self.parameters) + + markdown_str = "\n\n" markdown_str += "__Scheme__\n\n" markdown_str += f"* *nnls*: {self.non_negative_least_squares}\n" markdown_str += f"* *nfev*: {self.maximum_number_function_evaluations}\n" markdown_str += f"* *group_tolerance*: {self.group_tolerance}\n" - return MarkdownStr(markdown_str) + return model_markdown_str + MarkdownStr(markdown_str) def is_grouped(self) -> bool: - """Returns whether the scheme should be grouped.""" + """Return whether the scheme should be grouped. + + Returns + ------- + bool + Weather the scheme should be grouped. + """ if self.group is not None and not self.group: return False is_groupable = self.model.is_groupable(self.parameters, self.data) @@ -86,18 +118,33 @@ def is_grouped(self) -> bool: return is_groupable def _repr_markdown_(self) -> str: - """Special method used by ``ipython`` to render markdown.""" + """Return a markdown representation str. + + Special method used by ``ipython`` to render markdown. + + Returns + ------- + str + The scheme as markdown string. + """ return str(self.markdown()) - def __str__(self): + def __str__(self) -> str: """Representation used by print and str.""" return str(self.markdown()) @property def model_dimensions(self) -> dict[str, str]: - """Returns the dataset model's model dimension.""" + """Return the dataset model's model dimension. + + Returns + ------- + dict[str, str] + A dictionary with the dataset labels as key and the model dimension of + the dataset as value. + """ return { - dataset_name: self.model.dataset[dataset_name] + dataset_name: self.model.dataset[dataset_name] # type:ignore[attr-defined] .fill(self.model, self.parameters) .set_data(self.data[dataset_name]) .get_model_dimension() @@ -106,9 +153,16 @@ def model_dimensions(self) -> dict[str, str]: @property def global_dimensions(self) -> dict[str, str]: - """Returns the dataset model's global dimension.""" + """Return the dataset model's global dimension. + + Returns + ------- + dict[str, str] + A dictionary with the dataset labels as key and the global dimension of + the dataset as value. + """ return { - dataset_name: self.model.dataset[dataset_name] + dataset_name: self.model.dataset[dataset_name] # type:ignore[attr-defined] .fill(self.model, self.parameters) .set_data(self.data[dataset_name]) .get_global_dimension() diff --git a/glotaran/project/test/test_dataclass_helpers.py b/glotaran/project/test/test_dataclass_helpers.py new file mode 100644 index 000000000..6dafc69f3 --- /dev/null +++ b/glotaran/project/test/test_dataclass_helpers.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass + +from glotaran.project.dataclass_helpers import asdict +from glotaran.project.dataclass_helpers import exclude_from_dict_field +from glotaran.project.dataclass_helpers import file_representation_field +from glotaran.project.dataclass_helpers import fromdict + + +def dummy_loader(file: str) -> int: + return {"foo.file": 21, "bar.file": 42}[file] + + +def test_serialize_to_file_name_field(): + @dataclass + class DummyDataclass: + foo: int = exclude_from_dict_field() + foo_file: int = file_representation_field("foo", dummy_loader) + bar: int = exclude_from_dict_field(default=42) + bar_file: int = file_representation_field("bar", dummy_loader, default="bar.file") + baz: int = 84 + + dummy_class = DummyDataclass(foo=21, foo_file="foo.file") + + dummy_class_dict = asdict(dummy_class) + + assert "foo" not in dummy_class_dict + assert dummy_class_dict["foo_file"] == "foo.file" + + assert "bar" not in dummy_class_dict + assert dummy_class_dict["bar_file"] == "bar.file" + + assert dummy_class_dict["baz"] == 84 + assert dummy_class_dict["baz"] == dummy_class.baz + + loaded = fromdict(DummyDataclass, dummy_class_dict) + + assert loaded == dummy_class diff --git a/glotaran/project/test/test_result.py b/glotaran/project/test/test_result.py index 60c27fd37..b8a1670a2 100644 --- a/glotaran/project/test/test_result.py +++ b/glotaran/project/test/test_result.py @@ -7,6 +7,7 @@ from glotaran.analysis.simulation import simulate from glotaran.analysis.test.models import ThreeDatasetDecay as suite from glotaran.project import Scheme +from glotaran.project.result import IncompleteResultError from glotaran.project.result import Result @@ -48,3 +49,22 @@ def test_result_ipython_rendering(dummy_result: Result): assert "text/markdown" in rendered_markdown_return assert rendered_markdown_return["text/markdown"].startswith("| Optimization Result") + + +def test_result_incomplete_exception(dummy_result: Result): + """Raise error if required fields are missing.""" + + with pytest.raises(IncompleteResultError) as excinfo: + Result(1, True, "foo", "gta", ["1"]) + + for mandatory_field, file_post_fix in [ + ("scheme", ""), + ("initial_parameters", ""), + ("optimized_parameters", ""), + ("parameter_history", ""), + ("data", "s"), + ]: + assert ( + f"Set either '{mandatory_field}' or '{mandatory_field}_file{file_post_fix}'." + in str(excinfo.value) + ) diff --git a/glotaran/project/test/test_scheme.py b/glotaran/project/test/test_scheme.py index 3e9fbcbca..57e04802e 100644 --- a/glotaran/project/test/test_scheme.py +++ b/glotaran/project/test/test_scheme.py @@ -32,19 +32,12 @@ def mock_scheme(tmp_path: Path) -> Scheme: ).to_netcdf(dataset_path) scheme_yml_str = f""" - model: {model_path} - parameters: {parameter_path} - non-negative-least-squares: True - maximum-number-function-evaluations: 42 - data: + model_file: {model_path} + parameters_file: {parameter_path} + non_negative_least_squares: True + maximum_number_function_evaluations: 42 + data_files: dataset1: {dataset_path} - - saving: - level: minimal - data_filter: [a, b, c] - data_format: csv - parameter_format: yaml - report: false """ scheme_path = tmp_path / "scheme.yml" scheme_path.write_text(scheme_yml_str) @@ -53,6 +46,7 @@ def mock_scheme(tmp_path: Path) -> Scheme: def test_scheme(mock_scheme: Scheme): + """Test scheme attributes.""" assert mock_scheme.model is not None assert mock_scheme.model_dimensions["dataset1"] == "time" @@ -67,12 +61,6 @@ def test_scheme(mock_scheme: Scheme): assert "dataset1" in mock_scheme.data assert mock_scheme.data["dataset1"].data.shape == (1, 3) - assert mock_scheme.saving.level == "minimal" - assert mock_scheme.saving.data_filter == ["a", "b", "c"] - assert mock_scheme.saving.data_format == "csv" - assert mock_scheme.saving.parameter_format == "yaml" - assert not mock_scheme.saving.report - def test_scheme_ipython_rendering(mock_scheme: Scheme): """Autorendering in ipython""" diff --git a/pyproject.toml b/pyproject.toml index 678b78df6..72ebe1218 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ remove_redundant_aliases = true [tool.interrogate] exclude = ["setup.py", "docs", "*test/*", "benchmark/*"] ignore-init-module = true -fail-under = 55 +fail-under = 59 [tool.nbqa.addopts] flake8 = [ diff --git a/requirements_dev.txt b/requirements_dev.txt index 6df755111..691ba1ea3 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -33,6 +33,7 @@ pytest-env>=0.6.2 pytest-runner>=2.11.1 pytest-benchmark>=3.1.1 pytest-allclose>=1.0.0 +types-dataclasses>=0.1.7 # code quality assurance flake8>=3.8.3 diff --git a/setup.cfg b/setup.cfg index 022bd304e..b584563e6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,7 +79,7 @@ ignore_messages = xarraydoc [darglint] docstring_style = numpy -ignore_regex = test_.+|.*wrapper.*|inject_warn_into_call|.*dummy.*|__(str|eq)__ +ignore_regex = test_.+|.*wrapper.*|inject_warn_into_call|.*dummy.*|__(.+?)__ [pydocstyle] convention = numpy @@ -100,3 +100,9 @@ ignore_errors = False [mypy-glotaran.deprecation.*] ignore_errors = False + +[mypy-glotaran.parameter.*] +ignore_errors = False + +[mypy-glotaran.project.*] +ignore_errors = False