Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] ToPickle - Unpickle on the Scheduler #5728

Merged
merged 13 commits into from
Mar 21, 2022
96 changes: 76 additions & 20 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

import msgpack

import dask.config

from . import pickle
from .compression import decompress, maybe_compress
from .serialize import (
Pickled,
Serialize,
Serialized,
ToPickle,
merge_and_deserialize,
msgpack_decode_default,
msgpack_encode_default,
Expand All @@ -16,6 +21,15 @@
logger = logging.getLogger(__name__)


def ensure_memoryview(obj):
"""Ensure `obj` is a memoryview of datatype bytes"""
ret = memoryview(obj)
if ret.nbytes:
return ret.cast("B")
else:
return ret


def dumps(
msg, serializers=None, on_error="message", context=None, frame_split_size=None
) -> list:
Expand Down Expand Up @@ -45,31 +59,59 @@ def _inplace_compress_frames(header, frames):

header["compression"] = tuple(compression)

def create_sub_frames(obj) -> list:
madsbk marked this conversation as resolved.
Show resolved Hide resolved
typ = type(obj)
if typ is Serialized:
sub_header, sub_frames = obj.header, obj.frames
else:
sub_header, sub_frames = serialize_and_split(
obj,
serializers=serializers,
on_error=on_error,
context=context,
size=frame_split_size,
)
_inplace_compress_frames(sub_header, sub_frames)
sub_header["num-sub-frames"] = len(sub_frames)
sub_header = msgpack.dumps(
sub_header, default=msgpack_encode_default, use_bin_type=True
)
return [sub_header] + sub_frames

def create_pickle_sub_frames(obj) -> list:
typ = type(obj)
if typ is Pickled:
sub_header, sub_frames = obj.header, obj.frames
else:
sub_frames = []
sub_header = {
"pickled-obj": pickle.dumps(
obj.data,
# In to support len() and slicing, we convert `PickleBuffer`
# objects to memoryviews of bytes.
buffer_callback=lambda x: sub_frames.append(
ensure_memoryview(x)
),
)
}
_inplace_compress_frames(sub_header, sub_frames)
madsbk marked this conversation as resolved.
Show resolved Hide resolved

sub_header["num-sub-frames"] = len(sub_frames)
sub_header = msgpack.dumps(sub_header)
return [sub_header] + sub_frames

frames = [None]

def _encode_default(obj):
typ = type(obj)
if typ is Serialize or typ is Serialized:
offset = len(frames)
if typ is Serialized:
sub_header, sub_frames = obj.header, obj.frames
else:
sub_header, sub_frames = serialize_and_split(
obj,
serializers=serializers,
on_error=on_error,
context=context,
size=frame_split_size,
)
_inplace_compress_frames(sub_header, sub_frames)
sub_header["num-sub-frames"] = len(sub_frames)
frames.append(
msgpack.dumps(
sub_header, default=msgpack_encode_default, use_bin_type=True
)
)
frames.extend(sub_frames)
frames.extend(create_sub_frames(obj))
return {"__Serialized__": offset}
elif typ is ToPickle or typ is Pickled:
offset = len(frames)
frames.extend(create_pickle_sub_frames(obj))
return {"__Pickled__": offset}
else:
return msgpack_encode_default(obj)

Expand All @@ -84,6 +126,8 @@ def _encode_default(obj):
def loads(frames, deserialize=True, deserializers=None):
"""Transform bytestream back into Python value"""

allow_pickle = dask.config.get("distributed.scheduler.pickle")

try:

def _decode_default(obj):
Expand All @@ -105,8 +149,20 @@ def _decode_default(obj):
)
else:
return Serialized(sub_header, sub_frames)
else:
return msgpack_decode_default(obj)

offset = obj.get("__Pickled__", 0)
if offset > 0:
sub_header = msgpack.loads(frames[offset])
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
if allow_pickle:
return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames)
else:
raise ValueError(
"Unpickle on the Scheduler isn't allowed, set `distributed.scheduler.pickle=true`"
)

return msgpack_decode_default(obj)

return msgpack.loads(
frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts
Expand Down
47 changes: 45 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,7 @@ def __hash__(self):


class Serialized:
"""
An object that is already serialized into header and frames
"""An object that is already serialized into header and frames

Normal serialization operations pass these objects through. This is
typically used within the scheduler which accepts messages that contain
Expand All @@ -545,6 +544,50 @@ def __ne__(self, other):
return not (self == other)


class ToPickle:
"""Mark an object that should be pickled

Both the scheduler and workers with automatically unpickle this
object on arrival
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a note about the "distributed.scheduler.pickle" config here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added:

    """Mark an object that should be pickled

    Both the scheduler and workers with automatically unpickle this
    object on arrival.

    Notice, this requires that the scheduler is allowed to use pickle.
    If the configuration option "distributed.scheduler.pickle" is set
    to False, the scheduler will raise an exception instead.
    """

"""

def __init__(self, data):
self.data = data

def __repr__(self):
return "<ToPickle: %s>" % str(self.data)

def __eq__(self, other):
return isinstance(other, type(self)) and other.data == self.data

def __ne__(self, other):
return not (self == other)

def __hash__(self):
return hash(self.data)


class Pickled:
"""An object that is already pickled into header and frames

Normal pickled objects are unpickled by the scheduler.
"""

def __init__(self, header, frames):
self.header = header
self.frames = frames

def __eq__(self, other):
return (
isinstance(other, type(self))
and other.header == self.header
and other.frames == self.frames
)

def __ne__(self, other):
return not (self == other)


def nested_deserialize(x):
"""
Replace all Serialize and Serialized values nested in *x*
Expand Down
35 changes: 35 additions & 0 deletions distributed/protocol/tests/test_to_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Dict

import dask.config
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer

from distributed.client import Client
from distributed.protocol.serialize import ToPickle
from distributed.utils_test import gen_cluster


class NonMsgPackSerializableLayer(MaterializedLayer):
"""Layer that use non-msgpack-serializable data"""
madsbk marked this conversation as resolved.
Show resolved Hide resolved

def __dask_distributed_pack__(self, *args, **kwargs):
ret = super().__dask_distributed_pack__(*args, **kwargs)
# Some info that contains a `list`, which msgpack will convert to
# a tuple if getting the chance.
madsbk marked this conversation as resolved.
Show resolved Hide resolved
ret["myinfo"] = ["myinfo"]
return ToPickle(ret)

@classmethod
def __dask_distributed_unpack__(cls, state, *args, **kwargs):
assert state["myinfo"] == ["myinfo"]
return super().__dask_distributed_unpack__(state, *args, **kwargs)


@gen_cluster(client=True)
async def test_non_msgpack_serializable_layer(c: Client, s, w1, w2):
with dask.config.set({"distributed.scheduler.allowed-imports": "test_to_pickle"}):
a = NonMsgPackSerializableLayer({"x": 42})
layers = {"a": a}
dependencies: Dict[str, set] = {"a": set()}
hg = HighLevelGraph(layers, dependencies)
res = await c.get(hg, "x", sync=False)
assert res == 42