Skip to content

Commit

Permalink
utils: add test for logging with timing
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed May 22, 2024
1 parent d9f4e5e commit 30eb49d
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 15 deletions.
10 changes: 5 additions & 5 deletions molpipeline/pipeline/_skl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union

from molpipeline.utils.logging import _print_elapsed_time

try:
from typing import Self # type: ignore[attr-defined]
Expand Down Expand Up @@ -34,6 +33,7 @@
PostPredictionTransformation,
PostPredictionWrapper,
)
from molpipeline.utils.logging import print_elapsed_time
from molpipeline.utils.molpipeline_types import (
AnyElement,
AnyPredictor,
Expand Down Expand Up @@ -242,7 +242,7 @@ def _fit(
for step in self._iter(with_final=False, filter_passthrough=False):
step_idx, name, transformer = step
if transformer is None or transformer == "passthrough":
with _print_elapsed_time("Pipeline", self._log_message(step_idx)):
with print_elapsed_time("Pipeline", self._log_message(step_idx)):
continue

if hasattr(memory, "location") and memory.location is None:
Expand Down Expand Up @@ -459,7 +459,7 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self:
"""
routed_params = self._check_method_params(method="fit", props=fit_params)
Xt, yt = self._fit(X, y, routed_params) # pylint: disable=invalid-name
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if self._final_estimator != "passthrough":
if is_empty(Xt):
logger.warning(
Expand Down Expand Up @@ -530,7 +530,7 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any:
routed_params = self._check_method_params(method="fit_transform", props=params)
iter_input, iter_label = self._fit(X, y, routed_params)
last_step = self._final_estimator
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if last_step == "passthrough":
pass
elif is_empty(iter_input):
Expand Down Expand Up @@ -650,7 +650,7 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any:
) # pylint: disable=invalid-name

params_last_step = routed_params[self.steps[-1][0]]
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if self._final_estimator == "passthrough":
y_pred = iter_input
elif is_empty(iter_input):
Expand Down
24 changes: 14 additions & 10 deletions molpipeline/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Logging helper functions."""

from __future__ import annotations

import timeit
Expand All @@ -7,7 +8,7 @@
from loguru import logger


def _message_with_time(source: str, message: str, time: float):
def _message_with_time(source: str, message: str, time: float) -> str:
"""Create one line message for logging purposes.
Adapted from sklearn's function to stay consistent with the logging style:
Expand All @@ -17,28 +18,29 @@ def _message_with_time(source: str, message: str, time: float):
----------
source : str
String indicating the source or the reference of the message.
message : str
Short message.
time : float
Time in seconds.
"""
start_message = "[%s] " % source
start_message = f"[{source}] "

# adapted from joblib.logger.short_format_time without the Windows -.1s
# adjustment
if time > 60:
time_str = "%4.1fmin" % (time / 60)
time_str = f"{(time / 60):4.1f}min"
else:
time_str = " %5.1fs" % time
end_message = " %s, total=%s" % (message, time_str)
time_str = f" {time:5.1f}s"

end_message = f" {message}, total={time_str}"
dots_len = 70 - len(start_message) - len(end_message)
return f"{start_message}{dots_len * '.'}{end_message}"


@contextmanager
def _print_elapsed_time(source: str, message: str | None = None, use_logger: bool = False):
def print_elapsed_time(
source: str, message: str | None = None, use_logger: bool = False
) -> None:
"""Log elapsed time to stdout when the context is exited.
Adapted from sklearn's function to stay consistent with the logging style:
Expand All @@ -48,7 +50,6 @@ def _print_elapsed_time(source: str, message: str | None = None, use_logger: boo
----------
source : str
String indicating the source or the reference of the message.
message : str, default=None
Short message. If None, nothing will be printed.
use_logger : bool, default=False
Expand All @@ -64,7 +65,10 @@ def _print_elapsed_time(source: str, message: str | None = None, use_logger: boo
else:
start = timeit.default_timer()
yield
message_to_print = _message_with_time(source, message, timeit.default_timer() - start)
message_to_print = _message_with_time(
source, message, timeit.default_timer() - start
)

if use_logger:
logger.info(message_to_print)
else:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_utils/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Test logging utils."""

import io
import unittest
from contextlib import redirect_stdout

from molpipeline.utils.logging import print_elapsed_time


class LoggingUtilsTest(unittest.TestCase):
"""Unittest for conversion of sklearn models to json and back."""

def test__print_elapsed_time(self) -> None:
"""Test message logging with timings work as expected."""

# when message is None nothing should be printed
stream1 = io.StringIO()
with redirect_stdout(stream1):
with print_elapsed_time("source", message=None, use_logger=False):
pass
output1 = stream1.getvalue()
self.assertEqual(output1, "")

# message should be printed in the expected sklearn format
stream2 = io.StringIO()
with redirect_stdout(stream2):
with print_elapsed_time("source", message="my message", use_logger=False):
pass
output2 = stream2.getvalue()
self.assertTrue(
output2.startswith(
"[source] ................................... my message, total="
)
)

0 comments on commit 30eb49d

Please sign in to comment.