Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Apr 13, 2021
1 parent 040292d commit 72a7d1b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 129 deletions.
94 changes: 5 additions & 89 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,14 @@ def msgpack_dumps(x):
except Exception:
raise NotImplementedError()
else:
return {"serializer": "msgpack"}, [frame]
return {"serializer": "msgpack", "list-as-tuple": type(x) is list}, [frame]


def msgpack_loads(header, frames):
return msgpack.loads(b"".join(frames), use_list=False, **msgpack_opts)
ret = msgpack.loads(b"".join(frames), use_list=False, **msgpack_opts)
if header["list-as-tuple"]:
ret = list(ret)
return ret


def serialization_error_loads(header, frames):
Expand Down Expand Up @@ -228,7 +231,6 @@ def serialize(x, serializers=None, on_error="message", context=None):
return x.header, x.frames

if type(x) in (list, set, tuple, dict):
iterate_collection = False
if type(x) is list and "msgpack" in serializers:
# Note: "msgpack" will always convert lists to tuples
# (see GitHub #3716), so we should iterate
Expand All @@ -237,61 +239,6 @@ def serialize(x, serializers=None, on_error="message", context=None):
iterate_collection = ("pickle" not in serializers) or (
serializers.index("pickle") > serializers.index("msgpack")
)
if not iterate_collection:
# Check for "dask"-serializable data in dict/list/set
iterate_collection = check_dask_serializable(x)

# Determine whether keys are safe to be serialized with msgpack
if type(x) is dict and iterate_collection:
try:
msgpack.dumps(list(x.keys()))
except Exception:
dict_safe = False
else:
dict_safe = True

if (
type(x) in (list, set, tuple)
and iterate_collection
or type(x) is dict
and iterate_collection
and dict_safe
):
if isinstance(x, dict):
headers_frames = []
for k, v in x.items():
_header, _frames = serialize(
v, serializers=serializers, on_error=on_error, context=context
)
_header["key"] = k
headers_frames.append((_header, _frames))
else:
headers_frames = [
serialize(
obj, serializers=serializers, on_error=on_error, context=context
)
for obj in x
]

frames = []
lengths = []
compressions = []
for _header, _frames in headers_frames:
frames.extend(_frames)
length = len(_frames)
lengths.append(length)
compressions.extend(_header.get("compression") or [None] * len(_frames))

headers = [obj[0] for obj in headers_frames]
headers = {
"sub-headers": headers,
"is-collection": True,
"frame-lengths": lengths,
"type-serialized": type(x).__name__,
}
if any(compression is not None for compression in compressions):
headers["compression"] = compressions
return headers, frames

tb = ""

