Skip to content

Commit

Permalink
combine Acquisition{Types,Dimensions}(StrEnum) => AcquisitionType(Flag)
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Aug 19, 2024
1 parent 73cbdfa commit 96ec5ab
Show file tree
Hide file tree
Showing 17 changed files with 116 additions and 138 deletions.
2 changes: 1 addition & 1 deletion Wrappers/Python/cil/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
from .processors import DataProcessor, Processor, AX, PixelByPixelDataProcessor, CastDataContainer
from .block import BlockDataContainer, BlockGeometry
from .partitioner import Partitioner
from .labels import AcquisitionDimensionLabels, ImageDimensionLabels, FillTypes, UnitsAngles, AcquisitionTypes, AcquisitionDimensions
from .labels import AcquisitionDimensionLabels, ImageDimensionLabels, FillTypes, UnitsAngles, AcquisitionType
33 changes: 15 additions & 18 deletions Wrappers/Python/cil/framework/acquisition_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import numpy

from .labels import AcquisitionDimensionLabels, UnitsAngles, AcquisitionTypes, FillTypes, AcquisitionDimensions
from .labels import AcquisitionDimensionLabels, UnitsAngles, AcquisitionType, FillTypes
from .acquisition_data import AcquisitionData
from .image_geometry import ImageGeometry

Expand Down Expand Up @@ -186,10 +186,7 @@ class SystemConfiguration(object):

@property
def dimension(self):
if self._dimension == 2:
return AcquisitionDimensions.DIM2.value
else:
return AcquisitionDimensions.DIM3.value
return AcquisitionType.DIM2 if self._dimension == 2 else AcquisitionType.DIM3

@dimension.setter
def dimension(self,val):
Expand All @@ -203,8 +200,8 @@ def geometry(self):
return self._geometry

@geometry.setter
def geometry(self,val):
self._geometry = AcquisitionTypes(val)
def geometry(self, val):
self._geometry = AcquisitionType(val)

