diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 3e45d899ec..41bb4ae79a 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -976,6 +976,18 @@ Spatial :members: :special-members: __call__ +`ConvertBoxToPoints` +"""""""""""""""""""" +.. autoclass:: ConvertBoxToPoints + :members: + :special-members: __call__ + +`ConvertPointsToBoxes` +"""""""""""""""""""""" +.. autoclass:: ConvertPointsToBoxes + :members: + :special-members: __call__ + Smooth Field ^^^^^^^^^^^^ @@ -1222,6 +1234,12 @@ Utility :members: :special-members: __call__ +`ApplyTransformToPoints` +"""""""""""""""""""""""" +.. autoclass:: ApplyTransformToPoints + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -1973,6 +1991,18 @@ Spatial (Dict) :members: :special-members: __call__ +`ConvertBoxToPointsd` +""""""""""""""""""""" +.. autoclass:: ConvertBoxToPointsd + :members: + :special-members: __call__ + +`ConvertPointsToBoxesd` +""""""""""""""""""""""" +.. autoclass:: ConvertPointsToBoxesd + :members: + :special-members: __call__ + Smooth Field (Dict) ^^^^^^^^^^^^^^^^^^^ @@ -2277,6 +2307,12 @@ Utility (Dict) :members: :special-members: __call__ +`ApplyTransformToPointsd` +""""""""""""""""""""""""" +.. autoclass:: ApplyTransformToPointsd + :members: + :special-members: __call__ + MetaTensor ^^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index f37016e63f..2cdd965c91 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -396,6 +396,8 @@ from .spatial.array import ( Affine, AffineGrid, + ConvertBoxToPoints, + ConvertPointsToBoxes, Flip, GridDistortion, GridPatch, @@ -427,6 +429,12 @@ Affined, AffineD, AffineDict, + ConvertBoxToPointsd, + ConvertBoxToPointsD, + ConvertBoxToPointsDict, + ConvertPointsToBoxesd, + ConvertPointsToBoxesD, + ConvertPointsToBoxesDict, Flipd, FlipD, FlipDict, @@ -503,6 +511,7 @@ from .utility.array import ( AddCoordinateChannels, AddExtremePointsChannel, + ApplyTransformToPoints, AsChannelLast, CastToType, ClassesToIndices, @@ -542,6 +551,9 @@ AddExtremePointsChanneld, AddExtremePointsChannelD, AddExtremePointsChannelDict, + ApplyTransformToPointsd, + ApplyTransformToPointsD, + ApplyTransformToPointsDict, AsChannelLastd, AsChannelLastD, AsChannelLastDict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 3739a83e71..6e39fb2e19 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -25,6 +25,7 @@ from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor +from monai.data.box_utils import BoxMode, StandardMode from monai.data.meta_obj import get_track_meta, set_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine @@ -34,6 +35,8 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.functional import ( affine_func, + convert_box_to_points, + convert_points_to_box, flip, orientation, resize, @@ -3544,3 +3547,44 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: else: return img + + +class ConvertBoxToPoints(Transform): + """ + Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode. + Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box. + Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None: + """ + Args: + mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode. + """ + super().__init__() + self.mode = StandardMode if mode is None else mode + + def __call__(self, data: Any): + data = convert_to_tensor(data, track_meta=get_track_meta()) + points = convert_box_to_points(data, mode=self.mode) + return convert_to_dst_type(points, data)[0] + + +class ConvertPointsToBoxes(Transform): + """ + Converts points to an axis-aligned bounding box. + Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or + (N, 4, 2) for the 4 corners of a 2D rectangle. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self) -> None: + super().__init__() + + def __call__(self, data: Any): + data = convert_to_tensor(data, track_meta=get_track_meta()) + box = convert_points_to_box(data) + return convert_to_dst_type(box, data)[0] diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 01fadcfb69..82dee15c7c 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -26,6 +26,7 @@ from monai.config import DtypeLike, KeysCollection, SequenceStr from monai.config.type_definitions import NdarrayOrTensor +from monai.data.box_utils import BoxMode, StandardMode from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.networks.layers.simplelayers import GaussianFilter @@ -33,6 +34,8 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( Affine, + ConvertBoxToPoints, + ConvertPointsToBoxes, Flip, GridDistortion, GridPatch, @@ -2611,6 +2614,61 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class ConvertBoxToPointsd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`. + """ + + backend = ConvertBoxToPoints.backend + + def __init__( + self, + keys: KeysCollection, + point_key="points", + mode: str | BoxMode | type[BoxMode] | None = StandardMode, + allow_missing_keys: bool = False, + ): + """ + Args: + keys: keys of the corresponding items to be transformed. + point_key: key to store the point data. + mode: the mode of the input boxes. Defaults to StandardMode. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.point_key = point_key + self.converter = ConvertBoxToPoints(mode=mode) + + def __call__(self, data): + d = dict(data) + for key in self.key_iterator(d): + data[self.point_key] = self.converter(d[key]) + return data + + +class ConvertPointsToBoxesd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`. + """ + + def __init__(self, keys: KeysCollection, box_key="box", allow_missing_keys: bool = False): + """ + Args: + keys: keys of the corresponding items to be transformed. + box_key: key to store the box data. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.box_key = box_key + self.converter = ConvertPointsToBoxes() + + def __call__(self, data): + d = dict(data) + for key in self.key_iterator(d): + data[self.box_key] = self.converter(d[key]) + return data + + SpatialResampleD = SpatialResampleDict = SpatialResampled ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd SpacingD = SpacingDict = Spacingd @@ -2635,3 +2693,5 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N GridPatchD = GridPatchDict = GridPatchd RandGridPatchD = RandGridPatchDict = RandGridPatchd RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond +ConvertBoxToPointsD = ConvertBoxToPointsDict = ConvertBoxToPointsd +ConvertPointsToBoxesD = ConvertPointsToBoxesDict = ConvertPointsToBoxesd diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 22726f06a5..b693e7d023 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -24,6 +24,7 @@ import monai from monai.config import USE_COMPILED from monai.config.type_definitions import NdarrayOrTensor +from monai.data.box_utils import get_boxmode from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd @@ -32,7 +33,7 @@ from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import TraceableTransform from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine -from monai.transforms.utils_pytorch_numpy_unification import allclose +from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack from monai.utils import ( LazyAttr, TraceKeys, @@ -610,3 +611,71 @@ def affine_func( out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device) out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out return out if image_only else (out, affine) + + +def convert_box_to_points(bbox, mode): + """ + Converts an axis-aligned bounding box to points. + + Args: + mode: The mode specifying how to interpret the bounding box. + bbox: Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] + for 3D for each box. Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D. + + Returns: + sequence of points representing the corners of the bounding box. + """ + + mode = get_boxmode(mode) + + points_list = [] + for _num in range(bbox.shape[0]): + corners = mode.boxes_to_corners(bbox[_num : _num + 1]) + if len(corners) == 4: + points_list.append( + concatenate( + [ + concatenate([corners[0], corners[1]], axis=1), + concatenate([corners[2], corners[1]], axis=1), + concatenate([corners[2], corners[3]], axis=1), + concatenate([corners[0], corners[3]], axis=1), + ], + axis=0, + ) + ) + else: + points_list.append( + concatenate( + [ + concatenate([corners[0], corners[1], corners[2]], axis=1), + concatenate([corners[3], corners[1], corners[2]], axis=1), + concatenate([corners[3], corners[4], corners[2]], axis=1), + concatenate([corners[0], corners[4], corners[2]], axis=1), + concatenate([corners[0], corners[1], corners[5]], axis=1), + concatenate([corners[3], corners[1], corners[5]], axis=1), + concatenate([corners[3], corners[4], corners[5]], axis=1), + concatenate([corners[0], corners[4], corners[5]], axis=1), + ], + axis=0, + ) + ) + + return stack(points_list, dim=0) + + +def convert_points_to_box(points): + """ + Converts points to an axis-aligned bounding box. + + Args: + points: Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of + a 3D cuboid or (N, 4, 2) for the 4 corners of a 2D rectangle. + """ + from monai.transforms.utils_pytorch_numpy_unification import max, min + + mins = min(points, dim=1) + maxs = max(points, dim=1) + # Concatenate the min and max values to get the bounding boxes + bboxes = concatenate([mins, maxs], axis=1) + + return bboxes diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 5dfbcb0e91..fee546bea3 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,7 +31,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor -from monai.data.utils import is_no_channel, no_collation +from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps from monai.networks.layers.simplelayers import ( ApplyFilter, EllipticalFilter, @@ -42,16 +42,17 @@ SharpenFilter, median_filter, ) -from monai.transforms.inverse import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform, TraceableTransform from monai.transforms.traits import MultiSampleTrait from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform from monai.transforms.utils import ( + apply_affine_to_points, extreme_points_to_image, get_extreme_points, map_binary_to_indices, map_classes_to_indices, ) -from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices +from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, linalg_inv, moveaxis, unravel_indices from monai.utils import ( MetaKeys, TraceKeys, @@ -66,7 +67,7 @@ ) from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least -from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype +from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -106,6 +107,7 @@ "ToCupy", "ImageFilter", "RandImageFilter", + "ApplyTransformToPoints", ] @@ -1715,3 +1717,133 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Mapping | None = None) -> Nd if self._do_transform: img = self.filter(img) return img + + +class ApplyTransformToPoints(InvertibleTransform, Transform): + """ + Transform points between image coordinates and world coordinates. + The input coordinates are assumed to be in the shape (C, N, 2 or 3), where C represents the number of channels + and N denotes the number of points. It will return a tensor with the same shape as the input. + + Args: + dtype: The desired data type for the output. + affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates + from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary + Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when + applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly. + The matrix is always converted to float64 for computation, which can be computationally + expensive when applied to a large number of points. + If None, will try to use the affine matrix from the input data. + invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``. + Typically, the affine matrix is derived from an image and represents its location in world space, + while the points are in world coordinates. A value of ``True`` represents transforming these + world space coordinates to the image's coordinate space, and ``False`` the inverse of this operation. + affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system + or you're using `ITKReader` with `affine_lps_to_ras=True`. + This ensures the correct application of the affine transformation between LPS (left-posterior-superior) + and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine + matrix are in the same coordinate system. + + Use Cases: + - Transforming points between world space and image space, and vice versa. + - Automatically handling inverse transformations between image space and world space. + - If points have an existing affine transformation, the class computes and + applies the required delta affine transformation. + + """ + + def __init__( + self, + dtype: DtypeLike | torch.dtype | None = None, + affine: torch.Tensor | None = None, + invert_affine: bool = True, + affine_lps_to_ras: bool = False, + ) -> None: + self.dtype = dtype + self.affine = affine + self.invert_affine = invert_affine + self.affine_lps_to_ras = affine_lps_to_ras + + def transform_coordinates( + self, data: torch.Tensor, affine: torch.Tensor | None = None + ) -> tuple[torch.Tensor, dict]: + """ + Transform coordinates using an affine transformation matrix. + + Args: + data: The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: 3x3 or 4x4 affine transformation matrix. The matrix is always converted to float64 for computation, + which can be computationally expensive when applied to a large number of points. + + Returns: + Transformed coordinates. + """ + data = convert_to_tensor(data, track_meta=get_track_meta()) + # applied_affine is the affine transformation matrix that has already been applied to the point data + applied_affine = getattr(data, "affine", None) + + if affine is None and self.invert_affine: + raise ValueError("affine must be provided when invert_affine is True.") + + affine = applied_affine if affine is None else affine + affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine + original_affine: torch.Tensor = affine + if self.affine_lps_to_ras: + affine = orientation_ras_lps(affine) + + # the final affine transformation matrix that will be applied to the point data + _affine: torch.Tensor = affine + if self.invert_affine: + _affine = linalg_inv(affine) + if applied_affine is not None: + # consider the affine transformation already applied to the data in the world space + # and compute delta affine + _affine = _affine @ linalg_inv(applied_affine) + out = apply_affine_to_points(data, _affine, dtype=self.dtype) + + extra_info = { + "invert_affine": self.invert_affine, + "dtype": get_dtype_string(self.dtype), + "image_affine": original_affine, # record for inverse operation + "affine_lps_to_ras": self.affine_lps_to_ras, + } + xform: torch.Tensor = original_affine if self.invert_affine else linalg_inv(original_affine) + meta_info = TraceableTransform.track_transform_meta( + data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info() + ) + + return out, meta_info + + def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None): + """ + Args: + data: The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: A 3x3 or 4x4 affine transformation matrix, this argument will take precedence over ``self.affine``. + """ + if data.ndim != 3 or data.shape[-1] not in (2, 3): + raise ValueError(f"data should be in shape (C, N, 2 or 3), got {data.shape}.") + affine = self.affine if affine is None else affine + if affine is not None and affine.shape not in ((3, 3), (4, 4)): + raise ValueError(f"affine should be in shape (3, 3) or (4, 4), got {affine.shape}.") + + out, meta_info = self.transform_coordinates(data, affine) + + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + def inverse(self, data: torch.Tensor) -> torch.Tensor: + transform = self.pop_transform(data) + # Create inverse transform + dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] + invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"] + affine = transform[TraceKeys.EXTRA_INFO]["image_affine"] + affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"] + inverse_transform = ApplyTransformToPoints( + dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + ) + # Apply inverse + with inverse_transform.trace_transform(False): + data = inverse_transform(data, affine) + + return data diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 2475060f4e..1279ca93ab 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -35,6 +35,7 @@ from monai.transforms.utility.array import ( AddCoordinateChannels, AddExtremePointsChannel, + ApplyTransformToPoints, AsChannelLast, CastToType, ClassesToIndices, @@ -180,6 +181,9 @@ "ClassesToIndicesd", "ClassesToIndicesD", "ClassesToIndicesDict", + "ApplyTransformToPointsd", + "ApplyTransformToPointsD", + "ApplyTransformToPointsDict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -1744,6 +1748,75 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class ApplyTransformToPointsd(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ApplyTransformToPoints`. + The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + The output has the same shape as the input. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + refer_key: The key of the reference item used for transformation. + It can directly refer to an affine or an image from which the affine can be derived. + dtype: The desired data type for the output. + affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates + from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary + Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when + applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly. + The matrix is always converted to float64 for computation, which can be computationally + expensive when applied to a large number of points. + If None, will try to use the affine matrix from the refer data. + invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``. + Typically, the affine matrix is derived from the image, while the points are in world coordinates. + If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``. + affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system + or you're using `ITKReader` with `affine_lps_to_ras=True`. + This ensures the correct application of the affine transformation between LPS (left-posterior-superior) + and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine + matrix are in the same coordinate system. + allow_missing_keys: Don't raise exception if key is missing. + """ + + def __init__( + self, + keys: KeysCollection, + refer_key: str | None = None, + dtype: DtypeLike | torch.dtype = torch.float64, + affine: torch.Tensor | None = None, + invert_affine: bool = True, + affine_lps_to_ras: bool = False, + allow_missing_keys: bool = False, + ): + MapTransform.__init__(self, keys, allow_missing_keys) + self.refer_key = refer_key + self.converter = ApplyTransformToPoints( + dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + ) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]): + d = dict(data) + if self.refer_key is not None: + if self.refer_key in d: + refer_data = d[self.refer_key] + else: + raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.") + else: + refer_data = None + affine = getattr(refer_data, "affine", refer_data) + for key in self.key_iterator(d): + coords = d[key] + d[key] = self.converter(coords, affine) + return d + + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter.inverse(d[key]) + return d + + RandImageFilterD = RandImageFilterDict = RandImageFilterd ImageFilterD = ImageFilterDict = ImageFilterd IdentityD = IdentityDict = Identityd @@ -1784,3 +1857,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N RandCuCIMD = RandCuCIMDict = RandCuCIMd AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd +ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 1d1f070568..b1f1bbd0f6 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -27,6 +27,7 @@ import monai from monai.config import DtypeLike, IndexSelection from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.data.utils import to_affine_nd from monai.networks.layers import GaussianFilter from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose @@ -35,6 +36,7 @@ from monai.transforms.utils_pytorch_numpy_unification import ( any_np_pt, ascontiguousarray, + concatenate, cumsum, isfinite, nonzero, @@ -2555,5 +2557,26 @@ def distance_transform_edt( return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0] +def apply_affine_to_points(data: torch.Tensor, affine: torch.Tensor, dtype: DtypeLike | torch.dtype | None = None): + """ + apply affine transformation to a set of points. + + Args: + data: input data to apply affine transformation, should be a tensor of shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: affine matrix to be applied, should be a tensor of shape (3, 3) or (4, 4). + dtype: output data dtype. + """ + data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=torch.float64) + affine = to_affine_nd(data_.shape[-1], affine) + + homogeneous: torch.Tensor = concatenate((data_, torch.ones((data_.shape[0], data_.shape[1], 1))), axis=2) # type: ignore + transformed_homogeneous = torch.matmul(homogeneous, affine.T) + transformed_coordinates = transformed_homogeneous[:, :, :-1] + out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype) + + return out + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 03fa1ceed1..4e36e3cd47 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -148,6 +148,7 @@ dtype_numpy_to_torch, dtype_torch_to_numpy, get_dtype, + get_dtype_string, get_equivalent_dtype, get_numpy_dtype_from_string, get_torch_dtype_from_string, diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index e4f97fc4a6..420e935b33 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -33,6 +33,7 @@ "get_equivalent_dtype", "convert_data_type", "get_dtype", + "get_dtype_string", "convert_to_cupy", "convert_to_numpy", "convert_to_tensor", @@ -102,6 +103,13 @@ def get_dtype(data: Any) -> DtypeLike | torch.dtype: return type(data) +def get_dtype_string(dtype: DtypeLike | torch.dtype) -> str: + """Get a string representation of the dtype.""" + if isinstance(dtype, torch.dtype): + return str(dtype)[6:] + return str(dtype)[3:] + + def convert_to_tensor( data: Any, dtype: DtypeLike | torch.dtype = None, diff --git a/tests/test_apply_transform_to_points.py b/tests/test_apply_transform_to_points.py new file mode 100644 index 0000000000..0c16603996 --- /dev/null +++ b/tests/test_apply_transform_to_points.py @@ -0,0 +1,81 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms.utility.array import ApplyTransformToPoints +from monai.utils import set_determinism + +set_determinism(seed=0) + +DATA_2D = torch.rand(1, 64, 64) +DATA_3D = torch.rand(1, 64, 64, 64) +POINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]]) +POINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]]) +POINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]]) +POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]) +POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]]) +POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]]) +AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) +AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]) + +TEST_CASES = [ + [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], + [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, True, POINT_2D_IMAGE_RAS], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + [ + MetaTensor(DATA_3D, affine=AFFINE_2), + MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2), + None, + False, + False, + POINT_3D_WORLD, + ], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], +] + +TEST_CASES_WRONG = [ + [POINT_2D_WORLD, True, None], + [POINT_2D_WORLD.unsqueeze(0), False, None], + [POINT_3D_WORLD[..., 0:1], False, None], + [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])], +] + + +class TestCoordinateTransform(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output): + transform = ApplyTransformToPoints( + dtype=torch.int64, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + ) + affine = image.affine if image is not None else None + output = transform(points, affine) + self.assertTrue(torch.allclose(output, expected_output)) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out, points)) + + @parameterized.expand(TEST_CASES_WRONG) + def test_wrong_input(self, input, invert_affine, affine): + transform = ApplyTransformToPoints(dtype=torch.int64, invert_affine=invert_affine) + with self.assertRaises(ValueError): + transform(input, affine) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py new file mode 100644 index 0000000000..4cedfa9d66 --- /dev/null +++ b/tests/test_apply_transform_to_pointsd.py @@ -0,0 +1,133 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms.utility.dictionary import ApplyTransformToPointsd +from monai.utils import set_determinism + +set_determinism(seed=0) + +DATA_2D = torch.rand(1, 64, 64) +DATA_3D = torch.rand(1, 64, 64, 64) +POINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]]) +POINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]]) +POINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]]) +POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]) +POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]]) +POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]]) + +TEST_CASES = [ + [ + MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_2D_WORLD, + None, + True, + False, + POINT_2D_IMAGE, + ], + [ + None, + MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + None, + False, + False, + POINT_2D_WORLD, + ], + [ + None, + MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), + False, + False, + POINT_2D_WORLD, + ], + [ + MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_2D_WORLD, + None, + True, + True, + POINT_2D_IMAGE_RAS, + ], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_3D_WORLD, + None, + True, + False, + POINT_3D_IMAGE, + ], + ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + MetaTensor(POINT_3D_IMAGE, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + None, + False, + False, + POINT_3D_WORLD, + ], + [ + MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])), + POINT_3D_WORLD, + None, + True, + True, + POINT_3D_IMAGE_RAS, + ], +] + +TEST_CASES_WRONG = [ + [POINT_2D_WORLD, True, None], + [POINT_2D_WORLD.unsqueeze(0), False, None], + [POINT_3D_WORLD[..., 0:1], False, None], + [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])], +] + + +class TestCoordinateTransform(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output): + data = { + "image": image, + "point": points, + "affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]), + } + refer_key = "image" if (image is not None and image != "affine") else image + transform = ApplyTransformToPointsd( + keys="point", + refer_key=refer_key, + dtype=torch.int64, + affine=affine, + invert_affine=invert_affine, + affine_lps_to_ras=affine_lps_to_ras, + ) + output = transform(data) + + self.assertTrue(torch.allclose(output["point"], expected_output)) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out["point"], points)) + + @parameterized.expand(TEST_CASES_WRONG) + def test_wrong_input(self, input, invert_affine, affine): + transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine) + with self.assertRaises(ValueError): + transform({"point": input}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_convert_box_points.py b/tests/test_convert_box_points.py new file mode 100644 index 0000000000..5e3d7ee645 --- /dev/null +++ b/tests/test_convert_box_points.py @@ -0,0 +1,121 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.data.box_utils import convert_box_to_standard_mode +from monai.transforms.spatial.array import ConvertBoxToPoints, ConvertPointsToBoxes +from tests.utils import assert_allclose + +TEST_CASE_POINTS_2D = [ + [ + torch.tensor([[10, 20, 30, 40], [50, 60, 70, 80]]), + "xyxy", + torch.tensor([[[10, 20], [30, 20], [30, 40], [10, 40]], [[50, 60], [70, 60], [70, 80], [50, 80]]]), + ], + [torch.tensor([[10, 20, 20, 20]]), "ccwh", torch.tensor([[[0, 10], [20, 10], [20, 30], [0, 30]]])], +] +TEST_CASE_POINTS_3D = [ + [ + torch.tensor([[10, 20, 30, 40, 50, 60], [70, 80, 90, 100, 110, 120]]), + "xyzxyz", + torch.tensor( + [ + [ + [10, 20, 30], + [40, 20, 30], + [40, 50, 30], + [10, 50, 30], + [10, 20, 60], + [40, 20, 60], + [40, 50, 60], + [10, 50, 60], + ], + [ + [70, 80, 90], + [100, 80, 90], + [100, 110, 90], + [70, 110, 90], + [70, 80, 120], + [100, 80, 120], + [100, 110, 120], + [70, 110, 120], + ], + ] + ), + ], + [ + torch.tensor([[10, 20, 30, 10, 10, 10]]), + "cccwhd", + torch.tensor( + [ + [ + [5, 15, 25], + [15, 15, 25], + [15, 25, 25], + [5, 25, 25], + [5, 15, 35], + [15, 15, 35], + [15, 25, 35], + [5, 25, 35], + ] + ] + ), + ], + [ + torch.tensor([[10, 20, 30, 40, 50, 60]]), + "xxyyzz", + torch.tensor( + [ + [ + [10, 30, 50], + [20, 30, 50], + [20, 40, 50], + [10, 40, 50], + [10, 30, 60], + [20, 30, 60], + [20, 40, 60], + [10, 40, 60], + ] + ] + ), + ], +] + +TEST_CASES = TEST_CASE_POINTS_2D + TEST_CASE_POINTS_3D + + +class TestConvertBoxToPoints(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_convert_box_to_points(self, boxes, mode, expected_points): + transform = ConvertBoxToPoints(mode=mode) + converted_points = transform(boxes) + assert_allclose(converted_points, expected_points, type_test=False) + + +class TestConvertPointsToBoxes(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_convert_box_to_points(self, boxes, mode, points): + transform = ConvertPointsToBoxes() + converted_boxes = transform(points) + expected_boxes = convert_box_to_standard_mode(boxes, mode) + assert_allclose(converted_boxes, expected_boxes, type_test=False) + + +if __name__ == "__main__": + unittest.main()