Skip to content

Commit

Permalink
Ensure exceptions in handlers are handled equally for sync and async
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jun 24, 2021
1 parent f1b0172 commit 6ddf5b1
Show file tree
Hide file tree
Showing 13 changed files with 219 additions and 64 deletions.
2 changes: 1 addition & 1 deletion distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ async def _handle_stream(self, stream, address):
try:
await self.on_connection(comm)
except CommClosedError:
logger.info("Connection closed before handshake completed")
logger.info("Connection from %s closed before handshake completed", address)
return

await self.comm_handler(comm)
Expand Down
45 changes: 29 additions & 16 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import tblib
from tlz import merge
from tornado import gen
from tornado.ioloop import IOLoop, PeriodicCallback

import dask
Expand All @@ -37,6 +36,7 @@
has_keyword,
is_coroutine_function,
parse_timedelta,
shielded,
truncate_exception,
)

Expand Down Expand Up @@ -160,6 +160,7 @@ def __init__(
self.counters = None
self.digests = None
self._ongoing_coroutines = weakref.WeakSet()

self._event_finished = asyncio.Event()

self.listeners = []
Expand Down Expand Up @@ -263,10 +264,17 @@ async def finished(self):

def __await__(self):
async def _():
if self.status == Status.running:
return self

if self.status in (Status.closing, Status.closed):
# We should never await an object which is already closing but
# we should also not start it up again otherwise we'd produce
# zombies
await self.finished()

timeout = getattr(self, "death_timeout", 0)
async with self._startup_lock:
if self.status == Status.running:
return self
if timeout:
try:
await asyncio.wait_for(self.start(), timeout=timeout)
Expand Down Expand Up @@ -422,7 +430,7 @@ async def handle_comm(self, comm):

logger.debug("Connection from %r to %s", address, type(self).__name__)
self._comms[comm] = op
await self

try:
while True:
try:
Expand Down Expand Up @@ -565,17 +573,15 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]):
break
handler = self.stream_handlers[op]
if is_coroutine_function(handler):
self.loop.add_callback(handler, **merge(extra, msg))
await gen.sleep(0)
await handler(**merge(extra, msg))
else:
handler(**merge(extra, msg))
else:
logger.error("odd message %s", msg)
await asyncio.sleep(0)

for func in every_cycle:
if is_coroutine_function(func):
self.loop.add_callback(func)
await func()
else:
func()

Expand All @@ -593,32 +599,39 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]):
await comm.close()
assert comm.closed()

@gen.coroutine
def close(self):
@shielded
async def close(self):
self.status = Status.closing
self.stop()
await self.rpc.close()

for pc in self.periodic_callbacks.values():
pc.stop()
self.__stopped = True
for listener in self.listeners:
future = listener.stop()
if inspect.isawaitable(future):
yield future
for i in range(20):
await future
for _ in range(20):
# If there are still handlers running at this point, give them a
# second to finish gracefully themselves, otherwise...
if any(self._comms.values()):
yield asyncio.sleep(0.05)
await asyncio.sleep(0.05)
else:
break
yield [comm.close() for comm in list(self._comms)] # then forcefully close
await asyncio.gather(
*[comm.close() for comm in list(self._comms)]
) # then forcefully close
for cb in self._ongoing_coroutines:
cb.cancel()
for i in range(10):
for _ in range(10):
if all(c.cancelled() for c in self._ongoing_coroutines):
break
else:
yield asyncio.sleep(0.01)
await asyncio.sleep(0.01)

self._event_finished.set()
self.status = Status.closed


def pingpong(comm):
Expand Down
9 changes: 1 addition & 8 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,8 @@ async def _correct_state_internal(self):
if to_close:
if self.scheduler.status == Status.running:
await self.scheduler_comm.retire_workers(workers=list(to_close))
tasks = [
asyncio.create_task(self.workers[w].close())
for w in to_close
if w in self.workers
]
tasks = [self.workers[w].close() for w in to_close if w in self.workers]
await asyncio.wait(tasks)
for task in tasks: # for tornado gen.coroutine support
with suppress(RuntimeError):
await task
for name in to_close:
if name in self.workers:
del self.workers[name]
Expand Down
7 changes: 2 additions & 5 deletions distributed/deploy/tests/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,19 +446,16 @@ async def test_scale_needs_to_be_awaited(cleanup):

class RequiresAwaitCluster(LocalCluster):
def scale(self, n):
# super invocation in the nested function scope is messy
method = super().scale

async def _():
return method(n)
return LocalCluster.scale(self, n)

return self.sync(_)

async with RequiresAwaitCluster(n_workers=0, asynchronous=True) as cluster:
async with Client(cluster, asynchronous=True) as client:
futures = client.map(slowinc, range(5), delay=0.05)
assert len(cluster.workers) == 0
cluster.adapt()
cluster.adapt(interval="10ms")

await client.gather(futures)

Expand Down
3 changes: 1 addition & 2 deletions distributed/deploy/tests/test_spec_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ async def test_scale(cleanup):
assert len(cluster.workers) == 2


@pytest.mark.slow
@pytest.mark.asyncio
async def test_adaptive_killed_worker(cleanup):
with dask.config.set({"distributed.deploy.lost-worker-timeout": 0.1}):
Expand All @@ -150,7 +149,7 @@ async def test_adaptive_killed_worker(cleanup):

async with Client(cluster, asynchronous=True) as client:

cluster.adapt(minimum=1, maximum=1)
cluster.adapt(minimum=1, maximum=1, interval="10ms")

