Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

110 - Don't stop all tasks when experiencing unexpected worker deaths #134

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ Unreleased
* Expanded error message in case of unexpected worker death (`#130`_)
* The progress bar will now show ``Keyboard interrupt`` when a keyboard interrupt is raised to distinguish it from
other exceptions
* In the case of an unexpected worker death (e.g., OOM errors) and the worker was working on an ``apply`` task, the
worker will now be restarted and the other workers will continue their work. The task that caused the death will be
set to failed (`#110`_)

.. _#110: https://github.com/sybrenjansen/mpire/issues/110
.. _#130: https://github.com/sybrenjansen/mpire/issues/130

2.10.2
Expand Down
52 changes: 37 additions & 15 deletions mpire/async_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,34 @@
import itertools
import queue
import threading
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional, Union

from mpire.comms import EXIT_FUNC, INIT_FUNC
from mpire.comms import EXIT_FUNC, MAIN_PROCESS

job_counter = itertools.count()


class AsyncResult:
class JobType(Enum):
MAIN = auto()
INIT = auto()
MAP = auto()
EXIT = auto()
APPLY = auto()

""" Adapted from ``multiprocessing.pool.ApplyResult``. """

def __init__(self, cache: Dict, callback: Optional[Callable], error_callback: Optional[Callable],
job_id: Optional[int] = None, delete_from_cache: bool = True, timeout: Optional[float] = None) -> None:
class AsyncResult:
"""Adapted from ``multiprocessing.pool.ApplyResult``."""

def __init__(
self,
cache: Dict,
callback: Optional[Callable],
error_callback: Optional[Callable],
job_id: Optional[int] = None,
delete_from_cache: bool = True,
timeout: Optional[float] = None,
) -> None:
Comment on lines +24 to +32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've started applying black?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm applying it to some files here and there. Will do a full black conversion later.

Also planning on introducing the other precommit checks, toml file etc. But that will take some time

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I can't honestly say the changes it made in this PR improved it. I found it nicer to read before. I think it'll take some fancier automation to replace your sense of code style ;)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, thnx :D Yeah, I'm also not in agreement with all the choices black makes, but it's so much easier to simply hit the format code shortcut in the editor and be done with it

