Skip to content

Commit

Permalink
SIRT update objective fix and catch in-place errors in the constraint (
Browse files Browse the repository at this point in the history
…#1658)



---------

Signed-off-by: Margaret Duff <43645617+MargaretDuff@users.noreply.github.com>
  • Loading branch information
MargaretDuff authored Feb 26, 2024
1 parent 8d50c65 commit 7466a4f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 10 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
- Added `Callback` (abstract base class), `ProgressCallback`, `TextProgressCallback`, `LogfileCallback`
- Deprecated `Algorithm.run(callback: Callable)`
- Added `Algorithm.run(callbacks: list[Callback])`
- New unit tests have been implemented for operators and functions to check for in place errors and the behaviour of `out`.
- Bug fix for missing factor of 1/2 in SIRT update objective and catch in place errors in the SIRT constraint

* 23.1.0
- Fix bug in IndicatorBox proximal_conjugate
Expand All @@ -45,7 +47,7 @@
- Added warmstart capability to proximal evaluation of the CIL TotalVariation function.
- Bug fix in the LinearOperator norm with an additional flag for the algorithm linearOperator.PowerMethod
- Tidied up documentation in the framework folder
- New unit tests have been implemented for operators and functions to check for in place errors and the behaviour of `out`.


* 23.0.1
- Fix bug with NikonReader requiring ROI to be set in constructor.
Expand Down
17 changes: 10 additions & 7 deletions Wrappers/Python/cil/optimisation/algorithms/SIRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.functions import IndicatorBox
from cil.framework import BlockDataContainer
from cil.utilities.errors import InPlaceError
from numpy import inf
import numpy
import logging
Expand All @@ -35,7 +36,7 @@ class SIRT(Algorithm):
The SIRT algorithm is
.. math:: x^{k+1} = \mathrm{proj}_{C}( x^{k} + \omega * D ( A^{T} ( M * (b - Ax) ) ) ),
.. math:: x^{k+1} = \mathrm{proj}_{C}( x^{k} + \omega * D ( A^{T} ( M * (b - Ax^{k}) ) ) ),
where,
:math:`M = \frac{1}{A*\mathbb{1}}`,
Expand Down Expand Up @@ -82,7 +83,7 @@ class SIRT(Algorithm):
Examples
--------
.. math:: \underset{x}{\mathrm{argmin}} \| x - d\|^{2}
.. math:: \underset{x}{\mathrm{argmin}} \frac{1}{2}\| x - d\|^{2}
>>> sirt = SIRT(initial = ig.allocate(0), operator = A, data = d, max_iteration = 5)
Expand Down Expand Up @@ -198,15 +199,17 @@ def update(self):
self.operator.adjoint(self.r, out=self.tmp_x)
self.x.sapyb(1.0, self.tmp_x, self._Dscaled, out=self.x)


if self.constraint is not None:
# IndicatorBox allows inplace operation for proximal
self.constraint.proximal(self.x, tau=1, out=self.x)
try:
self.constraint.proximal(self.x, tau=1, out=self.x)
except InPlaceError:
self.x=self.constraint.proximal(self.x, tau=1)

def update_objective(self):
r"""Returns the objective
.. math:: \|A x - b\|^{2}
.. math:: \frac{1}{2}\|A x - b\|^{2}
"""
self.loss.append(self.r.squared_norm())
self.loss.append(0.5*self.r.squared_norm())

28 changes: 26 additions & 2 deletions Wrappers/Python/test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def test_FISTA_Denoising(self):



class TestSIRT(unittest.TestCase):
class TestSIRT(CCPiTestClass):


def setUp(self):
Expand Down Expand Up @@ -757,7 +757,31 @@ def test_SIRT_remove_nan_or_inf_with_BlockDataContainer(self):

self.assertFalse(np.any(sirt.D == np.inf))


def test_SIRT_with_TV(self):
data = dataexample.SIMPLE_PHANTOM_2D.get(size=(128,128))
ig = data.geometry
A=IdentityOperator(ig)
constraint=TotalVariation(warm_start=False, max_iteration=100)
initial=ig.allocate('random', seed=5)
sirt = SIRT(initial = initial, operator=A, data=data, max_iteration=2, constraint=constraint)
sirt.run(2, verbose=0)
f=LeastSquares(A,data, c=0.5)
fista=FISTA(initial=initial,f=f, g=constraint, max_iteration=1000)
fista.run(100, verbose=0)
self.assertNumpyArrayAlmostEqual(fista.x.as_array(), sirt.x.as_array())

def test_SIRT_with_TV_warm_start(self):
data = dataexample.SIMPLE_PHANTOM_2D.get(size=(128,128))
ig = data.geometry
A=IdentityOperator(ig)
constraint=1e6*TotalVariation(warm_start=True, max_iteration=100)
initial=ig.allocate('random', seed=5)
sirt = SIRT(initial = initial, operator=A, data=data, max_iteration=150, constraint=constraint)
sirt.run(25, verbose=0)

self.assertNumpyArrayAlmostEqual(sirt.x.as_array(), ig.allocate(0.25).as_array(),3)


class TestSPDHG(unittest.TestCase):

@unittest.skipUnless(has_astra, "cil-astra not available")
Expand Down

0 comments on commit 7466a4f

Please sign in to comment.