Skip to content

Commit

Permalink
Revert "fix GeneralizedDiceLoss (Project-MONAI#5468)"
Browse files Browse the repository at this point in the history
This reverts commit e03ecd4.

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Jul 25, 2023
1 parent 2800a76 commit 8bb0fae
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
5 changes: 3 additions & 2 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1)
w = w + infs * max_values

numer = 2.0 * (intersection * w) + self.smooth_nr
denom = (denominator * w) + self.smooth_dr
final_reduce_dim = 0 if self.batch else 1
numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
f: torch.Tensor = 1.0 - (numer / denom)

if self.reduction == LossReduction.MEAN.value:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_generalized_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
0.435035,
0.469964,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
{
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
0.3837,
0.414507,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{
Expand All @@ -71,7 +71,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
1.5348,
0.829015,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{
Expand All @@ -86,7 +86,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
[[[0.210949], [0.295351]], [[0.599976], [0.428522]]],
[[[0.273476]], [[0.555539]]],
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8},
Expand Down Expand Up @@ -114,7 +114,7 @@
"input": torch.tensor([[[0.0, 10.0, 10.0, 10.0], [10.0, 0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1, 1, 0, 0]]]),
},
0.26669,
0.250023,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
Expand All @@ -136,7 +136,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
-8.55485,
-0.097833,
],
]

Expand Down

0 comments on commit 8bb0fae

Please sign in to comment.