Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure exceptions in handlers are handled equally for sync and async #4734

Closed
wants to merge 8 commits into from
11 changes: 5 additions & 6 deletions distributed/cli/tests/test_dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from distributed.utils_test import (
assert_can_connect_from_everywhere_4_6,
assert_can_connect_locally_4,
get_unused_port,
popen,
)

Expand Down Expand Up @@ -296,20 +297,18 @@ def check_scheduler():
def test_preload_remote_module(loop, tmp_path):
with open(tmp_path / "scheduler_info.py", "w") as f:
f.write(PRELOAD_TEXT)

with popen([sys.executable, "-m", "http.server", "9382"], cwd=tmp_path):
port = get_unused_port()
with popen([sys.executable, "-m", "http.server", str(port)], cwd=tmp_path):
with popen(
[
"dask-scheduler",
"--scheduler-file",
str(tmp_path / "scheduler-file.json"),
"--preload",
"http://localhost:9382/scheduler_info.py",
f"http://localhost:{port}/scheduler_info.py",
]
) as proc:
with Client(
scheduler_file=tmp_path / "scheduler-file.json", loop=loop
) as c:
with Client(scheduler_file=tmp_path / "scheduler-file.json") as c:
assert (
c.run_on_scheduler(
lambda dask_scheduler: getattr(dask_scheduler, "foo", None)
Expand Down
2 changes: 1 addition & 1 deletion distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ async def _handle_stream(self, stream, address):
try:
await self.on_connection(comm)
except CommClosedError:
logger.info("Connection closed before handshake completed")
logger.info("Connection from %s closed before handshake completed", address)
return

await self.comm_handler(comm)
Expand Down
92 changes: 60 additions & 32 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import tblib
from tlz import merge
from tornado import gen
from tornado.ioloop import IOLoop, PeriodicCallback

import dask
Expand All @@ -37,6 +36,7 @@
get_traceback,
has_keyword,
is_coroutine_function,
shielded,
truncate_exception,
)

Expand Down Expand Up @@ -161,6 +161,7 @@ def __init__(
self.counters = None
self.digests = None
self._ongoing_coroutines = weakref.WeakSet()

self._event_finished = asyncio.Event()

self.listeners = []
Expand Down Expand Up @@ -223,7 +224,7 @@ def set_thread_ident():
self.thread_id = threading.get_ident()

self.io_loop.add_callback(set_thread_ident)
self._startup_lock = asyncio.Lock()
self._startup_fut = None
self.status = Status.undefined

self.rpc = ConnectionPool(
Expand Down Expand Up @@ -264,30 +265,46 @@ async def finished(self):

def __await__(self):
async def _():
timeout = getattr(self, "death_timeout", 0)
async with self._startup_lock:
if self.status == Status.running:
return self
if self.status in (Status.closing, Status.closed):
# We should never await an object which is already closing but
# we should also not start it up again otherwise we'd produce
# zombies
await self.finished()
return

timeout = getattr(self, "death_timeout", None)
if self._startup_fut is None:
self._startup_fut = asyncio.ensure_future(
asyncio.wait_for(self._start(), timeout=timeout)
)
await self.rpc.start()

try:
await self._startup_fut
except Exception:
await self.rpc.close()
# Suppress all exception since the objects might not have been
# properly initialized for close to be successful.
with suppress(Exception):
await self.close()
if timeout:
try:
await asyncio.wait_for(self.start(), timeout=timeout)
self.status = Status.running
except Exception:
await self.close(timeout=1)
raise TimeoutError(
"{} failed to start in {} seconds".format(
type(self).__name__, timeout
)
)
raise TimeoutError(
f"{type(self).__name__} failed to start in {timeout} seconds"
)
else:
await self.start()
self.status = Status.running
raise
return self

return _().__await__()

async def _start(self):
"""Child specific logic to define startup behaviour."""
self.status = Status.running

async def start(self):
await self.rpc.start()
"""Start the server. Child classes are supposed to overwrite
Server._start for custom startup logic."""
await self

async def __aenter__(self):
await self
Expand Down Expand Up @@ -426,7 +443,11 @@ async def handle_comm(self, comm):

logger.debug("Connection from %r to %s", address, type(self).__name__)
self._comms[comm] = op

# The server might already be listening even though it is not properly
# started, yet. (e.g. preload modules not done)
await self
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This await self without the guard I introduced in __await__ would allow an incoming comm to restart a currently shutting down scheduler. There is a time window while the scheduler is handling its affairs, before the actual server closes, where it still accepts incoming comms. If during this time a new comm is opened, the scheduler is restarted although it is currently trying to shutdown.

self.status = Status.closing
logger.info("Scheduler closing...")
setproctitle("dask-scheduler [closing]")
for preload in self.preloads:
await preload.teardown()
if close_workers:
await self.broadcast(msg={"op": "close_gracefully"}, nanny=True)
for worker in parent._workers_dv:
self.worker_send(worker, {"op": "close"})
for i in range(20): # wait a second for send signals to clear
if parent._workers_dv:
await asyncio.sleep(0.05)
else:
break
await asyncio.gather(*[plugin.close() for plugin in self.plugins])
for pc in self.periodic_callbacks.values():
pc.stop()
self.periodic_callbacks.clear()
self.stop_services()
for ext in parent._extensions.values():
with suppress(AttributeError):
ext.teardown()
logger.info("Scheduler closing all comms")
futures = []
for w, comm in list(self.stream_comms.items()):
if not comm.closed():
comm.send({"op": "close", "report": False})
comm.send({"op": "close-stream"})
with suppress(AttributeError):
futures.append(comm.close())
for future in futures: # TODO: do all at once
await future
for comm in self.client_comms.values():
comm.abort()
await self.rpc.close()
self.status = Status.closed
self.stop()
await super().close()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add a dedicated test for this since there are actually many failing tests once the close coroutines are shielded


try:
while True:
try:
Expand Down Expand Up @@ -569,17 +590,15 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]):
break
handler = self.stream_handlers[op]
if is_coroutine_function(handler):
self.loop.add_callback(handler, **merge(extra, msg))
await gen.sleep(0)
await handler(**merge(extra, msg))
else:
handler(**merge(extra, msg))
else:
logger.error("odd message %s", msg)
await asyncio.sleep(0)

for func in every_cycle:
if is_coroutine_function(func):
self.loop.add_callback(func)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tracked down the dangling task and this was the ensure_compute scheduled here. In particular it concerned the ensure_computing which uses the async interface to offload deserialisation. This was introduced in #4307 just before the infamous 2020.12.0 release

await func()
else:
func()

Expand All @@ -597,32 +616,39 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]):
await comm.close()
assert comm.closed()

