Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

498 Add logger_handler to LrScheduleHandler #3570

Merged
merged 15 commits into from
Jan 5, 2022
6 changes: 6 additions & 0 deletions monai/handlers/lr_schedule_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
name: Optional[str] = None,
epoch_level: bool = True,
step_transform: Callable[[Engine], Any] = lambda engine: (),
logger_handler: Optional[logging.Handler] = None,
) -> None:
"""
Args:
Expand All @@ -47,6 +48,9 @@ def __init__(
`True` is epoch level, `False` is iteration level.
step_transform: a callable that is used to transform the information from `engine`
to expected input data of lr_scheduler.step() function if necessary.
logger_handler: if `print_lr` is True, add additional handler to log the learning rate: save to file, etc.
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.
wyli marked this conversation as resolved.
Show resolved Hide resolved

Raises:
TypeError: When ``step_transform`` is not ``callable``.
Expand All @@ -59,6 +63,8 @@ def __init__(
if not callable(step_transform):
raise TypeError(f"step_transform must be callable but is {type(step_transform).__name__}.")
self.step_transform = step_transform
if logger_handler is not None:
self.logger.addHandler(logger_handler)

self._name = name

Expand Down
3 changes: 2 additions & 1 deletion monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(
tag_name: scalar_value to logger. Defaults to ``'Loss'``.
key_var_format: a formatting string to control the output string format of key: value.
logger_handler: add additional handler to handle the stats data: save to file, etc.
Add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.
"""

self.epoch_print_logger = epoch_print_logger
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def __init__(
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
additional_info: user can define callable function to extract additional info from input data.
logger_handler: add additional handler to output data: save to file, etc.
add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.

Raises:
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def __init__(
additional info from input data. it also can be a sequence of string, each element
corresponds to a key in ``keys``.
logger_handler: add additional handler to output data: save to file, etc.
add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.
allow_missing_keys: don't raise exception if key is missing.

Expand Down
48 changes: 37 additions & 11 deletions tests/test_handler_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# limitations under the License.

import logging
import os
import re
import sys
import tempfile
import unittest

import numpy as np
Expand All @@ -24,6 +27,8 @@ class TestHandlerLrSchedule(unittest.TestCase):
def test_content(self):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
data = [0] * 8
test_lr = 0.1
gamma = 0.1

# set up engine
def _train_func(engine, batch):
Expand All @@ -41,24 +46,45 @@ def run_validation(engine):
net = torch.nn.PReLU()

def _reduce_lr_on_plateau():
optimizer = torch.optim.SGD(net.parameters(), 0.1)
optimizer = torch.optim.SGD(net.parameters(), test_lr)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1)
handler = LrScheduleHandler(lr_scheduler, step_transform=lambda x: val_engine.state.metrics["val_loss"])
handler.attach(train_engine)
return lr_scheduler
return handler

def _reduce_on_step():
optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
handler = LrScheduleHandler(lr_scheduler)
handler.attach(train_engine)
return lr_scheduler
with tempfile.TemporaryDirectory() as tempdir:
key_to_handler = "test_log_lr"
key_to_print = "Current learning rate"
filename = os.path.join(tempdir, "test_lr.log")
# test with additional logging handler
file_saver = logging.FileHandler(filename, mode="w")
file_saver.setLevel(logging.INFO)

def _reduce_on_step():
optimizer = torch.optim.SGD(net.parameters(), test_lr)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma)
handler = LrScheduleHandler(lr_scheduler, name=key_to_handler, logger_handler=file_saver)
handler.attach(train_engine)
handler.logger.setLevel(logging.INFO)
return handler

schedulers = _reduce_lr_on_plateau(), _reduce_on_step()

train_engine.run(data, max_epochs=5)
file_saver.close()
schedulers[1].logger.removeHandler(file_saver)

schedulers = _reduce_lr_on_plateau(), _reduce_on_step()
with open(filename) as f:
output_str = f.read()
has_key_word = re.compile(f".*{key_to_print}.*")
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

train_engine.run(data, max_epochs=5)
for scheduler in schedulers:
np.testing.assert_allclose(scheduler._last_lr[0], 0.001)
np.testing.assert_allclose(scheduler.lr_scheduler._last_lr[0], 0.001)


if __name__ == "__main__":
Expand Down
11 changes: 6 additions & 5 deletions tests/test_handler_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,19 @@ def _train_func(engine, batch):
# set up testing handler
stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=handler)
stats_handler.attach(engine)
stats_handler.logger.setLevel(logging.INFO)

engine.run(range(3), max_epochs=2)
handler.close()
stats_handler.logger.removeHandler(handler)
with open(filename) as f:
output_str = f.read()
Nic-Ma marked this conversation as resolved.
Show resolved Hide resolved
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(f".*{key_to_print}.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line):
if idx in [1, 2, 3, 6, 7, 8]:
self.assertTrue(has_key_word.match(line))
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

def test_exception(self):
# set up engine
Expand Down