diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 7da7f7f50d..270083f717 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -29,9 +29,9 @@ CSV saver :members: -Ignite Metric -------------- -.. autoclass:: IgniteMetric +Ignite Metric Handler +--------------------- +.. autoclass:: IgniteMetricHandler :members: diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index f032191043..641f9aae7d 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -20,7 +20,7 @@ from .earlystop_handler import EarlyStopHandler from .garbage_collector import GarbageCollector from .hausdorff_distance import HausdorffDistance -from .ignite_metric import IgniteMetric +from .ignite_metric import IgniteMetric, IgniteMetricHandler from .logfile_handler import LogfileHandler from .lr_schedule_handler import LrScheduleHandler from .mean_dice import MeanDice diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index e8f8b6b112..89c0f5551f 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import ConfusionMatrixMetric from monai.utils.enums import MetricReduction -class ConfusionMatrix(IgniteMetric): +class ConfusionMatrix(IgniteMetricHandler): """ Compute confusion matrix related metrics from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index ef4136906c..669bf30068 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import HausdorffDistanceMetric from monai.utils import MetricReduction -class HausdorffDistance(IgniteMetric): +class HausdorffDistance(IgniteMetricHandler): """ Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index 822da0aa18..3f30b62c0c 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -13,20 +13,26 @@ import warnings from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch +from torch.nn.modules.loss import _Loss from monai.config import IgniteInfo -from monai.metrics import CumulativeIterationMetric -from monai.utils import min_version, optional_import +from monai.metrics import CumulativeIterationMetric, LossMetric +from monai.utils import MetricReduction, deprecated, min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") if TYPE_CHECKING: - from ignite.engine import Engine - from ignite.metrics import Metric - from ignite.metrics.metric import reinit__is_reduced + try: + _, has_ignite = optional_import("ignite") + from ignite.engine import Engine + from ignite.metrics import Metric + from ignite.metrics.metric import reinit__is_reduced + except ImportError: + has_ignite = False + else: Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base") @@ -35,7 +41,7 @@ ) -class IgniteMetric(Metric): +class IgniteMetricHandler(Metric): """ Base Metric class based on ignite event handler mechanism. The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim, @@ -44,6 +50,7 @@ class IgniteMetric(Metric): Args: metric_fn: callable function or class to compute raw metric results after every iteration. expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans). + loss_fn: A torch _Loss function which is used to generate the LossMetric output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. @@ -52,18 +59,35 @@ class IgniteMetric(Metric): https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: mean_dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. + reduction: Argument for the LossMetric, look there for details + get_not_nans: Argument for the LossMetric, look there for details """ def __init__( - self, metric_fn: CumulativeIterationMetric, output_transform: Callable = lambda x: x, save_details: bool = True + self, + metric_fn: CumulativeIterationMetric | None = None, + loss_fn: _Loss | None = None, + output_transform: Callable = lambda x: x, + save_details: bool = True, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, ) -> None: self._is_reduced: bool = False - self.metric_fn = metric_fn + self.metric_fn: CumulativeIterationMetric = cast(CumulativeIterationMetric, metric_fn) + self.loss_fn = loss_fn self.save_details = save_details self._scores: list = [] self._engine: Engine | None = None self._name: str | None = None + + if self.metric_fn is None and self.loss_fn is None: + raise ValueError("Either metric_fn or loss_fn have to be passed.") + if self.metric_fn is not None and self.loss_fn is not None: + raise ValueError("Either metric_fn or loss_fn have to be passed, but not both.") + if self.loss_fn: + self.metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=get_not_nans) + super().__init__(output_transform) @reinit__is_reduced @@ -129,3 +153,24 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override] self._name = name if self.save_details and not hasattr(engine.state, "metric_details"): engine.state.metric_details = {} # type: ignore + + +@deprecated(since="1.3", removed="1.4", msg_suffix="Use IgniteMetricHandler instead of IgniteMetric.") +class IgniteMetric(IgniteMetricHandler): + def __init__( + self, + metric_fn: CumulativeIterationMetric | None = None, + loss_fn: _Loss | None = None, + output_transform: Callable = lambda x: x, + save_details: bool = True, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + ) -> None: + super().__init__( + metric_fn=metric_fn, + loss_fn=loss_fn, + output_transform=output_transform, + save_details=save_details, + reduction=reduction, + get_not_nans=get_not_nans, + ) diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index ebd6d1aabb..7b532a95ed 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import DiceMetric from monai.utils import MetricReduction -class MeanDice(IgniteMetric): +class MeanDice(IgniteMetricHandler): """ Computes Dice score metric from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/monai/handlers/mean_iou.py b/monai/handlers/mean_iou.py index 2fc0d5f8ab..894a9185a3 100644 --- a/monai/handlers/mean_iou.py +++ b/monai/handlers/mean_iou.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import MeanIoU from monai.utils import MetricReduction -class MeanIoUHandler(IgniteMetric): +class MeanIoUHandler(IgniteMetricHandler): """ Computes IoU score metric from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/monai/handlers/metrics_reloaded_handler.py b/monai/handlers/metrics_reloaded_handler.py index dbec98256b..b526d21ee0 100644 --- a/monai/handlers/metrics_reloaded_handler.py +++ b/monai/handlers/metrics_reloaded_handler.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import MetricsReloadedBinary, MetricsReloadedCategorical from monai.utils.enums import MetricReduction -class MetricsReloadedBinaryHandler(IgniteMetric): +class MetricsReloadedBinaryHandler(IgniteMetricHandler): """ Handler of MetricsReloadedBinary, which wraps the binary pairwise metrics of MetricsReloaded. """ @@ -65,7 +65,7 @@ def __init__( super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) -class MetricsReloadedCategoricalHandler(IgniteMetric): +class MetricsReloadedCategoricalHandler(IgniteMetricHandler): """ Handler of MetricsReloadedCategorical, which wraps the categorical pairwise metrics of MetricsReloaded. """ diff --git a/monai/handlers/panoptic_quality.py b/monai/handlers/panoptic_quality.py index 4bf561826c..80b0f71390 100644 --- a/monai/handlers/panoptic_quality.py +++ b/monai/handlers/panoptic_quality.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import PanopticQualityMetric from monai.utils import MetricReduction -class PanopticQuality(IgniteMetric): +class PanopticQuality(IgniteMetricHandler): """ Computes Panoptic quality from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/monai/handlers/regression_metrics.py b/monai/handlers/regression_metrics.py index fee7238491..c5702a7e59 100644 --- a/monai/handlers/regression_metrics.py +++ b/monai/handlers/regression_metrics.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric from monai.utils import MetricReduction -class MeanSquaredError(IgniteMetric): +class MeanSquaredError(IgniteMetricHandler): """ Computes Mean Squared Error from full size Tensor and collects average over batch, iterations. """ @@ -51,7 +51,7 @@ def __init__( super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) -class MeanAbsoluteError(IgniteMetric): +class MeanAbsoluteError(IgniteMetricHandler): """ Computes Mean Absolute Error from full size Tensor and collects average over batch, iterations. """ @@ -84,7 +84,7 @@ def __init__( super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) -class RootMeanSquaredError(IgniteMetric): +class RootMeanSquaredError(IgniteMetricHandler): """ Computes Root Mean Squared Error from full size Tensor and collects average over batch, iterations. """ @@ -117,7 +117,7 @@ def __init__( super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) -class PeakSignalToNoiseRatio(IgniteMetric): +class PeakSignalToNoiseRatio(IgniteMetricHandler): """ Computes Peak Signal to Noise Ratio from full size Tensor and collects average over batch, iterations. """ diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index a521a4cc06..deef17d83c 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import ROCAUCMetric from monai.utils import Average -class ROCAUC(IgniteMetric): +class ROCAUC(IgniteMetricHandler): """ Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`. diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index eb80b41a07..1a002d2e73 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -13,12 +13,12 @@ from collections.abc import Callable -from monai.handlers.ignite_metric import IgniteMetric +from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import SurfaceDistanceMetric from monai.utils import MetricReduction -class SurfaceDistance(IgniteMetric): +class SurfaceDistance(IgniteMetricHandler): """ Computes surface distance from full size Tensor and collects average over batch, class-channels, iterations. """ diff --git a/tests/min_tests.py b/tests/min_tests.py index f553dc4a50..e3b09e7c84 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -83,6 +83,7 @@ def run_testsuit(): "test_handler_early_stop", "test_handler_garbage_collector", "test_handler_hausdorff_distance", + "test_handler_ignite_metric", "test_handler_lr_scheduler", "test_handler_mean_dice", "test_handler_panoptic_quality", diff --git a/tests/test_handler_ignite_metric.py b/tests/test_handler_ignite_metric.py new file mode 100644 index 0000000000..c98bc8ac78 --- /dev/null +++ b/tests/test_handler_ignite_metric.py @@ -0,0 +1,192 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.handlers import IgniteMetric, IgniteMetricHandler, from_engine +from monai.losses import DiceLoss +from monai.metrics import LossMetric +from tests.utils import SkipIfNoModule, assert_allclose, optional_import + +try: + _, has_ignite = optional_import("ignite") + from ignite.engine import Engine, Events +except ImportError: + has_ignite = False + + +TEST_CASE_1 = [ + {"reduction": "none", "include_background": True}, + {}, + {"output_transform": from_engine(["pred", "label"])}, + 0.25, +] +TEST_CASE_2 = [ + {"reduction": "mean", "include_background": False}, + {}, + {"output_transform": from_engine(["pred", "label"])}, + 0.5, +] +TEST_CASE_3 = [ + {"reduction": "none"}, + {"reduction": "mean_channel"}, + {"output_transform": from_engine(["pred", "label"])}, + torch.Tensor([0.5, 0]), +] + +TEST_CASES = [ + [ + {"include_background": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + }, + 0, + ], + [ + {"include_background": False, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + }, + 0, + ], + [ + {"include_background": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + "target": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]), + }, + 1, + ], + [ + {"include_background": False, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]), + "target": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + }, + 1, + ], + [ + {"include_background": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[0.0, 1.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + }, + 0.333333, + ], + [ + {"include_background": False, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, + { + "input": torch.tensor([[[[0.0, 1.0], [0.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]]]), + }, + 0, + ], +] + + +class TestHandlerIgniteMetricHandler(unittest.TestCase): + @SkipIfNoModule("ignite") + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_metric_fn(self, loss_params, metric_params, handler_params, expected_avg): + loss_fn = DiceLoss(**loss_params) + metric_fn = LossMetric(loss_fn=loss_fn, **metric_params) + ignite_metric = IgniteMetricHandler(metric_fn=metric_fn, **handler_params) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + ignite_metric.attach(engine=engine, name="ignite_dice_loss") + y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) + y = torch.tensor([[[[0.0, 1.0]], [[0.0, 1.0]]]]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) + y = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) + + @SkipIfNoModule("ignite") + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_loss_fn(self, loss_params, metric_params, handler_params, expected_avg): + loss_fn = DiceLoss(**loss_params) + ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, **handler_params, **metric_params) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + ignite_metric.attach(engine=engine, name="ignite_dice_loss") + y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) + y = torch.tensor([[[[0.0, 1.0]], [[0.0, 1.0]]]]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) + y = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]]) + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) + + @SkipIfNoModule("ignite") + @parameterized.expand(TEST_CASES) + def test_dice_loss(self, input_param, input_data, expected_val): + loss_fn = DiceLoss(**input_param) + ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"])) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + ignite_metric.attach(engine=engine, name="ignite_dice_loss") + y_pred = input_data["input"] + y = input_data["target"] + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_val, atol=1e-4, rtol=1e-4, type_test=False) + + @SkipIfNoModule("ignite") + @parameterized.expand(TEST_CASES[0:2]) + def test_old_ignite_metric(self, input_param, input_data, expected_val): + loss_fn = DiceLoss(**input_param) + ignite_metric = IgniteMetric(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"])) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + ignite_metric.attach(engine=engine, name="ignite_dice_loss") + y_pred = input_data["input"] + y = input_data["target"] + engine.state.output = {"pred": y_pred, "label": y} + engine.fire_event(Events.ITERATION_COMPLETED) + + engine.fire_event(Events.EPOCH_COMPLETED) + assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_val, atol=1e-4, rtol=1e-4, type_test=False) + + +if __name__ == "__main__": + unittest.main()