@gen.coroutine
def close(self):
@shielded
async def close(self):
self.status = Status.closing
self.stop()
await self.rpc.close()

for pc in self.periodic_callbacks.values():
pc.stop()
self.__stopped = True
for listener in self.listeners:
future = listener.stop()
if inspect.isawaitable(future):
yield future
for i in range(20):
await future
for _ in range(20):
# If there are still handlers running at this point, give them a
# second to finish gracefully themselves, otherwise...
if any(self._comms.values()):
yield asyncio.sleep(0.05)
await asyncio.sleep(0.05)
else:
break
yield [comm.close() for comm in list(self._comms)] # then forcefully close
await asyncio.gather(
*[comm.close() for comm in list(self._comms)]
) # then forcefully close
for cb in self._ongoing_coroutines:
cb.cancel()
for i in range(10):
for _ in range(10):
if all(c.cancelled() for c in self._ongoing_coroutines):
break
else:
yield asyncio.sleep(0.01)
await asyncio.sleep(0.01)

self._event_finished.set()
self.status = Status.closed


def pingpong(comm):
Expand Down Expand Up @@ -993,6 +1019,8 @@ async def _():
return _().__await__()

async def start(self):
if self.status is not Status.init:
return
# Invariant: semaphore._value == limit - open - _n_connecting
self.semaphore = asyncio.Semaphore(self.limit)
self.status = Status.running
Expand Down Expand Up @@ -1022,7 +1050,6 @@ async def connect(self, addr, timeout=None):
raise CommClosedError(
f"ConnectionPool not running. Status: {self.status}"
)