Expand Down Expand Up @@ -336,37 +283,6 @@ def deserialize(header, frames, deserializers=None):
--------
serialize
"""
if "is-collection" in header:
headers = header["sub-headers"]
lengths = header["frame-lengths"]
cls = {"tuple": tuple, "list": list, "set": set, "dict": dict}[
header["type-serialized"]
]

start = 0
if cls is dict:
d = {}
for _header, _length in zip(headers, lengths):
k = _header.pop("key")
d[k] = deserialize(
_header,
frames[start : start + _length],
deserializers=deserializers,
)
start += _length
return d
else:
lst = []
for _header, _length in zip(headers, lengths):
lst.append(
deserialize(
_header,
frames[start : start + _length],
deserializers=deserializers,
)
)
start += _length
return cls(lst)

name = header.get("serializer")
if deserializers is not None and name not in deserializers:
Expand Down
15 changes: 0 additions & 15 deletions distributed/protocol/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@ def test_serialize_collection(collection, y, y_serializer):
t = deserialize(header, frames, deserializers=("dask", "pickle", "error"))
assert isinstance(t, collection)

assert header["is-collection"] is True
sub_headers = header["sub-headers"]

if collection is not dict:
assert sub_headers[0]["serializer"] == "dask"
assert sub_headers[1]["serializer"] == y_serializer

if collection is dict:
assert (t["x"] == x).all()
assert str(t["y"]) == str(y)
Expand All @@ -43,11 +36,3 @@ def test_serialize_collection(collection, y, y_serializer):
def test_large_collections_serialize_simply():
header, frames = serialize(tuple(range(1000)))
assert len(frames) == 1


def test_nested_types():
x = np.ones(5)
header, frames = serialize([[[x]]])
assert "dask" in str(header)
assert len(frames) == 1
assert x.data == np.frombuffer(frames[0]).data
31 changes: 9 additions & 22 deletions distributed/protocol/tests/test_collection_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dask.dataframe.utils import assert_eq

from distributed.protocol import deserialize, serialize
from distributed.protocol import dumps, loads


@pytest.mark.parametrize("collection", [tuple, dict])
Expand All @@ -14,19 +14,13 @@ def test_serialize_cupy(collection, y, y_serializer):
if y is not None:
y = cupy.arange(y)
if issubclass(collection, dict):
header, frames = serialize(
{"x": x, "y": y}, serializers=("cuda", "dask", "pickle")
)
frames = dumps({"x": x, "y": y}, serializers=("cuda", "dask", "pickle"))
else:
header, frames = serialize((x, y), serializers=("cuda", "dask", "pickle"))
t = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error"))
frames = dumps((x, y), serializers=("cuda", "dask", "pickle"))

assert header["is-collection"] is True
sub_headers = header["sub-headers"]
assert sub_headers[0]["serializer"] == "cuda"
assert sub_headers[1]["serializer"] == y_serializer
assert isinstance(t, collection)
assert any(isinstance(f, cupy.ndarray) for f in frames)

t = loads(frames, deserializers=("cuda", "dask", "pickle", "error"))
assert ((t["x"] if isinstance(t, dict) else t[0]) == x).all()
if y is None:
assert (t["y"] if isinstance(t, dict) else t[1]) is None
Expand All @@ -46,19 +40,12 @@ def test_serialize_pandas_pandas(collection, df2, df2_serializer):
if df2 is not None:
df2 = cudf.from_pandas(pd.DataFrame(df2))
if issubclass(collection, dict):
header, frames = serialize(
{"df1": df1, "df2": df2}, serializers=("cuda", "dask", "pickle")
)
frames = dumps({"df1": df1, "df2": df2}, serializers=("cuda", "dask", "pickle"))
else:
header, frames = serialize((df1, df2), serializers=("cuda", "dask", "pickle"))
t = deserialize(header, frames, deserializers=("cuda", "dask", "pickle"))

assert header["is-collection"] is True
sub_headers = header["sub-headers"]
assert sub_headers[0]["serializer"] == "cuda"
assert sub_headers[1]["serializer"] == df2_serializer
assert isinstance(t, collection)
frames = dumps((df1, df2), serializers=("cuda", "dask", "pickle"))
assert any(isinstance(f, cudf.core.buffer.Buffer) for f in frames)

t = loads(frames, deserializers=("cuda", "dask", "pickle"))
assert_eq(t["df1"] if isinstance(t, dict) else t[0], df1)
if df2 is None:
assert (t["df2"] if isinstance(t, dict) else t[1]) is None
Expand Down
11 changes: 8 additions & 3 deletions distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Serialize,
Serialized,
dask_serialize,
dask_deserialize,
deserialize,
deserialize_bytes,
dumps,
Expand Down Expand Up @@ -419,12 +420,16 @@ class MyObj:

@dask_serialize.register(MyObj)
def _(x):
header = {"compression": [False]}
header = {"compression": (False,)}
frames = [b""]
return header, frames

header, frames = serialize([MyObj(), MyObj()])
assert header["compression"] == [False, False]
@dask_deserialize.register(MyObj)
def _(header, frames):
assert header["compression"] == (False,)

frames = dumps([MyObj(), MyObj()])
loads(frames)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 72a7d1b

Please sign in to comment.