Skip to content

Commit

Permalink
6765 update GeneralizedDiceLoss (#6775)
Browse files Browse the repository at this point in the history
Fixes #6765

### Description
as discussed in #6765, when `batch=True` the loss should still return 1
aggregated value instead of C channels.
#5466 is not actually
achievable with this formulation.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Jul 26, 2023
1 parent 3b56e7f commit 6f9cf6b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
6 changes: 4 additions & 2 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def __init__(
smooth_dr: a small constant added to the denominator to avoid nan.
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, intersection over union is computed from each item in the batch.
If True, the class-weighted intersection and union areas are first summed across the batches.
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand Down Expand Up @@ -360,8 +361,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1)
w = w + infs * max_values

numer = 2.0 * (intersection * w) + self.smooth_nr
denom = (denominator * w) + self.smooth_dr
final_reduce_dim = 0 if self.batch else 1
numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
f: torch.Tensor = 1.0 - (numer / denom)

if self.reduction == LossReduction.MEAN.value:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_generalized_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
0.435035,
0.469964,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
{
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
0.3837,
0.414507,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{
Expand All @@ -71,7 +71,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
1.5348,
0.829015,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{
Expand All @@ -86,7 +86,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
[[[0.210949], [0.295351]], [[0.599976], [0.428522]]],
[[[0.273476]], [[0.555539]]],
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8},
Expand Down Expand Up @@ -114,7 +114,7 @@
"input": torch.tensor([[[0.0, 10.0, 10.0, 10.0], [10.0, 0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1, 1, 0, 0]]]),
},
0.26669,
0.250023,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
Expand All @@ -136,7 +136,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
-8.55485,
-0.097833,
],
]

Expand Down

0 comments on commit 6f9cf6b

Please sign in to comment.