fut = asyncio.ensure_future(
connect(
addr,
Expand Down Expand Up @@ -1126,6 +1153,7 @@ async def close(self):
# run into an exception and raise a commclosed
while self._n_connecting:
await asyncio.sleep(0.005)
await asyncio.sleep(0)


def coerce_to_address(o):
Expand Down
24 changes: 13 additions & 11 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,13 @@ async def _correct_state_internal(self):
if to_close:
if self.scheduler.status == Status.running:
await self.scheduler_comm.retire_workers(workers=list(to_close))
tasks = [
asyncio.create_task(self.workers[w].close())
for w in to_close
if w in self.workers
]
await asyncio.wait(tasks)
for task in tasks: # for tornado gen.coroutine support
with suppress(RuntimeError):
await task
finished, _ = await asyncio.wait(
[self.workers[w].close() for w in to_close if w in self.workers]
)
for task in finished:
exc = task.exception()
if exc:
raise exc
for name in to_close:
if name in self.workers:
del self.workers[name]
Expand All @@ -359,10 +357,14 @@ async def _correct_state_internal(self):
self._created.add(worker)
workers.append(worker)
if workers:
await asyncio.wait(workers)
finished, _ = await asyncio.wait(workers)
for task in finished:
exc = task.exception()
if exc:
raise exc

for w in workers:
w._cluster = weakref.ref(self)
await w # for tornado gen.coroutine support
self.workers.update(dict(zip(to_open, workers)))

def _update_worker_status(self, op, msg):
Expand Down
7 changes: 2 additions & 5 deletions distributed/deploy/tests/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,19 +439,16 @@ async def test_scale_needs_to_be_awaited(cleanup):

class RequiresAwaitCluster(LocalCluster):
def scale(self, n):
# super invocation in the nested function scope is messy
method = super().scale

async def _():
return method(n)
return LocalCluster.scale(self, n)

return self.sync(_)

async with RequiresAwaitCluster(n_workers=0, asynchronous=True) as cluster:
async with Client(cluster, asynchronous=True) as client:
futures = client.map(slowinc, range(5), delay=0.05)
assert len(cluster.workers) == 0
cluster.adapt()
cluster.adapt(interval="10ms")

await client.gather(futures)

Expand Down
6 changes: 3 additions & 3 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,12 +1068,12 @@ async def test_local_cluster_redundant_kwarg(nanny):
# Extra arguments are forwarded to the worker class. Depending on
# whether we use the nanny or not, the error treatment is quite
# different and we should assert that an exception is raised
async with await LocalCluster(
typo_kwarg="foo", processes=nanny, n_workers=1
async with LocalCluster(
typo_kwarg="foo", processes=nanny, n_workers=1, asynchronous=True
) as cluster:

# This will never work but is a reliable way to block without hard
# coding any sleep values
async with Client(cluster) as c:
async with Client(cluster, asynchronous=True) as c:
f = c.submit(sleep, 0)
await f
3 changes: 3 additions & 0 deletions distributed/deploy/tests/test_spec_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ def __str__(self):

__repr__ = __str__

def __await__(self):
return self.start().__await__()

Comment on lines +412 to +414
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This MultiWorker is a bit strange since it isn't a Server although Workers are always Servers

async def start(self):
await asyncio.gather(*self.workers)

Expand Down
Loading