Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loss_fn to IgniteMetric and rename to IgniteMetricHandler #6695

Merged
merged 27 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
324e53f
Add DiceCEMetric
matt3o Jul 3, 2023
4384f4c
WiP: Add unittest for DiceCEMetric
matt3o Jul 3, 2023
94cb06f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2023
48c6ef9
Remove DiceCEMetric
matt3o Jul 4, 2023
6815816
Add IgniteLossMetric
matt3o Jul 4, 2023
2d77048
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2023
84ff449
Merge branch 'dev' into Add_dice_ce_metric
matt3o Jul 10, 2023
2524e0d
Undo previous commits as discussed
matt3o Jul 10, 2023
7c57f00
Add loss_fn support to IgniteMetric
matt3o Jul 10, 2023
14b0748
Delete previously created files
matt3o Jul 10, 2023
a3154ee
Modify IgniteMetric to also support loss_fn
matt3o Jul 10, 2023
0392f60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 10, 2023
a51bd4d
Add tests for IgniteMetric(Handler)
matt3o Jul 10, 2023
d76f0ed
Fix formatting
matt3o Jul 10, 2023
9d93967
Update test cases for IgniteMetric(Handler)
matt3o Jul 11, 2023
974a3eb
Rename IgniteMetric to IgniteMetricHandler
matt3o Jul 11, 2023
acceeab
Rename test_handler_ignite_metric_handler to test_handler_ignite_metric
matt3o Jul 11, 2023
e77e97b
Remove warning
matt3o Jul 11, 2023
9ef3dac
Fix ignite ImportError
matt3o Jul 11, 2023
1388b46
Fix typing
matt3o Jul 11, 2023
91f85ae
Add deprecation warning for IgniteMetric
matt3o Jul 11, 2023
2c5d188
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2023
c8027b2
Add test_handler_ignite_metric to the min_tests list
matt3o Jul 11, 2023
f54a1bf
Fix code formatting
matt3o Jul 12, 2023
966c99a
Fix code formatting and remove debug prints
matt3o Jul 12, 2023
1a4b133
Merge branch 'dev' into Add_dice_ce_metric
wyli Jul 12, 2023
f92824a
Remove commented code
matt3o Jul 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ CSV saver
:members:


Ignite Metric
-------------
.. autoclass:: IgniteMetric
Ignite Metric Handler
---------------------
.. autoclass:: IgniteMetricHandler
:members:


Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .earlystop_handler import EarlyStopHandler
from .garbage_collector import GarbageCollector
from .hausdorff_distance import HausdorffDistance
from .ignite_metric import IgniteMetric
from .ignite_metric import IgniteMetric, IgniteMetricHandler
from .logfile_handler import LogfileHandler
from .lr_schedule_handler import LrScheduleHandler
from .mean_dice import MeanDice
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import ConfusionMatrixMetric
from monai.utils.enums import MetricReduction


class ConfusionMatrix(IgniteMetric):
class ConfusionMatrix(IgniteMetricHandler):
"""
Compute confusion matrix related metrics from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import HausdorffDistanceMetric
from monai.utils import MetricReduction


class HausdorffDistance(IgniteMetric):
class HausdorffDistance(IgniteMetricHandler):
"""
Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
65 changes: 56 additions & 9 deletions monai/handlers/ignite_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,26 @@

import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

import torch
from torch.nn.modules.loss import _Loss

from monai.config import IgniteInfo
from monai.metrics import CumulativeIterationMetric
from monai.utils import min_version, optional_import
from monai.metrics import CumulativeIterationMetric, LossMetric
from monai.utils import MetricReduction, deprecated, min_version, optional_import

idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")

if TYPE_CHECKING:
from ignite.engine import Engine
from ignite.metrics import Metric
from ignite.metrics.metric import reinit__is_reduced
try:
_, has_ignite = optional_import("ignite")
from ignite.engine import Engine
from ignite.metrics import Metric
from ignite.metrics.metric import reinit__is_reduced
except ImportError:
has_ignite = False

else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base")
Expand All @@ -35,7 +41,7 @@
)


class IgniteMetric(Metric):
matt3o marked this conversation as resolved.
Show resolved Hide resolved
class IgniteMetricHandler(Metric):
"""
Base Metric class based on ignite event handler mechanism.
The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim,
Expand All @@ -44,6 +50,7 @@ class IgniteMetric(Metric):
Args:
metric_fn: callable function or class to compute raw metric results after every iteration.
expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans).
loss_fn: A torch _Loss function which is used to generate the LossMetric
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
Expand All @@ -52,18 +59,37 @@ class IgniteMetric(Metric):
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
save_details: whether to save metric computation details per image, for example: mean_dice of every image.
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
reduction: Argument for the LossMetric, look there for details
get_not_nans: Argument for the LossMetric, look there for details

"""

