Skip to content

Commit

Permalink
use more tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Aug 19, 2024
1 parent 441135d commit 73cbdfa
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 27 deletions.
5 changes: 3 additions & 2 deletions Wrappers/Python/cil/framework/acquisition_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1737,9 +1737,10 @@ def dimension_labels(self):
]

try:
labels = list(self._dimension_labels)
labels = self._dimension_labels
except AttributeError:
labels = labels_default.copy()
labels = labels_default
labels = list(labels)

#remove from list labels where len == 1
#
Expand Down
6 changes: 3 additions & 3 deletions Wrappers/Python/cil/framework/data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def dimension_labels(self):
def dimension_labels(self, val):
if val is None:
self._dimension_labels = None
elif len(list(val))==self.number_of_dimensions:
self._dimension_labels = tuple(val)
elif len(val_tuple := tuple(val)) == self.number_of_dimensions:
self._dimension_labels = val_tuple
else:
raise ValueError("dimension_labels expected a list containing {0} strings got {1}".format(self.number_of_dimensions, val))

Expand Down Expand Up @@ -260,7 +260,7 @@ def fill(self, array, **dimension):
else:

axis = [':']* self.number_of_dimensions
dimension_labels = list(self.dimension_labels)
dimension_labels = tuple(self.dimension_labels)
for k,v in dimension.items():
i = dimension_labels.index(k)
axis[i] = v
Expand Down
5 changes: 3 additions & 2 deletions Wrappers/Python/cil/framework/image_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ def dimension_labels(self):
self.voxel_num_x]

try:
labels = list(self._dimension_labels)
labels = self._dimension_labels
except AttributeError:
labels = labels_default.copy()
labels = labels_default
labels = list(labels)

for i, x in enumerate(shape_default):
if x == 0 or x==1:
Expand Down
24 changes: 12 additions & 12 deletions Wrappers/Python/cil/framework/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class ImageDimensionLabels(StrEnum):
HORIZONTAL_Y = auto()

@classmethod
def get_order_for_engine(cls, engine: str, geometry=None) -> list:
def get_order_for_engine(cls, engine: str, geometry=None) -> tuple:
"""
Returns the order of dimensions for a specific engine and geometry.
Expand All @@ -103,13 +103,13 @@ def get_order_for_engine(cls, engine: str, geometry=None) -> list:
geometry: ImageGeometry, optional
If unspecified, the default order is returned.
"""
order = [cls.CHANNEL, cls.VERTICAL, cls.HORIZONTAL_Y, cls.HORIZONTAL_X]
order = cls.CHANNEL, cls.VERTICAL, cls.HORIZONTAL_Y, cls.HORIZONTAL_X
engine_orders = {Backends.ASTRA: order, Backends.TIGRE: order, Backends.CIL: order}
dim_order = engine_orders[Backends(engine)]

if geometry is None:
return dim_order
return [label for label in dim_order if label in geometry.dimension_labels]
return tuple(label for label in dim_order if label in geometry.dimension_labels)

@classmethod
def check_order_for_engine(cls, engine: str, geometry) -> bool:
Expand All @@ -125,11 +125,11 @@ def check_order_for_engine(cls, engine: str, geometry) -> bool:
ValueError if the order of dimensions is incorrect.
"""
order_requested = cls.get_order_for_engine(engine, geometry)
if order_requested == list(geometry.dimension_labels):
if order_requested == tuple(geometry.dimension_labels):
return True
raise ValueError(
f"Expected dimension_label order {order_requested}"
f" got {list(geometry.dimension_labels)}."
f" got {tuple(geometry.dimension_labels)}."
f" Try using `data.reorder('{engine}')` to permute for {engine}")


Expand All @@ -152,7 +152,7 @@ class AcquisitionDimensionLabels(StrEnum):
HORIZONTAL = auto()

@classmethod
def get_order_for_engine(cls, engine: str, geometry=None) -> list:
def get_order_for_engine(cls, engine: str, geometry=None) -> tuple:
"""
Returns the order of dimensions for a specific engine and geometry.
Expand All @@ -162,14 +162,14 @@ def get_order_for_engine(cls, engine: str, geometry=None) -> list:
If unspecified, the default order is returned.
"""
engine_orders = {
Backends.ASTRA: [cls.CHANNEL, cls.VERTICAL, cls.ANGLE, cls.HORIZONTAL],
Backends.TIGRE: [cls.CHANNEL, cls.ANGLE, cls.VERTICAL, cls.HORIZONTAL],
Backends.CIL: [cls.CHANNEL, cls.ANGLE, cls.VERTICAL, cls.HORIZONTAL]}
Backends.ASTRA: (cls.CHANNEL, cls.VERTICAL, cls.ANGLE, cls.HORIZONTAL),
Backends.TIGRE: (cls.CHANNEL, cls.ANGLE, cls.VERTICAL, cls.HORIZONTAL),
Backends.CIL: (cls.CHANNEL, cls.ANGLE, cls.VERTICAL, cls.HORIZONTAL)}
dim_order = engine_orders[Backends(engine)]

if geometry is None:
return dim_order
return [label for label in dim_order if label in geometry.dimension_labels]
return tuple(label for label in dim_order if label in geometry.dimension_labels)

@classmethod
def check_order_for_engine(cls, engine: str, geometry) -> bool:
Expand All @@ -185,11 +185,11 @@ def check_order_for_engine(cls, engine: str, geometry) -> bool:
ValueError if the order of dimensions is incorrect.
"""
order_requested = cls.get_order_for_engine(engine, geometry)
if order_requested == list(geometry.dimension_labels):
if order_requested == tuple(geometry.dimension_labels):
return True
raise ValueError(
f"Expected dimension_label order {order_requested},"
f" got {list(geometry.dimension_labels)}."
f" got {tuple(geometry.dimension_labels)}."
f" Try using `data.reorder('{engine}')` to permute for {engine}")


