Skip to content

Commit

Permalink
Allow initial = None in SIRT (#1906)
Browse files Browse the repository at this point in the history
---------

Signed-off-by: Margaret Duff <43645617+MargaretDuff@users.noreply.github.com>
  • Loading branch information
MargaretDuff authored Aug 29, 2024
1 parent 4f6e3cf commit a7b0ba0
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- Add checks on out argument passed to processors to ensure corrrect dtype and size (#1805)
- Internal refactor: Replaced string-based label checks with enum-based checks for improved type safety and consistency (#1692)
- Internal refactor: Separate framework into multiple files (#1692)
- Allow the SIRT algorithm to take `initial=None` (#1906)
- Testing:
- New unit tests for operators and functions to check for in place errors and the behaviour of `out` (#1805)
- Updates in SPDHG vs PDHG unit test to reduce test time and adjustments to parameters (#1898)
Expand Down
21 changes: 19 additions & 2 deletions Wrappers/Python/cil/optimisation/algorithms/SIRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SIRT(Algorithm):
----------
initial : DataContainer, default = None
Starting point of the algorithm, default value = Zero DataContainer
Starting point of the algorithm, default value = DataContainer in the domain of the operator allocated with zeros.
operator : LinearOperator
The operator A.
data : DataContainer
Expand Down Expand Up @@ -91,7 +91,7 @@ class SIRT(Algorithm):
"""


def __init__(self, initial, operator, data, lower=None, upper=None, constraint=None, **kwargs):
def __init__(self, initial=None, operator=None, data=None, lower=None, upper=None, constraint=None, **kwargs):

super(SIRT, self).__init__(**kwargs)

Expand All @@ -100,6 +100,23 @@ def __init__(self, initial, operator, data, lower=None, upper=None, constraint=N
def set_up(self, initial, operator, data, lower=None, upper=None, constraint=None):
"""Initialisation of the algorithm"""
log.info("%s setting up", self.__class__.__name__)

warning = 0
if operator is None:
warning += 1
msg = "an `operator`"
if data is None:
warning += 10
if warning > 10:
msg += " and `data`"
else:
msg = "`data`"
if warning > 0:
raise ValueError(f'You must pass {msg} to the SIRT algorithm' )

if initial is None:
initial = operator.domain_geometry().allocate(0)

self.x = initial.copy()
self.tmp_x = self.x * 0.0
self.operator = operator
Expand Down
47 changes: 47 additions & 0 deletions Wrappers/Python/test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,53 @@ def setUp(self):

def tearDown(self):
pass

def test_set_up(self):

initial = self.A2.domain_geometry().allocate(0)
sirt = SIRT(initial=initial, operator=self.A2, data=self.b2, lower=0, upper=1)

# Test if set_up correctly configures the object
self.assertTrue(sirt.configured)
self.assertIsNotNone(sirt.x)
self.assertIsNotNone(sirt.r)
self.assertIsNotNone(sirt.constraint)
self.assertEqual(sirt.constraint.lower, 0)
self.assertEqual(sirt.constraint.upper, 1)


constraint = IndicatorBox(lower=0, upper=1)
sirt = SIRT(initial=None, operator=self.A2, data=self.b2, constraint=constraint)

# Test if set_up correctly configures the object with constraint
self.assertTrue(sirt.configured)
self.assertEqual(sirt.constraint, constraint)


with self.assertRaises(ValueError) as context:
sirt = SIRT(initial=None, operator=None, data=self.b2)
self.assertEqual(str(context.exception), 'You must pass an `operator` to the SIRT algorithm')


with self.assertRaises(ValueError) as context:
sirt = SIRT(initial=None, operator=self.A2, data=None)
self.assertEqual(str(context.exception), 'You must pass `data` to the SIRT algorithm')
with self.assertRaises(ValueError) as context:
sirt = SIRT(initial=None, operator=None, data=None)
self.assertEqual(str(context.exception),
'You must pass an `operator` and `data` to the SIRT algorithm')

sirt = SIRT(initial=None, operator=self.A2, data=self.b2)
self.assertTrue(sirt.configured)
self.assertIsInstance(sirt.x, ImageData)
self.assertTrue((sirt.x.as_array() == 0).all())


initial = self.A2.domain_geometry().allocate(1)
sirt = SIRT(initial=initial, operator=self.A2, data=self.b2)
self.assertTrue(sirt.configured)
self.assertIsInstance(sirt.x, ImageData)
self.assertTrue((sirt.x.as_array() == 1).all())

def test_update(self):
# sirt run 5 iterations
Expand Down

0 comments on commit a7b0ba0

Please sign in to comment.