Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added callable options for iteration_log and epoch_log in TensorBoard and MLFlow #5976

Merged
merged 7 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions monai/handlers/mlflow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ class MLFlowHandler:
to log data to a directory. The URI defaults to path `mlruns`.
for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri.
iteration_log: whether to log data to MLFlow when iteration completed, default to `True`.
``iteration_log`` can be also a function and it will be interpreted as an event filter
(see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).
Event filter function accepts as input engine and event value (iteration) and should return True/False.
epoch_log: whether to log data to MLFlow when epoch completed, default to `True`.
``epoch_log`` can be also a function and it will be interpreted as an event filter.
See ``iteration_log`` argument for more details.
epoch_logger: customized callable logger for epoch level logging with MLFlow.
Must accept parameter "engine", use default logger if None.
iteration_logger: customized callable logger for iteration level logging with MLFlow.
Expand Down Expand Up @@ -98,8 +103,8 @@ class MLFlowHandler:
def __init__(
self,
tracking_uri: str | None = None,
iteration_log: bool = True,
epoch_log: bool = True,
iteration_log: bool | Callable[[Engine, int], bool] = True,
epoch_log: bool | Callable[[Engine, int], bool] = True,
epoch_logger: Callable[[Engine], Any] | None = None,
iteration_logger: Callable[[Engine], Any] | None = None,
output_transform: Callable = lambda x: x[0],
Expand Down Expand Up @@ -159,9 +164,15 @@ def attach(self, engine: Engine) -> None:
if not engine.has_event_handler(self.start, Events.STARTED):
engine.add_event_handler(Events.STARTED, self.start)
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
event = Events.ITERATION_COMPLETED
if callable(self.iteration_log): # substitute event with new one using filter callable
event = event(event_filter=self.iteration_log)
engine.add_event_handler(event, self.iteration_completed)
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)
event = Events.EPOCH_COMPLETED
if callable(self.epoch_log): # substitute event with new one using filter callable
event = event(event_filter=self.epoch_log)
engine.add_event_handler(event, self.epoch_completed)
if not engine.has_event_handler(self.complete, Events.COMPLETED):
engine.add_event_handler(Events.COMPLETED, self.complete)
if self.close_on_complete and (not engine.has_event_handler(self.close, Events.COMPLETED)):
Expand Down
37 changes: 30 additions & 7 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch

from monai.config import IgniteInfo
from monai.utils import is_scalar, min_version, optional_import
from monai.utils import deprecated_arg, is_scalar, min_version, optional_import
from monai.visualize import plot_2d_or_3d_image

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
Expand Down Expand Up @@ -87,12 +87,14 @@ class TensorBoardStatsHandler(TensorBoardHandler):

"""

@deprecated_arg("epoch_interval", since="1.1", removed="1.3")
@deprecated_arg("iteration_interval", since="1.1", removed="1.3")
wyli marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
summary_writer: SummaryWriter | SummaryWriterX | None = None,
log_dir: str = "./runs",
iteration_log: bool = True,
epoch_log: bool = True,
iteration_log: bool | Callable[[Engine, int], bool] = True,
epoch_log: bool | Callable[[Engine, int], bool] = True,
epoch_event_writer: Callable[[Engine, Any], Any] | None = None,
epoch_interval: int = 1,
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
iteration_event_writer: Callable[[Engine, Any], Any] | None = None,
Expand All @@ -108,13 +110,20 @@ def __init__(
default to create a new TensorBoard writer.
log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.
iteration_log: whether to write data to TensorBoard when iteration completed, default to `True`.
``iteration_log`` can be also a function and it will be interpreted as an event filter
(see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).
Event filter function accepts as input engine and event value (iteration) and should return True/False.
epoch_log: whether to write data to TensorBoard when epoch completed, default to `True`.
``epoch_log`` can be also a function and it will be interpreted as an event filter.
See ``iteration_log`` argument for more details.
epoch_event_writer: customized callable TensorBoard writer for epoch level.
Must accept parameter "engine" and "summary_writer", use default event writer if None.
epoch_interval: the epoch interval at which the epoch_event_writer is called. Defaults to 1.
``epoch_interval`` must be 1 if ``epoch_log`` is callable.
iteration_event_writer: customized callable TensorBoard writer for iteration level.
Must accept parameter "engine" and "summary_writer", use default event writer if None.
iteration_interval: the iteration interval at which the iteration_event_writer is called. Defaults to 1.
``iteration_interval`` must be 1 if ``iteration_log`` is callable.
output_transform: a callable that is used to transform the
``ignite.engine.state.output`` into a scalar to plot, or a dictionary of {key: scalar}.
In the latter case, the output string will be formatted as key: value.
Expand All @@ -131,6 +140,12 @@ def __init__(
when epoch completed.
tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``.
"""
if callable(iteration_log) and iteration_interval > 1:
raise ValueError("If iteration_log is callable, then iteration_interval should be 1")

if callable(epoch_log) and epoch_interval > 1:
raise ValueError("If epoch_log is callable, then epoch_interval should be 1")

