Skip to content

Commit

Permalink
update h-label head
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Aug 13, 2024
1 parent 900faaf commit 8e6f8f8
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 13 deletions.
9 changes: 7 additions & 2 deletions src/otx/algo/classification/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from __future__ import annotations

from copy import deepcopy
from copy import copy, deepcopy
from math import ceil
from typing import TYPE_CHECKING, Literal

from torch import Tensor, nn
Expand Down Expand Up @@ -269,14 +270,18 @@ def _build_model(self, head_config: dict) -> nn.Module:
raise TypeError(self.label_info)

backbone = OTXEfficientNet(version=self.version, input_size=self.input_size, pretrained=self.pretrained)

copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))

return ImageClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
**head_config,
**copied_head_config,
),
optimize_gap=False,
)
Expand Down
23 changes: 14 additions & 9 deletions src/otx/algo/classification/heads/hlabel_cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ class HierarchicalCBAMClsHead(HierarchicalClsHead):
thr (float, optional): Predictions with scores under the thresholds are considered
as negative. Defaults to 0.5.
init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None.
step_size (int, optional): Step size value for HierarchicalCBAMClsHead, Defaults to 7.
step_size (int | tuple[int, int], optional): Step size value for HierarchicalCBAMClsHead, Defaults to 7.
"""

def __init__(
Expand All @@ -435,7 +435,7 @@ def __init__(
multilabel_loss: nn.Module | None = None,
thr: float = 0.5,
init_cfg: dict | None = None,
step_size: int = 7,
step_size: int | tuple[int, int] = 7,
**kwargs,
):
super().__init__(
Expand All @@ -452,19 +452,19 @@ def __init__(
init_cfg=init_cfg,
**kwargs,
)
self.step_size = step_size
self.fc_superclass = nn.Linear(in_channels * step_size * step_size, num_multiclass_heads)
self.attention_fc = nn.Linear(num_multiclass_heads, in_channels * step_size * step_size)
self.step_size = (step_size, step_size) if isinstance(step_size, int) else tuple(step_size)
self.fc_superclass = nn.Linear(in_channels * self.step_size[0] * self.step_size[1], num_multiclass_heads)
self.attention_fc = nn.Linear(num_multiclass_heads, in_channels * self.step_size[0] * self.step_size[1])
self.cbam = CBAM(in_channels)
self.fc_subclass = nn.Linear(in_channels * step_size * step_size, num_single_label_classes)
self.fc_subclass = nn.Linear(in_channels * self.step_size[0] * self.step_size[1], num_single_label_classes)

self._init_layers()

def pre_logits(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor:
"""The process before the final classification head."""
if isinstance(feats, Sequence):
feats = feats[-1]
return feats.view(feats.size(0), self.in_channels * self.step_size * self.step_size)
return feats.view(feats.size(0), self.in_channels * self.step_size[0] * self.step_size[1])

def _init_layers(self) -> None:
"""Iniitialize weights of classification head."""
Expand All @@ -479,10 +479,15 @@ def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor:
attention_weights = torch.sigmoid(self.attention_fc(out_superclass))
attended_features = pre_logits * attention_weights

attended_features = attended_features.view(pre_logits.size(0), self.in_channels, self.step_size, self.step_size)
attended_features = attended_features.view(
pre_logits.size(0),
self.in_channels,
self.step_size[0],
self.step_size[1],
)
attended_features = self.cbam(attended_features)
attended_features = attended_features.view(
pre_logits.size(0),
self.in_channels * self.step_size * self.step_size,
self.in_channels * self.step_size[0] * self.step_size[1],
)
return self.fc_subclass(attended_features)
8 changes: 6 additions & 2 deletions src/otx/algo/classification/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from __future__ import annotations

from copy import deepcopy
from copy import copy, deepcopy
from math import ceil
from typing import TYPE_CHECKING, Any, Literal

import torch
Expand Down Expand Up @@ -331,14 +332,17 @@ def _build_model(self, head_config: dict) -> nn.Module:
if not isinstance(self.label_info, HLabelInfo):
raise TypeError(self.label_info)

copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))

return ImageClassifier(
backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size),
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=960,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
**head_config,
**copied_head_config,
),
optimize_gap=False,
)
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/algo/classification/heads/test_hlabel_cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,22 @@ def test_pre_logits(self, fxt_hierarchical_cbam_cls_head) -> None:
input_tensor = torch.rand((8, 64, 7, 7))
pre_logits = fxt_hierarchical_cbam_cls_head.pre_logits(input_tensor)
assert pre_logits.shape == (8, 64 * 7 * 7)

def test_pre_logits_tuple_step_size(self) -> None:
head_idx_to_logits_range = {"0": (0, 5), "1": (5, 10), "2": (10, 12)}
head = HierarchicalCBAMClsHead(
num_multiclass_heads=3,
num_multilabel_classes=0,
head_idx_to_logits_range=head_idx_to_logits_range,
num_single_label_classes=12,
empty_multiclass_head_indices=[],
in_channels=64,
num_classes=12,
multiclass_loss=CrossEntropyLoss(),
multilabel_loss=None,
step_size=(14, 7),
)

input_tensor = torch.rand((8, 64, 14, 7))
pre_logits = head.pre_logits(input_tensor)
assert pre_logits.shape == (8, 64 * 14 * 7)

0 comments on commit 8e6f8f8

Please sign in to comment.