"""
:param cache: Cache for storing intermediate results
:param callback: Callback function to call when the task is finished. The callback function receives the output
Expand All @@ -32,6 +47,7 @@ def __init__(self, cache: Dict, callback: Optional[Callable], error_callback: Op
self._delete_from_cache = delete_from_cache
self._timeout = timeout

self.type = JobType.APPLY
self.job_id = next(job_counter) if job_id is None else job_id
self._ready_event = threading.Event()
self._success = None
Expand Down Expand Up @@ -103,11 +119,11 @@ def _set(self, success: bool, result: Any) -> None:


class UnorderedAsyncResultIterator:
"""Stores results of a task and provides an iterator to obtain the results in an unordered fashion"""

""" Stores results of a task and provides an iterator to obtain the results in an unordered fashion """

def __init__(self, cache: Dict, n_tasks: Optional[int], job_id: Optional[int] = None,
timeout: Optional[float] = None) -> None:
def __init__(
self, cache: Dict, n_tasks: Optional[int], job_id: Optional[int] = None, timeout: Optional[float] = None
) -> None:
"""
:param cache: Cache for storing intermediate results
:param n_tasks: Number of tasks that will be executed. If None, we don't know the lenght yet
Expand All @@ -119,6 +135,7 @@ def __init__(self, cache: Dict, n_tasks: Optional[int], job_id: Optional[int] =
self._n_tasks = None
self._timeout = timeout

self.type = JobType.MAP
self.job_id = next(job_counter) if job_id is None else job_id
self._items = collections.deque()
self._condition = threading.Condition(lock=threading.Lock())
Expand Down Expand Up @@ -202,8 +219,9 @@ def set_length(self, length: int) -> None:
"""
if self._n_tasks is not None:
if self._n_tasks != length:
raise ValueError(f"Length of iterator has already been set to {self._n_tasks}, "
f"but is now set to {length}")
raise ValueError(
f"Length of iterator has already been set to {self._n_tasks}, but is now set to {length}"
)
# Length has already been set. No need to do anything
return

Expand All @@ -228,8 +246,10 @@ def remove_from_cache(self) -> None:
class AsyncResultWithExceptionGetter(AsyncResult):

def __init__(self, cache: Dict, job_id: int) -> None:
super().__init__(cache, callback=None, error_callback=None, job_id=job_id, delete_from_cache=False,
timeout=None)
super().__init__(
cache, callback=None, error_callback=None, job_id=job_id, delete_from_cache=False, timeout=None
)
self.type = JobType.MAIN if job_id == MAIN_PROCESS else JobType.INIT

def get_exception(self) -> Exception:
"""
Expand All @@ -251,6 +271,7 @@ class UnorderedAsyncExitResultIterator(UnorderedAsyncResultIterator):

def __init__(self, cache: Dict) -> None:
super().__init__(cache, n_tasks=None, job_id=EXIT_FUNC, timeout=None)
self.type = JobType.EXIT

def get_results(self) -> List[Any]:
"""
Expand All @@ -270,5 +291,6 @@ def reset(self) -> None:
self._got_exception.clear()


AsyncResultType = Union[AsyncResult, AsyncResultWithExceptionGetter, UnorderedAsyncResultIterator,
UnorderedAsyncExitResultIterator]
AsyncResultType = Union[
AsyncResult, AsyncResultWithExceptionGetter, UnorderedAsyncResultIterator, UnorderedAsyncExitResultIterator
]
11 changes: 10 additions & 1 deletion mpire/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self, ctx: mp.context.BaseContext, n_jobs: int, order_tasks: bool)
self._worker_restart_array: Optional[mp.Array] = None
self._worker_restart_condition = self.ctx.Condition(self.ctx.Lock())

# List of Event objects to indicate whether workers are alive
# Array to indicate whether workers are alive, which is used to check whether a worker was terminated by the OS
self._workers_dead: Optional[mp.Array] = None

# Array where the child processes indicate when they started a task, worker_init, and worker_exit used for
Expand Down Expand Up @@ -198,6 +198,15 @@ def init_comms(self) -> None:

self.reset_progress()
self._initialized = True

def reinit_comms_for_worker(self, worker_id: int) -> None:
"""
Reinitialize the comms for a worker. This is used when a worker is restarted in case of an unexpected death.

For some reason, recreating the worker running task value makes sure the worker doesn't get stuck when it's
restarted in case of an unexpected death. For normal restarts, this is not necessary.
"""
self._worker_running_task[worker_id] = self.ctx.Value(ctypes.c_bool, False, lock=self.ctx.RLock())

def reset_progress(self) -> None:
"""
Expand Down
32 changes: 21 additions & 11 deletions mpire/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
np = None
NUMPY_INSTALLED = False

from mpire.async_result import (AsyncResult, AsyncResultType, AsyncResultWithExceptionGetter,
from mpire.async_result import (AsyncResult, AsyncResultType, AsyncResultWithExceptionGetter, JobType,
UnorderedAsyncExitResultIterator, UnorderedAsyncResultIterator)
from mpire.comms import EXIT_FUNC, INIT_FUNC, MAIN_PROCESS, POISON_PILL, WorkerComms
from mpire.context import DEFAULT_START_METHOD, RUNNING_WINDOWS
Expand Down Expand Up @@ -298,26 +298,36 @@ def _unexpected_death_handler(self) -> None:
# we just wait a bit and try again.
for worker_id in range(len(self._workers)):
try:
worker_died = (self._worker_comms.is_worker_alive(worker_id) and
not self._workers[worker_id].is_alive())
worker_died = (
self._worker_comms.is_worker_alive(worker_id) and not self._workers[worker_id].is_alive()
)
except ValueError:
worker_died = False

if worker_died:
self._worker_comms.signal_worker_dead(worker_id)

# Obtain task it was working on and set it to failed
job_id = self._worker_comms.get_worker_working_on_job(worker_id)
self._worker_comms.signal_exception_thrown(job_id)
job_type = self._cache[job_id].type
err = RuntimeError(
f"Worker-{worker_id} died unexpectedly. This usually means the OS/kernel killed the process "
"due to running out of memory"
)

# When a worker dies unexpectedly, the pool shuts down and we set all tasks that haven't completed
# yet to failed
job_ids = set(self._cache.keys()) - {MAIN_PROCESS}
for job_id in job_ids:
self._cache[job_id]._set(success=False, result=err)
return
self._cache[job_id]._set(success=False, result=err)

if job_type == JobType.APPLY:
# When a worker of an apply task dies unexpectedly we restart the worker and continue
self._worker_comms.reinit_comms_for_worker(worker_id)
self._start_worker(worker_id)
else:
# When a worker of a map task dies unexpectedly, the pool shuts down and we set all tasks that
# haven't completed yet to failed
self._worker_comms.signal_exception_thrown(job_id)
job_ids = set(self._cache.keys()) - {MAIN_PROCESS}
for job_id in job_ids:
self._cache[job_id]._set(success=False, result=err)
return

# Check this every once in a while
time.sleep(0.1)
Expand Down
44 changes: 35 additions & 9 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,9 +1381,7 @@ def test_defunct_processes_exit(self):
Tests if MPIRE correctly shuts down after process becomes defunct using exit()
"""
print()
for n_jobs, progress_bar, worker_lifespan in [(1, False, None),
(3, True, 1),
(3, False, 3)]:
for n_jobs, progress_bar, worker_lifespan in [(1, False, None), (3, True, 1), (3, False, 3)]:
for start_method in TEST_START_METHODS:
# Progress bar on Windows + threading is not supported right now
if RUNNING_WINDOWS and start_method == 'threading' and progress_bar:
Expand All @@ -1394,18 +1392,16 @@ def test_defunct_processes_exit(self):
WorkerPool(n_jobs=n_jobs, start_method=start_method) as pool:
pool.map(self._exit, range(100), progress_bar=progress_bar, worker_lifespan=worker_lifespan)

def test_defunct_processes_kill(self):
def test_defunct_processes_kill_map(self):
"""
Tests if MPIRE correctly shuts down after one process becomes defunct using os.kill().
Tests if MPIRE correctly shuts down after one process becomes defunct using os.kill() in a map function.

We kill worker 0 and to be sure it's alive we set an event object and then go in an infinite loop. The kill
thread waits until the event is set and then kills the worker. The other workers are also ensured to have done
something so we can test what happens during restarts
"""
print()
for n_jobs, progress_bar, worker_lifespan in [(1, False, None),
(3, True, 1),
(3, False, 3)]:
for n_jobs, progress_bar, worker_lifespan in [(1, False, None), (3, True, 1), (3, False, 3)]:
for start_method in TEST_START_METHODS:
# Can't kill threads
if start_method == 'threading':
Expand All @@ -1421,6 +1417,36 @@ def test_defunct_processes_kill(self):
pool.set_shared_objects(events)
pool.map(self._worker_0_sleeps_others_square, range(100), progress_bar=progress_bar,
worker_lifespan=worker_lifespan, chunk_size=1)

def test_defunct_processes_kill_apply(self):
"""
Tests if MPIRE correctly continues after one process becomes defunct using os.kill() in an apply function.

We kill worker 0 and to be sure it's alive we set an event object and then go in an infinite loop. The kill
thread waits until the event is set and then kills the worker. The other workers are also ensured to have done
something so we can test what happens during restarts
"""
print()
for n_jobs in [1, 3]:
for start_method in TEST_START_METHODS:
# Can't kill threads
if start_method == 'threading':
continue

print(f"========== {start_method}, {n_jobs} ==========")
with self.subTest(n_jobs=n_jobs, start_method=start_method), \
WorkerPool(n_jobs=3, pass_worker_id=True) as pool:
events = [pool.ctx.Event() for _ in range(3)]
kill_thread = Thread(target=self._kill_process, args=(events[0], pool))
kill_thread.start()
pool.set_shared_objects(events)
futures = [
pool.apply_async(self._worker_0_sleeps_others_square, args=(x,))
for x in range(100)
]
[futures.wait() for futures in futures]
assert [future.successful() for future in futures] == [False] + [True] * 99


def test_dill_deadlock(self):
"""
Expand Down Expand Up @@ -1452,7 +1478,7 @@ def _worker_0_sleeps_others_square(worker_id, events, x):
"""
Worker 0 waits until the other workers have at least spun up and then sets her event and sleeps
"""
if worker_id == 0:
if worker_id == 0 and not events[0].is_set():
[event.wait() for event in events[1:]]
events[0].set()
while True:
Expand Down
Loading