super().__init__(summary_writer=summary_writer, log_dir=log_dir)
self.iteration_log = iteration_log
self.epoch_log = epoch_log
Expand All @@ -152,11 +167,19 @@ def attach(self, engine: Engine) -> None:

"""
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.add_event_handler(
Events.ITERATION_COMPLETED(every=self.iteration_interval), self.iteration_completed
)
event = Events.ITERATION_COMPLETED
if callable(self.iteration_log): # substitute event with new one using filter callable
event = event(event_filter=self.iteration_log)
elif self.iteration_interval > 1:
event = event(every=self.iteration_interval)
engine.add_event_handler(event, self.iteration_completed)
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.epoch_interval), self.epoch_completed)
event = Events.EPOCH_COMPLETED
if callable(self.epoch_log): # substitute event with new one using filter callable
event = event(event_filter=self.epoch_log)
elif self.epoch_log > 1:
event = event(every=self.epoch_interval)
engine.add_event_handler(event, self.epoch_completed)

def epoch_completed(self, engine: Engine) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions monai/utils/deprecate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def deprecated_arg(
else:
# compare the numbers
is_deprecated = since is not None and version_leq(since, version_val)
is_removed = removed is not None and version_leq(removed, version_val)
is_removed = removed is not None and version_val != f"{sys.maxsize}" and version_leq(removed, version_val)

def _decorator(func):
argname = f"{func.__module__} {func.__qualname__}:{name}"
Expand Down Expand Up @@ -284,7 +284,7 @@ def deprecated_arg_default(
else:
# compare the numbers
is_deprecated = since is not None and version_leq(since, version_val)
is_replaced = replaced is not None and version_leq(replaced, version_val)
is_replaced = replaced is not None and version_val != f"{sys.maxsize}" and version_leq(replaced, version_val)

def _decorator(func):
argname = f"{func.__module__} {func.__qualname__}:{name}"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def test_arg_except2_unknown(self):
def afoo4(a, b=None):
pass

self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2))
afoo4(1, b=2)

def test_arg_except3_unknown(self):
"""
Expand All @@ -246,8 +246,8 @@ def test_arg_except3_unknown(self):
def afoo4(a, b=None, **kwargs):
pass

self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2))
self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2, c=3))
afoo4(1, b=2)
afoo4(1, b=2, c=3)

