Skip to content

Commit

Permalink
Add box and points convert transform (#8053)
Browse files Browse the repository at this point in the history
Add box and points convert transform
Cherrypick ApplyTransformToPoints

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 2, 2024
1 parent c9f8d32 commit 7219ee7
Show file tree
Hide file tree
Showing 13 changed files with 799 additions and 5 deletions.
36 changes: 36 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,18 @@ Spatial
:members:
:special-members: __call__

`ConvertBoxToPoints`
""""""""""""""""""""
.. autoclass:: ConvertBoxToPoints
:members:
:special-members: __call__

`ConvertPointsToBoxes`
""""""""""""""""""""""
.. autoclass:: ConvertPointsToBoxes
:members:
:special-members: __call__


Smooth Field
^^^^^^^^^^^^
Expand Down Expand Up @@ -1222,6 +1234,12 @@ Utility
:members:
:special-members: __call__

`ApplyTransformToPoints`
""""""""""""""""""""""""
.. autoclass:: ApplyTransformToPoints
:members:
:special-members: __call__

Dictionary Transforms
---------------------

Expand Down Expand Up @@ -1973,6 +1991,18 @@ Spatial (Dict)
:members:
:special-members: __call__

`ConvertBoxToPointsd`
"""""""""""""""""""""
.. autoclass:: ConvertBoxToPointsd
:members:
:special-members: __call__

`ConvertPointsToBoxesd`
"""""""""""""""""""""""
.. autoclass:: ConvertPointsToBoxesd
:members:
:special-members: __call__


Smooth Field (Dict)
^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -2277,6 +2307,12 @@ Utility (Dict)
:members:
:special-members: __call__

`ApplyTransformToPointsd`
"""""""""""""""""""""""""
.. autoclass:: ApplyTransformToPointsd
:members:
:special-members: __call__


MetaTensor
^^^^^^^^^^
Expand Down
12 changes: 12 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@
from .spatial.array import (
Affine,
AffineGrid,
ConvertBoxToPoints,
ConvertPointsToBoxes,
Flip,
GridDistortion,
GridPatch,
Expand Down Expand Up @@ -427,6 +429,12 @@
Affined,
AffineD,
AffineDict,
ConvertBoxToPointsd,
ConvertBoxToPointsD,
ConvertBoxToPointsDict,
ConvertPointsToBoxesd,
ConvertPointsToBoxesD,
ConvertPointsToBoxesDict,
Flipd,
FlipD,
FlipDict,
Expand Down Expand Up @@ -503,6 +511,7 @@
from .utility.array import (
AddCoordinateChannels,
AddExtremePointsChannel,
ApplyTransformToPoints,
AsChannelLast,
CastToType,
ClassesToIndices,
Expand Down Expand Up @@ -542,6 +551,9 @@
AddExtremePointsChanneld,
AddExtremePointsChannelD,
AddExtremePointsChannelDict,
ApplyTransformToPointsd,
ApplyTransformToPointsD,
ApplyTransformToPointsDict,
AsChannelLastd,
AsChannelLastD,
AsChannelLastDict,
Expand Down
44 changes: 44 additions & 0 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta, set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
Expand All @@ -34,6 +35,8 @@
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.functional import (
affine_func,
convert_box_to_points,
convert_points_to_box,
flip,
orientation,
resize,
Expand Down Expand Up @@ -3544,3 +3547,44 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:

else:
return img


class ConvertBoxToPoints(Transform):
"""
Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode.
Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box.
Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None:
"""
Args:
mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode.
"""
super().__init__()
self.mode = StandardMode if mode is None else mode

def __call__(self, data: Any):
data = convert_to_tensor(data, track_meta=get_track_meta())
points = convert_box_to_points(data, mode=self.mode)
return convert_to_dst_type(points, data)[0]


class ConvertPointsToBoxes(Transform):
"""
Converts points to an axis-aligned bounding box.
Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or
(N, 4, 2) for the 4 corners of a 2D rectangle.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self) -> None:
super().__init__()

def __call__(self, data: Any):
data = convert_to_tensor(data, track_meta=get_track_meta())
box = convert_points_to_box(data)
return convert_to_dst_type(box, data)[0]
60 changes: 60 additions & 0 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@

from monai.config import DtypeLike, KeysCollection, SequenceStr
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.croppad.array import CenterSpatialCrop
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.array import (
Affine,
ConvertBoxToPoints,
ConvertPointsToBoxes,
Flip,
GridDistortion,
GridPatch,
Expand Down Expand Up @@ -2611,6 +2614,61 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class ConvertBoxToPointsd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`.
"""

backend = ConvertBoxToPoints.backend

def __init__(
self,
keys: KeysCollection,
point_key="points",
mode: str | BoxMode | type[BoxMode] | None = StandardMode,
allow_missing_keys: bool = False,
):
"""
Args:
keys: keys of the corresponding items to be transformed.
point_key: key to store the point data.
mode: the mode of the input boxes. Defaults to StandardMode.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.point_key = point_key
self.converter = ConvertBoxToPoints(mode=mode)

def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
data[self.point_key] = self.converter(d[key])
return data


class ConvertPointsToBoxesd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`.
"""

def __init__(self, keys: KeysCollection, box_key="box", allow_missing_keys: bool = False):
"""
Args:
keys: keys of the corresponding items to be transformed.
box_key: key to store the box data.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.box_key = box_key
self.converter = ConvertPointsToBoxes()

def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
data[self.box_key] = self.converter(d[key])
return data


SpatialResampleD = SpatialResampleDict = SpatialResampled
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
SpacingD = SpacingDict = Spacingd
Expand All @@ -2635,3 +2693,5 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
GridPatchD = GridPatchDict = GridPatchd
RandGridPatchD = RandGridPatchDict = RandGridPatchd
RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond
ConvertBoxToPointsD = ConvertBoxToPointsDict = ConvertBoxToPointsd
ConvertPointsToBoxesD = ConvertPointsToBoxesDict = ConvertPointsToBoxesd
71 changes: 70 additions & 1 deletion monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import monai
from monai.config import USE_COMPILED
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import get_boxmode
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd
Expand All @@ -32,7 +33,7 @@
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack
from monai.utils import (
LazyAttr,
TraceKeys,
Expand Down Expand Up @@ -610,3 +611,71 @@ def affine_func(
out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)
out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
return out if image_only else (out, affine)


def convert_box_to_points(bbox, mode):
"""
Converts an axis-aligned bounding box to points.
Args:
mode: The mode specifying how to interpret the bounding box.
bbox: Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2]
for 3D for each box. Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
Returns:
sequence of points representing the corners of the bounding box.
"""

mode = get_boxmode(mode)

points_list = []
for _num in range(bbox.shape[0]):
corners = mode.boxes_to_corners(bbox[_num : _num + 1])
if len(corners) == 4:
points_list.append(
concatenate(
[
concatenate([corners[0], corners[1]], axis=1),
concatenate([corners[2], corners[1]], axis=1),
concatenate([corners[2], corners[3]], axis=1),
concatenate([corners[0], corners[3]], axis=1),
],
axis=0,
)
)
else:
points_list.append(
concatenate(
[
concatenate([corners[0], corners[1], corners[2]], axis=1),
concatenate([corners[3], corners[1], corners[2]], axis=1),
concatenate([corners[3], corners[4], corners[2]], axis=1),
concatenate([corners[0], corners[4], corners[2]], axis=1),
concatenate([corners[0], corners[1], corners[5]], axis=1),
concatenate([corners[3], corners[1], corners[5]], axis=1),
concatenate([corners[3], corners[4], corners[5]], axis=1),
concatenate([corners[0], corners[4], corners[5]], axis=1),
],
axis=0,
)
)

return stack(points_list, dim=0)


def convert_points_to_box(points):
"""
Converts points to an axis-aligned bounding box.
Args:
points: Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of
a 3D cuboid or (N, 4, 2) for the 4 corners of a 2D rectangle.
"""
from monai.transforms.utils_pytorch_numpy_unification import max, min

mins = min(points, dim=1)
maxs = max(points, dim=1)
# Concatenate the min and max values to get the bounding boxes
bboxes = concatenate([mins, maxs], axis=1)

return bboxes
Loading

0 comments on commit 7219ee7

Please sign in to comment.