Skip to content

Commit

Permalink
Add keypoint detection recipe for single object cases (#3903)
Browse files Browse the repository at this point in the history
* add rtmpose_tiny for single obj

* add rtmpose_tiny for single obj

* modify test subset name

* fix unit test

* update recipe with reset
  • Loading branch information
wonjuleee authored Aug 28, 2024
1 parent 0dc7a29 commit 0d6799c
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/otx/algo/keypoint_detection/heads/rtmcc_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def loss(self, x: tuple[Tensor], entity: KeypointDetBatchDataEntity) -> dict:
mask=mask,
)

loss_pose = torch.tensor(avg_acc, device=device)
loss_pose = -1 * torch.tensor(avg_acc, device=device)
losses.update(loss_pose=loss_pose)

return losses
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/data/dataset/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def _get_item_impl(self, index: int) -> KeypointDetDataEntity | None:
).reshape(-1, 2)
keypoints_visible = np.minimum(1, keypoints)[..., 0]

bbox_center = (bboxes[0, 2:] + bboxes[0, :2]) * 0.5
bbox_scale = (bboxes[0, 2:] - bboxes[0, :2]) * 1.25
bbox_center = np.array(img_shape) / 2.0
bbox_scale = np.array(img_shape)
bbox_rotation = 0.0

entity = KeypointDetDataEntity(
Expand Down
62 changes: 42 additions & 20 deletions src/otx/core/data/transform_libs/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3109,6 +3109,46 @@ def __call__(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None:
return inputs


class GetBBoxCenterScale(tvt_v2.Transform):
"""Convert bboxes from [x, y, w, h] to center and scale.
The center is the coordinates of the bbox center, and the scale is the
bbox width and height normalized by a scale factor.
Required Keys:
- bbox
Modified Keys:
- bbox_center
- bbox_scale
Args:
padding (float): The bbox padding scale that will be multilied to
`bbox_scale`. Defaults to 1.25
"""

def __init__(self, padding: float = 1.25) -> None:
super().__init__()

self.padding = padding

def __call__(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None:
"""Transform function to add bbox_infos from bboxes for keypoint detection task."""
assert len(_inputs) == 1, "[tmp] Multiple entity is not supported yet." # noqa: S101
inputs = _inputs[0]

bbox = inputs.bboxes[0].numpy()
inputs.bbox_info.center = (bbox[2:] + bbox[:2]) * 0.5
inputs.bbox_info.scale = (bbox[2:] - bbox[:2]) * self.padding

return inputs

def __repr__(self) -> str:
"""Print the basic information of the transform.
Returns:
str: Formatted string.
"""
return self.__class__.__name__ + f"(padding={self.padding})"


class RandomBBoxTransform(tvt_v2.Transform):
r"""Rnadomly shift, resize and rotate the bounding boxes.
Expand Down Expand Up @@ -3202,16 +3242,7 @@ def _get_transform_params(self) -> tuple:
return offset, scale, rotate

def __call__(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None:
"""The transform function of :class:`RandomBboxTransform`.
See ``transform()`` method of :class:`BaseTransform` for details.
Args:
results (dict): The result dict
Returns:
dict: The result dict.
"""
"""Transform function to adjust bbox_infos randomly."""
assert len(_inputs) == 1, "[tmp] Multiple entity is not supported yet." # noqa: S101
inputs = _inputs[0]

Expand Down Expand Up @@ -3378,16 +3409,7 @@ def _get_warp_image(
return torch.from_numpy(warped_image).permute(2, 0, 1)

def __call__(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None:
"""The transform function of :class:`TopdownAffine`.
See ``transform()`` method of :class:`BaseTransform` for details.
Args:
results (dict): The result dict
Returns:
dict: The result dict.
"""
"""Transform function to affine image through warp matrix."""
assert len(_inputs) == 1, "[tmp] Multiple entity is not supported yet." # noqa: S101
inputs = _inputs[0]

Expand Down
3 changes: 3 additions & 0 deletions src/otx/recipe/_base_/data/keypoint_detection.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ train_subset:
subset_name: train
batch_size: 32
transforms:
- class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale
- class_path: otx.core.data.transform_libs.torchvision.RandomBBoxTransform
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
Expand All @@ -30,6 +31,7 @@ val_subset:
subset_name: val
batch_size: 32
transforms:
- class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
input_size: $(input_size)
Expand All @@ -45,6 +47,7 @@ test_subset:
subset_name: test
batch_size: 32
transforms:
- class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
input_size: $(input_size)
Expand Down
81 changes: 81 additions & 0 deletions src/otx/recipe/keypoint_detection/rtmpose_tiny_single_obj.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
model:
class_path: otx.algo.keypoint_detection.rtmpose.RTMPoseTiny
init_args:
label_info: 17

optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.001
weight_decay: 0.0001

scheduler:
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
init_args:
num_warmup_steps: 3
main_scheduler_callable:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: max
factor: 0.1
patience: 9
monitor: val/accuracy

engine:
task: KEYPOINT_DETECTION
device: auto

callback_monitor: val/accuracy

data: ../_base_/data/keypoint_detection.yaml

overrides:
gradient_clip_val: 35.0
reset:
- data.train_subset.transforms
- data.val_subset.transforms
- data.test_subset.transforms
input_size:
- 512
- 512
train_subset:
transforms:
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
input_size: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.YOLOXHSVRandomAug
init_args:
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
val_subset:
transforms:
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
input_size: $(input_size)
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
test_subset:
transforms:
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
input_size: $(input_size)
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
14 changes: 11 additions & 3 deletions tests/unit/core/data/transform_libs/test_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from otx.core.data.transform_libs.torchvision import (
CachedMixUp,
CachedMosaic,
Compose,
DecodeVideo,
FilterAnnotations,
GetBBoxCenterScale,
MinIoURandomCrop,
PackVideo,
Pad,
Expand Down Expand Up @@ -901,12 +903,18 @@ def keypoint_det_entity(self) -> KeypointDetDataEntity:
labels=torch.LongTensor([0]),
keypoints=tv_tensors.TVTensor(np.array([[0, 4], [4, 2], [2, 6], [6, 0]])),
keypoints_visible=tv_tensors.TVTensor(np.array([1, 1, 1, 0])),
bbox_info=BboxInfo(center=np.array([3.5, 3.5]), scale=np.array([8.75, 8.75]), rotation=0),
bbox_info=BboxInfo(center=np.array([5, 5]), scale=np.array([10, 10]), rotation=0),
)

def test_forward(self, keypoint_det_entity) -> None:
transform = TopdownAffine(input_size=(5, 5))
transform = Compose(
[
GetBBoxCenterScale(),
TopdownAffine(input_size=(5, 5)),
],
)
results = transform(deepcopy(keypoint_det_entity))

assert hasattr(results, "keypoints")
assert np.array_equal(results.bbox_info.center, np.array([3.5, 3.5]))
assert np.array_equal(results.bbox_info.scale, np.array([8.75, 8.75]))
assert results.keypoints.shape == (4, 2)

0 comments on commit 0d6799c

Please sign in to comment.