Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Apr 14, 2021
1 parent 6b69342 commit f8b7dc9
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 173 deletions.
61 changes: 40 additions & 21 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from .compression import decompress, maybe_compress
from .serialize import (
MsgpackList,
Serialize,
Serialized,
SerializedCallable,
merge_and_deserialize,
msgpack_decode_default,
msgpack_encode_default,
Expand Down Expand Up @@ -47,27 +49,32 @@ def _inplace_compress_frames(header, frames):

def _encode_default(obj):
typ = type(obj)
if typ is Serialize or typ is Serialized:
if typ is Serialize:
obj = obj.data
offset = len(frames)
if typ is Serialized:
sub_header, sub_frames = obj.header, obj.frames
else:
sub_header, sub_frames = serialize_and_split(
obj, serializers=serializers, on_error=on_error, context=context
)
_inplace_compress_frames(sub_header, sub_frames)
sub_header["num-sub-frames"] = len(sub_frames)
frames.append(
msgpack.dumps(
sub_header, default=msgpack_encode_default, use_bin_type=True
)
)
frames.extend(sub_frames)
return {"__Serialized__": offset}

ret = msgpack_encode_default(obj)
if ret is not obj:
return ret

print(f"msgpack - fallback: {type(obj)}")

if typ is Serialize:
obj = obj.data # TODO: remove Serialize/to_serialize completely

offset = len(frames)
if typ is Serialized:
sub_header, sub_frames = obj.header, obj.frames
else:
return msgpack_encode_default(obj)
sub_header, sub_frames = serialize_and_split(
obj, serializers=serializers, on_error=on_error, context=context
)
_inplace_compress_frames(sub_header, sub_frames)
sub_header["num-sub-frames"] = len(sub_frames)
frames.append(
msgpack.dumps(
sub_header, default=msgpack_encode_default, use_bin_type=True
)
)
frames.extend(sub_frames)
return {"__Serialized__": offset, "callable": callable(obj)}

frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True)
return frames
Expand All @@ -89,7 +96,7 @@ def _decode_default(obj):
frames[offset],
object_hook=msgpack_decode_default,
use_list=False,
**msgpack_opts
**msgpack_opts,
)
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
Expand All @@ -99,9 +106,21 @@ def _decode_default(obj):
return merge_and_deserialize(
sub_header, sub_frames, deserializers=deserializers
)
elif obj["callable"]:
return SerializedCallable(sub_header, sub_frames)
else:
return Serialized(sub_header, sub_frames)
else:
# Notice, even though `msgpack_decode_default()` supports
# `__MsgpackList__`, we decode it here explicitly. This way
# we can delay the convertion to a regular `list` until it
# gets to a worker.
if "__MsgpackList__" in obj:
if deserialize:
return list(obj["as-tuple"])
else:
return MsgpackList(obj["as-tuple"])

return msgpack_decode_default(obj)

return msgpack.loads(
Expand Down
194 changes: 83 additions & 111 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections.abc
import importlib
import traceback
from array import array
Expand All @@ -7,7 +8,7 @@
import msgpack

import dask
from dask.base import normalize_token
from dask.base import normalize_token, tokenize

from ..utils import ensure_bytes, has_keyword, typename
from . import pickle
Expand All @@ -20,6 +21,8 @@
dask_deserialize = dask.utils.Dispatch("dask_deserialize")

_cached_allowed_modules = {}
non_list_collection_types = (tuple, set, frozenset)
collection_types = (list,) + non_list_collection_types


def dask_dumps(x, context=None):
Expand Down Expand Up @@ -106,6 +109,48 @@ def import_allowed_module(name):
)


class MsgpackList(collections.abc.MutableSequence):
def __init__(self, x):
self.data = x

def __getitem__(self, i):
return self.data[i]

def __delitem__(self, i):
del self.data[i]

def __setitem__(self, i, v):
self.data[i] = v

def __len__(self):
return len(self.data)

def insert(self, i, v):
return self.data.insert(i, v)


def msgpack_persist_lists(obj):
typ = type(obj)
if typ is list:
return MsgpackList([msgpack_persist_lists(o) for o in obj])
elif typ in non_list_collection_types:
return typ(msgpack_persist_lists(o) for o in obj)
elif typ is dict:
return {k: msgpack_persist_lists(v) for k, v in obj.items()}
else:
return obj


def msgpack_unpersist_lists(obj):
typ = type(obj)
if typ is MsgpackList:
return list(msgpack_unpersist_lists(o) for o in obj.data)
elif typ in collection_types:
return typ(msgpack_unpersist_lists(o) for o in obj)
else:
return obj


def msgpack_decode_default(obj):
"""
Custom packer/unpacker for msgpack
Expand All @@ -116,15 +161,10 @@ def msgpack_decode_default(obj):
return getattr(typ, obj["name"])

if "__Set__" in obj:
return set(obj["as-list"])
return set(obj["as-tuple"])

if "__Serialized__" in obj:
# Notice, the data here is marked a Serialized rather than deserialized. This
# is because deserialization requires Pickle which the Scheduler cannot run
# because of security reasons.
# By marking it Serialized, the data is passed through to the workers that
# eventually will deserialize it.
return Serialized(*obj["data"])
if "__MsgpackList__" in obj:
return list(obj["as-tuple"])

return obj

Expand All @@ -134,9 +174,6 @@ def msgpack_encode_default(obj):
Custom packer/unpacker for msgpack
"""

