Skip to content

Commit

Permalink
Add transform to handle empty box as training data (Project-MONAI#6170)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#5990 .

### Description

Add transforms to convert empty box with shape (0,M) or (0,) into (0,4)
or (0,6).
Provide format checking inside detector so users can know how to solve
the format issue with empty box input.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Can Zhao <canz@nvidia.com>
  • Loading branch information
Can-Zhao authored and jak0bw committed Mar 28, 2023
1 parent 2055768 commit 6680469
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 25 deletions.
6 changes: 4 additions & 2 deletions monai/apps/detection/networks/retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,9 @@ def forward(
"""
# 1. Check if input arguments are valid
if self.training:
check_training_targets(input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key)
targets = check_training_targets(
input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key
)
self._check_detector_training_components()

# 2. Pad list of images to a single Tensor `images` with spatial size divisible by self.size_divisible.
Expand Down Expand Up @@ -877,7 +879,7 @@ def get_cls_train_sample_per_image(

foreground_idxs_per_image = matched_idxs_per_image >= 0

num_foreground = foreground_idxs_per_image.sum()
num_foreground = int(foreground_idxs_per_image.sum())
num_gt_box = targets_per_image[self.target_box_key].shape[0]

if self.debug:
Expand Down
23 changes: 23 additions & 0 deletions monai/apps/detection/transforms/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
convert_box_to_standard_mode,
get_spatial_dims,
spatial_crop_boxes,
standardize_empty_box,
)
from monai.transforms import Rotate90, SpatialCrop
from monai.transforms.transform import Transform
Expand All @@ -46,6 +47,7 @@
)

__all__ = [
"StandardizeEmptyBox",
"ConvertBoxToStandardMode",
"ConvertBoxMode",
"AffineBox",
Expand All @@ -60,6 +62,27 @@
]


class StandardizeEmptyBox(Transform):
"""
When boxes are empty, this transform standardize it to shape of (0,4) or (0,6).
Args:
spatial_dims: number of spatial dimensions of the bounding boxes.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, spatial_dims: int) -> None:
self.spatial_dims = spatial_dims

def __call__(self, boxes: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
boxes: source bounding boxes, Nx4 or Nx6 or 0xM torch tensor or ndarray.
"""
return standardize_empty_box(boxes, spatial_dims=self.spatial_dims)


class ConvertBoxMode(Transform):
"""
This transform converts the boxes in src_mode to the dst_mode.
Expand Down
49 changes: 49 additions & 0 deletions monai/apps/detection/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
MaskToBox,
RotateBox90,
SpatialCropBox,
StandardizeEmptyBox,
ZoomBox,
)
from monai.apps.detection.transforms.box_ops import convert_box_to_mask
Expand All @@ -51,6 +52,9 @@
from monai.utils.type_conversion import convert_data_type, convert_to_tensor

__all__ = [
"StandardizeEmptyBoxd",
"StandardizeEmptyBoxD",
"StandardizeEmptyBoxDict",
"ConvertBoxModed",
"ConvertBoxModeD",
"ConvertBoxModeDict",
Expand Down Expand Up @@ -95,6 +99,50 @@
DEFAULT_POST_FIX = PostFix.meta()


class StandardizeEmptyBoxd(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.StandardizeEmptyBox`.
When boxes are empty, this transform standardize it to shape of (0,4) or (0,6).
Example:
.. code-block:: python
data = {"boxes": torch.ones(0,), "image": torch.ones(1, 128, 128, 128)}
box_converter = StandardizeEmptyBoxd(box_keys=["boxes"], box_ref_image_keys="image")
box_converter(data)
"""

def __init__(self, box_keys: KeysCollection, box_ref_image_keys: str, allow_missing_keys: bool = False) -> None:
"""
Args:
box_keys: Keys to pick data for transformation.
box_ref_image_keys: The single key that represents the reference image to which ``box_keys`` are attached.
allow_missing_keys: don't raise exception if key is missing.
See also :py:class:`monai.apps.detection,transforms.array.ConvertBoxToStandardMode`
"""
super().__init__(box_keys, allow_missing_keys)
box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys)
if len(box_ref_image_keys_tuple) > 1:
raise ValueError(
"Please provide a single key for box_ref_image_keys.\
All boxes of box_keys are attached to box_ref_image_keys."
)
self.box_ref_image_keys = box_ref_image_keys

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
spatial_dims = len(d[self.box_ref_image_keys].shape) - 1
self.converter = StandardizeEmptyBox(spatial_dims=spatial_dims)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
return dict(data)


class ConvertBoxModed(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.ConvertBoxMode`.
Expand Down Expand Up @@ -1353,3 +1401,4 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
RandCropBoxByPosNegLabelD = RandCropBoxByPosNegLabelDict = RandCropBoxByPosNegLabeld
RotateBox90D = RotateBox90Dict = RotateBox90d
RandRotateBox90D = RandRotateBox90Dict = RandRotateBox90d
StandardizeEmptyBoxD = StandardizeEmptyBoxDict = StandardizeEmptyBoxd
29 changes: 23 additions & 6 deletions monai/apps/detection/utils/detector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

import torch
import torch.nn.functional as F
from torch import Tensor

from monai.data.box_utils import standardize_empty_box
from monai.transforms.croppad.array import SpatialPad
from monai.transforms.utils import compute_divisible_spatial_size, convert_pad_mode
from monai.utils import PytorchPadMode, ensure_tuple_rep
Expand Down Expand Up @@ -56,7 +58,7 @@ def check_training_targets(
spatial_dims: int,
target_label_key: str,
target_box_key: str,
) -> None:
) -> list[dict[str, Tensor]]:
"""
Validate the input images/targets during training (raise a `ValueError` if invalid).
Expand All @@ -75,7 +77,8 @@ def check_training_targets(
if len(input_images) != len(targets):
raise ValueError(f"len(input_images) should equal to len(targets), got {len(input_images)}, {len(targets)}.")

for target in targets:
for i in range(len(targets)):
target = targets[i]
if (target_label_key not in target.keys()) or (target_box_key not in target.keys()):
raise ValueError(
f"{target_label_key} and {target_box_key} are expected keys in targets. Got {target.keys()}."
Expand All @@ -85,10 +88,24 @@ def check_training_targets(
if not isinstance(boxes, torch.Tensor):
raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
if len(boxes.shape) != 2 or boxes.shape[-1] != 2 * spatial_dims:
raise ValueError(
f"Expected target boxes to be a tensor " f"of shape [N, {2* spatial_dims}], got {boxes.shape}."
)
return
if boxes.numel() == 0:
warnings.warn(
f"Warning: Given target boxes has shape of {boxes.shape}. "
f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2* spatial_dims}])."
)
else:
raise ValueError(
f"Expected target boxes to be a tensor of shape [N, {2* spatial_dims}], got {boxes.shape}.)."
)
if not torch.is_floating_point(boxes):
raise ValueError(f"Expected target boxes to be a float tensor, got {boxes.dtype}.")
targets[i][target_box_key] = standardize_empty_box(boxes, spatial_dims=spatial_dims) # type: ignore

labels = target[target_label_key]
if torch.is_floating_point(labels):
warnings.warn(f"Warning: Given target labels is {labels.dtype}. The detector converted it to torch.long.")
targets[i][target_label_key] = labels.long()
return targets


def pad_images(
Expand Down
59 changes: 56 additions & 3 deletions monai/data/box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,19 +395,41 @@ def get_spatial_dims(

# Check the validity of each input and add its corresponding spatial_dims to spatial_dims_set
if boxes is not None:
if int(boxes.shape[1]) not in [4, 6]:
if len(boxes.shape) != 2:
if boxes.shape[0] == 0:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], "
f"got boxes with shape {boxes.shape}. "
f"Please reshape it with boxes = torch.reshape(boxes, [0, 4]) or torch.reshape(boxes, [0, 6])."
)
else:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}."
)
if int(boxes.shape[1] / 2) not in SUPPORTED_SPATIAL_DIMS:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}."
)
spatial_dims_set.add(int(boxes.shape[1] / 2))
if points is not None:
if len(points.shape) != 2:
if points.shape[0] == 0:
raise ValueError(
f"Currently we support only points with shape [N,2] or [N,3], "
f"got points with shape {points.shape}. "
f"Please reshape it with points = torch.reshape(points, [0, 2]) or torch.reshape(points, [0, 3])."
)
else:
raise ValueError(
f"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}."
)
if int(points.shape[1]) not in SUPPORTED_SPATIAL_DIMS:
raise ValueError(
f"Currently we support only points with shape [N,2] or [N,3], got boxes with shape {points.shape}."
f"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}."
)
spatial_dims_set.add(int(points.shape[1]))
if corners is not None:
if len(corners) not in [4, 6]:
if len(corners) // 2 not in SUPPORTED_SPATIAL_DIMS:
raise ValueError(
f"Currently we support only boxes with shape [N,4] or [N,6], got box corner tuple with length {len(corners)}."
)
Expand Down Expand Up @@ -494,6 +516,33 @@ def get_boxmode(mode: str | BoxMode | type[BoxMode] | None = None, *args, **kwar
return StandardMode(*args, **kwargs)


def standardize_empty_box(boxes: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
"""
When boxes are empty, this function standardize it to shape of (0,4) or (0,6).
Args:
boxes: bounding boxes, Nx4 or Nx6 or empty torch tensor or ndarray
spatial_dims: number of spatial dimensions of the bounding boxes.
Returns:
bounding boxes with shape (N,4) or (N,6), N can be 0.
Example:
.. code-block:: python
boxes = torch.ones(0,)
standardize_empty_box(boxes, 3)
"""
# convert numpy to tensor if needed
boxes_t, *_ = convert_data_type(boxes, torch.Tensor)
# handle empty box
if boxes_t.shape[0] == 0:
boxes_t = torch.reshape(boxes_t, [0, spatial_dims * 2])
# convert tensor back to numpy if needed
boxes_dst, *_ = convert_to_dst_type(src=boxes_t, dst=boxes)
return boxes_dst


def convert_box_mode(
boxes: NdarrayOrTensor,
src_mode: str | BoxMode | type[BoxMode] | None = None,
Expand Down Expand Up @@ -522,6 +571,10 @@ def convert_box_mode(
convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode)
convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode())
"""
# handle empty box
if boxes.shape[0] == 0:
return boxes

src_boxmode = get_boxmode(src_mode)
dst_boxmode = get_boxmode(dst_mode)

Expand Down
29 changes: 15 additions & 14 deletions tests/test_retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,21 @@ def test_retina_detector_resnet_backbone_shape(self, input_param, input_shape):

detector.set_atss_matcher()
detector.set_hard_negative_sampler(10, 0.5)
gt_box_start = torch.randint(2, (3, input_param["spatial_dims"])).to(torch.float16)
gt_box_end = gt_box_start + torch.randint(1, 10, (3, input_param["spatial_dims"]))
one_target = {
"boxes": torch.cat((gt_box_start, gt_box_end), dim=1),
"labels": torch.randint(input_param["num_classes"], (3,)),
}
with train_mode(detector):
input_data = torch.randn(input_shape)
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)

input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)
for num_gt_box in [0, 3]: # test for both empty and non-empty boxes
gt_box_start = torch.randint(2, (num_gt_box, input_param["spatial_dims"])).to(torch.float16)
gt_box_end = gt_box_start + torch.randint(1, 10, (num_gt_box, input_param["spatial_dims"]))
one_target = {
"boxes": torch.cat((gt_box_start, gt_box_end), dim=1),
"labels": torch.randint(input_param["num_classes"], (num_gt_box,)),
}
with train_mode(detector):
input_data = torch.randn(input_shape)
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)

input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]
targets = [one_target] * len(input_data)
result = detector.forward(input_data, targets)

@parameterized.expand(TEST_CASES)
def test_naive_retina_detector_shape(self, input_param, input_shape):
Expand Down

0 comments on commit 6680469

Please sign in to comment.