From 9da703ad3e40056e8eb3b758d40920e784b93718 Mon Sep 17 00:00:00 2001 From: Margaret Duff Date: Wed, 2 Oct 2024 10:00:54 +0000 Subject: [PATCH] Edo's comments --- .../cil/optimisation/algorithms/SPDHG.py | 82 +++++++++++++------ 1 file changed, 58 insertions(+), 24 deletions(-) diff --git a/Wrappers/Python/cil/optimisation/algorithms/SPDHG.py b/Wrappers/Python/cil/optimisation/algorithms/SPDHG.py index c3432cd3d9..2de8f6cb95 100644 --- a/Wrappers/Python/cil/optimisation/algorithms/SPDHG.py +++ b/Wrappers/Python/cil/optimisation/algorithms/SPDHG.py @@ -45,18 +45,18 @@ class SPDHG(Algorithm): A convex function with a "simple" proximal operator : BlockOperator BlockOperator must contain Linear Operators - tau : positive float, optional, default= see note - Step size parameter for primal problem - sigma : list of positive float, optional, default= see note - List of Step size parameters for dual problem - initial : DataContainer, optional, default to a zero DataContainer in the range of the `operator`. + tau : positive float, optional, default=None + Step size parameter for primal problem. If `None` see note. + sigma : list of positive float, optional, default=None + List of Step size parameters for dual problem. If `None` see note. + initial : DataContainer, optional, the default value is a zero DataContainer in the range of the `operator`. Initial point for the SPDHG algorithm gamma : float, optional Parameter controlling the trade-off between the primal and dual step sizes sampler: optional, an instance of a `cil.optimisation.utilities.Sampler` class or another class with the function __next__(self) implemented outputting an integer from {1,...,len(operator)}. - Method of selecting the next index for the SPDHG update. If None, a sampler will be created for random sampling with replacement and each index will have probability = 1/len(operator) - prob_weights: optional, list of floats of length num_indices that sum to 1. Defaults to [1/len(operator)]*len(operator) - Consider that the sampler is called a large number of times this argument holds the expected number of times each index would be called, normalised to 1. Note that this should not be passed if the provided sampler has it as an attribute. + Method of selecting the next index for the SPDHG update. If None, a sampler will be created for random sampling with replacement and each index will have `probability = 1/len(operator)` + prob_weights: optional, list of floats of length `num_indices` that sum to 1. Defaults to `[1/len(operator)]*len(operator)` + Consider that the sampler is called a large number of times this argument holds the expected number of times each index would be called, normalised to 1. Note that this should not be passed if the provided sampler has it as an attribute: if the sampler has a `prob_weight` attribute it will take precedence on this parameter. @@ -87,7 +87,7 @@ class SPDHG(Algorithm): Note ----- - When setting `sigma` and `tau`, there are 4 possible cases considered by setup function: + When setting `sigma` and `tau`, there are 4 possible cases considered by setup function. In all cases the probabilities :math:`p_i` are set by a default or user defined sampler: - Case 1: If neither `sigma` or `tau` are provided then `sigma` is set using the formula: @@ -153,24 +153,29 @@ def set_up(self, f, g, operator, sigma=None, tau=None, self._ndual_subsets = len(self.operator) self._sampler = sampler - self._prob_weights = getattr(self._sampler, 'prob_weights', None) - if prob_weights is not None: - if self._prob_weights is None: - self._prob_weights = prob_weights - else: - raise ValueError( + #Set up the _prob_weights. In preference order they are taken from: the sampler, the prob_weights argument, the deprecated prob argument or set as defualt. + self._prob_weights = getattr(self._sampler, 'prob_weights', None) # from the sampler + if self._prob_weights is None: #from prob_weights + self._prob_weights = prob_weights + elif prob_weights is not None: + raise ValueError( ' You passed a `prob_weights` argument and a sampler with attribute `prob_weights`, please remove the `prob_weights` argument.') - self._deprecated_kwargs(deprecated_kwargs) + self._deprecated_prob(deprecated_kwargs) #from prob argument - if self._prob_weights is None: + if self._prob_weights is None: #set from default self._prob_weights = [1/self._ndual_subsets]*self._ndual_subsets + #Set the sampler if self._sampler is None: self._sampler = Sampler.random_with_replacement( len(operator), prob=self._prob_weights) + #Set the norms of the operators + self._deprecated_norms(deprecated_kwargs) self._norms = operator.get_norms_as_list() + #Check for other kwargs + self._deprecated_else(deprecated_kwargs) self.set_step_sizes(sigma=sigma, tau=tau) @@ -196,7 +201,7 @@ def set_up(self, f, g, operator, sigma=None, tau=None, self.configured = True logging.info("{} configured".format(self.__class__.__name__, )) - def _deprecated_kwargs(self, deprecated_kwargs): + def _deprecated_prob(self, deprecated_kwargs): """ Handle deprecated keyword arguments for backward compatibility. @@ -209,7 +214,7 @@ def _deprecated_kwargs(self, deprecated_kwargs): ----- This method is called by the set_up method. """ - norms = deprecated_kwargs.pop('norms', None) + prob = deprecated_kwargs.pop('prob', None) if prob is not None: @@ -221,14 +226,45 @@ def _deprecated_kwargs(self, deprecated_kwargs): raise ValueError( '`prob` is being deprecated to be replaced with a sampler class and `prob_weights`. You passed a `prob` argument, and either a `prob_weights` argument or a sampler with a `prob_weights` property. Please give only one of the three. ') + + + def _deprecated_norms(self, deprecated_kwargs): + """ + Handle deprecated keyword arguments for backward compatibility. + + Parameters + ---------- + deprecated_kwargs : dict + Dictionary of keyword arguments. + + Notes + ----- + This method is called by the set_up method. + """ + norms = deprecated_kwargs.pop('norms', None) + if norms is not None: self.operator.set_norms(norms) warnings.warn( ' `norms` is being deprecated, use instead the `BlockOperator` function `set_norms`', DeprecationWarning, stacklevel=2) + def _deprecated_else(self, deprecated_kwargs): + """ + Handle deprecated keyword arguments for backward compatibility. + + Parameters + ---------- + deprecated_kwargs : dict + Dictionary of keyword arguments. + + Notes + ----- + This method is called by the set_up method. + """ if deprecated_kwargs: raise ValueError("Additional keyword arguments passed but not used: {}".format(deprecated_kwargs)) - + + @property def sigma(self): return self._sigma @@ -336,9 +372,7 @@ def set_step_sizes(self, sigma=None, tau=None): self._tau = min([value for value in values if value > 1e-8]) else: - if isinstance(tau, Number) and tau > 0: - pass - else: + if not ( isinstance(tau, Number) and tau > 0): raise ValueError( "The step-sizes of SPDHG must be positive, passed tau = {}".format(tau)) @@ -367,7 +401,7 @@ def check_convergence(self): return False return True else: - return False + raise ValueError('Convergence criterion currently can only be checked for scalar values of tau and sigma[i].') def update(self): """ Runs one iteration of SPDHG