Skip to content

Commit

Permalink
Update reduction for channel cases
Browse files Browse the repository at this point in the history
  • Loading branch information
surajpaib committed Sep 5, 2024
1 parent 2bc6dd8 commit 8274060
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
7 changes: 6 additions & 1 deletion monai/metrics/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@ def __init__(
self.include_background = include_background
self.reduction = look_up_option(reduction, MetricReduction)
self.weight_type = look_up_option(weight_type, Weight)
self.sum_over_classes = self.reduction == MetricReduction.SUM or self.reduction == MetricReduction.MEAN
self.sum_over_classes = self.reduction in {
MetricReduction.SUM,
MetricReduction.MEAN,
MetricReduction.MEAN_CHANNEL,
MetricReduction.SUM_CHANNEL,
}

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Expand Down
26 changes: 23 additions & 3 deletions tests/test_compute_generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
]

# remove background
TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 2) (no background)
TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background) with GeneralizedDiceScore
{
"y_pred": torch.tensor(
[
Expand All @@ -47,8 +47,9 @@
]
),
"include_background": False,
"reduction": "mean_batch",
},
[0.416667],
[0.583333, 0.333333],
]

# should return 0 for both cases
Expand Down Expand Up @@ -129,6 +130,25 @@
[[1.0000, 1.0000], [1.0000, 1.0000]],
]

TEST_CASE_10 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2)
{"include_background": True, "reduction": "mean_channel"},
{
"y_pred": torch.tensor(
[
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
),
"y": torch.tensor(
[
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
]
),
},
[0.545455, 0.545455],
]


class TestComputeGeneralizedDiceScore(unittest.TestCase):

Expand Down Expand Up @@ -162,7 +182,7 @@ def test_value_class(self, input_data, expected_value):
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

# Aggregation tests
@parameterized.expand([TEST_CASE_4, TEST_CASE_5])
@parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_10])
def test_nans_class(self, params, input_data, expected_value):
generalized_dice_score = GeneralizedDiceScore(**params)
generalized_dice_score(**input_data)
Expand Down

0 comments on commit 8274060

Please sign in to comment.