Skip to content

Commit

Permalink
[Python] use IPC for pickle serialisation
Browse files Browse the repository at this point in the history
Existing pickling serialises the whole buffer, even if the Array is
sliced.

Now we use Arrow's buffer truncation implemented for IPC serialization
for pickling.

Relies on a RecordBatch wrapper, adding ~230 bytes to the
pickled payload per Array chunk.
  • Loading branch information
anjakefala committed Sep 12, 2023
1 parent 47bf6e9 commit 3979363
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 44 deletions.
32 changes: 26 additions & 6 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -674,12 +674,18 @@ cdef shared_ptr[CArrayData] _reconstruct_array_data(data):
offset)


def _restore_array(data):
def _restore_array(buffer):
"""
Reconstruct an Array from pickled ArrayData.
Restore an IPC serialized Arrow Array.
Workaround for a pickling sliced Array issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
cdef shared_ptr[CArrayData] ad = _reconstruct_array_data(data)
return pyarrow_wrap_array(MakeArray(ad))
from pyarrow.ipc import RecordBatchStreamReader

with RecordBatchStreamReader(buffer) as reader:
return reader.read_next_batch().column(0)


cdef class _PandasConvertible(_Weakrefable):
Expand Down Expand Up @@ -1100,8 +1106,22 @@ cdef class Array(_PandasConvertible):
memory_pool=memory_pool)

def __reduce__(self):
return _restore_array, \
(_reduce_array_data(self.sp_array.get().data().get()),)
"""
Use Arrow IPC format for serialization.
Workaround for a pickling sliced Array issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
from pyarrow.ipc import RecordBatchStreamWriter
from pyarrow.lib import RecordBatch, BufferOutputStream

batch = RecordBatch.from_arrays([self], [''])
sink = BufferOutputStream()
with RecordBatchStreamWriter(sink, schema=batch.schema) as writer:
writer.write_batch(batch)

return _restore_array, (sink.getvalue(),)

