Skip to content

Commit

Permalink
Ensure connectionpool does not leave comms if closed mid connect (#4951)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Jun 24, 2021
1 parent 1bcd8a9 commit f1b0172
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 15 deletions.
62 changes: 47 additions & 15 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,12 @@ def __init__(
self.server = weakref.ref(server) if server else None
self._created = weakref.WeakSet()
self._instances.add(self)
# _n_connecting and _connecting have subtle different semantics. The set
# _connecting contains futures actively trying to establish a connection
# while the _n_connecting also accounts for connection attempts which
# are waiting due to the connection limit
self._connecting = set()
self.status = Status.init

def _validate(self):
"""
Expand Down Expand Up @@ -987,6 +993,7 @@ async def _():
async def start(self):
# Invariant: semaphore._value == limit - open - _n_connecting
self.semaphore = asyncio.Semaphore(self.limit)
self.status = Status.running

async def connect(self, addr, timeout=None):
"""
Expand All @@ -1007,28 +1014,43 @@ async def connect(self, addr, timeout=None):

self._n_connecting += 1
await self.semaphore.acquire()

fut = None
try:
comm = await connect(
addr,
timeout=timeout or self.timeout,
deserialize=self.deserialize,
**self.connection_args,
if self.status != Status.running:
raise CommClosedError(
f"ConnectionPool not running. Status: {self.status}"
)

fut = asyncio.ensure_future(
connect(
addr,
timeout=timeout or self.timeout,
deserialize=self.deserialize,
**self.connection_args,
)
)
self._connecting.add(fut)
comm = await fut
comm.name = "ConnectionPool"
comm._pool = weakref.ref(self)
comm.allow_offload = self.allow_offload
self._created.add(comm)
except Exception:

occupied.add(comm)

return comm
except asyncio.CancelledError as exc:
self.semaphore.release()
raise
raise CommClosedError(
f"ConnectionPool not running. Status: {self.status}"
) from exc
except Exception as exc:
self.semaphore.release()
raise exc
finally:
self._connecting.discard(fut)
self._n_connecting -= 1

occupied.add(comm)

return comm

def reuse(self, addr, comm):
"""
Reuse an open communication to the given address. For internal use.
Expand Down Expand Up @@ -1082,16 +1104,26 @@ async def close(self):
"""
Close all communications
"""
self.status = Status.closed
for d in [self.available, self.occupied]:
comms = [comm for comms in d.values() for comm in comms]
comms = set()
while d:
comms.update(d.popitem()[1])

await asyncio.gather(
*[comm.close() for comm in comms], return_exceptions=True
)

for _ in comms:
self.semaphore.release()

for comm in self._created:
IOLoop.current().add_callback(comm.abort)
for conn_fut in self._connecting:
conn_fut.cancel()

# We might still have tasks haning in the semaphore. This will let them
# run into an exception and raise a commclosed
while self._n_connecting:
await asyncio.sleep(0.005)


def coerce_to_address(o):
Expand Down
50 changes: 50 additions & 0 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import dask

from distributed.comm.core import CommClosedError
from distributed.core import (
ConnectionPool,
Server,
Expand Down Expand Up @@ -614,6 +615,55 @@ async def ping(comm, delay=0.1):
await rpc.close()


@pytest.mark.asyncio
async def test_connection_pool_close_while_connecting(monkeypatch):
"""
Ensure a closed connection pool guarantees to have no connections left open
even if it is closed mid-connecting
"""
from distributed.comm.registry import backends
from distributed.comm.tcp import TCPBackend, TCPConnector

class SlowConnector(TCPConnector):
async def connect(self, address, deserialize, **connection_args):
await asyncio.sleep(0.1)
return await super().connect(
address, deserialize=deserialize, **connection_args
)

class SlowBackend(TCPBackend):
_connector_class = SlowConnector

monkeypatch.setitem(backends, "tcp", SlowBackend())

server = Server({})
await server.listen("tcp://")

pool = await ConnectionPool(limit=2)

async def connect_to_server():
comm = await pool.connect(server.address)
pool.reuse(server.address, comm)

tasks = [asyncio.create_task(connect_to_server()) for _ in range(30)]

await asyncio.sleep(0)
assert pool._connecting
close_fut = asyncio.create_task(pool.close())

with pytest.raises(
CommClosedError, match="ConnectionPool not running. Status: Status.closed"
):
await asyncio.gather(*tasks)

await close_fut
assert not pool.open
assert not pool._n_connecting

for t in tasks:
t.cancel()


@pytest.mark.asyncio
async def test_connection_pool_respects_limit():

Expand Down

0 comments on commit f1b0172

Please sign in to comment.