Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add box and points convert transform #8053

Merged
merged 13 commits into from
Sep 2, 2024
36 changes: 36 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,18 @@ Spatial
:members:
:special-members: __call__

`ConvertBoxToPoints`
""""""""""""""""""""
.. autoclass:: ConvertBoxToPoints
:members:
:special-members: __call__

`ConvertPointsToBoxes`
""""""""""""""""""""""
.. autoclass:: ConvertPointsToBoxes
:members:
:special-members: __call__


Smooth Field
^^^^^^^^^^^^
Expand Down Expand Up @@ -1222,6 +1234,12 @@ Utility
:members:
:special-members: __call__

`ApplyTransformToPoints`
""""""""""""""""""""""""
.. autoclass:: ApplyTransformToPoints
:members:
:special-members: __call__

Dictionary Transforms
---------------------

Expand Down Expand Up @@ -1973,6 +1991,18 @@ Spatial (Dict)
:members:
:special-members: __call__

`ConvertBoxToPointsd`
"""""""""""""""""""""
.. autoclass:: ConvertBoxToPointsd
:members:
:special-members: __call__

`ConvertPointsToBoxesd`
"""""""""""""""""""""""
.. autoclass:: ConvertPointsToBoxesd
:members:
:special-members: __call__


Smooth Field (Dict)
^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -2277,6 +2307,12 @@ Utility (Dict)
:members:
:special-members: __call__

`ApplyTransformToPointsd`
"""""""""""""""""""""""""
.. autoclass:: ApplyTransformToPointsd
:members:
:special-members: __call__


MetaTensor
^^^^^^^^^^
Expand Down
12 changes: 12 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@
from .spatial.array import (
Affine,
AffineGrid,
ConvertBoxToPoints,
ConvertPointsToBoxes,
Flip,
GridDistortion,
GridPatch,
Expand Down Expand Up @@ -427,6 +429,12 @@
Affined,
AffineD,
AffineDict,
ConvertBoxToPointsd,
ConvertBoxToPointsD,
ConvertBoxToPointsDict,
ConvertPointsToBoxesd,
ConvertPointsToBoxesD,
ConvertPointsToBoxesDict,
Flipd,
FlipD,
FlipDict,
Expand Down Expand Up @@ -503,6 +511,7 @@
from .utility.array import (
AddCoordinateChannels,
AddExtremePointsChannel,
ApplyTransformToPoints,
AsChannelLast,
CastToType,
ClassesToIndices,
Expand Down Expand Up @@ -542,6 +551,9 @@
AddExtremePointsChanneld,
AddExtremePointsChannelD,
AddExtremePointsChannelDict,
ApplyTransformToPointsd,
ApplyTransformToPointsD,
ApplyTransformToPointsDict,
AsChannelLastd,
AsChannelLastD,
AsChannelLastDict,
Expand Down
44 changes: 44 additions & 0 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta, set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
Expand All @@ -34,6 +35,8 @@
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.functional import (
affine_func,
convert_box_to_points,
convert_points_to_box,
flip,
orientation,
resize,
Expand Down Expand Up @@ -3544,3 +3547,44 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:

else:
return img


class ConvertBoxToPoints(Transform):
"""
Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode.
Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box.
Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None:
"""
Args:
mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode.
"""
super().__init__()
self.mode = StandardMode if mode is None else mode

def __call__(self, data: Any):
data = convert_to_tensor(data, track_meta=get_track_meta())
points = convert_box_to_points(data, mode=self.mode)
return convert_to_dst_type(points, data)[0]


class ConvertPointsToBoxes(Transform):
"""
Converts points to an axis-aligned bounding box.
Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or
(N, 4, 2) for the 4 corners of a 2D rectangle.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self) -> None:
super().__init__()

def __call__(self, data: Any):
data = convert_to_tensor(data, track_meta=get_track_meta())
box = convert_points_to_box(data)
return convert_to_dst_type(box, data)[0]
60 changes: 60 additions & 0 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@

from monai.config import DtypeLike, KeysCollection, SequenceStr
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.croppad.array import CenterSpatialCrop
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.array import (
Affine,
ConvertBoxToPoints,
ConvertPointsToBoxes,
Flip,
GridDistortion,
GridPatch,
Expand Down Expand Up @@ -2611,6 +2614,61 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class ConvertBoxToPointsd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`.
"""

