Skip to content

Commit

Permalink
Ensure exceptions in handlers are handled equally for sync and async
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jun 30, 2021
1 parent 84641e7 commit 1cee222
Show file tree
Hide file tree
Showing 14 changed files with 233 additions and 70 deletions.
2 changes: 1 addition & 1 deletion distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ async def _handle_stream(self, stream, address):
try:
await self.on_connection(comm)
except CommClosedError:
logger.info("Connection closed before handshake completed")
logger.info("Connection from %s closed before handshake completed", address)
return

await self.comm_handler(comm)
Expand Down
46 changes: 30 additions & 16 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import tblib
from tlz import merge
from tornado import gen
from tornado.ioloop import IOLoop, PeriodicCallback

import dask
Expand All @@ -37,6 +36,7 @@
get_traceback,
has_keyword,
is_coroutine_function,
shielded,
truncate_exception,
)

Expand Down Expand Up @@ -160,6 +160,7 @@ def __init__(
self.counters = None
self.digests = None
self._ongoing_coroutines = weakref.WeakSet()

self._event_finished = asyncio.Event()

self.listeners = []
Expand Down Expand Up @@ -263,10 +264,18 @@ async def finished(self):

def __await__(self):
async def _():
if self.status == Status.running:
return self

if self.status in (Status.closing, Status.closed):
# We should never await an object which is already closing but
# we should also not start it up again otherwise we'd produce
# zombies
await self.finished()
return

timeout = getattr(self, "death_timeout", 0)
async with self._startup_lock:
if self.status == Status.running:
return self
if timeout:
try:
await asyncio.wait_for(self.start(), timeout=timeout)
Expand Down Expand Up @@ -422,7 +431,7 @@ async def handle_comm(self, comm):

logger.debug("Connection from %r to %s", address, type(self).__name__)
self._comms[comm] = op
await self

try:
while True:
try:
Expand Down Expand Up @@ -565,17 +574,15 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]):
break
handler = self.stream_handlers[op]
if is_coroutine_function(handler):
self.loop.add_callback(handler, **merge(extra, msg))
await gen.sleep(0)
await handler(**merge(extra, msg))
else:
handler(**merge(extra, msg))
else:
logger.error("odd message %s", msg)
await asyncio.sleep(0)

for func in every_cycle:
if is_coroutine_function(func):
self.loop.add_callback(func)
await func()
else:
func()

Expand All @@ -593,32 +600,39 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]):
await comm.close()
assert comm.closed()

@gen.coroutine
def close(self):
@shielded
async def close(self):
self.status = Status.closing
self.stop()
await self.rpc.close()

for pc in self.periodic_callbacks.values():
pc.stop()
self.__stopped = True
for listener in self.listeners:
future = listener.stop()
if inspect.isawaitable(future):
yield future
for i in range(20):
await future
for _ in range(20):
# If there are still handlers running at this point, give them a
# second to finish gracefully themselves, otherwise...
if any(self._comms.values()):
yield asyncio.sleep(0.05)
await asyncio.sleep(0.05)
else:
break
yield [comm.close() for comm in list(self._comms)] # then forcefully close
await asyncio.gather(
*[comm.close() for comm in list(self._comms)]
) # then forcefully close
for cb in self._ongoing_coroutines:
cb.cancel()
for i in range(10):
for _ in range(10):
if all(c.cancelled() for c in self._ongoing_coroutines):
break
else:
yield asyncio.sleep(0.01)
await asyncio.sleep(0.01)

self._event_finished.set()
self.status = Status.closed


def pingpong(comm):
Expand Down
22 changes: 11 additions & 11 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,12 @@ async def _correct_state_internal(self):
if to_close:
if self.scheduler.status == Status.running:
await self.scheduler_comm.retire_workers(workers=list(to_close))
tasks = [
asyncio.create_task(self.workers[w].close())
for w in to_close
if w in self.workers
]
await asyncio.wait(tasks)
for task in tasks: # for tornado gen.coroutine support
with suppress(RuntimeError):
await task
finished, _ = await asyncio.wait(
[self.workers[w].close() for w in to_close if w in self.workers]
)
for task in finished:
if task.exception:
raise task.exception()
for name in to_close:
if name in self.workers:
del self.workers[name]
Expand All @@ -359,10 +356,13 @@ async def _correct_state_internal(self):
self._created.add(worker)
workers.append(worker)
if workers:
await asyncio.wait(workers)
finished, _ = await asyncio.wait(workers)
for task in finished:
if task.exception:
raise task.exception()

for w in workers:
w._cluster = weakref.ref(self)
await w # for tornado gen.coroutine support
self.workers.update(dict(zip(to_open, workers)))

def _update_worker_status(self, op, msg):
Expand Down
7 changes: 2 additions & 5 deletions distributed/deploy/tests/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,19 +446,16 @@ async def test_scale_needs_to_be_awaited(cleanup):

class RequiresAwaitCluster(LocalCluster):
def scale(self, n):
# super invocation in the nested function scope is messy
method = super().scale

