Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3429 Enhance the scalar write logic of TensorBoardStatsHandler #3431

Merged
merged 11 commits into from
Dec 3, 2021
13 changes: 10 additions & 3 deletions monai/handlers/classification_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ def attach(self, engine: Engine) -> None:
if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED):
engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize)

def _started(self, engine: Engine) -> None:
def _started(self, _engine: Engine) -> None:
"""
Initialize internal buffers.

Args:
_engine: Ignite Engine, unused argument.

"""
self._outputs = []
self._filenames = []

Expand All @@ -120,12 +127,12 @@ def __call__(self, engine: Engine) -> None:
o = o.detach()
self._outputs.append(o)

def _finalize(self, engine: Engine) -> None:
def _finalize(self, _engine: Engine) -> None:
"""
All gather classification results from ranks and save to CSV file.

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
_engine: Ignite Engine, unused argument.
"""
ws = idist.get_world_size()
if self.save_rank >= ws:
Expand Down
9 changes: 8 additions & 1 deletion monai/handlers/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,14 @@ def attach(self, engine: Engine) -> None:
engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames)
engine.add_event_handler(Events.EPOCH_COMPLETED, self)

def _started(self, engine: Engine) -> None:
def _started(self, _engine: Engine) -> None:
"""
Initialize internal buffers.

Args:
_engine: Ignite Engine, unused argument.

"""
self._filenames = []

def _get_filenames(self, engine: Engine) -> None:
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ def iteration_completed(self, engine: Engine) -> None:
else:
self._default_iteration_print(engine)

def exception_raised(self, engine: Engine, e: Exception) -> None:
def exception_raised(self, _engine: Engine, e: Exception) -> None:
"""
Handler for train or validation/evaluation exception raised Event.
Print the exception information and traceback. This callback may be skipped because the logic
with Ignite can only trigger the first attached handler for `EXCEPTION_RAISED` event.

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
_engine: Ignite Engine, unused argument.
e: the exception caught in Ignite during engine.run().

"""
Expand Down
35 changes: 29 additions & 6 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,21 @@ def iteration_completed(self, engine: Engine) -> None:
else:
self._default_iteration_writer(engine, self._writer)

def _write_scalar(self, _engine: Engine, writer: SummaryWriter, tag: str, value: Any, step: int) -> None:
"""
Write scale value into TensorBoard.
Default to call `SummaryWriter.add_scalar()`.

Args:
_engine: Ignite Engine, unused argument.
writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler.
tag: tag name in the TensorBoard.
value: value of the scalar data for current step.
step: index of current step.

"""
writer.add_scalar(tag, value, step)

def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None:
"""
Execute epoch level event write operation.
Expand All @@ -188,11 +203,11 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None:
summary_dict = engine.state.metrics
for name, value in summary_dict.items():
if is_scalar(value):
writer.add_scalar(name, value, current_epoch)
self._write_scalar(engine, writer, name, value, current_epoch)

if self.state_attributes is not None:
for attr in self.state_attributes:
writer.add_scalar(attr, getattr(engine.state, attr, None), current_epoch)
self._write_scalar(engine, writer, attr, getattr(engine.state, attr, None), current_epoch)
writer.flush()

def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> None:
Expand Down Expand Up @@ -221,12 +236,20 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No
" {}:{}".format(name, type(value))
)
continue # not plot multi dimensional output
writer.add_scalar(
name, value.item() if isinstance(value, torch.Tensor) else value, engine.state.iteration
self._write_scalar(
_engine=engine,
writer=writer,
tag=name,
value=value.item() if isinstance(value, torch.Tensor) else value,
step=engine.state.iteration,
)
elif is_scalar(loss): # not printing multi dimensional output
writer.add_scalar(
self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss, engine.state.iteration
self._write_scalar(
_engine=engine,
writer=writer,
tag=self.tag_name,
value=loss.item() if isinstance(loss, torch.Tensor) else loss,
step=engine.state.iteration,
)
else:
warnings.warn(
Expand Down