diff --git a/src/otx/core/data/dataset/classification.py b/src/otx/core/data/dataset/classification.py index 57170da967b..c5048dd7987 100644 --- a/src/otx/core/data/dataset/classification.py +++ b/src/otx/core/data/dataset/classification.py @@ -34,7 +34,16 @@ def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None: img = item.media_as(Image) img_data, img_shape = self._get_img_data_and_shape(img) - label_anns = [ann for ann in item.annotations if isinstance(ann, Label)] + label_anns = [] + for ann in item.annotations: + if isinstance(ann, Label): + label_anns.append(ann) + else: + # If the annotation is not Label, it should be converted to Label. + # For Chained Task: Detection (Bbox) -> Classification (Label) + label = Label(label=ann.label) + if label not in label_anns: + label_anns.append(label) if len(label_anns) > 1: msg = f"Multi-class Classification can't use the multi-label, currently len(labels) = {len(label_anns)}" raise ValueError(msg) @@ -71,7 +80,16 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None: ignored_labels: list[int] = [] # This should be assigned form item img_data, img_shape = self._get_img_data_and_shape(img) - label_anns = [ann for ann in item.annotations if isinstance(ann, Label)] + label_anns = [] + for ann in item.annotations: + if isinstance(ann, Label): + label_anns.append(ann) + else: + # If the annotation is not Label, it should be converted to Label. + # For Chained Task: Detection (Bbox) -> Classification (Label) + label = Label(label=ann.label) + if label not in label_anns: + label_anns.append(label) labels = torch.as_tensor([ann.label for ann in label_anns]) entity = MultilabelClsDataEntity( @@ -179,7 +197,16 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None: ignored_labels: list[int] = [] # This should be assigned form item img_data, img_shape = self._get_img_data_and_shape(img) - label_anns = [ann for ann in item.annotations if isinstance(ann, Label)] + label_anns = [] + for ann in item.annotations: + if isinstance(ann, Label): + label_anns.append(ann) + else: + # If the annotation is not Label, it should be converted to Label. + # For Chained Task: Detection (Bbox) -> Classification (Label) + label = Label(label=ann.label) + if label not in label_anns: + label_anns.append(label) hlabel_labels = self._convert_label_to_hlabel_format(label_anns, ignored_labels) entity = HlabelClsDataEntity( diff --git a/tests/unit/core/data/conftest.py b/tests/unit/core/data/conftest.py index 78ee2b1fb8c..6c3ca23dde6 100644 --- a/tests/unit/core/data/conftest.py +++ b/tests/unit/core/data/conftest.py @@ -96,6 +96,37 @@ def fxt_dm_item(request, tmpdir) -> DatasetItem: ) +@pytest.fixture(params=["bytes", "file"]) +def fxt_dm_item_bbox_only(request, tmpdir) -> DatasetItem: + np_img = np.zeros(shape=(10, 10, 3), dtype=np.uint8) + np_img[:, :, 0] = 0 # Set 0 for B channel + np_img[:, :, 1] = 1 # Set 1 for G channel + np_img[:, :, 2] = 2 # Set 2 for R channel + + if request.param == "bytes": + _, np_bytes = cv2.imencode(".png", np_img) + media = Image.from_bytes(np_bytes.tobytes()) + media.path = "" + elif request.param == "file": + fname = str(uuid.uuid4()) + fpath = str(Path(tmpdir) / f"{fname}.png") + cv2.imwrite(fpath, np_img) + media = Image.from_file(fpath) + else: + raise ValueError(request.param) + + return DatasetItem( + id="item", + subset="train", + media=media, + annotations=[ + Bbox(x=0, y=0, w=1, h=1, label=0), + Bbox(x=1, y=0, w=1, h=1, label=0), + Bbox(x=1, y=1, w=1, h=1, label=0), + ], + ) + + @pytest.fixture() def fxt_mock_dm_subset(mocker: MockerFixture, fxt_dm_item: DatasetItem) -> MagicMock: mock_dm_subset = mocker.MagicMock(spec=DmDataset) @@ -105,6 +136,15 @@ def fxt_mock_dm_subset(mocker: MockerFixture, fxt_dm_item: DatasetItem) -> Magic return mock_dm_subset +@pytest.fixture() +def fxt_mock_det_dm_subset(mocker: MockerFixture, fxt_dm_item_bbox_only: DatasetItem) -> MagicMock: + mock_dm_subset = mocker.MagicMock(spec=DmDataset) + mock_dm_subset.__getitem__.return_value = fxt_dm_item_bbox_only + mock_dm_subset.__len__.return_value = 1 + mock_dm_subset.categories().__getitem__.return_value = LabelCategories.from_iterable(_LABEL_NAMES) + return mock_dm_subset + + @pytest.fixture( params=[ (OTXHlabelClsDataset, HlabelClsDataEntity, {}), diff --git a/tests/unit/core/data/dataset/test_classification.py b/tests/unit/core/data/dataset/test_classification.py index 8bef7ffa4e2..bf2da750d9a 100644 --- a/tests/unit/core/data/dataset/test_classification.py +++ b/tests/unit/core/data/dataset/test_classification.py @@ -5,7 +5,65 @@ from unittest.mock import MagicMock -from otx.core.data.dataset.classification import OTXHlabelClsDataset +from otx.core.data.dataset.classification import ( + HLabelInfo, + OTXHlabelClsDataset, + OTXMulticlassClsDataset, + OTXMultilabelClsDataset, +) +from otx.core.data.entity.classification import HlabelClsDataEntity, MulticlassClsDataEntity, MultilabelClsDataEntity + + +class TestOTXMulticlassClsDataset: + def test_get_item( + self, + fxt_mock_dm_subset, + ) -> None: + dataset = OTXMulticlassClsDataset( + dm_subset=fxt_mock_dm_subset, + transforms=[lambda x: x], + mem_cache_img_max_size=None, + max_refetch=3, + ) + assert isinstance(dataset[0], MulticlassClsDataEntity) + + def test_get_item_from_bbox_dataset( + self, + fxt_mock_det_dm_subset, + ) -> None: + dataset = OTXMulticlassClsDataset( + dm_subset=fxt_mock_det_dm_subset, + transforms=[lambda x: x], + mem_cache_img_max_size=None, + max_refetch=3, + ) + assert isinstance(dataset[0], MulticlassClsDataEntity) + + +class TestOTXMultilabelClsDataset: + def test_get_item( + self, + fxt_mock_dm_subset, + ) -> None: + dataset = OTXMultilabelClsDataset( + dm_subset=fxt_mock_dm_subset, + transforms=[lambda x: x], + mem_cache_img_max_size=None, + max_refetch=3, + ) + assert isinstance(dataset[0], MultilabelClsDataEntity) + + def test_get_item_from_bbox_dataset( + self, + fxt_mock_det_dm_subset, + ) -> None: + dataset = OTXMultilabelClsDataset( + dm_subset=fxt_mock_det_dm_subset, + transforms=[lambda x: x], + mem_cache_img_max_size=None, + max_refetch=3, + ) + assert isinstance(dataset[0], MultilabelClsDataEntity) class TestOTXHlabelClsDataset: @@ -20,3 +78,33 @@ def test_add_ancestors(self, fxt_hlabel_dataset_subset): # Added the ancestor adjusted_anns = hlabel_dataset.dm_subset.get(id=0, subset="train").annotations assert len(adjusted_anns) == 2 + + def test_get_item( + self, + mocker, + fxt_mock_dm_subset, + fxt_mock_hlabelinfo, + ) -> None: + mocker.patch.object(HLabelInfo, "from_dm_label_groups", return_value=fxt_mock_hlabelinfo) + dataset = OTXHlabelClsDataset( + dm_subset=fxt_mock_dm_subset, + transforms=[lambda x: x], + mem_cache_img_max_size=None, + max_refetch=3, + ) + assert isinstance(dataset[0], HlabelClsDataEntity) + + def test_get_item_from_bbox_dataset( + self, + mocker, + fxt_mock_det_dm_subset, + fxt_mock_hlabelinfo, + ) -> None: + mocker.patch.object(HLabelInfo, "from_dm_label_groups", return_value=fxt_mock_hlabelinfo) + dataset = OTXHlabelClsDataset( + dm_subset=fxt_mock_det_dm_subset, + transforms=[lambda x: x], + mem_cache_img_max_size=None, + max_refetch=3, + ) + assert isinstance(dataset[0], HlabelClsDataEntity)