Skip to content

Commit

Permalink
Fix code formatting and remove debug prints
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Hadlich <matthiashadlich@posteo.de>
  • Loading branch information
matt3o committed Jul 12, 2023
1 parent f54a1bf commit eaf7fd0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
28 changes: 20 additions & 8 deletions monai/handlers/ignite_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@

import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

import torch
from torch.nn.modules.loss import _Loss

from monai.config import IgniteInfo
from monai.metrics import CumulativeIterationMetric, LossMetric
from monai.utils import deprecated, min_version, optional_import
from monai.utils import MetricReduction, deprecated, min_version, optional_import

idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")

Expand Down Expand Up @@ -59,7 +59,8 @@ class IgniteMetricHandler(Metric):
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
save_details: whether to save metric computation details per image, for example: mean_dice of every image.
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
kwargs: keyword argument that will be passed into the LossMetric
reduction: Argument for the LossMetric, look there for details
get_not_nans: Argument for the LossMetric, look there for details
"""

Expand All @@ -69,10 +70,11 @@ def __init__(
loss_fn: _Loss | None = None,
output_transform: Callable = lambda x: x,
save_details: bool = True,
**kwargs: dict,
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
) -> None:
self._is_reduced: bool = False
self.metric_fn = metric_fn
self.metric_fn: CumulativeIterationMetric = cast(CumulativeIterationMetric, metric_fn)
self.loss_fn = loss_fn
self.save_details = save_details
self._scores: list = []
Expand All @@ -84,7 +86,11 @@ def __init__(
if self.metric_fn is not None and self.loss_fn is not None:
raise ValueError("Either metric_fn or loss_fn have to be passed, but not both.")
if self.loss_fn:
self.metric_fn = LossMetric(loss_fn=self.loss_fn, **kwargs)
self.metric_fn = LossMetric(
loss_fn=self.loss_fn, reduction=reduction, get_not_nans=get_not_nans
)
# else:
# self.metric_fn = cast(CumulativeIterationMetric, metric_fn)

super().__init__(output_transform)

Expand Down Expand Up @@ -161,8 +167,14 @@ def __init__(
loss_fn: _Loss | None = None,
output_transform: Callable = lambda x: x,
save_details: bool = True,
**kwargs: dict,
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
) -> None:
super().__init__(
metric_fn=metric_fn, loss_fn=loss_fn, output_transform=output_transform, save_details=save_details, **kwargs
metric_fn=metric_fn,
loss_fn=loss_fn,
output_transform=output_transform,
save_details=save_details,
reduction=reduction,
get_not_nans=get_not_nans,
)
6 changes: 0 additions & 6 deletions tests/test_handler_ignite_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ def _val_func(engine, batch):
engine.fire_event(Events.ITERATION_COMPLETED)

engine.fire_event(Events.EPOCH_COMPLETED)
print(f"{engine.state.metrics['ignite_dice_loss']}")
print(f"{engine.state.metric_details['ignite_dice_loss'].shape}")
print(f"{engine.state.metric_details['ignite_dice_loss']}")
assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)

@SkipIfNoModule("ignite")
Expand All @@ -150,9 +147,6 @@ def _val_func(engine, batch):
engine.fire_event(Events.ITERATION_COMPLETED)

engine.fire_event(Events.EPOCH_COMPLETED)
print(f"{engine.state.metrics['ignite_dice_loss']}")
print(f"{engine.state.metric_details['ignite_dice_loss'].shape}")
print(f"{engine.state.metric_details['ignite_dice_loss']}")
assert_allclose(engine.state.metrics["ignite_dice_loss"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)

@SkipIfNoModule("ignite")
Expand Down

0 comments on commit eaf7fd0

Please sign in to comment.