if isinstance(obj, Serialize):
return {"__Serialized__": True, "data": serialize(obj.data)}

if isinstance(obj, Enum):
return {
"__Enum__": True,
Expand All @@ -146,22 +183,30 @@ def msgpack_encode_default(obj):
}

if isinstance(obj, set):
return {"__Set__": True, "as-list": list(obj)}
return {"__Set__": True, "as-tuple": tuple(obj)}

if isinstance(obj, MsgpackList):
return {"__MsgpackList__": True, "as-tuple": tuple(obj.data)}

return obj


def msgpack_dumps(x):
try:
frame = msgpack.dumps(x, use_bin_type=True)
frame = msgpack.dumps(x, default=msgpack_encode_default, use_bin_type=True)
except Exception:
raise NotImplementedError()
else:
return {"serializer": "msgpack"}, [frame]


def msgpack_loads(header, frames):
return msgpack.loads(b"".join(frames), use_list=False, **msgpack_opts)
return msgpack.loads(
b"".join(frames),
object_hook=msgpack_decode_default,
use_list=False,
**msgpack_opts,
)


def serialization_error_loads(header, frames):
Expand Down Expand Up @@ -238,71 +283,18 @@ def serialize(x, serializers=None, on_error="message", context=None):
if isinstance(x, Serialized):
return x.header, x.frames

if type(x) in (list, set, tuple, dict):
iterate_collection = False
if type(x) is list and "msgpack" in serializers:
# Note: "msgpack" will always convert lists to tuples
# (see GitHub #3716), so we should iterate
# through the list if "msgpack" comes before "pickle"
# in the list of serializers.
iterate_collection = ("pickle" not in serializers) or (
serializers.index("pickle") > serializers.index("msgpack")
)
if not iterate_collection:
# Check for "dask"-serializable data in dict/list/set
iterate_collection = check_dask_serializable(x)

# Determine whether keys are safe to be serialized with msgpack
if type(x) is dict and iterate_collection:
try:
msgpack.dumps(list(x.keys()))
except Exception:
dict_safe = False
else:
dict_safe = True

# Note: "msgpack" will always convert lists to tuple (see GitHub #3716),
# so we should persist lists if "msgpack" comes before "pickle"
# in the list of serializers.
if (
type(x) in (list, set, tuple)
and iterate_collection
or type(x) is dict
and iterate_collection
and dict_safe
type(x) is list
and "msgpack" in serializers
and (
"pickle" not in serializers
or serializers.index("pickle") > serializers.index("msgpack")
)
):
if isinstance(x, dict):
headers_frames = []
for k, v in x.items():
_header, _frames = serialize(
v, serializers=serializers, on_error=on_error, context=context
)
_header["key"] = k
headers_frames.append((_header, _frames))
else:
headers_frames = [
serialize(
obj, serializers=serializers, on_error=on_error, context=context
)
for obj in x
]

frames = []
lengths = []
compressions = []
for _header, _frames in headers_frames:
frames.extend(_frames)
length = len(_frames)
lengths.append(length)
compressions.extend(_header.get("compression") or [None] * len(_frames))

headers = [obj[0] for obj in headers_frames]
headers = {
"sub-headers": headers,
"is-collection": True,
"frame-lengths": lengths,
"type-serialized": type(x).__name__,
}
if any(compression is not None for compression in compressions):
headers["compression"] = compressions
return headers, frames
x = msgpack_persist_lists(x)

tb = ""

Expand Down Expand Up @@ -347,37 +339,6 @@ def deserialize(header, frames, deserializers=None):
--------
serialize
"""
if "is-collection" in header:
headers = header["sub-headers"]
lengths = header["frame-lengths"]
cls = {"tuple": tuple, "list": list, "set": set, "dict": dict}[
header["type-serialized"]
]

start = 0
if cls is dict:
d = {}
for _header, _length in zip(headers, lengths):
k = _header.pop("key")
d[k] = deserialize(
_header,
frames[start : start + _length],
deserializers=deserializers,
)
start += _length
return d
else:
lst = []
for _header, _length in zip(headers, lengths):
lst.append(
deserialize(
_header,
frames[start : start + _length],
deserializers=deserializers,
)
)
start += _length
return cls(lst)

name = header.get("serializer")
if deserializers is not None and name not in deserializers:
Expand Down Expand Up @@ -511,6 +472,14 @@ def __eq__(self, other):
def __ne__(self, other):
return not (self == other)

def __hash__(self):
return hash(tokenize((self.header, self.frames)))


class SerializedCallable(Serialized):
def __call__(self) -> None:
raise NotImplementedError


def nested_deserialize(x):
"""
Expand All @@ -529,7 +498,7 @@ def replace_inner(x):
typ = type(v)
if typ is dict or typ is list:
x[k] = replace_inner(v)
elif typ is Serialize:
elif typ is Serialize or typ is MsgpackList:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)
Expand All @@ -540,11 +509,14 @@ def replace_inner(x):
typ = type(v)
if typ is dict or typ is list:
x[k] = replace_inner(v)
elif typ is Serialize:
elif typ is Serialize or typ is MsgpackList:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)

elif type(x) is MsgpackList:
x = replace_inner(x.data)

return x

return replace_inner(x)
Expand Down
Loading

0 comments on commit f8b7dc9

Please sign in to comment.