Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix task chain for Det -> Cls / Seg #4105

Merged
merged 10 commits into from
Nov 8, 2024
21 changes: 19 additions & 2 deletions src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from abc import abstractmethod
from collections.abc import Iterable
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Generic, Iterator, List, Union
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, List, Union

import cv2
import numpy as np
Expand Down Expand Up @@ -92,6 +92,7 @@
self.image_color_channel = image_color_channel
self.stack_images = stack_images
self.to_tv_image = to_tv_image

if self.dm_subset.categories():
self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label])
else:
Expand Down Expand Up @@ -141,7 +142,11 @@
msg = f"Reach the maximum refetch number ({self.max_refetch})"
raise RuntimeError(msg)

def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, int]]:
def _get_img_data_and_shape(
kprokofi marked this conversation as resolved.
Show resolved Hide resolved
self,
img: Image,
roi: dict[str, Any] | None = None,
) -> tuple[np.ndarray, tuple[int, int]]:
key = img.path if isinstance(img, ImageFromFile) else id(img)

if (img_data := self.mem_cache_handler.get(key=key)[0]) is not None:
Expand All @@ -158,6 +163,18 @@
msg = "Cannot get image data"
raise RuntimeError(msg)

if roi:
# extract ROI from image
shape = roi["shape"]
h, w = img_data.shape[:2]
x1, y1, x2, y2 = (

Check warning on line 170 in src/otx/core/data/dataset/base.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/base.py#L168-L170

Added lines #L168 - L170 were not covered by tests
kprokofi marked this conversation as resolved.
Show resolved Hide resolved
np.trunc(shape["x1"] * w),
np.trunc(shape["y1"] * h),
np.ceil(shape["x2"] * w),
np.ceil(shape["y2"] * h),
)
img_data = img_data[int(y1) : int(y2), int(x1) : int(x2)]

Check warning on line 176 in src/otx/core/data/dataset/base.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/base.py#L176

Added line #L176 was not covered by tests

img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8))

return img_data, img_data.shape[:2]
Expand Down
24 changes: 12 additions & 12 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@
def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
img_data, img_shape = self._get_img_data_and_shape(img)
roi = item.attributes.get("roi", None)
img_data, img_shape = self._get_img_data_and_shape(img, roi)
if roi:
# extract labels from ROI
labels_ids = [

Check warning on line 39 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L39

Added line #L39 was not covered by tests
label["label"]["_id"] for label in roi["labels"] if label["label"]["domain"] == "CLASSIFICATION"
]
label_anns = [self.label_info.label_names.index(label_id) for label_id in labels_ids]

Check warning on line 42 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L42

Added line #L42 was not covered by tests
else:
# extract labels from annotations
label_anns = [ann.label for ann in item.annotations if isinstance(ann, Label)]

label_anns = []
for ann in item.annotations:
if isinstance(ann, Label):
label_anns.append(ann)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
if len(label_anns) > 1:
msg = f"Multi-class Classification can't use the multi-label, currently len(labels) = {len(label_anns)}"
raise ValueError(msg)
Expand All @@ -56,7 +56,7 @@
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
),
labels=torch.as_tensor([ann.label for ann in label_anns]),
labels=torch.as_tensor(label_anns),
)

return self._apply_transforms(entity)
Expand Down
Loading