Skip to content

Commit

Permalink
Add loss_fn to IgniteMetric and rename to IgniteMetricHandler (#6695)
Browse files Browse the repository at this point in the history
### Description

As explained in #6693 I would like to use the DiceCELoss as a train
metric as well. This branch adds a very crude but working version of
that.
The added tests, which I copied from the MeanDice metric, do still fail.
It would be cool if someone more experienced could check, what needs to
be done there.
I ran the code with my full DeepEdit setup and it appears to be working
just fine there.

No formatting checks done so far - I want to find out first if this code
is useful for others. Docs not updated yet either.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Matthias Hadlich <matthiashadlich@posteo.de>
  • Loading branch information
matt3o authored Jul 13, 2023
1 parent 48a86b2 commit c1c0cdc
Show file tree
Hide file tree
Showing 14 changed files with 273 additions and 35 deletions.
6 changes: 3 additions & 3 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ CSV saver
:members:


Ignite Metric
-------------
.. autoclass:: IgniteMetric
Ignite Metric Handler
---------------------
.. autoclass:: IgniteMetricHandler
:members:


Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .earlystop_handler import EarlyStopHandler
from .garbage_collector import GarbageCollector
from .hausdorff_distance import HausdorffDistance
from .ignite_metric import IgniteMetric
from .ignite_metric import IgniteMetric, IgniteMetricHandler
from .logfile_handler import LogfileHandler
from .lr_schedule_handler import LrScheduleHandler
from .mean_dice import MeanDice
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import ConfusionMatrixMetric
from monai.utils.enums import MetricReduction


class ConfusionMatrix(IgniteMetric):
class ConfusionMatrix(IgniteMetricHandler):
"""
Compute confusion matrix related metrics from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import HausdorffDistanceMetric
from monai.utils import MetricReduction


class HausdorffDistance(IgniteMetric):
class HausdorffDistance(IgniteMetricHandler):
"""
Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
63 changes: 54 additions & 9 deletions monai/handlers/ignite_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,26 @@

import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

import torch
from torch.nn.modules.loss import _Loss

from monai.config import IgniteInfo
from monai.metrics import CumulativeIterationMetric
from monai.utils import min_version, optional_import
from monai.metrics import CumulativeIterationMetric, LossMetric
from monai.utils import MetricReduction, deprecated, min_version, optional_import

idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")

if TYPE_CHECKING:
from ignite.engine import Engine
from ignite.metrics import Metric
from ignite.metrics.metric import reinit__is_reduced
try:
_, has_ignite = optional_import("ignite")
from ignite.engine import Engine
from ignite.metrics import Metric
from ignite.metrics.metric import reinit__is_reduced
except ImportError:
has_ignite = False

else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base")
Expand All @@ -35,7 +41,7 @@
)


class IgniteMetric(Metric):
class IgniteMetricHandler(Metric):
"""
Base Metric class based on ignite event handler mechanism.
The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim,
Expand All @@ -44,6 +50,7 @@ class IgniteMetric(Metric):
Args:
metric_fn: callable function or class to compute raw metric results after every iteration.
expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans).
loss_fn: A torch _Loss function which is used to generate the LossMetric
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
Expand All @@ -52,18 +59,35 @@ class IgniteMetric(Metric):
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
save_details: whether to save metric computation details per image, for example: mean_dice of every image.
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
reduction: Argument for the LossMetric, look there for details
get_not_nans: Argument for the LossMetric, look there for details
"""

def __init__(
self, metric_fn: CumulativeIterationMetric, output_transform: Callable = lambda x: x, save_details: bool = True
self,
metric_fn: CumulativeIterationMetric | None = None,
loss_fn: _Loss | None = None,
output_transform: Callable = lambda x: x,
save_details: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
) -> None:
self._is_reduced: bool = False
self.metric_fn = metric_fn
self.metric_fn: CumulativeIterationMetric = cast(CumulativeIterationMetric, metric_fn)
self.loss_fn = loss_fn
self.save_details = save_details
self._scores: list = []
self._engine: Engine | None = None
self._name: str | None = None

