diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index 92e88098ac..70041ad1c9 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -62,7 +62,12 @@ def __init__( self.include_background = include_background self.reduction = look_up_option(reduction, MetricReduction) self.weight_type = look_up_option(weight_type, Weight) - self.sum_over_classes = self.reduction == MetricReduction.SUM or self.reduction == MetricReduction.MEAN + self.sum_over_classes = self.reduction in { + MetricReduction.SUM, + MetricReduction.MEAN, + MetricReduction.MEAN_CHANNEL, + MetricReduction.SUM_CHANNEL, + } def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py index fd7745245c..448a17eaa5 100644 --- a/tests/test_compute_generalized_dice.py +++ b/tests/test_compute_generalized_dice.py @@ -32,7 +32,7 @@ ] # remove background -TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 2) (no background) +TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background) with GeneralizedDiceScore { "y_pred": torch.tensor( [ @@ -47,8 +47,9 @@ ] ), "include_background": False, + "reduction": "mean_batch", }, - [0.416667], + [0.583333, 0.333333], ] # should return 0 for both cases @@ -129,6 +130,25 @@ [[1.0000, 1.0000], [1.0000, 1.0000]], ] +TEST_CASE_10 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2) + {"include_background": True, "reduction": "mean_channel"}, + { + "y_pred": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], + [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], + ] + ), + "y": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], + ] + ), + }, + [0.545455, 0.545455], +] + class TestComputeGeneralizedDiceScore(unittest.TestCase): @@ -162,7 +182,7 @@ def test_value_class(self, input_data, expected_value): np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) # Aggregation tests - @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_10]) def test_nans_class(self, params, input_data, expected_value): generalized_dice_score = GeneralizedDiceScore(**params) generalized_dice_score(**input_data)