Skip to content

Commit

Permalink
Rename IgniteMetric to IgniteMetricHandler
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Hadlich <matthiashadlich@posteo.de>
  • Loading branch information
matt3o committed Jul 11, 2023
1 parent 9d93967 commit 974a3eb
Show file tree
Hide file tree
Showing 13 changed files with 39 additions and 34 deletions.
6 changes: 3 additions & 3 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ CSV saver
:members:


Ignite Metric
-------------
.. autoclass:: IgniteMetric
Ignite Metric Handler
---------------------
.. autoclass:: IgniteMetricHandler
:members:


Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .earlystop_handler import EarlyStopHandler
from .garbage_collector import GarbageCollector
from .hausdorff_distance import HausdorffDistance
from .ignite_metric import IgniteMetric
from .ignite_metric import IgniteMetric, IgniteMetricHandler
from .logfile_handler import LogfileHandler
from .lr_schedule_handler import LrScheduleHandler
from .mean_dice import MeanDice
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import ConfusionMatrixMetric
from monai.utils.enums import MetricReduction


class ConfusionMatrix(IgniteMetric):
class ConfusionMatrix(IgniteMetricHandler):
"""
Compute confusion matrix related metrics from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import HausdorffDistanceMetric
from monai.utils import MetricReduction


class HausdorffDistance(IgniteMetric):
class HausdorffDistance(IgniteMetricHandler):
"""
Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
13 changes: 9 additions & 4 deletions monai/handlers/ignite_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

if TYPE_CHECKING:
from ignite.engine import Engine
from ignite.metrics import LossMetric, Metric
from ignite.metrics import Metric
from ignite.metrics.metric import reinit__is_reduced
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
Expand All @@ -35,8 +35,10 @@
"ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced", as_type="decorator"
)

warnings.warn("IgniteMetric has been renamed to IgniteMetricHandler")

class IgniteMetric(Metric):

class IgniteMetricHandler(Metric):
"""
Base Metric class based on ignite event handler mechanism.
The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim,
Expand Down Expand Up @@ -74,9 +76,9 @@ def __init__(
self._engine: Engine | None = None
self._name: str | None = None

if self.metric_fn is None and self.loss_fn is None:
if self.metric_fn is None and self.loss_fn is None:
raise ValueError("Either metric_fn or loss_fn have to be passed.")
if self.metric_fn is not None and self.loss_fn is not None:
if self.metric_fn is not None and self.loss_fn is not None:
raise ValueError("Either metric_fn or loss_fn have to be passed, but not both.")
if self.loss_fn:
self.metric_fn = LossMetric(loss_fn=self.loss_fn, **kwargs)
Expand Down Expand Up @@ -146,3 +148,6 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override]
self._name = name
if self.save_details and not hasattr(engine.state, "metric_details"):
engine.state.metric_details = {} # type: ignore


IgniteMetric = IgniteMetricHandler
4 changes: 2 additions & 2 deletions monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import DiceMetric
from monai.utils import MetricReduction


class MeanDice(IgniteMetric):
class MeanDice(IgniteMetricHandler):
"""
Computes Dice score metric from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import MeanIoU
from monai.utils import MetricReduction


class MeanIoUHandler(IgniteMetric):
class MeanIoUHandler(IgniteMetricHandler):
"""
Computes IoU score metric from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
6 changes: 3 additions & 3 deletions monai/handlers/metrics_reloaded_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import MetricsReloadedBinary, MetricsReloadedCategorical
from monai.utils.enums import MetricReduction


class MetricsReloadedBinaryHandler(IgniteMetric):
class MetricsReloadedBinaryHandler(IgniteMetricHandler):
"""
Handler of MetricsReloadedBinary, which wraps the binary pairwise metrics of MetricsReloaded.
"""
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class MetricsReloadedCategoricalHandler(IgniteMetric):
class MetricsReloadedCategoricalHandler(IgniteMetricHandler):
"""
Handler of MetricsReloadedCategorical, which wraps the categorical pairwise metrics of MetricsReloaded.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import PanopticQualityMetric
from monai.utils import MetricReduction


class PanopticQuality(IgniteMetric):
class PanopticQuality(IgniteMetricHandler):
"""
Computes Panoptic quality from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
10 changes: 5 additions & 5 deletions monai/handlers/regression_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric
from monai.utils import MetricReduction


class MeanSquaredError(IgniteMetric):
class MeanSquaredError(IgniteMetricHandler):
"""
Computes Mean Squared Error from full size Tensor and collects average over batch, iterations.
"""
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class MeanAbsoluteError(IgniteMetric):
class MeanAbsoluteError(IgniteMetricHandler):
"""
Computes Mean Absolute Error from full size Tensor and collects average over batch, iterations.
"""
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class RootMeanSquaredError(IgniteMetric):
class RootMeanSquaredError(IgniteMetricHandler):
"""
Computes Root Mean Squared Error from full size Tensor and collects average over batch, iterations.
"""
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class PeakSignalToNoiseRatio(IgniteMetric):
class PeakSignalToNoiseRatio(IgniteMetricHandler):
"""
Computes Peak Signal to Noise Ratio from full size Tensor and collects average over batch, iterations.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import ROCAUCMetric
from monai.utils import Average


class ROCAUC(IgniteMetric):
class ROCAUC(IgniteMetricHandler):
"""
Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC).
accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`.
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import SurfaceDistanceMetric
from monai.utils import MetricReduction


class SurfaceDistance(IgniteMetric):
class SurfaceDistance(IgniteMetricHandler):
"""
Computes surface distance from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/test_handler_ignite_metric_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ignite.engine import Engine, Events
from parameterized import parameterized

from monai.handlers import IgniteMetric, from_engine
from monai.handlers import IgniteMetricHandler, from_engine
from monai.losses import DiceLoss
from monai.metrics import LossMetric
from tests.utils import assert_allclose
Expand Down Expand Up @@ -98,7 +98,7 @@ class TestHandlerIgniteMetricHandler(unittest.TestCase):
def test_metric_fn(self, loss_params, metric_params, handler_params, expected_avg):
loss_fn = DiceLoss(**loss_params)
metric_fn = LossMetric(loss_fn=loss_fn, **metric_params)
ignite_metric = IgniteMetric(metric_fn=metric_fn, **handler_params)
ignite_metric = IgniteMetricHandler(metric_fn=metric_fn, **handler_params)

def _val_func(engine, batch):
pass
Expand All @@ -124,7 +124,7 @@ def _val_func(engine, batch):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_loss_fn(self, loss_params, metric_params, handler_params, expected_avg):
loss_fn = DiceLoss(**loss_params)
ignite_metric = IgniteMetric(loss_fn=loss_fn, **handler_params, **metric_params)
ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, **handler_params, **metric_params)

def _val_func(engine, batch):
pass
Expand All @@ -150,7 +150,7 @@ def _val_func(engine, batch):
@parameterized.expand(TEST_CASES)
def test_dice_loss(self, input_param, input_data, expected_val):
loss_fn = DiceLoss(**input_param)
ignite_metric = IgniteMetric(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"]))
ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"]))

def _val_func(engine, batch):
pass
Expand Down

0 comments on commit 974a3eb

Please sign in to comment.