diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e1fc4d59a..fd0885792e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ - Add checks on out argument passed to processors to ensure corrrect dtype and size (#1805) - Internal refactor: Replaced string-based label checks with enum-based checks for improved type safety and consistency (#1692) - Internal refactor: Separate framework into multiple files (#1692) + - Allow the SIRT algorithm to take `initial=None` (#1906) - Testing: - New unit tests for operators and functions to check for in place errors and the behaviour of `out` (#1805) - Updates in SPDHG vs PDHG unit test to reduce test time and adjustments to parameters (#1898) diff --git a/Wrappers/Python/cil/optimisation/algorithms/SIRT.py b/Wrappers/Python/cil/optimisation/algorithms/SIRT.py index 3cf01806f7..3064ae4b53 100644 --- a/Wrappers/Python/cil/optimisation/algorithms/SIRT.py +++ b/Wrappers/Python/cil/optimisation/algorithms/SIRT.py @@ -50,7 +50,7 @@ class SIRT(Algorithm): ---------- initial : DataContainer, default = None - Starting point of the algorithm, default value = Zero DataContainer + Starting point of the algorithm, default value = DataContainer in the domain of the operator allocated with zeros. operator : LinearOperator The operator A. data : DataContainer @@ -91,7 +91,7 @@ class SIRT(Algorithm): """ - def __init__(self, initial, operator, data, lower=None, upper=None, constraint=None, **kwargs): + def __init__(self, initial=None, operator=None, data=None, lower=None, upper=None, constraint=None, **kwargs): super(SIRT, self).__init__(**kwargs) @@ -100,6 +100,23 @@ def __init__(self, initial, operator, data, lower=None, upper=None, constraint=N def set_up(self, initial, operator, data, lower=None, upper=None, constraint=None): """Initialisation of the algorithm""" log.info("%s setting up", self.__class__.__name__) + + warning = 0 + if operator is None: + warning += 1 + msg = "an `operator`" + if data is None: + warning += 10 + if warning > 10: + msg += " and `data`" + else: + msg = "`data`" + if warning > 0: + raise ValueError(f'You must pass {msg} to the SIRT algorithm' ) + + if initial is None: + initial = operator.domain_geometry().allocate(0) + self.x = initial.copy() self.tmp_x = self.x * 0.0 self.operator = operator diff --git a/Wrappers/Python/test/test_algorithms.py b/Wrappers/Python/test/test_algorithms.py index d9d6fa2086..e3f6302d78 100644 --- a/Wrappers/Python/test/test_algorithms.py +++ b/Wrappers/Python/test/test_algorithms.py @@ -979,6 +979,53 @@ def setUp(self): def tearDown(self): pass + + def test_set_up(self): + + initial = self.A2.domain_geometry().allocate(0) + sirt = SIRT(initial=initial, operator=self.A2, data=self.b2, lower=0, upper=1) + + # Test if set_up correctly configures the object + self.assertTrue(sirt.configured) + self.assertIsNotNone(sirt.x) + self.assertIsNotNone(sirt.r) + self.assertIsNotNone(sirt.constraint) + self.assertEqual(sirt.constraint.lower, 0) + self.assertEqual(sirt.constraint.upper, 1) + + + constraint = IndicatorBox(lower=0, upper=1) + sirt = SIRT(initial=None, operator=self.A2, data=self.b2, constraint=constraint) + + # Test if set_up correctly configures the object with constraint + self.assertTrue(sirt.configured) + self.assertEqual(sirt.constraint, constraint) + + + with self.assertRaises(ValueError) as context: + sirt = SIRT(initial=None, operator=None, data=self.b2) + self.assertEqual(str(context.exception), 'You must pass an `operator` to the SIRT algorithm') + + + with self.assertRaises(ValueError) as context: + sirt = SIRT(initial=None, operator=self.A2, data=None) + self.assertEqual(str(context.exception), 'You must pass `data` to the SIRT algorithm') + with self.assertRaises(ValueError) as context: + sirt = SIRT(initial=None, operator=None, data=None) + self.assertEqual(str(context.exception), + 'You must pass an `operator` and `data` to the SIRT algorithm') + + sirt = SIRT(initial=None, operator=self.A2, data=self.b2) + self.assertTrue(sirt.configured) + self.assertIsInstance(sirt.x, ImageData) + self.assertTrue((sirt.x.as_array() == 0).all()) + + + initial = self.A2.domain_geometry().allocate(1) + sirt = SIRT(initial=initial, operator=self.A2, data=self.b2) + self.assertTrue(sirt.configured) + self.assertIsInstance(sirt.x, ImageData) + self.assertTrue((sirt.x.as_array() == 1).all()) def test_update(self): # sirt run 5 iterations