Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transform to handle empty box as training data #6170

Merged
merged 25 commits into from
Mar 25, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions monai/apps/detection/networks/retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,9 @@ def forward(
"""
# 1. Check if input arguments are valid
if self.training:
check_training_targets(input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key)
targets = check_training_targets(
input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key
)
self._check_detector_training_components()

# 2. Pad list of images to a single Tensor `images` with spatial size divisible by self.size_divisible.
Expand Down Expand Up @@ -877,7 +879,7 @@ def get_cls_train_sample_per_image(

foreground_idxs_per_image = matched_idxs_per_image >= 0

num_foreground = foreground_idxs_per_image.sum()
num_foreground = int(foreground_idxs_per_image.sum())
num_gt_box = targets_per_image[self.target_box_key].shape[0]

if self.debug:
Expand Down
23 changes: 23 additions & 0 deletions monai/apps/detection/transforms/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
convert_box_to_standard_mode,
get_spatial_dims,
spatial_crop_boxes,
standardize_empty_box,
)
from monai.transforms import Rotate90, SpatialCrop
from monai.transforms.transform import Transform
Expand All @@ -46,6 +47,7 @@
)

__all__ = [
"StandardizeEmptyBox",
"ConvertBoxToStandardMode",
"ConvertBoxMode",
"AffineBox",
Expand All @@ -60,6 +62,27 @@
]


class StandardizeEmptyBox(Transform):
"""
When boxes are empty, this transform standardize it to shape of (0,4) or (0,6).

Args:
spatial_dims: number of spatial dimensions of the bounding boxes.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, spatial_dims: int) -> None:
self.spatial_dims = spatial_dims

def __call__(self, boxes: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
boxes: source bounding boxes, Nx4 or Nx6 or 0xM torch tensor or ndarray.
"""
return standardize_empty_box(boxes, spatial_dims=self.spatial_dims)


class ConvertBoxMode(Transform):
"""
This transform converts the boxes in src_mode to the dst_mode.
Expand Down
49 changes: 49 additions & 0 deletions monai/apps/detection/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
MaskToBox,
RotateBox90,
SpatialCropBox,
StandardizeEmptyBox,
ZoomBox,
)
from monai.apps.detection.transforms.box_ops import convert_box_to_mask
Expand All @@ -51,6 +52,9 @@
from monai.utils.type_conversion import convert_data_type, convert_to_tensor

__all__ = [
"StandardizeEmptyBoxd",
"StandardizeEmptyBoxD",
"StandardizeEmptyBoxDict",
"ConvertBoxModed",
"ConvertBoxModeD",
"ConvertBoxModeDict",
Expand Down Expand Up @@ -95,6 +99,50 @@
DEFAULT_POST_FIX = PostFix.meta()


class StandardizeEmptyBoxd(MapTransform, InvertibleTransform):
wyli marked this conversation as resolved.
Show resolved Hide resolved
"""
Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.StandardizeEmptyBox`.

When boxes are empty, this transform standardize it to shape of (0,4) or (0,6).

Example:
.. code-block:: python

data = {"boxes": torch.ones(0,), "image": torch.ones(1, 128, 128, 128)}
box_converter = StandardizeEmptyBoxd(box_keys=["boxes"], box_ref_image_keys="image")
box_converter(data)
"""

def __init__(self, box_keys: KeysCollection, box_ref_image_keys: str, allow_missing_keys: bool = False) -> None:
"""
Args:
box_keys: Keys to pick data for transformation.
box_ref_image_keys: The single key that represents the reference image to which ``box_keys`` are attached.
allow_missing_keys: don't raise exception if key is missing.

See also :py:class:`monai.apps.detection,transforms.array.ConvertBoxToStandardMode`
"""
super().__init__(box_keys, allow_missing_keys)
box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys)
if len(box_ref_image_keys_tuple) > 1:
raise ValueError(
"Please provide a single key for box_ref_image_keys.\
All boxes of box_keys are attached to box_ref_image_keys."
)
self.box_ref_image_keys = box_ref_image_keys

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
spatial_dims = len(d[self.box_ref_image_keys].shape) - 1
self.converter = StandardizeEmptyBox(spatial_dims=spatial_dims)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
return dict(data)