backend = ConvertBoxToPoints.backend

def __init__(
self,
keys: KeysCollection,
point_key="points",
mode: str | BoxMode | type[BoxMode] | None = StandardMode,
allow_missing_keys: bool = False,
):
"""
Args:
keys: keys of the corresponding items to be transformed.
point_key: key to store the point data.
mode: the mode of the input boxes. Defaults to StandardMode.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.point_key = point_key
self.converter = ConvertBoxToPoints(mode=mode)

def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
data[self.point_key] = self.converter(d[key])
return data


class ConvertPointsToBoxesd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`.
"""

def __init__(self, keys: KeysCollection, box_key="box", allow_missing_keys: bool = False):
"""
Args:
keys: keys of the corresponding items to be transformed.
box_key: key to store the box data.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.box_key = box_key
self.converter = ConvertPointsToBoxes()

def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
data[self.box_key] = self.converter(d[key])
return data


SpatialResampleD = SpatialResampleDict = SpatialResampled
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
SpacingD = SpacingDict = Spacingd
Expand All @@ -2635,3 +2693,5 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
GridPatchD = GridPatchDict = GridPatchd
RandGridPatchD = RandGridPatchDict = RandGridPatchd
RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond
ConvertBoxToPointsD = ConvertBoxToPointsDict = ConvertBoxToPointsd
ConvertPointsToBoxesD = ConvertPointsToBoxesDict = ConvertPointsToBoxesd
71 changes: 70 additions & 1 deletion monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import monai
from monai.config import USE_COMPILED
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import get_boxmode
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd
Expand All @@ -32,7 +33,7 @@
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack
from monai.utils import (
LazyAttr,
TraceKeys,
Expand Down Expand Up @@ -610,3 +611,71 @@ def affine_func(
out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)
out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
return out if image_only else (out, affine)


def convert_box_to_points(bbox, mode):
"""
Converts an axis-aligned bounding box to points.

Args:
mode: The mode specifying how to interpret the bounding box.
bbox: Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2]
for 3D for each box. Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.

Returns:
sequence of points representing the corners of the bounding box.
"""

mode = get_boxmode(mode)

points_list = []
for _num in range(bbox.shape[0]):
corners = mode.boxes_to_corners(bbox[_num : _num + 1])
if len(corners) == 4:
points_list.append(
concatenate(
[
concatenate([corners[0], corners[1]], axis=1),
concatenate([corners[2], corners[1]], axis=1),
concatenate([corners[2], corners[3]], axis=1),
concatenate([corners[0], corners[3]], axis=1),
],
axis=0,
)
)
else:
points_list.append(
concatenate(
[
concatenate([corners[0], corners[1], corners[2]], axis=1),
concatenate([corners[3], corners[1], corners[2]], axis=1),
concatenate([corners[3], corners[4], corners[2]], axis=1),
concatenate([corners[0], corners[4], corners[2]], axis=1),
concatenate([corners[0], corners[1], corners[5]], axis=1),
concatenate([corners[3], corners[1], corners[5]], axis=1),
concatenate([corners[3], corners[4], corners[5]], axis=1),
concatenate([corners[0], corners[4], corners[5]], axis=1),
],
axis=0,
)
)

return stack(points_list, dim=0)


def convert_points_to_box(points):
"""
Converts points to an axis-aligned bounding box.

Args:
points: Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of
a 3D cuboid or (N, 4, 2) for the 4 corners of a 2D rectangle.
"""
from monai.transforms.utils_pytorch_numpy_unification import max, min

mins = min(points, dim=1)
maxs = max(points, dim=1)
# Concatenate the min and max values to get the bounding boxes
bboxes = concatenate([mins, maxs], axis=1)

return bboxes
Loading
Loading