diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index b10bfcba18..0e0ae003b5 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -2,10 +2,15 @@ import msgpack +import dask.config + +from distributed.protocol import pickle from distributed.protocol.compression import decompress, maybe_compress from distributed.protocol.serialize import ( + Pickled, Serialize, Serialized, + ToPickle, merge_and_deserialize, msgpack_decode_default, msgpack_encode_default, @@ -16,6 +21,15 @@ logger = logging.getLogger(__name__) +def ensure_memoryview(obj): + """Ensure `obj` is a memoryview of datatype bytes""" + ret = memoryview(obj) + if ret.nbytes: + return ret.cast("B") + else: + return ret + + def dumps( msg, serializers=None, on_error="message", context=None, frame_split_size=None ) -> list: @@ -45,31 +59,59 @@ def _inplace_compress_frames(header, frames): header["compression"] = tuple(compression) + def create_serialized_sub_frames(obj) -> list: + typ = type(obj) + if typ is Serialized: + sub_header, sub_frames = obj.header, obj.frames + else: + sub_header, sub_frames = serialize_and_split( + obj, + serializers=serializers, + on_error=on_error, + context=context, + size=frame_split_size, + ) + _inplace_compress_frames(sub_header, sub_frames) + sub_header["num-sub-frames"] = len(sub_frames) + sub_header = msgpack.dumps( + sub_header, default=msgpack_encode_default, use_bin_type=True + ) + return [sub_header] + sub_frames + + def create_pickled_sub_frames(obj) -> list: + typ = type(obj) + if typ is Pickled: + sub_header, sub_frames = obj.header, obj.frames + else: + sub_frames = [] + sub_header = { + "pickled-obj": pickle.dumps( + obj.data, + # In to support len() and slicing, we convert `PickleBuffer` + # objects to memoryviews of bytes. + buffer_callback=lambda x: sub_frames.append( + ensure_memoryview(x) + ), + ) + } + _inplace_compress_frames(sub_header, sub_frames) + + sub_header["num-sub-frames"] = len(sub_frames) + sub_header = msgpack.dumps(sub_header) + return [sub_header] + sub_frames + frames = [None] def _encode_default(obj): typ = type(obj) if typ is Serialize or typ is Serialized: offset = len(frames) - if typ is Serialized: - sub_header, sub_frames = obj.header, obj.frames - else: - sub_header, sub_frames = serialize_and_split( - obj, - serializers=serializers, - on_error=on_error, - context=context, - size=frame_split_size, - ) - _inplace_compress_frames(sub_header, sub_frames) - sub_header["num-sub-frames"] = len(sub_frames) - frames.append( - msgpack.dumps( - sub_header, default=msgpack_encode_default, use_bin_type=True - ) - ) - frames.extend(sub_frames) + frames.extend(create_serialized_sub_frames(obj)) return {"__Serialized__": offset} + elif typ is ToPickle or typ is Pickled: + offset = len(frames) + frames.extend(create_pickled_sub_frames(obj)) + return {"__Pickled__": offset} else: return msgpack_encode_default(obj) @@ -84,6 +126,8 @@ def _encode_default(obj): def loads(frames, deserialize=True, deserializers=None): """Transform bytestream back into Python value""" + allow_pickle = dask.config.get("distributed.scheduler.pickle") + try: def _decode_default(obj): @@ -105,8 +149,20 @@ def _decode_default(obj): ) else: return Serialized(sub_header, sub_frames) - else: - return msgpack_decode_default(obj) + + offset = obj.get("__Pickled__", 0) + if offset > 0: + sub_header = msgpack.loads(frames[offset]) + offset += 1 + sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] + if allow_pickle: + return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames) + else: + raise ValueError( + "Unpickle on the Scheduler isn't allowed, set `distributed.scheduler.pickle=true`" + ) + + return msgpack_decode_default(obj) return msgpack.loads( frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index c8bfc9dce7..b4daf5bd65 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -522,8 +522,7 @@ def __hash__(self): class Serialized: - """ - An object that is already serialized into header and frames + """An object that is already serialized into header and frames Normal serialization operations pass these objects through. This is typically used within the scheduler which accepts messages that contain @@ -545,6 +544,54 @@ def __ne__(self, other): return not (self == other) +class ToPickle: + """Mark an object that should be pickled + + Both the scheduler and workers with automatically unpickle this + object on arrival. + + Notice, this requires that the scheduler is allowed to use pickle. + If the configuration option "distributed.scheduler.pickle" is set + to False, the scheduler will raise an exception instead. + """ + + def __init__(self, data): + self.data = data + + def __repr__(self): + return "" % str(self.data) + + def __eq__(self, other): + return isinstance(other, type(self)) and other.data == self.data + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash(self.data) + + +class Pickled: + """An object that is already pickled into header and frames + + Normal pickled objects are unpickled by the scheduler. + """ + + def __init__(self, header, frames): + self.header = header + self.frames = frames + + def __eq__(self, other): + return ( + isinstance(other, type(self)) + and other.header == self.header + and other.frames == self.frames + ) + + def __ne__(self, other): + return not (self == other) + + def nested_deserialize(x): """ Replace all Serialize and Serialized values nested in *x* diff --git a/distributed/protocol/tests/test_to_pickle.py b/distributed/protocol/tests/test_to_pickle.py new file mode 100644 index 0000000000..7db7a5d973 --- /dev/null +++ b/distributed/protocol/tests/test_to_pickle.py @@ -0,0 +1,35 @@ +from typing import Dict + +import dask.config +from dask.highlevelgraph import HighLevelGraph, MaterializedLayer + +from distributed.client import Client +from distributed.protocol.serialize import ToPickle +from distributed.utils_test import gen_cluster + + +class NonMsgPackSerializableLayer(MaterializedLayer): + """Layer that uses non-msgpack-serializable data""" + + def __dask_distributed_pack__(self, *args, **kwargs): + ret = super().__dask_distributed_pack__(*args, **kwargs) + # Some info that contains a `list`, which msgpack will convert to + # a tuple if getting the chance. + ret["myinfo"] = ["myinfo"] + return ToPickle(ret) + + @classmethod + def __dask_distributed_unpack__(cls, state, *args, **kwargs): + assert state["myinfo"] == ["myinfo"] + return super().__dask_distributed_unpack__(state, *args, **kwargs) + + +@gen_cluster(client=True) +async def test_non_msgpack_serializable_layer(c: Client, s, w1, w2): + with dask.config.set({"distributed.scheduler.allowed-imports": "test_to_pickle"}): + a = NonMsgPackSerializableLayer({"x": 42}) + layers = {"a": a} + dependencies: Dict[str, set] = {"a": set()} + hg = HighLevelGraph(layers, dependencies) + res = await c.get(hg, "x", sync=False) + assert res == 42