From 619d64b8609e1cd136435f172efae534f38c4f88 Mon Sep 17 00:00:00 2001 From: Andrew Shao Date: Fri, 5 Apr 2024 10:34:52 -0700 Subject: [PATCH 1/2] Optionally skip building Torch with Intel MKL (#538) On systems that have the Intel Compilers and/or the Intel Math Kernel library installed, the Caffe2 package that comes with Torch will unconditionally try to link in the MKL during the Torch backend. This however can lead to two types of failures: - Problems when compiling the Torch backend because the linker does not include the path to the MKL library path - Loading the Torch backend into RedisAI fails because the user does not expect to need to have the MKL library loaded. To alleviate this, a new option "--no_torch_with_mkl" has been added to the `smart build` command that modifies the mkl.cmake file to prevent the detection of MKL. [ committed by @ashao ] [ reviewed by @MattToast and @al-rigazzi ] --- doc/changelog.rst | 7 ++++ smartsim/_core/_cli/build.py | 10 +++++ smartsim/_core/_install/builder.py | 66 ++++++++++++++++++++---------- tests/install/test_builder.py | 42 ++++++++++++++++--- 4 files changed, 98 insertions(+), 27 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 3e73101e1..5e7b89f0f 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -18,6 +18,7 @@ To be released at some future point in time Description +- Add option to build Torch backend without the Intel Math Kernel Library - Fix ReadTheDocs build issue - Promote device options to an Enum - Update telemetry monitor, add telemetry collectors @@ -35,6 +36,11 @@ Description Detailed Notes +- Add an option to smart build "--torch_with_mkl"/"--no_torch_with_mkl" to + prevent Torch from trying to link in the Intel Math Kernel Library. This + is needed because on machines that have the Intel compilers installed, the + Torch will unconditionally try to link in this library, however fails + because the linking flags are incorrect. (SmartSim-PR538_) - Change type_extension and pydantic versions in readthedocs environment to enable docs build. (SmartSim-PR537_) - Promote devices to a dedicated Enum type throughout the SmartSim code base. @@ -77,6 +83,7 @@ Detailed Notes - Remove previously deprecated behavior present in test suite on machines with Slurm and Open MPI. (SmartSim-PR520_) +.. _SmartSim-PR538: https://github.com/CrayLabs/SmartSim/pull/538 .. _SmartSim-PR537: https://github.com/CrayLabs/SmartSim/pull/537 .. _SmartSim-PR498: https://github.com/CrayLabs/SmartSim/pull/498 .. _SmartSim-PR460: https://github.com/CrayLabs/SmartSim/pull/460 diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 08a1a6138..ab982ac1b 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -139,6 +139,7 @@ def build_redis_ai( torch_dir: t.Union[str, Path, None] = None, libtf_dir: t.Union[str, Path, None] = None, verbose: bool = False, + torch_with_mkl: bool = True, ) -> None: # make sure user isn't trying to do something silly on MacOS if build_env.PLATFORM == "darwin" and device == Device.GPU: @@ -186,6 +187,7 @@ def build_redis_ai( build_tf=use_tf, build_onnx=use_onnx, verbose=verbose, + torch_with_mkl=torch_with_mkl, ) if rai_builder.is_built: @@ -414,6 +416,7 @@ def execute( args.torch_dir, args.libtensorflow_dir, verbose=verbose, + torch_with_mkl=args.torch_with_mkl, ) except (SetupError, BuildError) as e: logger.error(str(e)) @@ -496,3 +499,10 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: default=False, help="Build KeyDB instead of Redis", ) + + parser.add_argument( + "--no_torch_with_mkl", + dest="torch_with_mkl", + action="store_false", + help="Do not build Torch with Intel MKL", + ) diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index 47f12d044..d0dbc5a6a 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -28,6 +28,7 @@ import concurrent.futures import enum +import fileinput import itertools import os import platform @@ -53,8 +54,7 @@ # TODO: check cmake version and use system if possible to avoid conflicts TRedisAIBackendStr = t.Literal["tensorflow", "torch", "onnxruntime", "tflite"] - - +_PathLike = t.Union[str, "os.PathLike[str]"] _T = t.TypeVar("_T") _U = t.TypeVar("_U") @@ -369,7 +369,7 @@ class _RAIBuildDependency(ABC): def __rai_dependency_name__(self) -> str: ... @abstractmethod - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: ... + def __place_for_rai__(self, target: _PathLike) -> Path: ... @staticmethod @abstractmethod @@ -377,7 +377,7 @@ def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: def _place_rai_dep_at( - target: t.Union[str, "os.PathLike[str]"], verbose: bool + target: _PathLike, verbose: bool ) -> t.Callable[[_RAIBuildDependency], Path]: def _place(dep: _RAIBuildDependency) -> Path: if verbose: @@ -410,6 +410,7 @@ def __init__( build_onnx: bool = False, jobs: int = 1, verbose: bool = False, + torch_with_mkl: bool = True, ) -> None: super().__init__( build_env or {}, @@ -428,6 +429,9 @@ def __init__( self.libtf_dir = libtf_dir self.torch_dir = torch_dir + # extra configuration options + self.torch_with_mkl = torch_with_mkl + # Sanity checks self._validate_platform() @@ -517,8 +521,8 @@ def _get_deps_to_fetch_for( # DLPack is always required fetchable_deps: t.List[_RAIBuildDependency] = [_DLPackRepository("v0.5_RAI")] if self.fetch_torch: - pt_dep = _choose_pt_variant(os_) - fetchable_deps.append(pt_dep(arch, device, "2.0.1")) + pt_dep = _choose_pt_variant(os_)(arch, device, "2.0.1", self.torch_with_mkl) + fetchable_deps.append(pt_dep) if self.fetch_tf: fetchable_deps.append(_TFArchive(os_, arch, device, "2.13.1")) if self.fetch_onnx: @@ -755,7 +759,7 @@ def url(self) -> str: ... class _WebGitRepository(_WebLocation): def clone( self, - target: t.Union[str, "os.PathLike[str]"], + target: _PathLike, depth: t.Optional[int] = None, branch: t.Optional[str] = None, ) -> None: @@ -785,7 +789,7 @@ def url(self) -> str: def __rai_dependency_name__(self) -> str: return f"dlpack@{self.url}" - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def __place_for_rai__(self, target: _PathLike) -> Path: target = Path(target) / "dlpack" self.clone(target, branch=self.version, depth=1) if not target.is_dir(): @@ -799,7 +803,7 @@ def name(self) -> str: _, name = self.url.rsplit("/", 1) return name - def download(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def download(self, target: _PathLike) -> Path: target = Path(target) if target.is_dir(): target = target / self.name @@ -809,28 +813,22 @@ def download(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: class _ExtractableWebArchive(_WebArchive, ABC): @abstractmethod - def _extract_download( - self, download_path: Path, target: t.Union[str, "os.PathLike[str]"] - ) -> None: ... + def _extract_download(self, download_path: Path, target: _PathLike) -> None: ... - def extract(self, target: t.Union[str, "os.PathLike[str]"]) -> None: + def extract(self, target: _PathLike) -> None: with tempfile.TemporaryDirectory() as tmp_dir: arch_path = self.download(tmp_dir) self._extract_download(arch_path, target) class _WebTGZ(_ExtractableWebArchive): - def _extract_download( - self, download_path: Path, target: t.Union[str, "os.PathLike[str]"] - ) -> None: + def _extract_download(self, download_path: Path, target: _PathLike) -> None: with tarfile.open(download_path, "r") as tgz_file: tgz_file.extractall(target) class _WebZip(_ExtractableWebArchive): - def _extract_download( - self, download_path: Path, target: t.Union[str, "os.PathLike[str]"] - ) -> None: + def _extract_download(self, download_path: Path, target: _PathLike) -> None: with zipfile.ZipFile(download_path, "r") as zip_file: zip_file.extractall(target) @@ -840,6 +838,7 @@ class _PTArchive(_WebZip, _RAIBuildDependency): architecture: Architecture device: Device version: str + with_mkl: bool @staticmethod def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: @@ -854,7 +853,20 @@ def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: def __rai_dependency_name__(self) -> str: return f"libtorch@{self.url}" - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + @staticmethod + def _patch_out_mkl(libtorch_root: Path) -> None: + _modify_source_files( + libtorch_root / "share/cmake/Caffe2/public/mkl.cmake", + r"find_package\(MKL QUIET\)", + "# find_package(MKL QUIET)", + ) + + def extract(self, target: _PathLike) -> None: + super().extract(target) + if not self.with_mkl: + self._patch_out_mkl(Path(target)) + + def __place_for_rai__(self, target: _PathLike) -> Path: self.extract(target) target = Path(target) / "libtorch" if not target.is_dir(): @@ -964,7 +976,7 @@ def url(self) -> str: def __rai_dependency_name__(self) -> str: return f"libtensorflow@{self.url}" - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def __place_for_rai__(self, target: _PathLike) -> Path: target = Path(target) / "libtensorflow" target.mkdir() self.extract(target) @@ -1010,7 +1022,7 @@ def url(self) -> str: def __rai_dependency_name__(self) -> str: return f"onnxruntime@{self.url}" - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def __place_for_rai__(self, target: _PathLike) -> Path: target = Path(target).resolve() / "onnxruntime" self.extract(target) try: @@ -1051,3 +1063,13 @@ def config_git_command(plat: Platform, cmd: t.Sequence[str]) -> t.List[str]: + cmd[where:] ) return cmd + + +def _modify_source_files( + files: t.Union[_PathLike, t.Iterable[_PathLike]], regex: str, replacement: str +) -> None: + compiled_regex = re.compile(regex) + with fileinput.input(files=files, inplace=True) as handles: + for line in handles: + line = compiled_regex.sub(replacement, line) + print(line, end="") diff --git a/tests/install/test_builder.py b/tests/install/test_builder.py index c69a083d1..feaf7e54f 100644 --- a/tests/install/test_builder.py +++ b/tests/install/test_builder.py @@ -27,8 +27,7 @@ import functools import pathlib -import platform -import threading +import textwrap import time import pytest @@ -254,13 +253,13 @@ def test_PTArchiveMacOSX_url(): pt_version = RAI_VERSIONS.torch pt_linux_cpu = build._PTArchiveLinux( - build.Architecture.X64, build.Device.CPU, pt_version + build.Architecture.X64, build.Device.CPU, pt_version, False ) x64_prefix = "https://download.pytorch.org/libtorch/" assert x64_prefix in pt_linux_cpu.url pt_macosx_cpu = build._PTArchiveMacOSX( - build.Architecture.ARM64, build.Device.CPU, pt_version + build.Architecture.ARM64, build.Device.CPU, pt_version, False ) arm64_prefix = "https://github.com/CrayLabs/ml_lib_builder/releases/download/" assert arm64_prefix in pt_macosx_cpu.url @@ -269,7 +268,7 @@ def test_PTArchiveMacOSX_url(): def test_PTArchiveMacOSX_gpu_error(): with pytest.raises(build.BuildError, match="support GPU on Mac OSX"): build._PTArchiveMacOSX( - build.Architecture.ARM64, build.Device.GPU, RAI_VERSIONS.torch + build.Architecture.ARM64, build.Device.GPU, RAI_VERSIONS.torch, False ).url @@ -370,3 +369,36 @@ def test_valid_platforms(): ) def test_git_commands_are_configered_correctly_for_platforms(plat, cmd, expected_cmd): assert build.config_git_command(plat, cmd) == expected_cmd + + +def test_modify_source_files(p_test_dir): + def make_text_blurb(food): + return textwrap.dedent(f"""\ + My favorite food is {food} + {food} is an important part of a healthy breakfast + {food} {food} {food} {food} + This line should be unchanged! + --> {food} <-- + """) + + original_word = "SPAM" + mutated_word = "EGGS" + + source_files = [] + for i in range(3): + source_file = p_test_dir / f"test_{i}" + source_file.touch() + source_file.write_text(make_text_blurb(original_word)) + source_files.append(source_file) + # Modify a single file + build._modify_source_files(source_files[0], original_word, mutated_word) + assert source_files[0].read_text() == make_text_blurb(mutated_word) + assert source_files[1].read_text() == make_text_blurb(original_word) + assert source_files[2].read_text() == make_text_blurb(original_word) + + # Modify multiple files + build._modify_source_files( + (source_files[1], source_files[2]), original_word, mutated_word + ) + assert source_files[1].read_text() == make_text_blurb(mutated_word) + assert source_files[2].read_text() == make_text_blurb(mutated_word) From 505de50fea8cbff22f2a3a02c51b8a7a7eb7ec59 Mon Sep 17 00:00:00 2001 From: Matt Drozt Date: Fri, 5 Apr 2024 10:59:39 -0700 Subject: [PATCH 2/2] Enhanced Signal Management (#535) Fixes unfalsifiable test that tests SmartSim's custom SIGINT signal handler. Adds infrastructure to make the test pass again. [ committed by @MattToast ] [ reviewed by @ashao ] --- conftest.py | 21 +++++ doc/changelog.rst | 7 ++ smartsim/_core/control/controller.py | 15 +++- smartsim/_core/control/jobmanager.py | 2 +- smartsim/_core/utils/helpers.py | 98 +++++++++++++++++++++++ tests/test_helpers.py | 115 +++++++++++++++++++++++++++ tests/test_interrupt.py | 59 +++++++------- 7 files changed, 285 insertions(+), 32 deletions(-) diff --git a/conftest.py b/conftest.py index c1e9ba4a9..1e9b5a141 100644 --- a/conftest.py +++ b/conftest.py @@ -31,6 +31,7 @@ import os import pathlib import shutil +import signal import sys import tempfile import typing as t @@ -206,6 +207,26 @@ def alloc_specs() -> t.Dict[str, t.Any]: return specs +def _reset_signal(signalnum: int): + """SmartSim will set/overwrite signals on occasion. This function will + return a generator that can be used as a fixture to automatically reset the + signal handler to what it was at the beginning of the test suite to keep + tests atomic. + """ + original = signal.getsignal(signalnum) + + def _reset(): + yield + signal.signal(signalnum, original) + + return _reset + + +_reset_signal_interrupt = pytest.fixture( + _reset_signal(signal.SIGINT), autouse=True, scope="function" +) + + @pytest.fixture def wlmutils() -> t.Type[WLMUtils]: return WLMUtils diff --git a/doc/changelog.rst b/doc/changelog.rst index 5e7b89f0f..ec22f7acd 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -33,6 +33,7 @@ Description - Fix publishing of development docs - Update Experiment API typing - Minor enhancements to test suite +- Improve SmartSim experiment signal handlers Detailed Notes @@ -82,12 +83,18 @@ Detailed Notes undefined. (SmartSim-PR521_) - Remove previously deprecated behavior present in test suite on machines with Slurm and Open MPI. (SmartSim-PR520_) +- When calling ``Experiment.start`` SmartSim would register a signal handler + that would capture an interrupt signal (^C) to kill any jobs launched through + its ``JobManager``. This would replace the default (or user defined) signal + handler. SmartSim will now attempt to kill any launched jobs before calling + the previously registered signal handler. (SmartSim-PR535_) .. _SmartSim-PR538: https://github.com/CrayLabs/SmartSim/pull/538 .. _SmartSim-PR537: https://github.com/CrayLabs/SmartSim/pull/537 .. _SmartSim-PR498: https://github.com/CrayLabs/SmartSim/pull/498 .. _SmartSim-PR460: https://github.com/CrayLabs/SmartSim/pull/460 .. _SmartSim-PR512: https://github.com/CrayLabs/SmartSim/pull/512 +.. _SmartSim-PR535: https://github.com/CrayLabs/SmartSim/pull/535 .. _SmartSim-PR529: https://github.com/CrayLabs/SmartSim/pull/529 .. _SmartSim-PR522: https://github.com/CrayLabs/SmartSim/pull/522 .. _SmartSim-PR521: https://github.com/CrayLabs/SmartSim/pull/521 diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py index 5c1de5cc2..989d66d2c 100644 --- a/smartsim/_core/control/controller.py +++ b/smartsim/_core/control/controller.py @@ -43,7 +43,11 @@ from smartsim._core.utils.network import get_ip_from_host from ..._core.launcher.step import Step -from ..._core.utils.helpers import unpack_colo_db_identifier, unpack_db_identifier +from ..._core.utils.helpers import ( + SignalInterceptionStack, + unpack_colo_db_identifier, + unpack_db_identifier, +) from ..._core.utils.redis import ( db_is_active, set_ml_model, @@ -71,6 +75,8 @@ from .manifest import LaunchedManifest, LaunchedManifestBuilder, Manifest if t.TYPE_CHECKING: + from types import FrameType + from ..utils.serialize import TStepLaunchMetaData @@ -113,8 +119,11 @@ def start( execution of all jobs. """ self._jobs.kill_on_interrupt = kill_on_interrupt + # register custom signal handler for ^C (SIGINT) - signal.signal(signal.SIGINT, self._jobs.signal_interrupt) + SignalInterceptionStack.get(signal.SIGINT).push_unique( + self._jobs.signal_interrupt + ) launched = self._launch(exp_name, exp_path, manifest) # start the job manager thread if not already started @@ -132,7 +141,7 @@ def start( # block until all non-database jobs are complete if block: # poll handles its own keyboard interrupt as - # it may be called seperately + # it may be called separately self.poll(5, True, kill_on_interrupt=kill_on_interrupt) @property diff --git a/smartsim/_core/control/jobmanager.py b/smartsim/_core/control/jobmanager.py index 89363d520..4910b8311 100644 --- a/smartsim/_core/control/jobmanager.py +++ b/smartsim/_core/control/jobmanager.py @@ -350,9 +350,9 @@ def set_db_hosts(self, orchestrator: Orchestrator) -> None: self.db_jobs[dbnode.name].hosts = dbnode.hosts def signal_interrupt(self, signo: int, _frame: t.Optional[FrameType]) -> None: + """Custom handler for whenever SIGINT is received""" if not signo: logger.warning("Received SIGINT with no signal number") - """Custom handler for whenever SIGINT is received""" if self.actively_monitoring and len(self) > 0: if self.kill_on_interrupt: for _, job in self().items(): diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index 9ae319883..b9e79e250 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -28,7 +28,9 @@ A file of helper functions for SmartSim """ import base64 +import collections.abc import os +import signal import typing as t import uuid from datetime import datetime @@ -38,6 +40,12 @@ from smartsim._core._install.builder import TRedisAIBackendStr as _TRedisAIBackendStr +if t.TYPE_CHECKING: + from types import FrameType + + +_TSignalHandlerFn = t.Callable[[int, t.Optional["FrameType"]], object] + def unpack_db_identifier(db_id: str, token: str) -> t.Tuple[str, str]: """Unpack the unformatted database identifier @@ -302,3 +310,93 @@ def decode_cmd(encoded_cmd: str) -> t.List[str]: cleaned_cmd = decoded_cmd.decode("ascii").split("|") return cleaned_cmd + + +# TODO: Remove the ``type: ignore`` comment here when Python 3.8 support is dropped +# ``collections.abc.Collection`` is not subscriptable until Python 3.9 +@t.final +class SignalInterceptionStack(collections.abc.Collection): # type: ignore[type-arg] + """Registers a stack of unique callables to be called when a signal is + received before calling the original signal handler. + """ + + def __init__( + self, + signalnum: int, + callbacks: t.Optional[t.Iterable[_TSignalHandlerFn]] = None, + ) -> None: + """Set up a ``SignalInterceptionStack`` for particular signal number. + + .. note:: + This class typically should not be instanced directly as it will + change the registered signal handler regardless of if a signal + interception stack is already present. Instead, it is generally + best to create or get a signal interception stack for a particular + signal number via the `get` factory method. + + :param signalnum: The signal number to intercept + :type signalnum: int + :param callbacks: A iterable of functions to call upon receiving the signal + :type callbacks: t.Iterable[_TSignalHandlerFn] | None + """ + self._callbacks = list(callbacks) if callbacks else [] + self._original = signal.signal(signalnum, self) + + def __call__(self, signalnum: int, frame: t.Optional["FrameType"]) -> None: + """Handle the signal on which the interception stack was registered. + End by calling the originally registered signal hander (if present). + + :param frame: The current stack frame + :type frame: FrameType | None + """ + for fn in self: + fn(signalnum, frame) + if callable(self._original): + self._original(signalnum, frame) + + def __contains__(self, obj: object) -> bool: + return obj in self._callbacks + + def __iter__(self) -> t.Iterator[_TSignalHandlerFn]: + return reversed(self._callbacks) + + def __len__(self) -> int: + return len(self._callbacks) + + @classmethod + def get(cls, signalnum: int) -> "SignalInterceptionStack": + """Fetch an existing ``SignalInterceptionStack`` or create a new one + for a particular signal number. + + :param signalnum: The singal number of the signal interception stack + should be registered + :type signalnum: int + :returns: The existing or created signal interception stack + :rtype: SignalInterceptionStack + """ + handler = signal.getsignal(signalnum) + if isinstance(handler, cls): + return handler + return cls(signalnum, []) + + def push(self, fn: _TSignalHandlerFn) -> None: + """Add a callback to the signal interception stack. + + :param fn: A callable to add to the unique signal stack + :type fn: _TSignalHandlerFn + """ + self._callbacks.append(fn) + + def push_unique(self, fn: _TSignalHandlerFn) -> bool: + """Add a callback to the signal interception stack if and only if the + callback is not already present. + + :param fn: A callable to add to the unique signal stack + :type fn: _TSignalHandlerFn + :returns: True if the callback was added, False if the callback was + already present + :rtype: bool + """ + if did_push := fn not in self: + self.push(fn) + return did_push diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 025f53d32..523ed7191 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -24,6 +24,9 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import collections +import signal + import pytest from smartsim._core.utils import helpers @@ -68,3 +71,115 @@ def test_encode_raises_on_empty(): def test_decode_raises_on_empty(): with pytest.raises(ValueError): helpers.decode_cmd("") + + +class MockSignal: + def __init__(self): + self.signal_handlers = collections.defaultdict(lambda: signal.SIG_IGN) + + def signal(self, signalnum, handler): + orig = self.getsignal(signalnum) + self.signal_handlers[signalnum] = handler + return orig + + def getsignal(self, signalnum): + return self.signal_handlers[signalnum] + + +@pytest.fixture +def mock_signal(monkeypatch): + mock_signal = MockSignal() + monkeypatch.setattr(helpers, "signal", mock_signal) + yield mock_signal + + +def test_signal_intercept_stack_will_register_itself_with_callback_fn(mock_signal): + callback = lambda num, frame: ... + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push(callback) + assert isinstance(stack, helpers.SignalInterceptionStack) + assert stack is mock_signal.signal_handlers[signal.NSIG] + assert len(stack) == 1 + assert list(stack)[0] == callback + + +def test_signal_intercept_stack_keeps_track_of_previous_handlers(mock_signal): + default_handler = lambda num, frame: ... + mock_signal.signal_handlers[signal.NSIG] = default_handler + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push(lambda n, f: ...) + assert stack._original is default_handler + + +def test_signal_intercept_stacks_are_registered_per_signal_number(mock_signal): + handler = lambda num, frame: ... + stack_1 = helpers.SignalInterceptionStack.get(signal.NSIG) + stack_1.push(handler) + stack_2 = helpers.SignalInterceptionStack.get(signal.NSIG + 1) + stack_2.push(handler) + + assert mock_signal.signal_handlers[signal.NSIG] is stack_1 + assert mock_signal.signal_handlers[signal.NSIG + 1] is stack_2 + assert stack_1 is not stack_2 + assert list(stack_1) == list(stack_2) == [handler] + + +def test_signal_intercept_handlers_will_not_overwrite_if_handler_already_exists( + mock_signal, +): + handler_1 = lambda num, frame: ... + handler_2 = lambda num, frame: ... + stack_1 = helpers.SignalInterceptionStack.get(signal.NSIG) + stack_1.push(handler_1) + stack_2 = helpers.SignalInterceptionStack.get(signal.NSIG) + stack_2.push(handler_2) + assert stack_1 is stack_2 is mock_signal.signal_handlers[signal.NSIG] + assert list(stack_1) == [handler_2, handler_1] + + +def test_signal_intercept_stack_can_add_multiple_instances_of_the_same_handler( + mock_signal, +): + handler = lambda num, frame: ... + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push(handler) + stack.push(handler) + assert list(stack) == [handler, handler] + + +def test_signal_intercept_stack_enforces_that_unique_push_handlers_are_unique( + mock_signal, +): + handler = lambda num, frame: ... + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + assert stack.push_unique(handler) + assert not helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(handler) + assert list(stack) == [handler] + + +def test_signal_intercept_stack_enforces_that_unique_push_method_handlers_are_unique( + mock_signal, +): + class C: + def fn(num, frame): ... + + c1 = C() + c2 = C() + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push_unique(c1.fn) + assert helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(c2.fn) + assert not helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(c1.fn) + assert list(stack) == [c2.fn, c1.fn] + + +def test_signal_handler_calls_functions_in_reverse_order(mock_signal): + called_list = [] + default = lambda num, frame: called_list.append("default") + handler_1 = lambda num, frame: called_list.append("handler_1") + handler_2 = lambda num, frame: called_list.append("handler_2") + + mock_signal.signal_handlers[signal.NSIG] = default + helpers.SignalInterceptionStack.get(signal.NSIG).push(handler_1) + helpers.SignalInterceptionStack.get(signal.NSIG).push(handler_2) + mock_signal.signal_handlers[signal.NSIG](signal.NSIG, None) + assert called_list == ["handler_2", "handler_1", "default"] diff --git a/tests/test_interrupt.py b/tests/test_interrupt.py index 28c48e0db..61dc5b8c0 100644 --- a/tests/test_interrupt.py +++ b/tests/test_interrupt.py @@ -65,20 +65,21 @@ def test_interrupt_blocked_jobs(test_dir): ) ensemble.set_path(test_dir) num_jobs = 1 + len(ensemble) - try: - pid = os.getpid() - keyboard_interrupt_thread = Thread( - name="sigint_thread", target=keyboard_interrupt, args=(pid,) - ) - keyboard_interrupt_thread.start() + pid = os.getpid() + keyboard_interrupt_thread = Thread( + name="sigint_thread", target=keyboard_interrupt, args=(pid,) + ) + keyboard_interrupt_thread.start() + + with pytest.raises(KeyboardInterrupt): exp.start(model, ensemble, block=True, kill_on_interrupt=True) - except KeyboardInterrupt: - time.sleep(2) # allow time for jobs to be stopped - active_jobs = exp._control._jobs.jobs - active_db_jobs = exp._control._jobs.db_jobs - completed_jobs = exp._control._jobs.completed - assert len(active_jobs) + len(active_db_jobs) == 0 - assert len(completed_jobs) == num_jobs + + time.sleep(2) # allow time for jobs to be stopped + active_jobs = exp._control._jobs.jobs + active_db_jobs = exp._control._jobs.db_jobs + completed_jobs = exp._control._jobs.completed + assert len(active_jobs) + len(active_db_jobs) == 0 + assert len(completed_jobs) == num_jobs def test_interrupt_multi_experiment_unblocked_jobs(test_dir): @@ -106,20 +107,22 @@ def test_interrupt_multi_experiment_unblocked_jobs(test_dir): ) ensemble.set_path(test_dir) jobs_per_experiment[i] = 1 + len(ensemble) - try: - pid = os.getpid() - keyboard_interrupt_thread = Thread( - name="sigint_thread", target=keyboard_interrupt, args=(pid,) - ) - keyboard_interrupt_thread.start() + + pid = os.getpid() + keyboard_interrupt_thread = Thread( + name="sigint_thread", target=keyboard_interrupt, args=(pid,) + ) + keyboard_interrupt_thread.start() + + with pytest.raises(KeyboardInterrupt): for experiment in experiments: experiment.start(model, ensemble, block=False, kill_on_interrupt=True) - time.sleep(9) # since jobs aren't blocked, wait for SIGINT - except KeyboardInterrupt: - time.sleep(2) # allow time for jobs to be stopped - for i, experiment in enumerate(experiments): - active_jobs = experiment._control._jobs.jobs - active_db_jobs = experiment._control._jobs.db_jobs - completed_jobs = experiment._control._jobs.completed - assert len(active_jobs) + len(active_db_jobs) == 0 - assert len(completed_jobs) == jobs_per_experiment[i] + keyboard_interrupt_thread.join() # since jobs aren't blocked, wait for SIGINT + + time.sleep(2) # allow time for jobs to be stopped + for i, experiment in enumerate(experiments): + active_jobs = experiment._control._jobs.jobs + active_db_jobs = experiment._control._jobs.db_jobs + completed_jobs = experiment._control._jobs.completed + assert len(active_jobs) + len(active_db_jobs) == 0 + assert len(completed_jobs) == jobs_per_experiment[i]