diff --git a/Wrappers/Python/cil/framework/acquisition_geometry.py b/Wrappers/Python/cil/framework/acquisition_geometry.py index c52278af32..c373614869 100644 --- a/Wrappers/Python/cil/framework/acquisition_geometry.py +++ b/Wrappers/Python/cil/framework/acquisition_geometry.py @@ -200,7 +200,7 @@ def dimension(self,val): @property def geometry(self): - return self._geometry.value + return self._geometry @geometry.setter def geometry(self,val): @@ -1720,15 +1720,11 @@ def dimension(self): @property def shape(self): - shape_dict = {AcquisitionDimensionLabels.CHANNEL.value: self.config.channels.num_channels, - AcquisitionDimensionLabels.ANGLE.value: self.config.angles.num_positions, - AcquisitionDimensionLabels.VERTICAL.value: self.config.panel.num_pixels[1], - AcquisitionDimensionLabels.HORIZONTAL.value: self.config.panel.num_pixels[0]} - shape = [] - for label in self.dimension_labels: - shape.append(shape_dict[label]) - - return tuple(shape) + shape_dict = {AcquisitionDimensionLabels.CHANNEL: self.config.channels.num_channels, + AcquisitionDimensionLabels.ANGLE: self.config.angles.num_positions, + AcquisitionDimensionLabels.VERTICAL: self.config.panel.num_pixels[1], + AcquisitionDimensionLabels.HORIZONTAL: self.config.panel.num_pixels[0]} + return tuple(shape_dict[label] for label in self.dimension_labels) @property def dimension_labels(self): @@ -1758,10 +1754,8 @@ def dimension_labels(self): @dimension_labels.setter def dimension_labels(self, val): - if val is not None: - label_new=[AcquisitionDimensionLabels(x).value for x in val if x in AcquisitionDimensionLabels] - self._dimension_labels = tuple(label_new) + self._dimension_labels = tuple(AcquisitionDimensionLabels(x) for x in val if x in AcquisitionDimensionLabels) @property def ndim(self): diff --git a/Wrappers/Python/cil/framework/image_geometry.py b/Wrappers/Python/cil/framework/image_geometry.py index b9c99a68ed..ab9049ac83 100644 --- a/Wrappers/Python/cil/framework/image_geometry.py +++ b/Wrappers/Python/cil/framework/image_geometry.py @@ -56,17 +56,12 @@ def VERTICAL(self): return ImageDimensionLabels.VERTICAL @property - def shape(self): - shape_dict = {ImageDimensionLabels.CHANNEL.value: self.channels, - ImageDimensionLabels.VERTICAL.value: self.voxel_num_z, - ImageDimensionLabels.HORIZONTAL_Y.value: self.voxel_num_y, - ImageDimensionLabels.HORIZONTAL_X.value: self.voxel_num_x} - - shape = [] - for label in self.dimension_labels: - shape.append(shape_dict[label]) - - return tuple(shape) + def shape(self): + shape_dict = {ImageDimensionLabels.CHANNEL: self.channels, + ImageDimensionLabels.VERTICAL: self.voxel_num_z, + ImageDimensionLabels.HORIZONTAL_Y: self.voxel_num_y, + ImageDimensionLabels.HORIZONTAL_X: self.voxel_num_x} + return tuple(shape_dict[label] for label in self.dimension_labels) @shape.setter def shape(self, val): @@ -74,17 +69,11 @@ def shape(self, val): @property def spacing(self): - - spacing_dict = {ImageDimensionLabels.CHANNEL.value: self.channel_spacing, - ImageDimensionLabels.VERTICAL.value: self.voxel_size_z, - ImageDimensionLabels.HORIZONTAL_Y.value: self.voxel_size_y, - ImageDimensionLabels.HORIZONTAL_X.value: self.voxel_size_x} - - spacing = [] - for label in self.dimension_labels: - spacing.append(spacing_dict[label]) - - return tuple(spacing) + spacing_dict = {ImageDimensionLabels.CHANNEL: self.channel_spacing, + ImageDimensionLabels.VERTICAL: self.voxel_size_z, + ImageDimensionLabels.HORIZONTAL_Y: self.voxel_size_y, + ImageDimensionLabels.HORIZONTAL_X: self.voxel_size_x} + return tuple(spacing_dict[label] for label in self.dimension_labels) @property def length(self): @@ -123,8 +112,7 @@ def dimension_labels(self, val): def set_labels(self, labels): if labels is not None: - label_new=[ImageDimensionLabels(x).value for x in labels if x in ImageDimensionLabels] - self._dimension_labels = tuple(label_new) + self._dimension_labels = tuple(ImageDimensionLabels(x) for x in labels if x in ImageDimensionLabels) def __eq__(self, other): diff --git a/Wrappers/Python/cil/framework/labels.py b/Wrappers/Python/cil/framework/labels.py index 8fb7c27c0d..9eac7530c6 100644 --- a/Wrappers/Python/cil/framework/labels.py +++ b/Wrappers/Python/cil/framework/labels.py @@ -27,7 +27,7 @@ class _StrEnumMeta(EnumType): def __contains__(self, item: str) -> bool: try: key = item.upper() - except AttributeError: + except (AttributeError, TypeError): return False return key in self.__members__ or item in self.__members__.values() @@ -42,9 +42,10 @@ def _missing_(cls, value: str): def __eq__(self, value: str) -> bool: """Uses value.upper() for case-insensitivity""" try: - return super().__eq__(self.__class__[value.upper()]) - except (KeyError, ValueError): - return False + value = self.__class__[value.upper()] + except (KeyError, ValueError, AttributeError): + pass + return super().__eq__(value) def __hash__(self) -> int: """consistent hashing for dictionary keys""" @@ -104,7 +105,7 @@ def get_order_for_engine(cls, engine: str, geometry=None) -> list: """ order = [cls.CHANNEL, cls.VERTICAL, cls.HORIZONTAL_Y, cls.HORIZONTAL_X] engine_orders = {Backends.ASTRA: order, Backends.TIGRE: order, Backends.CIL: order} - dim_order = engine_orders[Backends[engine.upper()]] + dim_order = engine_orders[Backends(engine)] if geometry is None: return dim_order @@ -164,7 +165,7 @@ def get_order_for_engine(cls, engine: str, geometry=None) -> list: Backends.ASTRA: [cls.CHANNEL, cls.VERTICAL, cls.ANGLE, cls.HORIZONTAL], Backends.TIGRE: [cls.CHANNEL, cls.ANGLE, cls.VERTICAL, cls.HORIZONTAL], Backends.CIL: [cls.CHANNEL, cls.ANGLE, cls.VERTICAL, cls.HORIZONTAL]} - dim_order = engine_orders[Backends[engine.upper()]] + dim_order = engine_orders[Backends(engine)] if geometry is None: return dim_order diff --git a/Wrappers/Python/cil/io/NEXUSDataWriter.py b/Wrappers/Python/cil/io/NEXUSDataWriter.py index daa9c94927..a3839d702c 100644 --- a/Wrappers/Python/cil/io/NEXUSDataWriter.py +++ b/Wrappers/Python/cil/io/NEXUSDataWriter.py @@ -145,23 +145,19 @@ def write(self): ds_data.write_direct(self.data.array) # set up dataset attributes - if (isinstance(self.data, ImageData)): - ds_data.attrs['data_type'] = 'ImageData' - else: - ds_data.attrs['data_type'] = 'AcquisitionData' + ds_data.attrs['data_type'] = 'ImageData' if isinstance(self.data, ImageData) else 'AcquisitionData' for i in range(self.data.number_of_dimensions): - ds_data.attrs['dim{}'.format(i)] = self.data.dimension_labels[i] - - if (isinstance(self.data, AcquisitionData)): + ds_data.attrs[f'dim{i}'] = str(self.data.dimension_labels[i]) + if isinstance(self.data, AcquisitionData): # create group to store configuration f.create_group('entry1/tomo_entry/config') f.create_group('entry1/tomo_entry/config/source') f.create_group('entry1/tomo_entry/config/detector') f.create_group('entry1/tomo_entry/config/rotation_axis') - ds_data.attrs['geometry'] = self.data.geometry.config.system.geometry + ds_data.attrs['geometry'] = str(self.data.geometry.config.system.geometry) ds_data.attrs['dimension'] = self.data.geometry.config.system.dimension ds_data.attrs['num_channels'] = self.data.geometry.config.channels.num_channels diff --git a/Wrappers/Python/cil/io/ZEISSDataReader.py b/Wrappers/Python/cil/io/ZEISSDataReader.py index 41e7ab1908..3bce8b3d0d 100644 --- a/Wrappers/Python/cil/io/ZEISSDataReader.py +++ b/Wrappers/Python/cil/io/ZEISSDataReader.py @@ -127,17 +127,16 @@ def set_up(self, if roi is not None: if metadata['data geometry'] == 'acquisition': - allowed_labels = [item.value for item in AcquisitionDimensionLabels] - zeiss_data_order = {'angle':0, 'vertical':1, 'horizontal':2} + zeiss_data_order = {AcquisitionDimensionLabels.ANGLE: 0, + AcquisitionDimensionLabels.VERTICAL: 1, + AcquisitionDimensionLabels.HORIZONTAL: 2} else: - allowed_labels = [item.value for item in ImageDimensionLabels] - zeiss_data_order = {'vertical':0, 'horizontal_y':1, 'horizontal_x':2} + zeiss_data_order = {ImageDimensionLabels.VERTICAL: 0, + ImageDimensionLabels.HORIZONTAL_Y: 1, + ImageDimensionLabels.HORIZONTAL_X: 2} # check roi labels and create tuple for slicing for key in roi.keys(): - if key not in allowed_labels: - raise Exception("Wrong label, got {0}. Expected dimension labels in {1}, {2}, {3}".format(key,**allowed_labels)) - idx = zeiss_data_order[key] if roi[key] != -1: for i, x in enumerate(roi[key]): @@ -289,4 +288,3 @@ def get_geometry(self): def get_metadata(self): '''return the metadata of the file''' return self._metadata - diff --git a/Wrappers/Python/cil/plugins/astra/utilities/convert_geometry_to_astra_vec_2D.py b/Wrappers/Python/cil/plugins/astra/utilities/convert_geometry_to_astra_vec_2D.py index ce6738ed19..105ce588c6 100644 --- a/Wrappers/Python/cil/plugins/astra/utilities/convert_geometry_to_astra_vec_2D.py +++ b/Wrappers/Python/cil/plugins/astra/utilities/convert_geometry_to_astra_vec_2D.py @@ -53,7 +53,7 @@ def convert_geometry_to_astra_vec_2D(volume_geometry, sinogram_geometry_in): panel = sinogram_geometry.config.panel #get units - degrees = angles.angle_unit == UnitsAngles.DEGREE.value + degrees = angles.angle_unit == UnitsAngles.DEGREE #create a 2D astra geom from 2D CIL geometry, 2D astra geometry has axis flipped compared to 3D volume_geometry_temp = volume_geometry.copy() diff --git a/Wrappers/Python/cil/plugins/astra/utilities/convert_geometry_to_astra_vec_3D.py b/Wrappers/Python/cil/plugins/astra/utilities/convert_geometry_to_astra_vec_3D.py index ac33c127af..5a0e846b64 100644 --- a/Wrappers/Python/cil/plugins/astra/utilities/convert_geometry_to_astra_vec_3D.py +++ b/Wrappers/Python/cil/plugins/astra/utilities/convert_geometry_to_astra_vec_3D.py @@ -55,7 +55,7 @@ def convert_geometry_to_astra_vec_3D(volume_geometry, sinogram_geometry_in): panel = sinogram_geometry.config.panel #get units - degrees = angles.angle_unit == UnitsAngles.DEGREE.value + degrees = angles.angle_unit == UnitsAngles.DEGREE if sinogram_geometry.dimension == '2D': #create a 3D astra geom from 2D CIL geometry diff --git a/Wrappers/Python/test/test_DataContainer.py b/Wrappers/Python/test/test_DataContainer.py index a97d1a89ce..273bfb5bb6 100644 --- a/Wrappers/Python/test/test_DataContainer.py +++ b/Wrappers/Python/test/test_DataContainer.py @@ -865,19 +865,19 @@ def error_message(function_name, test_name): expected = expected_func(data.as_array(), axis=1) expected_dimension_labels = data.dimension_labels[0],data.dimension_labels[2] numpy.testing.assert_almost_equal(result.as_array(), expected, err_msg=error_message(function_name, "'with 1 axis'")) - numpy.testing.assert_equal(result.dimension_labels, expected_dimension_labels, err_msg=error_message(function_name, "'with 1 axis'")) + self.assertEqual(result.dimension_labels, expected_dimension_labels, f"{function_name} 'with 1 axis'") # test specifying axis with an int result = test_func(axis=1) numpy.testing.assert_almost_equal(result.as_array(), expected, err_msg=error_message(function_name, "'with 1 axis'")) - numpy.testing.assert_equal(result.dimension_labels,expected_dimension_labels, err_msg=error_message(function_name, "'with 1 axis'")) + self.assertEqual(result.dimension_labels,expected_dimension_labels, f"{function_name} 'with 1 axis'") # test specifying function in 2 axes result = test_func(axis=(data.dimension_labels[0],data.dimension_labels[1])) numpy.testing.assert_almost_equal(result.as_array(), expected_func(data.as_array(), axis=(0,1)), err_msg=error_message(function_name, "'with 2 axes'")) - numpy.testing.assert_equal(result.dimension_labels,(data.dimension_labels[2],), err_msg=error_message(function_name, "'with 2 axes'")) + self.assertEqual(result.dimension_labels, (data.dimension_labels[2],), f"{function_name} 'with 2 axes'") # test specifying function in 2 axes with an int result = test_func(axis=(0,1)) numpy.testing.assert_almost_equal(result.as_array(), expected_func(data.as_array(), axis=(0,1)), err_msg=error_message(function_name, "'with 2 axes'")) - numpy.testing.assert_equal(result.dimension_labels,(data.dimension_labels[2],), err_msg=error_message(function_name, "'with 2 axes'")) + self.assertEqual(result.dimension_labels, (data.dimension_labels[2],), f"{function_name} 'with 2 axes'") # test specifying function in 3 axes result = test_func(axis=(data.dimension_labels[0],data.dimension_labels[1],data.dimension_labels[2])) numpy.testing.assert_almost_equal(result, expected_func(data.as_array()), err_msg=error_message(function_name, "'with 3 axes'")) @@ -885,10 +885,10 @@ def error_message(function_name, test_name): expected_array = expected_func(data.as_array(), axis = 0) test_func(axis=0, out=out) numpy.testing.assert_almost_equal(out.as_array(), expected_array, err_msg=error_message(function_name, "'of out argument'")) - numpy.testing.assert_equal(out.dimension_labels, (data.dimension_labels[1],data.dimension_labels[2]), err_msg=error_message(function_name, "'of out argument'")) + self.assertEqual(out.dimension_labels, (data.dimension_labels[1],data.dimension_labels[2]), f"{function_name} 'of out argument'") test_func(axis=data.dimension_labels[0], out=out) numpy.testing.assert_almost_equal(out.as_array(), expected_array, err_msg=error_message(function_name, "'of out argument'")) - numpy.testing.assert_equal(out.dimension_labels, (data.dimension_labels[1],data.dimension_labels[2]), err_msg=error_message(function_name, "'of out argument'")) + self.assertEqual(out.dimension_labels, (data.dimension_labels[1],data.dimension_labels[2]), f"{function_name} 'of out argument'") # test providing a numpy array to out out = numpy.zeros((2,2), dtype=data.dtype) test_func(axis=0, out=out) diff --git a/Wrappers/Python/test/test_labels.py b/Wrappers/Python/test/test_labels.py index e9c70e2fec..bc3417ab14 100644 --- a/Wrappers/Python/test/test_labels.py +++ b/Wrappers/Python/test/test_labels.py @@ -15,229 +15,161 @@ # # Authors: # CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt +import unittest import numpy as np -import unittest - -from cil.framework.labels import (_LabelsBase, - FillTypes, UnitsAngles, - AcquisitionTypes, AcquisitionDimensions, +from cil.framework import AcquisitionGeometry, ImageGeometry +from cil.framework.labels import (StrEnum, + FillTypes, UnitsAngles, + AcquisitionTypes, AcquisitionDimensions, ImageDimensionLabels, AcquisitionDimensionLabels, Backends) -from cil.framework import AcquisitionGeometry, ImageGeometry class Test_Lables(unittest.TestCase): - - def test_base_labels(self): - - out_gold = AcquisitionDimensions.DIM3 - - input_good = ["3D", AcquisitionDimensions.DIM3] - input_bad = ["bad_str", "DIM3", UnitsAngles.DEGREE] - - for item in input_good: + def test_labels_strenum(self): + for item in ("3D", "DIM3", AcquisitionDimensions.DIM3): out = AcquisitionDimensions(item) - self.assertEqual(out, out_gold) + self.assertEqual(out, AcquisitionDimensions.DIM3) self.assertTrue(isinstance(out, AcquisitionDimensions)) - - for item in input_bad: + for item in ("bad_str", "4D", "DIM4", UnitsAngles.DEGREE): with self.assertRaises(ValueError): AcquisitionDimensions(item) - - def test_labels_eq(self): - self.assertTrue(_LabelsBase.__eq__(AcquisitionDimensions.DIM3, "3D")) - self.assertTrue(_LabelsBase.__eq__(AcquisitionDimensions.DIM3, AcquisitionDimensions.DIM3)) - - self.assertFalse(_LabelsBase.__eq__(AcquisitionDimensions.DIM3, "DIM3")) - self.assertFalse(_LabelsBase.__eq__(AcquisitionDimensions.DIM3, "2D")) - self.assertFalse(_LabelsBase.__eq__(AcquisitionDimensions.DIM3, AcquisitionDimensions.DIM2)) - self.assertFalse(_LabelsBase.__eq__(AcquisitionDimensions.DIM3, AcquisitionDimensions)) - + def test_labels_strenum_eq(self): + for i in ("3D", "DIM3", AcquisitionDimensions.DIM3): + self.assertEqual(AcquisitionDimensions.DIM3, i) + self.assertEqual(i, AcquisitionDimensions.DIM3) + for i in ("2D", "DIM2", AcquisitionDimensions.DIM2, AcquisitionDimensions): + self.assertNotEqual(AcquisitionDimensions.DIM3, i) def test_labels_contains(self): - self.assertTrue(_LabelsBase.__contains__(AcquisitionDimensions, "3D")) - self.assertTrue(_LabelsBase.__contains__(AcquisitionDimensions, AcquisitionDimensions.DIM3)) - self.assertTrue(_LabelsBase.__contains__(AcquisitionDimensions, AcquisitionDimensions.DIM2)) - - self.assertFalse(_LabelsBase.__contains__(AcquisitionDimensions, "DIM3")) - self.assertFalse(_LabelsBase.__contains__(AcquisitionDimensions, AcquisitionDimensions)) - + for i in ("3D", "DIM3", AcquisitionDimensions.DIM3, AcquisitionDimensions.DIM2): + self.assertIn(i, AcquisitionDimensions) + for i in ("4D", "DIM4", AcquisitionDimensions): + self.assertNotIn(i, AcquisitionDimensions) def test_backends(self): - self.assertTrue('astra' in Backends) - self.assertTrue('cil' in Backends) - self.assertTrue('tigre' in Backends) - self.assertTrue(Backends.ASTRA in Backends) - self.assertTrue(Backends.CIL in Backends) - self.assertTrue(Backends.TIGRE in Backends) + for i in ('ASTRA', 'CIL', 'TIGRE'): + self.assertIn(i, Backends) + self.assertIn(i.lower(), Backends) + self.assertIn(getattr(Backends, i), Backends) def test_fill_types(self): - self.assertTrue('random' in FillTypes) - self.assertTrue('random_int' in FillTypes) - self.assertTrue(FillTypes.RANDOM in FillTypes) - self.assertTrue(FillTypes.RANDOM_INT in FillTypes) - + for i in ('RANDOM', 'RANDOM_INT'): + self.assertIn(i, FillTypes) + self.assertIn(i.lower(), FillTypes) + self.assertIn(getattr(FillTypes, i), FillTypes) + def test_units_angles(self): - self.assertTrue('degree' in UnitsAngles) - self.assertTrue('radian' in UnitsAngles) - self.assertTrue(UnitsAngles.DEGREE in UnitsAngles) - self.assertTrue(UnitsAngles.RADIAN in UnitsAngles) + for i in ('DEGREE', 'RADIAN'): + self.assertIn(i, UnitsAngles) + self.assertIn(i.lower(), UnitsAngles) + self.assertIn(getattr(UnitsAngles, i), UnitsAngles) def test_acquisition_type(self): - self.assertTrue('parallel' in AcquisitionTypes) - self.assertTrue('cone' in AcquisitionTypes) - self.assertTrue(AcquisitionTypes.PARALLEL in AcquisitionTypes) - self.assertTrue(AcquisitionTypes.CONE in AcquisitionTypes) + for i in ('PARALLEL', 'CONE'): + self.assertIn(i, AcquisitionTypes) + self.assertIn(i.lower(), AcquisitionTypes) + self.assertIn(getattr(AcquisitionTypes, i), AcquisitionTypes) def test_acquisition_dimension(self): - self.assertTrue('2D' in AcquisitionDimensions) - self.assertTrue('3D' in AcquisitionDimensions) - self.assertTrue(AcquisitionDimensions.DIM2 in AcquisitionDimensions) - self.assertTrue(AcquisitionDimensions.DIM3 in AcquisitionDimensions) + for i in ('2D', '3D'): + self.assertIn(i, AcquisitionDimensions) + for i in ('DIM2', 'DIM3'): + self.assertIn(i, AcquisitionDimensions) + self.assertIn(i.lower(), AcquisitionDimensions) + self.assertIn(getattr(AcquisitionDimensions, i), AcquisitionDimensions) def test_image_dimension_labels(self): - self.assertTrue('channel' in ImageDimensionLabels) - self.assertTrue('vertical' in ImageDimensionLabels) - self.assertTrue('horizontal_x' in ImageDimensionLabels) - self.assertTrue('horizontal_y' in ImageDimensionLabels) - self.assertTrue(ImageDimensionLabels.CHANNEL in ImageDimensionLabels) - self.assertTrue(ImageDimensionLabels.VERTICAL in ImageDimensionLabels) - self.assertTrue(ImageDimensionLabels.HORIZONTAL_X in ImageDimensionLabels) - self.assertTrue(ImageDimensionLabels.HORIZONTAL_Y in ImageDimensionLabels) + for i in ('CHANNEL', 'VERTICAL', 'HORIZONTAL_X', 'HORIZONTAL_Y'): + self.assertIn(i, ImageDimensionLabels) + self.assertIn(i.lower(), ImageDimensionLabels) + self.assertIn(getattr(ImageDimensionLabels, i), ImageDimensionLabels) def test_image_dimension_labels_default_order(self): - - order_gold = [ImageDimensionLabels.CHANNEL, ImageDimensionLabels.VERTICAL, ImageDimensionLabels.HORIZONTAL_Y, ImageDimensionLabels.HORIZONTAL_X] - - order = ImageDimensionLabels.get_order_for_engine("cil") - self.assertEqual(order,order_gold ) - - order = ImageDimensionLabels.get_order_for_engine("tigre") - self.assertEqual(order,order_gold) - - order = ImageDimensionLabels.get_order_for_engine("astra") - self.assertEqual(order, order_gold) - - with self.assertRaises(ValueError): - order = AcquisitionDimensionLabels.get_order_for_engine("bad_engine") + order_gold = [ImageDimensionLabels.CHANNEL, 'VERTICAL', 'horizontal_y', 'HORIZONTAL_X'] + for i in ('CIL', 'TIGRE', 'ASTRA'): + self.assertSequenceEqual(ImageDimensionLabels.get_order_for_engine(i), order_gold) + with self.assertRaises((KeyError, ValueError)): + AcquisitionDimensionLabels.get_order_for_engine("bad_engine") def test_image_dimension_labels_get_order(self): ig = ImageGeometry(4, 8, 1, channels=2) ig.set_labels(['channel', 'horizontal_y', 'horizontal_x']) # for 2D all engines have the same order - order_gold = [ImageDimensionLabels.CHANNEL, ImageDimensionLabels.HORIZONTAL_Y, ImageDimensionLabels.HORIZONTAL_X] - order = ImageDimensionLabels.get_order_for_engine("cil", ig) - self.assertEqual(order, order_gold) - - order = ImageDimensionLabels.get_order_for_engine("tigre", ig) - self.assertEqual(order, order_gold) - - order = ImageDimensionLabels.get_order_for_engine("astra", ig) - self.assertEqual(order, order_gold) + order_gold = [ImageDimensionLabels.CHANNEL, 'HORIZONTAL_Y', 'horizontal_x'] + self.assertSequenceEqual(ImageDimensionLabels.get_order_for_engine('cil', ig), order_gold) + self.assertSequenceEqual(ImageDimensionLabels.get_order_for_engine('tigre', ig), order_gold) + self.assertSequenceEqual(ImageDimensionLabels.get_order_for_engine('astra', ig), order_gold) def test_image_dimension_labels_check_order(self): ig = ImageGeometry(4, 8, 1, channels=2) ig.set_labels(['horizontal_x', 'horizontal_y', 'channel']) - with self.assertRaises(ValueError): - ImageDimensionLabels.check_order_for_engine("cil", ig) - - with self.assertRaises(ValueError): - ImageDimensionLabels.check_order_for_engine("tigre", ig) - - with self.assertRaises(ValueError): - ImageDimensionLabels.check_order_for_engine("astra", ig) + for i in ('cil', 'tigre', 'astra'): + with self.assertRaises(ValueError): + ImageDimensionLabels.check_order_for_engine(i, ig) ig.set_labels(['channel', 'horizontal_y', 'horizontal_x']) - self.assertTrue( ImageDimensionLabels.check_order_for_engine("cil", ig)) - self.assertTrue( ImageDimensionLabels.check_order_for_engine("tigre", ig)) - self.assertTrue( ImageDimensionLabels.check_order_for_engine("astra", ig)) + self.assertTrue(ImageDimensionLabels.check_order_for_engine("cil", ig)) + self.assertTrue(ImageDimensionLabels.check_order_for_engine("tigre", ig)) + self.assertTrue(ImageDimensionLabels.check_order_for_engine("astra", ig)) def test_acquisition_dimension_labels(self): - self.assertTrue('channel' in AcquisitionDimensionLabels) - self.assertTrue('angle' in AcquisitionDimensionLabels) - self.assertTrue('vertical' in AcquisitionDimensionLabels) - self.assertTrue('horizontal' in AcquisitionDimensionLabels) - self.assertTrue(AcquisitionDimensionLabels.CHANNEL in AcquisitionDimensionLabels) - self.assertTrue(AcquisitionDimensionLabels.ANGLE in AcquisitionDimensionLabels) - self.assertTrue(AcquisitionDimensionLabels.VERTICAL in AcquisitionDimensionLabels) - self.assertTrue(AcquisitionDimensionLabels.HORIZONTAL in AcquisitionDimensionLabels) + for i in ('CHANNEL', 'ANGLE', 'VERTICAL', 'HORIZONTAL'): + self.assertIn(i, AcquisitionDimensionLabels) + self.assertIn(i.lower(), AcquisitionDimensionLabels) + self.assertIn(getattr(AcquisitionDimensionLabels, i), AcquisitionDimensionLabels) def test_acquisition_dimension_labels_default_order(self): - order = AcquisitionDimensionLabels.get_order_for_engine("cil") - self.assertEqual(order, [AcquisitionDimensionLabels.CHANNEL, AcquisitionDimensionLabels.ANGLE, AcquisitionDimensionLabels.VERTICAL, AcquisitionDimensionLabels.HORIZONTAL]) + self.assertEqual(AcquisitionDimensionLabels.get_order_for_engine('CIL'), [AcquisitionDimensionLabels.CHANNEL, 'ANGLE', 'vertical', 'HORIZONTAL']) + self.assertEqual(AcquisitionDimensionLabels.get_order_for_engine(Backends.TIGRE), ['CHANNEL', 'ANGLE', 'VERTICAL', 'HORIZONTAL']) + self.assertEqual(AcquisitionDimensionLabels.get_order_for_engine('astra'), ['CHANNEL', 'VERTICAL', 'ANGLE', 'HORIZONTAL']) - order = AcquisitionDimensionLabels.get_order_for_engine("tigre") - self.assertEqual(order, [AcquisitionDimensionLabels.CHANNEL, AcquisitionDimensionLabels.ANGLE, AcquisitionDimensionLabels.VERTICAL, AcquisitionDimensionLabels.HORIZONTAL]) - - order = AcquisitionDimensionLabels.get_order_for_engine("astra") - self.assertEqual(order, [AcquisitionDimensionLabels.CHANNEL, AcquisitionDimensionLabels.VERTICAL, AcquisitionDimensionLabels.ANGLE, AcquisitionDimensionLabels.HORIZONTAL]) - - with self.assertRaises(ValueError): - order = AcquisitionDimensionLabels.get_order_for_engine("bad_engine") + with self.assertRaises((KeyError, ValueError)): + AcquisitionDimensionLabels.get_order_for_engine("bad_engine") def test_acquisition_dimension_labels_get_order(self): - ag = AcquisitionGeometry.create_Parallel2D()\ .set_angles(np.arange(0,16 , 1), angle_unit="degree")\ .set_panel(4)\ .set_channels(8)\ .set_labels(['angle', 'horizontal', 'channel']) - - # for 2D all engines have the same order - order_gold = [AcquisitionDimensionLabels.CHANNEL, AcquisitionDimensionLabels.ANGLE, AcquisitionDimensionLabels.HORIZONTAL] - order = AcquisitionDimensionLabels.get_order_for_engine("cil", ag) - self.assertEqual(order, order_gold) - - order = AcquisitionDimensionLabels.get_order_for_engine("tigre", ag) - self.assertEqual(order, order_gold) - - order = AcquisitionDimensionLabels.get_order_for_engine("astra", ag) - self.assertEqual(order, order_gold) + # for 2D all engines have the same order + order_gold = [AcquisitionDimensionLabels.CHANNEL, 'ANGLE', 'horizontal'] + self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine('CIL', ag), order_gold) + self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine('TIGRE', ag), order_gold) + self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine('ASTRA', ag), order_gold) ag = AcquisitionGeometry.create_Parallel3D()\ .set_angles(np.arange(0,16 , 1), angle_unit="degree")\ .set_panel((4,2))\ .set_labels(['angle', 'horizontal', 'vertical']) - - - order_gold = [AcquisitionDimensionLabels.ANGLE, AcquisitionDimensionLabels.VERTICAL, AcquisitionDimensionLabels.HORIZONTAL] - order = AcquisitionDimensionLabels.get_order_for_engine("cil", ag) - self.assertEqual(order, order_gold) - - order = AcquisitionDimensionLabels.get_order_for_engine("tigre", ag) - self.assertEqual(order, order_gold) - - order_gold = [AcquisitionDimensionLabels.VERTICAL, AcquisitionDimensionLabels.ANGLE, AcquisitionDimensionLabels.HORIZONTAL] - order = AcquisitionDimensionLabels.get_order_for_engine("astra", ag) - self.assertEqual(order, order_gold) + order_gold = [AcquisitionDimensionLabels.ANGLE, 'VERTICAL', 'horizontal'] + self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine("cil", ag), order_gold) + self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine("tigre", ag), order_gold) + order_gold = [AcquisitionDimensionLabels.VERTICAL, 'ANGLE', 'horizontal'] + self.assertSequenceEqual(AcquisitionDimensionLabels.get_order_for_engine("astra", ag), order_gold) def test_acquisition_dimension_labels_check_order(self): - ag = AcquisitionGeometry.create_Parallel3D()\ .set_angles(np.arange(0,16 , 1), angle_unit="degree")\ .set_panel((8,4))\ .set_channels(2)\ .set_labels(['angle', 'horizontal', 'channel', 'vertical']) - - with self.assertRaises(ValueError): - AcquisitionDimensionLabels.check_order_for_engine("cil", ag) - with self.assertRaises(ValueError): - AcquisitionDimensionLabels.check_order_for_engine("tigre", ag) - - with self.assertRaises(ValueError): - AcquisitionDimensionLabels.check_order_for_engine("astra", ag) + for i in ('cil', 'tigre', 'astra'): + with self.assertRaises(ValueError): + AcquisitionDimensionLabels.check_order_for_engine(i, ag) ag.set_labels(['channel', 'angle', 'vertical', 'horizontal']) - self.assertTrue( AcquisitionDimensionLabels.check_order_for_engine("cil", ag)) - self.assertTrue( AcquisitionDimensionLabels.check_order_for_engine("tigre", ag)) + self.assertTrue(AcquisitionDimensionLabels.check_order_for_engine("cil", ag)) + self.assertTrue(AcquisitionDimensionLabels.check_order_for_engine("tigre", ag)) ag.set_labels(['channel', 'vertical', 'angle', 'horizontal']) - self.assertTrue( AcquisitionDimensionLabels.check_order_for_engine("astra", ag)) + self.assertTrue(AcquisitionDimensionLabels.check_order_for_engine("astra", ag))