diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 84c5882bde..c50d5487a4 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -66,7 +66,7 @@ from distributed.cluster_dump import load_cluster_dump from distributed.comm import CommClosedError from distributed.compatibility import LINUX, WINDOWS -from distributed.core import Status +from distributed.core import Server, Status from distributed.metrics import time from distributed.objects import HasWhat, WhoHas from distributed.scheduler import ( @@ -94,7 +94,6 @@ inc, map_varying, nodebug, - popen, pristine_loop, randominc, save_sys_modules, @@ -3701,60 +3700,70 @@ async def test_scatter_raises_if_no_workers(c, s): await c.scatter(1, timeout=0.5) -@pytest.mark.slow -def test_reconnect(loop): - w = Worker("127.0.0.1", 9393, loop=loop) - loop.add_callback(w.start) - - scheduler_cli = [ - "dask-scheduler", - "--host", - "127.0.0.1", - "--port", - "9393", - "--no-dashboard", - ] - with popen(scheduler_cli): - c = Client("127.0.0.1:9393", loop=loop) - c.wait_for_workers(1, timeout=10) - x = c.submit(inc, 1) - assert x.result(timeout=10) == 2 +@gen_test() +async def test_reconnect(): + async def hard_stop(s): + for pc in s.periodic_callbacks.values(): + pc.stop() + + s.stop_services() + for comm in list(s.stream_comms.values()): + comm.abort() + for comm in list(s.client_comms.values()): + comm.abort() + + await s.rpc.close() + s.stop() + await Server.close(s) + + port = 9393 + futures = [] + w = Worker(f"127.0.0.1:{port}") + futures.append(asyncio.ensure_future(w.start())) + + s = await Scheduler(port=port) + c = await Client(f"127.0.0.1:{port}", asynchronous=True) + await c.wait_for_workers(1, timeout=10) + x = c.submit(inc, 1) + assert (await x) == 2 + await hard_stop(s) start = time() while c.status != "connecting": assert time() < start + 10 - sleep(0.01) + await asyncio.sleep(0.01) assert x.status == "cancelled" with pytest.raises(CancelledError): - x.result(timeout=10) + await x - with popen(scheduler_cli): - start = time() - while c.status != "running": - sleep(0.1) - assert time() < start + 10 - start = time() - while len(c.nthreads()) != 1: - sleep(0.05) - assert time() < start + 10 + s = await Scheduler(port=port) + start = time() + while c.status != "running": + await asyncio.sleep(0.1) + assert time() < start + 10 + start = time() + while len(await c.nthreads()) != 1: + await asyncio.sleep(0.05) + assert time() < start + 10 - x = c.submit(inc, 1) - assert x.result(timeout=10) == 2 + x = c.submit(inc, 1) + assert (await x) == 2 + await hard_stop(s) start = time() while True: assert time() < start + 10 try: - x.result(timeout=10) + await x assert False except CommClosedError: continue except CancelledError: break - sync(loop, w.close, timeout=1) - c.close() + await w.close(report=False) + await c._close(fast=True) class UnhandledException(Exception): diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index e6469eae8b..e16269d92f 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import gc import itertools import logging import random @@ -945,6 +946,7 @@ class Foo: assert not s.who_has assert not any(s.has_what.values()) + gc.collect() assert not list(ws)