From 7861b7608015dfc39c4104e790ebdbef5b5d848d Mon Sep 17 00:00:00 2001 From: Sybren Jansen Date: Tue, 2 Jul 2024 11:03:11 +0200 Subject: [PATCH 1/2] 110 In the case of an unexpected worker death (e.g., OOM errors) and the worker was working on an apply task, the worker will now be restarted and the other workers will continue their work. The task that caused the death will be set to failed --- docs/changelog.rst | 4 ++++ mpire/async_result.py | 52 ++++++++++++++++++++++++++++++------------- mpire/comms.py | 11 ++++++++- mpire/pool.py | 32 +++++++++++++++++--------- tests/test_pool.py | 44 ++++++++++++++++++++++++++++-------- 5 files changed, 107 insertions(+), 36 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 09879ce..7140689 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,7 +7,11 @@ Unreleased * Expanded error message in case of unexpected worker death (`#130`_) * The progress bar will now show ``Keyboard interrupt`` when a keyboard interrupt is raised to distinguish it from other exceptions +* In the case of an unexpected worker death (e.g., OOM errors) and the worker was working on an ``apply`` task, the + worker will now be restarted and the other workers will continue their work. The task that caused the death will be + set to failed (`#110`_) +.. _#110: https://github.com/sybrenjansen/mpire/issues/110 .. _#130: https://github.com/sybrenjansen/mpire/issues/130 2.10.2 diff --git a/mpire/async_result.py b/mpire/async_result.py index be7b206..b8b9932 100644 --- a/mpire/async_result.py +++ b/mpire/async_result.py @@ -2,19 +2,34 @@ import itertools import queue import threading +from enum import Enum from typing import Any, Callable, Dict, List, Optional, Union -from mpire.comms import EXIT_FUNC, INIT_FUNC +from mpire.comms import EXIT_FUNC, MAIN_PROCESS job_counter = itertools.count() -class AsyncResult: +class JobType(Enum): + MAIN = 1 + INIT = 2 + MAP = 3 + EXIT = 4 + APPLY = 5 - """ Adapted from ``multiprocessing.pool.ApplyResult``. """ - def __init__(self, cache: Dict, callback: Optional[Callable], error_callback: Optional[Callable], - job_id: Optional[int] = None, delete_from_cache: bool = True, timeout: Optional[float] = None) -> None: +class AsyncResult: + """Adapted from ``multiprocessing.pool.ApplyResult``.""" + + def __init__( + self, + cache: Dict, + callback: Optional[Callable], + error_callback: Optional[Callable], + job_id: Optional[int] = None, + delete_from_cache: bool = True, + timeout: Optional[float] = None, + ) -> None: """ :param cache: Cache for storing intermediate results :param callback: Callback function to call when the task is finished. The callback function receives the output @@ -32,6 +47,7 @@ def __init__(self, cache: Dict, callback: Optional[Callable], error_callback: Op self._delete_from_cache = delete_from_cache self._timeout = timeout + self.type = JobType.APPLY self.job_id = next(job_counter) if job_id is None else job_id self._ready_event = threading.Event() self._success = None @@ -103,11 +119,11 @@ def _set(self, success: bool, result: Any) -> None: class UnorderedAsyncResultIterator: + """Stores results of a task and provides an iterator to obtain the results in an unordered fashion""" - """ Stores results of a task and provides an iterator to obtain the results in an unordered fashion """ - - def __init__(self, cache: Dict, n_tasks: Optional[int], job_id: Optional[int] = None, - timeout: Optional[float] = None) -> None: + def __init__( + self, cache: Dict, n_tasks: Optional[int], job_id: Optional[int] = None, timeout: Optional[float] = None + ) -> None: """ :param cache: Cache for storing intermediate results :param n_tasks: Number of tasks that will be executed. If None, we don't know the lenght yet @@ -119,6 +135,7 @@ def __init__(self, cache: Dict, n_tasks: Optional[int], job_id: Optional[int] = self._n_tasks = None self._timeout = timeout + self.type = JobType.MAP self.job_id = next(job_counter) if job_id is None else job_id self._items = collections.deque() self._condition = threading.Condition(lock=threading.Lock()) @@ -202,8 +219,9 @@ def set_length(self, length: int) -> None: """ if self._n_tasks is not None: if self._n_tasks != length: - raise ValueError(f"Length of iterator has already been set to {self._n_tasks}, " - f"but is now set to {length}") + raise ValueError( + f"Length of iterator has already been set to {self._n_tasks}, but is now set to {length}" + ) # Length has already been set. No need to do anything return @@ -228,8 +246,10 @@ def remove_from_cache(self) -> None: class AsyncResultWithExceptionGetter(AsyncResult): def __init__(self, cache: Dict, job_id: int) -> None: - super().__init__(cache, callback=None, error_callback=None, job_id=job_id, delete_from_cache=False, - timeout=None) + super().__init__( + cache, callback=None, error_callback=None, job_id=job_id, delete_from_cache=False, timeout=None + ) + self.type = JobType.MAIN if job_id == MAIN_PROCESS else JobType.INIT def get_exception(self) -> Exception: """ @@ -251,6 +271,7 @@ class UnorderedAsyncExitResultIterator(UnorderedAsyncResultIterator): def __init__(self, cache: Dict) -> None: super().__init__(cache, n_tasks=None, job_id=EXIT_FUNC, timeout=None) + self.type = JobType.EXIT def get_results(self) -> List[Any]: """ @@ -270,5 +291,6 @@ def reset(self) -> None: self._got_exception.clear() -AsyncResultType = Union[AsyncResult, AsyncResultWithExceptionGetter, UnorderedAsyncResultIterator, - UnorderedAsyncExitResultIterator] +AsyncResultType = Union[ + AsyncResult, AsyncResultWithExceptionGetter, UnorderedAsyncResultIterator, UnorderedAsyncExitResultIterator +] diff --git a/mpire/comms.py b/mpire/comms.py index c0d351d..b82db9b 100644 --- a/mpire/comms.py +++ b/mpire/comms.py @@ -121,7 +121,7 @@ def __init__(self, ctx: mp.context.BaseContext, n_jobs: int, order_tasks: bool) self._worker_restart_array: Optional[mp.Array] = None self._worker_restart_condition = self.ctx.Condition(self.ctx.Lock()) - # List of Event objects to indicate whether workers are alive + # Array to indicate whether workers are alive, which is used to check whether a worker was terminated by the OS self._workers_dead: Optional[mp.Array] = None # Array where the child processes indicate when they started a task, worker_init, and worker_exit used for @@ -198,6 +198,15 @@ def init_comms(self) -> None: self.reset_progress() self._initialized = True + + def reinit_comms_for_worker(self, worker_id: int) -> None: + """ + Reinitialize the comms for a worker. This is used when a worker is restarted in case of an unexpected death. + + For some reason, recreating the worker running task value makes sure the worker doesn't get stuck when it's + restarted in case of an unexpected death. For normal restarts, this is not necessary. + """ + self._worker_running_task[worker_id] = self.ctx.Value(ctypes.c_bool, False, lock=self.ctx.RLock()) def reset_progress(self) -> None: """ diff --git a/mpire/pool.py b/mpire/pool.py index f11902e..46a7480 100644 --- a/mpire/pool.py +++ b/mpire/pool.py @@ -13,7 +13,7 @@ np = None NUMPY_INSTALLED = False -from mpire.async_result import (AsyncResult, AsyncResultType, AsyncResultWithExceptionGetter, +from mpire.async_result import (AsyncResult, AsyncResultType, AsyncResultWithExceptionGetter, JobType, UnorderedAsyncExitResultIterator, UnorderedAsyncResultIterator) from mpire.comms import EXIT_FUNC, INIT_FUNC, MAIN_PROCESS, POISON_PILL, WorkerComms from mpire.context import DEFAULT_START_METHOD, RUNNING_WINDOWS @@ -298,26 +298,36 @@ def _unexpected_death_handler(self) -> None: # we just wait a bit and try again. for worker_id in range(len(self._workers)): try: - worker_died = (self._worker_comms.is_worker_alive(worker_id) and - not self._workers[worker_id].is_alive()) + worker_died = ( + self._worker_comms.is_worker_alive(worker_id) and not self._workers[worker_id].is_alive() + ) except ValueError: worker_died = False if worker_died: + self._worker_comms.signal_worker_dead(worker_id) + # Obtain task it was working on and set it to failed job_id = self._worker_comms.get_worker_working_on_job(worker_id) - self._worker_comms.signal_exception_thrown(job_id) + job_type = self._cache[job_id].type err = RuntimeError( f"Worker-{worker_id} died unexpectedly. This usually means the OS/kernel killed the process " "due to running out of memory" ) - - # When a worker dies unexpectedly, the pool shuts down and we set all tasks that haven't completed - # yet to failed - job_ids = set(self._cache.keys()) - {MAIN_PROCESS} - for job_id in job_ids: - self._cache[job_id]._set(success=False, result=err) - return + self._cache[job_id]._set(success=False, result=err) + + if job_type == JobType.APPLY: + # When a worker of an apply task dies unexpectedly we restart the worker and continue + self._worker_comms.reinit_comms_for_worker(worker_id) + self._start_worker(worker_id) + else: + # When a worker of a map task dies unexpectedly, the pool shuts down and we set all tasks that + # haven't completed yet to failed + self._worker_comms.signal_exception_thrown(job_id) + job_ids = set(self._cache.keys()) - {MAIN_PROCESS} + for job_id in job_ids: + self._cache[job_id]._set(success=False, result=err) + return # Check this every once in a while time.sleep(0.1) diff --git a/tests/test_pool.py b/tests/test_pool.py index 3c1921a..83e744b 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1381,9 +1381,7 @@ def test_defunct_processes_exit(self): Tests if MPIRE correctly shuts down after process becomes defunct using exit() """ print() - for n_jobs, progress_bar, worker_lifespan in [(1, False, None), - (3, True, 1), - (3, False, 3)]: + for n_jobs, progress_bar, worker_lifespan in [(1, False, None), (3, True, 1), (3, False, 3)]: for start_method in TEST_START_METHODS: # Progress bar on Windows + threading is not supported right now if RUNNING_WINDOWS and start_method == 'threading' and progress_bar: @@ -1394,18 +1392,16 @@ def test_defunct_processes_exit(self): WorkerPool(n_jobs=n_jobs, start_method=start_method) as pool: pool.map(self._exit, range(100), progress_bar=progress_bar, worker_lifespan=worker_lifespan) - def test_defunct_processes_kill(self): + def test_defunct_processes_kill_map(self): """ - Tests if MPIRE correctly shuts down after one process becomes defunct using os.kill(). + Tests if MPIRE correctly shuts down after one process becomes defunct using os.kill() in a map function. We kill worker 0 and to be sure it's alive we set an event object and then go in an infinite loop. The kill thread waits until the event is set and then kills the worker. The other workers are also ensured to have done something so we can test what happens during restarts """ print() - for n_jobs, progress_bar, worker_lifespan in [(1, False, None), - (3, True, 1), - (3, False, 3)]: + for n_jobs, progress_bar, worker_lifespan in [(1, False, None), (3, True, 1), (3, False, 3)]: for start_method in TEST_START_METHODS: # Can't kill threads if start_method == 'threading': @@ -1421,6 +1417,36 @@ def test_defunct_processes_kill(self): pool.set_shared_objects(events) pool.map(self._worker_0_sleeps_others_square, range(100), progress_bar=progress_bar, worker_lifespan=worker_lifespan, chunk_size=1) + + def test_defunct_processes_kill_apply(self): + """ + Tests if MPIRE correctly continues after one process becomes defunct using os.kill() in an apply function. + + We kill worker 0 and to be sure it's alive we set an event object and then go in an infinite loop. The kill + thread waits until the event is set and then kills the worker. The other workers are also ensured to have done + something so we can test what happens during restarts + """ + print() + for n_jobs in [1, 3]: + for start_method in TEST_START_METHODS: + # Can't kill threads + if start_method == 'threading': + continue + + print(f"========== {start_method}, {n_jobs} ==========") + with self.subTest(n_jobs=n_jobs, start_method=start_method), \ + WorkerPool(n_jobs=3, pass_worker_id=True) as pool: + events = [pool.ctx.Event() for _ in range(3)] + kill_thread = Thread(target=self._kill_process, args=(events[0], pool)) + kill_thread.start() + pool.set_shared_objects(events) + futures = [ + pool.apply_async(self._worker_0_sleeps_others_square, args=(x,)) + for x in range(100) + ] + [futures.wait() for futures in futures] + assert [future.successful() for future in futures] == [False] + [True] * 99 + def test_dill_deadlock(self): """ @@ -1452,7 +1478,7 @@ def _worker_0_sleeps_others_square(worker_id, events, x): """ Worker 0 waits until the other workers have at least spun up and then sets her event and sleeps """ - if worker_id == 0: + if worker_id == 0 and not events[0].is_set(): [event.wait() for event in events[1:]] events[0].set() while True: From 353f3f043a133f0bdda58c8690c73a95d6bed780 Mon Sep 17 00:00:00 2001 From: Sybren Jansen Date: Tue, 2 Jul 2024 16:18:39 +0200 Subject: [PATCH 2/2] enum stuff replaced by auto --- mpire/async_result.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mpire/async_result.py b/mpire/async_result.py index b8b9932..d7dad88 100644 --- a/mpire/async_result.py +++ b/mpire/async_result.py @@ -2,7 +2,7 @@ import itertools import queue import threading -from enum import Enum +from enum import Enum, auto from typing import Any, Callable, Dict, List, Optional, Union from mpire.comms import EXIT_FUNC, MAIN_PROCESS @@ -11,11 +11,11 @@ class JobType(Enum): - MAIN = 1 - INIT = 2 - MAP = 3 - EXIT = 4 - APPLY = 5 + MAIN = auto() + INIT = auto() + MAP = auto() + EXIT = auto() + APPLY = auto() class AsyncResult: