diff --git a/otx/algorithms/segmentation/adapters/mmseg/__init__.py b/otx/algorithms/segmentation/adapters/mmseg/__init__.py index ff3a9626547..9c368fd2f2c 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/__init__.py +++ b/otx/algorithms/segmentation/adapters/mmseg/__init__.py @@ -17,7 +17,6 @@ from .datasets import MPASegDataset from .models import ( - ClassIncrEncoderDecoder, ConstantScalarScheduler, CrossEntropyLossWithIgnore, CustomFCNHead, @@ -57,6 +56,5 @@ "DetConB", "CrossEntropyLossWithIgnore", "SupConDetConB", - "ClassIncrEncoderDecoder", "MeanTeacherSegmentor", ] diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/__init__.py b/otx/algorithms/segmentation/adapters/mmseg/models/__init__.py index 5a3d18db740..0475eff3aff 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/models/__init__.py +++ b/otx/algorithms/segmentation/adapters/mmseg/models/__init__.py @@ -24,7 +24,6 @@ StepScalarScheduler, ) from .segmentors import ( - ClassIncrEncoderDecoder, DetConB, MeanTeacherSegmentor, SupConDetConB, @@ -43,7 +42,6 @@ "DetConB", "CrossEntropyLossWithIgnore", "SupConDetConB", - "ClassIncrEncoderDecoder", "MeanTeacherSegmentor", "DetConHead", ] diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/heads/mixin.py b/otx/algorithms/segmentation/adapters/mmseg/models/heads/mixin.py deleted file mode 100644 index ff8a5eef05c..00000000000 --- a/otx/algorithms/segmentation/adapters/mmseg/models/heads/mixin.py +++ /dev/null @@ -1,348 +0,0 @@ -"""Modules for aggregator and loss mix.""" -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# -import torch -import torch.nn.functional as F -from mmcv.runner import force_fp32 -from mmseg.core import add_prefix -from mmseg.models.losses import accuracy -from mmseg.ops import resize -from torch import nn - -from otx.algorithms.segmentation.adapters.mmseg.models.utils import ( - AngularPWConv, - IterativeAggregator, - LossEqualizer, - normalize, -) -from otx.algorithms.segmentation.adapters.mmseg.utils import ( - get_valid_label_mask_per_batch, -) - -# pylint: disable=abstract-method, unused-argument, keyword-arg-before-vararg - - -class SegmentOutNormMixin(nn.Module): - """SegmentOutNormMixin.""" - - def __init__(self, *args, enable_out_seg=True, enable_out_norm=False, **kwargs): - super().__init__(*args, **kwargs) - - self.enable_out_seg = enable_out_seg - self.enable_out_norm = enable_out_norm - - if enable_out_seg: - if enable_out_norm: - self.conv_seg = AngularPWConv(self.channels, self.out_channels, clip_output=True) - else: - self.conv_seg = None - - def cls_seg(self, feat): - """Classify each pixel.""" - if self.dropout is not None: - feat = self.dropout(feat) - if self.enable_out_norm: - feat = normalize(feat, dim=1, p=2) - if self.conv_seg is not None: - return self.conv_seg(feat) - return feat - - -class AggregatorMixin(nn.Module): - """A class for creating an aggregator.""" - - def __init__( - self, - *args, - enable_aggregator=False, - aggregator_min_channels=None, - aggregator_merge_norm=None, - aggregator_use_concat=False, - **kwargs, - ): - - in_channels = kwargs.get("in_channels") - in_index = kwargs.get("in_index") - norm_cfg = kwargs.get("norm_cfg") - conv_cfg = kwargs.get("conv_cfg") - input_transform = kwargs.get("input_transform") - - aggregator = None - if enable_aggregator: - assert isinstance(in_channels, (tuple, list)) - assert len(in_channels) > 1 - - aggregator = IterativeAggregator( - in_channels=in_channels, - min_channels=aggregator_min_channels, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - merge_norm=aggregator_merge_norm, - use_concat=aggregator_use_concat, - ) - - aggregator_min_channels = aggregator_min_channels if aggregator_min_channels is not None else 0 - # change arguments temporarily - kwargs["in_channels"] = max(in_channels[0], aggregator_min_channels) - kwargs["input_transform"] = None - if in_index is not None: - kwargs["in_index"] = in_index[0] - - super().__init__(*args, **kwargs) - - self.aggregator = aggregator - # re-define variables - self.in_channels = in_channels - self.input_transform = input_transform - self.in_index = in_index - - def _transform_inputs(self, inputs): - inputs = super()._transform_inputs(inputs) - if self.aggregator is not None: - inputs = self.aggregator(inputs)[0] - return inputs - - -class MixLossMixin(nn.Module): - """Loss mixing module.""" - - @staticmethod - def _mix_loss(logits, target, ignore_index=255): - num_samples = logits.size(0) - assert num_samples % 2 == 0 - - with torch.no_grad(): - probs = F.softmax(logits, dim=1) - probs_a, probs_b = torch.split(probs, num_samples // 2) - mean_probs = 0.5 * (probs_a + probs_b) - trg_probs = torch.cat([mean_probs, mean_probs], dim=0) - - log_probs = torch.log_softmax(logits, dim=1) - losses = torch.sum(trg_probs * log_probs, dim=1).neg() - - valid_mask = target != ignore_index - valid_losses = torch.where(valid_mask, losses, torch.zeros_like(losses)) - - return valid_losses.mean() - - @force_fp32(apply_to=("seg_logit",)) - def losses(self, seg_logit, seg_label, train_cfg, *args, **kwargs): - """Loss computing.""" - loss = super().losses(seg_logit, seg_label, train_cfg, *args, **kwargs) - if train_cfg.get("mix_loss", None) and train_cfg.mix_loss.get("enable", False): - mix_loss = self._mix_loss(seg_logit, seg_label, ignore_index=self.ignore_index) - - mix_loss_weight = train_cfg.mix_loss.get("weight", 1.0) - loss["loss_mix"] = mix_loss_weight * mix_loss - - return loss - - -class PixelWeightsMixin(nn.Module): - """PixelWeightsMixin.""" - - def __init__(self, enable_loss_equalizer=False, loss_target="gt_semantic_seg", *args, **kwargs): - super().__init__(*args, **kwargs) - - self.enable_loss_equalizer = enable_loss_equalizer - self.loss_target = loss_target - - self.loss_equalizer = None - if enable_loss_equalizer: - self.loss_equalizer = LossEqualizer() - - self.forward_output = None - - @property - def loss_target_name(self): - """Return loss target name.""" - return self.loss_target - - @property - def last_scale(self): - """Return the last scale.""" - if not isinstance(self.loss_decode, nn.ModuleList): - losses_decode = [self.loss_decode] - else: - losses_decode = self.loss_decode - - num_losses = len(losses_decode) - if num_losses <= 0: - return 1.0 - - loss_module = losses_decode[0] - if not hasattr(loss_module, "last_scale"): - return 1.0 - - return loss_module.last_scale - - def set_step_params(self, init_iter, epoch_size): - """Set step parameters.""" - if not isinstance(self.loss_decode, nn.ModuleList): - losses_decode = [self.loss_decode] - else: - losses_decode = self.loss_decode - - for loss_module in losses_decode: - if hasattr(loss_module, "set_step_params"): - loss_module.set_step_params(init_iter, epoch_size) - - def forward_train( - self, - inputs, - img_metas, - gt_semantic_seg, - train_cfg, - pixel_weights=None, - return_logits=False, - ): - """Forward function for training. - - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - gt_semantic_seg (Tensor): Semantic segmentation masks - used if the architecture supports semantic segmentation task. - train_cfg (dict): The training config. - pixel_weights (Tensor): Pixels weights. - return_logits (bool): Flag to retun the logit with losses. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - - seg_logits = self.forward(inputs) - losses = self.losses(seg_logits, gt_semantic_seg, train_cfg, pixel_weights) - - if return_logits: - logits = self.forward_output if self.forward_output is not None else seg_logits - return losses, logits - return losses - - @force_fp32(apply_to=("seg_logit",)) - def losses(self, seg_logit, seg_label, train_cfg, pixel_weights=None): - """Compute segmentation loss.""" - - loss = dict() - - seg_logit = resize( - input=seg_logit, - size=seg_label.shape[2:], - mode="bilinear", - align_corners=self.align_corners, - ) - - seg_label = seg_label.squeeze(1) - - if not isinstance(self.loss_decode, nn.ModuleList): - losses_decode = [self.loss_decode] - else: - losses_decode = self.loss_decode - - out_losses = dict() - for loss_idx, loss_module in enumerate(losses_decode): - loss_value, loss_meta = loss_module(seg_logit, seg_label, pixel_weights=pixel_weights) - - loss_name = loss_module.name + f"-{loss_idx}" - out_losses[loss_name] = loss_value - loss.update(add_prefix(loss_meta, loss_name)) - - if self.enable_loss_equalizer and len(losses_decode) > 1: - out_losses = self.loss_equalizer.reweight(out_losses) - - for loss_name, loss_value in out_losses.items(): - loss[loss_name] = loss_value - - loss["loss_seg"] = sum(out_losses.values()) - loss["acc_seg"] = accuracy(seg_logit, seg_label) - - return loss - - -class PixelWeightsMixin2(PixelWeightsMixin): - """Pixel weight mixin class.""" - - def forward_train( - self, - inputs, - img_metas, - gt_semantic_seg, - train_cfg, - pixel_weights=None, - return_logits=False, - ): - """Forward function for training. - - Args: - inputs (list[Tensor]): List of multi-level img features. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', 'img_norm_cfg', - and 'ignored_labels'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - gt_semantic_seg (Tensor): Semantic segmentation masks - used if the architecture supports semantic segmentation task. - train_cfg (dict): The training config. - pixel_weights (Tensor): Pixels weights. - return_logits (bool): Flag to retun the logit with losses. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - seg_logits = self(inputs) - valid_label_mask = get_valid_label_mask_per_batch(img_metas, self.num_classes) - losses = self.losses( - seg_logits, gt_semantic_seg, train_cfg, valid_label_mask=valid_label_mask, pixel_weights=pixel_weights - ) - - if return_logits: - logits = self.forward_output if self.forward_output is not None else seg_logits - return losses, logits - return losses - - @force_fp32(apply_to=("seg_logit",)) - def losses( - self, seg_logit, seg_label, train_cfg, valid_label_mask, pixel_weights=None - ): # pylint: disable=arguments-renamed - """Compute segmentation loss.""" - - loss = dict() - - seg_logit = resize( - input=seg_logit, - size=seg_label.shape[2:], - mode="bilinear", - align_corners=self.align_corners, - ) - - seg_label = seg_label.squeeze(1) - - if not isinstance(self.loss_decode, nn.ModuleList): - losses_decode = [self.loss_decode] - else: - losses_decode = self.loss_decode - - out_losses = dict() - for loss_idx, loss_module in enumerate(losses_decode): - loss_value, loss_meta = loss_module(seg_logit, seg_label, valid_label_mask, pixel_weights=pixel_weights) - - loss_name = loss_module.name + f"-{loss_idx}" - out_losses[loss_name] = loss_value - loss.update(add_prefix(loss_meta, loss_name)) - - if self.enable_loss_equalizer and len(losses_decode) > 1: - out_losses = self.loss_equalizer.reweight(out_losses) - - for loss_name, loss_value in out_losses.items(): - loss[loss_name] = loss_value - - loss["loss_seg"] = sum(out_losses.values()) - loss["acc_seg"] = accuracy(seg_logit, seg_label) - - return loss diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/losses/base_pixel_loss.py b/otx/algorithms/segmentation/adapters/mmseg/models/losses/base_pixel_loss.py deleted file mode 100644 index 1d0d69e27c8..00000000000 --- a/otx/algorithms/segmentation/adapters/mmseg/models/losses/base_pixel_loss.py +++ /dev/null @@ -1,163 +0,0 @@ -"""Base pixel loss.""" -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -from abc import abstractmethod - -import torch -import torch.nn.functional as F -from mmseg.models.losses.utils import weight_reduce_loss - -from otx.algorithms.segmentation.adapters.mmseg.utils.builder import ( - build_scalar_scheduler, -) - -from .base_weighted_loss import BaseWeightedLoss - - -def entropy(p, dim=1, keepdim=False): - """Calculates the entropy.""" - return -torch.where(p > 0.0, p * p.log(), torch.zeros_like(p)).sum(dim=dim, keepdim=keepdim) - - -class BasePixelLoss(BaseWeightedLoss): - """Base pixel loss.""" - - def __init__(self, scale_cfg=None, pr_product=False, conf_penalty_weight=None, border_reweighting=False, **kwargs): - super().__init__(**kwargs) - - self._enable_pr_product = pr_product - self._border_reweighting = border_reweighting - - self._reg_weight_scheduler = build_scalar_scheduler(conf_penalty_weight) - self._scale_scheduler = build_scalar_scheduler(scale_cfg, default_value=1.0) - - self._last_scale = 0.0 - self._last_reg_weight = 0.0 - - @property - def last_scale(self): - """Return last_scale.""" - return self._last_scale - - @property - def last_reg_weight(self): - """Return last_reg_weight.""" - return self._last_reg_weight - - @property - def with_regularization(self): - """Check regularization use.""" - return self._reg_weight_scheduler is not None - - @property - def with_pr_product(self): - """Check pr_product.""" - return self._enable_pr_product - - @property - def with_border_reweighting(self): - """Check border reweighting.""" - return self._border_reweighting - - @staticmethod - def _pr_product(prod): - alpha = torch.sqrt(1.0 - prod.pow(2.0)) - out_prod = alpha.detach() * prod + prod.detach() * (1.0 - alpha) - - return out_prod - - @staticmethod - def _regularization(logits, scale, weight): - probs = F.softmax(scale * logits, dim=1) - entropy_values = entropy(probs, dim=1) - out_values = -weight * entropy_values - - return out_values - - @staticmethod - def _sparsity(values, valid_mask): - with torch.no_grad(): - valid_values = values[valid_mask] - sparsity = 1.0 - valid_values.count_nonzero() / max(1.0, valid_mask.sum()) - return sparsity.item() - - @staticmethod - def _pred_stat(output, labels, valid_mask, window_size=5, min_group_ratio=0.6): - assert window_size > 1 - assert 0.0 < min_group_ratio < 1.0 - - min_group_size = int(min_group_ratio * window_size * window_size) - assert min_group_size > 0 - - with torch.no_grad(): - predictions = torch.argmax(output, dim=1) - invalid_pred_mask = valid_mask & (predictions != labels) - - group_sizes = F.avg_pool2d( - invalid_pred_mask.float(), - kernel_size=window_size, - stride=1, - padding=(window_size - 1) // 2, - divisor_override=1, - ) - large_group_mask = invalid_pred_mask & (group_sizes >= min_group_size) - - num_target = torch.sum(large_group_mask, dim=(1, 2)) - num_total = torch.sum(invalid_pred_mask, dim=(1, 2)) - out_ratio = torch.mean(num_target / num_total.clamp_min(1)) - - return out_ratio.item() - - def _forward( - self, output, labels, avg_factor=None, pixel_weights=None, reduction_override=None - ): # pylint: disable=too-many-locals - assert reduction_override in (None, "none", "mean", "sum") - reduction = reduction_override if reduction_override else self.reduction - - self._last_scale = self._scale_scheduler(self.iter, self.epoch_size) - - if self.with_pr_product: - output = self._pr_product(output) - - num_classes = output.size(1) - valid_labels = torch.clamp(labels, 0, num_classes - 1) - valid_mask = labels != self.ignore_index - - losses, updated_output = self._calculate(output, valid_labels, self._last_scale) - - if self.with_regularization: - self._last_reg_weight = self._reg_weight_scheduler(self.iter, self.epoch_size) - regularization = self._regularization(updated_output, self._last_scale, self._last_reg_weight) - losses = torch.clamp_min(losses + regularization, 0.0) - - if self.with_border_reweighting: - assert pixel_weights is not None - losses = pixel_weights.squeeze(1) * losses - - losses = torch.where(valid_mask, losses, torch.zeros_like(losses)) - raw_sparsity = self._sparsity(losses, valid_mask) - invalid_ratio = self._pred_stat(output, labels, valid_mask) - - weight, weight_sparsity = None, 0.0 - if self.sampler is not None: - weight = self.sampler(losses, output, valid_labels, valid_mask) - weight_sparsity = self._sparsity(weight, valid_mask) - - loss = weight_reduce_loss(losses, weight=weight, reduction=reduction, avg_factor=avg_factor) - - meta = dict( - weight=self.last_loss_weight, - reg_weight=self.last_reg_weight, - scale=self.last_scale, - raw_sparsity=raw_sparsity, - weight_sparsity=weight_sparsity, - invalid_ratio=invalid_ratio, - ) - - return loss, meta - - @abstractmethod - def _calculate(self, output, labels, scale): - pass diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/losses/base_weighted_loss.py b/otx/algorithms/segmentation/adapters/mmseg/models/losses/base_weighted_loss.py deleted file mode 100644 index 2487eed40fe..00000000000 --- a/otx/algorithms/segmentation/adapters/mmseg/models/losses/base_weighted_loss.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Base weighted loss function for semantic segmentation.""" -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -from abc import ABCMeta, abstractmethod - -import torch -from mmseg.core import build_pixel_sampler -from scipy.special import erfinv # pylint: disable=no-name-in-module -from torch import nn - -from otx.algorithms.segmentation.adapters.mmseg.utils.builder import ( - build_scalar_scheduler, -) - - -# pylint: disable=too-many-instance-attributes, unused-argument -class BaseWeightedLoss(nn.Module, metaclass=ABCMeta): - """Base class for loss. - - All subclass should overwrite the ``_forward()`` method which returns the - normal loss without loss weights. - - Args: - loss_weight (float or dict): Factor scalar multiplied on the loss. - Default: 1.0. - """ - - def __init__( - self, - reduction="mean", - loss_weight=1.0, - ignore_index=255, - sampler=None, - loss_jitter_prob=None, - loss_jitter_momentum=0.1, - **kwargs - ): - super().__init__() - - self.reduction = reduction - self.ignore_index = ignore_index - - self.sampler = None - if sampler is not None: - self.sampler = build_pixel_sampler(sampler, ignore_index=ignore_index) - - self._smooth_loss = None - self._jitter_sigma_factor = None - self._loss_jitter_momentum = loss_jitter_momentum - assert 0.0 < self._loss_jitter_momentum < 1.0 - if loss_jitter_prob is not None: - assert 0.0 < loss_jitter_prob < 1.0 - self._jitter_sigma_factor = 1.0 / ((2.0**0.5) * erfinv(1.0 - 2.0 * loss_jitter_prob)) - - self._loss_weight_scheduler = build_scalar_scheduler(loss_weight, default_value=1.0) - - self._iter = 0 - self._last_loss_weight = 0 - self._epoch_size = 1 - - def set_step_params(self, init_iter, epoch_size): - """Set step parameters.""" - assert init_iter >= 0 - assert epoch_size > 0 - - self._iter = init_iter - self._epoch_size = epoch_size - - @property - def with_loss_jitter(self): - """Check loss jitter.""" - return self._jitter_sigma_factor is not None - - @property - def iter(self): - """Return iteration.""" - return self._iter - - @property - def epoch_size(self): - """Return epoch size.""" - return self._epoch_size - - @property - def last_loss_weight(self): - """Return last loss weight.""" - return self._last_loss_weight - - @abstractmethod - def _forward(self, *args, **kwargs): - pass - - def forward(self, *args, **kwargs): - """Defines the computation performed at every call. - - Args: - *args: The positional arguments for the corresponding - loss. - **kwargs: The keyword arguments for the corresponding - loss. - - Returns: - torch.Tensor: The calculated loss. - """ - - loss, meta = self._forward(*args, **kwargs) - # make sure meta data are tensor as well for aggregation - # when parsing loss in sgementator - for key, val in meta.items(): - meta[key] = torch.tensor(val, dtype=loss.dtype, device=loss.device) - - if self.with_loss_jitter and loss.numel() == 1: - if self._smooth_loss is None: - self._smooth_loss = loss.item() - else: - self._smooth_loss = ( - 1.0 - self._loss_jitter_momentum - ) * self._smooth_loss + self._loss_jitter_momentum * loss.item() - - jitter_sigma = self._jitter_sigma_factor * abs(self._smooth_loss) - jitter_point = torch.normal(0.0, jitter_sigma, [], device=loss.device, dtype=loss.dtype) - loss = (loss - jitter_point).abs() + jitter_point - - self._last_loss_weight = self._loss_weight_scheduler(self.iter, self.epoch_size) - out_loss = self._last_loss_weight * loss - - self._iter += 1 - - return out_loss, meta diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/losses/otx_pixel_base.py b/otx/algorithms/segmentation/adapters/mmseg/models/losses/otx_pixel_base.py deleted file mode 100644 index b9f68a5ffc4..00000000000 --- a/otx/algorithms/segmentation/adapters/mmseg/models/losses/otx_pixel_base.py +++ /dev/null @@ -1,76 +0,0 @@ -"""OTX pixel loss.""" -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import torch -from mmseg.models.losses.utils import weight_reduce_loss - -from .base_pixel_loss import BasePixelLoss - -# pylint: disable=too-many-function-args, too-many-locals - - -class OTXBasePixelLoss(BasePixelLoss): # pylint: disable=abstract-method - """OTXBasePixelLoss.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def _forward( - self, - output, - labels, - valid_label_mask, - avg_factor=None, - pixel_weights=None, - reduction_override=None, - ): # pylint: disable=arguments-renamed - assert reduction_override in (None, "none", "mean", "sum") - reduction = reduction_override if reduction_override else self.reduction - - self._last_scale = self._scale_scheduler(self.iter, self.epoch_size) - - if self.with_pr_product: - output = self._pr_product(output) - - import numpy as np - - _labels = labels.cpu().detach().numpy() - _labels[np.where((_labels == self.ignore_index))] = 0 - num_classes = output.size(1) - valid_labels = torch.clamp(labels, 0, num_classes - 1) - valid_mask = labels != self.ignore_index - - losses, updated_output = self._calculate(output, _labels, valid_label_mask, self._last_scale) - - if self.with_regularization: - self._last_reg_weight = self._reg_weight_scheduler(self.iter, self.epoch_size) - regularization = self._regularization(updated_output, self._last_scale, self._last_reg_weight) - losses = torch.clamp_min(losses + regularization, 0.0) - - if self.with_border_reweighting: - assert pixel_weights is not None - losses = pixel_weights.squeeze(1) * losses - - losses = torch.where(valid_mask, losses, torch.zeros_like(losses)) - raw_sparsity = self._sparsity(losses, valid_mask) - invalid_ratio = self._pred_stat(output, labels, valid_mask) - - weight, weight_sparsity = None, 0.0 - if self.sampler is not None: - weight = self.sampler.sample(output, valid_labels, losses, valid_mask) - weight_sparsity = self._sparsity(weight, valid_mask) - - loss = weight_reduce_loss(losses, weight=weight, reduction=reduction, avg_factor=avg_factor) - - meta = dict( - weight=self.last_loss_weight, - reg_weight=self.last_reg_weight, - scale=self.last_scale, - raw_sparsity=raw_sparsity, - weight_sparsity=weight_sparsity, - invalid_ratio=invalid_ratio, - ) - - return loss, meta diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/__init__.py b/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/__init__.py index d953b628f81..5e3af6f2a9e 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/__init__.py +++ b/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/__init__.py @@ -14,8 +14,7 @@ # See the License for the specific language governing permissions # and limitations under the License. -from .class_incr_encoder_decoder import ClassIncrEncoderDecoder from .detcon import DetConB, SupConDetConB from .mean_teacher_segmentor import MeanTeacherSegmentor -__all__ = ["DetConB", "SupConDetConB", "ClassIncrEncoderDecoder", "MeanTeacherSegmentor"] +__all__ = ["DetConB", "SupConDetConB", "MeanTeacherSegmentor"] diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/class_incr_encoder_decoder.py b/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/class_incr_encoder_decoder.py deleted file mode 100644 index f225d875152..00000000000 --- a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/class_incr_encoder_decoder.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Encoder-decoder for incremental learning.""" - -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import functools - -import torch -from mmseg.models import SEGMENTORS -from mmseg.utils import get_root_logger - -from otx.algorithms.common.utils.task_adapt import map_class_names - -from .mixin import PixelWeightsMixin -from .otx_encoder_decoder import OTXEncoderDecoder - - -@SEGMENTORS.register_module() -class ClassIncrEncoderDecoder(PixelWeightsMixin, OTXEncoderDecoder): - """Encoder-decoder for incremental learning.""" - - def __init__(self, *args, task_adapt=None, **kwargs): - super().__init__(*args, **kwargs) - - # Hook for class-sensitive weight loading - assert task_adapt is not None, "When using task_adapt, task_adapt must be set." - - self._register_load_state_dict_pre_hook( - functools.partial( - self.load_state_dict_pre_hook, - self, # model - task_adapt["dst_classes"], # model_classes - task_adapt["src_classes"], # chkpt_classes - ) - ) - - def forward_train( - self, - img, - img_metas, - gt_semantic_seg, - aux_img=None, - **kwargs, - ): # pylint: disable=arguments-renamed - """Forward function for training. - - Args: - img (Tensor): Input images. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - gt_semantic_seg (Tensor): Semantic segmentation masks - used if the architecture supports semantic segmentation task. - aux_img (Tensor): Auxiliary images. - **kwargs (Any): Addition keyword arguments. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - if aux_img is not None: - mix_loss_enabled = False - mix_loss_cfg = self.train_cfg.get("mix_loss", None) - if mix_loss_cfg is not None: - mix_loss_enabled = mix_loss_cfg.get("enable", False) - if mix_loss_enabled: - self.train_cfg.mix_loss.enable = mix_loss_enabled - - if self.train_cfg.mix_loss.enable: - img = torch.cat([img, aux_img], dim=0) - gt_semantic_seg = torch.cat([gt_semantic_seg, gt_semantic_seg], dim=0) - - return super().forward_train(img, img_metas, gt_semantic_seg, **kwargs) - - @staticmethod - def load_state_dict_pre_hook( - model, model_classes, chkpt_classes, chkpt_dict, prefix, *args, **kwargs - ): # pylint: disable=too-many-locals, unused-argument - """Modify input state_dict according to class name matching before weight loading.""" - logger = get_root_logger("INFO") - logger.info(f"----------------- ClassIncrEncoderDecoder.load_state_dict_pre_hook() called w/ prefix: {prefix}") - - # Dst to src mapping index - model_classes = list(model_classes) - chkpt_classes = list(chkpt_classes) - model2chkpt = map_class_names(model_classes, chkpt_classes) - logger.info(f"{chkpt_classes} -> {model_classes} ({model2chkpt})") - - model_dict = model.state_dict() - param_names = [ - "decode_head.conv_seg.weight", - "decode_head.conv_seg.bias", - ] - for model_name in param_names: - chkpt_name = prefix + model_name - if model_name not in model_dict or chkpt_name not in chkpt_dict: - logger.info(f"Skipping weight copy: {chkpt_name}") - continue - - # Mix weights - model_param = model_dict[model_name].clone() - chkpt_param = chkpt_dict[chkpt_name] - for model_key, c in enumerate(model2chkpt): - if c >= 0: - model_param[model_key].copy_(chkpt_param[c]) - - # Replace checkpoint weight by mixed weights - chkpt_dict[chkpt_name] = model_param diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/detcon.py b/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/detcon.py index c083537d08a..b6e41ae1a1a 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/detcon.py +++ b/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/detcon.py @@ -26,7 +26,7 @@ from otx.algorithms.common.utils.logger import get_logger -from .class_incr_encoder_decoder import OTXEncoderDecoder +from .otx_encoder_decoder import OTXEncoderDecoder logger = get_logger() diff --git a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/mixin.py b/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/mixin.py deleted file mode 100644 index 2f7aa3b3fa7..00000000000 --- a/otx/algorithms/segmentation/adapters/mmseg/models/segmentors/mixin.py +++ /dev/null @@ -1,192 +0,0 @@ -"""Modules for decode and loss reweighting/mix.""" -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# -from mmseg.core import add_prefix -from mmseg.models.builder import build_loss -from mmseg.ops import resize -from torch import nn - -from otx.algorithms.segmentation.adapters.mmseg.models.utils import LossEqualizer - -# pylint: disable=too-many-locals - - -class PixelWeightsMixin: - """PixelWeightsMixin.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._init_train_components(self.train_cfg) - self.feature_maps = None - - def _init_train_components(self, train_cfg): - if train_cfg is None: - self.mutual_losses = None - self.loss_equalizer = None - return - - mutual_loss_configs = train_cfg.get("mutual_loss") - if mutual_loss_configs: - if isinstance(mutual_loss_configs, dict): - mutual_loss_configs = [mutual_loss_configs] - - self.mutual_losses = nn.ModuleList() - for mutual_loss_config in mutual_loss_configs: - self.mutual_losses.append(build_loss(mutual_loss_config)) - else: - self.mutual_losses = None - - loss_reweighting_config = train_cfg.get("loss_reweighting") - if loss_reweighting_config: - self.loss_equalizer = LossEqualizer(**loss_reweighting_config) - else: - self.loss_equalizer = None - - @staticmethod - def _get_argument_by_name(trg_name, arguments): - assert trg_name in arguments.keys() - return arguments[trg_name] - - def set_step_params(self, init_iter, epoch_size): - """Sets the step params for the current object's decode head.""" - self.decode_head.set_step_params(init_iter, epoch_size) - - if getattr(self, "auxiliary_head", None) is not None: - if isinstance(self.auxiliary_head, nn.ModuleList): - for aux_head in self.auxiliary_head: - aux_head.set_step_params(init_iter, epoch_size) - else: - self.auxiliary_head.set_step_params(init_iter, epoch_size) - - def _decode_head_forward_train(self, x, img_metas, pixel_weights=None, **kwargs): - """Run forward train in decode head.""" - trg_map = self._get_argument_by_name(self.decode_head.loss_target_name, kwargs) - loss_decode, logits_decode = self.decode_head.forward_train( - x, - img_metas, - trg_map, - train_cfg=self.train_cfg, - pixel_weights=pixel_weights, - return_logits=True, - ) - - scale = self.decode_head.last_scale - scaled_logits_decode = scale * logits_decode - - name_prefix = "decode" - - losses, meta = dict(), dict() - losses.update(add_prefix(loss_decode, name_prefix)) - meta[f"{name_prefix}_scaled_logits"] = scaled_logits_decode - - return losses, meta - - def _auxiliary_head_forward_train(self, x, img_metas, **kwargs): - - losses, meta = dict(), dict() - if isinstance(self.auxiliary_head, nn.ModuleList): - for idx, aux_head in enumerate(self.auxiliary_head): - trg_map = self._get_argument_by_name(aux_head.loss_target_name, kwargs) - loss_aux, logits_aux = aux_head.forward_train( - x, - img_metas, - trg_map, - train_cfg=self.train_cfg, - return_logits=True, - ) - - scale = aux_head.last_scale - scaled_logits_aux = scale * logits_aux - - name_prefix = f"aux_{idx}" - losses.update(add_prefix(loss_aux, name_prefix)) - meta[f"{name_prefix}_scaled_logits"] = scaled_logits_aux - else: - trg_map = self._get_argument_by_name(self.auxiliary_head.loss_target_name, kwargs) - loss_aux, logits_aux = self.auxiliary_head.forward_train( - x, - img_metas, - trg_map, - train_cfg=self.train_cfg, - return_logits=True, - ) - - scale = self.auxiliary_head.last_scale - scaled_logits_aux = scale * logits_aux - - name_prefix = "aux" - losses.update(add_prefix(loss_aux, name_prefix)) - meta[f"{name_prefix}_scaled_logits"] = scaled_logits_aux - - return losses, meta - - def forward_train(self, img, img_metas, gt_semantic_seg, pixel_weights=None, **kwargs): - """Forward function for training. - - Args: - img (Tensor): Input images. - img_metas (list[dict]): List of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmseg/datasets/pipelines/formatting.py:Collect`. - gt_semantic_seg (Tensor): Semantic segmentation masks - used if the architecture supports semantic segmentation task. - pixel_weights (Tensor): Pixels weights. - **kwargs (Any): Addition keyword arguments. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - - losses = dict() - - features = self.extract_feat(img) - - loss_decode, meta_decode = self._decode_head_forward_train( - features, img_metas, pixel_weights, gt_semantic_seg=gt_semantic_seg, **kwargs - ) - losses.update(loss_decode) - - if self.with_auxiliary_head: - loss_aux, meta_aux = self._auxiliary_head_forward_train( - features, img_metas, gt_semantic_seg=gt_semantic_seg, **kwargs - ) - losses.update(loss_aux) - - if self.mutual_losses is not None and self.with_auxiliary_head: - meta = dict() - meta.update(meta_decode) - meta.update(meta_aux) - - out_mutual_losses = dict() - for mutual_loss_idx, mutual_loss in enumerate(self.mutual_losses): - logits_a = self._get_argument_by_name(mutual_loss.trg_a_name, meta) - logits_b = self._get_argument_by_name(mutual_loss.trg_b_name, meta) - - logits_a = resize( - input=logits_a, size=gt_semantic_seg.shape[2:], mode="bilinear", align_corners=self.align_corners - ) - logits_b = resize( - input=logits_b, size=gt_semantic_seg.shape[2:], mode="bilinear", align_corners=self.align_corners - ) - - mutual_labels = gt_semantic_seg.squeeze(1) - mutual_loss_value, mutual_loss_meta = mutual_loss(logits_a, logits_b, mutual_labels) - - mutual_loss_name = mutual_loss.name + f"-{mutual_loss_idx}" - out_mutual_losses[mutual_loss_name] = mutual_loss_value - losses[mutual_loss_name] = mutual_loss_value - losses.update(add_prefix(mutual_loss_meta, mutual_loss_name)) - - losses["loss_mutual"] = sum(out_mutual_losses.values()) - - if self.loss_equalizer is not None: - unweighted_losses = {loss_name: loss for loss_name, loss in losses.items() if "loss" in loss_name} - weighted_losses = self.loss_equalizer.reweight(unweighted_losses) - - for loss_name, loss_value in weighted_losses.items(): - losses[loss_name] = loss_value - - return losses diff --git a/tests/unit/algorithms/segmentation/adapters/mmseg/utils/test_config_utils.py b/tests/unit/algorithms/segmentation/adapters/mmseg/utils/test_config_utils.py index b5c1c2e7b5c..c0fd84683a0 100644 --- a/tests/unit/algorithms/segmentation/adapters/mmseg/utils/test_config_utils.py +++ b/tests/unit/algorithms/segmentation/adapters/mmseg/utils/test_config_utils.py @@ -21,7 +21,7 @@ def _create_dummy_config() -> Config: config: dict = dict( model=dict( - type="ClassIncrEncoderDecoder", + type="OTXEncoderDecoder", backbone=dict(), decode_head=dict( norm_cfg=dict(type="BN", requires_grad=True),