Skip to content

Commit

Permalink
Allow pickle to fall back to dask_serialize (#7567)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Feb 23, 2023
1 parent 8f77b44 commit 41fdb91
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 4 deletions.
31 changes: 27 additions & 4 deletions distributed/protocol/pickle.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
from __future__ import annotations

import inspect
import io
import logging
import pickle

import cloudpickle
from packaging.version import parse as parse_version

from distributed.protocol.serialize import dask_deserialize, dask_serialize

CLOUDPICKLE_GTE_20 = parse_version(cloudpickle.__version__) >= parse_version("2.0.0")

HIGHEST_PROTOCOL = pickle.HIGHEST_PROTOCOL

logger = logging.getLogger(__name__)


class _DaskPickler(pickle.Pickler):
def reducer_override(self, obj):
# For some objects this causes segfaults otherwise, see
# https://github.com/dask/distributed/pull/7564#issuecomment-1438727339
if _always_use_pickle_for(obj):
return NotImplemented
try:
serialize = dask_serialize.dispatch(type(obj))
deserialize = dask_deserialize.dispatch(type(obj))
return deserialize, serialize(obj)
except TypeError:
return NotImplemented


def _always_use_pickle_for(x):
mod, _, _ = x.__class__.__module__.partition(".")
if mod == "numpy":
Expand Down Expand Up @@ -42,8 +59,14 @@ def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL):
if dump_kwargs["protocol"] >= 5 and buffer_callback is not None:
dump_kwargs["buffer_callback"] = buffers.append
try:
buffers.clear()
result = pickle.dumps(x, **dump_kwargs)
try:
result = pickle.dumps(x, **dump_kwargs)
except Exception:
f = io.BytesIO()
pickler = _DaskPickler(f, **dump_kwargs)
buffers.clear()
pickler.dump(x)
result = f.getvalue()
if b"__main__" in result or (
CLOUDPICKLE_GTE_20
and getattr(inspect.getmodule(x), "__name__", None)
Expand All @@ -56,8 +79,8 @@ def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL):
try:
buffers.clear()
result = cloudpickle.dumps(x, **dump_kwargs)
except Exception as e:
logger.info("Failed to serialize %s. Exception: %s", x, e)
except Exception:
logger.exception("Failed to serialize %s.", x)
raise
if buffer_callback is not None:
for b in buffers:
Expand Down
58 changes: 58 additions & 0 deletions distributed/protocol/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
dumps,
loads,
)
from distributed.protocol.serialize import dask_deserialize, dask_serialize
from distributed.utils_test import save_sys_modules


Expand Down Expand Up @@ -220,3 +221,60 @@ def test_pickle_by_value_when_registered():

finally:
sys.path.pop(0)


class NoPickle:
def __getstate__(self):
raise TypeError("nope")


def _serialize_nopickle(x):
return {}, ["hooray"]


def _deserialize_nopickle(header, frames):
assert header == {}
assert frames == ["hooray"]
return NoPickle()


def test_allow_pickle_if_registered_in_dask_serialize():
with pytest.raises(TypeError, match="nope"):
dumps(NoPickle())

dask_serialize.register(NoPickle)(_serialize_nopickle)
dask_deserialize.register(NoPickle)(_deserialize_nopickle)

try:
assert isinstance(loads(dumps(NoPickle())), NoPickle)
finally:
del dask_serialize._lookup[NoPickle]
del dask_deserialize._lookup[NoPickle]


class NestedNoPickle:
def __init__(self) -> None:
self.stuff = {"foo": NoPickle()}


def test_nopickle_nested():
nested_obj = [NoPickle()]
with pytest.raises(TypeError, match="nope"):
dumps(nested_obj)
with pytest.raises(TypeError, match="nope"):
dumps(NestedNoPickle())

dask_serialize.register(NoPickle)(_serialize_nopickle)
dask_deserialize.register(NoPickle)(_deserialize_nopickle)

try:
obj = NestedNoPickle()
roundtrip = loads(dumps(obj))
assert roundtrip is not obj
assert isinstance(roundtrip.stuff["foo"], NoPickle)
roundtrip = loads(dumps(nested_obj))
assert roundtrip is not nested_obj
assert isinstance(roundtrip[0], NoPickle)
finally:
del dask_serialize._lookup[NoPickle]
del dask_deserialize._lookup[NoPickle]

0 comments on commit 41fdb91

Please sign in to comment.