Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added callable options for iteration_log and epoch_log in StatsHandler #5965

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class StatsHandler:

def __init__(
self,
iteration_log: bool = True,
epoch_log: bool = True,
iteration_log: bool | Callable[[Engine, int], bool] = True,
epoch_log: bool | Callable[[Engine, int], bool] = True,
epoch_print_logger: Callable[[Engine], Any] | None = None,
iteration_print_logger: Callable[[Engine], Any] | None = None,
output_transform: Callable = lambda x: x[0],
Expand All @@ -80,8 +80,14 @@ def __init__(
"""
Args:
iteration_log: whether to log data when iteration completed, default to `True`.
epoch_log: whether to log data when epoch completed, default to `True`.
iteration_log: whether to log data when iteration completed, default to `True`. ``iteration_log`` can
be also a function and it will be interpreted as an event filter
(see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).
Event filter function accepts as input engine and event value (iteration) and should return True/False.
Event filtering can be helpful to customize iteration logging frequency.
epoch_log: whether to log data when epoch completed, default to `True`. ``epoch_log`` can be
also a function and it will be interpreted as an event filter. See ``iteration_log`` argument for more
details.
epoch_print_logger: customized callable printer for epoch level logging.
Must accept parameter "engine", use default printer if None.
iteration_print_logger: customized callable printer for iteration level logging.
Expand Down Expand Up @@ -135,9 +141,19 @@ def attach(self, engine: Engine) -> None:
" please call `logging.basicConfig(stream=sys.stdout, level=logging.INFO)` to enable it."
)
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
event = (
Events.ITERATION_COMPLETED(event_filter=self.iteration_log)
if callable(self.iteration_log)
else Events.ITERATION_COMPLETED
)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
engine.add_event_handler(event, self.iteration_completed)
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)
event = (
Events.EPOCH_COMPLETED(event_filter=self.epoch_log)
if callable(self.epoch_log)
else Events.EPOCH_COMPLETED
)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
engine.add_event_handler(event, self.epoch_completed)
if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED):
engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)

Expand Down
137 changes: 80 additions & 57 deletions tests/test_handler_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,74 +26,97 @@

class TestHandlerStats(unittest.TestCase):
def test_metrics_print(self):
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
key_to_handler = "test_logging"
key_to_print = "testing_metric"
def event_filter(_, event):
if event in [1, 2]:
return True
return False

for epoch_log in [True, event_filter]:
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
key_to_handler = "test_logging"
key_to_print = "testing_metric"

# set up engine
def _train_func(engine, batch):
return [torch.tensor(0.0)]
# set up engine
def _train_func(engine, batch):
return [torch.tensor(0.0)]

engine = Engine(_train_func)
engine = Engine(_train_func)

# set up dummy metric
@engine.on(Events.EPOCH_COMPLETED)
def _update_metric(engine):
current_metric = engine.state.metrics.get(key_to_print, 0.1)
engine.state.metrics[key_to_print] = current_metric + 0.1
# set up dummy metric
@engine.on(Events.EPOCH_COMPLETED)
def _update_metric(engine):
current_metric = engine.state.metrics.get(key_to_print, 0.1)
engine.state.metrics[key_to_print] = current_metric + 0.1

# set up testing handler
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(iteration_log=False, epoch_log=True, name=key_to_handler)
stats_handler.attach(engine)

engine.run(range(3), max_epochs=2)
# set up testing handler
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(iteration_log=False, epoch_log=epoch_log, name=key_to_handler)
stats_handler.attach(engine)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
has_key_word = re.compile(f".*{key_to_print}.*")
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)
max_epochs = 4
engine.run(range(3), max_epochs=max_epochs)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
has_key_word = re.compile(f".*{key_to_print}.*")
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
if epoch_log is True:
self.assertTrue(content_count == max_epochs)
else:
self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter

def test_loss_print(self):
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
key_to_handler = "test_logging"
key_to_print = "myLoss"

# set up engine
def _train_func(engine, batch):
return [torch.tensor(0.0)]
def event_filter(_, event):
if event in [1, 3]:
return True
return False

for iteration_log in [True, event_filter]:
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
key_to_handler = "test_logging"
key_to_print = "myLoss"

engine = Engine(_train_func)
# set up engine
def _train_func(engine, batch):
return [torch.tensor(0.0)]

# set up testing handler
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(iteration_log=True, epoch_log=False, name=key_to_handler, tag_name=key_to_print)
stats_handler.attach(engine)
engine = Engine(_train_func)

engine.run(range(3), max_epochs=2)
# set up testing handler
logger = logging.getLogger(key_to_handler)
logger.setLevel(logging.INFO)
logger.addHandler(log_handler)
stats_handler = StatsHandler(
iteration_log=iteration_log, epoch_log=False, name=key_to_handler, tag_name=key_to_print
)
stats_handler.attach(engine)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
has_key_word = re.compile(f".*{key_to_print}.*")
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)
num_iters = 3
max_epochs = 2
engine.run(range(num_iters), max_epochs=max_epochs)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
has_key_word = re.compile(f".*{key_to_print}.*")
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
if iteration_log is True:
self.assertTrue(content_count == num_iters * max_epochs)
else:
self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter

def test_loss_dict(self):
log_stream = StringIO()
Expand Down