diff --git a/aiida/common/log.py b/aiida/common/log.py index a545871422..e993950423 100644 --- a/aiida/common/log.py +++ b/aiida/common/log.py @@ -13,9 +13,10 @@ import collections import contextlib import enum +import io import logging import types -from typing import cast +import typing as t __all__ = ('AIIDA_LOGGER', 'override_log_level') @@ -56,7 +57,7 @@ def report(self, msg: str, *args, **kwargs) -> None: LogLevels = enum.Enum('LogLevels', {key: key for key in LOG_LEVELS}) # type: ignore[misc] -AIIDA_LOGGER = cast(AiidaLoggerType, logging.getLogger('aiida')) +AIIDA_LOGGER = t.cast(AiidaLoggerType, logging.getLogger('aiida')) CLI_ACTIVE: bool | None = None """Flag that is set to ``True`` if the module is imported by ``verdi`` being called.""" @@ -249,3 +250,22 @@ def override_log_level(level=logging.CRITICAL): yield finally: logging.disable(level=logging.NOTSET) + + +@contextlib.contextmanager +def capture_logging(logger: logging.Logger = AIIDA_LOGGER) -> t.Generator[io.StringIO, None, None]: + """Capture logging to a stream in memory. + + Note, this only copies any content that is being logged to a stream in memory. It does not interfere with any other + existing stream handlers. In this sense, this context manager is non-destructive. + + :param logger: The logger whose output to capture. + :returns: A stream to which the logged content is captured. + """ + stream = io.StringIO() + handler = logging.StreamHandler(stream) + logger.addHandler(handler) + try: + yield stream + finally: + logger.removeHandler(handler) diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 8be012bbf5..5187ce401e 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -1,5 +1,6 @@ ### python builtins py:class _io.BufferedReader +py:class _io.StringIO py:class AliasedClass py:class asyncio.events.AbstractEventLoop py:class AbstractContextManager diff --git a/tests/common/test_logging.py b/tests/common/test_logging.py index 09f51a5e51..f3a092839b 100644 --- a/tests/common/test_logging.py +++ b/tests/common/test_logging.py @@ -10,6 +10,8 @@ """Tests for the :mod:`aiida.common.log` module.""" import logging +from aiida.common.log import capture_logging + def test_logging_before_dbhandler_loaded(caplog): """Test that logging still works even if no database is loaded. @@ -36,3 +38,12 @@ def test_log_report(caplog): logger.report(msg) assert caplog.record_tuples == [(logger.name, logging.REPORT, msg)] # pylint: disable=no-member + + +def test_capture_logging(): + """Test the :func:`aiida.common.log.capture_logging` function.""" + logger = logging.getLogger() + message = 'Some message' + with capture_logging(logger) as stream: + logging.getLogger().error(message) + assert stream.getvalue().strip() == message