Skip to content

Commit

Permalink
Add keypoints OV model
Browse files Browse the repository at this point in the history
Add integration tests for keypoint detection task (#3945)

* Fix export

* Add integrastion tests for keypoint detection

* Update export

* Add kp det to gh workflow

* Fix ruff

* Move std mean to class vars

Cleanup

Implement dummy input for kp

Update export inf test

Upgrade MAPI

Cleanup

Fix CLI tests

Cleanup

Fix kp OV model data config

Fix dummy input docstring

Co-authored-by: Prokofiev Kirill <kirill.prokofiev@intel.com>

Add unit tests for kp det model

Fix ruff

Update pck docs

Fix typos in tests

Add uts for PCK metric

Rename kaypoints det metric

Del extra transform in p ove model

Fix unit test naming

Fix cpu train/eval

Fix autocast deprecation warnings
  • Loading branch information
sovrasov committed Sep 21, 2024
1 parent e5de44c commit c3f2ed1
Show file tree
Hide file tree
Showing 13 changed files with 420 additions and 37 deletions.
2 changes: 1 addition & 1 deletion src/otx/algo/common/layers/spp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ def __init__(
def forward(self, x: Tensor) -> Tensor:
"""Forward."""
x = self.conv1(x)
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(x.device.type, enabled=False):
x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1)
return self.conv2(x)
2 changes: 1 addition & 1 deletion src/otx/algo/detection/layers/channel_attention_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(

def forward(self, x: Tensor) -> Tensor:
"""Forward function for ChannelAttention."""
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(x.device.type, enabled=False):
out = self.global_avgpool(x)
out = self.fc(out)
out = self.act(out)
Expand Down
2 changes: 2 additions & 0 deletions src/otx/algo/keypoint_detection/heads/rtmcc_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def to_numpy(self, x: Tensor | tuple[Tensor, Tensor]) -> np.ndarray | tuple[np.n
np.ndarray | tuple: return a tuple of converted numpy array(s)
"""
if isinstance(x, Tensor):
if x.dtype == torch.bfloat16:
x = x.float()
return x.detach().cpu().numpy()
if isinstance(x, tuple) and all(isinstance(i, Tensor) for i in x):
return tuple([self.to_numpy(i) for i in x])
Expand Down
52 changes: 27 additions & 25 deletions src/otx/core/metrics/pck.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _calc_distances(preds: np.ndarray, gts: np.ndarray, mask: np.ndarray, norm_f
Args:
preds (np.ndarray[N, K, D]): Predicted keypoint location.
gts (np.ndarray[N, K, D]): Groundtruth keypoint location.
gts (np.ndarray[N, K, D]): Ground truth keypoint location.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.
Expand Down Expand Up @@ -75,7 +75,7 @@ def keypoint_pck_accuracy(
pred: np.ndarray,
gt: np.ndarray,
mask: np.ndarray,
thr: np.ndarray,
thr: float,
norm_factor: np.ndarray,
) -> tuple[np.ndarray, float, int]:
"""Calculate the pose accuracy of PCK for each individual keypoint.
Expand Down Expand Up @@ -117,34 +117,22 @@ def keypoint_pck_accuracy(


class PCKMeasure(Metric):
"""Computes the f-measure (also known as F1-score) for a resultset.
The f-measure is typically used in detection (localization) tasks to obtain a single number that balances precision
and recall.
To determine whether a predicted box matches a ground truth box an overlap measured
is used based on a minimum
intersection-over-union (IoU), by default a value of 0.5 is used.
In addition spurious results are eliminated by applying non-max suppression (NMS) so that two predicted boxes with
IoU > threshold are reduced to one. This threshold can be determined automatically by setting `vary_nms_threshold`
to True.
"""Computes the pose accuracy (also known as PCK) for a resultset.
Args:
label_info (int): Dataclass including label information.
vary_nms_threshold (bool): if True the maximal F-measure is determined by optimizing for different NMS threshold
values. Defaults to False.
cross_class_nms (bool): Whether non-max suppression should be applied cross-class. If True this will eliminate
boxes with sufficient overlap even if they are from different classes. Defaults to False.
dist_threshold (float): Threshold of PCK calculation.
"""

def __init__(
self,
label_info: LabelInfo,
dist_threshold: float = 0.05,
):
super().__init__()

self.label_info: LabelInfo = label_info
self.dist_threshold: float = dist_threshold
self.reset()

def reset(self) -> None:
Expand Down Expand Up @@ -174,19 +162,33 @@ def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]
def compute(self) -> dict:
"""Compute PCK score metric."""
pred_kpts = np.stack([p[0].cpu().numpy() for p in self.preds])
gt_kpts = np.stack([p[0] for p in self.targets])
kpts_visible = np.stack([p[1] for p in self.targets])

normalize = np.tile(np.array([[256, 192]]), (pred_kpts.shape[0], 1))
gt_kpts_processed = []
for p in self.targets:
if len(p[0].shape) == 3 and p[0].shape[0] == 1:
gt_kpts_processed.append(p[0].squeeze())
else:
gt_kpts_processed.append(p[0])
gt_kpts = np.stack(gt_kpts_processed)

kpts_visible = []
for p in self.targets:
if len(p[1].shape) == 3 and p[1].shape[0] == 1:
kpts_visible.append(p[1].squeeze())
else:
kpts_visible.append(p[1])

kpts_visible_stacked = np.stack(kpts_visible)

normalize = np.tile(np.array([self.input_size[::-1]]), (pred_kpts.shape[0], 1))
_, avg_acc, _ = keypoint_pck_accuracy(
pred_kpts,
gt_kpts,
mask=kpts_visible > 0,
thr=0.05,
mask=kpts_visible_stacked > 0,
thr=self.dist_threshold,
norm_factor=normalize,
)

return {"accuracy": Tensor([avg_acc])}
return {"PCK": Tensor([avg_acc])}


def _pck_measure_callable(label_info: LabelInfo) -> PCKMeasure:
Expand Down
102 changes: 96 additions & 6 deletions src/otx/core/model/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
from otx.core.data.entity.keypoint_detection import KeypointDetBatchDataEntity, KeypointDetBatchPredEntity
from otx.core.metrics import MetricCallable, MetricInput
from otx.core.metrics.pck import PCKMeasureCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.export import TaskLevelExportParameters
from otx.core.types.label import LabelInfoTypes

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from model_api.models.utils import DetectedKeypoints
from torch import nn


class OTXKeypointDetectionModel(OTXModel[KeypointDetBatchDataEntity, KeypointDetBatchPredEntity]):
"""Base class for the detection models used in OTX."""
"""Base class for the keypoint detection models used in OTX."""

def __init__(
self,
Expand All @@ -37,8 +38,6 @@ def __init__(
metric: MetricCallable = PCKMeasureCallable,
torch_compile: bool = False,
) -> None:
self.mean = (0.0, 0.0, 0.0)
self.std = (255.0, 255.0, 255.0)
super().__init__(
label_info=label_info,
input_size=input_size,
Expand Down Expand Up @@ -152,6 +151,97 @@ def _export_parameters(self) -> TaskLevelExportParameters:
model_type="keypoint_detection",
task_type="keypoint_detection",
confidence_threshold=self.hparams.get("best_confidence_threshold", None),
iou_threshold=0.5,
tile_config=self.tile_config if self.tile_config.enable_tiler else None,
)

def get_dummy_input(self, batch_size: int = 1) -> KeypointDetBatchDataEntity:
"""Returns a dummy input for key point detection model."""
images = torch.rand(batch_size, 3, *self.input_size)
return KeypointDetBatchDataEntity(
batch_size,
images,
[],
[torch.tensor([0, 0, self.input_size[1], self.input_size[0]])],
labels=[],
bbox_info=[],
keypoints=[],
keypoints_visible=[],
)


class OVKeypointDetectionModel(OVModel[KeypointDetBatchDataEntity, KeypointDetBatchPredEntity]):
"""Keypoint detection model compatible for OpenVINO IR inference.
It can consume OpenVINO IR model path or model name from Intel OMZ repository
and create the OTX keypoint detection model compatible for OTX testing pipeline.
"""

def __init__(
self,
model_name: str,
model_type: str = "keypoint_detection",
async_inference: bool = True,
max_num_requests: int | None = None,
use_throughput_mode: bool = True,
model_api_configuration: dict[str, Any] | None = None,
metric: MetricCallable = PCKMeasureCallable,
**kwargs,
) -> None:
super().__init__(
model_name=model_name,
model_type=model_type,
async_inference=async_inference,
max_num_requests=max_num_requests,
use_throughput_mode=use_throughput_mode,
model_api_configuration=model_api_configuration,
metric=metric,
)

def _customize_outputs(
self,
outputs: list[DetectedKeypoints],
inputs: KeypointDetBatchDataEntity,
) -> KeypointDetBatchPredEntity | OTXBatchLossEntity:
keypoints = []
scores = []
for output in outputs:
keypoints.append(torch.as_tensor(output.keypoints, device=self.device))
scores.append(torch.as_tensor(output.scores, device=self.device))

return KeypointDetBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
keypoints=keypoints,
scores=scores,
keypoints_visible=[],
bboxes=[],
labels=[],
bbox_info=[],
)

def configure_metric(self) -> None:
"""Configure the metric."""
super().configure_metric()
self._metric.input_size = (self.model.h, self.model.w)

def _convert_pred_entity_to_compute_metric(
self,
preds: KeypointDetBatchPredEntity,
inputs: KeypointDetBatchDataEntity,
) -> MetricInput:
return {
"preds": [
{
"keypoints": kpt,
"scores": score,
}
for kpt, score in zip(preds.keypoints, preds.scores)
],
"target": [
{
"keypoints": kpt,
"keypoints_visible": kpt_visible,
}
for kpt, kpt_visible in zip(inputs.keypoints, inputs.keypoints_visible)
],
}
1 change: 1 addition & 0 deletions src/otx/engine/utils/auto_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
OTXTaskType.ANOMALY_CLASSIFICATION: "otx.algo.anomaly.openvino_model.AnomalyOpenVINO",
OTXTaskType.ANOMALY_DETECTION: "otx.algo.anomaly.openvino_model.AnomalyOpenVINO",
OTXTaskType.ANOMALY_SEGMENTATION: "otx.algo.anomaly.openvino_model.AnomalyOpenVINO",
OTXTaskType.KEYPOINT_DETECTION: "otx.core.model.keypoint_detection.OVKeypointDetectionModel",
}


Expand Down
50 changes: 50 additions & 0 deletions src/otx/recipe/keypoint_detection/openvino_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
model:
class_path: otx.core.model.keypoint_detection.OVKeypointDetectionModel
init_args:
label_info: 19
model_name: rtm_pose_tiny
model_type: "keypoint_detection"
async_inference: true
use_throughput_mode: true

engine:
task: KEYPOINT_DETECTION
device: cpu

callback_monitor: val/PCK

data: ../_base_/data/keypoint_detection.yaml
overrides:
reset:
- data.train_subset.transforms
- data.val_subset.transforms
- data.test_subset.transforms

data:
stack_images: false
train_subset:
batch_size: 1
num_workers: 2
transforms:
- class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
input_size: $(input_size)

val_subset:
batch_size: 1
num_workers: 2
transforms:
- class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
input_size: $(input_size)

test_subset:
batch_size: 64
num_workers: 2
transforms:
- class_path: otx.core.data.transform_libs.torchvision.GetBBoxCenterScale
- class_path: otx.core.data.transform_libs.torchvision.TopdownAffine
init_args:
input_size: $(input_size)
4 changes: 2 additions & 2 deletions src/otx/recipe/keypoint_detection/rtmpose_tiny.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ model:
mode: max
factor: 0.1
patience: 9
monitor: val/accuracy
monitor: val/PCK

engine:
task: KEYPOINT_DETECTION
device: auto

callback_monitor: val/accuracy
callback_monitor: val/PCK

data: ../_base_/data/keypoint_detection.yaml

Expand Down
Loading

0 comments on commit c3f2ed1

Please sign in to comment.