Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add box and points convert transform #8053

Merged
merged 13 commits into from
Sep 2, 2024
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,12 @@ Utility
:members:
:special-members: __call__

`ApplyTransformToPoints`
""""""""""""""""""""""""
.. autoclass:: ApplyTransformToPoints
:members:
:special-members: __call__

Dictionary Transforms
---------------------

Expand Down Expand Up @@ -2277,6 +2283,12 @@ Utility (Dict)
:members:
:special-members: __call__

`ApplyTransformToPointsd`
"""""""""""""""""""""""""
.. autoclass:: ApplyTransformToPointsd
:members:
:special-members: __call__


MetaTensor
^^^^^^^^^^
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@
from .utility.array import (
AddCoordinateChannels,
AddExtremePointsChannel,
ApplyTransformToPoints,
AsChannelLast,
CastToType,
ClassesToIndices,
Expand Down Expand Up @@ -542,6 +543,9 @@
AddExtremePointsChanneld,
AddExtremePointsChannelD,
AddExtremePointsChannelDict,
ApplyTransformToPointsd,
ApplyTransformToPointsD,
ApplyTransformToPointsDict,
AsChannelLastd,
AsChannelLastD,
AsChannelLastDict,
Expand Down
42 changes: 42 additions & 0 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta, set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
Expand All @@ -34,6 +35,8 @@
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.functional import (
affine_func,
convert_box_to_points,
convert_points_to_box,
flip,
orientation,
resize,
Expand Down Expand Up @@ -3544,3 +3547,42 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:

else:
return img


class ConvertBoxToPoints(Transform):
"""
Convert boxes to points. It can automatically convert the boxes to the points based on the box mode.
The return points will be in the shape of (N, 4, 2) or (N, 8, 3) based on the box mode.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None:
"""
Args:
mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode.
"""
super().__init__()
self.mode = StandardMode if mode is None else mode

def __call__(self, data: Any):
data = convert_to_tensor(data, track_meta=get_track_meta())
points = convert_box_to_points(data, mode=self.mode)
return convert_to_dst_type(points, data)[0]


class ConvertPointsToBoxes(Transform):
"""
Convert points to boxes.
The return box will be in the shape of (N, 6) for 3D or (N, 4) for 2D.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self) -> None:
super().__init__()

def __call__(self, data: Any):
data = convert_to_tensor(data, track_meta=get_track_meta())
box = convert_points_to_box(data)
return convert_to_dst_type(box, data)[0]
57 changes: 57 additions & 0 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@

from monai.config import DtypeLike, KeysCollection, SequenceStr
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.croppad.array import CenterSpatialCrop
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.array import (
Affine,
ConvertBoxToPoints,
ConvertPointsToBoxes,
Flip,
GridDistortion,
GridPatch,
Expand Down Expand Up @@ -2611,6 +2614,60 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class ConvertBoxToPointsd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`.
"""

backend = ConvertBoxToPoints.backend

def __init__(
self,
keys: KeysCollection,
point_key="points",
mode: str | BoxMode | type[BoxMode] | None = StandardMode,
allow_missing_keys: bool = False,
):
"""
Args:
keys: keys of the corresponding items to be transformed.
point_key: key to store the point data.
mode: the mode of the input boxes. Defaults to StandardMode.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.point_key = point_key
self.converter = ConvertBoxToPoints(mode=mode)

def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
data[self.point_key] = self.converter(d[key])
return data


class ConvertPointsToBoxesd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`.
"""
def __init__(self, keys: KeysCollection, box_key="box", allow_missing_keys: bool = False):
"""
Args:
keys: keys of the corresponding items to be transformed.
box_key: key to store the box data.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.box_key = box_key
self.converter = ConvertPointsToBoxes()

def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
data[self.box_key] = self.converter(d[key])
return data


SpatialResampleD = SpatialResampleDict = SpatialResampled
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
SpacingD = SpacingDict = Spacingd
Expand Down
68 changes: 67 additions & 1 deletion monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import monai
from monai.config import USE_COMPILED
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import get_boxmode
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd
Expand All @@ -32,7 +33,7 @@
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack, min, max
from monai.utils import (
LazyAttr,
TraceKeys,
Expand Down Expand Up @@ -610,3 +611,68 @@ def affine_func(
out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)
out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
return out if image_only else (out, affine)


def convert_box_to_points(bbox, mode):
"""
Convert bounding box to points.

Args:
mode: The mode specifying how to interpret the bounding box.
bbox: Bounding box in the form of [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D.
Return shape will be (N, 4) for 2D or (N, 6) for 3D.

Returns:
sequence of points representing the corners of the bounding box.
"""
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

mode = get_boxmode(mode)

points_list = []
for _num in range(bbox.shape[0]):
corners = mode.boxes_to_corners(bbox[_num : _num + 1])
if len(corners) == 4:
points_list.append(
concatenate(
[
concatenate([corners[0], corners[1]], axis=1),
concatenate([corners[2], corners[1]], axis=1),
concatenate([corners[2], corners[3]], axis=1),
concatenate([corners[0], corners[3]], axis=1),
],
axis=0,
)
)
else:
points_list.append(
concatenate(
[
concatenate([corners[0], corners[1], corners[2]], axis=1),
concatenate([corners[3], corners[1], corners[2]], axis=1),
concatenate([corners[3], corners[4], corners[2]], axis=1),
concatenate([corners[0], corners[4], corners[2]], axis=1),
concatenate([corners[0], corners[1], corners[5]], axis=1),
concatenate([corners[3], corners[1], corners[5]], axis=1),
concatenate([corners[3], corners[4], corners[5]], axis=1),
concatenate([corners[0], corners[4], corners[5]], axis=1),
],
axis=0,
)
)

return stack(points_list, dim=0)


def convert_points_to_box(points):
"""
Convert points to bounding box.

Args:
points: Points representing the corners of the bounding box. Shape (N, 8, 3) or (N, 4, 2).
"""
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
mins = min(points, dim=1)
maxs = max(points, dim=1)
# Concatenate the min and max values to get the bounding boxes
bboxes = concatenate([mins, maxs], axis=1)

return bboxes
Loading
Loading