diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 0d87595bf75..2fc08313388 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -4,8 +4,10 @@ from .compression import decompress, maybe_compress from .serialize import ( + MsgpackList, Serialize, Serialized, + SerializedCallable, merge_and_deserialize, msgpack_decode_default, msgpack_encode_default, @@ -47,27 +49,32 @@ def _inplace_compress_frames(header, frames): def _encode_default(obj): typ = type(obj) - if typ is Serialize or typ is Serialized: - if typ is Serialize: - obj = obj.data - offset = len(frames) - if typ is Serialized: - sub_header, sub_frames = obj.header, obj.frames - else: - sub_header, sub_frames = serialize_and_split( - obj, serializers=serializers, on_error=on_error, context=context - ) - _inplace_compress_frames(sub_header, sub_frames) - sub_header["num-sub-frames"] = len(sub_frames) - frames.append( - msgpack.dumps( - sub_header, default=msgpack_encode_default, use_bin_type=True - ) - ) - frames.extend(sub_frames) - return {"__Serialized__": offset} + + ret = msgpack_encode_default(obj) + if ret is not obj: + return ret + + print(f"msgpack - fallback: {type(obj)}") + + if typ is Serialize: + obj = obj.data # TODO: remove Serialize/to_serialize completely + + offset = len(frames) + if typ is Serialized: + sub_header, sub_frames = obj.header, obj.frames else: - return msgpack_encode_default(obj) + sub_header, sub_frames = serialize_and_split( + obj, serializers=serializers, on_error=on_error, context=context + ) + _inplace_compress_frames(sub_header, sub_frames) + sub_header["num-sub-frames"] = len(sub_frames) + frames.append( + msgpack.dumps( + sub_header, default=msgpack_encode_default, use_bin_type=True + ) + ) + frames.extend(sub_frames) + return {"__Serialized__": offset, "callable": callable(obj)} frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True) return frames @@ -89,7 +96,7 @@ def _decode_default(obj): frames[offset], object_hook=msgpack_decode_default, use_list=False, - **msgpack_opts + **msgpack_opts, ) offset += 1 sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] @@ -99,9 +106,21 @@ def _decode_default(obj): return merge_and_deserialize( sub_header, sub_frames, deserializers=deserializers ) + elif obj["callable"]: + return SerializedCallable(sub_header, sub_frames) else: return Serialized(sub_header, sub_frames) else: + # Notice, even though `msgpack_decode_default()` supports + # `__MsgpackList__`, we decode it here explicitly. This way + # we can delay the convertion to a regular `list` until it + # gets to a worker. + if "__MsgpackList__" in obj: + if deserialize: + return list(obj["as-tuple"]) + else: + return MsgpackList(obj["as-tuple"]) + return msgpack_decode_default(obj) return msgpack.loads( diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 50d87953111..c5d04706944 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -1,3 +1,4 @@ +import collections.abc import importlib import traceback from array import array @@ -7,7 +8,7 @@ import msgpack import dask -from dask.base import normalize_token +from dask.base import normalize_token, tokenize from ..utils import ensure_bytes, has_keyword, typename from . import pickle @@ -20,6 +21,8 @@ dask_deserialize = dask.utils.Dispatch("dask_deserialize") _cached_allowed_modules = {} +non_list_collection_types = (tuple, set, frozenset) +collection_types = (list,) + non_list_collection_types def dask_dumps(x, context=None): @@ -106,6 +109,48 @@ def import_allowed_module(name): ) +class MsgpackList(collections.abc.MutableSequence): + def __init__(self, x): + self.data = x + + def __getitem__(self, i): + return self.data[i] + + def __delitem__(self, i): + del self.data[i] + + def __setitem__(self, i, v): + self.data[i] = v + + def __len__(self): + return len(self.data) + + def insert(self, i, v): + return self.data.insert(i, v) + + +def msgpack_persist_lists(obj): + typ = type(obj) + if typ is list: + return MsgpackList([msgpack_persist_lists(o) for o in obj]) + elif typ in non_list_collection_types: + return typ(msgpack_persist_lists(o) for o in obj) + elif typ is dict: + return {k: msgpack_persist_lists(v) for k, v in obj.items()} + else: + return obj + + +def msgpack_unpersist_lists(obj): + typ = type(obj) + if typ is MsgpackList: + return list(msgpack_unpersist_lists(o) for o in obj.data) + elif typ in collection_types: + return typ(msgpack_unpersist_lists(o) for o in obj) + else: + return obj + + def msgpack_decode_default(obj): """ Custom packer/unpacker for msgpack @@ -116,15 +161,10 @@ def msgpack_decode_default(obj): return getattr(typ, obj["name"]) if "__Set__" in obj: - return set(obj["as-list"]) + return set(obj["as-tuple"]) - if "__Serialized__" in obj: - # Notice, the data here is marked a Serialized rather than deserialized. This - # is because deserialization requires Pickle which the Scheduler cannot run - # because of security reasons. - # By marking it Serialized, the data is passed through to the workers that - # eventually will deserialize it. - return Serialized(*obj["data"]) + if "__MsgpackList__" in obj: + return list(obj["as-tuple"]) return obj @@ -134,9 +174,6 @@ def msgpack_encode_default(obj): Custom packer/unpacker for msgpack """ - if isinstance(obj, Serialize): - return {"__Serialized__": True, "data": serialize(obj.data)} - if isinstance(obj, Enum): return { "__Enum__": True, @@ -146,14 +183,17 @@ def msgpack_encode_default(obj): } if isinstance(obj, set): - return {"__Set__": True, "as-list": list(obj)} + return {"__Set__": True, "as-tuple": tuple(obj)} + + if isinstance(obj, MsgpackList): + return {"__MsgpackList__": True, "as-tuple": tuple(obj.data)} return obj def msgpack_dumps(x): try: - frame = msgpack.dumps(x, use_bin_type=True) + frame = msgpack.dumps(x, default=msgpack_encode_default, use_bin_type=True) except Exception: raise NotImplementedError() else: @@ -161,7 +201,12 @@ def msgpack_dumps(x): def msgpack_loads(header, frames): - return msgpack.loads(b"".join(frames), use_list=False, **msgpack_opts) + return msgpack.loads( + b"".join(frames), + object_hook=msgpack_decode_default, + use_list=False, + **msgpack_opts, + ) def serialization_error_loads(header, frames): @@ -238,71 +283,18 @@ def serialize(x, serializers=None, on_error="message", context=None): if isinstance(x, Serialized): return x.header, x.frames - if type(x) in (list, set, tuple, dict): - iterate_collection = False - if type(x) is list and "msgpack" in serializers: - # Note: "msgpack" will always convert lists to tuples - # (see GitHub #3716), so we should iterate - # through the list if "msgpack" comes before "pickle" - # in the list of serializers. - iterate_collection = ("pickle" not in serializers) or ( - serializers.index("pickle") > serializers.index("msgpack") - ) - if not iterate_collection: - # Check for "dask"-serializable data in dict/list/set - iterate_collection = check_dask_serializable(x) - - # Determine whether keys are safe to be serialized with msgpack - if type(x) is dict and iterate_collection: - try: - msgpack.dumps(list(x.keys())) - except Exception: - dict_safe = False - else: - dict_safe = True - + # Note: "msgpack" will always convert lists to tuple (see GitHub #3716), + # so we should persist lists if "msgpack" comes before "pickle" + # in the list of serializers. if ( - type(x) in (list, set, tuple) - and iterate_collection - or type(x) is dict - and iterate_collection - and dict_safe + type(x) is list + and "msgpack" in serializers + and ( + "pickle" not in serializers + or serializers.index("pickle") > serializers.index("msgpack") + ) ): - if isinstance(x, dict): - headers_frames = [] - for k, v in x.items(): - _header, _frames = serialize( - v, serializers=serializers, on_error=on_error, context=context - ) - _header["key"] = k - headers_frames.append((_header, _frames)) - else: - headers_frames = [ - serialize( - obj, serializers=serializers, on_error=on_error, context=context - ) - for obj in x - ] - - frames = [] - lengths = [] - compressions = [] - for _header, _frames in headers_frames: - frames.extend(_frames) - length = len(_frames) - lengths.append(length) - compressions.extend(_header.get("compression") or [None] * len(_frames)) - - headers = [obj[0] for obj in headers_frames] - headers = { - "sub-headers": headers, - "is-collection": True, - "frame-lengths": lengths, - "type-serialized": type(x).__name__, - } - if any(compression is not None for compression in compressions): - headers["compression"] = compressions - return headers, frames + x = msgpack_persist_lists(x) tb = "" @@ -347,37 +339,6 @@ def deserialize(header, frames, deserializers=None): -------- serialize """ - if "is-collection" in header: - headers = header["sub-headers"] - lengths = header["frame-lengths"] - cls = {"tuple": tuple, "list": list, "set": set, "dict": dict}[ - header["type-serialized"] - ] - - start = 0 - if cls is dict: - d = {} - for _header, _length in zip(headers, lengths): - k = _header.pop("key") - d[k] = deserialize( - _header, - frames[start : start + _length], - deserializers=deserializers, - ) - start += _length - return d - else: - lst = [] - for _header, _length in zip(headers, lengths): - lst.append( - deserialize( - _header, - frames[start : start + _length], - deserializers=deserializers, - ) - ) - start += _length - return cls(lst) name = header.get("serializer") if deserializers is not None and name not in deserializers: @@ -511,6 +472,14 @@ def __eq__(self, other): def __ne__(self, other): return not (self == other) + def __hash__(self): + return hash(tokenize((self.header, self.frames))) + + +class SerializedCallable(Serialized): + def __call__(self) -> None: + raise NotImplementedError + def nested_deserialize(x): """ @@ -529,7 +498,7 @@ def replace_inner(x): typ = type(v) if typ is dict or typ is list: x[k] = replace_inner(v) - elif typ is Serialize: + elif typ is Serialize or typ is MsgpackList: x[k] = v.data elif typ is Serialized: x[k] = deserialize(v.header, v.frames) @@ -540,11 +509,14 @@ def replace_inner(x): typ = type(v) if typ is dict or typ is list: x[k] = replace_inner(v) - elif typ is Serialize: + elif typ is Serialize or typ is MsgpackList: x[k] = v.data elif typ is Serialized: x[k] = deserialize(v.header, v.frames) + elif type(x) is MsgpackList: + x = replace_inner(x.data) + return x return replace_inner(x) diff --git a/distributed/protocol/tests/test_collection.py b/distributed/protocol/tests/test_collection.py index 32d11a74755..f7db7c61216 100644 --- a/distributed/protocol/tests/test_collection.py +++ b/distributed/protocol/tests/test_collection.py @@ -25,13 +25,6 @@ def test_serialize_collection(collection, y, y_serializer): t = deserialize(header, frames, deserializers=("dask", "pickle", "error")) assert isinstance(t, collection) - assert header["is-collection"] is True - sub_headers = header["sub-headers"] - - if collection is not dict: - assert sub_headers[0]["serializer"] == "dask" - assert sub_headers[1]["serializer"] == y_serializer - if collection is dict: assert (t["x"] == x).all() assert str(t["y"]) == str(y) @@ -43,11 +36,3 @@ def test_serialize_collection(collection, y, y_serializer): def test_large_collections_serialize_simply(): header, frames = serialize(tuple(range(1000))) assert len(frames) == 1 - - -def test_nested_types(): - x = np.ones(5) - header, frames = serialize([[[x]]]) - assert "dask" in str(header) - assert len(frames) == 1 - assert x.data == np.frombuffer(frames[0]).data diff --git a/distributed/protocol/tests/test_collection_cuda.py b/distributed/protocol/tests/test_collection_cuda.py index a50fb7e2bb8..93431d1df5c 100644 --- a/distributed/protocol/tests/test_collection_cuda.py +++ b/distributed/protocol/tests/test_collection_cuda.py @@ -2,7 +2,7 @@ from dask.dataframe.utils import assert_eq -from distributed.protocol import deserialize, serialize +from distributed.protocol import dumps, loads @pytest.mark.parametrize("collection", [tuple, dict]) @@ -14,19 +14,13 @@ def test_serialize_cupy(collection, y, y_serializer): if y is not None: y = cupy.arange(y) if issubclass(collection, dict): - header, frames = serialize( - {"x": x, "y": y}, serializers=("cuda", "dask", "pickle") - ) + frames = dumps({"x": x, "y": y}, serializers=("cuda", "dask", "pickle")) else: - header, frames = serialize((x, y), serializers=("cuda", "dask", "pickle")) - t = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + frames = dumps((x, y), serializers=("cuda", "dask", "pickle")) - assert header["is-collection"] is True - sub_headers = header["sub-headers"] - assert sub_headers[0]["serializer"] == "cuda" - assert sub_headers[1]["serializer"] == y_serializer - assert isinstance(t, collection) + assert any(isinstance(f, cupy.ndarray) for f in frames) + t = loads(frames, deserializers=("cuda", "dask", "pickle", "error")) assert ((t["x"] if isinstance(t, dict) else t[0]) == x).all() if y is None: assert (t["y"] if isinstance(t, dict) else t[1]) is None @@ -46,19 +40,12 @@ def test_serialize_pandas_pandas(collection, df2, df2_serializer): if df2 is not None: df2 = cudf.from_pandas(pd.DataFrame(df2)) if issubclass(collection, dict): - header, frames = serialize( - {"df1": df1, "df2": df2}, serializers=("cuda", "dask", "pickle") - ) + frames = dumps({"df1": df1, "df2": df2}, serializers=("cuda", "dask", "pickle")) else: - header, frames = serialize((df1, df2), serializers=("cuda", "dask", "pickle")) - t = deserialize(header, frames, deserializers=("cuda", "dask", "pickle")) - - assert header["is-collection"] is True - sub_headers = header["sub-headers"] - assert sub_headers[0]["serializer"] == "cuda" - assert sub_headers[1]["serializer"] == df2_serializer - assert isinstance(t, collection) + frames = dumps((df1, df2), serializers=("cuda", "dask", "pickle")) + assert any(isinstance(f, cudf.core.buffer.Buffer) for f in frames) + t = loads(frames, deserializers=("cuda", "dask", "pickle")) assert_eq(t["df1"] if isinstance(t, dict) else t[0], df1) if df2 is None: assert (t["df2"] if isinstance(t, dict) else t[1]) is None diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index cfb767bc0cb..9c84e02cc46 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -19,6 +19,7 @@ Serialize, Serialized, dask_serialize, + dask_deserialize, deserialize, deserialize_bytes, dumps, @@ -419,12 +420,16 @@ class MyObj: @dask_serialize.register(MyObj) def _(x): - header = {"compression": [False]} + header = {"compression": (False,)} frames = [b""] return header, frames - header, frames = serialize([MyObj(), MyObj()]) - assert header["compression"] == [False, False] + @dask_deserialize.register(MyObj) + def _(header, frames): + assert header["compression"] == (False,) + + frames = dumps([MyObj(), MyObj()]) + loads(frames) @pytest.mark.parametrize( @@ -468,7 +473,7 @@ def test_check_dask_serializable(data, is_serializable): ) def test_serialize_lists(serializers): data_in = ["a", 2, "c", None, "e", 6] - header, frames = serialize(data_in, serializers=serializers) + header, frames = serialize(data_in, serializers=serializers, on_error="error") data_out = deserialize(header, frames) assert data_in == data_out diff --git a/distributed/worker.py b/distributed/worker.py index 9e79ef5ab6e..cfac3df42da 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -47,6 +47,7 @@ from .node import ServerNode from .proctitle import setproctitle from .protocol import deserialize_bytes, pickle, serialize_bytelist, to_serialize +from .protocol.serialize import msgpack_persist_lists from .pubsub import PubSubWorkerExtension from .security import Security from .sizeof import safe_sizeof as sizeof @@ -3556,6 +3557,7 @@ def dumps_task(task): >>> dumps_task(1) # doctest: +SKIP {'task': b'\x80\x04\x95\x03\x00\x00\x00\x00\x00\x00\x00K\x01.'} """ + return msgpack_persist_lists(task) if istask(task): if task[0] is apply and not any(map(_maybe_complex, task[2:])): d = {"function": dumps_function(task[1]), "args": warn_dumps(task[2])}