Skip to content

Commit

Permalink
Merge pull request #14 from basf/adapt_own_time_elapsed_from_sklearn
Browse files Browse the repository at this point in the history
utils: add own _print_elapsed_time
  • Loading branch information
JochenSiegWork authored May 22, 2024
2 parents f3c7007 + b4cb32a commit 77fca18
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 5 deletions.
12 changes: 7 additions & 5 deletions molpipeline/pipeline/_skl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union


try:
from typing import Self # type: ignore[attr-defined]
except ImportError:
Expand All @@ -17,7 +18,7 @@
from sklearn.base import clone
from sklearn.pipeline import Pipeline as _Pipeline
from sklearn.pipeline import _final_estimator_has, _fit_transform_one
from sklearn.utils import Bunch, _print_elapsed_time
from sklearn.utils import Bunch
from sklearn.utils.metadata_routing import (
_routing_enabled, # pylint: disable=protected-access
)
Expand All @@ -32,6 +33,7 @@
PostPredictionTransformation,
PostPredictionWrapper,
)
from molpipeline.utils.logging import print_elapsed_time
from molpipeline.utils.molpipeline_types import (
AnyElement,
AnyPredictor,
Expand Down Expand Up @@ -240,7 +242,7 @@ def _fit(
for step in self._iter(with_final=False, filter_passthrough=False):
step_idx, name, transformer = step
if transformer is None or transformer == "passthrough":
with _print_elapsed_time("Pipeline", self._log_message(step_idx)):
with print_elapsed_time("Pipeline", self._log_message(step_idx)):
continue

if hasattr(memory, "location") and memory.location is None:
Expand Down Expand Up @@ -457,7 +459,7 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self:
"""
routed_params = self._check_method_params(method="fit", props=fit_params)
Xt, yt = self._fit(X, y, routed_params) # pylint: disable=invalid-name
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if self._final_estimator != "passthrough":
if is_empty(Xt):
logger.warning(
Expand Down Expand Up @@ -528,7 +530,7 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any:
routed_params = self._check_method_params(method="fit_transform", props=params)
iter_input, iter_label = self._fit(X, y, routed_params)
last_step = self._final_estimator
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if last_step == "passthrough":
pass
elif is_empty(iter_input):
Expand Down Expand Up @@ -648,7 +650,7 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any:
) # pylint: disable=invalid-name

params_last_step = routed_params[self.steps[-1][0]]
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if self._final_estimator == "passthrough":
y_pred = iter_input
elif is_empty(iter_input):
Expand Down
81 changes: 81 additions & 0 deletions molpipeline/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Logging helper functions."""

from __future__ import annotations

import timeit
from contextlib import contextmanager
from typing import Generator

from loguru import logger


def _message_with_time(source: str, message: str, time: float) -> str:
"""Create one line message for logging purposes.
Adapted from sklearn's function to stay consistent with the logging style:
https://github.com/scikit-learn/scikit-learn/blob/e16a6ddebd527e886fc22105710ee20ce255f9f0/sklearn/utils/_user_interface.py
Parameters
----------
source : str
String indicating the source or the reference of the message.
message : str
Short message.
time : float
Time in seconds.
Returns
-------
str
Message with elapsed time.
"""
start_message = f"[{source}] "

# adapted from joblib.logger.short_format_time without the Windows -.1s
# adjustment
if time > 60:
time_str = f"{(time / 60):4.1f}min"
else:
time_str = f" {time:5.1f}s"

end_message = f" {message}, total={time_str}"
dots_len = 70 - len(start_message) - len(end_message)
return f"{start_message}{dots_len * '.'}{end_message}"


@contextmanager
def print_elapsed_time(
source: str, message: str | None = None, use_logger: bool = False
) -> Generator[None, None, None]:
"""Log elapsed time to stdout when the context is exited.
Adapted from sklearn's function to stay consistent with the logging style:
https://github.com/scikit-learn/scikit-learn/blob/e16a6ddebd527e886fc22105710ee20ce255f9f0/sklearn/utils/_user_interface.py
Parameters
----------
source : str
String indicating the source or the reference of the message.
message : str, default=None
Short message. If None, nothing will be printed.
use_logger : bool, default=False
If True, the message will be logged using the logger.
Returns
-------
context_manager
Prints elapsed time upon exit if verbose.
"""
if message is None:
yield
else:
start = timeit.default_timer()
yield
message_to_print = _message_with_time(
source, message, timeit.default_timer() - start
)

if use_logger:
logger.info(message_to_print)
else:
print(message_to_print)
34 changes: 34 additions & 0 deletions tests/test_utils/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Test logging utils."""

import io
import unittest
from contextlib import redirect_stdout

from molpipeline.utils.logging import print_elapsed_time


class LoggingUtilsTest(unittest.TestCase):
"""Unittest for conversion of sklearn models to json and back."""

def test__print_elapsed_time(self) -> None:
"""Test message logging with timings work as expected."""

# when message is None nothing should be printed
stream1 = io.StringIO()
with redirect_stdout(stream1):
with print_elapsed_time("source", message=None, use_logger=False):
pass
output1 = stream1.getvalue()
self.assertEqual(output1, "")

# message should be printed in the expected sklearn format
stream2 = io.StringIO()
with redirect_stdout(stream2):
with print_elapsed_time("source", message="my message", use_logger=False):
pass
output2 = stream2.getvalue()
self.assertTrue(
output2.startswith(
"[source] ................................... my message, total="
)
)

0 comments on commit 77fca18

Please sign in to comment.