# Scale up a cluster with 1 worker.
while len(cluster.workers) != 1:
Expand Down
24 changes: 13 additions & 11 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
mp_context,
parse_ports,
parse_timedelta,
shielded,
silence_logging,
)
from .worker import Worker, parse_memory_limit, run
Expand Down Expand Up @@ -194,15 +195,17 @@ def __init__(
"instantiate": self.instantiate,
"kill": self.kill,
"restart": self.restart,
# cannot call it 'close' on the rpc side for naming conflict
"get_logs": self.get_logs,
# cannot call it 'close' on the rpc side for naming conflict
"terminate": self.close,
"close_gracefully": self.close_gracefully,
"run": self.run,
}

super().__init__(
handlers=handlers, io_loop=self.loop, connection_args=self.connection_args
handlers=handlers,
io_loop=self.loop,
connection_args=self.connection_args,
)

self.scheduler = self.rpc(self.scheduler_addr)
Expand Down Expand Up @@ -306,6 +309,7 @@ async def start(self):

return self

@shielded
async def kill(self, comm=None, timeout=2):
"""Kill the local worker process
Expand All @@ -317,6 +321,7 @@ async def kill(self, comm=None, timeout=2):
return "OK"

deadline = self.loop.time() + timeout
self.status = Status.stopped
await self.process.kill(timeout=0.8 * (deadline - self.loop.time()))

async def instantiate(self, comm=None) -> Status:
Expand Down Expand Up @@ -489,15 +494,13 @@ def close_gracefully(self, comm=None):
"""
self.status = Status.closing_gracefully

@shielded
async def close(self, comm=None, timeout=5, report=None):
"""
Close the worker process, stop all comms.
"""
if self.status == Status.closing:
if self.status in (Status.closing, Status.closed):
await self.finished()
assert self.status == Status.closed

if self.status == Status.closed:
return "OK"

self.status = Status.closing
Expand All @@ -506,18 +509,15 @@ async def close(self, comm=None, timeout=5, report=None):
for preload in self.preloads:
await preload.teardown()

self.stop()
try:
if self.process is not None:
await self.kill(timeout=timeout)
except Exception:
pass
self.process = None
await self.rpc.close()
self.status = Status.closed
if comm:
await comm.write("OK")
await ServerNode.close(self)
await super().close()


class WorkerProcess:
Expand Down Expand Up @@ -653,6 +653,7 @@ def mark_stopped(self):
if self.on_exit is not None:
self.on_exit(r)

@shielded
async def kill(self, timeout=2, executor_wait=True):
"""
Ensure the worker process is stopped, waiting at most
Expand Down Expand Up @@ -746,6 +747,7 @@ def _run(
loop.make_current()
worker = Worker(**worker_kwargs)

@shielded
async def do_stop(timeout=5, executor_wait=True):
try:
await worker.close(
Expand Down Expand Up @@ -795,7 +797,7 @@ async def run():
# properly handled. See also
# WorkerProcess._wait_until_connected (the 2 is for good
# measure)
sync_sleep(cls._init_msg_interval * 2)
await asyncio.sleep(cls._init_msg_interval * 2)
else:
try:
assert worker.address
Expand Down
27 changes: 16 additions & 11 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
no_default,
parse_bytes,
parse_timedelta,
shielded,
tmpfile,
validate_key,
)
Expand Down Expand Up @@ -3585,7 +3586,6 @@ def __init__(
"start_task_metadata": self.start_task_metadata,
"stop_task_metadata": self.stop_task_metadata,
}

connection_limit = get_fileno_limit() / 2

super().__init__(
Expand Down Expand Up @@ -3754,6 +3754,7 @@ def del_scheduler_file():
setproctitle("dask-scheduler [%s]" % (self.address,))
return self

@shielded
async def close(self, comm=None, fast=False, close_workers=False):
"""Send cleanup signal to all coroutines then wait until finished
Expand Down Expand Up @@ -3810,10 +3811,6 @@ async def close(self, comm=None, fast=False, close_workers=False):
for comm in self.client_comms.values():
comm.abort()

await self.rpc.close()

self.status = Status.closed
self.stop()
await super().close()

setproctitle("dask-scheduler [closed]")
Expand All @@ -3823,15 +3820,20 @@ async def close_worker(self, comm=None, worker=None, safe=None):
"""Remove a worker from the cluster
This both removes the worker from our local state and also sends a
signal to the worker to shut down. This works regardless of whether or
not the worker has a nanny process restarting it
signal to the worker to shut down.
If a Nanny is in front of the worker, the Nanny is instead instructed to
close.
"""
logger.info("Closing worker %s", worker)
parent: SchedulerState = cast(SchedulerState, self)
with log_errors():
self.log_event(worker, {"action": "close-worker"})
# FIXME: This does not handly nannys
self.worker_send(worker, {"op": "close", "report": False})
await self.remove_worker(address=worker, safe=safe)
ws: WorkerState = parent._workers_dv[worker]
if ws._nanny:
await self.rpc(ws._nanny).terminate()
else:
self.worker_send(worker, {"op": "close", "report": False})
await self.remove_worker(address=worker, safe=safe, close=False)

###########
# Stimuli #
Expand Down Expand Up @@ -6277,7 +6279,10 @@ async def retire_workers(
)
if remove:
await asyncio.gather(
*[self.remove_worker(address=w, safe=True) for w in worker_keys]
*[
self.remove_worker(address=w, safe=True, close=False)
for w in worker_keys
]
)

self.log_event(
Expand Down
Loading

0 comments on commit 6ddf5b1

Please sign in to comment.