From 324e53f2a1a829d755637ef96579636c5caebf36 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Mon, 3 Jul 2023 19:52:46 +0200 Subject: [PATCH 01/25] Add DiceCEMetric Signed-off-by: Matthias Hadlich --- monai/handlers/__init__.py | 1 + monai/handlers/dice_ce_metric.py | 64 ++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 monai/handlers/dice_ce_metric.py diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index f032191043..b58e883260 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -42,3 +42,4 @@ from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler +from .dice_ce_metric import DiceCEMetric diff --git a/monai/handlers/dice_ce_metric.py b/monai/handlers/dice_ce_metric.py new file mode 100644 index 0000000000..f3fe30fb75 --- /dev/null +++ b/monai/handlers/dice_ce_metric.py @@ -0,0 +1,64 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from monai.handlers.ignite_metric import IgniteMetric +from monai.metrics import DiceMetric +from monai.utils import MetricReduction +from monai.losses import DiceCELoss +from monai.metrics import LossMetric + +class DiceCEMetric(IgniteMetric): + """ + Computes DiceCE score metric from full size Tensor and collects average over batch, class-channels, iterations. + """ + + def __init__( + self, + # include_background: bool = True, + reduction: MetricReduction | str = MetricReduction.MEAN, + # num_classes: int | None = None, + output_transform: Callable = lambda x: x, + save_details: bool = True, + *args, + **kwargs + ) -> None: + """ + + Args: + include_background: whether to include dice computation on the first channel of the predicted output. + Defaults to True. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + num_classes: number of input channels (always including the background). When this is None, + ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are + single-channel class indices and the number of classes is not automatically inferred from data. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: mean dice of every image. + default to True, will save to `engine.state.metric_details` dict with the metric name as key. + args: Arguments for the DiceCELoss + + See also: + :py:meth:`monai.metrics.meandice.compute_dice` + """ + # metric_fn = DiceMetric(include_background=include_background, reduction=reduction, num_classes=num_classes) + loss_function = DiceCELoss(*args, **kwargs) + metric_fn = LossMetric(loss_fn=loss_function, reduction=reduction, get_not_nans=False) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) From 4384f4c51b5d7735a1cdeb09cbb9106799bc878b Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Mon, 3 Jul 2023 20:06:30 +0200 Subject: [PATCH 02/25] WiP: Add unittest for DiceCEMetric Signed-off-by: Matthias Hadlich --- tests/test_handler_dice_ce_metric.py | 98 ++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 tests/test_handler_dice_ce_metric.py diff --git a/tests/test_handler_dice_ce_metric.py b/tests/test_handler_dice_ce_metric.py new file mode 100644 index 0000000000..4b4059affa --- /dev/null +++ b/tests/test_handler_dice_ce_metric.py @@ -0,0 +1,98 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from ignite.engine import Engine, Events +from parameterized import parameterized + +from monai.handlers import DiceCEMetric, from_engine +from tests.utils import assert_allclose + +TEST_CASE_1 = [{"include_background": True, "output_transform": from_engine(["pred", "label"])}, 0.813259, (4, 2)] +TEST_CASE_2 = [{"include_background": False, "output_transform": from_engine(["pred", "label"])}, 0.813259, (4, 1)] +TEST_CASE_3 = [ + {"reduction": "mean_channel", "output_transform": from_engine(["pred", "label"])}, + torch.Tensor([0.313262, 2.313251, 0.313262, 0.313262]), + (4, 2), +] + + +class TestHandlerDiceCEMetric(unittest.TestCase): + # TODO test multi node averaged dice + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_compute(self, input_params, expected_avg, details_shape): + dice_metric = DiceCEMetric(**input_params) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + dice_metric.attach(engine=engine, name="dice_ce_metric") + # test input a list of channel-first tensor + y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] + y = torch.Tensor([[[0], [1]], [[0], [1]]]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] + y = torch.Tensor([[[0], [1]], [[1], [0]]]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose(engine.state.metrics["dice_ce_metric"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) + self.assertTupleEqual(tuple(engine.state.metric_details["dice_ce_metric"].shape), details_shape) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_shape_mismatch(self, input_params, _expected_avg, _details_shape): + dice_metric = DiceCEMetric(**input_params) + with self.assertRaises((AssertionError, ValueError)): + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.ones((2, 3)) + dice_metric.update([y_pred, y]) + + with self.assertRaises((AssertionError, ValueError)): + y_pred = torch.Tensor([[0, 1], [1, 0]]) + y = torch.ones((3, 2)) + dice_metric.update([y_pred, y]) + + # @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + # def test_compute_n_class(self, input_params, expected_avg, details_shape): + # dice_metric = DiceCEMetric(num_classes=2, **input_params) + + # def _val_func(engine, batch): + # pass + + # engine = Engine(_val_func) + # dice_metric.attach(engine=engine, name="dice_ce_metric") + # # test input a list of channel-first tensor + # y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] + # y = torch.Tensor([[[0], [1]], [[0], [1]]]) + # engine.state.output = {"pred": y_pred, "label": y} + # engine.fire_event(Events.ITERATION_COMPLETED) + + # y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] # class indices y_pred + # y = torch.Tensor([[[1]], [[0]]]) # class indices y + # engine.state.output = {"pred": y_pred, "label": y} + # engine.fire_event(Events.ITERATION_COMPLETED) + + # engine.fire_event(Events.EPOCH_COMPLETED) + # assert_allclose(engine.state.metrics["dice_ce_metric"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) + # self.assertTupleEqual(tuple(engine.state.metric_details["dice_ce_metric"].shape), details_shape) + + +if __name__ == "__main__": + unittest.main() From 94cb06fffc5ca6632a645b1ebcf9b33581a70c82 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Jul 2023 18:23:44 +0000 Subject: [PATCH 03/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/handlers/dice_ce_metric.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/handlers/dice_ce_metric.py b/monai/handlers/dice_ce_metric.py index f3fe30fb75..78a8797330 100644 --- a/monai/handlers/dice_ce_metric.py +++ b/monai/handlers/dice_ce_metric.py @@ -14,7 +14,6 @@ from collections.abc import Callable from monai.handlers.ignite_metric import IgniteMetric -from monai.metrics import DiceMetric from monai.utils import MetricReduction from monai.losses import DiceCELoss from monai.metrics import LossMetric From 48c6ef9993a969c2c92ca6eeecef6059e95f46b4 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 4 Jul 2023 13:12:45 +0200 Subject: [PATCH 04/25] Remove DiceCEMetric Signed-off-by: Matthias Hadlich --- monai/handlers/__init__.py | 2 +- monai/handlers/dice_ce_metric.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index b58e883260..e1588fedb1 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -42,4 +42,4 @@ from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler -from .dice_ce_metric import DiceCEMetric +from .ignite_loss_metric import IgniteLossMetric diff --git a/monai/handlers/dice_ce_metric.py b/monai/handlers/dice_ce_metric.py index 78a8797330..549b992ae9 100644 --- a/monai/handlers/dice_ce_metric.py +++ b/monai/handlers/dice_ce_metric.py @@ -18,24 +18,27 @@ from monai.losses import DiceCELoss from monai.metrics import LossMetric -class DiceCEMetric(IgniteMetric): +from torch.nn.modules.loss import _Loss + +class LossMetricIgnite(IgniteMetric): """ Computes DiceCE score metric from full size Tensor and collects average over batch, class-channels, iterations. """ def __init__( self, - # include_background: bool = True, + loss_fn: _Loss, reduction: MetricReduction | str = MetricReduction.MEAN, - # num_classes: int | None = None, output_transform: Callable = lambda x: x, save_details: bool = True, - *args, - **kwargs + # *args, + # **kwargs ) -> None: """ Args: + loss_fn: a callable function that takes ``y_pred`` and optionally ``y`` as input (in the "batch-first" format), + returns a "batch-first" tensor of loss values. include_background: whether to include dice computation on the first channel of the predicted output. Defaults to True. reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, @@ -57,7 +60,7 @@ def __init__( See also: :py:meth:`monai.metrics.meandice.compute_dice` """ - # metric_fn = DiceMetric(include_background=include_background, reduction=reduction, num_classes=num_classes) - loss_function = DiceCELoss(*args, **kwargs) - metric_fn = LossMetric(loss_fn=loss_function, reduction=reduction, get_not_nans=False) + self.loss_fn = loss_fn + # loss_function = DiceCELoss(*args, **kwargs) + metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=False) super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) From 6815816c55a6d154c1f66fd5f9aa5a16035ab7c5 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 4 Jul 2023 13:13:35 +0200 Subject: [PATCH 05/25] Add IgniteLossMetric Signed-off-by: Matthias Hadlich --- monai/handlers/ignite_loss_metric.py | 67 ++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 monai/handlers/ignite_loss_metric.py diff --git a/monai/handlers/ignite_loss_metric.py b/monai/handlers/ignite_loss_metric.py new file mode 100644 index 0000000000..50cbb485a5 --- /dev/null +++ b/monai/handlers/ignite_loss_metric.py @@ -0,0 +1,67 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from monai.handlers.ignite_metric import IgniteMetric +from monai.metrics import DiceMetric +from monai.utils import MetricReduction +from monai.losses import DiceCELoss +from monai.metrics import LossMetric + +from torch.nn.modules.loss import _Loss + +class IgniteLossMetric(IgniteMetric): + """ + Computes DiceCE score metric from full size Tensor and collects average over batch, class-channels, iterations. + """ + + def __init__( + self, + loss_fn: _Loss, + reduction: MetricReduction | str = MetricReduction.MEAN, + output_transform: Callable = lambda x: x, + save_details: bool = True, + # *args, + # **kwargs + ) -> None: + """ + + Args: + loss_fn: a callable function that takes ``y_pred`` and optionally ``y`` as input (in the "batch-first" format), + returns a "batch-first" tensor of loss values. + include_background: whether to include dice computation on the first channel of the predicted output. + Defaults to True. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + num_classes: number of input channels (always including the background). When this is None, + ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are + single-channel class indices and the number of classes is not automatically inferred from data. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: mean dice of every image. + default to True, will save to `engine.state.metric_details` dict with the metric name as key. + args: Arguments for the DiceCELoss + + See also: + :py:meth:`monai.metrics.meandice.compute_dice` + """ + self.loss_fn = loss_fn + # loss_function = DiceCELoss(*args, **kwargs) + metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=False) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) From 2d770489eefedf18e374ba03418b30fbae33b12c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Jul 2023 11:14:57 +0000 Subject: [PATCH 06/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/handlers/dice_ce_metric.py | 1 - monai/handlers/ignite_loss_metric.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/monai/handlers/dice_ce_metric.py b/monai/handlers/dice_ce_metric.py index 549b992ae9..8264868d78 100644 --- a/monai/handlers/dice_ce_metric.py +++ b/monai/handlers/dice_ce_metric.py @@ -15,7 +15,6 @@ from monai.handlers.ignite_metric import IgniteMetric from monai.utils import MetricReduction -from monai.losses import DiceCELoss from monai.metrics import LossMetric from torch.nn.modules.loss import _Loss diff --git a/monai/handlers/ignite_loss_metric.py b/monai/handlers/ignite_loss_metric.py index 50cbb485a5..ec8bd4e33e 100644 --- a/monai/handlers/ignite_loss_metric.py +++ b/monai/handlers/ignite_loss_metric.py @@ -14,9 +14,7 @@ from collections.abc import Callable from monai.handlers.ignite_metric import IgniteMetric -from monai.metrics import DiceMetric from monai.utils import MetricReduction -from monai.losses import DiceCELoss from monai.metrics import LossMetric from torch.nn.modules.loss import _Loss From 2524e0daad4c74d91c8d0769bce416bcc800d485 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Mon, 10 Jul 2023 08:33:37 +0200 Subject: [PATCH 07/25] Undo previous commits as discussed Signed-off-by: Matthias Hadlich --- monai/handlers/__init__.py | 1 - tests/test_handler_dice_ce_metric.py | 98 ---------------------------- 2 files changed, 99 deletions(-) delete mode 100644 tests/test_handler_dice_ce_metric.py diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index e1588fedb1..f032191043 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -42,4 +42,3 @@ from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler -from .ignite_loss_metric import IgniteLossMetric diff --git a/tests/test_handler_dice_ce_metric.py b/tests/test_handler_dice_ce_metric.py deleted file mode 100644 index 4b4059affa..0000000000 --- a/tests/test_handler_dice_ce_metric.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from ignite.engine import Engine, Events -from parameterized import parameterized - -from monai.handlers import DiceCEMetric, from_engine -from tests.utils import assert_allclose - -TEST_CASE_1 = [{"include_background": True, "output_transform": from_engine(["pred", "label"])}, 0.813259, (4, 2)] -TEST_CASE_2 = [{"include_background": False, "output_transform": from_engine(["pred", "label"])}, 0.813259, (4, 1)] -TEST_CASE_3 = [ - {"reduction": "mean_channel", "output_transform": from_engine(["pred", "label"])}, - torch.Tensor([0.313262, 2.313251, 0.313262, 0.313262]), - (4, 2), -] - - -class TestHandlerDiceCEMetric(unittest.TestCase): - # TODO test multi node averaged dice - - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_compute(self, input_params, expected_avg, details_shape): - dice_metric = DiceCEMetric(**input_params) - - def _val_func(engine, batch): - pass - - engine = Engine(_val_func) - dice_metric.attach(engine=engine, name="dice_ce_metric") - # test input a list of channel-first tensor - y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] - y = torch.Tensor([[[0], [1]], [[0], [1]]]) - engine.state.output = {"pred": y_pred, "label": y} - engine.fire_event(Events.ITERATION_COMPLETED) - - y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] - y = torch.Tensor([[[0], [1]], [[1], [0]]]) - engine.state.output = {"pred": y_pred, "label": y} - engine.fire_event(Events.ITERATION_COMPLETED) - - engine.fire_event(Events.EPOCH_COMPLETED) - assert_allclose(engine.state.metrics["dice_ce_metric"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) - self.assertTupleEqual(tuple(engine.state.metric_details["dice_ce_metric"].shape), details_shape) - - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape_mismatch(self, input_params, _expected_avg, _details_shape): - dice_metric = DiceCEMetric(**input_params) - with self.assertRaises((AssertionError, ValueError)): - y_pred = torch.Tensor([[0, 1], [1, 0]]) - y = torch.ones((2, 3)) - dice_metric.update([y_pred, y]) - - with self.assertRaises((AssertionError, ValueError)): - y_pred = torch.Tensor([[0, 1], [1, 0]]) - y = torch.ones((3, 2)) - dice_metric.update([y_pred, y]) - - # @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - # def test_compute_n_class(self, input_params, expected_avg, details_shape): - # dice_metric = DiceCEMetric(num_classes=2, **input_params) - - # def _val_func(engine, batch): - # pass - - # engine = Engine(_val_func) - # dice_metric.attach(engine=engine, name="dice_ce_metric") - # # test input a list of channel-first tensor - # y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] - # y = torch.Tensor([[[0], [1]], [[0], [1]]]) - # engine.state.output = {"pred": y_pred, "label": y} - # engine.fire_event(Events.ITERATION_COMPLETED) - - # y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] # class indices y_pred - # y = torch.Tensor([[[1]], [[0]]]) # class indices y - # engine.state.output = {"pred": y_pred, "label": y} - # engine.fire_event(Events.ITERATION_COMPLETED) - - # engine.fire_event(Events.EPOCH_COMPLETED) - # assert_allclose(engine.state.metrics["dice_ce_metric"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) - # self.assertTupleEqual(tuple(engine.state.metric_details["dice_ce_metric"].shape), details_shape) - - -if __name__ == "__main__": - unittest.main() From 7c57f0011e5adc902143ce933afb11e35420f9a7 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Mon, 10 Jul 2023 09:44:47 +0200 Subject: [PATCH 08/25] Add loss_fn support to IgniteMetric Signed-off-by: Matthias Hadlich --- monai/handlers/ignite_metric.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index 822da0aa18..41dce53537 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -16,16 +16,17 @@ from typing import TYPE_CHECKING, Any import torch +from torch.nn.modules.loss import _Loss from monai.config import IgniteInfo from monai.metrics import CumulativeIterationMetric -from monai.utils import min_version, optional_import +from monai.utils import min_version, optional_import, MetricReduction idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") if TYPE_CHECKING: from ignite.engine import Engine - from ignite.metrics import Metric + from ignite.metrics import Metric, LossMetric from ignite.metrics.metric import reinit__is_reduced else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") @@ -44,6 +45,7 @@ class IgniteMetric(Metric): Args: metric_fn: callable function or class to compute raw metric results after every iteration. expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans). + loss_fn: A torch _Loss function which is used to generate the LossMetric output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. @@ -52,18 +54,25 @@ class IgniteMetric(Metric): https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: mean_dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. + kwargs: keyword argument that will be passed into the LossMetric """ - def __init__( - self, metric_fn: CumulativeIterationMetric, output_transform: Callable = lambda x: x, save_details: bool = True + self, metric_fn: CumulativeIterationMetric = None, loss_fn: _Loss = None, output_transform: Callable = lambda x: x, save_details: bool = True, **kwargs, ) -> None: self._is_reduced: bool = False self.metric_fn = metric_fn + self.loss_fn = loss_fn self.save_details = save_details self._scores: list = [] self._engine: Engine | None = None self._name: str | None = None + + if bool(self.metric_fn) == bool(self.loss_fn): + raise ValueError(f"Either metric_fn or loss_fn have to be passed, but not both.") + if self.loss_fn: + self.metric_fn = LossMetric(loss_fn=self.loss_fn, **kwargs) + super().__init__(output_transform) @reinit__is_reduced From 14b07489160ccd7ae07f4e1d8c5bf3ab1b7138ed Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Mon, 10 Jul 2023 13:26:07 +0200 Subject: [PATCH 09/25] Delete previously created files Signed-off-by: Matthias Hadlich --- monai/handlers/dice_ce_metric.py | 65 ---------------------------- monai/handlers/ignite_loss_metric.py | 65 ---------------------------- 2 files changed, 130 deletions(-) delete mode 100644 monai/handlers/dice_ce_metric.py delete mode 100644 monai/handlers/ignite_loss_metric.py diff --git a/monai/handlers/dice_ce_metric.py b/monai/handlers/dice_ce_metric.py deleted file mode 100644 index 8264868d78..0000000000 --- a/monai/handlers/dice_ce_metric.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from collections.abc import Callable - -from monai.handlers.ignite_metric import IgniteMetric -from monai.utils import MetricReduction -from monai.metrics import LossMetric - -from torch.nn.modules.loss import _Loss - -class LossMetricIgnite(IgniteMetric): - """ - Computes DiceCE score metric from full size Tensor and collects average over batch, class-channels, iterations. - """ - - def __init__( - self, - loss_fn: _Loss, - reduction: MetricReduction | str = MetricReduction.MEAN, - output_transform: Callable = lambda x: x, - save_details: bool = True, - # *args, - # **kwargs - ) -> None: - """ - - Args: - loss_fn: a callable function that takes ``y_pred`` and optionally ``y`` as input (in the "batch-first" format), - returns a "batch-first" tensor of loss values. - include_background: whether to include dice computation on the first channel of the predicted output. - Defaults to True. - reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, - available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. - num_classes: number of input channels (always including the background). When this is None, - ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are - single-channel class indices and the number of classes is not automatically inferred from data. - output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then - construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or - lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - `engine.state` and `output_transform` inherit from the ignite concept: - https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: - https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. - save_details: whether to save metric computation details per image, for example: mean dice of every image. - default to True, will save to `engine.state.metric_details` dict with the metric name as key. - args: Arguments for the DiceCELoss - - See also: - :py:meth:`monai.metrics.meandice.compute_dice` - """ - self.loss_fn = loss_fn - # loss_function = DiceCELoss(*args, **kwargs) - metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=False) - super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/handlers/ignite_loss_metric.py b/monai/handlers/ignite_loss_metric.py deleted file mode 100644 index ec8bd4e33e..0000000000 --- a/monai/handlers/ignite_loss_metric.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from collections.abc import Callable - -from monai.handlers.ignite_metric import IgniteMetric -from monai.utils import MetricReduction -from monai.metrics import LossMetric - -from torch.nn.modules.loss import _Loss - -class IgniteLossMetric(IgniteMetric): - """ - Computes DiceCE score metric from full size Tensor and collects average over batch, class-channels, iterations. - """ - - def __init__( - self, - loss_fn: _Loss, - reduction: MetricReduction | str = MetricReduction.MEAN, - output_transform: Callable = lambda x: x, - save_details: bool = True, - # *args, - # **kwargs - ) -> None: - """ - - Args: - loss_fn: a callable function that takes ``y_pred`` and optionally ``y`` as input (in the "batch-first" format), - returns a "batch-first" tensor of loss values. - include_background: whether to include dice computation on the first channel of the predicted output. - Defaults to True. - reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, - available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. - num_classes: number of input channels (always including the background). When this is None, - ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are - single-channel class indices and the number of classes is not automatically inferred from data. - output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then - construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or - lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - `engine.state` and `output_transform` inherit from the ignite concept: - https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: - https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. - save_details: whether to save metric computation details per image, for example: mean dice of every image. - default to True, will save to `engine.state.metric_details` dict with the metric name as key. - args: Arguments for the DiceCELoss - - See also: - :py:meth:`monai.metrics.meandice.compute_dice` - """ - self.loss_fn = loss_fn - # loss_function = DiceCELoss(*args, **kwargs) - metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=False) - super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) From a3154ee50be70a8efb1b67380ea27fbe46f2e7c5 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Mon, 10 Jul 2023 13:26:39 +0200 Subject: [PATCH 10/25] Modify IgniteMetric to also support loss_fn Signed-off-by: Matthias Hadlich --- monai/handlers/ignite_metric.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index 41dce53537..934dd847ea 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -19,14 +19,14 @@ from torch.nn.modules.loss import _Loss from monai.config import IgniteInfo -from monai.metrics import CumulativeIterationMetric -from monai.utils import min_version, optional_import, MetricReduction +from monai.metrics import CumulativeIterationMetric, LossMetric +from monai.utils import MetricReduction, min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") if TYPE_CHECKING: from ignite.engine import Engine - from ignite.metrics import Metric, LossMetric + from ignite.metrics import LossMetric, Metric from ignite.metrics.metric import reinit__is_reduced else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") @@ -57,8 +57,14 @@ class IgniteMetric(Metric): kwargs: keyword argument that will be passed into the LossMetric """ + def __init__( - self, metric_fn: CumulativeIterationMetric = None, loss_fn: _Loss = None, output_transform: Callable = lambda x: x, save_details: bool = True, **kwargs, + self, + metric_fn: CumulativeIterationMetric = None, + loss_fn: _Loss = None, + output_transform: Callable = lambda x: x, + save_details: bool = True, + **kwargs, ) -> None: self._is_reduced: bool = False self.metric_fn = metric_fn @@ -67,12 +73,14 @@ def __init__( self._scores: list = [] self._engine: Engine | None = None self._name: str | None = None - - if bool(self.metric_fn) == bool(self.loss_fn): + + if self.metric_fn is None and self.loss_fn is None: + raise ValueError(f"Either metric_fn or loss_fn have to be passed.") + if self.metric_fn is not None and self.loss_fn is not None: raise ValueError(f"Either metric_fn or loss_fn have to be passed, but not both.") if self.loss_fn: self.metric_fn = LossMetric(loss_fn=self.loss_fn, **kwargs) - + super().__init__(output_transform) @reinit__is_reduced From 0392f600b6aec068eb47b42c19baafdbc478650f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Jul 2023 11:28:47 +0000 Subject: [PATCH 11/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/handlers/ignite_metric.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index 934dd847ea..108133157a 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -20,7 +20,7 @@ from monai.config import IgniteInfo from monai.metrics import CumulativeIterationMetric, LossMetric -from monai.utils import MetricReduction, min_version, optional_import +from monai.utils import min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") @@ -75,9 +75,9 @@ def __init__( self._name: str | None = None if self.metric_fn is None and self.loss_fn is None: - raise ValueError(f"Either metric_fn or loss_fn have to be passed.") + raise ValueError("Either metric_fn or loss_fn have to be passed.") if self.metric_fn is not None and self.loss_fn is not None: - raise ValueError(f"Either metric_fn or loss_fn have to be passed, but not both.") + raise ValueError("Either metric_fn or loss_fn have to be passed, but not both.") if self.loss_fn: self.metric_fn = LossMetric(loss_fn=self.loss_fn, **kwargs) From a51bd4db56535a2ae5ef84a7237bcc6dd9037809 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Mon, 10 Jul 2023 13:44:58 +0200 Subject: [PATCH 12/25] Add tests for IgniteMetric(Handler) Signed-off-by: Matthias Hadlich --- tests/test_handler_ignite_metric_handler.py | 107 ++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 tests/test_handler_ignite_metric_handler.py diff --git a/tests/test_handler_ignite_metric_handler.py b/tests/test_handler_ignite_metric_handler.py new file mode 100644 index 0000000000..239d7b2c97 --- /dev/null +++ b/tests/test_handler_ignite_metric_handler.py @@ -0,0 +1,107 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from ignite.engine import Engine, Events +from parameterized import parameterized + +from monai.handlers import IgniteMetric, from_engine +from monai.metrics import LossMetric +from tests.utils import assert_allclose +from monai.losses import DiceLoss + +TEST_CASE_1 = [{"include_background": True}, {"output_transform": from_engine(["pred", "label"])}, 0.75, (4, 2)] +TEST_CASE_2 = [{"include_background": False}, {"output_transform": from_engine(["pred", "label"])}, 0.66666, (4, 1)] +TEST_CASE_3 = [ + {"reduction": "none"}, + {"reduction": "mean_channel", "output_transform": from_engine(["pred", "label"])}, + torch.Tensor([1.0, 0.0, 1.0, 1.0]), + (4, 2), +] + + +# test loss_fn +# test metric_fn +# compare loss_fn to metric_fn +# compare dice loss to dice metric + +class TestHandlerIgniteMetricHandler(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_metric_fn(self, loss_params, metric_params, expected_avg, details_shape): + loss_fn = DiceLoss(**loss_params) + metric_fn = LossMetric(loss_fn=loss_fn) + ignite_metric = IgniteMetric(metric_fn=metric_fn, **metric_params) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + ignite_metric.attach(engine=engine, name="ignite_dice_loss") + # test input a list of channel-first tensor + y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] + y = torch.Tensor([[[0], [1]], [[0], [1]]]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] + y = torch.Tensor([[[0], [1]], [[1], [0]]]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) + self.assertTupleEqual(tuple(engine.state.metric_details["ignite_dice_loss"].shape), details_shape) + + # @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + # def test_shape_mismatch(self, input_params, _expected_avg, _details_shape): + # dice_metric = MeanDice(**input_params) + # with self.assertRaises((AssertionError, ValueError)): + # y_pred = torch.Tensor([[0, 1], [1, 0]]) + # y = torch.ones((2, 3)) + # dice_metric.update([y_pred, y]) + + # with self.assertRaises((AssertionError, ValueError)): + # y_pred = torch.Tensor([[0, 1], [1, 0]]) + # y = torch.ones((3, 2)) + # dice_metric.update([y_pred, y]) + + # @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + # def test_compute_n_class(self, input_params, expected_avg, details_shape): + # dice_metric = MeanDice(num_classes=2, **input_params) + + # def _val_func(engine, batch): + # pass + + # engine = Engine(_val_func) + # dice_metric.attach(engine=engine, name="mean_dice") + # # test input a list of channel-first tensor + # y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] + # y = torch.Tensor([[[0], [1]], [[0], [1]]]) + # engine.state.output = {"pred": y_pred, "label": y} + # engine.fire_event(Events.ITERATION_COMPLETED) + + # y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] # class indices y_pred + # y = torch.Tensor([[[1]], [[0]]]) # class indices y + # engine.state.output = {"pred": y_pred, "label": y} + # engine.fire_event(Events.ITERATION_COMPLETED) + + # engine.fire_event(Events.EPOCH_COMPLETED) + # assert_allclose(engine.state.metrics["mean_dice"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) + # self.assertTupleEqual(tuple(engine.state.metric_details["mean_dice"].shape), details_shape) + + +if __name__ == "__main__": + unittest.main() From d76f0ed09109a8f53b32e29058deb7144a41f679 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Mon, 10 Jul 2023 13:46:42 +0200 Subject: [PATCH 13/25] Fix formatting Signed-off-by: Matthias Hadlich --- tests/test_handler_ignite_metric_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_handler_ignite_metric_handler.py b/tests/test_handler_ignite_metric_handler.py index 239d7b2c97..3e7c059334 100644 --- a/tests/test_handler_ignite_metric_handler.py +++ b/tests/test_handler_ignite_metric_handler.py @@ -18,14 +18,14 @@ from parameterized import parameterized from monai.handlers import IgniteMetric, from_engine +from monai.losses import DiceLoss from monai.metrics import LossMetric from tests.utils import assert_allclose -from monai.losses import DiceLoss TEST_CASE_1 = [{"include_background": True}, {"output_transform": from_engine(["pred", "label"])}, 0.75, (4, 2)] TEST_CASE_2 = [{"include_background": False}, {"output_transform": from_engine(["pred", "label"])}, 0.66666, (4, 1)] TEST_CASE_3 = [ - {"reduction": "none"}, + {"reduction": "none"}, {"reduction": "mean_channel", "output_transform": from_engine(["pred", "label"])}, torch.Tensor([1.0, 0.0, 1.0, 1.0]), (4, 2), @@ -37,8 +37,8 @@ # compare loss_fn to metric_fn # compare dice loss to dice metric -class TestHandlerIgniteMetricHandler(unittest.TestCase): +class TestHandlerIgniteMetricHandler(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_metric_fn(self, loss_params, metric_params, expected_avg, details_shape): loss_fn = DiceLoss(**loss_params) From 9d939671e7a9801472b3fe579b233605d2e666d6 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 11 Jul 2023 10:03:17 +0200 Subject: [PATCH 14/25] Update test cases for IgniteMetric(Handler) Signed-off-by: Matthias Hadlich --- tests/test_handler_ignite_metric_handler.py | 175 +++++++++++++------- 1 file changed, 119 insertions(+), 56 deletions(-) diff --git a/tests/test_handler_ignite_metric_handler.py b/tests/test_handler_ignite_metric_handler.py index 3e7c059334..e597decb96 100644 --- a/tests/test_handler_ignite_metric_handler.py +++ b/tests/test_handler_ignite_metric_handler.py @@ -22,85 +22,148 @@ from monai.metrics import LossMetric from tests.utils import assert_allclose -TEST_CASE_1 = [{"include_background": True}, {"output_transform": from_engine(["pred", "label"])}, 0.75, (4, 2)] -TEST_CASE_2 = [{"include_background": False}, {"output_transform": from_engine(["pred", "label"])}, 0.66666, (4, 1)] +TEST_CASE_1 = [ + {"reduction": "none", "include_background": True}, + {}, + {"output_transform": from_engine(["pred", "label"])}, + 0.25, +] +TEST_CASE_2 = [ + {"reduction": "mean", "include_background": False}, + {}, + {"output_transform": from_engine(["pred", "label"])}, + 0.5, +] TEST_CASE_3 = [ {"reduction": "none"}, - {"reduction": "mean_channel", "output_transform": from_engine(["pred", "label"])}, - torch.Tensor([1.0, 0.0, 1.0, 1.0]), - (4, 2), + {"reduction": "mean_channel"}, + {"output_transform": from_engine(["pred", "label"])}, + torch.Tensor([0.5, 0]), ] - -# test loss_fn -# test metric_fn -# compare loss_fn to metric_fn -# compare dice loss to dice metric +TEST_CASES = [ + [ + {"include_background": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + }, + 0, + ], + [ + {"include_background": False, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + }, + 0, + ], + [ + {"include_background": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + "target": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]), + }, + 1, + ], + [ + {"include_background": False, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]), + "target": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + }, + 1, + ], + [ + {"include_background": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[0.0, 1.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + }, + 0.333333, + ], + [ + {"include_background": False, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[0.0, 1.0], [0.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + }, + 0, + ], +] class TestHandlerIgniteMetricHandler(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_metric_fn(self, loss_params, metric_params, expected_avg, details_shape): + def test_metric_fn(self, loss_params, metric_params, handler_params, expected_avg): + loss_fn = DiceLoss(**loss_params) + metric_fn = LossMetric(loss_fn=loss_fn, **metric_params) + ignite_metric = IgniteMetric(metric_fn=metric_fn, **handler_params) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + ignite_metric.attach(engine=engine, name="ignite_dice_loss") + y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) + y = torch.tensor([[[[0.0, 1.0]], [[0.0, 1.0]]]]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) + y = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + print(f"{engine.state.metrics['ignite_dice_loss']}") + print(f"{engine.state.metric_details['ignite_dice_loss'].shape}") + print(f"{engine.state.metric_details['ignite_dice_loss']}") + assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_loss_fn(self, loss_params, metric_params, handler_params, expected_avg): loss_fn = DiceLoss(**loss_params) - metric_fn = LossMetric(loss_fn=loss_fn) - ignite_metric = IgniteMetric(metric_fn=metric_fn, **metric_params) + ignite_metric = IgniteMetric(loss_fn=loss_fn, **handler_params, **metric_params) def _val_func(engine, batch): pass engine = Engine(_val_func) ignite_metric.attach(engine=engine, name="ignite_dice_loss") - # test input a list of channel-first tensor - y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] - y = torch.Tensor([[[0], [1]], [[0], [1]]]) + y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) + y = torch.tensor([[[[0.0, 1.0]], [[0.0, 1.0]]]]) engine.state.output = {"pred": y_pred, "label": y} engine.fire_event(Events.ITERATION_COMPLETED) - y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] - y = torch.Tensor([[[0], [1]], [[1], [0]]]) + y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) + y = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) engine.state.output = {"pred": y_pred, "label": y} engine.fire_event(Events.ITERATION_COMPLETED) engine.fire_event(Events.EPOCH_COMPLETED) + print(f"{engine.state.metrics['ignite_dice_loss']}") + print(f"{engine.state.metric_details['ignite_dice_loss'].shape}") + print(f"{engine.state.metric_details['ignite_dice_loss']}") assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) - self.assertTupleEqual(tuple(engine.state.metric_details["ignite_dice_loss"].shape), details_shape) - - # @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - # def test_shape_mismatch(self, input_params, _expected_avg, _details_shape): - # dice_metric = MeanDice(**input_params) - # with self.assertRaises((AssertionError, ValueError)): - # y_pred = torch.Tensor([[0, 1], [1, 0]]) - # y = torch.ones((2, 3)) - # dice_metric.update([y_pred, y]) - - # with self.assertRaises((AssertionError, ValueError)): - # y_pred = torch.Tensor([[0, 1], [1, 0]]) - # y = torch.ones((3, 2)) - # dice_metric.update([y_pred, y]) - - # @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - # def test_compute_n_class(self, input_params, expected_avg, details_shape): - # dice_metric = MeanDice(num_classes=2, **input_params) - - # def _val_func(engine, batch): - # pass - - # engine = Engine(_val_func) - # dice_metric.attach(engine=engine, name="mean_dice") - # # test input a list of channel-first tensor - # y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] - # y = torch.Tensor([[[0], [1]], [[0], [1]]]) - # engine.state.output = {"pred": y_pred, "label": y} - # engine.fire_event(Events.ITERATION_COMPLETED) - - # y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])] # class indices y_pred - # y = torch.Tensor([[[1]], [[0]]]) # class indices y - # engine.state.output = {"pred": y_pred, "label": y} - # engine.fire_event(Events.ITERATION_COMPLETED) - - # engine.fire_event(Events.EPOCH_COMPLETED) - # assert_allclose(engine.state.metrics["mean_dice"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) - # self.assertTupleEqual(tuple(engine.state.metric_details["mean_dice"].shape), details_shape) + + @parameterized.expand(TEST_CASES) + def test_dice_loss(self, input_param, input_data, expected_val): + loss_fn = DiceLoss(**input_param) + ignite_metric = IgniteMetric(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"])) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + ignite_metric.attach(engine=engine, name="ignite_dice_loss") + y_pred = input_data["input"] + y = input_data["target"] + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_val, atol=1e-4, rtol=1e-4, type_test=False) if __name__ == "__main__": From 974a3ebaffe0fde54900e320ceb25a6bf407fa55 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 11 Jul 2023 11:06:58 +0200 Subject: [PATCH 15/25] Rename IgniteMetric to IgniteMetricHandler Signed-off-by: Matthias Hadlich --- docs/source/handlers.rst | 6 +++--- monai/handlers/__init__.py | 2 +- monai/handlers/confusion_matrix.py | 4 ++-- monai/handlers/hausdorff_distance.py | 4 ++-- monai/handlers/ignite_metric.py | 13 +++++++++---- monai/handlers/mean_dice.py | 4 ++-- monai/handlers/mean_iou.py | 4 ++-- monai/handlers/metrics_reloaded_handler.py | 6 +++--- monai/handlers/panoptic_quality.py | 4 ++-- monai/handlers/regression_metrics.py | 10 +++++----- monai/handlers/roc_auc.py | 4 ++-- monai/handlers/surface_distance.py | 4 ++-- tests/test_handler_ignite_metric_handler.py | 8 ++++---- 13 files changed, 39 insertions(+), 34 deletions(-) diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 7da7f7f50d..270083f717 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -29,9 +29,9 @@ CSV saver :members: -Ignite Metric -------------- -.. autoclass:: IgniteMetric +Ignite Metric Handler +--------------------- +.. autoclass:: IgniteMetricHandler :members: diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index f032191043..641f9aae7d 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -20,7 +20,7 @@ from .earlystop_handler import EarlyStopHandler from .garbage_collector import GarbageCollector from .hausdorff_distance import HausdorffDistance -from .ignite_metric import IgniteMetric +from .ignite_metric import IgniteMetric, IgniteMetricHandler from .logfile_handler import LogfileHandler from .lr_schedule_handler import LrScheduleHandler from .mean_dice import MeanDice diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index e8f8b6b112..89c0f5551f 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import ConfusionMatrixMetric from monai.utils.enums import MetricReduction -class ConfusionMatrix(IgniteMetric): +class ConfusionMatrix(IgniteMetricHandler): """ Compute confusion matrix related metrics from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index ef4136906c..669bf30068 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import HausdorffDistanceMetric from monai.utils import MetricReduction -class HausdorffDistance(IgniteMetric): +class HausdorffDistance(IgniteMetricHandler): """ Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index 108133157a..6458a9d429 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: from ignite.engine import Engine - from ignite.metrics import LossMetric, Metric + from ignite.metrics import Metric from ignite.metrics.metric import reinit__is_reduced else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") @@ -35,8 +35,10 @@ "ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced", as_type="decorator" ) +warnings.warn("IgniteMetric has been renamed to IgniteMetricHandler") -class IgniteMetric(Metric): + +class IgniteMetricHandler(Metric): """ Base Metric class based on ignite event handler mechanism. The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim, @@ -74,9 +76,9 @@ def __init__( self._engine: Engine | None = None self._name: str | None = None - if self.metric_fn is None and self.loss_fn is None: + if self.metric_fn is None and self.loss_fn is None: raise ValueError("Either metric_fn or loss_fn have to be passed.") - if self.metric_fn is not None and self.loss_fn is not None: + if self.metric_fn is not None and self.loss_fn is not None: raise ValueError("Either metric_fn or loss_fn have to be passed, but not both.") if self.loss_fn: self.metric_fn = LossMetric(loss_fn=self.loss_fn, **kwargs) @@ -146,3 +148,6 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override] self._name = name if self.save_details and not hasattr(engine.state, "metric_details"): engine.state.metric_details = {} # type: ignore + + +IgniteMetric = IgniteMetricHandler diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index ebd6d1aabb..7b532a95ed 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import DiceMetric from monai.utils import MetricReduction -class MeanDice(IgniteMetric): +class MeanDice(IgniteMetricHandler): """ Computes Dice score metric from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/monai/handlers/mean_iou.py b/monai/handlers/mean_iou.py index 2fc0d5f8ab..894a9185a3 100644 --- a/monai/handlers/mean_iou.py +++ b/monai/handlers/mean_iou.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import MeanIoU from monai.utils import MetricReduction -class MeanIoUHandler(IgniteMetric): +class MeanIoUHandler(IgniteMetricHandler): """ Computes IoU score metric from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/monai/handlers/metrics_reloaded_handler.py b/monai/handlers/metrics_reloaded_handler.py index dbec98256b..b526d21ee0 100644 --- a/monai/handlers/metrics_reloaded_handler.py +++ b/monai/handlers/metrics_reloaded_handler.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import MetricsReloadedBinary, MetricsReloadedCategorical from monai.utils.enums import MetricReduction -class MetricsReloadedBinaryHandler(IgniteMetric): +class MetricsReloadedBinaryHandler(IgniteMetricHandler): """ Handler of MetricsReloadedBinary, which wraps the binary pairwise metrics of MetricsReloaded. """ @@ -65,7 +65,7 @@ def __init__( super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) -class MetricsReloadedCategoricalHandler(IgniteMetric): +class MetricsReloadedCategoricalHandler(IgniteMetricHandler): """ Handler of MetricsReloadedCategorical, which wraps the categorical pairwise metrics of MetricsReloaded. """ diff --git a/monai/handlers/panoptic_quality.py b/monai/handlers/panoptic_quality.py index 4bf561826c..80b0f71390 100644 --- a/monai/handlers/panoptic_quality.py +++ b/monai/handlers/panoptic_quality.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import PanopticQualityMetric from monai.utils import MetricReduction -class PanopticQuality(IgniteMetric): +class PanopticQuality(IgniteMetricHandler): """ Computes Panoptic quality from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/monai/handlers/regression_metrics.py b/monai/handlers/regression_metrics.py index fee7238491..c5702a7e59 100644 --- a/monai/handlers/regression_metrics.py +++ b/monai/handlers/regression_metrics.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric from monai.utils import MetricReduction -class MeanSquaredError(IgniteMetric): +class MeanSquaredError(IgniteMetricHandler): """ Computes Mean Squared Error from full size Tensor and collects average over batch, iterations. """ @@ -51,7 +51,7 @@ def __init__( super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) -class MeanAbsoluteError(IgniteMetric): +class MeanAbsoluteError(IgniteMetricHandler): """ Computes Mean Absolute Error from full size Tensor and collects average over batch, iterations. """ @@ -84,7 +84,7 @@ def __init__( super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) -class RootMeanSquaredError(IgniteMetric): +class RootMeanSquaredError(IgniteMetricHandler): """ Computes Root Mean Squared Error from full size Tensor and collects average over batch, iterations. """ @@ -117,7 +117,7 @@ def __init__( super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) -class PeakSignalToNoiseRatio(IgniteMetric): +class PeakSignalToNoiseRatio(IgniteMetricHandler): """ Computes Peak Signal to Noise Ratio from full size Tensor and collects average over batch, iterations. """ diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index a521a4cc06..deef17d83c 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import ROCAUCMetric from monai.utils import Average -class ROCAUC(IgniteMetric): +class ROCAUC(IgniteMetricHandler): """ Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`. diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index eb80b41a07..1a002d2e73 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import SurfaceDistanceMetric from monai.utils import MetricReduction -class SurfaceDistance(IgniteMetric): +class SurfaceDistance(IgniteMetricHandler): """ Computes surface distance from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/tests/test_handler_ignite_metric_handler.py b/tests/test_handler_ignite_metric_handler.py index e597decb96..fb6226ad76 100644 --- a/tests/test_handler_ignite_metric_handler.py +++ b/tests/test_handler_ignite_metric_handler.py @@ -17,7 +17,7 @@ from ignite.engine import Engine, Events from parameterized import parameterized -from monai.handlers import IgniteMetric, from_engine +from monai.handlers import IgniteMetricHandler, from_engine from monai.losses import DiceLoss from monai.metrics import LossMetric from tests.utils import assert_allclose @@ -98,7 +98,7 @@ class TestHandlerIgniteMetricHandler(unittest.TestCase): def test_metric_fn(self, loss_params, metric_params, handler_params, expected_avg): loss_fn = DiceLoss(**loss_params) metric_fn = LossMetric(loss_fn=loss_fn, **metric_params) - ignite_metric = IgniteMetric(metric_fn=metric_fn, **handler_params) + ignite_metric = IgniteMetricHandler(metric_fn=metric_fn, **handler_params) def _val_func(engine, batch): pass @@ -124,7 +124,7 @@ def _val_func(engine, batch): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_loss_fn(self, loss_params, metric_params, handler_params, expected_avg): loss_fn = DiceLoss(**loss_params) - ignite_metric = IgniteMetric(loss_fn=loss_fn, **handler_params, **metric_params) + ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, **handler_params, **metric_params) def _val_func(engine, batch): pass @@ -150,7 +150,7 @@ def _val_func(engine, batch): @parameterized.expand(TEST_CASES) def test_dice_loss(self, input_param, input_data, expected_val): loss_fn = DiceLoss(**input_param) - ignite_metric = IgniteMetric(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"])) + ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"])) def _val_func(engine, batch): pass From acceeaba9533e8a699a14f79d9b2209b0fc229d2 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 11 Jul 2023 11:07:42 +0200 Subject: [PATCH 16/25] Rename test_handler_ignite_metric_handler to test_handler_ignite_metric Signed-off-by: Matthias Hadlich --- ...ler_ignite_metric_handler.py => test_handler_ignite_metric.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_handler_ignite_metric_handler.py => test_handler_ignite_metric.py} (100%) diff --git a/tests/test_handler_ignite_metric_handler.py b/tests/test_handler_ignite_metric.py similarity index 100% rename from tests/test_handler_ignite_metric_handler.py rename to tests/test_handler_ignite_metric.py From e77e97bed074f51a2b5f936437fc42743840d196 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 11 Jul 2023 11:18:23 +0200 Subject: [PATCH 17/25] Remove warning Signed-off-by: Matthias Hadlich --- monai/handlers/ignite_metric.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index 6458a9d429..bde560ee6d 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -35,8 +35,6 @@ "ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced", as_type="decorator" ) -warnings.warn("IgniteMetric has been renamed to IgniteMetricHandler") - class IgniteMetricHandler(Metric): """ From 9ef3dac3e02e56ee6a6aea0494c8d5b3e070ffb9 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 11 Jul 2023 11:36:03 +0200 Subject: [PATCH 18/25] Fix ignite ImportError Signed-off-by: Matthias Hadlich --- tests/test_handler_ignite_metric.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_handler_ignite_metric.py b/tests/test_handler_ignite_metric.py index fb6226ad76..2571860bd4 100644 --- a/tests/test_handler_ignite_metric.py +++ b/tests/test_handler_ignite_metric.py @@ -14,13 +14,19 @@ import unittest import torch -from ignite.engine import Engine, Events from parameterized import parameterized from monai.handlers import IgniteMetricHandler, from_engine from monai.losses import DiceLoss from monai.metrics import LossMetric -from tests.utils import assert_allclose +from tests.utils import SkipIfNoModule, assert_allclose, optional_import + +try: + _, has_ignite = optional_import("ignite") + from ignite.engine import Engine, Events +except ImportError: + has_ignite = False + TEST_CASE_1 = [ {"reduction": "none", "include_background": True}, @@ -94,6 +100,7 @@ class TestHandlerIgniteMetricHandler(unittest.TestCase): + @SkipIfNoModule("ignite") @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_metric_fn(self, loss_params, metric_params, handler_params, expected_avg): loss_fn = DiceLoss(**loss_params) @@ -121,6 +128,7 @@ def _val_func(engine, batch): print(f"{engine.state.metric_details['ignite_dice_loss']}") assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) + @SkipIfNoModule("ignite") @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_loss_fn(self, loss_params, metric_params, handler_params, expected_avg): loss_fn = DiceLoss(**loss_params) @@ -147,6 +155,7 @@ def _val_func(engine, batch): print(f"{engine.state.metric_details['ignite_dice_loss']}") assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) + @SkipIfNoModule("ignite") @parameterized.expand(TEST_CASES) def test_dice_loss(self, input_param, input_data, expected_val): loss_fn = DiceLoss(**input_param) From 1388b460bfb3c927dbe3bdfdca87f3df908cf29f Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 11 Jul 2023 13:19:48 +0200 Subject: [PATCH 19/25] Fix typing Signed-off-by: Matthias Hadlich --- monai/handlers/ignite_metric.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index bde560ee6d..c3dccef5e0 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -13,7 +13,7 @@ import warnings from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict import torch from torch.nn.modules.loss import _Loss @@ -60,11 +60,11 @@ class IgniteMetricHandler(Metric): def __init__( self, - metric_fn: CumulativeIterationMetric = None, - loss_fn: _Loss = None, + metric_fn: CumulativeIterationMetric | None = None, + loss_fn: _Loss | None = None, output_transform: Callable = lambda x: x, save_details: bool = True, - **kwargs, + **kwargs: Dict, ) -> None: self._is_reduced: bool = False self.metric_fn = metric_fn From 91f85ae1e9a15211e48f35961322f78628f41915 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 11 Jul 2023 13:57:10 +0200 Subject: [PATCH 20/25] Add deprecation warning for IgniteMetric Signed-off-by: Matthias Hadlich --- monai/handlers/ignite_metric.py | 20 ++++++++++++++------ tests/test_handler_ignite_metric.py | 21 ++++++++++++++++++++- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index c3dccef5e0..3283ad7296 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -13,21 +13,26 @@ import warnings from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, List import torch from torch.nn.modules.loss import _Loss from monai.config import IgniteInfo from monai.metrics import CumulativeIterationMetric, LossMetric -from monai.utils import min_version, optional_import +from monai.utils import deprecated, min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") if TYPE_CHECKING: - from ignite.engine import Engine - from ignite.metrics import Metric - from ignite.metrics.metric import reinit__is_reduced + try: + _, has_ignite = optional_import("ignite") + from ignite.engine import Engine + from ignite.metrics import Metric + from ignite.metrics.metric import reinit__is_reduced + except ImportError: + has_ignite = False + else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base") @@ -148,4 +153,7 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override] engine.state.metric_details = {} # type: ignore -IgniteMetric = IgniteMetricHandler +@deprecated(since="1.3", removed="1.4", msg_suffix="Use IgniteMetricHandler instead of IgniteMetric.") +class IgniteMetric(IgniteMetricHandler): + def __init__(self, *args: List, **kwargs: Dict): + super().__init__(*args, **kwargs) diff --git a/tests/test_handler_ignite_metric.py b/tests/test_handler_ignite_metric.py index 2571860bd4..c0e1177ad1 100644 --- a/tests/test_handler_ignite_metric.py +++ b/tests/test_handler_ignite_metric.py @@ -16,7 +16,7 @@ import torch from parameterized import parameterized -from monai.handlers import IgniteMetricHandler, from_engine +from monai.handlers import IgniteMetric, IgniteMetricHandler, from_engine from monai.losses import DiceLoss from monai.metrics import LossMetric from tests.utils import SkipIfNoModule, assert_allclose, optional_import @@ -174,6 +174,25 @@ def _val_func(engine, batch): engine.fire_event(Events.EPOCH_COMPLETED) assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_val, atol=1e-4, rtol=1e-4, type_test=False) + @SkipIfNoModule("ignite") + @parameterized.expand(TEST_CASES[0:2]) + def test_old_ignite_metric(self, input_param, input_data, expected_val): + loss_fn = DiceLoss(**input_param) + ignite_metric = IgniteMetric(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"])) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + ignite_metric.attach(engine=engine, name="ignite_dice_loss") + y_pred = input_data["input"] + y = input_data["target"] + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_val, atol=1e-4, rtol=1e-4, type_test=False) + if __name__ == "__main__": unittest.main() From 2c5d188a38c9e69f75e4dbad43744cd73c747821 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Jul 2023 11:58:24 +0000 Subject: [PATCH 21/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Matthias Hadlich --- monai/handlers/ignite_metric.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index 3283ad7296..e45f3a1c7a 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -13,7 +13,7 @@ import warnings from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any import torch from torch.nn.modules.loss import _Loss @@ -69,7 +69,7 @@ def __init__( loss_fn: _Loss | None = None, output_transform: Callable = lambda x: x, save_details: bool = True, - **kwargs: Dict, + **kwargs: dict, ) -> None: self._is_reduced: bool = False self.metric_fn = metric_fn @@ -155,5 +155,5 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override] @deprecated(since="1.3", removed="1.4", msg_suffix="Use IgniteMetricHandler instead of IgniteMetric.") class IgniteMetric(IgniteMetricHandler): - def __init__(self, *args: List, **kwargs: Dict): + def __init__(self, *args: list, **kwargs: dict): super().__init__(*args, **kwargs) From c8027b26f763118c63392458b912845b6d2a3af0 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 11 Jul 2023 14:34:47 +0200 Subject: [PATCH 22/25] Add test_handler_ignite_metric to the min_tests list Signed-off-by: Matthias Hadlich --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index f553dc4a50..e3b09e7c84 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -83,6 +83,7 @@ def run_testsuit(): "test_handler_early_stop", "test_handler_garbage_collector", "test_handler_hausdorff_distance", + "test_handler_ignite_metric", "test_handler_lr_scheduler", "test_handler_mean_dice", "test_handler_panoptic_quality", From f54a1bf5fd625dab84283da176fe1d8f5ae014c9 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Wed, 12 Jul 2023 17:51:32 +0200 Subject: [PATCH 23/25] Fix code formatting Signed-off-by: Matthias Hadlich --- monai/handlers/ignite_metric.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index e45f3a1c7a..dcb0bd3635 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -155,5 +155,14 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override] @deprecated(since="1.3", removed="1.4", msg_suffix="Use IgniteMetricHandler instead of IgniteMetric.") class IgniteMetric(IgniteMetricHandler): - def __init__(self, *args: list, **kwargs: dict): - super().__init__(*args, **kwargs) + def __init__( + self, + metric_fn: CumulativeIterationMetric | None = None, + loss_fn: _Loss | None = None, + output_transform: Callable = lambda x: x, + save_details: bool = True, + **kwargs: dict, + ) -> None: + super().__init__( + metric_fn=metric_fn, loss_fn=loss_fn, output_transform=output_transform, save_details=save_details, **kwargs + ) From 966c99aaa300b38e4143b106a9f8a81bbab14604 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Wed, 12 Jul 2023 20:21:42 +0200 Subject: [PATCH 24/25] Fix code formatting and remove debug prints Signed-off-by: Matthias Hadlich --- monai/handlers/ignite_metric.py | 26 ++++++++++++++++++-------- tests/test_handler_ignite_metric.py | 6 ------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index dcb0bd3635..8b614c44fc 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -13,14 +13,14 @@ import warnings from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch from torch.nn.modules.loss import _Loss from monai.config import IgniteInfo from monai.metrics import CumulativeIterationMetric, LossMetric -from monai.utils import deprecated, min_version, optional_import +from monai.utils import MetricReduction, deprecated, min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") @@ -59,7 +59,8 @@ class IgniteMetricHandler(Metric): https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: mean_dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. - kwargs: keyword argument that will be passed into the LossMetric + reduction: Argument for the LossMetric, look there for details + get_not_nans: Argument for the LossMetric, look there for details """ @@ -69,10 +70,11 @@ def __init__( loss_fn: _Loss | None = None, output_transform: Callable = lambda x: x, save_details: bool = True, - **kwargs: dict, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, ) -> None: self._is_reduced: bool = False - self.metric_fn = metric_fn + self.metric_fn: CumulativeIterationMetric = cast(CumulativeIterationMetric, metric_fn) self.loss_fn = loss_fn self.save_details = save_details self._scores: list = [] @@ -84,7 +86,9 @@ def __init__( if self.metric_fn is not None and self.loss_fn is not None: raise ValueError("Either metric_fn or loss_fn have to be passed, but not both.") if self.loss_fn: - self.metric_fn = LossMetric(loss_fn=self.loss_fn, **kwargs) + self.metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=get_not_nans) + # else: + # self.metric_fn = cast(CumulativeIterationMetric, metric_fn) super().__init__(output_transform) @@ -161,8 +165,14 @@ def __init__( loss_fn: _Loss | None = None, output_transform: Callable = lambda x: x, save_details: bool = True, - **kwargs: dict, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, ) -> None: super().__init__( - metric_fn=metric_fn, loss_fn=loss_fn, output_transform=output_transform, save_details=save_details, **kwargs + metric_fn=metric_fn, + loss_fn=loss_fn, + output_transform=output_transform, + save_details=save_details, + reduction=reduction, + get_not_nans=get_not_nans, ) diff --git a/tests/test_handler_ignite_metric.py b/tests/test_handler_ignite_metric.py index c0e1177ad1..c98bc8ac78 100644 --- a/tests/test_handler_ignite_metric.py +++ b/tests/test_handler_ignite_metric.py @@ -123,9 +123,6 @@ def _val_func(engine, batch): engine.fire_event(Events.ITERATION_COMPLETED) engine.fire_event(Events.EPOCH_COMPLETED) - print(f"{engine.state.metrics['ignite_dice_loss']}") - print(f"{engine.state.metric_details['ignite_dice_loss'].shape}") - print(f"{engine.state.metric_details['ignite_dice_loss']}") assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) @SkipIfNoModule("ignite") @@ -150,9 +147,6 @@ def _val_func(engine, batch): engine.fire_event(Events.ITERATION_COMPLETED) engine.fire_event(Events.EPOCH_COMPLETED) - print(f"{engine.state.metrics['ignite_dice_loss']}") - print(f"{engine.state.metric_details['ignite_dice_loss'].shape}") - print(f"{engine.state.metric_details['ignite_dice_loss']}") assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) @SkipIfNoModule("ignite") From f92824a2c75bf070016fadec80d5753d3776e653 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Thu, 13 Jul 2023 08:42:30 +0200 Subject: [PATCH 25/25] Remove commented code Signed-off-by: Matthias Hadlich --- monai/handlers/ignite_metric.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index 8b614c44fc..3f30b62c0c 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -87,8 +87,6 @@ def __init__( raise ValueError("Either metric_fn or loss_fn have to be passed, but not both.") if self.loss_fn: self.metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=get_not_nans) - # else: - # self.metric_fn = cast(CumulativeIterationMetric, metric_fn) super().__init__(output_transform)