Skip to content

Commit

Permalink
control de/ser offload (#3793)
Browse files Browse the repository at this point in the history
* start

* Apply to to_frames too

* Apply to listeners (TCP, UCX)

* listen arg for scheduler only

* black
  • Loading branch information
martindurant authored May 25, 2020
1 parent 5f5ebaf commit 0ec78f8
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 17 deletions.
1 change: 1 addition & 0 deletions distributed/comm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Comm(ABC):

def __init__(self):
self._instances.add(self)
self.allow_offload = True # for deserialization in utils.from_frames
self.name = None

# XXX add set_close_callback()?
Expand Down
16 changes: 14 additions & 2 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ async def read(self, deserializers=None):
else:
try:
msg = await from_frames(
frames, deserialize=self.deserialize, deserializers=deserializers
frames,
deserialize=self.deserialize,
deserializers=deserializers,
allow_offload=self.allow_offload,
)
except EOFError:
# Frames possibly garbled or truncated by communication error
Expand All @@ -216,6 +219,7 @@ async def write(self, msg, serializers=None, on_error="message"):

frames = await to_frames(
msg,
allow_offload=self.allow_offload,
serializers=serializers,
on_error=on_error,
context={"sender": self._local_addr, "recipient": self._peer_addr},
Expand Down Expand Up @@ -378,12 +382,19 @@ def _get_connect_args(self, **connection_args):

class BaseTCPListener(Listener, RequireEncryptionMixin):
def __init__(
self, address, comm_handler, deserialize=True, default_port=0, **connection_args
self,
address,
comm_handler,
deserialize=True,
allow_offload=True,
default_port=0,
**connection_args
):
self._check_encryption(address, connection_args)
self.ip, self.port = parse_host_port(address, default_port)
self.comm_handler = comm_handler
self.deserialize = deserialize
self.allow_offload = allow_offload
self.server_args = self._get_server_args(**connection_args)
self.tcp_server = None
self.bound_address = None
Expand Down Expand Up @@ -432,6 +443,7 @@ async def _handle_stream(self, stream, address):
logger.debug("Incoming connection from %r to %r", address, self.contact_address)
local_address = self.prefix + get_stream_address(stream)
comm = self.comm_class(stream, local_address, address, self.deserialize)
comm.allow_offload = self.allow_offload
await self.comm_handler(comm)

def get_host_port(self):
Expand Down
19 changes: 16 additions & 3 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,10 @@ async def write(
serializers = ("cuda", "dask", "pickle", "error")
# msg can also be a list of dicts when sending batched messages
frames = await to_frames(
msg, serializers=serializers, on_error=on_error
msg,
serializers=serializers,
on_error=on_error,
allow_offload=self.allow_offload,
)
nframes = len(frames)
cuda_frames = tuple(
Expand Down Expand Up @@ -261,7 +264,10 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
for each_frame in recv_frames:
await self.ep.recv(each_frame)
msg = await from_frames(
frames, deserialize=self.deserialize, deserializers=deserializers
frames,
deserialize=self.deserialize,
deserializers=deserializers,
allow_offload=self.allow_offload,
)
return msg

Expand Down Expand Up @@ -310,13 +316,19 @@ class UCXListener(Listener):
encrypted = UCXConnector.encrypted

def __init__(
self, address: str, comm_handler: None, deserialize=False, **connection_args
self,
address: str,
comm_handler: None,
deserialize=False,
allow_offload=True,
**connection_args
):
if not address.startswith("ucx"):
address = "ucx://" + address
self.ip, self._input_port = parse_host_port(address, default_port=0)
self.comm_handler = comm_handler
self.deserialize = deserialize
self.allow_offload = allow_offload
self._ep = None # type: ucp.Endpoint
self.ucp_server = None
self.connection_args = connection_args
Expand All @@ -337,6 +349,7 @@ async def serve_forever(client_ep):
peer_addr=self.address,
deserialize=self.deserialize,
)
ucx.allow_offload = self.allow_offload
if self.comm_handler:
await self.comm_handler(ucx)

Expand Down
30 changes: 21 additions & 9 deletions distributed/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
FRAME_OFFLOAD_THRESHOLD = parse_bytes(FRAME_OFFLOAD_THRESHOLD)


async def to_frames(msg, serializers=None, on_error="message", context=None):
async def to_frames(
msg, serializers=None, on_error="message", context=None, allow_offload=True
):
"""
Serialize a message into a list of Distributed protocol frames.
"""
Expand All @@ -38,22 +40,25 @@ def _to_frames():
logger.exception(e)
raise

try:
msg_size = sizeof(msg)
except RecursionError:
msg_size = math.inf
if FRAME_OFFLOAD_THRESHOLD and allow_offload:
try:
msg_size = sizeof(msg)
except RecursionError:
msg_size = math.inf
else:
msg_size = 0

if FRAME_OFFLOAD_THRESHOLD and msg_size > FRAME_OFFLOAD_THRESHOLD:
if allow_offload and FRAME_OFFLOAD_THRESHOLD and msg_size > FRAME_OFFLOAD_THRESHOLD:
return await offload(_to_frames)
else:
return _to_frames()


async def from_frames(frames, deserialize=True, deserializers=None):
async def from_frames(frames, deserialize=True, deserializers=None, allow_offload=True):
"""
Unserialize a list of Distributed protocol frames.
"""
size = sum(map(nbytes, frames))
size = False

def _from_frames():
try:
Expand All @@ -69,7 +74,14 @@ def _from_frames():
logger.error("truncated data stream (%d bytes): %s", size, datastr)
raise

if deserialize and FRAME_OFFLOAD_THRESHOLD and size > FRAME_OFFLOAD_THRESHOLD:
if allow_offload and deserialize and FRAME_OFFLOAD_THRESHOLD:
size = sum(map(nbytes, frames))
if (
allow_offload
and deserialize
and FRAME_OFFLOAD_THRESHOLD
and size > FRAME_OFFLOAD_THRESHOLD
):
res = await offload(_from_frames)
else:
res = _from_frames()
Expand Down
11 changes: 9 additions & 2 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def port(self):
def identity(self, comm=None):
return {"type": type(self).__name__, "id": self.id}

async def listen(self, port_or_addr=None, **kwargs):
async def listen(self, port_or_addr=None, allow_offload=True, **kwargs):
if port_or_addr is None:
port_or_addr = self.default_port
if isinstance(port_or_addr, int):
Expand All @@ -359,7 +359,11 @@ async def listen(self, port_or_addr=None, **kwargs):
addr = port_or_addr
assert isinstance(addr, str)
listener = await listen(
addr, self.handle_comm, deserialize=self.deserialize, **kwargs,
addr,
self.handle_comm,
deserialize=self.deserialize,
allow_offload=allow_offload,
**kwargs,
)
self.listeners.append(listener)

Expand Down Expand Up @@ -863,6 +867,7 @@ def __init__(
limit=512,
deserialize=True,
serializers=None,
allow_offload=True,
deserializers=None,
connection_args=None,
timeout=None,
Expand All @@ -873,6 +878,7 @@ def __init__(
self.available = defaultdict(set)
# Invariant: len(occupied) == active
self.occupied = defaultdict(set)
self.allow_offload = allow_offload
self.deserialize = deserialize
self.serializers = serializers
self.deserializers = deserializers if deserializers is not None else serializers
Expand Down Expand Up @@ -953,6 +959,7 @@ async def connect(self, addr, timeout=None):
)
comm.name = "ConnectionPool"
comm._pool = weakref.ref(self)
comm.allow_offload = self.allow_offload
self._created.add(comm)
except Exception:
self.semaphore.release()
Expand Down
5 changes: 4 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,7 @@ def __init__(

setproctitle("dask-scheduler [not started]")
Scheduler._instances.add(self)
self.rpc.allow_offload = False

##################
# Administration #
Expand Down Expand Up @@ -1438,7 +1439,9 @@ async def start(self):
c.cancel()

for addr in self._start_address:
await self.listen(addr, **self.security.get_listen_args("scheduler"))
await self.listen(
addr, allow_offload=False, **self.security.get_listen_args("scheduler")
)
self.ip = get_address_host(self.listen_address)
listen_ip = self.ip

Expand Down

0 comments on commit 0ec78f8

Please sign in to comment.