if self.metric_fn is None and self.loss_fn is None:
raise ValueError("Either metric_fn or loss_fn have to be passed.")
if self.metric_fn is not None and self.loss_fn is not None:
raise ValueError("Either metric_fn or loss_fn have to be passed, but not both.")
if self.loss_fn:
self.metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=get_not_nans)

super().__init__(output_transform)

@reinit__is_reduced
Expand Down Expand Up @@ -129,3 +153,24 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override]
self._name = name
if self.save_details and not hasattr(engine.state, "metric_details"):
engine.state.metric_details = {} # type: ignore


@deprecated(since="1.3", removed="1.4", msg_suffix="Use IgniteMetricHandler instead of IgniteMetric.")
class IgniteMetric(IgniteMetricHandler):
def __init__(
self,
metric_fn: CumulativeIterationMetric | None = None,
loss_fn: _Loss | None = None,
output_transform: Callable = lambda x: x,
save_details: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
) -> None:
super().__init__(
metric_fn=metric_fn,
loss_fn=loss_fn,
output_transform=output_transform,
save_details=save_details,
reduction=reduction,
get_not_nans=get_not_nans,
)
4 changes: 2 additions & 2 deletions monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import DiceMetric
from monai.utils import MetricReduction


class MeanDice(IgniteMetric):
class MeanDice(IgniteMetricHandler):
"""
Computes Dice score metric from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import MeanIoU
from monai.utils import MetricReduction


class MeanIoUHandler(IgniteMetric):
class MeanIoUHandler(IgniteMetricHandler):
"""
Computes IoU score metric from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
6 changes: 3 additions & 3 deletions monai/handlers/metrics_reloaded_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import MetricsReloadedBinary, MetricsReloadedCategorical
from monai.utils.enums import MetricReduction


class MetricsReloadedBinaryHandler(IgniteMetric):
class MetricsReloadedBinaryHandler(IgniteMetricHandler):
"""
Handler of MetricsReloadedBinary, which wraps the binary pairwise metrics of MetricsReloaded.
"""
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class MetricsReloadedCategoricalHandler(IgniteMetric):
class MetricsReloadedCategoricalHandler(IgniteMetricHandler):
"""
Handler of MetricsReloadedCategorical, which wraps the categorical pairwise metrics of MetricsReloaded.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import PanopticQualityMetric
from monai.utils import MetricReduction


class PanopticQuality(IgniteMetric):
class PanopticQuality(IgniteMetricHandler):
"""
Computes Panoptic quality from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
10 changes: 5 additions & 5 deletions monai/handlers/regression_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric
from monai.utils import MetricReduction


class MeanSquaredError(IgniteMetric):
class MeanSquaredError(IgniteMetricHandler):
"""
Computes Mean Squared Error from full size Tensor and collects average over batch, iterations.
"""
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class MeanAbsoluteError(IgniteMetric):
class MeanAbsoluteError(IgniteMetricHandler):
"""
Computes Mean Absolute Error from full size Tensor and collects average over batch, iterations.
"""
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class RootMeanSquaredError(IgniteMetric):
class RootMeanSquaredError(IgniteMetricHandler):
"""
Computes Root Mean Squared Error from full size Tensor and collects average over batch, iterations.
"""
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)


class PeakSignalToNoiseRatio(IgniteMetric):
class PeakSignalToNoiseRatio(IgniteMetricHandler):
"""
Computes Peak Signal to Noise Ratio from full size Tensor and collects average over batch, iterations.
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import ROCAUCMetric
from monai.utils import Average


class ROCAUC(IgniteMetric):
class ROCAUC(IgniteMetricHandler):
"""
Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC).
accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`.
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetric
from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import SurfaceDistanceMetric
from monai.utils import MetricReduction


class SurfaceDistance(IgniteMetric):
class SurfaceDistance(IgniteMetricHandler):
"""
Computes surface distance from full size Tensor and collects average over batch, class-channels, iterations.
"""
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def run_testsuit():
"test_handler_early_stop",
"test_handler_garbage_collector",
"test_handler_hausdorff_distance",
"test_handler_ignite_metric",
"test_handler_lr_scheduler",
"test_handler_mean_dice",
"test_handler_panoptic_quality",
Expand Down
Loading

0 comments on commit c1c0cdc

Please sign in to comment.