def __init__(
self, metric_fn: CumulativeIterationMetric, output_transform: Callable = lambda x: x, save_details: bool = True
self,
metric_fn: CumulativeIterationMetric | None = None,
loss_fn: _Loss | None = None,
output_transform: Callable = lambda x: x,
save_details: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
) -> None:
self._is_reduced: bool = False
self.metric_fn = metric_fn
self.metric_fn: CumulativeIterationMetric = cast(CumulativeIterationMetric, metric_fn)
self.loss_fn = loss_fn
self.save_details = save_details
self._scores: list = []
self._engine: Engine | None = None
self._name: str | None = None

if self.metric_fn is None and self.loss_fn is None:
raise ValueError("Either metric_fn or loss_fn have to be passed.")
if self.metric_fn is not None and self.loss_fn is not None:
raise ValueError("Either metric_fn or loss_fn have to be passed, but not both.")
if self.loss_fn:
self.metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=get_not_nans)
# else:
matt3o marked this conversation as resolved.
Show resolved Hide resolved
# self.metric_fn = cast(CumulativeIterationMetric, metric_fn)

super().__init__(output_transform)

@reinit__is_reduced
Expand Down Expand Up @@ -129,3 +155,24 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override]
self._name = name
if self.save_details and not hasattr(engine.state, "metric_details"):
engine.state.metric_details = {} # type: ignore


@deprecated(since="1.3", removed="1.4", msg_suffix="Use IgniteMetricHandler instead of IgniteMetric.")
class IgniteMetric(IgniteMetricHandler):
def __init__(
self,
metric_fn: CumulativeIterationMetric | None = None,
loss_fn: _Loss | None = None,
output_transform: Callable = lambda x: x,
save_details: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
) -> None:
super().__init__(
metric_fn=metric_fn,
loss_fn=loss_fn,
output_transform=output_transform,
save_details=save_details,
reduction=reduction,
get_not_nans=get_not_nans,
)
4 changes: 2 additions & 2 deletions monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import DiceMetric
from monai.utils import MetricReduction


class MeanDice(IgniteMetric):
class MeanDice(IgniteMetricHandler):
"""
Computes Dice score metric from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import MeanIoU
from monai.utils import MetricReduction


class MeanIoUHandler(IgniteMetric):
class MeanIoUHandler(IgniteMetricHandler):
"""
Computes IoU score metric from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
6 changes: 3 additions & 3 deletions monai/handlers/metrics_reloaded_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import MetricsReloadedBinary, MetricsReloadedCategorical
from monai.utils.enums import MetricReduction


class MetricsReloadedBinaryHandler(IgniteMetric):
class MetricsReloadedBinaryHandler(IgniteMetricHandler):
"""
Handler of MetricsReloadedBinary, which wraps the binary pairwise metrics of MetricsReloaded.
"""
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class MetricsReloadedCategoricalHandler(IgniteMetric):
class MetricsReloadedCategoricalHandler(IgniteMetricHandler):
"""
Handler of MetricsReloadedCategorical, which wraps the categorical pairwise metrics of MetricsReloaded.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import PanopticQualityMetric
from monai.utils import MetricReduction


class PanopticQuality(IgniteMetric):
class PanopticQuality(IgniteMetricHandler):
"""
Computes Panoptic quality from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
10 changes: 5 additions & 5 deletions monai/handlers/regression_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric
from monai.utils import MetricReduction


class MeanSquaredError(IgniteMetric):
class MeanSquaredError(IgniteMetricHandler):
"""
Computes Mean Squared Error from full size Tensor and collects average over batch, iterations.
"""
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class MeanAbsoluteError(IgniteMetric):
class MeanAbsoluteError(IgniteMetricHandler):
"""
Computes Mean Absolute Error from full size Tensor and collects average over batch, iterations.
"""
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class RootMeanSquaredError(IgniteMetric):
class RootMeanSquaredError(IgniteMetricHandler):
"""
Computes Root Mean Squared Error from full size Tensor and collects average over batch, iterations.
"""
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class PeakSignalToNoiseRatio(IgniteMetric):
class PeakSignalToNoiseRatio(IgniteMetricHandler):
"""
Computes Peak Signal to Noise Ratio from full size Tensor and collects average over batch, iterations.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import ROCAUCMetric
from monai.utils import Average


class ROCAUC(IgniteMetric):
class ROCAUC(IgniteMetricHandler):
"""
Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC).
accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`.
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import SurfaceDistanceMetric
from monai.utils import MetricReduction


class SurfaceDistance(IgniteMetric):
class SurfaceDistance(IgniteMetricHandler):
"""
Computes surface distance from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def run_testsuit():
"test_handler_early_stop",
"test_handler_garbage_collector",
"test_handler_hausdorff_distance",
"test_handler_ignite_metric",
"test_handler_lr_scheduler",
"test_handler_mean_dice",
"test_handler_panoptic_quality",
Expand Down
Loading