From 30eb49d6e74f52f0f788c91a78c4d43813a76851 Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Wed, 22 May 2024 16:08:17 +0200 Subject: [PATCH] utils: add test for logging with timing --- molpipeline/pipeline/_skl_pipeline.py | 10 ++++---- molpipeline/utils/logging.py | 24 +++++++++++-------- tests/test_utils/test_logging.py | 34 +++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 15 deletions(-) create mode 100644 tests/test_utils/test_logging.py diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index cfd626f2..cdc85e91 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -4,7 +4,6 @@ from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union -from molpipeline.utils.logging import _print_elapsed_time try: from typing import Self # type: ignore[attr-defined] @@ -34,6 +33,7 @@ PostPredictionTransformation, PostPredictionWrapper, ) +from molpipeline.utils.logging import print_elapsed_time from molpipeline.utils.molpipeline_types import ( AnyElement, AnyPredictor, @@ -242,7 +242,7 @@ def _fit( for step in self._iter(with_final=False, filter_passthrough=False): step_idx, name, transformer = step if transformer is None or transformer == "passthrough": - with _print_elapsed_time("Pipeline", self._log_message(step_idx)): + with print_elapsed_time("Pipeline", self._log_message(step_idx)): continue if hasattr(memory, "location") and memory.location is None: @@ -459,7 +459,7 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self: """ routed_params = self._check_method_params(method="fit", props=fit_params) Xt, yt = self._fit(X, y, routed_params) # pylint: disable=invalid-name - with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): if self._final_estimator != "passthrough": if is_empty(Xt): logger.warning( @@ -530,7 +530,7 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any: routed_params = self._check_method_params(method="fit_transform", props=params) iter_input, iter_label = self._fit(X, y, routed_params) last_step = self._final_estimator - with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): if last_step == "passthrough": pass elif is_empty(iter_input): @@ -650,7 +650,7 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any: ) # pylint: disable=invalid-name params_last_step = routed_params[self.steps[-1][0]] - with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): if self._final_estimator == "passthrough": y_pred = iter_input elif is_empty(iter_input): diff --git a/molpipeline/utils/logging.py b/molpipeline/utils/logging.py index 22c12186..fd69fc25 100644 --- a/molpipeline/utils/logging.py +++ b/molpipeline/utils/logging.py @@ -1,4 +1,5 @@ """Logging helper functions.""" + from __future__ import annotations import timeit @@ -7,7 +8,7 @@ from loguru import logger -def _message_with_time(source: str, message: str, time: float): +def _message_with_time(source: str, message: str, time: float) -> str: """Create one line message for logging purposes. Adapted from sklearn's function to stay consistent with the logging style: @@ -17,28 +18,29 @@ def _message_with_time(source: str, message: str, time: float): ---------- source : str String indicating the source or the reference of the message. - message : str Short message. - time : float Time in seconds. """ - start_message = "[%s] " % source + start_message = f"[{source}] " # adapted from joblib.logger.short_format_time without the Windows -.1s # adjustment if time > 60: - time_str = "%4.1fmin" % (time / 60) + time_str = f"{(time / 60):4.1f}min" else: - time_str = " %5.1fs" % time - end_message = " %s, total=%s" % (message, time_str) + time_str = f" {time:5.1f}s" + + end_message = f" {message}, total={time_str}" dots_len = 70 - len(start_message) - len(end_message) return f"{start_message}{dots_len * '.'}{end_message}" @contextmanager -def _print_elapsed_time(source: str, message: str | None = None, use_logger: bool = False): +def print_elapsed_time( + source: str, message: str | None = None, use_logger: bool = False +) -> None: """Log elapsed time to stdout when the context is exited. Adapted from sklearn's function to stay consistent with the logging style: @@ -48,7 +50,6 @@ def _print_elapsed_time(source: str, message: str | None = None, use_logger: boo ---------- source : str String indicating the source or the reference of the message. - message : str, default=None Short message. If None, nothing will be printed. use_logger : bool, default=False @@ -64,7 +65,10 @@ def _print_elapsed_time(source: str, message: str | None = None, use_logger: boo else: start = timeit.default_timer() yield - message_to_print = _message_with_time(source, message, timeit.default_timer() - start) + message_to_print = _message_with_time( + source, message, timeit.default_timer() - start + ) + if use_logger: logger.info(message_to_print) else: diff --git a/tests/test_utils/test_logging.py b/tests/test_utils/test_logging.py new file mode 100644 index 00000000..f737119c --- /dev/null +++ b/tests/test_utils/test_logging.py @@ -0,0 +1,34 @@ +"""Test logging utils.""" + +import io +import unittest +from contextlib import redirect_stdout + +from molpipeline.utils.logging import print_elapsed_time + + +class LoggingUtilsTest(unittest.TestCase): + """Unittest for conversion of sklearn models to json and back.""" + + def test__print_elapsed_time(self) -> None: + """Test message logging with timings work as expected.""" + + # when message is None nothing should be printed + stream1 = io.StringIO() + with redirect_stdout(stream1): + with print_elapsed_time("source", message=None, use_logger=False): + pass + output1 = stream1.getvalue() + self.assertEqual(output1, "") + + # message should be printed in the expected sklearn format + stream2 = io.StringIO() + with redirect_stdout(stream2): + with print_elapsed_time("source", message="my message", use_logger=False): + pass + output2 = stream2.getvalue() + self.assertTrue( + output2.startswith( + "[source] ................................... my message, total=" + ) + )