Skip to content

Commit

Permalink
Fix get_item for Chained Tasks in Classification (#3931)
Browse files Browse the repository at this point in the history
* Fix Task Chain

* Add multi-label case as well

* Add multi-label case as well2

* Add H-label case
  • Loading branch information
harimkang authored Sep 5, 2024
1 parent 706f99b commit 53a7d9a
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 4 deletions.
33 changes: 30 additions & 3 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@ def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None:
img = item.media_as(Image)
img_data, img_shape = self._get_img_data_and_shape(img)

label_anns = [ann for ann in item.annotations if isinstance(ann, Label)]
label_anns = []
for ann in item.annotations:
if isinstance(ann, Label):
label_anns.append(ann)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
if len(label_anns) > 1:
msg = f"Multi-class Classification can't use the multi-label, currently len(labels) = {len(label_anns)}"
raise ValueError(msg)
Expand Down Expand Up @@ -71,7 +80,16 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None:
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape = self._get_img_data_and_shape(img)

label_anns = [ann for ann in item.annotations if isinstance(ann, Label)]
label_anns = []
for ann in item.annotations:
if isinstance(ann, Label):
label_anns.append(ann)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
labels = torch.as_tensor([ann.label for ann in label_anns])

entity = MultilabelClsDataEntity(
Expand Down Expand Up @@ -179,7 +197,16 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None:
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape = self._get_img_data_and_shape(img)

label_anns = [ann for ann in item.annotations if isinstance(ann, Label)]
label_anns = []
for ann in item.annotations:
if isinstance(ann, Label):
label_anns.append(ann)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
hlabel_labels = self._convert_label_to_hlabel_format(label_anns, ignored_labels)

entity = HlabelClsDataEntity(
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/core/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,37 @@ def fxt_dm_item(request, tmpdir) -> DatasetItem:
)


@pytest.fixture(params=["bytes", "file"])
def fxt_dm_item_bbox_only(request, tmpdir) -> DatasetItem:
np_img = np.zeros(shape=(10, 10, 3), dtype=np.uint8)
np_img[:, :, 0] = 0 # Set 0 for B channel
np_img[:, :, 1] = 1 # Set 1 for G channel
np_img[:, :, 2] = 2 # Set 2 for R channel

if request.param == "bytes":
_, np_bytes = cv2.imencode(".png", np_img)
media = Image.from_bytes(np_bytes.tobytes())
media.path = ""
elif request.param == "file":
fname = str(uuid.uuid4())
fpath = str(Path(tmpdir) / f"{fname}.png")
cv2.imwrite(fpath, np_img)
media = Image.from_file(fpath)
else:
raise ValueError(request.param)

return DatasetItem(
id="item",
subset="train",
media=media,
annotations=[
Bbox(x=0, y=0, w=1, h=1, label=0),
Bbox(x=1, y=0, w=1, h=1, label=0),
Bbox(x=1, y=1, w=1, h=1, label=0),
],
)


@pytest.fixture()
def fxt_mock_dm_subset(mocker: MockerFixture, fxt_dm_item: DatasetItem) -> MagicMock:
mock_dm_subset = mocker.MagicMock(spec=DmDataset)
Expand All @@ -105,6 +136,15 @@ def fxt_mock_dm_subset(mocker: MockerFixture, fxt_dm_item: DatasetItem) -> Magic
return mock_dm_subset


@pytest.fixture()
def fxt_mock_det_dm_subset(mocker: MockerFixture, fxt_dm_item_bbox_only: DatasetItem) -> MagicMock:
mock_dm_subset = mocker.MagicMock(spec=DmDataset)
mock_dm_subset.__getitem__.return_value = fxt_dm_item_bbox_only
mock_dm_subset.__len__.return_value = 1
mock_dm_subset.categories().__getitem__.return_value = LabelCategories.from_iterable(_LABEL_NAMES)
return mock_dm_subset


@pytest.fixture(
params=[
(OTXHlabelClsDataset, HlabelClsDataEntity, {}),
Expand Down
90 changes: 89 additions & 1 deletion tests/unit/core/data/dataset/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,65 @@

from unittest.mock import MagicMock

from otx.core.data.dataset.classification import OTXHlabelClsDataset
from otx.core.data.dataset.classification import (
HLabelInfo,
OTXHlabelClsDataset,
OTXMulticlassClsDataset,
OTXMultilabelClsDataset,
)
from otx.core.data.entity.classification import HlabelClsDataEntity, MulticlassClsDataEntity, MultilabelClsDataEntity


class TestOTXMulticlassClsDataset:
def test_get_item(
self,
fxt_mock_dm_subset,
) -> None:
dataset = OTXMulticlassClsDataset(
dm_subset=fxt_mock_dm_subset,
transforms=[lambda x: x],
mem_cache_img_max_size=None,
max_refetch=3,
)
assert isinstance(dataset[0], MulticlassClsDataEntity)

def test_get_item_from_bbox_dataset(
self,
fxt_mock_det_dm_subset,
) -> None:
dataset = OTXMulticlassClsDataset(
dm_subset=fxt_mock_det_dm_subset,
transforms=[lambda x: x],
mem_cache_img_max_size=None,
max_refetch=3,
)
assert isinstance(dataset[0], MulticlassClsDataEntity)


class TestOTXMultilabelClsDataset:
def test_get_item(
self,
fxt_mock_dm_subset,
) -> None:
dataset = OTXMultilabelClsDataset(
dm_subset=fxt_mock_dm_subset,
transforms=[lambda x: x],
mem_cache_img_max_size=None,
max_refetch=3,
)
assert isinstance(dataset[0], MultilabelClsDataEntity)

def test_get_item_from_bbox_dataset(
self,
fxt_mock_det_dm_subset,
) -> None:
dataset = OTXMultilabelClsDataset(
dm_subset=fxt_mock_det_dm_subset,
transforms=[lambda x: x],
mem_cache_img_max_size=None,
max_refetch=3,
)
assert isinstance(dataset[0], MultilabelClsDataEntity)


class TestOTXHlabelClsDataset:
Expand All @@ -20,3 +78,33 @@ def test_add_ancestors(self, fxt_hlabel_dataset_subset):
# Added the ancestor
adjusted_anns = hlabel_dataset.dm_subset.get(id=0, subset="train").annotations
assert len(adjusted_anns) == 2

def test_get_item(
self,
mocker,
fxt_mock_dm_subset,
fxt_mock_hlabelinfo,
) -> None:
mocker.patch.object(HLabelInfo, "from_dm_label_groups", return_value=fxt_mock_hlabelinfo)
dataset = OTXHlabelClsDataset(
dm_subset=fxt_mock_dm_subset,
transforms=[lambda x: x],
mem_cache_img_max_size=None,
max_refetch=3,
)
assert isinstance(dataset[0], HlabelClsDataEntity)

def test_get_item_from_bbox_dataset(
self,
mocker,
fxt_mock_det_dm_subset,
fxt_mock_hlabelinfo,
) -> None:
mocker.patch.object(HLabelInfo, "from_dm_label_groups", return_value=fxt_mock_hlabelinfo)
dataset = OTXHlabelClsDataset(
dm_subset=fxt_mock_det_dm_subset,
transforms=[lambda x: x],
mem_cache_img_max_size=None,
max_refetch=3,
)
assert isinstance(dataset[0], HlabelClsDataEntity)

0 comments on commit 53a7d9a

Please sign in to comment.