Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OV model for keypoint detection #3970

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/otx/algo/common/layers/spp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ def __init__(
def forward(self, x: Tensor) -> Tensor:
"""Forward."""
x = self.conv1(x)
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(x.device.type, enabled=False):
x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1)
return self.conv2(x)
2 changes: 1 addition & 1 deletion src/otx/algo/detection/layers/channel_attention_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(

def forward(self, x: Tensor) -> Tensor:
"""Forward function for ChannelAttention."""
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(x.device.type, enabled=False):
out = self.global_avgpool(x)
out = self.fc(out)
out = self.act(out)
Expand Down
2 changes: 2 additions & 0 deletions src/otx/algo/keypoint_detection/heads/rtmcc_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def to_numpy(self, x: Tensor | tuple[Tensor, Tensor]) -> np.ndarray | tuple[np.n
np.ndarray | tuple: return a tuple of converted numpy array(s)
"""
if isinstance(x, Tensor):
if x.dtype == torch.bfloat16:
x = x.float()
return x.detach().cpu().numpy()
if isinstance(x, tuple) and all(isinstance(i, Tensor) for i in x):
return tuple([self.to_numpy(i) for i in x])
Expand Down
54 changes: 28 additions & 26 deletions src/otx/core/metrics/pck.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _calc_distances(preds: np.ndarray, gts: np.ndarray, mask: np.ndarray, norm_f

Args:
preds (np.ndarray[N, K, D]): Predicted keypoint location.
gts (np.ndarray[N, K, D]): Groundtruth keypoint location.
gts (np.ndarray[N, K, D]): Ground truth keypoint location.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
Expand Down Expand Up @@ -75,7 +75,7 @@ def keypoint_pck_accuracy(
pred: np.ndarray,
gt: np.ndarray,
mask: np.ndarray,
thr: np.ndarray,
thr: float,
norm_factor: np.ndarray,
) -> tuple[np.ndarray, float, int]:
"""Calculate the pose accuracy of PCK for each individual keypoint.
Expand All @@ -99,7 +99,7 @@ def keypoint_pck_accuracy(
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
thr (float): Threshold of PCK calculation.
norm_factor (np.ndarray[N, 2]): Normalization factor for H&W.
norm_factor (np.ndarray[N, 2]): Normalization factor for the keypoints.

Returns:
tuple: A tuple containing keypoint accuracy.
Expand All @@ -117,34 +117,22 @@ def keypoint_pck_accuracy(


class PCKMeasure(Metric):
"""Computes the f-measure (also known as F1-score) for a resultset.

The f-measure is typically used in detection (localization) tasks to obtain a single number that balances precision
and recall.

To determine whether a predicted box matches a ground truth box an overlap measured
is used based on a minimum
intersection-over-union (IoU), by default a value of 0.5 is used.

In addition spurious results are eliminated by applying non-max suppression (NMS) so that two predicted boxes with
IoU > threshold are reduced to one. This threshold can be determined automatically by setting `vary_nms_threshold`
to True.
"""Computes the pose accuracy (also known as PCK) for a resultset.

Args:
label_info (int): Dataclass including label information.
vary_nms_threshold (bool): if True the maximal F-measure is determined by optimizing for different NMS threshold
values. Defaults to False.
cross_class_nms (bool): Whether non-max suppression should be applied cross-class. If True this will eliminate
boxes with sufficient overlap even if they are from different classes. Defaults to False.
dist_threshold (float): Threshold of PCK calculation.
"""

def __init__(
self,
label_info: LabelInfo,
dist_threshold: float = 0.05,
):
super().__init__()

self.label_info: LabelInfo = label_info
self.dist_threshold: float = dist_threshold
self.reset()

@property
Expand Down Expand Up @@ -190,19 +178,33 @@ def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]
def compute(self) -> dict:
"""Compute PCK score metric."""
pred_kpts = np.stack([p[0].cpu().numpy() for p in self.preds])
gt_kpts = np.stack([p[0] for p in self.targets])
kpts_visible = np.stack([p[1] for p in self.targets])

normalize = np.tile(np.array([self.input_size]), (pred_kpts.shape[0], 1))
gt_kpts_processed = []
for p in self.targets:
if len(p[0].shape) == 3 and p[0].shape[0] == 1:
gt_kpts_processed.append(p[0].squeeze())
else:
gt_kpts_processed.append(p[0])
gt_kpts = np.stack(gt_kpts_processed)

kpts_visible = []
for p in self.targets:
if len(p[1].shape) == 3 and p[1].shape[0] == 1:
kpts_visible.append(p[1].squeeze())
else:
kpts_visible.append(p[1])

kpts_visible_stacked = np.stack(kpts_visible)

normalize = np.tile(np.array([self.input_size[::-1]]), (pred_kpts.shape[0], 1))
_, avg_acc, _ = keypoint_pck_accuracy(
pred_kpts,
gt_kpts,
mask=kpts_visible > 0,
thr=0.05,
mask=kpts_visible_stacked > 0,
thr=self.dist_threshold,
norm_factor=normalize,
)

return {"accuracy": Tensor([avg_acc])}
return {"PCK": Tensor([avg_acc])}


def _pck_measure_callable(label_info: LabelInfo) -> PCKMeasure:
Expand Down
102 changes: 96 additions & 6 deletions src/otx/core/model/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
from otx.core.data.entity.keypoint_detection import KeypointDetBatchDataEntity, KeypointDetBatchPredEntity
from otx.core.metrics import MetricCallable, MetricInput
from otx.core.metrics.pck import PCKMeasureCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.export import TaskLevelExportParameters
from otx.core.types.label import LabelInfoTypes

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from model_api.models.utils import DetectedKeypoints
from torch import nn


class OTXKeypointDetectionModel(OTXModel[KeypointDetBatchDataEntity, KeypointDetBatchPredEntity]):
"""Base class for the detection models used in OTX."""
"""Base class for the keypoint detection models used in OTX."""

def __init__(
self,
Expand All @@ -37,8 +38,6 @@ def __init__(
metric: MetricCallable = PCKMeasureCallable,
torch_compile: bool = False,
) -> None:
self.mean = (0.0, 0.0, 0.0)
self.std = (255.0, 255.0, 255.0)
super().__init__(
label_info=label_info,
input_size=input_size,
Expand Down Expand Up @@ -157,6 +156,97 @@ def _export_parameters(self) -> TaskLevelExportParameters:
model_type="keypoint_detection",
task_type="keypoint_detection",
confidence_threshold=self.hparams.get("best_confidence_threshold", None),
iou_threshold=0.5,
tile_config=self.tile_config if self.tile_config.enable_tiler else None,
)

def get_dummy_input(self, batch_size: int = 1) -> KeypointDetBatchDataEntity:
"""Returns a dummy input for key point detection model."""
images = torch.rand(batch_size, 3, *self.input_size)
return KeypointDetBatchDataEntity(
batch_size,
images,
[],
[torch.tensor([0, 0, self.input_size[1], self.input_size[0]])],
labels=[],
bbox_info=[],
keypoints=[],
keypoints_visible=[],
)


class OVKeypointDetectionModel(OVModel[KeypointDetBatchDataEntity, KeypointDetBatchPredEntity]):
"""Keypoint detection model compatible for OpenVINO IR inference.

It can consume OpenVINO IR model path or model name from Intel OMZ repository
and create the OTX keypoint detection model compatible for OTX testing pipeline.
"""

def __init__(
self,
model_name: str,
model_type: str = "keypoint_detection",
async_inference: bool = True,
max_num_requests: int | None = None,
use_throughput_mode: bool = True,
model_api_configuration: dict[str, Any] | None = None,
metric: MetricCallable = PCKMeasureCallable,
**kwargs,
) -> None:
super().__init__(
model_name=model_name,
model_type=model_type,
async_inference=async_inference,
max_num_requests=max_num_requests,
use_throughput_mode=use_throughput_mode,
model_api_configuration=model_api_configuration,
metric=metric,
)

def _customize_outputs(
self,
outputs: list[DetectedKeypoints],
inputs: KeypointDetBatchDataEntity,
) -> KeypointDetBatchPredEntity | OTXBatchLossEntity:
keypoints = []
scores = []
for output in outputs:
keypoints.append(torch.as_tensor(output.keypoints, device=self.device))
scores.append(torch.as_tensor(output.scores, device=self.device))

return KeypointDetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
keypoints=keypoints,
scores=scores,
keypoints_visible=[],
bboxes=[],
labels=[],
bbox_info=[],
)

def configure_metric(self) -> None:
"""Configure the metric."""
super().configure_metric()
self._metric.input_size = (self.model.h, self.model.w)

def _convert_pred_entity_to_compute_metric(
self,
preds: KeypointDetBatchPredEntity,
inputs: KeypointDetBatchDataEntity,
) -> MetricInput:
return {
"preds": [
{
"keypoints": kpt,
"scores": score,
}
for kpt, score in zip(preds.keypoints, preds.scores)
],
"target": [
{
"keypoints": kpt,
"keypoints_visible": kpt_visible,
}
for kpt, kpt_visible in zip(inputs.keypoints, inputs.keypoints_visible)
],
}
1 change: 1 addition & 0 deletions src/otx/engine/utils/auto_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
OTXTaskType.ANOMALY_CLASSIFICATION: "otx.algo.anomaly.openvino_model.AnomalyOpenVINO",
OTXTaskType.ANOMALY_DETECTION: "otx.algo.anomaly.openvino_model.AnomalyOpenVINO",
OTXTaskType.ANOMALY_SEGMENTATION: "otx.algo.anomaly.openvino_model.AnomalyOpenVINO",
OTXTaskType.KEYPOINT_DETECTION: "otx.core.model.keypoint_detection.OVKeypointDetectionModel",
}


Expand Down
50 changes: 50 additions & 0 deletions src/otx/recipe/keypoint_detection/openvino_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
model:
class_path: otx.core.model.keypoint_detection.OVKeypointDetectionModel
init_args:
label_info: 19
model_name: rtm_pose_tiny
model_type: "keypoint_detection"
async_inference: true
use_throughput_mode: true

engine:
task: KEYPOINT_DETECTION
device: cpu

callback_monitor: val/PCK

data: ../_base_/data/keypoint_detection.yaml
overrides:
reset:
- data.train_subset.transforms
- data.val_subset.transforms
- data.test_subset.transforms

data:
stack_images: false
train_subset:
batch_size: 1
num_workers: 2
transforms:
- class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
input_size: $(input_size)

val_subset:
batch_size: 1
num_workers: 2
transforms:
- class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
input_size: $(input_size)

test_subset:
batch_size: 64
num_workers: 2
transforms:
- class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
input_size: $(input_size)
4 changes: 2 additions & 2 deletions src/otx/recipe/keypoint_detection/rtmpose_tiny.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ model:
mode: max
factor: 0.1
patience: 9
monitor: val/accuracy
monitor: val/PCK

engine:
task: KEYPOINT_DETECTION
device: auto

callback_monitor: val/accuracy
callback_monitor: val/PCK

data: ../_base_/data/keypoint_detection.yaml

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ model:
mode: max
factor: 0.1
patience: 9
monitor: val/accuracy
monitor: val/PCK

engine:
task: KEYPOINT_DETECTION
device: auto

callback_monitor: val/accuracy
callback_monitor: val/PCK

data: ../_base_/data/keypoint_detection.yaml

Expand Down
3 changes: 3 additions & 0 deletions tests/integration/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ def test_otx_e2e(

if "yolov9" in model_name:
return # RT-DETR currently is not supported.
if "keypoint" in recipe:
print("Explain is not supported for keypoint detection")
return

tmp_path_test = tmp_path / f"otx_export_xai_{model_name}"
for export_case in fxt_export_list:
Expand Down
1 change: 1 addition & 0 deletions tests/integration/cli/test_export_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def fxt_local_seed() -> int:
"visual_prompting": "test/f1-score",
"zero_shot_visual_prompting": "test/f1-score",
"action_classification": "test/accuracy",
"keypoint_detection": "test/PCK",
}


Expand Down
Loading
Loading