Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Aug 19, 2024
1 parent 55dc29b commit 714434b
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 221 deletions.
20 changes: 7 additions & 13 deletions Wrappers/Python/cil/framework/acquisition_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def dimension(self,val):

@property
def geometry(self):
return self._geometry.value
return self._geometry

@geometry.setter
def geometry(self,val):
Expand Down Expand Up @@ -1720,15 +1720,11 @@ def dimension(self):
@property
def shape(self):

shape_dict = {AcquisitionDimensionLabels.CHANNEL.value: self.config.channels.num_channels,
AcquisitionDimensionLabels.ANGLE.value: self.config.angles.num_positions,
AcquisitionDimensionLabels.VERTICAL.value: self.config.panel.num_pixels[1],
AcquisitionDimensionLabels.HORIZONTAL.value: self.config.panel.num_pixels[0]}
shape = []
for label in self.dimension_labels:
shape.append(shape_dict[label])

return tuple(shape)
shape_dict = {AcquisitionDimensionLabels.CHANNEL: self.config.channels.num_channels,
AcquisitionDimensionLabels.ANGLE: self.config.angles.num_positions,
AcquisitionDimensionLabels.VERTICAL: self.config.panel.num_pixels[1],
AcquisitionDimensionLabels.HORIZONTAL: self.config.panel.num_pixels[0]}
return tuple(shape_dict[label] for label in self.dimension_labels)

@property
def dimension_labels(self):
Expand Down Expand Up @@ -1758,10 +1754,8 @@ def dimension_labels(self):

@dimension_labels.setter
def dimension_labels(self, val):

if val is not None:
label_new=[AcquisitionDimensionLabels(x).value for x in val if x in AcquisitionDimensionLabels]
self._dimension_labels = tuple(label_new)
self._dimension_labels = tuple(AcquisitionDimensionLabels(x) for x in val if x in AcquisitionDimensionLabels)

@property
def ndim(self):
Expand Down
36 changes: 12 additions & 24 deletions Wrappers/Python/cil/framework/image_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,35 +56,24 @@ def VERTICAL(self):
return ImageDimensionLabels.VERTICAL

@property
def shape(self):
shape_dict = {ImageDimensionLabels.CHANNEL.value: self.channels,
ImageDimensionLabels.VERTICAL.value: self.voxel_num_z,
ImageDimensionLabels.HORIZONTAL_Y.value: self.voxel_num_y,
ImageDimensionLabels.HORIZONTAL_X.value: self.voxel_num_x}

shape = []
for label in self.dimension_labels:
shape.append(shape_dict[label])

return tuple(shape)
def shape(self):
shape_dict = {ImageDimensionLabels.CHANNEL: self.channels,
ImageDimensionLabels.VERTICAL: self.voxel_num_z,
ImageDimensionLabels.HORIZONTAL_Y: self.voxel_num_y,
ImageDimensionLabels.HORIZONTAL_X: self.voxel_num_x}
return tuple(shape_dict[label] for label in self.dimension_labels)

@shape.setter
def shape(self, val):
print("Deprecated - shape will be set automatically")

@property
def spacing(self):

spacing_dict = {ImageDimensionLabels.CHANNEL.value: self.channel_spacing,
ImageDimensionLabels.VERTICAL.value: self.voxel_size_z,
ImageDimensionLabels.HORIZONTAL_Y.value: self.voxel_size_y,
ImageDimensionLabels.HORIZONTAL_X.value: self.voxel_size_x}

spacing = []
for label in self.dimension_labels:
spacing.append(spacing_dict[label])

return tuple(spacing)
spacing_dict = {ImageDimensionLabels.CHANNEL: self.channel_spacing,
ImageDimensionLabels.VERTICAL: self.voxel_size_z,
ImageDimensionLabels.HORIZONTAL_Y: self.voxel_size_y,
ImageDimensionLabels.HORIZONTAL_X: self.voxel_size_x}
return tuple(spacing_dict[label] for label in self.dimension_labels)

@property
def length(self):
Expand Down Expand Up @@ -123,8 +112,7 @@ def dimension_labels(self, val):

def set_labels(self, labels):
if labels is not None:
label_new=[ImageDimensionLabels(x).value for x in labels if x in ImageDimensionLabels]
self._dimension_labels = tuple(label_new)
self._dimension_labels = tuple(ImageDimensionLabels(x) for x in labels if x in ImageDimensionLabels)

def __eq__(self, other):

Expand Down
13 changes: 7 additions & 6 deletions Wrappers/Python/cil/framework/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class _StrEnumMeta(EnumType):
def __contains__(self, item: str) -> bool:
try:
key = item.upper()
except AttributeError:
except (AttributeError, TypeError):
return False
return key in self.__members__ or item in self.__members__.values()

