-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* dice loss * format code, add docstring and calculate denominator without valid_mask * minor change * restore
- Loading branch information
谢昕辰
authored
Mar 11, 2021
1 parent
d0a71c1
commit 7e1b24d
Showing
3 changed files
with
158 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
from .accuracy import Accuracy, accuracy | ||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, | ||
cross_entropy, mask_cross_entropy) | ||
from .dice_loss import DiceLoss | ||
from .lovasz_loss import LovaszLoss | ||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss | ||
|
||
__all__ = [ | ||
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', | ||
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', | ||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss' | ||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/ | ||
segmentron/solver/loss.py (Apache-2.0 License)""" | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from ..builder import LOSSES | ||
from .utils import weighted_loss | ||
|
||
|
||
@weighted_loss | ||
def dice_loss(pred, | ||
target, | ||
valid_mask, | ||
smooth=1, | ||
exponent=2, | ||
class_weight=None, | ||
ignore_index=-1): | ||
assert pred.shape[0] == target.shape[0] | ||
total_loss = 0 | ||
num_classes = pred.shape[1] | ||
for i in range(num_classes): | ||
if i != ignore_index: | ||
dice_loss = binary_dice_loss( | ||
pred[:, i], | ||
target[..., i], | ||
valid_mask=valid_mask, | ||
smooth=smooth, | ||
exponent=exponent) | ||
if class_weight is not None: | ||
dice_loss *= class_weight[i] | ||
total_loss += dice_loss | ||
return total_loss / num_classes | ||
|
||
|
||
@weighted_loss | ||
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): | ||
assert pred.shape[0] == target.shape[0] | ||
pred = pred.contiguous().view(pred.shape[0], -1) | ||
target = target.contiguous().view(target.shape[0], -1) | ||
valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1) | ||
|
||
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth | ||
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth | ||
|
||
return 1 - num / den | ||
|
||
|
||
@LOSSES.register_module() | ||
class DiceLoss(nn.Module): | ||
"""DiceLoss. | ||
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for | ||
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_. | ||
Args: | ||
loss_type (str, optional): Binary or multi-class loss. | ||
Default: 'multi_class'. Options are "binary" and "multi_class". | ||
smooth (float): A float number to smooth loss, and avoid NaN error. | ||
Default: 1 | ||
exponent (float): An float number to calculate denominator | ||
value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2. | ||
reduction (str, optional): The method used to reduce the loss. Options | ||
are "none", "mean" and "sum". This parameter only works when | ||
per_image is True. Default: 'mean'. | ||
class_weight (list[float], optional): The weight for each class. | ||
Default: None. | ||
loss_weight (float, optional): Weight of the loss. Default to 1.0. | ||
ignore_index (int | None): The label index to be ignored. Default: 255. | ||
""" | ||
|
||
def __init__(self, | ||
loss_type='multi_class', | ||
smooth=1, | ||
exponent=2, | ||
reduction='mean', | ||
class_weight=None, | ||
loss_weight=1.0, | ||
ignore_index=255): | ||
super(DiceLoss, self).__init__() | ||
assert loss_type in ['multi_class', 'binary'] | ||
if loss_type == 'multi_class': | ||
self.cls_criterion = dice_loss | ||
else: | ||
self.cls_criterion = binary_dice_loss | ||
self.smooth = smooth | ||
self.exponent = exponent | ||
self.reduction = reduction | ||
self.class_weight = class_weight | ||
self.loss_weight = loss_weight | ||
self.ignore_index = ignore_index | ||
|
||
def forward(self, pred, target, avg_factor=None, reduction_override=None): | ||
assert reduction_override in (None, 'none', 'mean', 'sum') | ||
reduction = ( | ||
reduction_override if reduction_override else self.reduction) | ||
if self.class_weight is not None: | ||
class_weight = pred.new_tensor(self.class_weight) | ||
else: | ||
class_weight = None | ||
|
||
pred = F.softmax(pred, dim=1) | ||
one_hot_target = F.one_hot(torch.clamp_min(target.long(), 0)) | ||
valid_mask = (target != self.ignore_index).long() | ||
|
||
loss = self.loss_weight * self.cls_criterion( | ||
pred, | ||
one_hot_target, | ||
valid_mask=valid_mask, | ||
reduction=reduction, | ||
avg_factor=avg_factor, | ||
smooth=self.smooth, | ||
exponent=self.exponent, | ||
class_weight=class_weight, | ||
ignore_index=self.ignore_index) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters