diff --git a/distributed/core.py b/distributed/core.py index 1c95f6bd01..227c42610f 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -946,6 +946,12 @@ def __init__( self.server = weakref.ref(server) if server else None self._created = weakref.WeakSet() self._instances.add(self) + # _n_connecting and _connecting have subtle different semantics. The set + # _connecting contains futures actively trying to establish a connection + # while the _n_connecting also accounts for connection attempts which + # are waiting due to the connection limit + self._connecting = set() + self.status = Status.init def _validate(self): """ @@ -987,6 +993,7 @@ async def _(): async def start(self): # Invariant: semaphore._value == limit - open - _n_connecting self.semaphore = asyncio.Semaphore(self.limit) + self.status = Status.running async def connect(self, addr, timeout=None): """ @@ -1007,28 +1014,43 @@ async def connect(self, addr, timeout=None): self._n_connecting += 1 await self.semaphore.acquire() - + fut = None try: - comm = await connect( - addr, - timeout=timeout or self.timeout, - deserialize=self.deserialize, - **self.connection_args, + if self.status != Status.running: + raise CommClosedError( + f"ConnectionPool not running. Status: {self.status}" + ) + + fut = asyncio.ensure_future( + connect( + addr, + timeout=timeout or self.timeout, + deserialize=self.deserialize, + **self.connection_args, + ) ) + self._connecting.add(fut) + comm = await fut comm.name = "ConnectionPool" comm._pool = weakref.ref(self) comm.allow_offload = self.allow_offload self._created.add(comm) - except Exception: + + occupied.add(comm) + + return comm + except asyncio.CancelledError as exc: self.semaphore.release() - raise + raise CommClosedError( + f"ConnectionPool not running. Status: {self.status}" + ) from exc + except Exception as exc: + self.semaphore.release() + raise exc finally: + self._connecting.discard(fut) self._n_connecting -= 1 - occupied.add(comm) - - return comm - def reuse(self, addr, comm): """ Reuse an open communication to the given address. For internal use. @@ -1082,16 +1104,26 @@ async def close(self): """ Close all communications """ + self.status = Status.closed for d in [self.available, self.occupied]: - comms = [comm for comms in d.values() for comm in comms] + comms = set() + while d: + comms.update(d.popitem()[1]) + await asyncio.gather( *[comm.close() for comm in comms], return_exceptions=True ) + for _ in comms: self.semaphore.release() - for comm in self._created: - IOLoop.current().add_callback(comm.abort) + for conn_fut in self._connecting: + conn_fut.cancel() + + # We might still have tasks haning in the semaphore. This will let them + # run into an exception and raise a commclosed + while self._n_connecting: + await asyncio.sleep(0.005) def coerce_to_address(o): diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 9fc98f276f..d3df2e6186 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -9,6 +9,7 @@ import dask +from distributed.comm.core import CommClosedError from distributed.core import ( ConnectionPool, Server, @@ -614,6 +615,55 @@ async def ping(comm, delay=0.1): await rpc.close() +@pytest.mark.asyncio +async def test_connection_pool_close_while_connecting(monkeypatch): + """ + Ensure a closed connection pool guarantees to have no connections left open + even if it is closed mid-connecting + """ + from distributed.comm.registry import backends + from distributed.comm.tcp import TCPBackend, TCPConnector + + class SlowConnector(TCPConnector): + async def connect(self, address, deserialize, **connection_args): + await asyncio.sleep(0.1) + return await super().connect( + address, deserialize=deserialize, **connection_args + ) + + class SlowBackend(TCPBackend): + _connector_class = SlowConnector + + monkeypatch.setitem(backends, "tcp", SlowBackend()) + + server = Server({}) + await server.listen("tcp://") + + pool = await ConnectionPool(limit=2) + + async def connect_to_server(): + comm = await pool.connect(server.address) + pool.reuse(server.address, comm) + + tasks = [asyncio.create_task(connect_to_server()) for _ in range(30)] + + await asyncio.sleep(0) + assert pool._connecting + close_fut = asyncio.create_task(pool.close()) + + with pytest.raises( + CommClosedError, match="ConnectionPool not running. Status: Status.closed" + ): + await asyncio.gather(*tasks) + + await close_fut + assert not pool.open + assert not pool._n_connecting + + for t in tasks: + t.cancel() + + @pytest.mark.asyncio async def test_connection_pool_respects_limit():