Expand All @@ -42,9 +42,10 @@ def _missing_(cls, value: str):
def __eq__(self, value: str) -> bool:
"""Uses value.upper() for case-insensitivity"""
try:
return super().__eq__(self.__class__[value.upper()])
except (KeyError, ValueError):
return False
value = self.__class__[value.upper()]
except (KeyError, ValueError, AttributeError):
pass
return super().__eq__(value)

def __hash__(self) -> int:
"""consistent hashing for dictionary keys"""
Expand Down Expand Up @@ -104,7 +105,7 @@ def get_order_for_engine(cls, engine: str, geometry=None) -> list:
"""
order = [cls.CHANNEL, cls.VERTICAL, cls.HORIZONTAL_Y, cls.HORIZONTAL_X]
engine_orders = {Backends.ASTRA: order, Backends.TIGRE: order, Backends.CIL: order}
dim_order = engine_orders[Backends[engine.upper()]]
dim_order = engine_orders[Backends(engine)]

if geometry is None:
return dim_order
Expand Down Expand Up @@ -164,7 +165,7 @@ def get_order_for_engine(cls, engine: str, geometry=None) -> list:
Backends.ASTRA: [cls.CHANNEL, cls.VERTICAL, cls.ANGLE, cls.HORIZONTAL],
Backends.TIGRE: [cls.CHANNEL, cls.ANGLE, cls.VERTICAL, cls.HORIZONTAL],
Backends.CIL: [cls.CHANNEL, cls.ANGLE, cls.VERTICAL, cls.HORIZONTAL]}
dim_order = engine_orders[Backends[engine.upper()]]
dim_order = engine_orders[Backends(engine)]

if geometry is None:
return dim_order
Expand Down
12 changes: 4 additions & 8 deletions Wrappers/Python/cil/io/NEXUSDataWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,23 +145,19 @@ def write(self):
ds_data.write_direct(self.data.array)

# set up dataset attributes
if (isinstance(self.data, ImageData)):
ds_data.attrs['data_type'] = 'ImageData'
else:
ds_data.attrs['data_type'] = 'AcquisitionData'
ds_data.attrs['data_type'] = 'ImageData' if isinstance(self.data, ImageData) else 'AcquisitionData'

for i in range(self.data.number_of_dimensions):
ds_data.attrs['dim{}'.format(i)] = self.data.dimension_labels[i]

if (isinstance(self.data, AcquisitionData)):
ds_data.attrs[f'dim{i}'] = str(self.data.dimension_labels[i])

if isinstance(self.data, AcquisitionData):
# create group to store configuration
f.create_group('entry1/tomo_entry/config')
f.create_group('entry1/tomo_entry/config/source')
f.create_group('entry1/tomo_entry/config/detector')
f.create_group('entry1/tomo_entry/config/rotation_axis')

ds_data.attrs['geometry'] = self.data.geometry.config.system.geometry
ds_data.attrs['geometry'] = str(self.data.geometry.config.system.geometry)
ds_data.attrs['dimension'] = self.data.geometry.config.system.dimension
ds_data.attrs['num_channels'] = self.data.geometry.config.channels.num_channels

Expand Down
14 changes: 6 additions & 8 deletions Wrappers/Python/cil/io/ZEISSDataReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,16 @@ def set_up(self,

if roi is not None:
if metadata['data geometry'] == 'acquisition':
allowed_labels = [item.value for item in AcquisitionDimensionLabels]
zeiss_data_order = {'angle':0, 'vertical':1, 'horizontal':2}
zeiss_data_order = {AcquisitionDimensionLabels.ANGLE: 0,
AcquisitionDimensionLabels.VERTICAL: 1,
AcquisitionDimensionLabels.HORIZONTAL: 2}
else:
allowed_labels = [item.value for item in ImageDimensionLabels]
zeiss_data_order = {'vertical':0, 'horizontal_y':1, 'horizontal_x':2}
zeiss_data_order = {ImageDimensionLabels.VERTICAL: 0,
ImageDimensionLabels.HORIZONTAL_Y: 1,
ImageDimensionLabels.HORIZONTAL_X: 2}

# check roi labels and create tuple for slicing
for key in roi.keys():
if key not in allowed_labels:
raise Exception("Wrong label, got {0}. Expected dimension labels in {1}, {2}, {3}".format(key,**allowed_labels))

idx = zeiss_data_order[key]
if roi[key] != -1:
for i, x in enumerate(roi[key]):
Expand Down Expand Up @@ -289,4 +288,3 @@ def get_geometry(self):
def get_metadata(self):
'''return the metadata of the file'''
return self._metadata

Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def convert_geometry_to_astra_vec_2D(volume_geometry, sinogram_geometry_in):
panel = sinogram_geometry.config.panel

#get units
degrees = angles.angle_unit == UnitsAngles.DEGREE.value
degrees = angles.angle_unit == UnitsAngles.DEGREE

#create a 2D astra geom from 2D CIL geometry, 2D astra geometry has axis flipped compared to 3D
volume_geometry_temp = volume_geometry.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def convert_geometry_to_astra_vec_3D(volume_geometry, sinogram_geometry_in):
panel = sinogram_geometry.config.panel

#get units
degrees = angles.angle_unit == UnitsAngles.DEGREE.value
degrees = angles.angle_unit == UnitsAngles.DEGREE

if sinogram_geometry.dimension == '2D':
#create a 3D astra geom from 2D CIL geometry
Expand Down
12 changes: 6 additions & 6 deletions Wrappers/Python/test/test_DataContainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,30 +865,30 @@ def error_message(function_name, test_name):
expected = expected_func(data.as_array(), axis=1)
expected_dimension_labels = data.dimension_labels[0],data.dimension_labels[2]
numpy.testing.assert_almost_equal(result.as_array(), expected, err_msg=error_message(function_name, "'with 1 axis'"))
numpy.testing.assert_equal(result.dimension_labels, expected_dimension_labels, err_msg=error_message(function_name, "'with 1 axis'"))
self.assertEqual(result.dimension_labels, expected_dimension_labels, f"{function_name} 'with 1 axis'")
# test specifying axis with an int
result = test_func(axis=1)
numpy.testing.assert_almost_equal(result.as_array(), expected, err_msg=error_message(function_name, "'with 1 axis'"))
numpy.testing.assert_equal(result.dimension_labels,expected_dimension_labels, err_msg=error_message(function_name, "'with 1 axis'"))
self.assertEqual(result.dimension_labels,expected_dimension_labels, f"{function_name} 'with 1 axis'")
# test specifying function in 2 axes
result = test_func(axis=(data.dimension_labels[0],data.dimension_labels[1]))
numpy.testing.assert_almost_equal(result.as_array(), expected_func(data.as_array(), axis=(0,1)), err_msg=error_message(function_name, "'with 2 axes'"))
numpy.testing.assert_equal(result.dimension_labels,(data.dimension_labels[2],), err_msg=error_message(function_name, "'with 2 axes'"))
self.assertEqual(result.dimension_labels, (data.dimension_labels[2],), f"{function_name} 'with 2 axes'")
# test specifying function in 2 axes with an int
result = test_func(axis=(0,1))
numpy.testing.assert_almost_equal(result.as_array(), expected_func(data.as_array(), axis=(0,1)), err_msg=error_message(function_name, "'with 2 axes'"))
numpy.testing.assert_equal(result.dimension_labels,(data.dimension_labels[2],), err_msg=error_message(function_name, "'with 2 axes'"))
self.assertEqual(result.dimension_labels, (data.dimension_labels[2],), f"{function_name} 'with 2 axes'")
# test specifying function in 3 axes
result = test_func(axis=(data.dimension_labels[0],data.dimension_labels[1],data.dimension_labels[2]))
numpy.testing.assert_almost_equal(result, expected_func(data.as_array()), err_msg=error_message(function_name, "'with 3 axes'"))
# test providing a DataContainer to out
expected_array = expected_func(data.as_array(), axis = 0)
test_func(axis=0, out=out)
numpy.testing.assert_almost_equal(out.as_array(), expected_array, err_msg=error_message(function_name, "'of out argument'"))
numpy.testing.assert_equal(out.dimension_labels, (data.dimension_labels[1],data.dimension_labels[2]), err_msg=error_message(function_name, "'of out argument'"))
self.assertEqual(out.dimension_labels, (data.dimension_labels[1],data.dimension_labels[2]), f"{function_name} 'of out argument'")
test_func(axis=data.dimension_labels[0], out=out)
numpy.testing.assert_almost_equal(out.as_array(), expected_array, err_msg=error_message(function_name, "'of out argument'"))
numpy.testing.assert_equal(out.dimension_labels, (data.dimension_labels[1],data.dimension_labels[2]), err_msg=error_message(function_name, "'of out argument'"))
self.assertEqual(out.dimension_labels, (data.dimension_labels[1],data.dimension_labels[2]), f"{function_name} 'of out argument'")
# test providing a numpy array to out
out = numpy.zeros((2,2), dtype=data.dtype)
test_func(axis=0, out=out)
Expand Down
Loading

0 comments on commit 714434b

Please sign in to comment.