@staticmethod
def from_buffers(DataType type, length, buffers, null_count=-1, offset=0,
Expand Down
90 changes: 78 additions & 12 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import warnings
import functools


cdef class ChunkedArray(_PandasConvertible):
Expand Down Expand Up @@ -67,7 +68,21 @@ cdef class ChunkedArray(_PandasConvertible):
self.chunked_array = chunked_array.get()

def __reduce__(self):
return chunked_array, (self.chunks, self.type)
"""
Use Arrow IPC format for serialization.
Workaround for a pickling sliced Array issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
Adds ~230 extra bytes to the pickled payload per Array chunk.
"""
import pyarrow as pa

# IPC serialization requires wrapping in RecordBatch
table = pa.Table.from_arrays([self], names=[""])
reconstruct_table, serialised = table.__reduce__()
return functools.partial(_reconstruct_chunked_array, reconstruct_table), serialised

@property
def data(self):
Expand Down Expand Up @@ -1390,6 +1405,16 @@ def chunked_array(arrays, type=None):
c_result = GetResultValue(CChunkedArray.Make(c_arrays, c_type))
return pyarrow_wrap_chunked_array(c_result)

def _reconstruct_chunked_array(restore_table, buffer):
"""
Restore an IPC serialized ChunkedArray.
Workaround for a pickling sliced Array issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
return restore_table(buffer).column(0)


cdef _schema_from_arrays(arrays, names, metadata, shared_ptr[CSchema]* schema):
cdef:
Expand Down Expand Up @@ -2196,7 +2221,21 @@ cdef class RecordBatch(_Tabular):
return self.batch != NULL

def __reduce__(self):
return _reconstruct_record_batch, (self.columns, self.schema)
"""
Use Arrow IPC format for serialization.
Workaround for a pickling sliced RecordBatch issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
from pyarrow.ipc import RecordBatchStreamWriter
from pyarrow.lib import RecordBatch, BufferOutputStream

sink = BufferOutputStream()
with RecordBatchStreamWriter(sink, schema=self.schema) as writer:
writer.write_batch(self)

return _reconstruct_record_batch, (sink.getvalue(),)

def validate(self, *, full=False):
"""
Expand Down Expand Up @@ -2984,11 +3023,18 @@ cdef class RecordBatch(_Tabular):
return pyarrow_wrap_batch(c_batch)


def _reconstruct_record_batch(columns, schema):
def _reconstruct_record_batch(buffer):
"""
Internal: reconstruct RecordBatch from pickled components.
Restore an IPC serialized Arrow RecordBatch.
Workaround for a pickling sliced RecordBatch issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
return RecordBatch.from_arrays(columns, schema=schema)
from pyarrow.ipc import RecordBatchStreamReader

with RecordBatchStreamReader(buffer) as reader:
return reader.read_next_batch()


def table_to_blocks(options, Table table, categories, extension_columns):
Expand Down Expand Up @@ -3170,10 +3216,23 @@ cdef class Table(_Tabular):
check_status(self.table.Validate())

def __reduce__(self):
# Reduce the columns as ChunkedArrays to avoid serializing schema
# data twice
columns = [col for col in self.columns]
return _reconstruct_table, (columns, self.schema)
"""
Use Arrow IPC format for serialization.
Workaround for a pickling sliced Table issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
Adds ~230 extra bytes to pickled payload per Array chunk.
"""
from pyarrow.ipc import RecordBatchStreamWriter
from pyarrow.lib import RecordBatch, BufferOutputStream

sink = BufferOutputStream()
with RecordBatchStreamWriter(sink, schema=self.schema) as writer:
writer.write_table(self)

return _reconstruct_table, (sink.getvalue(), )

def slice(self, offset=0, length=None):
"""
Expand Down Expand Up @@ -4754,11 +4813,18 @@ cdef class Table(_Tabular):
)


def _reconstruct_table(arrays, schema):
def _reconstruct_table(buffer):
"""
Internal: reconstruct pa.Table from pickled components.
Restore an IPC serialized Arrow Table.
Workaround for a pickling sliced Table issue,
where the whole buffer would be serialized:
https://github.com/apache/arrow/issues/26685
"""
return Table.from_arrays(arrays, schema=schema)
from pyarrow.ipc import RecordBatchStreamReader

with RecordBatchStreamReader(buffer) as reader:
return reader.read_all()


def record_batch(data, names=None, schema=None, metadata=None):
Expand Down
68 changes: 54 additions & 14 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1985,20 +1985,35 @@ def test_cast_identities(ty, values):


pickle_test_parametrize = pytest.mark.parametrize(
('data', 'typ'),
[
([True, False, True, True], pa.bool_()),
([1, 2, 4, 6], pa.int64()),
([1.0, 2.5, None], pa.float64()),
(['a', None, 'b'], pa.string()),
([], None),
([[1, 2], [3]], pa.list_(pa.int64())),
([[4, 5], [6]], pa.large_list(pa.int16())),
([['a'], None, ['b', 'c']], pa.list_(pa.string())),
([(1, 'a'), (2, 'c'), None],
pa.struct([pa.field('a', pa.int64()), pa.field('b', pa.string())]))
]
)
('data', 'typ'),
[
# Int array
(list(range(999)) + [None], pa.int64()),
# Float array
(list(map(float, range(999))) + [None], pa.float64()),
# Boolean array
([True, False, None, True] * 250, pa.bool_()),
# String array
(['a', 'b', 'cd', None, 'efg'] * 200, pa.string()),
# List array
([[1, 2], [3], [None, 4, 5], [6]] * 250, pa.list_(pa.int64())),
# Large list array
(
[[4, 5], [6], [None, 7], [8, 9, 10]] * 250,
pa.large_list(pa.int16())
),
# String list array
(
[['a'], None, ['b', 'cd'], ['efg']] * 250,
pa.list_(pa.string())
),
# Struct array
(
[(1, 'a'), (2, 'c'), None, (3, 'b')] * 250,
pa.struct([pa.field('a', pa.int64()), pa.field('b', pa.string())])
),
# Empty array
])


@pickle_test_parametrize
Expand Down Expand Up @@ -2049,6 +2064,31 @@ def test_array_pickle_protocol5(data, typ, pickle_module):
for buf in result.buffers()]
assert result_addresses == addresses

@pickle_test_parametrize
def test_array_pickle_slice_truncation(data, typ, pickle_module):
arr = pa.array(data, type=typ)
serialized = pickle_module.dumps(arr)

slice_arr = arr.slice(10, 2)
serialized = pickle_module.dumps(slice_arr)

# Check truncation upon serialization
assert len(serialized) <= 0.2 * len(serialized)

post_pickle_slice = pickle_module.loads(serialized)

# Check for post-roundtrip equality
assert post_pickle_slice.equals(slice_arr)

# Check that pickling reset the offset
assert post_pickle_slice.offset == 0

# Check that after pickling the slice buffer was trimmed to only contain the sliced data
buf_size = arr.get_total_buffer_size()
post_pickle_slice_buf_size = post_pickle_slice.get_total_buffer_size()
assert buf_size / post_pickle_slice_buf_size - len(arr) / len(post_pickle_slice) < 10



@pytest.mark.parametrize(
'narr',
Expand Down
Loading

0 comments on commit 3979363

Please sign in to comment.