diff --git a/distributed/comm/tests/test_ws.py b/distributed/comm/tests/test_ws.py index 08d1cbce0e..5bb826be13 100644 --- a/distributed/comm/tests/test_ws.py +++ b/distributed/comm/tests/test_ws.py @@ -209,3 +209,14 @@ async def test_wss_roundtrip(c, s, a, b): future = await c.scatter(x) y = await future assert (x == y).all() + + +@gen_cluster(client=True, scheduler_kwargs={"protocol": "ws://"}) +async def test_ws_roundtrip_large(c, s, a, b): + import numpy as np + + x = np.random.random(25000000) + + future = c.submit(lambda x: x, x) + y = await future + assert (x == y).all() diff --git a/distributed/comm/ws.py b/distributed/comm/ws.py index 46ce55b86e..2ab059496b 100644 --- a/distributed/comm/ws.py +++ b/distributed/comm/ws.py @@ -120,6 +120,10 @@ async def write(self, msg, serializers=None, on_error=None): }, frame_split_size=BIG_BYTES_SHARD_SIZE, ) + assert all(len(frame) <= BIG_BYTES_SHARD_SIZE for frame in frames), list( + map(len, frames) + ) + n = struct.pack("Q", len(frames)) try: await self.handler.write_message(n, binary=True) @@ -218,6 +222,10 @@ async def write(self, msg, serializers=None, on_error=None): }, frame_split_size=BIG_BYTES_SHARD_SIZE, ) + assert all(len(frame) <= BIG_BYTES_SHARD_SIZE for frame in frames), list( + map(len, frames) + ) + n = struct.pack("Q", len(frames)) try: await self.sock.write_message(n, binary=True) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 1be2d761e3..f4173ccb4c 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -74,6 +74,14 @@ def _encode_default(obj): return msgpack_encode_default(obj) frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True) + + if frame_split_size and len(frames[0]) > frame_split_size: + from distributed.protocol.utils import frame_split_size as split + + msg_frames = split(frames[0], n=frame_split_size) + header = msgpack.dumps({"large-header": True, "count": len(msg_frames)}) + frames = [header] + msg_frames + frames[1:] + return frames except Exception: @@ -108,9 +116,15 @@ def _decode_default(obj): else: return msgpack_decode_default(obj) - return msgpack.loads( + result = msgpack.loads( frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts ) + if isinstance(result, dict) and "large-header" in result: + frame = b"".join(frames[1 : result["count"] + 1]) + frames = [frame] + frames[result["count"] + 1 :] + return loads(frames, deserialize=deserialize, deserializers=deserializers) + else: + return result except Exception: logger.critical("Failed to deserialize", exc_info=True)