Skip to content

Commit

Permalink
Add Feature/321 parameter history (#557)
Browse files Browse the repository at this point in the history
* added parameter history
Closes #557 
Co-authored-by: Jörn Weißenborn <joern.weissenborn@gmail.com>
  • Loading branch information
jsnel authored and s-weigand committed Feb 23, 2021
1 parent bd48fb5 commit 08535b5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
7 changes: 6 additions & 1 deletion glotaran/analysis/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def optimize_problem(problem: Problem, verbose: bool = True) -> Result:
xtol = problem.scheme.xtol
verbose = 2 if verbose else 0
termination_reason = ""
history_index = -1

try:
ls_result = least_squares(
Expand All @@ -60,10 +61,13 @@ def optimize_problem(problem: Problem, verbose: bool = True) -> Result:
warn(f"Optimization failed:\n\n{e}")
termination_reason = str(e)
ls_result = None
history_index = -2

problem.save_parameters_for_history()

return Result(
problem.scheme,
problem.create_result_data(),
problem.create_result_data(history_index=history_index),
problem.parameters,
problem.additional_penalty,
ls_result,
Expand All @@ -73,6 +77,7 @@ def optimize_problem(problem: Problem, verbose: bool = True) -> Result:


def _calculate_penalty(parameters: np.ndarray, labels: List[str] = None, problem: Problem = None):
problem.save_parameters_for_history()
problem.parameters.set_from_label_and_value_arrays(labels, parameters)
problem.reset()
return problem.full_penalty
14 changes: 13 additions & 1 deletion glotaran/analysis/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, scheme: Scheme):
self._filled_dataset_descriptors = None

self.parameters = scheme.parameters.copy()
self._parameter_history = []

# all of the above are always not None

Expand Down Expand Up @@ -144,6 +145,10 @@ def parameters(self, parameters: ParameterGroup):
self._parameters = parameters
self.reset()

@property
def parameter_history(self) -> List[ParameterGroup]:
return self._parameter_history

@property
def grouped(self) -> bool:
return self._grouped
Expand Down Expand Up @@ -255,6 +260,9 @@ def full_penalty(self) -> np.ndarray:
)
return self._full_penalty

def save_parameters_for_history(self):
self._parameter_history.append(self._parameters)

def reset(self):
"""Resets all results and `DatasetDescriptors`. Use after updating parameters."""
self._filled_dataset_descriptors = {
Expand Down Expand Up @@ -810,8 +818,12 @@ def calculate_additional_penalty(self) -> Union[np.ndarray, Dict[str, np.ndarray
self._additional_penalty = None
return self._additional_penalty

def create_result_data(self, copy: bool = True) -> Dict[str, xr.Dataset]:
def create_result_data(
self, copy: bool = True, history_index: int = None
) -> Dict[str, xr.Dataset]:

if history_index is not None and history_index != -1:
self.parameters = self.parameter_history[history_index]
result_data = {label: self._create_result_dataset(label, copy=copy) for label in self.data}

if callable(self.model.finalize_data):
Expand Down

0 comments on commit 08535b5

Please sign in to comment.