diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 00a6d87a52a..852cb991c4b 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -40,6 +40,7 @@ class Comm(ABC): def __init__(self): self._instances.add(self) + self.allow_offload = True # for deserialization in utils.from_frames self.name = None # XXX add set_close_callback()? diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 769e9132abe..c610bde31ac 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -200,7 +200,10 @@ async def read(self, deserializers=None): else: try: msg = await from_frames( - frames, deserialize=self.deserialize, deserializers=deserializers + frames, + deserialize=self.deserialize, + deserializers=deserializers, + allow_offload=self.allow_offload, ) except EOFError: # Frames possibly garbled or truncated by communication error @@ -216,6 +219,7 @@ async def write(self, msg, serializers=None, on_error="message"): frames = await to_frames( msg, + allow_offload=self.allow_offload, serializers=serializers, on_error=on_error, context={"sender": self._local_addr, "recipient": self._peer_addr}, @@ -378,12 +382,19 @@ def _get_connect_args(self, **connection_args): class BaseTCPListener(Listener, RequireEncryptionMixin): def __init__( - self, address, comm_handler, deserialize=True, default_port=0, **connection_args + self, + address, + comm_handler, + deserialize=True, + allow_offload=True, + default_port=0, + **connection_args ): self._check_encryption(address, connection_args) self.ip, self.port = parse_host_port(address, default_port) self.comm_handler = comm_handler self.deserialize = deserialize + self.allow_offload = allow_offload self.server_args = self._get_server_args(**connection_args) self.tcp_server = None self.bound_address = None @@ -432,6 +443,7 @@ async def _handle_stream(self, stream, address): logger.debug("Incoming connection from %r to %r", address, self.contact_address) local_address = self.prefix + get_stream_address(stream) comm = self.comm_class(stream, local_address, address, self.deserialize) + comm.allow_offload = self.allow_offload await self.comm_handler(comm) def get_host_port(self): diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 22bf4ad52f3..54a85279ffd 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -176,7 +176,10 @@ async def write( serializers = ("cuda", "dask", "pickle", "error") # msg can also be a list of dicts when sending batched messages frames = await to_frames( - msg, serializers=serializers, on_error=on_error + msg, + serializers=serializers, + on_error=on_error, + allow_offload=self.allow_offload, ) nframes = len(frames) cuda_frames = tuple( @@ -261,7 +264,10 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): for each_frame in recv_frames: await self.ep.recv(each_frame) msg = await from_frames( - frames, deserialize=self.deserialize, deserializers=deserializers + frames, + deserialize=self.deserialize, + deserializers=deserializers, + allow_offload=self.allow_offload, ) return msg @@ -310,13 +316,19 @@ class UCXListener(Listener): encrypted = UCXConnector.encrypted def __init__( - self, address: str, comm_handler: None, deserialize=False, **connection_args + self, + address: str, + comm_handler: None, + deserialize=False, + allow_offload=True, + **connection_args ): if not address.startswith("ucx"): address = "ucx://" + address self.ip, self._input_port = parse_host_port(address, default_port=0) self.comm_handler = comm_handler self.deserialize = deserialize + self.allow_offload = allow_offload self._ep = None # type: ucp.Endpoint self.ucp_server = None self.connection_args = connection_args @@ -337,6 +349,7 @@ async def serve_forever(client_ep): peer_addr=self.address, deserialize=self.deserialize, ) + ucx.allow_offload = self.allow_offload if self.comm_handler: await self.comm_handler(ucx) diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index b75663a14f2..d1a1a97e63c 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -21,7 +21,9 @@ FRAME_OFFLOAD_THRESHOLD = parse_bytes(FRAME_OFFLOAD_THRESHOLD) -async def to_frames(msg, serializers=None, on_error="message", context=None): +async def to_frames( + msg, serializers=None, on_error="message", context=None, allow_offload=True +): """ Serialize a message into a list of Distributed protocol frames. """ @@ -38,22 +40,25 @@ def _to_frames(): logger.exception(e) raise - try: - msg_size = sizeof(msg) - except RecursionError: - msg_size = math.inf + if FRAME_OFFLOAD_THRESHOLD and allow_offload: + try: + msg_size = sizeof(msg) + except RecursionError: + msg_size = math.inf + else: + msg_size = 0 - if FRAME_OFFLOAD_THRESHOLD and msg_size > FRAME_OFFLOAD_THRESHOLD: + if allow_offload and FRAME_OFFLOAD_THRESHOLD and msg_size > FRAME_OFFLOAD_THRESHOLD: return await offload(_to_frames) else: return _to_frames() -async def from_frames(frames, deserialize=True, deserializers=None): +async def from_frames(frames, deserialize=True, deserializers=None, allow_offload=True): """ Unserialize a list of Distributed protocol frames. """ - size = sum(map(nbytes, frames)) + size = False def _from_frames(): try: @@ -69,7 +74,14 @@ def _from_frames(): logger.error("truncated data stream (%d bytes): %s", size, datastr) raise - if deserialize and FRAME_OFFLOAD_THRESHOLD and size > FRAME_OFFLOAD_THRESHOLD: + if allow_offload and deserialize and FRAME_OFFLOAD_THRESHOLD: + size = sum(map(nbytes, frames)) + if ( + allow_offload + and deserialize + and FRAME_OFFLOAD_THRESHOLD + and size > FRAME_OFFLOAD_THRESHOLD + ): res = await offload(_from_frames) else: res = _from_frames() diff --git a/distributed/core.py b/distributed/core.py index 829d2095083..1220671d115 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -348,7 +348,7 @@ def port(self): def identity(self, comm=None): return {"type": type(self).__name__, "id": self.id} - async def listen(self, port_or_addr=None, **kwargs): + async def listen(self, port_or_addr=None, allow_offload=True, **kwargs): if port_or_addr is None: port_or_addr = self.default_port if isinstance(port_or_addr, int): @@ -359,7 +359,11 @@ async def listen(self, port_or_addr=None, **kwargs): addr = port_or_addr assert isinstance(addr, str) listener = await listen( - addr, self.handle_comm, deserialize=self.deserialize, **kwargs, + addr, + self.handle_comm, + deserialize=self.deserialize, + allow_offload=allow_offload, + **kwargs, ) self.listeners.append(listener) @@ -863,6 +867,7 @@ def __init__( limit=512, deserialize=True, serializers=None, + allow_offload=True, deserializers=None, connection_args=None, timeout=None, @@ -873,6 +878,7 @@ def __init__( self.available = defaultdict(set) # Invariant: len(occupied) == active self.occupied = defaultdict(set) + self.allow_offload = allow_offload self.deserialize = deserialize self.serializers = serializers self.deserializers = deserializers if deserializers is not None else serializers @@ -953,6 +959,7 @@ async def connect(self, addr, timeout=None): ) comm.name = "ConnectionPool" comm._pool = weakref.ref(self) + comm.allow_offload = self.allow_offload self._created.add(comm) except Exception: self.semaphore.release() diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d032086ec79..dd96550ed41 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1372,6 +1372,7 @@ def __init__( setproctitle("dask-scheduler [not started]") Scheduler._instances.add(self) + self.rpc.allow_offload = False ################## # Administration # @@ -1438,7 +1439,9 @@ async def start(self): c.cancel() for addr in self._start_address: - await self.listen(addr, **self.security.get_listen_args("scheduler")) + await self.listen( + addr, allow_offload=False, **self.security.get_listen_args("scheduler") + ) self.ip = get_address_host(self.listen_address) listen_ip = self.ip