Skip to content

Commit

Permalink
Edo's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MargaretDuff committed Oct 2, 2024
1 parent 2edd0e8 commit 9da703a
Showing 1 changed file with 58 additions and 24 deletions.
82 changes: 58 additions & 24 deletions Wrappers/Python/cil/optimisation/algorithms/SPDHG.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ class SPDHG(Algorithm):
A convex function with a "simple" proximal
operator : BlockOperator
BlockOperator must contain Linear Operators
tau : positive float, optional, default= see note
Step size parameter for primal problem
sigma : list of positive float, optional, default= see note
List of Step size parameters for dual problem
initial : DataContainer, optional, default to a zero DataContainer in the range of the `operator`.
tau : positive float, optional, default=None
Step size parameter for primal problem. If `None` see note.
sigma : list of positive float, optional, default=None
List of Step size parameters for dual problem. If `None` see note.
initial : DataContainer, optional, the default value is a zero DataContainer in the range of the `operator`.
Initial point for the SPDHG algorithm
gamma : float, optional
Parameter controlling the trade-off between the primal and dual step sizes
sampler: optional, an instance of a `cil.optimisation.utilities.Sampler` class or another class with the function __next__(self) implemented outputting an integer from {1,...,len(operator)}.
Method of selecting the next index for the SPDHG update. If None, a sampler will be created for random sampling with replacement and each index will have probability = 1/len(operator)
prob_weights: optional, list of floats of length num_indices that sum to 1. Defaults to [1/len(operator)]*len(operator)
Consider that the sampler is called a large number of times this argument holds the expected number of times each index would be called, normalised to 1. Note that this should not be passed if the provided sampler has it as an attribute.
Method of selecting the next index for the SPDHG update. If None, a sampler will be created for random sampling with replacement and each index will have `probability = 1/len(operator)`
prob_weights: optional, list of floats of length `num_indices` that sum to 1. Defaults to `[1/len(operator)]*len(operator)`
Consider that the sampler is called a large number of times this argument holds the expected number of times each index would be called, normalised to 1. Note that this should not be passed if the provided sampler has it as an attribute: if the sampler has a `prob_weight` attribute it will take precedence on this parameter.
Expand Down Expand Up @@ -87,7 +87,7 @@ class SPDHG(Algorithm):
Note
-----
When setting `sigma` and `tau`, there are 4 possible cases considered by setup function:
When setting `sigma` and `tau`, there are 4 possible cases considered by setup function. In all cases the probabilities :math:`p_i` are set by a default or user defined sampler:
- Case 1: If neither `sigma` or `tau` are provided then `sigma` is set using the formula:
Expand Down Expand Up @@ -153,24 +153,29 @@ def set_up(self, f, g, operator, sigma=None, tau=None,
self._ndual_subsets = len(self.operator)
self._sampler = sampler

self._prob_weights = getattr(self._sampler, 'prob_weights', None)
if prob_weights is not None:
if self._prob_weights is None:
self._prob_weights = prob_weights
else:
raise ValueError(
#Set up the _prob_weights. In preference order they are taken from: the sampler, the prob_weights argument, the deprecated prob argument or set as defualt.
self._prob_weights = getattr(self._sampler, 'prob_weights', None) # from the sampler
if self._prob_weights is None: #from prob_weights
self._prob_weights = prob_weights
elif prob_weights is not None:
raise ValueError(
' You passed a `prob_weights` argument and a sampler with attribute `prob_weights`, please remove the `prob_weights` argument.')

self._deprecated_kwargs(deprecated_kwargs)
self._deprecated_prob(deprecated_kwargs) #from prob argument

if self._prob_weights is None:
if self._prob_weights is None: #set from default
self._prob_weights = [1/self._ndual_subsets]*self._ndual_subsets

#Set the sampler
if self._sampler is None:
self._sampler = Sampler.random_with_replacement(
len(operator), prob=self._prob_weights)

#Set the norms of the operators
self._deprecated_norms(deprecated_kwargs)
self._norms = operator.get_norms_as_list()
#Check for other kwargs
self._deprecated_else(deprecated_kwargs)

self.set_step_sizes(sigma=sigma, tau=tau)

Expand All @@ -196,7 +201,7 @@ def set_up(self, f, g, operator, sigma=None, tau=None,
self.configured = True
logging.info("{} configured".format(self.__class__.__name__, ))

def _deprecated_kwargs(self, deprecated_kwargs):
def _deprecated_prob(self, deprecated_kwargs):
"""
Handle deprecated keyword arguments for backward compatibility.
Expand All @@ -209,7 +214,7 @@ def _deprecated_kwargs(self, deprecated_kwargs):
-----
This method is called by the set_up method.
"""
norms = deprecated_kwargs.pop('norms', None)

prob = deprecated_kwargs.pop('prob', None)

if prob is not None:
Expand All @@ -221,14 +226,45 @@ def _deprecated_kwargs(self, deprecated_kwargs):
raise ValueError(
'`prob` is being deprecated to be replaced with a sampler class and `prob_weights`. You passed a `prob` argument, and either a `prob_weights` argument or a sampler with a `prob_weights` property. Please give only one of the three. ')



def _deprecated_norms(self, deprecated_kwargs):
"""
Handle deprecated keyword arguments for backward compatibility.
Parameters
----------
deprecated_kwargs : dict
Dictionary of keyword arguments.
Notes
-----
This method is called by the set_up method.
"""
norms = deprecated_kwargs.pop('norms', None)

if norms is not None:
self.operator.set_norms(norms)
warnings.warn(
' `norms` is being deprecated, use instead the `BlockOperator` function `set_norms`', DeprecationWarning, stacklevel=2)

def _deprecated_else(self, deprecated_kwargs):
"""
Handle deprecated keyword arguments for backward compatibility.
Parameters
----------
deprecated_kwargs : dict
Dictionary of keyword arguments.
Notes
-----
This method is called by the set_up method.
"""
if deprecated_kwargs:
raise ValueError("Additional keyword arguments passed but not used: {}".format(deprecated_kwargs))



@property
def sigma(self):
return self._sigma
Expand Down Expand Up @@ -336,9 +372,7 @@ def set_step_sizes(self, sigma=None, tau=None):
self._tau = min([value for value in values if value > 1e-8])

else:
if isinstance(tau, Number) and tau > 0:
pass
else:
if not ( isinstance(tau, Number) and tau > 0):
raise ValueError(
"The step-sizes of SPDHG must be positive, passed tau = {}".format(tau))

Expand Down Expand Up @@ -367,7 +401,7 @@ def check_convergence(self):
return False
return True
else:
return False
raise ValueError('Convergence criterion currently can only be checked for scalar values of tau and sigma[i].')

def update(self):
""" Runs one iteration of SPDHG
Expand Down

0 comments on commit 9da703a

Please sign in to comment.