From 176ed1546f941bb658f6e039833c2e3349161030 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Thu, 15 Apr 2021 20:03:26 +0200 Subject: [PATCH 01/21] Implementing single pass serialization --- distributed/protocol/core.py | 81 +++++-- distributed/protocol/serialize.py | 228 +++++++----------- distributed/protocol/tests/test_collection.py | 15 -- .../protocol/tests/test_collection_cuda.py | 31 +-- distributed/protocol/tests/test_serialize.py | 13 +- distributed/tests/test_client.py | 23 +- distributed/tests/test_utils.py | 10 - distributed/utils.py | 12 - distributed/worker.py | 61 ++++- 9 files changed, 226 insertions(+), 248 deletions(-) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 0d87595bf7..8b8835afa8 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -4,8 +4,11 @@ from .compression import decompress, maybe_compress from .serialize import ( + MsgpackList, Serialize, Serialized, + SerializedCallable, + TaskGraphValue, merge_and_deserialize, msgpack_decode_default, msgpack_encode_default, @@ -47,27 +50,30 @@ def _inplace_compress_frames(header, frames): def _encode_default(obj): typ = type(obj) - if typ is Serialize or typ is Serialized: - if typ is Serialize: - obj = obj.data - offset = len(frames) - if typ is Serialized: - sub_header, sub_frames = obj.header, obj.frames - else: - sub_header, sub_frames = serialize_and_split( - obj, serializers=serializers, on_error=on_error, context=context - ) - _inplace_compress_frames(sub_header, sub_frames) - sub_header["num-sub-frames"] = len(sub_frames) - frames.append( - msgpack.dumps( - sub_header, default=msgpack_encode_default, use_bin_type=True - ) - ) - frames.extend(sub_frames) - return {"__Serialized__": offset} + + ret = msgpack_encode_default(obj) + if ret is not obj: + return ret + + if typ is Serialize: + obj = obj.data # TODO: remove Serialize/to_serialize completely + + offset = len(frames) + if typ is Serialized: + sub_header, sub_frames = obj.header, obj.frames else: - return msgpack_encode_default(obj) + sub_header, sub_frames = serialize_and_split( + obj, serializers=serializers, on_error=on_error, context=context + ) + _inplace_compress_frames(sub_header, sub_frames) + sub_header["num-sub-frames"] = len(sub_frames) + frames.append( + msgpack.dumps( + sub_header, default=msgpack_encode_default, use_bin_type=True + ) + ) + frames.extend(sub_frames) + return {"__Serialized__": offset, "callable": callable(obj)} frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True) return frames @@ -77,6 +83,11 @@ def _encode_default(obj): raise +class DelayedExceptionRaise: + def __init__(self, err): + self.err = err + + def loads(frames, deserialize=True, deserializers=None): """ Transform bytestream back into Python value """ @@ -89,19 +100,41 @@ def _decode_default(obj): frames[offset], object_hook=msgpack_decode_default, use_list=False, - **msgpack_opts + **msgpack_opts, ) offset += 1 sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] if deserialize: if "compression" in sub_header: sub_frames = decompress(sub_header, sub_frames) - return merge_and_deserialize( - sub_header, sub_frames, deserializers=deserializers - ) + try: + return merge_and_deserialize( + sub_header, sub_frames, deserializers=deserializers + ) + except Exception as e: + if deserialize == "delay-exception": + return DelayedExceptionRaise(e) + else: + raise + elif obj["callable"]: + return SerializedCallable(sub_header, sub_frames) else: return Serialized(sub_header, sub_frames) else: + # Notice, even though `msgpack_decode_default()` supports + # `__MsgpackList__`, we decode it here explicitly. This way + # we can delay the convertion to a regular `list` until it + # gets to a worker. + if "__MsgpackList__" in obj: + if deserialize: + return list(obj["as-tuple"]) + else: + return MsgpackList(obj["as-tuple"]) + if "__TaskGraphValue__" in obj: + if deserialize: + return obj["data"] + else: + return TaskGraphValue(obj["data"]) return msgpack_decode_default(obj) return msgpack.loads( diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 50d8795311..8b52b68f99 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -1,3 +1,4 @@ +import collections.abc import importlib import traceback from array import array @@ -7,7 +8,7 @@ import msgpack import dask -from dask.base import normalize_token +from dask.base import normalize_token, tokenize from ..utils import ensure_bytes, has_keyword, typename from . import pickle @@ -20,6 +21,8 @@ dask_deserialize = dask.utils.Dispatch("dask_deserialize") _cached_allowed_modules = {} +non_list_collection_types = (tuple, set, frozenset) +collection_types = (list,) + non_list_collection_types def dask_dumps(x, context=None): @@ -106,6 +109,42 @@ def import_allowed_module(name): ) +class MsgpackList(collections.abc.MutableSequence): + def __init__(self, x): + self.data = x + + def __getitem__(self, i): + return self.data[i] + + def __delitem__(self, i): + del self.data[i] + + def __setitem__(self, i, v): + self.data[i] = v + + def __len__(self): + return len(self.data) + + def insert(self, i, v): + return self.data.insert(i, v) + + +class TaskGraphValue: + def __init__(self, x): + self.data = x + + +def msgpack_persist_lists(obj): + typ = type(obj) + if typ is list: + return MsgpackList([msgpack_persist_lists(o) for o in obj]) + if typ in non_list_collection_types: + return typ(msgpack_persist_lists(o) for o in obj) + if typ is dict: + return {k: msgpack_persist_lists(v) for k, v in obj.items()} + return obj + + def msgpack_decode_default(obj): """ Custom packer/unpacker for msgpack @@ -116,15 +155,13 @@ def msgpack_decode_default(obj): return getattr(typ, obj["name"]) if "__Set__" in obj: - return set(obj["as-list"]) + return set(obj["as-tuple"]) - if "__Serialized__" in obj: - # Notice, the data here is marked a Serialized rather than deserialized. This - # is because deserialization requires Pickle which the Scheduler cannot run - # because of security reasons. - # By marking it Serialized, the data is passed through to the workers that - # eventually will deserialize it. - return Serialized(*obj["data"]) + if "__MsgpackList__" in obj: + return list(obj["as-tuple"]) + + if "__TaskGraphValue__" in obj: + return obj["data"] return obj @@ -133,9 +170,7 @@ def msgpack_encode_default(obj): """ Custom packer/unpacker for msgpack """ - - if isinstance(obj, Serialize): - return {"__Serialized__": True, "data": serialize(obj.data)} + typ = type(obj) if isinstance(obj, Enum): return { @@ -146,14 +181,20 @@ def msgpack_encode_default(obj): } if isinstance(obj, set): - return {"__Set__": True, "as-list": list(obj)} + return {"__Set__": True, "as-tuple": tuple(obj)} + + if typ is MsgpackList: + return {"__MsgpackList__": True, "as-tuple": tuple(obj.data)} + + if typ is TaskGraphValue: + return {"__TaskGraphValue__": True, "data": obj.data} return obj def msgpack_dumps(x): try: - frame = msgpack.dumps(x, use_bin_type=True) + frame = msgpack.dumps(x, default=msgpack_encode_default, use_bin_type=True) except Exception: raise NotImplementedError() else: @@ -161,7 +202,12 @@ def msgpack_dumps(x): def msgpack_loads(header, frames): - return msgpack.loads(b"".join(frames), use_list=False, **msgpack_opts) + return msgpack.loads( + b"".join(frames), + object_hook=msgpack_decode_default, + use_list=False, + **msgpack_opts, + ) def serialization_error_loads(header, frames): @@ -238,71 +284,18 @@ def serialize(x, serializers=None, on_error="message", context=None): if isinstance(x, Serialized): return x.header, x.frames - if type(x) in (list, set, tuple, dict): - iterate_collection = False - if type(x) is list and "msgpack" in serializers: - # Note: "msgpack" will always convert lists to tuples - # (see GitHub #3716), so we should iterate - # through the list if "msgpack" comes before "pickle" - # in the list of serializers. - iterate_collection = ("pickle" not in serializers) or ( - serializers.index("pickle") > serializers.index("msgpack") - ) - if not iterate_collection: - # Check for "dask"-serializable data in dict/list/set - iterate_collection = check_dask_serializable(x) - - # Determine whether keys are safe to be serialized with msgpack - if type(x) is dict and iterate_collection: - try: - msgpack.dumps(list(x.keys())) - except Exception: - dict_safe = False - else: - dict_safe = True - + # Note: "msgpack" will always convert lists to tuple (see GitHub #3716), + # so we should persist lists if "msgpack" comes before "pickle" + # in the list of serializers. if ( - type(x) in (list, set, tuple) - and iterate_collection - or type(x) is dict - and iterate_collection - and dict_safe + type(x) is list + and "msgpack" in serializers + and ( + "pickle" not in serializers + or serializers.index("pickle") > serializers.index("msgpack") + ) ): - if isinstance(x, dict): - headers_frames = [] - for k, v in x.items(): - _header, _frames = serialize( - v, serializers=serializers, on_error=on_error, context=context - ) - _header["key"] = k - headers_frames.append((_header, _frames)) - else: - headers_frames = [ - serialize( - obj, serializers=serializers, on_error=on_error, context=context - ) - for obj in x - ] - - frames = [] - lengths = [] - compressions = [] - for _header, _frames in headers_frames: - frames.extend(_frames) - length = len(_frames) - lengths.append(length) - compressions.extend(_header.get("compression") or [None] * len(_frames)) - - headers = [obj[0] for obj in headers_frames] - headers = { - "sub-headers": headers, - "is-collection": True, - "frame-lengths": lengths, - "type-serialized": type(x).__name__, - } - if any(compression is not None for compression in compressions): - headers["compression"] = compressions - return headers, frames + x = msgpack_persist_lists(x) tb = "" @@ -347,37 +340,6 @@ def deserialize(header, frames, deserializers=None): -------- serialize """ - if "is-collection" in header: - headers = header["sub-headers"] - lengths = header["frame-lengths"] - cls = {"tuple": tuple, "list": list, "set": set, "dict": dict}[ - header["type-serialized"] - ] - - start = 0 - if cls is dict: - d = {} - for _header, _length in zip(headers, lengths): - k = _header.pop("key") - d[k] = deserialize( - _header, - frames[start : start + _length], - deserializers=deserializers, - ) - start += _length - return d - else: - lst = [] - for _header, _length in zip(headers, lengths): - lst.append( - deserialize( - _header, - frames[start : start + _length], - deserializers=deserializers, - ) - ) - start += _length - return cls(lst) name = header.get("serializer") if deserializers is not None and name not in deserializers: @@ -511,6 +473,14 @@ def __eq__(self, other): def __ne__(self, other): return not (self == other) + def __hash__(self): + return hash(tokenize((self.header, self.frames))) + + +class SerializedCallable(Serialized): + def __call__(self) -> None: + raise NotImplementedError + def nested_deserialize(x): """ @@ -522,32 +492,20 @@ def nested_deserialize(x): {'op': 'update', 'data': 123} """ - def replace_inner(x): - if type(x) is dict: - x = x.copy() - for k, v in x.items(): - typ = type(v) - if typ is dict or typ is list: - x[k] = replace_inner(v) - elif typ is Serialize: - x[k] = v.data - elif typ is Serialized: - x[k] = deserialize(v.header, v.frames) - - elif type(x) is list: - x = list(x) - for k, v in enumerate(x): - typ = type(v) - if typ is dict or typ is list: - x[k] = replace_inner(v) - elif typ is Serialize: - x[k] = v.data - elif typ is Serialized: - x[k] = deserialize(v.header, v.frames) - - return x - - return replace_inner(x) + typ = type(x) + if typ is dict: + return {k: nested_deserialize(v) for k, v in x.items()} + if typ is MsgpackList: + return list(nested_deserialize(x.data)) + if typ is TaskGraphValue: + return x.data + if typ is Serialize: + return x.data + if isinstance(x, Serialized): + return deserialize(x.header, x.frames) + if typ in collection_types: + return typ(nested_deserialize(o) for o in x) + return x def serialize_bytelist(x, **kwargs): diff --git a/distributed/protocol/tests/test_collection.py b/distributed/protocol/tests/test_collection.py index 32d11a7475..f7db7c6121 100644 --- a/distributed/protocol/tests/test_collection.py +++ b/distributed/protocol/tests/test_collection.py @@ -25,13 +25,6 @@ def test_serialize_collection(collection, y, y_serializer): t = deserialize(header, frames, deserializers=("dask", "pickle", "error")) assert isinstance(t, collection) - assert header["is-collection"] is True - sub_headers = header["sub-headers"] - - if collection is not dict: - assert sub_headers[0]["serializer"] == "dask" - assert sub_headers[1]["serializer"] == y_serializer - if collection is dict: assert (t["x"] == x).all() assert str(t["y"]) == str(y) @@ -43,11 +36,3 @@ def test_serialize_collection(collection, y, y_serializer): def test_large_collections_serialize_simply(): header, frames = serialize(tuple(range(1000))) assert len(frames) == 1 - - -def test_nested_types(): - x = np.ones(5) - header, frames = serialize([[[x]]]) - assert "dask" in str(header) - assert len(frames) == 1 - assert x.data == np.frombuffer(frames[0]).data diff --git a/distributed/protocol/tests/test_collection_cuda.py b/distributed/protocol/tests/test_collection_cuda.py index a50fb7e2bb..93431d1df5 100644 --- a/distributed/protocol/tests/test_collection_cuda.py +++ b/distributed/protocol/tests/test_collection_cuda.py @@ -2,7 +2,7 @@ from dask.dataframe.utils import assert_eq -from distributed.protocol import deserialize, serialize +from distributed.protocol import dumps, loads @pytest.mark.parametrize("collection", [tuple, dict]) @@ -14,19 +14,13 @@ def test_serialize_cupy(collection, y, y_serializer): if y is not None: y = cupy.arange(y) if issubclass(collection, dict): - header, frames = serialize( - {"x": x, "y": y}, serializers=("cuda", "dask", "pickle") - ) + frames = dumps({"x": x, "y": y}, serializers=("cuda", "dask", "pickle")) else: - header, frames = serialize((x, y), serializers=("cuda", "dask", "pickle")) - t = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error")) + frames = dumps((x, y), serializers=("cuda", "dask", "pickle")) - assert header["is-collection"] is True - sub_headers = header["sub-headers"] - assert sub_headers[0]["serializer"] == "cuda" - assert sub_headers[1]["serializer"] == y_serializer - assert isinstance(t, collection) + assert any(isinstance(f, cupy.ndarray) for f in frames) + t = loads(frames, deserializers=("cuda", "dask", "pickle", "error")) assert ((t["x"] if isinstance(t, dict) else t[0]) == x).all() if y is None: assert (t["y"] if isinstance(t, dict) else t[1]) is None @@ -46,19 +40,12 @@ def test_serialize_pandas_pandas(collection, df2, df2_serializer): if df2 is not None: df2 = cudf.from_pandas(pd.DataFrame(df2)) if issubclass(collection, dict): - header, frames = serialize( - {"df1": df1, "df2": df2}, serializers=("cuda", "dask", "pickle") - ) + frames = dumps({"df1": df1, "df2": df2}, serializers=("cuda", "dask", "pickle")) else: - header, frames = serialize((df1, df2), serializers=("cuda", "dask", "pickle")) - t = deserialize(header, frames, deserializers=("cuda", "dask", "pickle")) - - assert header["is-collection"] is True - sub_headers = header["sub-headers"] - assert sub_headers[0]["serializer"] == "cuda" - assert sub_headers[1]["serializer"] == df2_serializer - assert isinstance(t, collection) + frames = dumps((df1, df2), serializers=("cuda", "dask", "pickle")) + assert any(isinstance(f, cudf.core.buffer.Buffer) for f in frames) + t = loads(frames, deserializers=("cuda", "dask", "pickle")) assert_eq(t["df1"] if isinstance(t, dict) else t[0], df1) if df2 is None: assert (t["df2"] if isinstance(t, dict) else t[1]) is None diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index cfb767bc0c..73f20373aa 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -18,6 +18,7 @@ from distributed.protocol import ( Serialize, Serialized, + dask_deserialize, dask_serialize, deserialize, deserialize_bytes, @@ -419,12 +420,16 @@ class MyObj: @dask_serialize.register(MyObj) def _(x): - header = {"compression": [False]} + header = {"compression": (False,)} frames = [b""] return header, frames - header, frames = serialize([MyObj(), MyObj()]) - assert header["compression"] == [False, False] + @dask_deserialize.register(MyObj) + def _(header, frames): + assert header["compression"] == (False,) + + frames = dumps([MyObj(), MyObj()]) + loads(frames) @pytest.mark.parametrize( @@ -468,7 +473,7 @@ def test_check_dask_serializable(data, is_serializable): ) def test_serialize_lists(serializers): data_in = ["a", 2, "c", None, "e", 6] - header, frames = serialize(data_in, serializers=serializers) + header, frames = serialize(data_in, serializers=serializers, on_error="error") data_out = deserialize(header, frames) assert data_in == data_out diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index fb70994b56..1e59f78599 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4572,8 +4572,6 @@ async def test_get_future_error_simple(c, s, a, b): assert f.status == "error" function, args, kwargs, deps = await c._get_futures_error(f) - # args contains only solid values, not keys - assert function.__name__ == "div" with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @@ -4591,8 +4589,7 @@ async def test_get_futures_error(c, s, a, b): assert f.status == "error" function, args, kwargs, deps = await c._get_futures_error(f) - assert function.__name__ == "div" - assert args == (1, y0.key) + assert args == ((div, 1, y0.key),) @gen_cluster(client=True) @@ -4609,8 +4606,6 @@ async def test_recreate_error_delayed(c, s, a, b): function, args, kwargs = await c._recreate_error_locally(f) assert f.status == "error" - assert function.__name__ == "div" - assert args == (1, 0) with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @@ -4628,8 +4623,6 @@ async def test_recreate_error_futures(c, s, a, b): function, args, kwargs = await c._recreate_error_locally(f) assert f.status == "error" - assert function.__name__ == "div" - assert args == (1, 0) with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @@ -4720,14 +4713,18 @@ class Foo: def __getstate__(self): raise MyException() - with pytest.raises(MyException): + # Notice, because serialization is delayed until `distributed.batched` + # we get a `CancelledError` exception. + # Before the serialization + # happed immediately in `submit()`, which would raise the `MyException`. + with pytest.raises(CancelledError): future = c.submit(identity, Foo()) - futures = c.map(inc, range(10)) - results = await c.gather(futures) + futures = c.map(inc, range(10)) + results = await c.gather(futures) - assert results == list(map(inc, range(10))) - assert a.data and b.data + assert results == list(map(inc, range(10))) + assert a.data and b.data @gen_cluster(client=True) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 669368c70f..dae6313d74 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -23,7 +23,6 @@ Logs, LoopRunner, TimeoutError, - _maybe_complex, ensure_bytes, ensure_ip, format_dashboard_link, @@ -202,15 +201,6 @@ def c(x): assert type(tb).__name__ == "traceback" -def test_maybe_complex(): - assert not _maybe_complex(1) - assert not _maybe_complex("x") - assert _maybe_complex((inc, 1)) - assert _maybe_complex([(inc, 1)]) - assert _maybe_complex([(inc, 1)]) - assert _maybe_complex({"x": (inc, 1)}) - - def test_read_block(): delimiter = b"\n" data = delimiter.join([b"123", b"456", b"789"]) diff --git a/distributed/utils.py b/distributed/utils.py index 00c4dab50d..948759999d 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -38,7 +38,6 @@ from tornado.ioloop import IOLoop import dask -from dask import istask # Import config serialization functions here for backward compatibility from dask.config import deserialize as deserialize_for_cli # noqa @@ -759,17 +758,6 @@ def validate_key(k): raise TypeError("Unexpected key type %s (value: %r)" % (typ, k)) -def _maybe_complex(task): - """ Possibly contains a nested task """ - return ( - istask(task) - or type(task) is list - and any(map(_maybe_complex, task)) - or type(task) is dict - and any(map(_maybe_complex, task.values())) - ) - - def seek_delimiter(file, delimiter, blocksize): """Seek current file to next byte after a delimiter bytestring diff --git a/distributed/worker.py b/distributed/worker.py index 9e79ef5ab6..fa96e11c03 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -23,7 +23,6 @@ from tornado.ioloop import IOLoop, PeriodicCallback import dask -from dask.compatibility import apply from dask.core import istask from dask.system import CPU_COUNT from dask.utils import format_bytes, funcname @@ -47,6 +46,8 @@ from .node import ServerNode from .proctitle import setproctitle from .protocol import deserialize_bytes, pickle, serialize_bytelist, to_serialize +from .protocol.core import DelayedExceptionRaise +from .protocol.serialize import TaskGraphValue, collection_types, msgpack_persist_lists from .pubsub import PubSubWorkerExtension from .security import Security from .sizeof import safe_sizeof as sizeof @@ -55,7 +56,6 @@ from .utils import ( LRU, TimeoutError, - _maybe_complex, get_ip, has_arg, import_file, @@ -702,6 +702,7 @@ def __init__( stream_handlers=stream_handlers, io_loop=self.loop, connection_args=self.connection_args, + deserialize="delay-exception", **kwargs, ) @@ -878,7 +879,11 @@ async def _register_with_scheduler(self): while True: try: _start = time() - comm = await connect(self.scheduler.address, **self.connection_args) + comm = await connect( + self.scheduler.address, + deserialize="delay-exception", + **self.connection_args, + ) comm.name = "Worker->Scheduler" comm._server = weakref.ref(self) await comm.write( @@ -3481,6 +3486,18 @@ def loads_function(bytes_object): return pickle.loads(bytes_object) +def raise_delayed_exceptions(x): + typ = type(x) + if typ is DelayedExceptionRaise: + raise x.err + elif typ is dict: + for y in x.values(): + raise_delayed_exceptions(y) + elif typ in collection_types: + for y in x: + raise_delayed_exceptions(y) + + def _deserialize(function=None, args=None, kwargs=None, task=no_value): """ Deserialize task inputs and regularize to func, args, kwargs """ if function is not None: @@ -3494,6 +3511,7 @@ def _deserialize(function=None, args=None, kwargs=None, task=no_value): assert not function and not args and not kwargs function = execute_task args = (task,) + raise_delayed_exceptions(task) return function, args or (), kwargs or {} @@ -3536,6 +3554,31 @@ def dumps_function(func): return result +_warn_dumps_warned = [False] + + +def warn_large_args(obj, limit=1e6): + if _warn_dumps_warned[0]: + return + size = sizeof(obj) + if size > limit: + _warn_dumps_warned[0] = True + s = str(obj) + if len(s) > 70: + s = s[:50] + " ... " + s[-15:] + warnings.warn( + "Large object of size %s detected in task graph: \n" + " %s\n" + "Consider scattering large objects ahead of time\n" + "with client.scatter to reduce scheduler burden and \n" + "keep data on workers\n\n" + " future = client.submit(func, big_data) # bad\n\n" + " big_future = client.scatter(big_data) # good\n" + " future = client.submit(func, big_future) # good" + % (format_bytes(size), s) + ) + + def dumps_task(task): """Serialize a dask task @@ -3557,17 +3600,9 @@ def dumps_task(task): {'task': b'\x80\x04\x95\x03\x00\x00\x00\x00\x00\x00\x00K\x01.'} """ if istask(task): - if task[0] is apply and not any(map(_maybe_complex, task[2:])): - d = {"function": dumps_function(task[1]), "args": warn_dumps(task[2])} - if len(task) == 4: - d["kwargs"] = warn_dumps(task[3]) - return d - elif not any(map(_maybe_complex, task[1:])): - return {"function": dumps_function(task[0]), "args": warn_dumps(task[1:])} - return to_serialize(task) - + warn_large_args(task[1:]) -_warn_dumps_warned = [False] + return TaskGraphValue(msgpack_persist_lists(task)) def warn_dumps(obj, dumps=pickle.dumps, limit=1e6): From 0a4433a06f30e82f1a359b818298a2c10457c20c Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Thu, 15 Apr 2021 21:44:08 +0200 Subject: [PATCH 02/21] Added dumps/loads cache --- distributed/protocol/core.py | 51 ++++++++++++++++++++++++++--- distributed/tests/test_scheduler.py | 17 ---------- distributed/worker.py | 47 +++++--------------------- 3 files changed, 56 insertions(+), 59 deletions(-) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 8b8835afa8..9856458605 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -1,7 +1,10 @@ import logging +import threading import msgpack +from ..utils import LRU +from . import pickle from .compression import decompress, maybe_compress from .serialize import ( MsgpackList, @@ -18,6 +21,39 @@ logger = logging.getLogger(__name__) +cache_dumps = LRU(maxsize=100) +cache_loads = LRU(maxsize=100) +cache_dumps_lock = threading.Lock() +cache_loads_lock = threading.Lock() + + +def dumps_function(func): + """ Dump a function to bytes, cache functions """ + try: + with cache_dumps_lock: + result = cache_dumps[func] + except KeyError: + result = pickle.dumps(func, protocol=4) + if len(result) < 100000: + with cache_dumps_lock: + cache_dumps[func] = result + except TypeError: # Unhashable function + result = pickle.dumps(func, protocol=4) + return result + + +def loads_function(bytes_object): + """ Load a function from bytes, cache bytes """ + if len(bytes_object) < 100000: + with cache_dumps_lock: + try: + result = cache_loads[bytes_object] + except KeyError: + result = pickle.loads(bytes_object) + cache_loads[bytes_object] = result + return result + return pickle.loads(bytes_object) + def dumps(msg, serializers=None, on_error="message", context=None) -> list: """Transform Python message to bytestream suitable for communication @@ -59,8 +95,10 @@ def _encode_default(obj): obj = obj.data # TODO: remove Serialize/to_serialize completely offset = len(frames) - if typ is Serialized: + if typ in (Serialized, SerializedCallable): sub_header, sub_frames = obj.header, obj.frames + elif callable(obj): + sub_header, sub_frames = {"callable": dumps_function(obj)}, [] else: sub_header, sub_frames = serialize_and_split( obj, serializers=serializers, on_error=on_error, context=context @@ -73,7 +111,7 @@ def _encode_default(obj): ) ) frames.extend(sub_frames) - return {"__Serialized__": offset, "callable": callable(obj)} + return {"__Serialized__": offset} frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True) return frames @@ -104,6 +142,11 @@ def _decode_default(obj): ) offset += 1 sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] + if "callable" in sub_header: + if deserialize: + return loads_function(sub_header["callable"]) + else: + return SerializedCallable(sub_header, sub_frames) if deserialize: if "compression" in sub_header: sub_frames = decompress(sub_header, sub_frames) @@ -116,8 +159,8 @@ def _decode_default(obj): return DelayedExceptionRaise(e) else: raise - elif obj["callable"]: - return SerializedCallable(sub_header, sub_frames) + # elif obj["callable"]: + # return SerializedCallable(sub_header, sub_frames) else: return Serialized(sub_header, sub_frames) else: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index c82a4a0483..e08d695e4e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -13,7 +13,6 @@ import dask from dask import delayed -from dask.compatibility import apply from distributed import Client, Nanny, Worker, fire_and_forget, wait from distributed.client import wait @@ -471,22 +470,6 @@ def test_dumps_function(): assert a != c -def test_dumps_task(): - d = dumps_task((inc, 1)) - assert set(d) == {"function", "args"} - - f = lambda x, y=2: x + y - d = dumps_task((apply, f, (1,), {"y": 10})) - assert cloudpickle.loads(d["function"])(1, 2) == 3 - assert cloudpickle.loads(d["args"]) == (1,) - assert cloudpickle.loads(d["kwargs"]) == {"y": 10} - - d = dumps_task((apply, f, (1,))) - assert cloudpickle.loads(d["function"])(1, 2) == 3 - assert cloudpickle.loads(d["args"]) == (1,) - assert set(d) == {"function", "args"} - - @gen_cluster() async def test_ready_remove_worker(s, a, b): s.update_graph( diff --git a/distributed/worker.py b/distributed/worker.py index fa96e11c03..880d739f48 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -46,7 +46,13 @@ from .node import ServerNode from .proctitle import setproctitle from .protocol import deserialize_bytes, pickle, serialize_bytelist, to_serialize -from .protocol.core import DelayedExceptionRaise +from .protocol.core import ( + DelayedExceptionRaise, + cache_loads, + cache_loads_lock, + dumps_function, + loads_function, +) from .protocol.serialize import TaskGraphValue, collection_types, msgpack_persist_lists from .pubsub import PubSubWorkerExtension from .security import Security @@ -54,7 +60,6 @@ from .threadpoolexecutor import ThreadPoolExecutor from .threadpoolexecutor import secede as tpe_secede from .utils import ( - LRU, TimeoutError, get_ip, has_arg, @@ -1049,7 +1054,8 @@ def func(data): if load: try: import_file(out_filename) - cache_loads.data.clear() + with cache_loads_lock: + cache_loads.data.clear() except Exception as e: logger.exception(e) raise e @@ -3471,21 +3477,6 @@ async def _get_data(): job_counter = [0] -cache_loads = LRU(maxsize=100) - - -def loads_function(bytes_object): - """ Load a function from bytes, cache bytes """ - if len(bytes_object) < 100000: - try: - result = cache_loads[bytes_object] - except KeyError: - result = pickle.loads(bytes_object) - cache_loads[bytes_object] = result - return result - return pickle.loads(bytes_object) - - def raise_delayed_exceptions(x): typ = type(x) if typ is DelayedExceptionRaise: @@ -3534,26 +3525,6 @@ def execute_task(task): return task -cache_dumps = LRU(maxsize=100) - -_cache_lock = threading.Lock() - - -def dumps_function(func): - """ Dump a function to bytes, cache functions """ - try: - with _cache_lock: - result = cache_dumps[func] - except KeyError: - result = pickle.dumps(func, protocol=4) - if len(result) < 100000: - with _cache_lock: - cache_dumps[func] = result - except TypeError: # Unhashable function - result = pickle.dumps(func, protocol=4) - return result - - _warn_dumps_warned = [False] From ff0a542941a9da4f945fe03696bc0b0666286ade Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 16 Apr 2021 15:40:03 +0200 Subject: [PATCH 03/21] Fixed test_batched test --- distributed/tests/test_batched.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index ee84ec3224..a582df6ebb 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -1,5 +1,6 @@ import asyncio import random +import types import pytest from tlz import assoc @@ -7,7 +8,6 @@ from distributed.batched import BatchedSend from distributed.core import CommClosedError, connect, listen from distributed.metrics import time -from distributed.protocol import to_serialize from distributed.utils import All from distributed.utils_test import captured_logger @@ -234,18 +234,16 @@ async def test_serializers(): b = BatchedSend(interval="10ms", serializers=["msgpack"]) b.start(comm) - b.send({"x": to_serialize(123)}) - b.send({"x": to_serialize("hello")}) + b.send({"x": 123}) + b.send({"x": "hello"}) await asyncio.sleep(0.100) - b.send({"x": to_serialize(lambda x: x + 1)}) + b.send({"x": types.SimpleNamespace()}) # Object not msgpack serializable with captured_logger("distributed.protocol") as sio: await asyncio.sleep(0.100) value = sio.getvalue() - assert "serialize" in value - assert "type" in value - assert "function" in value - + assert "Failed to Serialize" in value + assert "SimpleNamespace" in value assert comm.closed() From 084680c22c79e29bf42fa9273ca8bed85ff0f24a Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 16 Apr 2021 15:52:24 +0200 Subject: [PATCH 04/21] Fixed test_turn_off_pickle --- distributed/tests/test_client.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 1e59f78599..fcfafa6520 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -17,6 +17,7 @@ from operator import add from threading import Semaphore from time import sleep +from types import SimpleNamespace import psutil import pytest @@ -5551,27 +5552,20 @@ async def test(s, a, b): # Can't send complex data with pytest.raises(TypeError): - future = await c.scatter(inc) + future = await c.scatter(SimpleNamespace()) - # can send complex tasks (this uses pickle regardless) + # Can send and receive functions tasks (this uses pickle regardless) future = c.submit(lambda x: x, inc) await wait(future) - - # but can't receive complex results - with pytest.raises(TypeError): - await c.gather(future, direct=direct) + await c.gather(future, direct=direct) # Run works result = await c.run(lambda: 1) assert list(result.values()) == [1, 1] result = await c.run_on_scheduler(lambda: 1) assert result == 1 - - # But not with complex return values - with pytest.raises(TypeError): - await c.run(lambda: inc) - with pytest.raises(TypeError): - await c.run_on_scheduler(lambda: inc) + await c.run(lambda: inc) + await c.run_on_scheduler(lambda: inc) test() From 6548968973e3e0ac1b82001330112d92be4a7a3b Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 16 Apr 2021 16:17:04 +0200 Subject: [PATCH 05/21] fixed test_pickle_safe --- distributed/tests/test_publish.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index 7f4c03fe71..e7fb181b2c 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -1,4 +1,5 @@ import asyncio +from types import SimpleNamespace import pytest @@ -271,12 +272,11 @@ async def test_pickle_safe(c, s, a, b): assert result == [1, 2, 3] with pytest.raises(TypeError): - await c2.publish_dataset(y=lambda x: x) + # SimpleNamespace() is not serializable + await c2.publish_dataset(y=SimpleNamespace()) await c.publish_dataset(z=lambda x: x) # this can use pickle - - with pytest.raises(TypeError): - await c2.get_dataset("z") + await c2.get_dataset("z") @gen_cluster(client=True) From ca6d0f6c704576c0aaeaa91ac10a7718056f3544 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 16 Apr 2021 16:23:49 +0200 Subject: [PATCH 06/21] fixed test_executor_offload --- distributed/tests/test_worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 775703ef1d..b6e334e455 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1781,7 +1781,9 @@ def __setstate__(self, state): x = SameThreadClass() def f(x): - return threading.get_ident() == x._thread_ident + # With , + # deserialization is done as part of communication. + return threading.get_ident() != x._thread_ident assert await c.submit(f, x) From 2da24765c7e5b901a1f675c349f08fc6801bcfb3 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Thu, 22 Apr 2021 11:14:24 +0200 Subject: [PATCH 07/21] delay exceptions when loads_function --- distributed/protocol/core.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 9856458605..63d6d154c0 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -46,12 +46,12 @@ def loads_function(bytes_object): """ Load a function from bytes, cache bytes """ if len(bytes_object) < 100000: with cache_dumps_lock: - try: - result = cache_loads[bytes_object] - except KeyError: + if bytes_object in cache_loads: + return cache_loads[bytes_object] + else: result = pickle.loads(bytes_object) cache_loads[bytes_object] = result - return result + return result return pickle.loads(bytes_object) @@ -144,7 +144,13 @@ def _decode_default(obj): sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] if "callable" in sub_header: if deserialize: - return loads_function(sub_header["callable"]) + try: + return loads_function(sub_header["callable"]) + except Exception as e: + if deserialize == "delay-exception": + return DelayedExceptionRaise(e) + else: + raise else: return SerializedCallable(sub_header, sub_frames) if deserialize: @@ -159,8 +165,6 @@ def _decode_default(obj): return DelayedExceptionRaise(e) else: raise - # elif obj["callable"]: - # return SerializedCallable(sub_header, sub_frames) else: return Serialized(sub_header, sub_frames) else: From 9717ae8edc9e43399aeb950f7b40b6969ad30f01 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Thu, 22 Apr 2021 11:18:20 +0200 Subject: [PATCH 08/21] fixed test_rpc_serialization --- distributed/tests/test_core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 398b933c02..95787a1f55 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -771,8 +771,7 @@ async def f(): await server.listen("tcp://") async with rpc(server.address, serializers=["msgpack"]) as r: - with pytest.raises(TypeError): - await r.echo(x=to_serialize(inc)) + await r.echo(x=to_serialize(inc)) async with rpc(server.address, serializers=["msgpack", "pickle"]) as r: result = await r.echo(x=to_serialize(inc)) From ec3d4ada940afad6cd18610a0c37c3e328e0ce1a Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Thu, 22 Apr 2021 11:31:49 +0200 Subject: [PATCH 09/21] Making sure to also unpack remotes in MsgpackList --- distributed/utils_comm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 80e6e0b8ae..076165cd20 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -12,6 +12,7 @@ from dask.utils import parse_timedelta, stringify from .core import rpc +from .protocol.serialize import MsgpackList from .utils import All logger = logging.getLogger(__name__) @@ -161,7 +162,7 @@ async def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=N return (names, who_has, nbytes) -collection_types = (tuple, list, set, frozenset) +collection_types = (tuple, list, set, frozenset, MsgpackList) def unpack_remotedata(o, byte_keys=False, myset=None): From 845b2902056f62cb96a5c21636a7a05ce85f5ebb Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 27 Apr 2021 16:58:51 +0200 Subject: [PATCH 10/21] Revert "fixed test_rpc_serialization" This reverts commit 9717ae8edc9e43399aeb950f7b40b6969ad30f01. --- distributed/tests/test_core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 95787a1f55..398b933c02 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -771,7 +771,8 @@ async def f(): await server.listen("tcp://") async with rpc(server.address, serializers=["msgpack"]) as r: - await r.echo(x=to_serialize(inc)) + with pytest.raises(TypeError): + await r.echo(x=to_serialize(inc)) async with rpc(server.address, serializers=["msgpack", "pickle"]) as r: result = await r.echo(x=to_serialize(inc)) From cd937b52a65e29d4feaadfbd58df12192b2de498 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 27 Apr 2021 19:19:08 +0200 Subject: [PATCH 11/21] Checking the deserialises before unpickle --- distributed/protocol/core.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 63d6d154c0..2960e81083 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -129,6 +129,10 @@ def __init__(self, err): def loads(frames, deserialize=True, deserializers=None): """ Transform bytestream back into Python value """ + if deserializers is None: + # TODO: get from configuration both here and in protocol.serialize() + deserializers = ("dask", "pickle") + try: def _decode_default(obj): @@ -145,6 +149,11 @@ def _decode_default(obj): if "callable" in sub_header: if deserialize: try: + if "pickle" not in deserializers: + raise TypeError( + f"Cannot deserialize {sub_header['callable']}, " + "pickle isn't in deserializers" + ) return loads_function(sub_header["callable"]) except Exception as e: if deserialize == "delay-exception": From 7fb8a269038c70250a6496a70a0179ba7763bd11 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 27 Apr 2021 19:20:40 +0200 Subject: [PATCH 12/21] Rewritten test_robust_unserializable() --- distributed/tests/test_client.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index fcfafa6520..1411e803a3 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4715,17 +4715,22 @@ def __getstate__(self): raise MyException() # Notice, because serialization is delayed until `distributed.batched` - # we get a `CancelledError` exception. - # Before the serialization - # happed immediately in `submit()`, which would raise the `MyException`. - with pytest.raises(CancelledError): + # we don't get an exception immediately. The exception is raised and logged + # when the ongoing communication between the client the scheduler encounters + # the `Foo` class. Before + # the serialization happed immediately in `submit()`, which would raise the + # `MyException`. + with captured_logger("distributed") as caplog: future = c.submit(identity, Foo()) + await asyncio.sleep(c.scheduler_comm.interval) + # Check that the serialization error was logged + assert "Failed to serialize" in caplog.getvalue() - futures = c.map(inc, range(10)) - results = await c.gather(futures) + futures = c.map(inc, range(10)) + results = await c.gather(futures) - assert results == list(map(inc, range(10))) - assert a.data and b.data + assert results == list(map(inc, range(10))) + assert a.data and b.data @gen_cluster(client=True) From 79f8d243ab315ee565638e8bcf42757c451708f0 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 27 Apr 2021 19:39:40 +0200 Subject: [PATCH 13/21] typo --- distributed/tests/test_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 1411e803a3..10c656ce17 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4718,10 +4718,11 @@ def __getstate__(self): # we don't get an exception immediately. The exception is raised and logged # when the ongoing communication between the client the scheduler encounters # the `Foo` class. Before - # the serialization happed immediately in `submit()`, which would raise the + # the serialization happened immediately in `submit()`, which would raise the # `MyException`. with captured_logger("distributed") as caplog: future = c.submit(identity, Foo()) + # We sleep to make sure that a `BatchedSend.interval` has passed. await asyncio.sleep(c.scheduler_comm.interval) # Check that the serialization error was logged assert "Failed to serialize" in caplog.getvalue() From 9aa4bbe09780fe76dedd97792eb948e6932bbc7c Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 28 Apr 2021 16:24:45 +0200 Subject: [PATCH 14/21] Revert "Fixed test_turn_off_pickle" This reverts commit 084680c22c79e29bf42fa9273ca8bed85ff0f24a. --- distributed/tests/test_client.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 10c656ce17..7ed8f4858f 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -17,7 +17,6 @@ from operator import add from threading import Semaphore from time import sleep -from types import SimpleNamespace import psutil import pytest @@ -5558,20 +5557,27 @@ async def test(s, a, b): # Can't send complex data with pytest.raises(TypeError): - future = await c.scatter(SimpleNamespace()) + future = await c.scatter(inc) - # Can send and receive functions tasks (this uses pickle regardless) + # can send complex tasks (this uses pickle regardless) future = c.submit(lambda x: x, inc) await wait(future) - await c.gather(future, direct=direct) + + # but can't receive complex results + with pytest.raises(TypeError): + await c.gather(future, direct=direct) # Run works result = await c.run(lambda: 1) assert list(result.values()) == [1, 1] result = await c.run_on_scheduler(lambda: 1) assert result == 1 - await c.run(lambda: inc) - await c.run_on_scheduler(lambda: inc) + + # But not with complex return values + with pytest.raises(TypeError): + await c.run(lambda: inc) + with pytest.raises(TypeError): + await c.run_on_scheduler(lambda: inc) test() From f700725f9465246e4afd2882d33cf45c071ae91b Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 28 Apr 2021 17:08:03 +0200 Subject: [PATCH 15/21] Checking the serialises before pickle --- distributed/protocol/core.py | 13 +++++++++++-- distributed/tests/test_client.py | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 2960e81083..c7258f873b 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -65,6 +65,11 @@ def dumps(msg, serializers=None, on_error="message", context=None) -> list: encounters an object it cannot serialize like a NumPy array, it is handled out-of-band by `_encode_default()` and appended to the output frame list. """ + + if serializers is None: + # TODO: get from configuration both here and in protocol.serialize() + serializers = ("dask", "pickle") + try: if context and "compression" in context: compress_opts = {"compression": context["compression"]} @@ -98,6 +103,10 @@ def _encode_default(obj): if typ in (Serialized, SerializedCallable): sub_header, sub_frames = obj.header, obj.frames elif callable(obj): + if "pickle" not in serializers: + raise TypeError( + f"Cannot serialize {repr(obj)} since pickle isn't in serializers" + ) sub_header, sub_frames = {"callable": dumps_function(obj)}, [] else: sub_header, sub_frames = serialize_and_split( @@ -151,8 +160,8 @@ def _decode_default(obj): try: if "pickle" not in deserializers: raise TypeError( - f"Cannot deserialize {sub_header['callable']}, " - "pickle isn't in deserializers" + f"Cannot deserialize {sub_header['callable']}, since " + f"pickle isn't in deserializers" ) return loads_function(sub_header["callable"]) except Exception as e: diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 7ed8f4858f..6818e041eb 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5574,9 +5574,9 @@ async def test(s, a, b): assert result == 1 # But not with complex return values - with pytest.raises(TypeError): + with pytest.raises(CommClosedError): await c.run(lambda: inc) - with pytest.raises(TypeError): + with pytest.raises(CommClosedError): await c.run_on_scheduler(lambda: inc) test() From 13ddcec4bb0b702b8fac045e9be69c06403ee49b Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 28 Apr 2021 17:10:42 +0200 Subject: [PATCH 16/21] Revert "fixed test_pickle_safe" This reverts commit 6548968973e3e0ac1b82001330112d92be4a7a3b. --- distributed/tests/test_publish.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index e7fb181b2c..7f4c03fe71 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -1,5 +1,4 @@ import asyncio -from types import SimpleNamespace import pytest @@ -272,11 +271,12 @@ async def test_pickle_safe(c, s, a, b): assert result == [1, 2, 3] with pytest.raises(TypeError): - # SimpleNamespace() is not serializable - await c2.publish_dataset(y=SimpleNamespace()) + await c2.publish_dataset(y=lambda x: x) await c.publish_dataset(z=lambda x: x) # this can use pickle - await c2.get_dataset("z") + + with pytest.raises(TypeError): + await c2.get_dataset("z") @gen_cluster(client=True) From 2fa9f12b5f7a41f72700a2d70842159482d65450 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 28 Apr 2021 21:05:27 +0200 Subject: [PATCH 17/21] test_robust_unserializable(): catching CancelledError --- distributed/tests/test_client.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 6818e041eb..0fded31b94 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4714,17 +4714,12 @@ def __getstate__(self): raise MyException() # Notice, because serialization is delayed until `distributed.batched` - # we don't get an exception immediately. The exception is raised and logged - # when the ongoing communication between the client the scheduler encounters - # the `Foo` class. Before - # the serialization happened immediately in `submit()`, which would raise the - # `MyException`. - with captured_logger("distributed") as caplog: - future = c.submit(identity, Foo()) - # We sleep to make sure that a `BatchedSend.interval` has passed. - await asyncio.sleep(c.scheduler_comm.interval) - # Check that the serialization error was logged - assert "Failed to serialize" in caplog.getvalue() + # we get a `CancelledError`. The exception is raised when the ongoing + # communication between the client the scheduler encounters the `Foo` class. + # Before the serialization + # happened immediately in `submit()`, which would raise the `MyException`. + with pytest.raises(CancelledError): + await c.submit(identity, Foo()) futures = c.map(inc, range(10)) results = await c.gather(futures) From 6a2a915ae4e983500b2afad7bc0b7d9c3dd36784 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 5 May 2021 14:58:43 +0200 Subject: [PATCH 18/21] MsgpackList: support isinstance(x, list) --- distributed/protocol/serialize.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 390c924b21..db5f5511f5 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -110,6 +110,8 @@ def import_allowed_module(name): class MsgpackList(collections.abc.MutableSequence): + __class__ = list # Make `isinstance(x, list)` true + def __init__(self, x): self.data = x @@ -128,6 +130,9 @@ def __len__(self): def insert(self, i, v): return self.data.insert(i, v) + def __repr__(self): + return f"MsgpackList({repr(self.data)})" + class TaskGraphValue: def __init__(self, x): From be143bae69bcf337d1d3f98d18f9c7f328b5697f Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 5 May 2021 14:59:39 +0200 Subject: [PATCH 19/21] SerializedCallable: accept args --- distributed/protocol/serialize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index db5f5511f5..01020b29cb 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -493,8 +493,8 @@ def __hash__(self): class SerializedCallable(Serialized): - def __call__(self) -> None: - raise NotImplementedError + def __call__(self, *args, **kwargs) -> None: + raise RuntimeError("SerializedCallable should never be called!") def nested_deserialize(x): From 29baab1cbc7db7ef48490a70fa2aa675a69b119e Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 5 May 2021 15:05:37 +0200 Subject: [PATCH 20/21] to_serialize = msgpack_persist_lists --- distributed/protocol/serialize.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 01020b29cb..f4c0fa797b 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -462,7 +462,9 @@ def __hash__(self): return hash(self.data) -to_serialize = Serialize +# Historically, `to_serialize()` has been used to mark data for serialization. +# Now, we use it to preserve lists through msgpack serialization. +to_serialize = msgpack_persist_lists class Serialized: From e7d51051d31040d266d219d985261c5972a83ee8 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 5 May 2021 16:08:57 +0200 Subject: [PATCH 21/21] replaced Serialize and to_serialize with TaskGraphValue --- distributed/comm/tests/test_comms.py | 3 +- distributed/protocol/serialize.py | 60 +++++++++++++--------------- 2 files changed, 30 insertions(+), 33 deletions(-) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 8b78f7f37d..eef7c2136b 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -32,6 +32,7 @@ from distributed.compatibility import WINDOWS from distributed.metrics import time from distributed.protocol import Serialized, deserialize, serialize, to_serialize +from distributed.protocol.serialize import TaskGraphValue from distributed.utils import get_ip, get_ipv6 from distributed.utils_test import loop # noqa: F401 from distributed.utils_test import ( @@ -1148,7 +1149,7 @@ async def check_deserialize_roundtrip(addr): assert isinstance(got["to_ser"][0], (bytes, bytearray)) assert isinstance(got["ser"], (bytes, bytearray)) else: - assert isinstance(got["to_ser"][0], (to_serialize, Serialized)) + assert isinstance(got["to_ser"][0], TaskGraphValue) assert isinstance(got["ser"], Serialized) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index f4c0fa797b..0ab3b9ae3c 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -135,8 +135,27 @@ def __repr__(self): class TaskGraphValue: - def __init__(self, x): - self.data = x + """Mark object as a graph value not concerning the scheduler + + This is usefull for objects we don't want the scheduler to interpret + literally such as `False` and `None`, which makes code like + `dsk.pop("key", None)` work as expected. + """ + + def __init__(self, data): + self.data = data + + def __repr__(self): + return f"" + + def __eq__(self, other): + return isinstance(other, TaskGraphValue) and other.data == self.data + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash(self.data) def msgpack_persist_lists(obj): @@ -432,39 +451,16 @@ def merge_and_deserialize(header, frames, deserializers=None): return deserialize(header, merged_frames, deserializers=deserializers) -class Serialize: - """Mark an object that should be serialized - - Examples - -------- - >>> msg = {'op': 'update', 'data': to_serialize(123)} - >>> msg # doctest: +SKIP - {'op': 'update', 'data': } - - See also - -------- - distributed.protocol.dumps +# TODO: remove Serialize and to_serialize and use TaskGraphValue instead +def to_serialize(x): """ - - def __init__(self, data): - self.data = data - - def __repr__(self): - return "" % str(self.data) - - def __eq__(self, other): - return isinstance(other, Serialize) and other.data == self.data - - def __ne__(self, other): - return not (self == other) - - def __hash__(self): - return hash(self.data) + Traditionally, `to_serialize()` has been used to mark data for serialization. + Now, we use it to preserve lists through msgpack serialization. + """ + return TaskGraphValue(msgpack_persist_lists(x)) -# Historically, `to_serialize()` has been used to mark data for serialization. -# Now, we use it to preserve lists through msgpack serialization. -to_serialize = msgpack_persist_lists +Serialize = TaskGraphValue class Serialized: