Skip to content

Commit

Permalink
498 Add logger_handler to LrScheduleHandler (Project-MONAI#3570)
Browse files Browse the repository at this point in the history
* [DLMED] add log handler

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] fix CI tests

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] fix CI test

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] test CI

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] fix logging

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] temp test

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] fix wrong unit test

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] fix wrong test cases

Signed-off-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
Nic-Ma authored and Can-Zhao committed Jan 10, 2022
1 parent 8edd498 commit 8ade1f7
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 39 deletions.
6 changes: 6 additions & 0 deletions monai/handlers/lr_schedule_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
name: Optional[str] = None,
epoch_level: bool = True,
step_transform: Callable[[Engine], Any] = lambda engine: (),
logger_handler: Optional[logging.Handler] = None,
) -> None:
"""
Args:
Expand All @@ -47,6 +48,9 @@ def __init__(
`True` is epoch level, `False` is iteration level.
step_transform: a callable that is used to transform the information from `engine`
to expected input data of lr_scheduler.step() function if necessary.
logger_handler: if `print_lr` is True, add additional handler to log the learning rate: save to file, etc.
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.
Raises:
TypeError: When ``step_transform`` is not ``callable``.
Expand All @@ -59,6 +63,8 @@ def __init__(
if not callable(step_transform):
raise TypeError(f"step_transform must be callable but is {type(step_transform).__name__}.")
self.step_transform = step_transform
if logger_handler is not None:
self.logger.addHandler(logger_handler)

self._name = name

Expand Down
3 changes: 2 additions & 1 deletion monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(
tag_name: scalar_value to logger. Defaults to ``'Loss'``.
key_var_format: a formatting string to control the output string format of key: value.
logger_handler: add additional handler to handle the stats data: save to file, etc.
Add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.
"""

self.epoch_print_logger = epoch_print_logger
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def __init__(
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
additional_info: user can define callable function to extract additional info from input data.
logger_handler: add additional handler to output data: save to file, etc.
add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.
Raises:
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def __init__(
additional info from input data. it also can be a sequence of string, each element
corresponds to a key in ``keys``.
logger_handler: add additional handler to output data: save to file, etc.
add existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html
all the existing python logging handlers: https://docs.python.org/3/library/logging.handlers.html.
the handler should have a logging level of at least `INFO`.
allow_missing_keys: don't raise exception if key is missing.
Expand Down
48 changes: 37 additions & 11 deletions tests/test_handler_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# limitations under the License.

import logging
import os
import re
import sys
import tempfile
import unittest

import numpy as np
Expand All @@ -24,6 +27,8 @@ class TestHandlerLrSchedule(unittest.TestCase):
def test_content(self):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
data = [0] * 8
test_lr = 0.1
gamma = 0.1

# set up engine
def _train_func(engine, batch):
Expand All @@ -41,24 +46,45 @@ def run_validation(engine):
net = torch.nn.PReLU()

def _reduce_lr_on_plateau():
optimizer = torch.optim.SGD(net.parameters(), 0.1)
optimizer = torch.optim.SGD(net.parameters(), test_lr)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1)
handler = LrScheduleHandler(lr_scheduler, step_transform=lambda x: val_engine.state.metrics["val_loss"])
handler.attach(train_engine)
return lr_scheduler
return handler

def _reduce_on_step():
optimizer = torch.optim.SGD(net.parameters(), 0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
handler = LrScheduleHandler(lr_scheduler)
handler.attach(train_engine)
return lr_scheduler
with tempfile.TemporaryDirectory() as tempdir:
key_to_handler = "test_log_lr"
key_to_print = "Current learning rate"
filename = os.path.join(tempdir, "test_lr.log")
# test with additional logging handler
file_saver = logging.FileHandler(filename, mode="w")
file_saver.setLevel(logging.INFO)

def _reduce_on_step():
optimizer = torch.optim.SGD(net.parameters(), test_lr)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma)
handler = LrScheduleHandler(lr_scheduler, name=key_to_handler, logger_handler=file_saver)
handler.attach(train_engine)
handler.logger.setLevel(logging.INFO)
return handler

schedulers = _reduce_lr_on_plateau(), _reduce_on_step()

train_engine.run(data, max_epochs=5)
file_saver.close()
schedulers[1].logger.removeHandler(file_saver)

schedulers = _reduce_lr_on_plateau(), _reduce_on_step()
with open(filename) as f:
output_str = f.read()
has_key_word = re.compile(f".*{key_to_print}.*")
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

train_engine.run(data, max_epochs=5)
for scheduler in schedulers:
np.testing.assert_allclose(scheduler._last_lr[0], 0.001)
np.testing.assert_allclose(scheduler.lr_scheduler._last_lr[0], 0.001)


if __name__ == "__main__":
Expand Down
56 changes: 31 additions & 25 deletions tests/test_handler_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ def _update_metric(engine):
# set up testing handler
stats_handler = StatsHandler(name=key_to_handler, logger_handler=log_handler)
stats_handler.attach(engine)
stats_handler.logger.setLevel(logging.INFO)

engine.run(range(3), max_epochs=2)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(f".*{key_to_print}.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line):
if idx in [5, 10]:
self.assertTrue(has_key_word.match(line))
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

def test_loss_print(self):
log_stream = StringIO()
Expand All @@ -74,18 +75,19 @@ def _train_func(engine, batch):
# set up testing handler
stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=log_handler)
stats_handler.attach(engine)
stats_handler.logger.setLevel(logging.INFO)

engine.run(range(3), max_epochs=2)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(f".*{key_to_print}.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line):
if idx in [1, 2, 3, 6, 7, 8]:
self.assertTrue(has_key_word.match(line))
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

def test_loss_dict(self):
log_stream = StringIO()
Expand All @@ -102,21 +104,22 @@ def _train_func(engine, batch):

# set up testing handler
stats_handler = StatsHandler(
name=key_to_handler, output_transform=lambda x: {key_to_print: x}, logger_handler=log_handler
name=key_to_handler, output_transform=lambda x: {key_to_print: x[0]}, logger_handler=log_handler
)
stats_handler.attach(engine)
stats_handler.logger.setLevel(logging.INFO)

engine.run(range(3), max_epochs=2)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(f".*{key_to_print}.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line):
if idx in [1, 2, 3, 6, 7, 8]:
self.assertTrue(has_key_word.match(line))
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

def test_loss_file(self):
key_to_handler = "test_logging"
Expand All @@ -136,18 +139,19 @@ def _train_func(engine, batch):
# set up testing handler
stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=handler)
stats_handler.attach(engine)
stats_handler.logger.setLevel(logging.INFO)

engine.run(range(3), max_epochs=2)
handler.close()
stats_handler.logger.removeHandler(handler)
with open(filename) as f:
output_str = f.read()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(f".*{key_to_print}.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line):
if idx in [1, 2, 3, 6, 7, 8]:
self.assertTrue(has_key_word.match(line))
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)

def test_exception(self):
# set up engine
Expand Down Expand Up @@ -190,17 +194,19 @@ def _update_metric(engine):
name=key_to_handler, state_attributes=["test1", "test2", "test3"], logger_handler=log_handler
)
stats_handler.attach(engine)
stats_handler.logger.setLevel(logging.INFO)

engine.run(range(3), max_epochs=2)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(".*State values.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line) and idx in [5, 10]:
self.assertTrue(has_key_word.match(line))
content_count = 0
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
self.assertTrue(content_count > 0)


if __name__ == "__main__":
Expand Down

0 comments on commit 8ade1f7

Please sign in to comment.