Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Label Info handling #4127

Merged
merged 19 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/otx/algo/classification/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from otx.algo.classification.backbones import OTXMobileNetV3
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
HierarchicalLinearClsHead,
LinearClsHead,
MultiLabelNonLinearClsHead,
SemiSLLinearClsHead,
Expand Down Expand Up @@ -313,14 +313,12 @@ def _build_model(self, head_config: dict) -> nn.Module:

copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
in_channels = 960 if self.mode == "large" else 576

return HLabelClassifier(
backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size),
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=960,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=in_channels),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.algo.utils.support_otx_v1 import OTXv1Helper
Expand Down Expand Up @@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from otx.algo.classification.backbones.torchvision import TorchvisionBackbone, TVModelType
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.core.data.entity.classification import (
Expand Down Expand Up @@ -315,11 +315,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
backbone = TorchvisionBackbone(backbone=self.backbone, pretrained=self.pretrained)
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.in_features,
**head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**head_config, in_channels=backbone.in_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
8 changes: 2 additions & 6 deletions src/otx/algo/classification/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
MultiLabelLinearClsHead,
SemiSLVisionTransformerClsHead,
VisionTransformerClsHead,
)
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.utils import get_classification_layers
from otx.algo.explain.explain_algo import ViTReciproCAM, feature_vector_fn
from otx.algo.utils.support_otx_v1 import OTXv1Helper
Expand Down Expand Up @@ -466,11 +466,7 @@ def _build_model(self, head_config: dict) -> nn.Module:
return HLabelClassifier(
backbone=vit_backbone,
neck=None,
head=HierarchicalCBAMClsHead(
in_channels=vit_backbone.embed_dim,
step_size=1,
**head_config,
),
head=HierarchicalLinearClsHead(**head_config, in_channels=vit_backbone.embed_dim),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
init_cfg=init_cfg,
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.BGR,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
) -> None:
super().__init__(
dm_subset,
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
) -> None:
self.task_type = task_type
super().__init__(
Expand Down
6 changes: 5 additions & 1 deletion src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self.dm_subset = dm_subset
self.transforms = transforms
Expand All @@ -92,8 +93,11 @@
self.image_color_channel = image_color_channel
self.stack_images = stack_images
self.to_tv_image = to_tv_image
self.data_format = data_format

if self.dm_subset.categories():
if self.dm_subset.categories() and data_format == "arrow":
self.label_info = LabelInfo.from_dm_label_groups_arrow(self.dm_subset.categories()[AnnotationType.label])

Check warning on line 99 in src/otx/core/data/dataset/base.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/base.py#L99

Added line #L99 was not covered by tests
elif self.dm_subset.categories():
self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label])
else:
self.label_info = NullLabelInfo()
Expand Down
30 changes: 20 additions & 10 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,15 @@
self.dm_categories = self.dm_subset.categories()[AnnotationType.label]

# Hlabel classification used HLabelInfo to insert the HLabelData.
self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories)
if self.data_format == "arrow":
# arrow format stores label IDs as names, have to deal with that here
self.label_info = HLabelInfo.from_dm_label_groups_arrow(self.dm_categories)

Check warning on line 133 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L133

Added line #L133 was not covered by tests
else:
self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories)

self.id_to_name_mapping = dict(zip(self.label_info.label_ids, self.label_info.label_names))
self.id_to_name_mapping[""] = ""

