Skip to content

Commit

Permalink
Add the missing buffer_callback argument (#308)
Browse files Browse the repository at this point in the history
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
  • Loading branch information
suquark and ogrisel authored Feb 7, 2020
1 parent f4ce61f commit e0ad635
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
over the module list
([PR #322](https://github.com/cloudpipe/cloudpickle/pull/322)).

- Add support for out-of-band pickling (Python 3.8 and later).
https://docs.python.org/3/library/pickle.html#example
([issue #308](https://github.com/cloudpipe/cloudpickle/pull/308))

1.2.2
=====

Expand Down
12 changes: 6 additions & 6 deletions cloudpickle/cloudpickle_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


# Shorthands similar to pickle.dump/pickle.dumps
def dump(obj, file, protocol=None):
def dump(obj, file, protocol=None, buffer_callback=None):
"""Serialize obj as bytes streamed into file
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
Expand All @@ -44,10 +44,10 @@ def dump(obj, file, protocol=None):
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
CloudPickler(file, protocol=protocol).dump(obj)
CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback).dump(obj)


def dumps(obj, protocol=None):
def dumps(obj, protocol=None, buffer_callback=None):
"""Serialize obj as a string of bytes allocated in memory
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
Expand All @@ -58,7 +58,7 @@ def dumps(obj, protocol=None):
compatibility with older versions of Python.
"""
with io.BytesIO() as file:
cp = CloudPickler(file, protocol=protocol)
cp = CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback)
cp.dump(obj)
return file.getvalue()

Expand Down Expand Up @@ -421,10 +421,10 @@ class CloudPickler(Pickler):
dispatch[types.MappingProxyType] = _mappingproxy_reduce
dispatch[weakref.WeakSet] = _weakset_reduce

def __init__(self, file, protocol=None):
def __init__(self, file, protocol=None, buffer_callback=None):
if protocol is None:
protocol = DEFAULT_PROTOCOL
Pickler.__init__(self, file, protocol=protocol)
Pickler.__init__(self, file, protocol=protocol, buffer_callback=buffer_callback)
# map functions __globals__ attribute ids, to ensure that functions
# sharing the same global namespace at pickling time also share their
# global namespace at unpickling time.
Expand Down
16 changes: 16 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,22 @@ def __getattr__(self, name):
with pytest.raises(pickle.PicklingError, match='recursion'):
cloudpickle.dumps(a)

def test_out_of_band_buffers(self):
if self.protocol < 5:
pytest.skip("Need Pickle Protocol 5 or later")
np = pytest.importorskip("numpy")

class LocallyDefinedClass:
data = np.zeros(10)

data_instance = LocallyDefinedClass()
buffers = []
pickle_bytes = cloudpickle.dumps(data_instance, protocol=self.protocol,
buffer_callback=buffers.append)
assert len(buffers) == 1
reconstructed = pickle.loads(pickle_bytes, buffers=buffers)
np.testing.assert_allclose(reconstructed.data, data_instance.data)


class Protocol2CloudPickleTest(CloudPickleTest):

Expand Down

0 comments on commit e0ad635

Please sign in to comment.