def test_replacement_arg(self):
"""
Expand Down
90 changes: 90 additions & 0 deletions tests/test_handler_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,25 @@
import tempfile
import unittest
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import MagicMock

import numpy as np
from ignite.engine import Engine, Events
from parameterized import parameterized

from monai.handlers import MLFlowHandler
from monai.utils import path_to_uri


def get_event_filter(e):
def event_filter(_, event):
if event in e:
return True
return False

return event_filter


def dummy_train(tracking_folder):
tempdir = tempfile.mkdtemp()

Expand Down Expand Up @@ -95,6 +106,85 @@ def _update_metric(engine):
# check logging output
self.assertTrue(len(glob.glob(test_path)) > 0)

@parameterized.expand([[True], [get_event_filter([1, 2])]])
def test_metrics_track_mock(self, epoch_log):
experiment_param = {"backbone": "efficientnet_b0"}
with tempfile.TemporaryDirectory() as tempdir:
# set up engine
def _train_func(engine, batch):
return [batch + 1.0]

engine = Engine(_train_func)

# set up dummy metric
@engine.on(Events.EPOCH_COMPLETED)
def _update_metric(engine):
current_metric = engine.state.metrics.get("acc", 0.1)
engine.state.metrics["acc"] = current_metric + 0.1
engine.state.test = current_metric

# set up testing handler
test_path = os.path.join(tempdir, "mlflow_test")
handler = MLFlowHandler(
iteration_log=False,
epoch_log=epoch_log,
tracking_uri=path_to_uri(test_path),
state_attributes=["test"],
experiment_param=experiment_param,
close_on_complete=True,
)
handler._default_epoch_log = MagicMock()
handler.attach(engine)

max_epochs = 4
engine.run(range(3), max_epochs=max_epochs)
handler.close()
# check logging output
if epoch_log is True:
self.assertEqual(handler._default_epoch_log.call_count, max_epochs)
else:
self.assertEqual(handler._default_epoch_log.call_count, 2) # 2 = len([1, 2]) from event_filter

@parameterized.expand([[True], [get_event_filter([1, 3])]])
def test_metrics_track_iters_mock(self, iteration_log):
experiment_param = {"backbone": "efficientnet_b0"}
with tempfile.TemporaryDirectory() as tempdir:
# set up engine
def _train_func(engine, batch):
return [batch + 1.0]

engine = Engine(_train_func)

# set up dummy metric
@engine.on(Events.EPOCH_COMPLETED)
def _update_metric(engine):
current_metric = engine.state.metrics.get("acc", 0.1)
engine.state.metrics["acc"] = current_metric + 0.1
engine.state.test = current_metric

# set up testing handler
test_path = os.path.join(tempdir, "mlflow_test")
handler = MLFlowHandler(
iteration_log=iteration_log,
epoch_log=False,
tracking_uri=path_to_uri(test_path),
state_attributes=["test"],
experiment_param=experiment_param,
close_on_complete=True,
)
handler._default_iteration_log = MagicMock()
handler.attach(engine)

num_iters = 3
max_epochs = 2
engine.run(range(num_iters), max_epochs=max_epochs)
handler.close()
# check logging output
if iteration_log is True:
self.assertEqual(handler._default_iteration_log.call_count, num_iters * max_epochs)
else:
self.assertEqual(handler._default_iteration_log.call_count, 2) # 2 = len([1, 3]) from event_filter

def test_multi_thread(self):
test_uri_list = ["monai_mlflow_test1", "monai_mlflow_test2"]
with ThreadPoolExecutor(2, "Training") as executor:
Expand Down
89 changes: 89 additions & 0 deletions tests/test_handler_tb_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,36 @@
import glob
import tempfile
import unittest
from unittest.mock import MagicMock

from ignite.engine import Engine, Events
from parameterized import parameterized

from monai.handlers import TensorBoardStatsHandler
from monai.utils import optional_import

SummaryWriter, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter")


def get_event_filter(e):
def event_filter(_, event):
if event in e:
return True
return False

return event_filter


@unittest.skipUnless(has_tb, "Requires SummaryWriter installation")
class TestHandlerTBStats(unittest.TestCase):
def test_args_validation(self):
with self.assertWarns(FutureWarning):
with self.assertRaisesRegex(ValueError, expected_regex="iteration_interval should be 1"):
TensorBoardStatsHandler(log_dir=".", iteration_log=get_event_filter([1, 2]), iteration_interval=2)

with self.assertRaisesRegex(ValueError, expected_regex="epoch_interval should be 1"):
TensorBoardStatsHandler(log_dir=".", epoch_log=get_event_filter([1, 2]), epoch_interval=2)

def test_metrics_print(self):
with tempfile.TemporaryDirectory() as tempdir:
# set up engine
Expand All @@ -47,6 +66,35 @@ def _update_metric(engine):
# check logging output
self.assertTrue(len(glob.glob(tempdir)) > 0)

@parameterized.expand([[True], [get_event_filter([1, 2])]])
def test_metrics_print_mock(self, epoch_log):
with tempfile.TemporaryDirectory() as tempdir:
# set up engine
def _train_func(engine, batch):
return [batch + 1.0]

engine = Engine(_train_func)

# set up dummy metric
@engine.on(Events.EPOCH_COMPLETED)
def _update_metric(engine):
current_metric = engine.state.metrics.get("acc", 0.1)
engine.state.metrics["acc"] = current_metric + 0.1

# set up testing handler
stats_handler = TensorBoardStatsHandler(log_dir=tempdir, iteration_log=False, epoch_log=epoch_log)
stats_handler._default_epoch_writer = MagicMock()
stats_handler.attach(engine)

max_epochs = 4
engine.run(range(3), max_epochs=max_epochs)
stats_handler.close()
# check logging output
if epoch_log is True:
self.assertEqual(stats_handler._default_epoch_writer.call_count, max_epochs)
else:
self.assertEqual(stats_handler._default_epoch_writer.call_count, 2) # 2 = len([1, 2]) from event_filter

def test_metrics_writer(self):
with tempfile.TemporaryDirectory() as tempdir:
# set up engine
Expand Down Expand Up @@ -78,6 +126,47 @@ def _update_metric(engine):
# check logging output
self.assertTrue(len(glob.glob(tempdir)) > 0)

@parameterized.expand([[True], [get_event_filter([1, 3])]])
def test_metrics_writer_mock(self, iteration_log):
with tempfile.TemporaryDirectory() as tempdir:
# set up engine
def _train_func(engine, batch):
return [batch + 1.0]

engine = Engine(_train_func)

# set up dummy metric
@engine.on(Events.EPOCH_COMPLETED)
def _update_metric(engine):
current_metric = engine.state.metrics.get("acc", 0.1)
engine.state.metrics["acc"] = current_metric + 0.1
engine.state.test = current_metric

# set up testing handler
writer = SummaryWriter(log_dir=tempdir)
stats_handler = TensorBoardStatsHandler(
summary_writer=writer,
iteration_log=iteration_log,
epoch_log=False,
output_transform=lambda x: {"loss": x[0] * 2.0},
global_epoch_transform=lambda x: x * 3.0,
state_attributes=["test"],
)
stats_handler._default_iteration_writer = MagicMock()
stats_handler.attach(engine)

num_iters = 3
max_epochs = 2
engine.run(range(num_iters), max_epochs=max_epochs)
writer.close()

if iteration_log is True:
self.assertEqual(stats_handler._default_iteration_writer.call_count, num_iters * max_epochs)
else:
self.assertEqual(
stats_handler._default_iteration_writer.call_count, 2
) # 2 = len([1, 3]) from event_filter


if __name__ == "__main__":
unittest.main()