if self.label_info.num_multiclass_heads == 0:
msg = "The number of multiclass heads should be larger than 0."
raise ValueError(msg)
Expand All @@ -149,14 +157,16 @@
"""

def _label_idx_to_name(idx: int) -> str:
return self.label_info.label_names[idx]
return self.dm_categories[idx].name

def _label_name_to_idx(name: str) -> int:
indices = [idx for idx, val in enumerate(self.label_info.label_names) if val == name]
return indices[0]

def _get_label_group_idx(label_name: str) -> int:
if isinstance(self.label_info, HLabelInfo):
if self.data_format == "arrow":
return self.label_info.class_to_group_idx[self.id_to_name_mapping[label_name]][0]

Check warning on line 169 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L169

Added line #L169 was not covered by tests
return self.label_info.class_to_group_idx[label_name][0]
msg = f"self.label_info should have HLabelInfo type, got {type(self.label_info)}"
raise ValueError(msg)
Expand Down Expand Up @@ -256,18 +266,18 @@
class_indices[i] = -1

for ann in label_anns:
ann_name = self.dm_categories.items[ann.label].name
ann_parent = self.dm_categories.items[ann.label].parent
if self.data_format == "arrow":
# skips unknown labels for instance, the empty one
if self.dm_categories.items[ann.label].name not in self.id_to_name_mapping:
continue
ann_name = self.id_to_name_mapping[self.dm_categories.items[ann.label].name]

Check warning on line 273 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L271-L273

Added lines #L271 - L273 were not covered by tests
else:
ann_name = self.dm_categories.items[ann.label].name
group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name]
(parent_group_idx, parent_in_group_idx) = (
self.label_info.class_to_group_idx[ann_parent] if ann_parent else (None, None)
)

if group_idx < num_multiclass_heads:
class_indices[group_idx] = in_group_idx
if parent_group_idx is not None and parent_in_group_idx is not None:
class_indices[parent_group_idx] = parent_in_group_idx
elif not ignored_labels or ann.label not in ignored_labels:
elif ann.label not in ignored_labels:

Check warning on line 280 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L280

Added line #L280 was not covered by tests
class_indices[num_multiclass_heads + in_group_idx] = 1
else:
class_indices[num_multiclass_heads + in_group_idx] = -1
Expand Down
4 changes: 3 additions & 1 deletion src/otx/core/data/dataset/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def __init__(
self.dm_subset = self._get_single_bbox_dataset(dm_subset)

if self.dm_subset.categories():
kp_labels = self.dm_subset.categories()[AnnotationType.points][0].labels
self.label_info = LabelInfo(
label_names=self.dm_subset.categories()[AnnotationType.points][0].labels,
label_names=kp_labels,
label_groups=[],
label_ids=[str(i) for i in range(len(kp_labels))],
)
else:
self.label_info = NullLabelInfo()
Expand Down
2 changes: 2 additions & 0 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
stack_images: bool = True,
to_tv_image: bool = True,
ignore_index: int = 255,
data_format: str = "",
) -> None:
super().__init__(
dm_subset,
Expand All @@ -187,6 +188,7 @@ def __init__(
label_names=self.label_info.label_names,
label_groups=self.label_info.label_groups,
ignore_index=ignore_index,
label_ids=self.label_info.label_ids,
)
self.ignore_index = ignore_index

Expand Down
2 changes: 2 additions & 0 deletions src/otx/core/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create( # noqa: PLR0911
dm_subset: DmDataset,
cfg_subset: SubsetConfig,
mem_cache_handler: MemCacheHandlerBase,
data_format: str,
mem_cache_img_max_size: tuple[int, int] | None = None,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
Expand All @@ -85,6 +86,7 @@ def create( # noqa: PLR0911
common_kwargs = {
"dm_subset": dm_subset,
"transforms": transforms,
"data_format": data_format,
"mem_cache_handler": mem_cache_handler,
"mem_cache_img_max_size": mem_cache_img_max_size,
"image_color_channel": image_color_channel,
Expand Down
10 changes: 3 additions & 7 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,6 @@ def __init__( # noqa: PLR0913
self.subsets: dict[str, OTXDataset] = {}
self.save_hyperparameters(ignore=["input_size"])

# TODO (Jaeguk): This is workaround for a bug in Datumaro.
# These lines should be removed after next datumaro release.
# https://github.com/openvinotoolkit/datumaro/pull/1223/files
from datumaro.plugins.data_formats.video import VIDEO_EXTENSIONS

VIDEO_EXTENSIONS.append(".mp4")

dataset = DmDataset.import_from(self.data_root, format=self.data_format)
if self.task != "H_LABEL_CLS":
dataset = pre_filtering(
Expand Down Expand Up @@ -193,6 +186,7 @@ def __init__( # noqa: PLR0913
dm_subset=dm_subset.as_dataset(),
cfg_subset=config_mapping[name],
mem_cache_handler=mem_cache_handler,
data_format=self.data_format,
mem_cache_img_max_size=mem_cache_img_max_size,
image_color_channel=image_color_channel,
stack_images=stack_images,
Expand Down Expand Up @@ -237,6 +231,7 @@ def __init__( # noqa: PLR0913
include_polygons=include_polygons,
ignore_index=ignore_index,
vpm_config=vpm_config,
data_format=self.data_format,
)
self.subsets[transform_key] = unlabeled_dataset
else:
Expand All @@ -251,6 +246,7 @@ def __init__( # noqa: PLR0913
include_polygons=include_polygons,
ignore_index=ignore_index,
vpm_config=vpm_config,
data_format=self.data_format,
)
self.subsets[name] = unlabeled_dataset

Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/pre_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def remove_unused_labels(dataset: DmDataset, data_format: str, ignore_index: int
used_labels = [0, *used_labels]
if data_format == "common_semantic_segmentation_with_subset_dirs" and len(original_categories) < len(used_labels):
msg = (
"There are labeles mismatch in dataset categories and actuall categories comes from semantic masks."
"There are labels mismatch in dataset categories and actual categories comes from semantic masks."
"Please, check `dataset_meta.json` file."
)
raise ValueError(msg)
Expand Down
8 changes: 6 additions & 2 deletions src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,11 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
if isinstance(label_info, int):
return LabelInfo.from_num_classes(num_classes=label_info)
if isinstance(label_info, Sequence) and all(isinstance(name, str) for name in label_info):
return LabelInfo(label_names=label_info, label_groups=[label_info])
return LabelInfo(
label_names=label_info,
label_groups=[label_info],
label_ids=[str(i) for i in range(len(label_info))],
)
if isinstance(label_info, LabelInfo):
return label_info

Expand Down Expand Up @@ -1113,7 +1117,7 @@ def _create_label_info_from_ov_ir(self) -> LabelInfo:
)

logger.warning(msg)
return LabelInfo(label_names=label_names, label_groups=[label_names])
return LabelInfo(label_names=label_names, label_groups=[label_names], label_ids=[])

msg = "Cannot construct LabelInfo from OpenVINO IR. Please check this model is trained by OTX."
raise ValueError(msg)
Expand Down
6 changes: 5 additions & 1 deletion src/otx/core/model/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,11 @@
if isinstance(label_info, int):
return SegLabelInfo.from_num_classes(num_classes=label_info)
if isinstance(label_info, Sequence) and all(isinstance(name, str) for name in label_info):
return SegLabelInfo(label_names=label_info, label_groups=[label_info])
return SegLabelInfo(

Check warning on line 248 in src/otx/core/model/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/segmentation.py#L248

Added line #L248 was not covered by tests
label_names=label_info,
label_groups=[label_info],
label_ids=[str(i) for i in range(len(label_info))],
)
if isinstance(label_info, SegLabelInfo):
return label_info

Expand Down
6 changes: 4 additions & 2 deletions src/otx/core/types/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,15 @@ def to_metadata(self) -> dict[tuple[str, str], str]:
all_label_ids = "None "
for lbl in self.label_info.label_names:
all_labels += lbl.replace(" ", "_") + " "
all_label_ids += lbl.replace(" ", "_") + " "
for lbl_id in self.label_info.label_ids:
all_label_ids += lbl_id + " "
else:
all_labels = ""
all_label_ids = ""
for lbl in self.label_info.label_names:
all_labels += lbl.replace(" ", "_") + " "
all_label_ids += lbl.replace(" ", "_") + " "
for lbl_id in self.label_info.label_ids:
all_label_ids += lbl_id + " "

metadata = {
# Common
Expand Down
Loading
Loading