Skip to content

Commit

Permalink
dice loss (#396)
Browse files Browse the repository at this point in the history
* dice loss

* format code, add docstring and calculate denominator without valid_mask

* minor change

* restore
  • Loading branch information
谢昕辰 authored Mar 11, 2021
1 parent d0a71c1 commit 7e1b24d
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mmseg/models/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .accuracy import Accuracy, accuracy
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy, mask_cross_entropy)
from .dice_loss import DiceLoss
from .lovasz_loss import LovaszLoss
from .utils import reduce_loss, weight_reduce_loss, weighted_loss

__all__ = [
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss'
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
]
116 changes: 116 additions & 0 deletions mmseg/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
segmentron/solver/loss.py (Apache-2.0 License)"""
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weighted_loss


@weighted_loss
def dice_loss(pred,
target,
valid_mask,
smooth=1,
exponent=2,
class_weight=None,
ignore_index=-1):
assert pred.shape[0] == target.shape[0]
total_loss = 0
num_classes = pred.shape[1]
for i in range(num_classes):
if i != ignore_index:
dice_loss = binary_dice_loss(
pred[:, i],
target[..., i],
valid_mask=valid_mask,
smooth=smooth,
exponent=exponent)
if class_weight is not None:
dice_loss *= class_weight[i]
total_loss += dice_loss
return total_loss / num_classes


@weighted_loss
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
assert pred.shape[0] == target.shape[0]
pred = pred.contiguous().view(pred.shape[0], -1)
target = target.contiguous().view(target.shape[0], -1)
valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1)

num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth

return 1 - num / den


@LOSSES.register_module()
class DiceLoss(nn.Module):
"""DiceLoss.
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
Args:
loss_type (str, optional): Binary or multi-class loss.
Default: 'multi_class'. Options are "binary" and "multi_class".
smooth (float): A float number to smooth loss, and avoid NaN error.
Default: 1
exponent (float): An float number to calculate denominator
value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float], optional): The weight for each class.
Default: None.
loss_weight (float, optional): Weight of the loss. Default to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
"""

def __init__(self,
loss_type='multi_class',
smooth=1,
exponent=2,
reduction='mean',
class_weight=None,
loss_weight=1.0,
ignore_index=255):
super(DiceLoss, self).__init__()
assert loss_type in ['multi_class', 'binary']
if loss_type == 'multi_class':
self.cls_criterion = dice_loss
else:
self.cls_criterion = binary_dice_loss
self.smooth = smooth
self.exponent = exponent
self.reduction = reduction
self.class_weight = class_weight
self.loss_weight = loss_weight
self.ignore_index = ignore_index

def forward(self, pred, target, avg_factor=None, reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = pred.new_tensor(self.class_weight)
else:
class_weight = None

pred = F.softmax(pred, dim=1)
one_hot_target = F.one_hot(torch.clamp_min(target.long(), 0))
valid_mask = (target != self.ignore_index).long()

loss = self.loss_weight * self.cls_criterion(
pred,
one_hot_target,
valid_mask=valid_mask,
reduction=reduction,
avg_factor=avg_factor,
smooth=self.smooth,
exponent=self.exponent,
class_weight=class_weight,
ignore_index=self.ignore_index)
return loss
40 changes: 40 additions & 0 deletions tests/test_models/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,43 @@ def test_lovasz_loss():
logits = torch.rand(2, 4, 4)
labels = (torch.rand(2, 4, 4)).long()
lovasz_loss(logits, labels, ignore_index=None)


def test_dice_lose():
from mmseg.models import build_loss

# loss_type should be 'binary' or 'multi_class'
with pytest.raises(AssertionError):
loss_cfg = dict(
type='DiceLoss',
loss_type='Binary',
reduction='none',
loss_weight=1.0)
build_loss(loss_cfg)

# test dice loss with loss_type = 'multi_class'
loss_cfg = dict(
type='DiceLoss',
loss_type='multi_class',
reduction='none',
class_weight=[1.0, 2.0, 3.0],
loss_weight=1.0,
ignore_index=1)
dice_loss = build_loss(loss_cfg)
logits = torch.rand(8, 3, 4, 4)
labels = (torch.rand(8, 4, 4) * 3).long()
dice_loss(logits, labels)

# test dice loss with loss_type = 'binary'
loss_cfg = dict(
type='DiceLoss',
loss_type='binary',
smooth=2,
exponent=3,
reduction='sum',
loss_weight=1.0,
ignore_index=0)
dice_loss = build_loss(loss_cfg)
logits = torch.rand(16, 4, 4)
labels = (torch.rand(16, 4, 4)).long()
dice_loss(logits, labels)

0 comments on commit 7e1b24d

Please sign in to comment.