From 5a75023405461524c325cb35717cbf6454374457 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Fri, 12 Nov 2021 16:17:23 -0700 Subject: [PATCH] Deserialization: zero-copy merge subframes when possible (#5208) --- distributed/protocol/serialize.py | 21 +++- distributed/protocol/tests/test_numpy.py | 4 +- .../protocol/tests/test_protocol_utils.py | 105 +++++++++++++++- distributed/protocol/utils.py | 117 ++++++++++++++++++ 4 files changed, 239 insertions(+), 8 deletions(-) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 1a957f903e..6736cd90d5 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -16,7 +16,13 @@ from ..utils import ensure_bytes, has_keyword from . import pickle from .compression import decompress, maybe_compress -from .utils import frame_split_size, msgpack_opts, pack_frames_prelude, unpack_frames +from .utils import ( + frame_split_size, + merge_memoryviews, + msgpack_opts, + pack_frames_prelude, + unpack_frames, +) dask_serialize = dask.utils.Dispatch("dask_serialize") dask_deserialize = dask.utils.Dispatch("dask_deserialize") @@ -466,15 +472,18 @@ def merge_and_deserialize(header, frames, deserializers=None): deserialize serialize_and_split """ - merged_frames = [] if "split-num-sub-frames" not in header: merged_frames = frames else: + merged_frames = [] for n, offset in zip(header["split-num-sub-frames"], header["split-offsets"]): - if n == 1: - merged_frames.append(frames[offset]) - else: - merged_frames.append(bytearray().join(frames[offset : offset + n])) + subframes = frames[offset : offset + n] + try: + merged = merge_memoryviews(subframes) + except (ValueError, TypeError): + merged = bytearray().join(subframes) + + merged_frames.append(merged) return deserialize(header, merged_frames, deserializers=deserializers) diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 3fb7dc182e..5d5acaee85 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -187,12 +187,14 @@ def test_dumps_serialize_numpy_large(): frames = dumps([to_serialize(x)]) dtype, shape = x.dtype, x.shape checksum = crc32(x) - del x [y] = loads(frames) assert (y.dtype, y.shape) == (dtype, shape) assert crc32(y) == checksum, "Arrays are unequal" + x[:] = 2 # shared buffer; serialization is zero-copy + assert (x == y).all(), "Data was copied" + @pytest.mark.parametrize( "dt,size", diff --git a/distributed/protocol/tests/test_protocol_utils.py b/distributed/protocol/tests/test_protocol_utils.py index aed16dd014..5ebcb6e1e1 100644 --- a/distributed/protocol/tests/test_protocol_utils.py +++ b/distributed/protocol/tests/test_protocol_utils.py @@ -1,4 +1,8 @@ -from distributed.protocol.utils import pack_frames, unpack_frames +from __future__ import annotations + +import pytest + +from distributed.protocol.utils import merge_memoryviews, pack_frames, unpack_frames def test_pack_frames(): @@ -8,3 +12,102 @@ def test_pack_frames(): frames2 = unpack_frames(b) assert frames == frames2 + + +class TestMergeMemroyviews: + def test_empty(self): + empty = merge_memoryviews([]) + assert isinstance(empty, memoryview) and len(empty) == 0 + + def test_one(self): + base = bytearray(range(10)) + base_mv = memoryview(base) + assert merge_memoryviews([base_mv]) is base_mv + + @pytest.mark.parametrize( + "slices", + [ + [slice(None, 3), slice(3, None)], + [slice(1, 3), slice(3, None)], + [slice(1, 3), slice(3, -1)], + [slice(0, 0), slice(None)], + [slice(None), slice(-1, -1)], + [slice(0, 0), slice(0, 0)], + [slice(None, 3), slice(3, 7), slice(7, None)], + [slice(2, 3), slice(3, 7), slice(7, 9)], + [slice(2, 3), slice(3, 7), slice(7, 9), slice(9, 9)], + [slice(1, 2), slice(2, 5), slice(5, 8), slice(8, None)], + ], + ) + def test_parts(self, slices: list[slice]): + base = bytearray(range(10)) + base_mv = memoryview(base) + + equiv_start = min(s.indices(10)[0] for s in slices) + equiv_stop = max(s.indices(10)[1] for s in slices) + equiv = base_mv[equiv_start:equiv_stop] + + parts = [base_mv[s] for s in slices] + result = merge_memoryviews(parts) + assert result.obj is base + assert len(result) == len(equiv) + assert result == equiv + + def test_readonly_buffer(self): + pytest.importorskip( + "numpy", reason="Read-only buffer zero-copy merging requires NumPy" + ) + base = bytes(range(10)) + base_mv = memoryview(base) + + result = merge_memoryviews([base_mv[:4], base_mv[4:]]) + assert result.obj is base + assert len(result) == len(base) + assert result == base + + def test_catch_non_memoryview(self): + with pytest.raises(TypeError, match="Expected memoryview"): + merge_memoryviews([b"1234", memoryview(b"4567")]) + + with pytest.raises(TypeError, match="expected memoryview"): + merge_memoryviews([memoryview(b"123"), b"1234"]) + + @pytest.mark.parametrize( + "slices", + [ + [slice(None, 3), slice(4, None)], + [slice(None, 3), slice(2, None)], + [slice(1, 3), slice(3, 6), slice(9, None)], + ], + ) + def test_catch_gaps(self, slices: list[slice]): + base = bytearray(range(10)) + base_mv = memoryview(base) + + parts = [base_mv[s] for s in slices] + with pytest.raises(ValueError, match="does not start where the previous ends"): + merge_memoryviews(parts) + + def test_catch_different_buffer(self): + base = bytearray(range(8)) + base_mv = memoryview(base) + with pytest.raises(ValueError, match="different buffer"): + merge_memoryviews([base_mv, memoryview(base.copy())]) + + def test_catch_different_non_contiguous(self): + base = bytearray(range(8)) + base_mv = memoryview(base)[::-1] + with pytest.raises(ValueError, match="non-contiguous"): + merge_memoryviews([base_mv[:3], base_mv[3:]]) + + def test_catch_multidimensional(self): + base = bytearray(range(6)) + base_mv = memoryview(base).cast("B", [3, 2]) + with pytest.raises(ValueError, match="has 2 dimensions, not 1"): + merge_memoryviews([base_mv[:1], base_mv[1:]]) + + def test_catch_different_formats(self): + base = bytearray(range(8)) + base_mv = memoryview(base) + with pytest.raises(ValueError, match="inconsistent format: I vs B"): + merge_memoryviews([base_mv[:4], base_mv[4:].cast("I")]) diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index cf4f1815ea..f00e57449a 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -1,4 +1,8 @@ +from __future__ import annotations + +import ctypes import struct +from collections.abc import Sequence import dask @@ -81,3 +85,116 @@ def unpack_frames(b): start = end return frames + + +def merge_memoryviews(mvs: Sequence[memoryview]) -> memoryview: + """ + Zero-copy "concatenate" a sequence of contiguous memoryviews. + + Returns a new memoryview which slices into the underlying buffer + to extract out the portion equivalent to all of ``mvs`` being concatenated. + + All the memoryviews must: + * Share the same underlying buffer (``.obj``) + * When merged, cover a continuous portion of that buffer with no gaps + * Have the same strides + * Be 1-dimensional + * Have the same format + * Be contiguous + + Raises ValueError if these conditions are not met. + """ + if not mvs: + return memoryview(bytearray()) + if len(mvs) == 1: + return mvs[0] + + first = mvs[0] + if not isinstance(first, memoryview): + raise TypeError(f"Expected memoryview; got {type(first)}") + obj = first.obj + format = first.format + + first_start_addr = 0 + nbytes = 0 + for i, mv in enumerate(mvs): + if not isinstance(mv, memoryview): + raise TypeError(f"{i}: expected memoryview; got {type(mv)}") + + if mv.nbytes == 0: + continue + + if mv.obj is not obj: + raise ValueError( + f"{i}: memoryview has different buffer: {mv.obj!r} vs {obj!r}" + ) + if not mv.contiguous: + raise ValueError(f"{i}: memoryview non-contiguous") + if mv.ndim != 1: + raise ValueError(f"{i}: memoryview has {mv.ndim} dimensions, not 1") + if mv.format != format: + raise ValueError(f"{i}: inconsistent format: {mv.format} vs {format}") + + start_addr = address_of_memoryview(mv) + if first_start_addr == 0: + first_start_addr = start_addr + else: + expected_addr = first_start_addr + nbytes + if start_addr != expected_addr: + raise ValueError( + f"memoryview {i} does not start where the previous ends. " + f"Expected {expected_addr:x}, starts {start_addr - expected_addr} byte(s) away." + ) + nbytes += mv.nbytes + + if nbytes == 0: + # all memoryviews were zero-length + assert len(first) == 0 + return first + + assert first_start_addr != 0, "Underlying buffer is null pointer?!" + + base_mv = memoryview(obj).cast("B") + base_start_addr = address_of_memoryview(base_mv) + start_index = first_start_addr - base_start_addr + + return base_mv[start_index : start_index + nbytes].cast(format) + + +one_byte_carr = ctypes.c_byte * 1 +# ^ length and type don't matter, just use it to get the address of the first byte + + +def address_of_memoryview(mv: memoryview) -> int: + """ + Get the pointer to the first byte of a memoryview's data. + + If the memoryview is read-only, NumPy must be installed. + """ + # NOTE: this method relies on pointer arithmetic to figure out + # where each memoryview starts within the underlying buffer. + # There's no direct API to get the address of a memoryview, + # so we use a trick through ctypes and the buffer protocol: + # https://mattgwwalker.wordpress.com/2020/10/15/address-of-a-buffer-in-python/ + try: + carr = one_byte_carr.from_buffer(mv) + except TypeError: + # `mv` is read-only. `from_buffer` requires the buffer to be writeable. + # See https://bugs.python.org/issue11427 for discussion. + # This typically comes from `deserialize_bytes`, where `mv.obj` is an + # immutable bytestring. + pass + else: + return ctypes.addressof(carr) + + try: + import numpy as np + except ImportError: + raise ValueError( + f"Cannot get address of read-only memoryview {mv} since NumPy is not installed." + ) + + # NumPy doesn't mind read-only buffers. We could just use this method + # for all cases, but it's nice to use the pure-Python method for the common + # case of writeable buffers (created by TCP comms, for example). + return np.asarray(mv).__array_interface__["data"][0]