def __init__(self, dof, geometry, units='units'):
"""Initialises the system component attributes for the acquisition type
Expand All @@ -213,7 +210,7 @@ def __init__(self, dof, geometry, units='units'):
self.geometry = geometry
self.units = units

if self.geometry == AcquisitionTypes.PARALLEL:
if AcquisitionType.PARALLEL & self.geometry:
self.ray = DirectionVector(dof)
else:
self.source = PositionVector(dof)
Expand Down Expand Up @@ -344,7 +341,7 @@ class Parallel2D(SystemConfiguration):
def __init__ (self, ray_direction, detector_pos, detector_direction_x, rotation_axis_pos, units='units'):
"""Constructor method
"""
super(Parallel2D, self).__init__(dof=2, geometry=AcquisitionTypes.PARALLEL, units=units)
super(Parallel2D, self).__init__(dof=2, geometry=AcquisitionType.PARALLEL, units=units)

#source
self.ray.direction = ray_direction
Expand Down Expand Up @@ -518,7 +515,7 @@ class Parallel3D(SystemConfiguration):
def __init__ (self, ray_direction, detector_pos, detector_direction_x, detector_direction_y, rotation_axis_pos, rotation_axis_direction, units='units'):
"""Constructor method
"""
super(Parallel3D, self).__init__(dof=3, geometry=AcquisitionTypes.PARALLEL, units=units)
super(Parallel3D, self).__init__(dof=3, geometry=AcquisitionType.PARALLEL, units=units)

#source
self.ray.direction = ray_direction
Expand Down Expand Up @@ -803,7 +800,7 @@ class Cone2D(SystemConfiguration):
def __init__ (self, source_pos, detector_pos, detector_direction_x, rotation_axis_pos, units='units'):
"""Constructor method
"""
super(Cone2D, self).__init__(dof=2, geometry=AcquisitionTypes.CONE, units=units)
super(Cone2D, self).__init__(dof=2, geometry=AcquisitionType.CONE, units=units)

#source
self.source.position = source_pos
Expand Down Expand Up @@ -982,7 +979,7 @@ class Cone3D(SystemConfiguration):
def __init__ (self, source_pos, detector_pos, detector_direction_x, detector_direction_y, rotation_axis_pos, rotation_axis_direction, units='units'):
"""Constructor method
"""
super(Cone3D, self).__init__(dof=3, geometry=AcquisitionTypes.CONE, units=units)
super(Cone3D, self).__init__(dof=3, geometry=AcquisitionType.CONE, units=units)

#source
self.source.position = source_pos
Expand Down Expand Up @@ -1819,7 +1816,7 @@ def get_centre_of_rotation(self, distance_units='default', angle_units='radian')
offset = offset_distance/ self.config.panel.pixel_size[0]
offset_units = 'pixels'

if self.dimension == '3D' and self.config.panel.pixel_size[0] != self.config.panel.pixel_size[1]:
if AcquisitionType.DIM3 & self.dimension and self.config.panel.pixel_size[0] != self.config.panel.pixel_size[1]:
#if aspect ratio of pixels isn't 1:1 need to convert angle by new ratio
y_pix = 1 /self.config.panel.pixel_size[1]
x_pix = math.tan(angle_rad)/self.config.panel.pixel_size[0]
Expand Down Expand Up @@ -1884,7 +1881,7 @@ def set_centre_of_rotation(self, offset=0.0, distance_units='default', angle=0.0
else:
raise ValueError("`distance_units` is not recognised. Must be 'default' or 'pixels'. Got {}".format(distance_units))

if self.dimension == '2D':
if AcquisitionType.DIM2 & self.dimension:
self.config.system.set_centre_of_rotation(offset_distance)
else:
self.config.system.set_centre_of_rotation(offset_distance, angle_rad)
Expand Down Expand Up @@ -1921,7 +1918,7 @@ def set_centre_of_rotation_by_slice(self, offset1, slice_index1=None, offset2=No
if not hasattr(self.config.system, 'set_centre_of_rotation'):
raise NotImplementedError()

if self.dimension == '2D':
if AcquisitionType.DIM2 & self.dimension:
if offset2 is not None:
warnings.warn("2D so offset2 is ingored", UserWarning, stacklevel=2)
self.set_centre_of_rotation(offset1)
Expand Down Expand Up @@ -2118,7 +2115,7 @@ def copy(self):
def get_centre_slice(self):
'''returns a 2D AcquisitionGeometry that corresponds to the centre slice of the input'''

if self.dimension == '2D':
if AcquisitionType.DIM2 & self.dimension:
return self

AG_2D = copy.deepcopy(self)
Expand All @@ -2133,7 +2130,7 @@ def get_ImageGeometry(self, resolution=1.0):
num_voxel_xy = int(numpy.ceil(self.config.panel.num_pixels[0] * resolution))
voxel_size_xy = self.config.panel.pixel_size[0] / (resolution * self.magnification)

if self.dimension == '3D':
if AcquisitionType.DIM3 & self.dimension:
num_voxel_z = int(numpy.ceil(self.config.panel.num_pixels[1] * resolution))
voxel_size_z = self.config.panel.pixel_size[1] / (resolution * self.magnification)
else:
Expand Down Expand Up @@ -2161,7 +2158,7 @@ def get_slice(self, channel=None, angle=None, vertical=None, horizontal=None):
geometry_new.config.angles.angle_data = geometry_new.config.angles.angle_data[angle]

if vertical is not None:
if geometry_new.geom_type == AcquisitionTypes.PARALLEL or vertical == 'centre' or abs(geometry_new.pixel_num_v/2 - vertical) < 1e-6:
if AcquisitionType.PARALLEL & geometry_new.geom_type or vertical == 'centre' or abs(geometry_new.pixel_num_v/2 - vertical) < 1e-6:
geometry_new = geometry_new.get_centre_slice()
else:
raise ValueError("Can only subset centre slice geometry on cone-beam data. Expected vertical = 'centre'. Got vertical = {0}".format(vertical))
Expand Down
51 changes: 34 additions & 17 deletions Wrappers/Python/cil/framework/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
from enum import Enum, auto, unique
from enum import Enum, Flag as _Flag, auto, unique
try:
from enum import EnumType
except ImportError: # Python<3.11
Expand All @@ -40,7 +40,6 @@ def _missing_(cls, value: str):
return cls.__members__.get(value.upper(), None)

def __eq__(self, value: str) -> bool:
"""Uses value.upper() for case-insensitivity"""
try:
value = self.__class__[value.upper()]
except (KeyError, ValueError, AttributeError):
Expand All @@ -60,7 +59,6 @@ def _generate_next_value_(name: str, start, count, last_values) -> str:
return name.lower()



class Backends(StrEnum):
"""
Available backends for CIL.
Expand Down Expand Up @@ -228,27 +226,46 @@ class UnitsAngles(StrEnum):
RADIAN = auto()


class AcquisitionTypes(StrEnum):
class _FlagMeta(EnumType):
"""Python<3.12 requires this in a metaclass (rather than directly in Flag)"""
def __contains__(self, item) -> bool:
return item.upper() in self.__members__ if isinstance(item, str) else super().__contains__(item)


