From bcb5024a0b3efce57b94b9a42537e35c48198e7a Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Tue, 24 Aug 2021 13:38:48 +0200 Subject: [PATCH] Add `allow_copy` flag to interchange protocol (#51) Add a flag to throw an exception if the export cannot be zero-copy. (e.g. for pandas, possible due to block manager where rows are contiguous and columns are not) . - Add `allow_zero_copy` flag to the DataFrame class. - Propagate the flag to the buffer and raise a `RuntimeError` when needed - Fix `test_noncontiguous_columns` - Make update in the requirements doc Co-authored-by: Stephannie Jimenez --- protocol/dataframe_protocol.py | 9 +++- protocol/pandas_implementation.py | 88 ++++++++++++++++++++----------- 2 files changed, 66 insertions(+), 31 deletions(-) diff --git a/protocol/dataframe_protocol.py b/protocol/dataframe_protocol.py index 3964ec6b..6b0c0f3f 100644 --- a/protocol/dataframe_protocol.py +++ b/protocol/dataframe_protocol.py @@ -354,7 +354,8 @@ class DataFrame: ``__dataframe__`` method of a public data frame class in a library adhering to the dataframe interchange protocol specification. """ - def __dataframe__(self, nan_as_null : bool = False) -> dict: + def __dataframe__(self, nan_as_null : bool = False, + allow_copy : bool = True) -> dict: """ Produces a dictionary object following the dataframe protocol specification. @@ -362,8 +363,14 @@ def __dataframe__(self, nan_as_null : bool = False) -> dict: producer to overwrite null values in the data with ``NaN`` (or ``NaT``). It is intended for cases where the consumer does not support the bit mask or byte mask that is the producer's native representation. + + ``allow_copy`` is a keyword that defines whether or not the library is + allowed to make a copy of the data. For example, copying data would be + necessary if a library supports strided buffers, given that this protocol + specifies contiguous buffers. """ self._nan_as_null = nan_as_null + self._allow_zero_zopy = allow_copy return { "dataframe": self, # DataFrame object adhering to the protocol "version": 0 # Version number of the protocol diff --git a/protocol/pandas_implementation.py b/protocol/pandas_implementation.py index bbc79ef8..786bfbd2 100644 --- a/protocol/pandas_implementation.py +++ b/protocol/pandas_implementation.py @@ -35,7 +35,8 @@ ColumnObject = Any -def from_dataframe(df : DataFrameObject) -> pd.DataFrame: +def from_dataframe(df : DataFrameObject, + allow_copy : bool = True) -> pd.DataFrame: """ Construct a pandas DataFrame from ``df`` if it supports ``__dataframe__`` """ @@ -46,7 +47,7 @@ def from_dataframe(df : DataFrameObject) -> pd.DataFrame: if not hasattr(df, '__dataframe__'): raise ValueError("`df` does not support __dataframe__") - return _from_dataframe(df.__dataframe__()) + return _from_dataframe(df.__dataframe__(allow_copy=allow_copy)) def _from_dataframe(df : DataFrameObject) -> pd.DataFrame: @@ -63,19 +64,24 @@ def _from_dataframe(df : DataFrameObject) -> pd.DataFrame: # least for now, deal with non-numpy dtypes later). columns = dict() _k = _DtypeKind + _buffers = [] # hold on to buffers, keeps memory alive for name in df.column_names(): col = df.get_column_by_name(name) if col.dtype[0] in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL): # Simple numerical or bool dtype, turn into numpy array - columns[name] = convert_column_to_ndarray(col) + columns[name], _buf = convert_column_to_ndarray(col) elif col.dtype[0] == _k.CATEGORICAL: - columns[name] = convert_categorical_column(col) + columns[name], _buf = convert_categorical_column(col) elif col.dtype[0] == _k.STRING: - columns[name] = convert_string_column(col) + columns[name], _buf = convert_string_column(col) else: raise NotImplementedError(f"Data type {col.dtype[0]} not handled yet") - return pd.DataFrame(columns) + _buffers.append(_buf) + + df_new = pd.DataFrame(columns) + df_new._buffers = _buffers + return df_new class _DtypeKind(enum.IntEnum): @@ -100,7 +106,7 @@ def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray: "sentinel values not handled yet") _buffer, _dtype = col.get_buffers()["data"] - return buffer_to_ndarray(_buffer, _dtype) + return buffer_to_ndarray(_buffer, _dtype), _buffer def buffer_to_ndarray(_buffer, _dtype) -> np.ndarray: @@ -159,7 +165,7 @@ def convert_categorical_column(col : ColumnObject) -> pd.Series: raise NotImplementedError("Only categorical columns with sentinel " "value supported at the moment") - return series + return series, codes_buffer def convert_string_column(col : ColumnObject) -> np.ndarray: @@ -218,10 +224,11 @@ def convert_string_column(col : ColumnObject) -> np.ndarray: str_list.append(s) # Convert the string list to a NumPy array - return np.asarray(str_list, dtype="object") + return np.asarray(str_list, dtype="object"), buffers -def __dataframe__(cls, nan_as_null : bool = False) -> dict: +def __dataframe__(cls, nan_as_null : bool = False, + allow_copy : bool = True) -> dict: """ The public method to attach to pd.DataFrame. @@ -232,12 +239,21 @@ def __dataframe__(cls, nan_as_null : bool = False) -> dict: producer to overwrite null values in the data with ``NaN`` (or ``NaT``). This currently has no effect; once support for nullable extension dtypes is added, this value should be propagated to columns. + + ``allow_copy`` is a keyword that defines whether or not the library is + allowed to make a copy of the data. For example, copying data would be + necessary if a library supports strided buffers, given that this protocol + specifies contiguous buffers. + Currently, if the flag is set to ``False`` and a copy is needed, a + ``RuntimeError`` will be raised. """ - return _PandasDataFrame(cls, nan_as_null=nan_as_null) + return _PandasDataFrame( + cls, nan_as_null=nan_as_null, allow_copy=allow_copy) # Monkeypatch the Pandas DataFrame class to support the interchange protocol pd.DataFrame.__dataframe__ = __dataframe__ +pd.DataFrame._buffers = [] # Implementation of interchange protocol @@ -248,16 +264,18 @@ class _PandasBuffer: Data in the buffer is guaranteed to be contiguous in memory. """ - def __init__(self, x : np.ndarray) -> None: + def __init__(self, x : np.ndarray, allow_copy : bool = True) -> None: """ Handle only regular columns (= numpy arrays) for now. """ if not x.strides == (x.dtype.itemsize,): - # Array is not contiguous - this is possible to get in Pandas, - # there was some discussion on whether to support it. Som extra - # complexity for libraries that don't support it (e.g. Arrow), - # but would help with numpy-based libraries like Pandas. - raise RuntimeError("Design needs fixing - non-contiguous buffer") + # The protocol does not support strided buffers, so a copy is + # necessary. If that's not allowed, we need to raise an exception. + if allow_copy: + x = x.copy() + else: + raise RuntimeError("Exports cannot be zero-copy in the case " + "of a non-contiguous buffer") # Store the numpy array in which the data resides as a private # attribute, so we can use it to retrieve the public attributes @@ -313,7 +331,8 @@ class _PandasColumn: """ - def __init__(self, column : pd.Series) -> None: + def __init__(self, column : pd.Series, + allow_copy : bool = True) -> None: """ Note: doesn't deal with extension arrays yet, just assume a regular Series/ndarray for now. @@ -324,6 +343,7 @@ def __init__(self, column : pd.Series) -> None: # Store the column as a private attribute self._col = column + self._allow_copy = allow_copy @property def size(self) -> int: @@ -560,11 +580,13 @@ def _get_data_buffer(self) -> Tuple[_PandasBuffer, Any]: # Any is for self.dtyp """ _k = _DtypeKind if self.dtype[0] in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL): - buffer = _PandasBuffer(self._col.to_numpy()) + buffer = _PandasBuffer( + self._col.to_numpy(), allow_copy=self._allow_copy) dtype = self.dtype elif self.dtype[0] == _k.CATEGORICAL: codes = self._col.values.codes - buffer = _PandasBuffer(codes) + buffer = _PandasBuffer( + codes, allow_copy=self._allow_copy) dtype = self._dtype_from_pandasdtype(codes.dtype) elif self.dtype[0] == _k.STRING: # Marshal the strings from a NumPy object array into a byte array @@ -677,7 +699,8 @@ class _PandasDataFrame: ``pd.DataFrame.__dataframe__`` as objects with the methods and attributes defined on this class. """ - def __init__(self, df : pd.DataFrame, nan_as_null : bool = False) -> None: + def __init__(self, df : pd.DataFrame, nan_as_null : bool = False, + allow_copy : bool = True) -> None: """ Constructor - an instance of this (private) class is returned from `pd.DataFrame.__dataframe__`. @@ -688,6 +711,7 @@ def __init__(self, df : pd.DataFrame, nan_as_null : bool = False) -> None: # This currently has no effect; once support for nullable extension # dtypes is added, this value should be propagated to columns. self._nan_as_null = nan_as_null + self._allow_copy = allow_copy @property def metadata(self): @@ -708,13 +732,16 @@ def column_names(self) -> Iterable[str]: return self._df.columns.tolist() def get_column(self, i: int) -> _PandasColumn: - return _PandasColumn(self._df.iloc[:, i]) + return _PandasColumn( + self._df.iloc[:, i], allow_copy=self._allow_copy) def get_column_by_name(self, name: str) -> _PandasColumn: - return _PandasColumn(self._df[name]) + return _PandasColumn( + self._df[name], allow_copy=self._allow_copy) def get_columns(self) -> Iterable[_PandasColumn]: - return [_PandasColumn(self._df[name]) for name in self._df.columns] + return [_PandasColumn(self._df[name], allow_copy=self._allow_copy) + for name in self._df.columns] def select_columns(self, indices: Sequence[int]) -> '_PandasDataFrame': if not isinstance(indices, collections.Sequence): @@ -752,13 +779,14 @@ def test_mixed_intfloat(): def test_noncontiguous_columns(): - # Currently raises: TBD whether it should work or not, see code comment - # where the RuntimeError is raised. arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - df = pd.DataFrame(arr) - assert df[0].to_numpy().strides == (24,) - pytest.raises(RuntimeError, from_dataframe, df) - #df2 = from_dataframe(df) + df = pd.DataFrame(arr, columns=['a', 'b', 'c']) + assert df['a'].to_numpy().strides == (24,) + df2 = from_dataframe(df) # uses default of allow_copy=True + tm.assert_frame_equal(df, df2) + + with pytest.raises(RuntimeError): + from_dataframe(df, allow_copy=False) def test_categorical_dtype():