From 61e1d5c814be05c02fef9a5c6c0b629bbf2908c9 Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU Date: Tue, 17 Nov 2020 00:14:03 -0800 Subject: [PATCH] [Enhancement] Support ignore_index for sigmoid BCE (#210) * [Enhancement] Add args check for ignore_index * Support ignore_index --- configs/_base_/models/fast_scnn.py | 6 +-- .../fast_scnn_4x8_80k_lr0.12_cityscapes.py | 2 +- mmseg/models/decode_heads/decode_head.py | 3 +- mmseg/models/losses/cross_entropy_loss.py | 43 +++++++++++++------ tests/test_models/test_losses.py | 12 +++++- 5 files changed, 48 insertions(+), 18 deletions(-) diff --git a/configs/_base_/models/fast_scnn.py b/configs/_base_/models/fast_scnn.py index 67ee0d39a6..06cd83979d 100644 --- a/configs/_base_/models/fast_scnn.py +++ b/configs/_base_/models/fast_scnn.py @@ -25,7 +25,7 @@ norm_cfg=norm_cfg, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.)), + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), auxiliary_head=[ dict( type='FCNHead', @@ -38,7 +38,7 @@ concat_input=False, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), dict( type='FCNHead', in_channels=64, @@ -50,7 +50,7 @@ concat_input=False, align_corners=False, loss_decode=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), ]) # model training and testing settings diff --git a/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py b/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py index 53fcfc4203..3d9c999937 100644 --- a/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py +++ b/configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py @@ -4,7 +4,7 @@ ] # Re-config the data sampler. -data = dict(samples_per_gpu=8, workers_per_gpu=4) +data = dict(samples_per_gpu=2, workers_per_gpu=4) # Re-config the optimizer. optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-5) diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index 0f58c80e9b..86b9b63f43 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -35,7 +35,8 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta): Default: None. loss_decode (dict): Config of decode loss. Default: dict(type='CrossEntropyLoss'). - ignore_index (int): The label index to be ignored. Default: 255 + ignore_index (int | None): The label index to be ignored. When using + masked BCE loss, ignore_index should be set to None. Default: 255 sampler (dict|None): The config of segmentation map sampler. Default: None. align_corners (bool): align_corners argument of F.interpolate. diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py index dcd9f1c894..44798421aa 100644 --- a/mmseg/models/losses/cross_entropy_loss.py +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -32,17 +32,25 @@ def cross_entropy(pred, return loss -def _expand_onehot_labels(labels, label_weights, label_channels): +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): """Expand onehot labels to match the size of prediction.""" - bin_labels = labels.new_full((labels.size(0), label_channels), 0) - inds = torch.nonzero(labels >= 1, as_tuple=False).squeeze() - if inds.numel() > 0: - bin_labels[inds, labels[inds] - 1] = 1 + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() if label_weights is None: - bin_label_weights = None + bin_label_weights = valid_mask else: - bin_label_weights = label_weights.view(-1, 1).expand( - label_weights.size(0), label_channels) + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights *= valid_mask + return bin_labels, bin_label_weights @@ -51,7 +59,8 @@ def binary_cross_entropy(pred, weight=None, reduction='mean', avg_factor=None, - class_weight=None): + class_weight=None, + ignore_index=255): """Calculate the binary CrossEntropy loss. Args: @@ -63,18 +72,24 @@ def binary_cross_entropy(pred, avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. class_weight (list[float], optional): The weight for each class. + ignore_index (int | None): The label index to be ignored. Default: 255 Returns: torch.Tensor: The calculated loss """ if pred.dim() != label.dim(): - label, weight = _expand_onehot_labels(label, weight, pred.size(-1)) + assert (pred.dim() == 2 and label.dim() == 1) or ( + pred.dim() == 4 and label.dim() == 3), \ + 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ + 'H, W], label shape [N, H, W] are supported' + label, weight = _expand_onehot_labels(label, weight, pred.shape, + ignore_index) # weighted element-wise losses if weight is not None: weight = weight.float() loss = F.binary_cross_entropy_with_logits( - pred, label.float(), weight=class_weight, reduction='none') + pred, label.float(), pos_weight=class_weight, reduction='none') # do the reduction for the weighted loss loss = weight_reduce_loss( loss, weight, reduction=reduction, avg_factor=avg_factor) @@ -87,7 +102,8 @@ def mask_cross_entropy(pred, label, reduction='mean', avg_factor=None, - class_weight=None): + class_weight=None, + ignore_index=None): """Calculate the CrossEntropy loss for masks. Args: @@ -103,10 +119,13 @@ def mask_cross_entropy(pred, avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. Returns: torch.Tensor: The calculated loss """ + assert ignore_index is None, 'BCE loss does not support ignore_index' # TODO: handle these two reserved arguments assert reduction == 'mean' and avg_factor is None num_rois = pred.size()[0] diff --git a/tests/test_models/test_losses.py b/tests/test_models/test_losses.py index edae6bfd16..32b3d067a3 100644 --- a/tests/test_models/test_losses.py +++ b/tests/test_models/test_losses.py @@ -71,7 +71,17 @@ def test_ce_loss(): loss_cls_cfg = dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0) loss_cls = build_loss(loss_cls_cfg) - assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(0.)) + assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.)) + + fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) + fake_label = torch.ones(2, 8, 8).long() + assert torch.allclose( + loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4) + fake_label[:, 0, 0] = 255 + assert torch.allclose( + loss_cls(fake_pred, fake_label, ignore_index=255), + torch.tensor(0.9354), + atol=1e-4) # TODO test use_mask