Expand Down
18 changes: 10 additions & 8 deletions Wrappers/Python/test/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_image_dimension_labels(self):
self.assertIn(getattr(ImageDimensionLabels, i), ImageDimensionLabels)

def test_image_dimension_labels_default_order(self):
order_gold = [ImageDimensionLabels.CHANNEL, 'VERTICAL', 'horizontal_y', 'HORIZONTAL_X']
order_gold = ImageDimensionLabels.CHANNEL, 'VERTICAL', 'horizontal_y', 'HORIZONTAL_X'
for i in ('CIL', 'TIGRE', 'ASTRA'):
self.assertSequenceEqual(ImageDimensionLabels.get_order_for_engine(i), order_gold)

Expand All @@ -100,7 +100,7 @@ def test_image_dimension_labels_get_order(self):
ig.set_labels(['channel', 'horizontal_y', 'horizontal_x'])

# for 2D all engines have the same order
order_gold = [ImageDimensionLabels.CHANNEL, 'HORIZONTAL_Y', 'horizontal_x']
order_gold = ImageDimensionLabels.CHANNEL, 'HORIZONTAL_Y', 'horizontal_x'
self.assertSequenceEqual(ImageDimensionLabels.get_order_for_engine('cil', ig), order_gold)
self.assertSequenceEqual(ImageDimensionLabels.get_order_for_engine('tigre', ig), order_gold)
self.assertSequenceEqual(ImageDimensionLabels.get_order_for_engine('astra', ig), order_gold)
Expand All @@ -125,9 +125,11 @@ def test_acquisition_dimension_labels(self):
self.assertIn(getattr(AcquisitionDimensionLabels, i), AcquisitionDimensionLabels)

def test_acquisition_dimension_labels_default_order(self):
self.assertEqual(AcquisitionDimensionLabels.get_order_for_engine('CIL'), [AcquisitionDimensionLabels.CHANNEL, 'ANGLE', 'vertical', 'HORIZONTAL'])
self.assertEqual(AcquisitionDimensionLabels.get_order_for_engine(Backends.TIGRE), ['CHANNEL', 'ANGLE', 'VERTICAL', 'HORIZONTAL'])
self.assertEqual(AcquisitionDimensionLabels.get_order_for_engine('astra'), ['CHANNEL', 'VERTICAL', 'ANGLE', 'HORIZONTAL'])
gold = AcquisitionDimensionLabels.CHANNEL, 'ANGLE', 'vertical', 'HORIZONTAL'
self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine('CIL'), gold)
self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine(Backends.TIGRE), gold)
gold = 'CHANNEL', 'VERTICAL', 'ANGLE', 'HORIZONTAL'
self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine('astra'), gold)

with self.assertRaises((KeyError, ValueError)):
AcquisitionDimensionLabels.get_order_for_engine("bad_engine")
Expand All @@ -140,7 +142,7 @@ def test_acquisition_dimension_labels_get_order(self):
.set_labels(['angle', 'horizontal', 'channel'])

# for 2D all engines have the same order
order_gold = [AcquisitionDimensionLabels.CHANNEL, 'ANGLE', 'horizontal']
order_gold = AcquisitionDimensionLabels.CHANNEL, 'ANGLE', 'horizontal'
self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine('CIL', ag), order_gold)
self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine('TIGRE', ag), order_gold)
self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine('ASTRA', ag), order_gold)
Expand All @@ -150,10 +152,10 @@ def test_acquisition_dimension_labels_get_order(self):
.set_panel((4,2))\
.set_labels(['angle', 'horizontal', 'vertical'])

order_gold = [AcquisitionDimensionLabels.ANGLE, 'VERTICAL', 'horizontal']
order_gold = AcquisitionDimensionLabels.ANGLE, 'VERTICAL', 'horizontal'
self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine("cil", ag), order_gold)
self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine("tigre", ag), order_gold)
order_gold = [AcquisitionDimensionLabels.VERTICAL, 'ANGLE', 'horizontal']
order_gold = AcquisitionDimensionLabels.VERTICAL, 'ANGLE', 'horizontal'
self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine("astra", ag), order_gold)

def test_acquisition_dimension_labels_check_order(self):
Expand Down

0 comments on commit 73cbdfa

Please sign in to comment.