class ConvertBoxModed(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.ConvertBoxMode`.
Expand Down Expand Up @@ -1353,3 +1401,4 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
RandCropBoxByPosNegLabelD = RandCropBoxByPosNegLabelDict = RandCropBoxByPosNegLabeld
RotateBox90D = RotateBox90Dict = RotateBox90d
RandRotateBox90D = RandRotateBox90Dict = RandRotateBox90d
StandardizeEmptyBoxD = StandardizeEmptyBoxDict = StandardizeEmptyBoxd
29 changes: 23 additions & 6 deletions monai/apps/detection/utils/detector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

import torch
import torch.nn.functional as F
from torch import Tensor

from monai.data.box_utils import standardize_empty_box
from monai.transforms.croppad.array import SpatialPad
from monai.transforms.utils import compute_divisible_spatial_size, convert_pad_mode
from monai.utils import PytorchPadMode, ensure_tuple_rep
Expand Down Expand Up @@ -56,7 +58,7 @@ def check_training_targets(
spatial_dims: int,
target_label_key: str,
target_box_key: str,
) -> None:
) -> list[dict[str, Tensor]]:
"""
Validate the input images/targets during training (raise a `ValueError` if invalid).

Expand All @@ -75,7 +77,8 @@ def check_training_targets(
if len(input_images) != len(targets):
raise ValueError(f"len(input_images) should equal to len(targets), got {len(input_images)}, {len(targets)}.")

for target in targets:
for i in range(len(targets)):
target = targets[i]
if (target_label_key not in target.keys()) or (target_box_key not in target.keys()):
raise ValueError(
f"{target_label_key} and {target_box_key} are expected keys in targets. Got {target.keys()}."
Expand All @@ -85,10 +88,24 @@ def check_training_targets(
if not isinstance(boxes, torch.Tensor):
raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
if len(boxes.shape) != 2 or boxes.shape[-1] != 2 * spatial_dims:
raise ValueError(
f"Expected target boxes to be a tensor " f"of shape [N, {2* spatial_dims}], got {boxes.shape}."
)
return
if boxes.numel() == 0:
warnings.warn(
f"Warning: Given target boxes has shape of {boxes.shape}. "
f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2* spatial_dims}])."
)
else:
raise ValueError(
f"Expected target boxes to be a tensor of shape [N, {2* spatial_dims}], got {boxes.shape}.)."
)
if not torch.is_floating_point(boxes):
raise ValueError(f"Expected target boxes to be a float tensor, got {boxes.dtype}.")
targets[i][target_box_key] = standardize_empty_box(boxes, spatial_dims=spatial_dims) # type: ignore

labels = target[target_label_key]
if torch.is_floating_point(labels):
warnings.warn(f"Warning: Given target labels is {labels.dtype}. The detector converted it to torch.long.")
targets[i][target_label_key] = labels.long()
return targets


def pad_images(
Expand Down
59 changes: 56 additions & 3 deletions monai/data/box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,19 +395,41 @@ def get_spatial_dims(

# Check the validity of each input and add its corresponding spatial_dims to spatial_dims_set
if boxes is not None:
if int(boxes.shape[1]) not in [4, 6]:
if len(boxes.shape) != 2:
if boxes.shape[0] == 0:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], "
f"got boxes with shape {boxes.shape}. "
f"Please reshape it with boxes = torch.reshape(boxes, [0, 4]) or torch.reshape(boxes, [0, 6])."
)
else:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}."
)
if int(boxes.shape[1] / 2) not in SUPPORTED_SPATIAL_DIMS:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}."
)
spatial_dims_set.add(int(boxes.shape[1] / 2))
if points is not None:
if len(points.shape) != 2:
if points.shape[0] == 0:
raise ValueError(
f"Currently we support only points with shape [N,2] or [N,3], "
f"got points with shape {points.shape}. "
f"Please reshape it with points = torch.reshape(points, [0, 2]) or torch.reshape(points, [0, 3])."
)
else:
raise ValueError(
f"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}."
)
if int(points.shape[1]) not in SUPPORTED_SPATIAL_DIMS:
raise ValueError(
f"Currently we support only points with shape [N,2] or [N,3], got boxes with shape {points.shape}."
f"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}."
)
spatial_dims_set.add(int(points.shape[1]))
if corners is not None:
if len(corners) not in [4, 6]:
if len(corners) // 2 not in SUPPORTED_SPATIAL_DIMS:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], got box corner tuple with length {len(corners)}."
)
Expand Down Expand Up @@ -494,6 +516,33 @@ def get_boxmode(mode: str | BoxMode | type[BoxMode] | None = None, *args, **kwar
return StandardMode(*args, **kwargs)


def standardize_empty_box(boxes: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
"""
When boxes are empty, this function standardize it to shape of (0,4) or (0,6).

Args:
boxes: bounding boxes, Nx4 or Nx6 or empty torch tensor or ndarray
spatial_dims: number of spatial dimensions of the bounding boxes.

Returns:
bounding boxes with shape (N,4) or (N,6), N can be 0.

Example:
.. code-block:: python

boxes = torch.ones(0,)
standardize_empty_box(boxes, 3)
"""
# convert numpy to tensor if needed
boxes_t, *_ = convert_data_type(boxes, torch.Tensor)
# handle empty box
if boxes_t.shape[0] == 0:
boxes_t = torch.reshape(boxes_t, [0, spatial_dims * 2])
# convert tensor back to numpy if needed
boxes_dst, *_ = convert_to_dst_type(src=boxes_t, dst=boxes)
return boxes_dst


def convert_box_mode(
boxes: NdarrayOrTensor,
src_mode: str | BoxMode | type[BoxMode] | None = None,
Expand Down Expand Up @@ -522,6 +571,10 @@ def convert_box_mode(
convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode)
convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode())
"""
# handle empty box
if boxes.shape[0] == 0:
return boxes

src_boxmode = get_boxmode(src_mode)
dst_boxmode = get_boxmode(dst_mode)

Expand Down
29 changes: 15 additions & 14 deletions tests/test_retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,21 @@ def test_retina_detector_resnet_backbone_shape(self, input_param, input_shape):

detector.set_atss_matcher()
detector.set_hard_negative_sampler(10, 0.5)
gt_box_start = torch.randint(2, (3, input_param["spatial_dims"])).to(torch.float16)
gt_box_end = gt_box_start + torch.randint(1, 10, (3, input_param["spatial_dims"]))
one_target = {
"boxes": torch.cat((gt_box_start, gt_box_end), dim=1),
"labels": torch.randint(input_param["num_classes"], (3,)),
}
with train_mode(detector):
input_data = torch.randn(input_shape)
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)

input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)
for num_gt_box in [0, 3]: # test for both empty and non-empty boxes
gt_box_start = torch.randint(2, (num_gt_box, input_param["spatial_dims"])).to(torch.float16)
gt_box_end = gt_box_start + torch.randint(1, 10, (num_gt_box, input_param["spatial_dims"]))
one_target = {
"boxes": torch.cat((gt_box_start, gt_box_end), dim=1),
"labels": torch.randint(input_param["num_classes"], (num_gt_box,)),
}
with train_mode(detector):
input_data = torch.randn(input_shape)
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)

input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)

@parameterized.expand(TEST_CASES)
def test_naive_retina_detector_shape(self, input_param, input_shape):
Expand Down