diff --git a/README.rst b/README.rst index e763b70..5ea757d 100644 --- a/README.rst +++ b/README.rst @@ -197,6 +197,14 @@ MPIRE has been benchmarked on three different benchmarks: numerical computation, initialization. More details on these benchmarks can be found in this `blog post`_. All code for these benchmarks can be found in this project_. +In short, the main reasons why MPIRE is faster are: + +- When ``fork`` is available we can make use of copy-on-write shared objects, which reduces the need to copy objects + that need to be shared over child processes +- Workers can hold state over multiple tasks. Therefore you can choose to load a big file or send resources over only + once per worker +- Automatic task chunking + The following graph shows the average normalized results of all three benchmarks. Results for individual benchmarks can be found in the `blog post`_. The benchmarks were run on a Linux machine with 20 cores, with disabled hyperthreading and 200GB of RAM. For each task, experiments were run with different numbers of processes/workers and results were diff --git a/docs/changelog.rst b/docs/changelog.rst index f840e85..80befeb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,16 @@ Changelog ========= +Master +------ + +* MPIRE now handles defunct child processes properly, instead of deadlocking (`#34`_) +* Added benchmark highlights to README (`#38`_) + +.. _#34: https://github.com/Slimmer-AI/mpire/issues/34 +.. _#38: https://github.com/Slimmer-AI/mpire/issues/38 + + 2.3.4 ----- diff --git a/mpire/comms.py b/mpire/comms.py index 3d2056f..b737927 100644 --- a/mpire/comms.py +++ b/mpire/comms.py @@ -58,8 +58,9 @@ def __init__(self, ctx: mp.context.BaseContext, n_jobs: int) -> None: # Array where the child processes can request a restart self._worker_done_array = None - # List of Event objects to indicate whether workers are alive + # List of Event objects to indicate whether workers are alive, together with accompanying locks self._workers_dead = None + self._workers_dead_locks = None # Queue where the child processes can pass on an encountered exception self._exception_queue = None @@ -112,6 +113,7 @@ def init_comms(self, has_worker_exit: bool, has_progress_bar: bool) -> None: self._worker_done_array = self.ctx.Array('b', self.n_jobs, lock=False) self._workers_dead = [self.ctx.Event() for _ in range(self.n_jobs)] [worker_dead.set() for worker_dead in self._workers_dead] + self._workers_dead_locks = [self.ctx.Lock() for _ in range(self.n_jobs)] # Exception related self._exception_queue = self.ctx.JoinableQueue() @@ -499,6 +501,15 @@ def reset_worker_restart(self, worker_id) -> None: """ self._worker_done_array[worker_id] = False + def get_worker_dead_lock(self, worker_id: int) -> mp.Lock: + """ + Returns the worker dead lock for a specific worker + + :param worker_id: Worker ID + :return: Lock object + """ + return self._workers_dead_locks[worker_id] + def signal_worker_alive(self, worker_id: int) -> None: """ Indicate that a worker is alive diff --git a/mpire/pool.py b/mpire/pool.py index d72604a..92c3487 100644 --- a/mpire/pool.py +++ b/mpire/pool.py @@ -148,12 +148,15 @@ def _start_workers(self, progress_bar: bool) -> None: self._workers.append(self._start_worker(worker_id)) logger.debug("Workers created") - def _restart_workers(self) -> List[Any]: + def _check_worker_status(self) -> List[Any]: """ - Restarts workers that need to be restarted. + Checks the worker status: + - If the worker is supposed to be alive, but isn't, terminate. + - Restarts workers that need to be restarted. :return: List of unordered results produces by workers """ + # Check restarts obtained_results = [] for worker_id in self._worker_comms.get_worker_restarts(): # Obtain results from exit results queue (should be done before joining the worker) @@ -178,6 +181,20 @@ def _restart_workers(self) -> List[Any]: # Start new worker self._workers[worker_id] = self._start_worker(worker_id) + # Check that workers that are supposed to be alive, are actually alive. If not, then a worker died unexpectedly. + # Note that a worker can be alive, but their alive status is still False. This doesn't really matter, because we + # know the worker is alive according to the OS. The only way we know that something bad happened is when a + # worker is supposed to be alive but according to the OS it's not. + for worker_id in range(self.pool_params.n_jobs): + with self._worker_comms.get_worker_dead_lock(worker_id): + worker_died = self._worker_comms.is_worker_alive(worker_id) and not self._workers[worker_id].is_alive() + if worker_died: + # We need to add an exception if we're using the progress bar handler + if self._worker_comms.has_progress_bar(): + self._worker_comms.add_exception(RuntimeError, f"Worker-{worker_id} died unexpectedly") + self.terminate() + raise RuntimeError(f"Worker-{worker_id} died unexpectedly") + return obtained_results def _start_worker(self, worker_id: int) -> mp.Process: @@ -547,8 +564,8 @@ def imap_unordered(self, func: Callable, iterable_of_args: Union[Sized, Iterable except queue.Empty: pass - # Restart workers if necessary. This can yield intermediate results - for results in self._restart_workers(): + # Check worker status (e.g., restarts). This can yield intermediate results + for results in self._check_worker_status(): yield from results n_active -= 1 @@ -560,8 +577,8 @@ def imap_unordered(self, func: Callable, iterable_of_args: Union[Sized, Iterable except queue.Empty: pass - # Restart workers if necessary. This can yield intermediate results - for results in self._restart_workers(): + # Check worker status (e.g., restarts). This can yield intermediate results + for results in self._check_worker_status(): yield from results n_active -= 1 @@ -668,7 +685,7 @@ def stop_and_join(self, progress_bar_handler: Optional[ProgressBarHandler] = Non t.join(timeout=0.01) if not t.is_alive(): break - self._restart_workers() + self._check_worker_status() logger.debug("Done joining task queues") # When an exception occurred in the above process (i.e., the worker init function raises), we need to handle diff --git a/mpire/worker.py b/mpire/worker.py index aa5c44e..cdc4351 100644 --- a/mpire/worker.py +++ b/mpire/worker.py @@ -107,7 +107,7 @@ def _exit_gracefully(self, *_) -> None: self.is_running = False raise StopWorker - def _exit_gracefully_windows(self): + def _exit_gracefully_windows(self) -> None: """ Windows doesn't fully support signals as Unix-based systems do. Therefore, we have to work around it. This function is started in a thread. We wait for a kill signal (Event object) and interrupt the main thread if we @@ -134,7 +134,8 @@ def run(self) -> None: t = Thread(target=self._exit_gracefully_windows) t.start() - self.worker_comms.signal_worker_alive(self.worker_id) + with self.worker_comms.get_worker_dead_lock(self.worker_id): + self.worker_comms.signal_worker_alive(self.worker_id) # Set tqdm and dashboard connection details. This is needed for nested pools and in the case forkserver or # spawn is used as start method @@ -235,7 +236,8 @@ def run(self) -> None: self.worker_comms.signal_worker_restart(self.worker_id) finally: - self.worker_comms.signal_worker_dead(self.worker_id) + with self.worker_comms.get_worker_dead_lock(self.worker_id): + self.worker_comms.signal_worker_dead(self.worker_id) def _get_func(self, additional_args: List) -> Callable: """ @@ -327,7 +329,7 @@ def _run_safely(self, func: Callable, exception_args: Optional[Any] = None, # The main process tells us to stop working, shutting down raise - except Exception as err: + except (Exception, SystemExit) as err: # An exception occurred inside the provided function. Let the signal handler know it shouldn't raise any # StopWorker exceptions from the parent process anymore, we got this. with self.is_running_lock: @@ -344,7 +346,7 @@ def _run_safely(self, func: Callable, exception_args: Optional[Any] = None, # Carry on return results, False - def _raise(self, args: Any, no_args: bool, err: Exception) -> None: + def _raise(self, args: Any, no_args: bool, err: Union[Exception, SystemExit]) -> None: """ Create exception and pass it to the parent process. Let other processes know an exception is set diff --git a/tests/test_comms.py b/tests/test_comms.py index a633a61..ace45a4 100644 --- a/tests/test_comms.py +++ b/tests/test_comms.py @@ -54,6 +54,7 @@ def test_init_comms(self): self.assertListEqual(comms._exit_results_queues, []) self.assertIsNone(comms._all_exit_results_obtained) self.assertIsNone(comms._worker_done_array) + self.assertIsNone(comms._workers_dead_locks) self.assertIsNone(comms._workers_dead) self.assertIsNone(comms._exception_queue) self.assertIsInstance(comms.exception_lock, lock_type) @@ -82,6 +83,9 @@ def test_init_comms(self): for worker_dead in comms._workers_dead: self.assertIsInstance(worker_dead, event_type) self.assertTrue(worker_dead.is_set()) + self.assertEqual(len(comms._workers_dead_locks), n_jobs) + for worker_dead_lock in comms._workers_dead_locks: + self.assertIsInstance(worker_dead_lock, lock_type) self.assertIsInstance(comms._exception_queue, joinable_queue_type) self.assertFalse(comms._exception_thrown.is_set()) self.assertFalse(comms._kill_signal_received.is_set()) @@ -142,6 +146,9 @@ def test_init_comms(self): for worker_dead in comms._workers_dead: self.assertIsInstance(worker_dead, event_type) self.assertTrue(worker_dead.is_set()) + self.assertEqual(len(comms._workers_dead_locks), n_jobs) + for worker_dead_lock in comms._workers_dead_locks: + self.assertIsInstance(worker_dead_lock, lock_type) self.assertIsInstance(comms._exception_queue, joinable_queue_type) self.assertFalse(comms._exception_thrown.is_set()) self.assertFalse(comms._kill_signal_received.is_set()) diff --git a/tests/test_pool.py b/tests/test_pool.py index f36ba47..05c7c10 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,10 +1,12 @@ import logging import os +import signal import types import unittest import warnings from itertools import product, repeat from multiprocessing import Barrier, Value +from threading import Thread from unittest.mock import patch import numpy as np @@ -25,7 +27,7 @@ def square(idx, x): return idx, x * x -def extremely_large_output(idx, x): +def extremely_large_output(idx, _): return idx, os.urandom(1024 * 1024) @@ -691,9 +693,9 @@ def test_start_methods(self): pool.map(self._square_daemon, ((X,) for X in repeat(self.test_data, 3)), chunk_size=1) @staticmethod - def _square_daemon(X): + def _square_daemon(x): with WorkerPool(n_jobs=4) as pool: - return pool.map(square, X, chunk_size=1) + return pool.map(square, x, chunk_size=1) class CPUPinningTest(unittest.TestCase): @@ -722,7 +724,7 @@ def test_cpu_pinning(self): (4, [[0, 3]], [[0, 3], [0, 3], [0, 3], [0, 3]])]: # The test has been designed for a system with at least 4 cores. We'll skip those test cases where the CPU # IDs exceed the number of CPUs. - if cpu_ids is not None and np.array(cpu_ids).max() >= cpu_count(): + if cpu_ids is not None and np.array(cpu_ids).max(initial=0) >= cpu_count(): continue with self.subTest(n_jobs=n_jobs, cpu_ids=cpu_ids), patch('mpire.pool.set_cpu_affinity') as p, \ @@ -1156,6 +1158,52 @@ def test_start_methods(self): with self.subTest(function='square_raises_on_idx', map='imap'), self.assertRaises(ValueError): list(pool.imap_unordered(self._square_raises_on_idx, self.test_data, progress_bar=progress_bar)) + def test_defunct_processes_exit(self): + """ + Tests if MPIRE correctly shuts down after process becomes defunct using exit() + """ + print() + for n_jobs, progress_bar, worker_lifespan in [(1, False, None), + (3, True, 1), + (3, False, 3)]: + for start_method in TEST_START_METHODS: + # Progress bar on Windows + threading is not supported right now + if RUNNING_WINDOWS and start_method == 'threading' and progress_bar: + continue + self.logger.debug(f"========== {start_method}, {n_jobs}, {progress_bar}, {worker_lifespan} ==========") + with self.subTest(n_jobs=n_jobs, progress_bar=progress_bar, worker_lifespan=worker_lifespan, + start_method=start_method), self.assertRaises(SystemExit), \ + WorkerPool(n_jobs=n_jobs, start_method=start_method) as pool: + pool.map(self._exit, range(100), progress_bar=progress_bar, worker_lifespan=worker_lifespan) + + def test_defunct_processes_kill(self): + """ + Tests if MPIRE correctly shuts down after one process becomes defunct using os.kill(). + + We kill worker 0 and to be sure it's alive we set an event object and then go in an infinite loop. The kill + thread waits until the event is set and then kills the worker. The other workers are also ensured to have done + something so we can test what happens during restarts + """ + print() + for n_jobs, progress_bar, worker_lifespan in [(1, False, None), + (3, True, 1), + (3, False, 3)]: + for start_method in TEST_START_METHODS: + # Can't kill threads + if start_method == 'threading': + continue + + self.logger.debug(f"========== {start_method}, {n_jobs}, {progress_bar}, {worker_lifespan} ==========") + with self.subTest(n_jobs=n_jobs, progress_bar=progress_bar, worker_lifespan=worker_lifespan, + start_method=start_method), self.assertRaises(RuntimeError), \ + WorkerPool(n_jobs=n_jobs, pass_worker_id=True, start_method=start_method) as pool: + events = [pool.ctx.Event() for _ in range(n_jobs)] + kill_thread = Thread(target=self._kill_process, args=(events[0], pool)) + kill_thread.start() + pool.set_shared_objects(events) + pool.map(self._worker_0_sleeps_others_square, range(1000), progress_bar=progress_bar, + worker_lifespan=worker_lifespan, chunk_size=1) + @staticmethod def _square_raises(_, x): raise ValueError(x) @@ -1166,3 +1214,29 @@ def _square_raises_on_idx(idx, x): raise ValueError(x) else: return idx, x * x + + @staticmethod + def _exit(_): + exit() + + @staticmethod + def _worker_0_sleeps_others_square(worker_id, events, x): + """ + Worker 0 waits until the other workers have at least spun up and then sets her event and sleeps + """ + if worker_id == 0: + [event.wait() for event in events[1:]] + events[0].set() + while True: + pass + else: + events[worker_id].set() + return x * x + + @staticmethod + def _kill_process(event, pool): + """ + Wait for event and kill + """ + event.wait() + pool._workers[0].terminate()