diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 258f4e88a5a..b54ee59f543 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -487,7 +487,7 @@ async def _handle_stream(self, stream, address): try: await self.on_connection(comm) except CommClosedError: - logger.info("Connection closed before handshake completed") + logger.info("Connection from %s closed before handshake completed", address) return await self.comm_handler(comm) diff --git a/distributed/core.py b/distributed/core.py index 1c95f6bd019..843fc42dcfc 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -14,7 +14,6 @@ import tblib from tlz import merge -from tornado import gen from tornado.ioloop import IOLoop, PeriodicCallback import dask @@ -37,6 +36,7 @@ has_keyword, is_coroutine_function, parse_timedelta, + shielded, truncate_exception, ) @@ -160,6 +160,7 @@ def __init__( self.counters = None self.digests = None self._ongoing_coroutines = weakref.WeakSet() + self._event_finished = asyncio.Event() self.listeners = [] @@ -263,10 +264,17 @@ async def finished(self): def __await__(self): async def _(): + if self.status == Status.running: + return self + + if self.status in (Status.closing, Status.closed): + # We should never await an object which is already closing but + # we should also not start it up again otherwise we'd produce + # zombies + await self.finished() + timeout = getattr(self, "death_timeout", 0) async with self._startup_lock: - if self.status == Status.running: - return self if timeout: try: await asyncio.wait_for(self.start(), timeout=timeout) @@ -422,7 +430,7 @@ async def handle_comm(self, comm): logger.debug("Connection from %r to %s", address, type(self).__name__) self._comms[comm] = op - await self + try: while True: try: @@ -565,17 +573,15 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]): break handler = self.stream_handlers[op] if is_coroutine_function(handler): - self.loop.add_callback(handler, **merge(extra, msg)) - await gen.sleep(0) + await handler(**merge(extra, msg)) else: handler(**merge(extra, msg)) else: logger.error("odd message %s", msg) - await asyncio.sleep(0) for func in every_cycle: if is_coroutine_function(func): - self.loop.add_callback(func) + await func() else: func() @@ -593,32 +599,39 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]): await comm.close() assert comm.closed() - @gen.coroutine - def close(self): + @shielded + async def close(self): + self.status = Status.closing + self.stop() + await self.rpc.close() + for pc in self.periodic_callbacks.values(): pc.stop() self.__stopped = True for listener in self.listeners: future = listener.stop() if inspect.isawaitable(future): - yield future - for i in range(20): + await future + for _ in range(20): # If there are still handlers running at this point, give them a # second to finish gracefully themselves, otherwise... if any(self._comms.values()): - yield asyncio.sleep(0.05) + await asyncio.sleep(0.05) else: break - yield [comm.close() for comm in list(self._comms)] # then forcefully close + await asyncio.gather( + *[comm.close() for comm in list(self._comms)] + ) # then forcefully close for cb in self._ongoing_coroutines: cb.cancel() - for i in range(10): + for _ in range(10): if all(c.cancelled() for c in self._ongoing_coroutines): break else: - yield asyncio.sleep(0.01) + await asyncio.sleep(0.01) self._event_finished.set() + self.status = Status.closed def pingpong(comm): diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 070d6d0624d..bc44ffb3423 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -338,15 +338,8 @@ async def _correct_state_internal(self): if to_close: if self.scheduler.status == Status.running: await self.scheduler_comm.retire_workers(workers=list(to_close)) - tasks = [ - asyncio.create_task(self.workers[w].close()) - for w in to_close - if w in self.workers - ] + tasks = [self.workers[w].close() for w in to_close if w in self.workers] await asyncio.wait(tasks) - for task in tasks: # for tornado gen.coroutine support - with suppress(RuntimeError): - await task for name in to_close: if name in self.workers: del self.workers[name] diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 30496a7d233..67911f06db3 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -446,11 +446,8 @@ async def test_scale_needs_to_be_awaited(cleanup): class RequiresAwaitCluster(LocalCluster): def scale(self, n): - # super invocation in the nested function scope is messy - method = super().scale - async def _(): - return method(n) + return LocalCluster.scale(self, n) return self.sync(_) @@ -458,7 +455,7 @@ async def _(): async with Client(cluster, asynchronous=True) as client: futures = client.map(slowinc, range(5), delay=0.05) assert len(cluster.workers) == 0 - cluster.adapt() + cluster.adapt(interval="10ms") await client.gather(futures) diff --git a/distributed/nanny.py b/distributed/nanny.py index bf51a5c9699..a1779c5284d 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -33,6 +33,7 @@ mp_context, parse_ports, parse_timedelta, + shielded, silence_logging, ) from .worker import Worker, parse_memory_limit, run @@ -194,15 +195,17 @@ def __init__( "instantiate": self.instantiate, "kill": self.kill, "restart": self.restart, - # cannot call it 'close' on the rpc side for naming conflict "get_logs": self.get_logs, + # cannot call it 'close' on the rpc side for naming conflict "terminate": self.close, "close_gracefully": self.close_gracefully, "run": self.run, } super().__init__( - handlers=handlers, io_loop=self.loop, connection_args=self.connection_args + handlers=handlers, + io_loop=self.loop, + connection_args=self.connection_args, ) self.scheduler = self.rpc(self.scheduler_addr) @@ -306,6 +309,7 @@ async def start(self): return self + @shielded async def kill(self, comm=None, timeout=2): """Kill the local worker process @@ -318,6 +322,7 @@ async def kill(self, comm=None, timeout=2): deadline = self.loop.time() + timeout await self.process.kill(timeout=0.8 * (deadline - self.loop.time())) + self.status = Status.stopped async def instantiate(self, comm=None) -> Status: """Start a local worker process @@ -489,15 +494,13 @@ def close_gracefully(self, comm=None): """ self.status = Status.closing_gracefully + @shielded async def close(self, comm=None, timeout=5, report=None): """ Close the worker process, stop all comms. """ - if self.status == Status.closing: + if self.status in (Status.closing, Status.closed): await self.finished() - assert self.status == Status.closed - - if self.status == Status.closed: return "OK" self.status = Status.closing @@ -506,18 +509,15 @@ async def close(self, comm=None, timeout=5, report=None): for preload in self.preloads: await preload.teardown() - self.stop() try: if self.process is not None: await self.kill(timeout=timeout) except Exception: pass self.process = None - await self.rpc.close() - self.status = Status.closed if comm: await comm.write("OK") - await ServerNode.close(self) + await super().close() class WorkerProcess: @@ -653,6 +653,7 @@ def mark_stopped(self): if self.on_exit is not None: self.on_exit(r) + @shielded async def kill(self, timeout=2, executor_wait=True): """ Ensure the worker process is stopped, waiting at most @@ -746,6 +747,7 @@ def _run( loop.make_current() worker = Worker(**worker_kwargs) + @shielded async def do_stop(timeout=5, executor_wait=True): try: await worker.close( @@ -795,7 +797,7 @@ async def run(): # properly handled. See also # WorkerProcess._wait_until_connected (the 2 is for good # measure) - sync_sleep(cls._init_msg_interval * 2) + await asyncio.sleep(cls._init_msg_interval * 2) else: try: assert worker.address diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a2996482ac9..3a1975e8af6 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -78,6 +78,7 @@ no_default, parse_bytes, parse_timedelta, + shielded, tmpfile, validate_key, ) @@ -3583,7 +3584,6 @@ def __init__( "start_task_metadata": self.start_task_metadata, "stop_task_metadata": self.stop_task_metadata, } - connection_limit = get_fileno_limit() / 2 super().__init__( @@ -3752,6 +3752,7 @@ def del_scheduler_file(): setproctitle("dask-scheduler [%s]" % (self.address,)) return self + @shielded async def close(self, comm=None, fast=False, close_workers=False): """Send cleanup signal to all coroutines then wait until finished @@ -3808,10 +3809,6 @@ async def close(self, comm=None, fast=False, close_workers=False): for comm in self.client_comms.values(): comm.abort() - await self.rpc.close() - - self.status = Status.closed - self.stop() await super().close() setproctitle("dask-scheduler [closed]") @@ -3821,15 +3818,19 @@ async def close_worker(self, comm=None, worker=None, safe=None): """Remove a worker from the cluster This both removes the worker from our local state and also sends a - signal to the worker to shut down. This works regardless of whether or - not the worker has a nanny process restarting it + signal to the worker to shut down. + If a Nanny is in front of the worker, the Nanny is instead instructed to + close. """ logger.info("Closing worker %s", worker) with log_errors(): self.log_event(worker, {"action": "close-worker"}) - # FIXME: This does not handly nannys - self.worker_send(worker, {"op": "close", "report": False}) - await self.remove_worker(address=worker, safe=safe) + ws: WorkerState = self._workers_dv[worker] + if ws._nanny: + await self.rpc(ws._nanny).terminate() + else: + self.worker_send(worker, {"op": "close", "report": False}) + await self.remove_worker(address=worker, safe=safe, close=False) ########### # Stimuli # @@ -6257,7 +6258,10 @@ async def retire_workers( ) if remove: await asyncio.gather( - *[self.remove_worker(address=w, safe=True) for w in worker_keys] + *[ + self.remove_worker(address=w, safe=True, close=False) + for w in worker_keys + ] ) self.log_event( diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 4ac628a599b..df30c4e6283 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -579,3 +579,38 @@ async def test_failure_during_worker_initialization(cleanup): async with Nanny(s.address, foo="bar") as n: await n assert "Restarting worker" not in logs.getvalue() + + +@gen_cluster(nthreads=[("127.0.0.1", 0)], Worker=Nanny) +async def test_nanny_terminate_handler(s, a): + a_rpc = rpc(a.address) + await a_rpc.terminate(reply=False) + + while a.status != Status.closed: + await asyncio.sleep(0.05) + + # already closed should be noop + await a.close() + + +@gen_cluster(nthreads=[("127.0.0.1", 0)], Worker=Nanny) +async def test_nanny_kill_handler(s, a): + a_rpc = rpc(a.address) + await a_rpc.kill(reply=False) + + while a.process.status != Status.stopped: + await asyncio.sleep(0.05) + + +@gen_cluster(nthreads=[("127.0.0.1", 0)], Worker=Nanny) +async def test_nanny_close_gracefully_handler(s, a): + a_rpc = rpc(a.address) + await a_rpc.close_gracefully() + assert a.status == Status.closing_gracefully + await a_rpc.terminate() + + while a.status != Status.closed: + await asyncio.sleep(0.05) + + # already closed should be noop + await a.close() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 04f266f8e61..d761c2bdef5 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2802,3 +2802,15 @@ async def test_transition_counter(c, s, a, b): assert s.transition_counter == 0 await c.submit(inc, 1) assert s.transition_counter > 1 + + +@gen_cluster(nthreads=[]) +async def test_scheduler_terminate(s): + s_rpc = rpc(s.address) + await s_rpc.terminate(reply=False) + + while s.status != Status.closed: + await asyncio.sleep(0.05) + + # already closed should be noop + await s.close() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 2f3a7f58ede..9c37651c709 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -39,6 +39,7 @@ from distributed.utils import TimeoutError, tmpfile from distributed.utils_test import ( TaskStateMetadataPlugin, + async_wait_for, captured_logger, dec, div, @@ -2334,3 +2335,76 @@ def raise_exc(*args): for server in [s, a, b]: while server.tasks: await asyncio.sleep(0.01) + + +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) +async def test_exception_in_handler(c, s, a): + """This test is supposed to ensure that regardless of whether a handler is + sync or async, the behaviour is the same.""" + + def raise_unhandled_exception(kind=None): + if kind == "runtime": + # some other exception, maybe handled + raise RuntimeError() + elif kind == "os": + # often special treatment for disconnects, etc. + raise OSError() + raise Exception() + + async def async_raise_unhandled_exception(kind=None): + return raise_unhandled_exception(kind) + + a.handlers["fail"] = raise_unhandled_exception + a.stream_handlers["fail"] = raise_unhandled_exception + + a.handlers["fail_async"] = async_raise_unhandled_exception + a.stream_handlers["fail_async"] = async_raise_unhandled_exception + + with rpc(a.address) as rpc_a: + with pytest.raises(Exception): + await rpc_a.fail() + + with pytest.raises(Exception): + await rpc_a.fail_async() + + # Current behaviour is that any exception in the handler causes a + # disconnect. The worker should automatically initiate a reconnect and + # return to ordinary behaviour. atm we do not differentiate between + # different types of exceptions + for kind in [None, "runtime", "os"]: + a_stream_comm = s.stream_comms.get(a.address) + a_stream_comm.send({"op": "fail", "kind": kind}) + # Worker is still fine + assert a.status is Status.running + + # Wait until the worker has been removed and wait for the stream comm to be repopulated + await async_wait_for(lambda: not s.stream_comms, 2) + await async_wait_for(lambda: s.stream_comms, 2) + assert a.status is Status.running + + for kind in [None, "runtime", "os"]: + a_stream_comm = s.stream_comms.get(a.address) + a_stream_comm.send({"op": "fail_async", "kind": kind}) + assert a.status is Status.running + await async_wait_for(lambda: not s.stream_comms, 2) + await async_wait_for(lambda: s.stream_comms, 2) + assert a.status is Status.running + + await asyncio.sleep(0) + # Is the worker still operational? + + futs = c.map(inc, range(10)) + res = await c.gather(futs) + assert sum(res) == sum(range(1, 11)) + + +@gen_cluster(nthreads=[("127.0.0.1", 0)]) +async def test_worker_terminate(s, a): + w_rpc = s.rpc(a.address) + await w_rpc.terminate(reply=False) + + while a.status != Status.closed: + await asyncio.sleep(0.05) + + # already closed should be noop + await a.close() diff --git a/distributed/utils.py b/distributed/utils.py index 33793b2f504..13dbd200f40 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -21,6 +21,7 @@ from collections import OrderedDict, UserDict, deque from concurrent.futures import CancelledError, ThreadPoolExecutor # noqa: F401 from contextlib import contextmanager, suppress +from functools import wraps from hashlib import md5 from importlib.util import cache_from_source from time import sleep @@ -202,6 +203,22 @@ def get_ip_interface(ifname): raise ValueError("interface %r doesn't have an IPv4 address" % (ifname,)) +def shielded(func): + """ + Shield decorated method or function from cancellation. Note that the + decorated coroutine will immediately scheduled as a task if the decorated + function is invoked. + + See also https://docs.python.org/3/library/asyncio-task.html#asyncio.shield + """ + + @wraps(func) + def _(*args, **kwargs): + return asyncio.shield(func(*args, **kwargs)) + + return _ + + async def All(args, quiet_exceptions=()): """Wait on many tasks at the same time diff --git a/distributed/utils_test.py b/distributed/utils_test.py index a40cf14151d..f589b775c7f 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1523,7 +1523,8 @@ def check_instances(): # raise ValueError("Unclosed Comms", L) assert all( - n.status == Status.closed or n.status == Status.init for n in Nanny._instances + n.status in [Status.closed, Status.init, Status.stopped] + for n in Nanny._instances ), {n: n.status for n in Nanny._instances} # assert not list(SpecCluster._instances) # TODO diff --git a/distributed/worker.py b/distributed/worker.py index 685f6df6240..3c5ed383226 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -68,6 +68,7 @@ parse_bytes, parse_ports, parse_timedelta, + shielded, silence_logging, thread_state, typename, @@ -1218,14 +1219,27 @@ def _close(self, *args, **kwargs): warnings.warn("Worker._close has moved to Worker.close", stacklevel=2) return self.close(*args, **kwargs) + @shielded async def close( self, report=True, timeout=10, nanny=True, executor_wait=True, safe=False ): with log_errors(): - if self.status in (Status.closed, Status.closing): + if self.status in ( + Status.closing, + Status.closed, + ): await self.finished() return + if self.status not in (Status.running, Status.closing_gracefully): + logger.info( + "Closed worker %s has not yet started: %s", + self.name, + self.status, + ) + + self.status = Status.closing + self.reconnect = False disable_gc_diagnosis() @@ -1233,9 +1247,6 @@ async def close( logger.info("Stopping worker at %s", self.address) except ValueError: # address not available if already closed logger.info("Stopping worker") - if self.status not in (Status.running, Status.closing_gracefully): - logger.info("Closed worker has not yet started: %s", self.status) - self.status = Status.closing for preload in self.preloads: await preload.teardown() @@ -1312,11 +1323,7 @@ async def close( else: executor.shutdown(wait=executor_wait) - self.stop() - await self.rpc.close() - - self.status = Status.closed - await ServerNode.close(self) + await super().close() setproctitle("dask-worker [closed]") return "OK"