Skip to content

Commit

Permalink
add matcher & merger
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjuleee committed Apr 19, 2024
1 parent 3a8b710 commit f2791c4
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 10 deletions.
42 changes: 36 additions & 6 deletions src/datumaro/components/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import math
from enum import IntEnum
from functools import partial
from itertools import zip_longest
Expand Down Expand Up @@ -852,8 +853,7 @@ class RotatedBbox(_Shape):

def __init__(self, x, y, w, h, r, *args, **kwargs):
kwargs.pop("points", None) # comes from wrap()
# points = x1, y1, x2, y2, x3, y3, x4, y4
self.__attrs_init__([x, y, x + w, y + h, r], *args, **kwargs)
self.__attrs_init__([x, y, w, h, r], *args, **kwargs)

@property
def x(self):
Expand All @@ -865,11 +865,11 @@ def y(self):

@property
def w(self):
return self.points[2] - self.points[0]
return self.points[2]

@property
def h(self):
return self.points[3] - self.points[1]
return self.points[3]

@property
def r(self):
Expand All @@ -879,10 +879,40 @@ def get_area(self):
return self.w * self.h

def get_bbox(self):
return [self.x, self.y, self.w, self.h]
points = self.as_polygon()
xs = [p for p in points[0::2]]
ys = [p for p in points[1::2]]

return [min(xs), min(ys), max(xs) - min(xs), max(ys) - min(ys)]

def get_rotated_bbox(self):
return [self.x, self.y, self.w, self.h, self.r]

def as_polygon(self) -> List[float]:
return self.points
"""Convert [center_x, center_y, width, height, rotation] to 8 coordinates for a rotated bounding box."""

half_width = self.w / 2
half_height = self.h / 2
rot = np.deg2rad(self.r)

# Calculate coordinates of the four corners
corners = np.array(
[
[-half_width, -half_height],
[half_width, -half_height],
[half_width, half_height],
[-half_width, half_height],
]
)

# Rotate the corners
transformed = []
for corner in corners:
x = corner[0] * math.cos(rot) - corner[1] * math.sin(rot) + self.x
y = corner[0] * math.sin(rot) + corner[1] * math.cos(rot) + self.y
transformed.extend([x, y])

return transformed

def iou(self, other: _Shape) -> Union[float, Literal[-1]]:
from datumaro.util.annotation_util import bbox_iou
Expand Down
6 changes: 4 additions & 2 deletions src/datumaro/components/annotations/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from datumaro.components.abstracts import IMergerContext
from datumaro.components.abstracts.merger import IMatcherContext
from datumaro.components.annotation import Annotation
from datumaro.components.annotation import Annotation, Points
from datumaro.util.annotation_util import (
OKS,
approximate_line,
Expand Down Expand Up @@ -371,5 +371,7 @@ def match_annotations(self, sources):

@attrs
class RotatedBboxMatcher(ShapeMatcher):
sigma: Optional[list] = attrib(default=None)

def distance(self, a, b):
return OKS(a, b, sigma=self.sigma)
return OKS(Points(a.as_polygon()), Points(b.as_polygon()), sigma=self.sigma)
5 changes: 3 additions & 2 deletions src/datumaro/components/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def __init__(self):
AnnotationType.points,
},
TaskType.detection_rotated: {
AnnotationType.label,
AnnotationType.polygon,
AnnotationType.rotated_bbox,
},
TaskType.detection_3d: {AnnotationType.label, AnnotationType.cuboid_3d},
TaskType.segmentation_semantic: {
Expand All @@ -53,6 +52,7 @@ def __init__(self):
TaskType.segmentation_instance: {
AnnotationType.label,
AnnotationType.bbox,
AnnotationType.rotated_bbox,
AnnotationType.ellipse,
AnnotationType.polygon,
AnnotationType.points,
Expand All @@ -65,6 +65,7 @@ def __init__(self):
TaskType.mixed: {
AnnotationType.label,
AnnotationType.bbox,
AnnotationType.rotated_bbox,
AnnotationType.cuboid_3d,
AnnotationType.ellipse,
AnnotationType.polygon,
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PointsCategories,
Polygon,
PolyLine,
RotatedBbox,
)
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DEFAULT_SUBSET_NAME, DatasetItem
Expand Down Expand Up @@ -222,6 +223,7 @@ def test_can_match_shapes(self):
Points([5, 6], label=0, group=1),
Points([6, 8], label=1, group=1),
PolyLine([1, 1, 2, 1, 3, 1]),
RotatedBbox(4, 5, 2, 4, 20, label=2),
],
),
],
Expand Down Expand Up @@ -251,6 +253,7 @@ def test_can_match_shapes(self):
Points([5.5, 6.5], label=0, group=2),
Points([6, 8], label=1, group=2),
PolyLine([1, 1.5, 2, 1.5]),
RotatedBbox(2, 4, 2, 4, 10, label=1),
],
),
],
Expand Down Expand Up @@ -280,6 +283,7 @@ def test_can_match_shapes(self):
Bbox(3, 6, 2, 3, label=2, z_order=4, group=3),
Points([4.5, 5.5], label=0, group=3),
PolyLine([1, 1.25, 3, 1, 4, 2]),
RotatedBbox(2, 4, 2, 4, 10, label=1),
],
),
],
Expand Down Expand Up @@ -313,6 +317,8 @@ def test_can_match_shapes(self):
Points([5, 6], label=0, group=1),
Points([6, 8], label=1, group=1),
PolyLine([1, 1.25, 3, 1, 4, 2]),
RotatedBbox(4, 5, 2, 4, 20, label=2),
RotatedBbox(2, 4, 2, 4, 10, label=1),
],
),
],
Expand All @@ -330,11 +336,21 @@ def test_can_match_shapes(self):
sources={2},
ann=source0.get("1").annotations[5],
),
NoMatchingAnnError(
item_id=("1", DEFAULT_SUBSET_NAME),
sources={0},
ann=source1.get("1").annotations[6],
),
NoMatchingAnnError(
item_id=("1", DEFAULT_SUBSET_NAME),
sources={1, 2},
ann=source0.get("1").annotations[0],
),
NoMatchingAnnError(
item_id=("1", DEFAULT_SUBSET_NAME),
sources={1, 2},
ann=source0.get("1").annotations[7],
),
],
sorted(
(e for e in merger.errors if isinstance(e, NoMatchingAnnError)),
Expand Down

0 comments on commit f2791c4

Please sign in to comment.