Skip to content

Commit

Permalink
Remove background label from RT Info for segmentation task (#4011)
Browse files Browse the repository at this point in the history
* remove background from rt_info

* provide another solution

* fix unit test
  • Loading branch information
kprokofi authored Oct 11, 2024
1 parent 7744c89 commit 7040faf
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _extract_class_mask(item: DatasetItem, img_shape: tuple[int, int], ignore_in
msg = "It is not currently support an ignore index which is more than 255."
raise ValueError(msg, ignore_index)

# fill mask with background label if we have Polygon/Ellipse annotations
# fill mask with background label if we have Polygon/Ellipse/Bbox annotations
fill_value = 0 if isinstance(item.annotations[0], (Ellipse, Polygon, Bbox, RotatedBbox)) else ignore_index
class_mask = np.full(shape=img_shape[:2], fill_value=fill_value, dtype=np.uint8)

Expand Down Expand Up @@ -179,9 +179,9 @@ def __init__(
to_tv_image,
)

if self.has_polygons and "background" not in [label_name.lower() for label_name in self.label_info.label_names]:
if self.has_polygons:
# insert background class at index 0 since polygons represent only objects
self.label_info.label_names.insert(0, "background")
self.label_info.label_names.insert(0, "otx_background_lbl")

self.label_info = SegLabelInfo(
label_names=self.label_info.label_names,
Expand Down
5 changes: 0 additions & 5 deletions src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,11 +1095,6 @@ def model_adapter_parameters(self) -> dict:
def _set_label_info(self, label_info: LabelInfoTypes) -> None:
"""Set this model label information."""
new_label_info = self._dispatch_label_info(label_info)

if self._label_info != new_label_info:
msg = "OVModel strictly does not allow overwrite label_info if they are different each other."
raise ValueError(msg)

self._label_info = new_label_info

def _create_label_info_from_ov_ir(self) -> LabelInfo:
Expand Down
9 changes: 9 additions & 0 deletions src/otx/core/model/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import copy
import json
from abc import abstractmethod
from collections.abc import Sequence
Expand Down Expand Up @@ -165,12 +166,20 @@ def _customize_outputs(
@property
def _export_parameters(self) -> TaskLevelExportParameters:
"""Defines parameters required to export a particular model implementation."""
if self.label_info.label_names[0] == "otx_background_lbl":
# remove otx background label for export
modified_label_info = copy.deepcopy(self.label_info)
modified_label_info.label_names.pop(0)
else:
modified_label_info = self.label_info

return super()._export_parameters.wrap(
model_type="Segmentation",
task_type="segmentation",
return_soft_prediction=True,
soft_threshold=0.5,
blur_strength=-1,
label_info=modified_label_info,
)

@property
Expand Down
25 changes: 17 additions & 8 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import copy
import csv
import inspect
import logging
Expand Down Expand Up @@ -370,14 +371,22 @@ def test(
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams)

if model.label_info != self.datamodule.label_info:
msg = (
"To launch a test pipeline, the label information should be same "
"between the training and testing datasets. "
"Please check whether you use the same dataset: "
f"model.label_info={model.label_info}, "
f"datamodule.label_info={self.datamodule.label_info}"
)
raise ValueError(msg)
if (
self.task == "SEMANTIC_SEGMENTATION"
and "otx_background_lbl" in self.datamodule.label_info.label_names
and (len(self.datamodule.label_info.label_names) - len(model.label_info.label_names) == 1)
):
# workaround for background label
model.label_info = copy.deepcopy(self.datamodule.label_info)
else:
msg = (
"To launch a test pipeline, the label information should be same "
"between the training and testing datasets. "
"Please check whether you use the same dataset: "
f"model.label_info={model.label_info}, "
f"datamodule.label_info={self.datamodule.label_info}"
)
raise ValueError(msg)

self._build_trainer(**kwargs)

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/core/data/dataset/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_get_item(
max_refetch=3,
)
assert isinstance(dataset[0], SegDataEntity)
assert "background" in [label_name.lower() for label_name in dataset.label_info.label_names]
assert "otx_background_lbl" in [label_name.lower() for label_name in dataset.label_info.label_names]

def test_get_item_from_bbox_dataset(
self,
Expand All @@ -33,4 +33,4 @@ def test_get_item_from_bbox_dataset(
)
assert isinstance(dataset[0], SegDataEntity)
# OTXSegmentationDataset should add background when getting a dataset which includes only bbox annotations
assert "background" in [label_name.lower() for label_name in dataset.label_info.label_names]
assert "otx_background_lbl" in [label_name.lower() for label_name in dataset.label_info.label_names]

0 comments on commit 7040faf

Please sign in to comment.