@unique
class Flag(_Flag, metaclass=_FlagMeta):
"""Case-insensitive Flag"""
@classmethod
def _missing_(cls, value):
return cls.__members__.get(value.upper(), None) if isinstance(value, str) else super()._missing_(value)

def __eq__(self, value: str) -> bool:
return super().__eq__(self.__class__[value.upper()] if isinstance(value, str) else value)


class AcquisitionType(Flag):
"""
Available acquisition types.
Available acquisition types & dimensions.
Attributes
----------
PARALLEL: Parallel beam.
CONE: Cone beam.
DIM2: 2D acquisition.
DIM3: 3D acquisition.
"""
PARALLEL = auto()
CONE = auto()
DIM2 = auto()
DIM3 = auto()


class AcquisitionDimensions(StrEnum):
"""
Available acquisition dimensions.
Attributes
----------
DIM2 ('2D'): 2D acquisition.
DIM3 ('3D'): 3D acquisition.
"""
DIM2 = "2D"
DIM3 = "3D"
@classmethod
def _missing_(cls, value):
"""2D/3D aliases"""
if isinstance(value, str):
value = {'2D': 'DIM2', '3D': 'DIM3'}.get(value.upper(), value)
return super()._missing_(value)

def __str__(self) -> str:
"""2D/3D special handling"""
return '2D' if self == self.DIM2 else '3D' if self == self.DIM3 else self.name
7 changes: 4 additions & 3 deletions Wrappers/Python/cil/io/NEXUSDataWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import numpy as np
import os
from cil.framework import AcquisitionData, AcquisitionGeometry, ImageData, ImageGeometry
from cil.framework import AcquisitionData, ImageData
from cil.framework.labels import AcquisitionType
from cil.version import version
import datetime
from cil.io import utilities
Expand Down Expand Up @@ -158,7 +159,7 @@ def write(self):
f.create_group('entry1/tomo_entry/config/rotation_axis')

ds_data.attrs['geometry'] = str(self.data.geometry.config.system.geometry)
ds_data.attrs['dimension'] = self.data.geometry.config.system.dimension
ds_data.attrs['dimension'] = str(self.data.geometry.config.system.dimension)
ds_data.attrs['num_channels'] = self.data.geometry.config.channels.num_channels

