diff --git a/src/datumaro/components/annotation.py b/src/datumaro/components/annotation.py index 262dff0139..de89952821 100644 --- a/src/datumaro/components/annotation.py +++ b/src/datumaro/components/annotation.py @@ -4,6 +4,7 @@ from __future__ import annotations +import math from enum import IntEnum from functools import partial from itertools import zip_longest @@ -852,8 +853,7 @@ class RotatedBbox(_Shape): def __init__(self, x, y, w, h, r, *args, **kwargs): kwargs.pop("points", None) # comes from wrap() - # points = x1, y1, x2, y2, x3, y3, x4, y4 - self.__attrs_init__([x, y, x + w, y + h, r], *args, **kwargs) + self.__attrs_init__([x, y, w, h, r], *args, **kwargs) @property def x(self): @@ -865,11 +865,11 @@ def y(self): @property def w(self): - return self.points[2] - self.points[0] + return self.points[2] @property def h(self): - return self.points[3] - self.points[1] + return self.points[3] @property def r(self): @@ -879,10 +879,40 @@ def get_area(self): return self.w * self.h def get_bbox(self): - return [self.x, self.y, self.w, self.h] + points = self.as_polygon() + xs = [p for p in points[0::2]] + ys = [p for p in points[1::2]] + + return [min(xs), min(ys), max(xs) - min(xs), max(ys) - min(ys)] + + def get_rotated_bbox(self): + return [self.x, self.y, self.w, self.h, self.r] def as_polygon(self) -> List[float]: - return self.points + """Convert [center_x, center_y, width, height, rotation] to 8 coordinates for a rotated bounding box.""" + + half_width = self.w / 2 + half_height = self.h / 2 + rot = np.deg2rad(self.r) + + # Calculate coordinates of the four corners + corners = np.array( + [ + [-half_width, -half_height], + [half_width, -half_height], + [half_width, half_height], + [-half_width, half_height], + ] + ) + + # Rotate the corners + transformed = [] + for corner in corners: + x = corner[0] * math.cos(rot) - corner[1] * math.sin(rot) + self.x + y = corner[0] * math.sin(rot) + corner[1] * math.cos(rot) + self.y + transformed.extend([x, y]) + + return transformed def iou(self, other: _Shape) -> Union[float, Literal[-1]]: from datumaro.util.annotation_util import bbox_iou diff --git a/src/datumaro/components/annotations/matcher.py b/src/datumaro/components/annotations/matcher.py index 40af5bb2ff..5e59456bfd 100644 --- a/src/datumaro/components/annotations/matcher.py +++ b/src/datumaro/components/annotations/matcher.py @@ -9,7 +9,7 @@ from datumaro.components.abstracts import IMergerContext from datumaro.components.abstracts.merger import IMatcherContext -from datumaro.components.annotation import Annotation +from datumaro.components.annotation import Annotation, Points from datumaro.util.annotation_util import ( OKS, approximate_line, @@ -371,5 +371,7 @@ def match_annotations(self, sources): @attrs class RotatedBboxMatcher(ShapeMatcher): + sigma: Optional[list] = attrib(default=None) + def distance(self, a, b): - return OKS(a, b, sigma=self.sigma) + return OKS(Points(a.as_polygon()), Points(b.as_polygon()), sigma=self.sigma) diff --git a/src/datumaro/components/task.py b/src/datumaro/components/task.py index 5d02e59ac8..0adc74dc00 100644 --- a/src/datumaro/components/task.py +++ b/src/datumaro/components/task.py @@ -42,8 +42,7 @@ def __init__(self): AnnotationType.points, }, TaskType.detection_rotated: { - AnnotationType.label, - AnnotationType.polygon, + AnnotationType.rotated_bbox, }, TaskType.detection_3d: {AnnotationType.label, AnnotationType.cuboid_3d}, TaskType.segmentation_semantic: { @@ -53,6 +52,7 @@ def __init__(self): TaskType.segmentation_instance: { AnnotationType.label, AnnotationType.bbox, + AnnotationType.rotated_bbox, AnnotationType.ellipse, AnnotationType.polygon, AnnotationType.points, @@ -65,6 +65,7 @@ def __init__(self): TaskType.mixed: { AnnotationType.label, AnnotationType.bbox, + AnnotationType.rotated_bbox, AnnotationType.cuboid_3d, AnnotationType.ellipse, AnnotationType.polygon, diff --git a/tests/unit/test_ops.py b/tests/unit/test_ops.py index 45d97ab44e..53266f78b4 100644 --- a/tests/unit/test_ops.py +++ b/tests/unit/test_ops.py @@ -17,6 +17,7 @@ PointsCategories, Polygon, PolyLine, + RotatedBbox, ) from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DEFAULT_SUBSET_NAME, DatasetItem @@ -222,6 +223,7 @@ def test_can_match_shapes(self): Points([5, 6], label=0, group=1), Points([6, 8], label=1, group=1), PolyLine([1, 1, 2, 1, 3, 1]), + RotatedBbox(4, 5, 2, 4, 20, label=2), ], ), ], @@ -251,6 +253,7 @@ def test_can_match_shapes(self): Points([5.5, 6.5], label=0, group=2), Points([6, 8], label=1, group=2), PolyLine([1, 1.5, 2, 1.5]), + RotatedBbox(2, 4, 2, 4, 10, label=1), ], ), ], @@ -280,6 +283,7 @@ def test_can_match_shapes(self): Bbox(3, 6, 2, 3, label=2, z_order=4, group=3), Points([4.5, 5.5], label=0, group=3), PolyLine([1, 1.25, 3, 1, 4, 2]), + RotatedBbox(2, 4, 2, 4, 10, label=1), ], ), ], @@ -313,6 +317,8 @@ def test_can_match_shapes(self): Points([5, 6], label=0, group=1), Points([6, 8], label=1, group=1), PolyLine([1, 1.25, 3, 1, 4, 2]), + RotatedBbox(4, 5, 2, 4, 20, label=2), + RotatedBbox(2, 4, 2, 4, 10, label=1), ], ), ], @@ -330,11 +336,21 @@ def test_can_match_shapes(self): sources={2}, ann=source0.get("1").annotations[5], ), + NoMatchingAnnError( + item_id=("1", DEFAULT_SUBSET_NAME), + sources={0}, + ann=source1.get("1").annotations[6], + ), NoMatchingAnnError( item_id=("1", DEFAULT_SUBSET_NAME), sources={1, 2}, ann=source0.get("1").annotations[0], ), + NoMatchingAnnError( + item_id=("1", DEFAULT_SUBSET_NAME), + sources={1, 2}, + ann=source0.get("1").annotations[7], + ), ], sorted( (e for e in merger.errors if isinstance(e, NoMatchingAnnError)),