diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index dcb0bd3635..8b614c44fc 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -13,14 +13,14 @@ import warnings from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch from torch.nn.modules.loss import _Loss from monai.config import IgniteInfo from monai.metrics import CumulativeIterationMetric, LossMetric -from monai.utils import deprecated, min_version, optional_import +from monai.utils import MetricReduction, deprecated, min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") @@ -59,7 +59,8 @@ class IgniteMetricHandler(Metric): https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. save_details: whether to save metric computation details per image, for example: mean_dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. - kwargs: keyword argument that will be passed into the LossMetric + reduction: Argument for the LossMetric, look there for details + get_not_nans: Argument for the LossMetric, look there for details """ @@ -69,10 +70,11 @@ def __init__( loss_fn: _Loss | None = None, output_transform: Callable = lambda x: x, save_details: bool = True, - **kwargs: dict, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, ) -> None: self._is_reduced: bool = False - self.metric_fn = metric_fn + self.metric_fn: CumulativeIterationMetric = cast(CumulativeIterationMetric, metric_fn) self.loss_fn = loss_fn self.save_details = save_details self._scores: list = [] @@ -84,7 +86,9 @@ def __init__( if self.metric_fn is not None and self.loss_fn is not None: raise ValueError("Either metric_fn or loss_fn have to be passed, but not both.") if self.loss_fn: - self.metric_fn = LossMetric(loss_fn=self.loss_fn, **kwargs) + self.metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=get_not_nans) + # else: + # self.metric_fn = cast(CumulativeIterationMetric, metric_fn) super().__init__(output_transform) @@ -161,8 +165,14 @@ def __init__( loss_fn: _Loss | None = None, output_transform: Callable = lambda x: x, save_details: bool = True, - **kwargs: dict, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, ) -> None: super().__init__( - metric_fn=metric_fn, loss_fn=loss_fn, output_transform=output_transform, save_details=save_details, **kwargs + metric_fn=metric_fn, + loss_fn=loss_fn, + output_transform=output_transform, + save_details=save_details, + reduction=reduction, + get_not_nans=get_not_nans, ) diff --git a/tests/test_handler_ignite_metric.py b/tests/test_handler_ignite_metric.py index c0e1177ad1..c98bc8ac78 100644 --- a/tests/test_handler_ignite_metric.py +++ b/tests/test_handler_ignite_metric.py @@ -123,9 +123,6 @@ def _val_func(engine, batch): engine.fire_event(Events.ITERATION_COMPLETED) engine.fire_event(Events.EPOCH_COMPLETED) - print(f"{engine.state.metrics['ignite_dice_loss']}") - print(f"{engine.state.metric_details['ignite_dice_loss'].shape}") - print(f"{engine.state.metric_details['ignite_dice_loss']}") assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) @SkipIfNoModule("ignite") @@ -150,9 +147,6 @@ def _val_func(engine, batch): engine.fire_event(Events.ITERATION_COMPLETED) engine.fire_event(Events.EPOCH_COMPLETED) - print(f"{engine.state.metrics['ignite_dice_loss']}") - print(f"{engine.state.metric_details['ignite_dice_loss'].shape}") - print(f"{engine.state.metric_details['ignite_dice_loss']}") assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False) @SkipIfNoModule("ignite")