Skip to content

Commit af1ffd6

Browse files
Nic-Mamonai-bot
andauthored
Add base class for TensorBoard handlers (#1573)
* [DLMED] add base class for TensorBoard handlers Signed-off-by: Nic Ma <nma@nvidia.com> * [MONAI] python code formatting Signed-off-by: monai-bot <monai.miccai2019@gmail.com> * [DLMED] enhance integration test with TensorBoard writer Signed-off-by: Nic Ma <nma@nvidia.com> * [MONAI] python code formatting Signed-off-by: monai-bot <monai.miccai2019@gmail.com> * [DLMED] add close method Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] Enhance MetricsSaver doc-strings Signed-off-by: Nic Ma <nma@nvidia.com> Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
1 parent 910b1d4 commit af1ffd6

8 files changed

+60
-17
lines changed

docs/source/handlers.rst

+3
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ Training stats handler
8585

8686
Tensorboard handlers
8787
--------------------
88+
.. autoclass:: TensorBoardHandler
89+
:members:
90+
8891
.. autoclass:: TensorBoardStatsHandler
8992
:members:
9093

monai/handlers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .smartcache_handler import SmartCacheHandler
2525
from .stats_handler import StatsHandler
2626
from .surface_distance import SurfaceDistance
27-
from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler
27+
from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler
2828
from .utils import (
2929
evenly_divisible_all_gather,
3030
stopping_fn_from_loss,

monai/handlers/metrics_saver.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,22 @@ class MetricsSaver:
3434
"*" - save all the existing metrics in `engine.state.metrics` dict into separate files.
3535
list of strings - specify the expected metrics to save.
3636
default to "*" to save all the metrics into `metrics.csv`.
37-
metric_details: expected metric details to save into files, for example: mean dice
38-
of every channel of every image in the validation dataset.
39-
the data in `engine.state.metric_details` must contain at least 2 dims: (batch, classes, ...),
37+
metric_details: expected metric details to save into files, the data comes from
38+
`engine.state.metric_details`, which should be provided by different `Metrics`,
39+
typically, it's some intermediate values in metric computation.
40+
for example: mean dice of every channel of every image in the validation dataset.
41+
it must contain at least 2 dims: (batch, classes, ...),
4042
if not, will unsequeeze to 2 dims.
4143
this arg can be: None, "*" or list of strings.
42-
None - don't save any metrics into files.
43-
"*" - save all the existing metrics in `engine.state.metric_details` dict into separate files.
44-
list of strings - specify the expected metrics to save.
45-
if not None, every metric will save a separate `{metric name}_raw.csv` file.
44+
None - don't save any metric_details into files.
45+
"*" - save all the existing metric_details in `engine.state.metric_details` dict into separate files.
46+
list of strings - specify the metric_details of expected metrics to save.
47+
if not None, every metric_details array will save a separate `{metric name}_raw.csv` file.
4648
batch_transform: callable function to extract the meta_dict from input batch data if saving metric details.
4749
used to extract filenames from input dict data.
48-
summary_ops: expected computation operations to generate the summary report.
50+
summary_ops: expected computation operations to generate the summary report based on specified metric_details.
4951
it can be: None, "*" or list of strings.
50-
None - don't generate summary report for every expected metric_details
52+
None - don't generate summary report for every specified metric_details
5153
"*" - generate summary report for every metric_details with all the supported operations.
5254
list of strings - generate summary report for every metric_details with specified operations, they
5355
should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`].

monai/handlers/tensorboard_handlers.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,38 @@
2929
DEFAULT_TAG = "Loss"
3030

3131

32-
class TensorBoardStatsHandler:
32+
class TensorBoardHandler:
33+
"""
34+
Base class for the handlers to write data into TensorBoard.
35+
36+
Args:
37+
summary_writer: user can specify TensorBoard SummaryWriter,
38+
default to create a new writer.
39+
log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.
40+
41+
"""
42+
43+
def __init__(self, summary_writer: Optional[SummaryWriter] = None, log_dir: str = "./runs"):
44+
if summary_writer is None:
45+
self._writer = SummaryWriter(log_dir=log_dir)
46+
self.internal_writer = True
47+
else:
48+
self._writer = summary_writer
49+
self.internal_writer = False
50+
51+
def attach(self, engine: Engine) -> None:
52+
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
53+
54+
def close(self):
55+
"""
56+
Close the summary writer if created in this TensorBoard handler.
57+
58+
"""
59+
if self.internal_writer:
60+
self._writer.close()
61+
62+
63+
class TensorBoardStatsHandler(TensorBoardHandler):
3364
"""
3465
TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics.
3566
It's can be used for any Ignite Engine(trainer, validator and evaluator).
@@ -71,7 +102,7 @@ def __init__(
71102
when plotting epoch vs metric curves.
72103
tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``.
73104
"""
74-
self._writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer
105+
super().__init__(summary_writer=summary_writer, log_dir=log_dir)
75106
self.epoch_event_writer = epoch_event_writer
76107
self.iteration_event_writer = iteration_event_writer
77108
self.output_transform = output_transform
@@ -176,7 +207,7 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No
176207
writer.flush()
177208

178209

179-
class TensorBoardImageHandler:
210+
class TensorBoardImageHandler(TensorBoardHandler):
180211
"""
181212
TensorBoardImageHandler is an Ignite Event handler that can visualize images, labels and outputs as 2D/3D images.
182213
2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch,
@@ -229,7 +260,7 @@ def __init__(
229260
max_channels: number of channels to plot.
230261
max_frames: number of frames for 2D-t plot.
231262
"""
232-
self._writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer
263+
super().__init__(summary_writer=summary_writer, log_dir=log_dir)
233264
self.interval = interval
234265
self.epoch_level = epoch_level
235266
self.batch_transform = batch_transform

monai/handlers/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def write_metrics_reports(
128128
images: name or path of every input image corresponding to the metric_details data.
129129
if None, will use index number as the filename of every input image.
130130
metrics: a dictionary of (metric name, metric value) pairs.
131-
metric_details: a dictionary of (metric name, metric raw values) pairs,
131+
metric_details: a dictionary of (metric name, metric raw values) pairs, usually, it comes from metrics computation,
132132
for example, the raw value can be the mean_dice of every channel of every input image.
133133
summary_ops: expected computation operations to generate the summary report.
134134
it can be: None, "*" or list of strings.

tests/test_handler_tb_image.py

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _train_func(engine, batch):
4040

4141
data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape)))
4242
engine.run(data, epoch_length=10, max_epochs=1)
43+
stats_handler.close()
4344

4445
self.assertTrue(len(glob.glob(tempdir)) > 0)
4546

tests/test_handler_tb_stats.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def _update_metric(engine):
3939
stats_handler = TensorBoardStatsHandler(log_dir=tempdir)
4040
stats_handler.attach(engine)
4141
engine.run(range(3), max_epochs=2)
42+
stats_handler.close()
4243
# check logging output
4344
self.assertTrue(len(glob.glob(tempdir)) > 0)
4445

@@ -64,6 +65,7 @@ def _update_metric(engine):
6465
)
6566
stats_handler.attach(engine)
6667
engine.run(range(3), max_epochs=2)
68+
writer.close()
6769
# check logging output
6870
self.assertTrue(len(glob.glob(tempdir)) > 0)
6971

tests/test_integration_workflows.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import torch
2424
from ignite.metrics import Accuracy
25+
from torch.utils.tensorboard import SummaryWriter
2526

2627
import monai
2728
from monai.data import create_test_image_3d
@@ -105,6 +106,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4):
105106
loss = monai.losses.DiceLoss(sigmoid=True)
106107
opt = torch.optim.Adam(net.parameters(), 1e-3)
107108
lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)
109+
summary_writer = SummaryWriter(log_dir=root_dir)
108110

109111
val_post_transforms = Compose(
110112
[
@@ -123,7 +125,7 @@ def _forward_completed(self, engine):
123125

124126
val_handlers = [
125127
StatsHandler(output_transform=lambda x: None),
126-
TensorBoardStatsHandler(log_dir=root_dir, output_transform=lambda x: None),
128+
TensorBoardStatsHandler(summary_writer=summary_writer, output_transform=lambda x: None),
127129
TensorBoardImageHandler(
128130
log_dir=root_dir, batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"]
129131
),
@@ -176,7 +178,9 @@ def _optimizer_completed(self, engine):
176178
LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
177179
ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
178180
StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
179-
TensorBoardStatsHandler(log_dir=root_dir, tag_name="train_loss", output_transform=lambda x: x["loss"]),
181+
TensorBoardStatsHandler(
182+
summary_writer=summary_writer, tag_name="train_loss", output_transform=lambda x: x["loss"]
183+
),
180184
CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
181185
_TestTrainIterEvents(),
182186
]

0 commit comments

Comments
 (0)