Skip to content

Commit

Permalink
[Enhancement] Support ignore_index for sigmoid BCE (#210)
Browse files Browse the repository at this point in the history
* [Enhancement] Add args check for ignore_index

* Support ignore_index
  • Loading branch information
xvjiarui authored Nov 17, 2020
1 parent c2608b2 commit 61e1d5c
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 18 deletions.
6 changes: 3 additions & 3 deletions configs/_base_/models/fast_scnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.)),
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
auxiliary_head=[
dict(
type='FCNHead',
Expand All @@ -38,7 +38,7 @@
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
dict(
type='FCNHead',
in_channels=64,
Expand All @@ -50,7 +50,7 @@
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
])

# model training and testing settings
Expand Down
2 changes: 1 addition & 1 deletion configs/fastscnn/fast_scnn_4x8_80k_lr0.12_cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
]

# Re-config the data sampler.
data = dict(samples_per_gpu=8, workers_per_gpu=4)
data = dict(samples_per_gpu=2, workers_per_gpu=4)

# Re-config the optimizer.
optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-5)
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
Default: None.
loss_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss').
ignore_index (int): The label index to be ignored. Default: 255
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Expand Down
43 changes: 31 additions & 12 deletions mmseg/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,25 @@ def cross_entropy(pred,
return loss


def _expand_onehot_labels(labels, label_weights, label_channels):
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1, as_tuple=False).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_labels = labels.new_zeros(target_shape)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(valid_mask, as_tuple=True)

if inds[0].numel() > 0:
if labels.dim() == 3:
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
else:
bin_labels[inds[0], labels[valid_mask]] = 1

valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
if label_weights is None:
bin_label_weights = None
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights *= valid_mask

return bin_labels, bin_label_weights


Expand All @@ -51,7 +59,8 @@ def binary_cross_entropy(pred,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None):
class_weight=None,
ignore_index=255):
"""Calculate the binary CrossEntropy loss.
Args:
Expand All @@ -63,18 +72,24 @@ def binary_cross_entropy(pred,
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored. Default: 255
Returns:
torch.Tensor: The calculated loss
"""
if pred.dim() != label.dim():
label, weight = _expand_onehot_labels(label, weight, pred.size(-1))
assert (pred.dim() == 2 and label.dim() == 1) or (
pred.dim() == 4 and label.dim() == 3), \
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
'H, W], label shape [N, H, W] are supported'
label, weight = _expand_onehot_labels(label, weight, pred.shape,
ignore_index)

# weighted element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), weight=class_weight, reduction='none')
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
Expand All @@ -87,7 +102,8 @@ def mask_cross_entropy(pred,
label,
reduction='mean',
avg_factor=None,
class_weight=None):
class_weight=None,
ignore_index=None):
"""Calculate the CrossEntropy loss for masks.
Args:
Expand All @@ -103,10 +119,13 @@ def mask_cross_entropy(pred,
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
"""
assert ignore_index is None, 'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
Expand Down
12 changes: 11 additions & 1 deletion tests/test_models/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,17 @@ def test_ce_loss():
loss_cls_cfg = dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(0.))
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.))

fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
fake_label = torch.ones(2, 8, 8).long()
assert torch.allclose(
loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4)
fake_label[:, 0, 0] = 255
assert torch.allclose(
loss_cls(fake_pred, fake_label, ignore_index=255),
torch.tensor(0.9354),
atol=1e-4)

# TODO test use_mask

Expand Down

0 comments on commit 61e1d5c

Please sign in to comment.