f.create_dataset('entry1/tomo_entry/config/detector/direction_x',
Expand Down Expand Up @@ -192,7 +193,7 @@ def write(self):
ds_data.attrs['pixel_size_h'] = self.data.geometry.config.panel.pixel_size[0]
ds_data.attrs['panel_origin'] = self.data.geometry.config.panel.origin

if self.data.geometry.config.system.dimension == '3D':
if AcquisitionType.DIM3 & self.data.geometry.config.system.dimension:
f.create_dataset('entry1/tomo_entry/config/detector/direction_y',
(self.data.geometry.config.system.detector.direction_y.shape),
dtype = 'float32',
Expand Down
6 changes: 3 additions & 3 deletions Wrappers/Python/cil/io/NikonDataReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt

from cil.framework import AcquisitionData, AcquisitionGeometry
from cil.framework import AcquisitionGeometry
from cil.framework.labels import AcquisitionType
from cil.io.TIFF import TIFFStackReader
import warnings
import numpy as np
import os

Expand Down Expand Up @@ -334,7 +334,7 @@ def get_geometry(self):
def get_roi(self):
'''returns the roi'''
roi = self._roi_par[:]
if self._ag.dimension == '2D':
if AcquisitionType.DIM2 & self._ag.dimension:
roi.pop(1)

roidict = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import logging

from cil.framework import BlockGeometry, AcquisitionDimensionLabels, ImageDimensionLabels
from cil.framework import BlockGeometry, AcquisitionDimensionLabels, ImageDimensionLabels, AcquisitionType
from cil.optimisation.operators import BlockOperator, LinearOperator, ChannelwiseOperator
from cil.plugins.astra.operators import AstraProjector2D, AstraProjector3D

Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(self,
if device == 'gpu':
operator = AstraProjector3D(volume_geometry_sc,
sinogram_geometry_sc)
elif self.sinogram_geometry.dimension == '2D':
elif AcquisitionType.DIM2 & self.sinogram_geometry.dimension:
operator = AstraProjector2D(volume_geometry_sc,
sinogram_geometry_sc,
device=device)
Expand Down
6 changes: 2 additions & 4 deletions Wrappers/Python/cil/plugins/astra/processors/FBP.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
import warnings

from cil.framework import DataProcessor, ImageDimensionLabels, AcquisitionDimensionLabels
from cil.framework import DataProcessor, ImageDimensionLabels, AcquisitionDimensionLabels, AcquisitionType
from cil.plugins.astra.processors.FBP_Flexible import FBP_Flexible
from cil.plugins.astra.processors.FDK_Flexible import FDK_Flexible
from cil.plugins.astra.processors.FBP_Flexible import FBP_CPU
Expand Down Expand Up @@ -81,7 +79,7 @@ def __init__(self, image_geometry=None, acquisition_geometry=None, device='gpu')
if acquisition_geometry.geom_type == 'cone':
raise NotImplementedError("Cannot process cone-beam data without a GPU")

if acquisition_geometry.dimension == '2D':
if AcquisitionType.DIM2 & acquisition_geometry.dimension:
processor = FBP_CPU(image_geometry, acquisition_geometry)
else:
raise NotImplementedError("Cannot process 3D data without a GPU")
Expand Down
6 changes: 3 additions & 3 deletions Wrappers/Python/cil/plugins/astra/processors/FBP_Flexible.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt


from cil.framework import AcquisitionGeometry, Processor, ImageData
from cil.framework import AcquisitionGeometry, Processor, ImageData, AcquisitionType
from cil.plugins.astra.processors.FDK_Flexible import FDK_Flexible
from cil.plugins.astra.utilities import convert_geometry_to_astra_vec_3D, convert_geometry_to_astra
import logging
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(self, volume_geometry,
detector_position = sino_geom_cone.config.system.detector.position
detector_direction_x = sino_geom_cone.config.system.detector.direction_x

if sino_geom_cone.dimension == '2D':
if AcquisitionType.DIM2 & sino_geom_cone.dimension:
tmp = AcquisitionGeometry.create_Cone2D(cone_source, detector_position, detector_direction_x)
else:
detector_direction_y = sino_geom_cone.config.system.detector.direction_y
Expand Down Expand Up @@ -141,7 +141,7 @@ def check_input(self, dataset):
raise ValueError("Expected input data to be parallel beam geometry , got {0}"\
.format(self.sinogram_geometry.geom_type))

if self.sinogram_geometry.dimension != '2D':
if not AcquisitionType.DIM2 & self.sinogram_geometry.dimension:
raise ValueError("Expected input data to be 2D , got {0}"\
.format(self.sinogram_geometry.dimension))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import astra
import numpy as np
from cil.framework import UnitsAngles
from cil.framework import AcquisitionType, UnitsAngles

def convert_geometry_to_astra(volume_geometry, sinogram_geometry):
"""
Expand All @@ -39,13 +39,8 @@ def convert_geometry_to_astra(volume_geometry, sinogram_geometry):
The ASTRA vol_geom and proj_geom
"""

# determine if the geometry is 2D or 3D

if sinogram_geometry.pixel_num_v > 1:
dimension = '3D'
else:
dimension = '2D'
dimension = AcquisitionType.DIM3 if sinogram_geometry.pixel_num_v > 1 else AcquisitionType.DIM2

#get units

Expand All @@ -54,7 +49,7 @@ def convert_geometry_to_astra(volume_geometry, sinogram_geometry):
else:
angles_rad = sinogram_geometry.config.angles.angle_data

if dimension == '2D':
if AcquisitionType.DIM2 & dimension:
vol_geom = astra.create_vol_geom(volume_geometry.voxel_num_y,
volume_geometry.voxel_num_x,
volume_geometry.get_min_x(),
Expand All @@ -77,7 +72,7 @@ def convert_geometry_to_astra(volume_geometry, sinogram_geometry):
else:
NotImplemented

elif dimension == '3D':
elif AcquisitionType.DIM3 & dimension:
vol_geom = astra.create_vol_geom(volume_geometry.voxel_num_y,
volume_geometry.voxel_num_x,
volume_geometry.voxel_num_z,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import astra
import numpy as np
from cil.framework import UnitsAngles
from cil.framework import AcquisitionType, UnitsAngles

def convert_geometry_to_astra_vec_3D(volume_geometry, sinogram_geometry_in):

Expand Down Expand Up @@ -57,7 +57,7 @@ def convert_geometry_to_astra_vec_3D(volume_geometry, sinogram_geometry_in):
#get units
degrees = angles.angle_unit == UnitsAngles.DEGREE

if sinogram_geometry.dimension == '2D':
if AcquisitionType.DIM2 & sinogram_geometry.dimension:
#create a 3D astra geom from 2D CIL geometry
volume_geometry_temp = volume_geometry.copy()
volume_geometry_temp.voxel_num_z = 1
Expand Down
Loading

0 comments on commit 96ec5ab

Please sign in to comment.