Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3429 Enhance the scalar write logic of TensorBoardStatsHandler #3431

Merged
merged 11 commits into from
Dec 3, 2021
35 changes: 29 additions & 6 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,21 @@ def iteration_completed(self, engine: Engine) -> None:
else:
self._default_iteration_writer(engine, self._writer)

def _write_scalar(self, engine: Engine, writer: SummaryWriter, tag: str, value: Any, step: int) -> None:
"""
Write scale value into TensorBoard.
Default to call `SummaryWriter.add_scalar()`, subclass can override it for more complicated logic.
Nic-Ma marked this conversation as resolved.
Show resolved Hide resolved

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler.
tag: tag name in the TensorBoard.
value: value of the scalar data for current step.
step: index of current step.

"""
writer.add_scalar(tag, value, step)

def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None:
"""
Execute epoch level event write operation.
Expand All @@ -188,11 +203,11 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None:
summary_dict = engine.state.metrics
for name, value in summary_dict.items():
if is_scalar(value):
writer.add_scalar(name, value, current_epoch)
self._write_scalar(engine, writer, name, value, current_epoch)

if self.state_attributes is not None:
for attr in self.state_attributes:
writer.add_scalar(attr, getattr(engine.state, attr, None), current_epoch)
self._write_scalar(engine, writer, attr, getattr(engine.state, attr, None), current_epoch)
writer.flush()

def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> None:
Expand Down Expand Up @@ -221,12 +236,20 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No
" {}:{}".format(name, type(value))
)
continue # not plot multi dimensional output
writer.add_scalar(
name, value.item() if isinstance(value, torch.Tensor) else value, engine.state.iteration
self._write_scalar(
engine=engine,
writer=writer,
tag=name,
value=value.item() if isinstance(value, torch.Tensor) else value,
step=engine.state.iteration,
)
elif is_scalar(loss): # not printing multi dimensional output
writer.add_scalar(
self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss, engine.state.iteration
self._write_scalar(
engine=engine,
writer=writer,
tag=self.tag_name,
value=loss.item() if isinstance(loss, torch.Tensor) else loss,
step=engine.state.iteration,
)
else:
warnings.warn(
Expand Down