async def _():
return method(n)
return LocalCluster.scale(self, n)

return self.sync(_)

async with RequiresAwaitCluster(n_workers=0, asynchronous=True) as cluster:
async with Client(cluster, asynchronous=True) as client:
futures = client.map(slowinc, range(5), delay=0.05)
assert len(cluster.workers) == 0
cluster.adapt()
cluster.adapt(interval="10ms")

await client.gather(futures)

Expand Down
6 changes: 3 additions & 3 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,12 +1070,12 @@ async def test_local_cluster_redundant_kwarg(nanny):
# Extra arguments are forwarded to the worker class. Depending on
# whether we use the nanny or not, the error treatment is quite
# different and we should assert that an exception is raised
async with await LocalCluster(
typo_kwarg="foo", processes=nanny, n_workers=1
async with LocalCluster(
typo_kwarg="foo", processes=nanny, n_workers=1, asynchronous=True
) as cluster:

# This will never work but is a reliable way to block without hard
# coding any sleep values
async with Client(cluster) as c:
async with Client(cluster, asynchronous=True) as c:
f = c.submit(sleep, 0)
await f
3 changes: 1 addition & 2 deletions distributed/deploy/tests/test_spec_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ async def test_scale(cleanup):
assert len(cluster.workers) == 2


@pytest.mark.slow
@pytest.mark.asyncio
async def test_adaptive_killed_worker(cleanup):
with dask.config.set({"distributed.deploy.lost-worker-timeout": 0.1}):
Expand All @@ -150,7 +149,7 @@ async def test_adaptive_killed_worker(cleanup):

async with Client(cluster, asynchronous=True) as client:

cluster.adapt(minimum=1, maximum=1)
cluster.adapt(minimum=1, maximum=1, interval="10ms")

# Scale up a cluster with 1 worker.
while len(cluster.workers) != 1:
Expand Down
24 changes: 13 additions & 11 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
json_load_robust,
mp_context,
parse_ports,
shielded,
silence_logging,
)
from .worker import Worker, parse_memory_limit, run
Expand Down Expand Up @@ -194,15 +195,17 @@ def __init__(
"instantiate": self.instantiate,
"kill": self.kill,
"restart": self.restart,
# cannot call it 'close' on the rpc side for naming conflict
"get_logs": self.get_logs,
# cannot call it 'close' on the rpc side for naming conflict
"terminate": self.close,
"close_gracefully": self.close_gracefully,
"run": self.run,
}

super().__init__(
handlers=handlers, io_loop=self.loop, connection_args=self.connection_args
handlers=handlers,
io_loop=self.loop,
connection_args=self.connection_args,
)

self.scheduler = self.rpc(self.scheduler_addr)
Expand Down Expand Up @@ -306,6 +309,7 @@ async def start(self):

return self

@shielded
async def kill(self, comm=None, timeout=2):
"""Kill the local worker process
Expand All @@ -317,6 +321,7 @@ async def kill(self, comm=None, timeout=2):
return "OK"

deadline = self.loop.time() + timeout
self.status = Status.stopped
await self.process.kill(timeout=0.8 * (deadline - self.loop.time()))

async def instantiate(self, comm=None) -> Status:
Expand Down Expand Up @@ -489,15 +494,13 @@ def close_gracefully(self, comm=None):
"""
self.status = Status.closing_gracefully

@shielded
async def close(self, comm=None, timeout=5, report=None):
"""
Close the worker process, stop all comms.
"""
if self.status == Status.closing:
if self.status in (Status.closing, Status.closed):
await self.finished()
assert self.status == Status.closed

if self.status == Status.closed:
return "OK"

self.status = Status.closing
Expand All @@ -506,18 +509,15 @@ async def close(self, comm=None, timeout=5, report=None):
for preload in self.preloads:
await preload.teardown()

self.stop()
try:
if self.process is not None:
await self.kill(timeout=timeout)
except Exception:
pass
self.process = None
await self.rpc.close()
self.status = Status.closed
if comm:
await comm.write("OK")
await ServerNode.close(self)
await super().close()


class WorkerProcess:
Expand Down Expand Up @@ -653,6 +653,7 @@ def mark_stopped(self):
if self.on_exit is not None:
self.on_exit(r)

@shielded
async def kill(self, timeout=2, executor_wait=True):
"""
Ensure the worker process is stopped, waiting at most
Expand Down Expand Up @@ -746,6 +747,7 @@ def _run(
loop.make_current()
worker = Worker(**worker_kwargs)

@shielded
async def do_stop(timeout=5, executor_wait=True):
try:
await worker.close(
Expand Down Expand Up @@ -795,7 +797,7 @@ async def run():
# properly handled. See also
# WorkerProcess._wait_until_connected (the 2 is for good
# measure)
sync_sleep(cls._init_msg_interval * 2)
await asyncio.sleep(cls._init_msg_interval * 2)
else:
try:
assert worker.address
Expand Down
Loading

0 comments on commit 1cee222

Please sign in to comment.