diff --git a/conftest.py b/conftest.py index c1e9ba4a9..1e9b5a141 100644 --- a/conftest.py +++ b/conftest.py @@ -31,6 +31,7 @@ import os import pathlib import shutil +import signal import sys import tempfile import typing as t @@ -206,6 +207,26 @@ def alloc_specs() -> t.Dict[str, t.Any]: return specs +def _reset_signal(signalnum: int): + """SmartSim will set/overwrite signals on occasion. This function will + return a generator that can be used as a fixture to automatically reset the + signal handler to what it was at the beginning of the test suite to keep + tests atomic. + """ + original = signal.getsignal(signalnum) + + def _reset(): + yield + signal.signal(signalnum, original) + + return _reset + + +_reset_signal_interrupt = pytest.fixture( + _reset_signal(signal.SIGINT), autouse=True, scope="function" +) + + @pytest.fixture def wlmutils() -> t.Type[WLMUtils]: return WLMUtils diff --git a/doc/changelog.rst b/doc/changelog.rst index 72ee67594..3d1bd3db9 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -19,6 +19,7 @@ To be released at some future point in time Description - Update watchdog dependency +- Add option to build Torch backend without the Intel Math Kernel Library - Fix ReadTheDocs build issue - Promote device options to an Enum - Update telemetry monitor, add telemetry collectors @@ -33,10 +34,16 @@ Description - Fix publishing of development docs - Update Experiment API typing - Minor enhancements to test suite +- Improve SmartSim experiment signal handlers Detailed Notes - Update watchdog dependency from 3.x to 4.x, fix new type issues (SmartSim-PR540_) +- Add an option to smart build "--torch_with_mkl"/"--no_torch_with_mkl" to + prevent Torch from trying to link in the Intel Math Kernel Library. This + is needed because on machines that have the Intel compilers installed, the + Torch will unconditionally try to link in this library, however fails + because the linking flags are incorrect. (SmartSim-PR538_) - Change type_extension and pydantic versions in readthedocs environment to enable docs build. (SmartSim-PR537_) - Promote devices to a dedicated Enum type throughout the SmartSim code base. @@ -78,12 +85,19 @@ Detailed Notes undefined. (SmartSim-PR521_) - Remove previously deprecated behavior present in test suite on machines with Slurm and Open MPI. (SmartSim-PR520_) +- When calling ``Experiment.start`` SmartSim would register a signal handler + that would capture an interrupt signal (^C) to kill any jobs launched through + its ``JobManager``. This would replace the default (or user defined) signal + handler. SmartSim will now attempt to kill any launched jobs before calling + the previously registered signal handler. (SmartSim-PR535_) .. _SmartSim-PR540: https://github.com/CrayLabs/SmartSim/pull/540 +.. _SmartSim-PR538: https://github.com/CrayLabs/SmartSim/pull/538 .. _SmartSim-PR537: https://github.com/CrayLabs/SmartSim/pull/537 .. _SmartSim-PR498: https://github.com/CrayLabs/SmartSim/pull/498 .. _SmartSim-PR460: https://github.com/CrayLabs/SmartSim/pull/460 .. _SmartSim-PR512: https://github.com/CrayLabs/SmartSim/pull/512 +.. _SmartSim-PR535: https://github.com/CrayLabs/SmartSim/pull/535 .. _SmartSim-PR529: https://github.com/CrayLabs/SmartSim/pull/529 .. _SmartSim-PR522: https://github.com/CrayLabs/SmartSim/pull/522 .. _SmartSim-PR521: https://github.com/CrayLabs/SmartSim/pull/521 diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 08a1a6138..ab982ac1b 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -139,6 +139,7 @@ def build_redis_ai( torch_dir: t.Union[str, Path, None] = None, libtf_dir: t.Union[str, Path, None] = None, verbose: bool = False, + torch_with_mkl: bool = True, ) -> None: # make sure user isn't trying to do something silly on MacOS if build_env.PLATFORM == "darwin" and device == Device.GPU: @@ -186,6 +187,7 @@ def build_redis_ai( build_tf=use_tf, build_onnx=use_onnx, verbose=verbose, + torch_with_mkl=torch_with_mkl, ) if rai_builder.is_built: @@ -414,6 +416,7 @@ def execute( args.torch_dir, args.libtensorflow_dir, verbose=verbose, + torch_with_mkl=args.torch_with_mkl, ) except (SetupError, BuildError) as e: logger.error(str(e)) @@ -496,3 +499,10 @@ def configure_parser(parser: argparse.ArgumentParser) -> None: default=False, help="Build KeyDB instead of Redis", ) + + parser.add_argument( + "--no_torch_with_mkl", + dest="torch_with_mkl", + action="store_false", + help="Do not build Torch with Intel MKL", + ) diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index 47f12d044..d0dbc5a6a 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -28,6 +28,7 @@ import concurrent.futures import enum +import fileinput import itertools import os import platform @@ -53,8 +54,7 @@ # TODO: check cmake version and use system if possible to avoid conflicts TRedisAIBackendStr = t.Literal["tensorflow", "torch", "onnxruntime", "tflite"] - - +_PathLike = t.Union[str, "os.PathLike[str]"] _T = t.TypeVar("_T") _U = t.TypeVar("_U") @@ -369,7 +369,7 @@ class _RAIBuildDependency(ABC): def __rai_dependency_name__(self) -> str: ... @abstractmethod - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: ... + def __place_for_rai__(self, target: _PathLike) -> Path: ... @staticmethod @abstractmethod @@ -377,7 +377,7 @@ def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: def _place_rai_dep_at( - target: t.Union[str, "os.PathLike[str]"], verbose: bool + target: _PathLike, verbose: bool ) -> t.Callable[[_RAIBuildDependency], Path]: def _place(dep: _RAIBuildDependency) -> Path: if verbose: @@ -410,6 +410,7 @@ def __init__( build_onnx: bool = False, jobs: int = 1, verbose: bool = False, + torch_with_mkl: bool = True, ) -> None: super().__init__( build_env or {}, @@ -428,6 +429,9 @@ def __init__( self.libtf_dir = libtf_dir self.torch_dir = torch_dir + # extra configuration options + self.torch_with_mkl = torch_with_mkl + # Sanity checks self._validate_platform() @@ -517,8 +521,8 @@ def _get_deps_to_fetch_for( # DLPack is always required fetchable_deps: t.List[_RAIBuildDependency] = [_DLPackRepository("v0.5_RAI")] if self.fetch_torch: - pt_dep = _choose_pt_variant(os_) - fetchable_deps.append(pt_dep(arch, device, "2.0.1")) + pt_dep = _choose_pt_variant(os_)(arch, device, "2.0.1", self.torch_with_mkl) + fetchable_deps.append(pt_dep) if self.fetch_tf: fetchable_deps.append(_TFArchive(os_, arch, device, "2.13.1")) if self.fetch_onnx: @@ -755,7 +759,7 @@ def url(self) -> str: ... class _WebGitRepository(_WebLocation): def clone( self, - target: t.Union[str, "os.PathLike[str]"], + target: _PathLike, depth: t.Optional[int] = None, branch: t.Optional[str] = None, ) -> None: @@ -785,7 +789,7 @@ def url(self) -> str: def __rai_dependency_name__(self) -> str: return f"dlpack@{self.url}" - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def __place_for_rai__(self, target: _PathLike) -> Path: target = Path(target) / "dlpack" self.clone(target, branch=self.version, depth=1) if not target.is_dir(): @@ -799,7 +803,7 @@ def name(self) -> str: _, name = self.url.rsplit("/", 1) return name - def download(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def download(self, target: _PathLike) -> Path: target = Path(target) if target.is_dir(): target = target / self.name @@ -809,28 +813,22 @@ def download(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: class _ExtractableWebArchive(_WebArchive, ABC): @abstractmethod - def _extract_download( - self, download_path: Path, target: t.Union[str, "os.PathLike[str]"] - ) -> None: ... + def _extract_download(self, download_path: Path, target: _PathLike) -> None: ... - def extract(self, target: t.Union[str, "os.PathLike[str]"]) -> None: + def extract(self, target: _PathLike) -> None: with tempfile.TemporaryDirectory() as tmp_dir: arch_path = self.download(tmp_dir) self._extract_download(arch_path, target) class _WebTGZ(_ExtractableWebArchive): - def _extract_download( - self, download_path: Path, target: t.Union[str, "os.PathLike[str]"] - ) -> None: + def _extract_download(self, download_path: Path, target: _PathLike) -> None: with tarfile.open(download_path, "r") as tgz_file: tgz_file.extractall(target) class _WebZip(_ExtractableWebArchive): - def _extract_download( - self, download_path: Path, target: t.Union[str, "os.PathLike[str]"] - ) -> None: + def _extract_download(self, download_path: Path, target: _PathLike) -> None: with zipfile.ZipFile(download_path, "r") as zip_file: zip_file.extractall(target) @@ -840,6 +838,7 @@ class _PTArchive(_WebZip, _RAIBuildDependency): architecture: Architecture device: Device version: str + with_mkl: bool @staticmethod def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: @@ -854,7 +853,20 @@ def supported_platforms() -> t.Sequence[t.Tuple[OperatingSystem, Architecture]]: def __rai_dependency_name__(self) -> str: return f"libtorch@{self.url}" - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + @staticmethod + def _patch_out_mkl(libtorch_root: Path) -> None: + _modify_source_files( + libtorch_root / "share/cmake/Caffe2/public/mkl.cmake", + r"find_package\(MKL QUIET\)", + "# find_package(MKL QUIET)", + ) + + def extract(self, target: _PathLike) -> None: + super().extract(target) + if not self.with_mkl: + self._patch_out_mkl(Path(target)) + + def __place_for_rai__(self, target: _PathLike) -> Path: self.extract(target) target = Path(target) / "libtorch" if not target.is_dir(): @@ -964,7 +976,7 @@ def url(self) -> str: def __rai_dependency_name__(self) -> str: return f"libtensorflow@{self.url}" - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def __place_for_rai__(self, target: _PathLike) -> Path: target = Path(target) / "libtensorflow" target.mkdir() self.extract(target) @@ -1010,7 +1022,7 @@ def url(self) -> str: def __rai_dependency_name__(self) -> str: return f"onnxruntime@{self.url}" - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: + def __place_for_rai__(self, target: _PathLike) -> Path: target = Path(target).resolve() / "onnxruntime" self.extract(target) try: @@ -1051,3 +1063,13 @@ def config_git_command(plat: Platform, cmd: t.Sequence[str]) -> t.List[str]: + cmd[where:] ) return cmd + + +def _modify_source_files( + files: t.Union[_PathLike, t.Iterable[_PathLike]], regex: str, replacement: str +) -> None: + compiled_regex = re.compile(regex) + with fileinput.input(files=files, inplace=True) as handles: + for line in handles: + line = compiled_regex.sub(replacement, line) + print(line, end="") diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py index 5c1de5cc2..989d66d2c 100644 --- a/smartsim/_core/control/controller.py +++ b/smartsim/_core/control/controller.py @@ -43,7 +43,11 @@ from smartsim._core.utils.network import get_ip_from_host from ..._core.launcher.step import Step -from ..._core.utils.helpers import unpack_colo_db_identifier, unpack_db_identifier +from ..._core.utils.helpers import ( + SignalInterceptionStack, + unpack_colo_db_identifier, + unpack_db_identifier, +) from ..._core.utils.redis import ( db_is_active, set_ml_model, @@ -71,6 +75,8 @@ from .manifest import LaunchedManifest, LaunchedManifestBuilder, Manifest if t.TYPE_CHECKING: + from types import FrameType + from ..utils.serialize import TStepLaunchMetaData @@ -113,8 +119,11 @@ def start( execution of all jobs. """ self._jobs.kill_on_interrupt = kill_on_interrupt + # register custom signal handler for ^C (SIGINT) - signal.signal(signal.SIGINT, self._jobs.signal_interrupt) + SignalInterceptionStack.get(signal.SIGINT).push_unique( + self._jobs.signal_interrupt + ) launched = self._launch(exp_name, exp_path, manifest) # start the job manager thread if not already started @@ -132,7 +141,7 @@ def start( # block until all non-database jobs are complete if block: # poll handles its own keyboard interrupt as - # it may be called seperately + # it may be called separately self.poll(5, True, kill_on_interrupt=kill_on_interrupt) @property diff --git a/smartsim/_core/control/jobmanager.py b/smartsim/_core/control/jobmanager.py index 89363d520..4910b8311 100644 --- a/smartsim/_core/control/jobmanager.py +++ b/smartsim/_core/control/jobmanager.py @@ -350,9 +350,9 @@ def set_db_hosts(self, orchestrator: Orchestrator) -> None: self.db_jobs[dbnode.name].hosts = dbnode.hosts def signal_interrupt(self, signo: int, _frame: t.Optional[FrameType]) -> None: + """Custom handler for whenever SIGINT is received""" if not signo: logger.warning("Received SIGINT with no signal number") - """Custom handler for whenever SIGINT is received""" if self.actively_monitoring and len(self) > 0: if self.kill_on_interrupt: for _, job in self().items(): diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index 9ae319883..b9e79e250 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -28,7 +28,9 @@ A file of helper functions for SmartSim """ import base64 +import collections.abc import os +import signal import typing as t import uuid from datetime import datetime @@ -38,6 +40,12 @@ from smartsim._core._install.builder import TRedisAIBackendStr as _TRedisAIBackendStr +if t.TYPE_CHECKING: + from types import FrameType + + +_TSignalHandlerFn = t.Callable[[int, t.Optional["FrameType"]], object] + def unpack_db_identifier(db_id: str, token: str) -> t.Tuple[str, str]: """Unpack the unformatted database identifier @@ -302,3 +310,93 @@ def decode_cmd(encoded_cmd: str) -> t.List[str]: cleaned_cmd = decoded_cmd.decode("ascii").split("|") return cleaned_cmd + + +# TODO: Remove the ``type: ignore`` comment here when Python 3.8 support is dropped +# ``collections.abc.Collection`` is not subscriptable until Python 3.9 +@t.final +class SignalInterceptionStack(collections.abc.Collection): # type: ignore[type-arg] + """Registers a stack of unique callables to be called when a signal is + received before calling the original signal handler. + """ + + def __init__( + self, + signalnum: int, + callbacks: t.Optional[t.Iterable[_TSignalHandlerFn]] = None, + ) -> None: + """Set up a ``SignalInterceptionStack`` for particular signal number. + + .. note:: + This class typically should not be instanced directly as it will + change the registered signal handler regardless of if a signal + interception stack is already present. Instead, it is generally + best to create or get a signal interception stack for a particular + signal number via the `get` factory method. + + :param signalnum: The signal number to intercept + :type signalnum: int + :param callbacks: A iterable of functions to call upon receiving the signal + :type callbacks: t.Iterable[_TSignalHandlerFn] | None + """ + self._callbacks = list(callbacks) if callbacks else [] + self._original = signal.signal(signalnum, self) + + def __call__(self, signalnum: int, frame: t.Optional["FrameType"]) -> None: + """Handle the signal on which the interception stack was registered. + End by calling the originally registered signal hander (if present). + + :param frame: The current stack frame + :type frame: FrameType | None + """ + for fn in self: + fn(signalnum, frame) + if callable(self._original): + self._original(signalnum, frame) + + def __contains__(self, obj: object) -> bool: + return obj in self._callbacks + + def __iter__(self) -> t.Iterator[_TSignalHandlerFn]: + return reversed(self._callbacks) + + def __len__(self) -> int: + return len(self._callbacks) + + @classmethod + def get(cls, signalnum: int) -> "SignalInterceptionStack": + """Fetch an existing ``SignalInterceptionStack`` or create a new one + for a particular signal number. + + :param signalnum: The singal number of the signal interception stack + should be registered + :type signalnum: int + :returns: The existing or created signal interception stack + :rtype: SignalInterceptionStack + """ + handler = signal.getsignal(signalnum) + if isinstance(handler, cls): + return handler + return cls(signalnum, []) + + def push(self, fn: _TSignalHandlerFn) -> None: + """Add a callback to the signal interception stack. + + :param fn: A callable to add to the unique signal stack + :type fn: _TSignalHandlerFn + """ + self._callbacks.append(fn) + + def push_unique(self, fn: _TSignalHandlerFn) -> bool: + """Add a callback to the signal interception stack if and only if the + callback is not already present. + + :param fn: A callable to add to the unique signal stack + :type fn: _TSignalHandlerFn + :returns: True if the callback was added, False if the callback was + already present + :rtype: bool + """ + if did_push := fn not in self: + self.push(fn) + return did_push diff --git a/tests/install/test_builder.py b/tests/install/test_builder.py index c69a083d1..feaf7e54f 100644 --- a/tests/install/test_builder.py +++ b/tests/install/test_builder.py @@ -27,8 +27,7 @@ import functools import pathlib -import platform -import threading +import textwrap import time import pytest @@ -254,13 +253,13 @@ def test_PTArchiveMacOSX_url(): pt_version = RAI_VERSIONS.torch pt_linux_cpu = build._PTArchiveLinux( - build.Architecture.X64, build.Device.CPU, pt_version + build.Architecture.X64, build.Device.CPU, pt_version, False ) x64_prefix = "https://download.pytorch.org/libtorch/" assert x64_prefix in pt_linux_cpu.url pt_macosx_cpu = build._PTArchiveMacOSX( - build.Architecture.ARM64, build.Device.CPU, pt_version + build.Architecture.ARM64, build.Device.CPU, pt_version, False ) arm64_prefix = "https://github.com/CrayLabs/ml_lib_builder/releases/download/" assert arm64_prefix in pt_macosx_cpu.url @@ -269,7 +268,7 @@ def test_PTArchiveMacOSX_url(): def test_PTArchiveMacOSX_gpu_error(): with pytest.raises(build.BuildError, match="support GPU on Mac OSX"): build._PTArchiveMacOSX( - build.Architecture.ARM64, build.Device.GPU, RAI_VERSIONS.torch + build.Architecture.ARM64, build.Device.GPU, RAI_VERSIONS.torch, False ).url @@ -370,3 +369,36 @@ def test_valid_platforms(): ) def test_git_commands_are_configered_correctly_for_platforms(plat, cmd, expected_cmd): assert build.config_git_command(plat, cmd) == expected_cmd + + +def test_modify_source_files(p_test_dir): + def make_text_blurb(food): + return textwrap.dedent(f"""\ + My favorite food is {food} + {food} is an important part of a healthy breakfast + {food} {food} {food} {food} + This line should be unchanged! + --> {food} <-- + """) + + original_word = "SPAM" + mutated_word = "EGGS" + + source_files = [] + for i in range(3): + source_file = p_test_dir / f"test_{i}" + source_file.touch() + source_file.write_text(make_text_blurb(original_word)) + source_files.append(source_file) + # Modify a single file + build._modify_source_files(source_files[0], original_word, mutated_word) + assert source_files[0].read_text() == make_text_blurb(mutated_word) + assert source_files[1].read_text() == make_text_blurb(original_word) + assert source_files[2].read_text() == make_text_blurb(original_word) + + # Modify multiple files + build._modify_source_files( + (source_files[1], source_files[2]), original_word, mutated_word + ) + assert source_files[1].read_text() == make_text_blurb(mutated_word) + assert source_files[2].read_text() == make_text_blurb(mutated_word) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 025f53d32..523ed7191 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -24,6 +24,9 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import collections +import signal + import pytest from smartsim._core.utils import helpers @@ -68,3 +71,115 @@ def test_encode_raises_on_empty(): def test_decode_raises_on_empty(): with pytest.raises(ValueError): helpers.decode_cmd("") + + +class MockSignal: + def __init__(self): + self.signal_handlers = collections.defaultdict(lambda: signal.SIG_IGN) + + def signal(self, signalnum, handler): + orig = self.getsignal(signalnum) + self.signal_handlers[signalnum] = handler + return orig + + def getsignal(self, signalnum): + return self.signal_handlers[signalnum] + + +@pytest.fixture +def mock_signal(monkeypatch): + mock_signal = MockSignal() + monkeypatch.setattr(helpers, "signal", mock_signal) + yield mock_signal + + +def test_signal_intercept_stack_will_register_itself_with_callback_fn(mock_signal): + callback = lambda num, frame: ... + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push(callback) + assert isinstance(stack, helpers.SignalInterceptionStack) + assert stack is mock_signal.signal_handlers[signal.NSIG] + assert len(stack) == 1 + assert list(stack)[0] == callback + + +def test_signal_intercept_stack_keeps_track_of_previous_handlers(mock_signal): + default_handler = lambda num, frame: ... + mock_signal.signal_handlers[signal.NSIG] = default_handler + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push(lambda n, f: ...) + assert stack._original is default_handler + + +def test_signal_intercept_stacks_are_registered_per_signal_number(mock_signal): + handler = lambda num, frame: ... + stack_1 = helpers.SignalInterceptionStack.get(signal.NSIG) + stack_1.push(handler) + stack_2 = helpers.SignalInterceptionStack.get(signal.NSIG + 1) + stack_2.push(handler) + + assert mock_signal.signal_handlers[signal.NSIG] is stack_1 + assert mock_signal.signal_handlers[signal.NSIG + 1] is stack_2 + assert stack_1 is not stack_2 + assert list(stack_1) == list(stack_2) == [handler] + + +def test_signal_intercept_handlers_will_not_overwrite_if_handler_already_exists( + mock_signal, +): + handler_1 = lambda num, frame: ... + handler_2 = lambda num, frame: ... + stack_1 = helpers.SignalInterceptionStack.get(signal.NSIG) + stack_1.push(handler_1) + stack_2 = helpers.SignalInterceptionStack.get(signal.NSIG) + stack_2.push(handler_2) + assert stack_1 is stack_2 is mock_signal.signal_handlers[signal.NSIG] + assert list(stack_1) == [handler_2, handler_1] + + +def test_signal_intercept_stack_can_add_multiple_instances_of_the_same_handler( + mock_signal, +): + handler = lambda num, frame: ... + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push(handler) + stack.push(handler) + assert list(stack) == [handler, handler] + + +def test_signal_intercept_stack_enforces_that_unique_push_handlers_are_unique( + mock_signal, +): + handler = lambda num, frame: ... + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + assert stack.push_unique(handler) + assert not helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(handler) + assert list(stack) == [handler] + + +def test_signal_intercept_stack_enforces_that_unique_push_method_handlers_are_unique( + mock_signal, +): + class C: + def fn(num, frame): ... + + c1 = C() + c2 = C() + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push_unique(c1.fn) + assert helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(c2.fn) + assert not helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(c1.fn) + assert list(stack) == [c2.fn, c1.fn] + + +def test_signal_handler_calls_functions_in_reverse_order(mock_signal): + called_list = [] + default = lambda num, frame: called_list.append("default") + handler_1 = lambda num, frame: called_list.append("handler_1") + handler_2 = lambda num, frame: called_list.append("handler_2") + + mock_signal.signal_handlers[signal.NSIG] = default + helpers.SignalInterceptionStack.get(signal.NSIG).push(handler_1) + helpers.SignalInterceptionStack.get(signal.NSIG).push(handler_2) + mock_signal.signal_handlers[signal.NSIG](signal.NSIG, None) + assert called_list == ["handler_2", "handler_1", "default"] diff --git a/tests/test_interrupt.py b/tests/test_interrupt.py index 28c48e0db..61dc5b8c0 100644 --- a/tests/test_interrupt.py +++ b/tests/test_interrupt.py @@ -65,20 +65,21 @@ def test_interrupt_blocked_jobs(test_dir): ) ensemble.set_path(test_dir) num_jobs = 1 + len(ensemble) - try: - pid = os.getpid() - keyboard_interrupt_thread = Thread( - name="sigint_thread", target=keyboard_interrupt, args=(pid,) - ) - keyboard_interrupt_thread.start() + pid = os.getpid() + keyboard_interrupt_thread = Thread( + name="sigint_thread", target=keyboard_interrupt, args=(pid,) + ) + keyboard_interrupt_thread.start() + + with pytest.raises(KeyboardInterrupt): exp.start(model, ensemble, block=True, kill_on_interrupt=True) - except KeyboardInterrupt: - time.sleep(2) # allow time for jobs to be stopped - active_jobs = exp._control._jobs.jobs - active_db_jobs = exp._control._jobs.db_jobs - completed_jobs = exp._control._jobs.completed - assert len(active_jobs) + len(active_db_jobs) == 0 - assert len(completed_jobs) == num_jobs + + time.sleep(2) # allow time for jobs to be stopped + active_jobs = exp._control._jobs.jobs + active_db_jobs = exp._control._jobs.db_jobs + completed_jobs = exp._control._jobs.completed + assert len(active_jobs) + len(active_db_jobs) == 0 + assert len(completed_jobs) == num_jobs def test_interrupt_multi_experiment_unblocked_jobs(test_dir): @@ -106,20 +107,22 @@ def test_interrupt_multi_experiment_unblocked_jobs(test_dir): ) ensemble.set_path(test_dir) jobs_per_experiment[i] = 1 + len(ensemble) - try: - pid = os.getpid() - keyboard_interrupt_thread = Thread( - name="sigint_thread", target=keyboard_interrupt, args=(pid,) - ) - keyboard_interrupt_thread.start() + + pid = os.getpid() + keyboard_interrupt_thread = Thread( + name="sigint_thread", target=keyboard_interrupt, args=(pid,) + ) + keyboard_interrupt_thread.start() + + with pytest.raises(KeyboardInterrupt): for experiment in experiments: experiment.start(model, ensemble, block=False, kill_on_interrupt=True) - time.sleep(9) # since jobs aren't blocked, wait for SIGINT - except KeyboardInterrupt: - time.sleep(2) # allow time for jobs to be stopped - for i, experiment in enumerate(experiments): - active_jobs = experiment._control._jobs.jobs - active_db_jobs = experiment._control._jobs.db_jobs - completed_jobs = experiment._control._jobs.completed - assert len(active_jobs) + len(active_db_jobs) == 0 - assert len(completed_jobs) == jobs_per_experiment[i] + keyboard_interrupt_thread.join() # since jobs aren't blocked, wait for SIGINT + + time.sleep(2) # allow time for jobs to be stopped + for i, experiment in enumerate(experiments): + active_jobs = experiment._control._jobs.jobs + active_db_jobs = experiment._control._jobs.db_jobs + completed_jobs = experiment._control._jobs.completed + assert len(active_jobs) + len(active_db_jobs) == 0 + assert len(completed_jobs) == jobs_per_experiment[i]