Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use asyncio for TCP/TLS comms #5450

Merged
merged 16 commits into from
Dec 10, 2021
112 changes: 93 additions & 19 deletions distributed/comm/asyncio_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,60 @@
_COMM_CLOSED = object()


def coalesce_buffers(buffers, target_buffer_size=64 * 1024, small_buffer_size=2048):
"""Given a list of buffers, coalesce them into a new list of buffers that
minimizes both copying and tiny writes.

Parameters
----------
buffers : list of bytes_like
target_buffer_size : int, optional
The target intermediate buffer size from concatenating small buffers
together. Coalesced buffers will be no larger than approximately this size.
small_buffer_size : int, optional
Buffers <= this size are considered "small" and may be copied. Buffers
larger than this may also be copied if the total message length is less
than ``target_buffer_size``.
"""
# Nothing to do
if len(buffers) == 1:
return buffers

# If the whole message can be sent in <= target_buffer_size, always concatenate
if sum(map(len, buffers)) <= target_buffer_size:
jcrist marked this conversation as resolved.
Show resolved Hide resolved
return [b"".join(buffers)]

out_buffers = []
concat = [] # A list of buffers to concatenate
csize = 0 # The total size of the concatenated buffers

def flush():
nonlocal csize
if concat:
if len(concat) == 1:
out_buffers.append(concat[0])
else:
out_buffers.append(b"".join(concat))
concat.clear()
csize = 0

for b in buffers:
if isinstance(b, memoryview):
b = b.cast("B")
size = len(b)
jcrist marked this conversation as resolved.
Show resolved Hide resolved
if size <= small_buffer_size:
concat.append(b)
csize += size
if csize >= target_buffer_size:
flush()
else:
flush()
out_buffers.append(b)
flush()

return out_buffers


class DaskCommProtocol(asyncio.BufferedProtocol):
def __init__(self, on_connection=None, min_read_size=128 * 1024):
jcrist marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
Expand Down Expand Up @@ -272,16 +326,13 @@ async def write(self, frames):
# change to the comms.
msg_nbytes = sum(frames_nbytes) + (nframes + 1) * 8
header = struct.pack(f"{nframes + 2}Q", msg_nbytes, nframes, *frames_nbytes)
frames = [header, *frames]

if msg_nbytes < 2 ** 17: # 128kiB
# small enough, send in one go
frames = [b"".join(frames)]
buffers = coalesce_buffers([header, *frames])

if len(frames) > 1:
self._transport.writelines(frames)
if len(buffers) > 1:
self._transport.writelines(buffers)
else:
self._transport.write(frames[0])
self._transport.write(buffers[0])
if self._transport.is_closing():
await asyncio.sleep(0)
jcrist marked this conversation as resolved.
Show resolved Hide resolved
elif self._paused:
jcrist marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -656,6 +707,9 @@ class _ZeroCopyWriter:
Note that this workaround isn't used with the windows ProactorEventLoop or
uvloop."""

SENDMSG_MAX_SIZE = 1024 * 1024 # 1 MiB
SENDMSG_MAX_COUNT = 16
jcrist marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, transport):
self.transport = transport
self._buffers = collections.deque()
Expand All @@ -666,17 +720,34 @@ def _buffer_append(self, data):

def _buffer_peek(self):
offset = self._offset
buf = self._buffers[0]
return buf[offset:] if offset else buf
buffers = []
size = 0
count = 0
for b in self._buffers:
if offset:
b = b[offset:]
offset = 0
buffers.append(b)
size += len(b)
count += 1
if size > self.SENDMSG_MAX_SIZE or count == self.SENDMSG_MAX_COUNT:
jcrist marked this conversation as resolved.
Show resolved Hide resolved
break
return buffers

def _buffer_advance(self, size):
b = self._buffers[0]
b_len = len(b) - self._offset
if b_len == size:
self._buffers.popleft()
self._offset = 0
else:
self._offset += size
offset = self._offset
buffers = self._buffers
while size:
b = buffers[0]
b_len = len(b) - offset
if b_len <= size:
buffers.popleft()
size -= b_len
offset = 0
else:
offset += size
break
self._offset = offset

def write(self, data):
transport = self.transport
Expand Down Expand Up @@ -749,9 +820,12 @@ def get_extra_info(self, key):
return self.transport.get_extra_info(key)

def _do_bulk_write(self):
buf = self._buffer_peek()
n = self.transport._sock.send(buf)
if n:
buffers = self._buffer_peek()
if len(buffers) == 1:
n = self.transport._sock.send(buffers[0])
self._buffer_advance(n)
else:
n = self.transport._sock.sendmsg(buffers)
self._buffer_advance(n)

def _on_write_ready(self):
Expand Down