Skip to content

Commit

Permalink
Add allow_copy flag to interchange protocol (data-apis#51)
Browse files Browse the repository at this point in the history
Add a flag to throw an exception if the export cannot be
zero-copy. (e.g. for pandas, possible due to block manager where rows
are contiguous and columns are not) .

- Add `allow_zero_copy` flag to the DataFrame class.
- Propagate the flag to the buffer and raise a `RuntimeError` when needed
- Fix `test_noncontiguous_columns`
- Make update in the requirements doc

Co-authored-by: Stephannie Jimenez <steff456@hotmail.com>
  • Loading branch information
rgommers and steff456 authored Aug 24, 2021
1 parent d9419b1 commit bcb5024
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 31 deletions.
9 changes: 8 additions & 1 deletion protocol/dataframe_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,16 +354,23 @@ class DataFrame:
``__dataframe__`` method of a public data frame class in a library adhering
to the dataframe interchange protocol specification.
"""
def __dataframe__(self, nan_as_null : bool = False) -> dict:
def __dataframe__(self, nan_as_null : bool = False,
allow_copy : bool = True) -> dict:
"""
Produces a dictionary object following the dataframe protocol specification.
``nan_as_null`` is a keyword intended for the consumer to tell the
producer to overwrite null values in the data with ``NaN`` (or ``NaT``).
It is intended for cases where the consumer does not support the bit
mask or byte mask that is the producer's native representation.
``allow_copy`` is a keyword that defines whether or not the library is
allowed to make a copy of the data. For example, copying data would be
necessary if a library supports strided buffers, given that this protocol
specifies contiguous buffers.
"""
self._nan_as_null = nan_as_null
self._allow_zero_zopy = allow_copy
return {
"dataframe": self, # DataFrame object adhering to the protocol
"version": 0 # Version number of the protocol
Expand Down
88 changes: 58 additions & 30 deletions protocol/pandas_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
ColumnObject = Any


def from_dataframe(df : DataFrameObject) -> pd.DataFrame:
def from_dataframe(df : DataFrameObject,
allow_copy : bool = True) -> pd.DataFrame:
"""
Construct a pandas DataFrame from ``df`` if it supports ``__dataframe__``
"""
Expand All @@ -46,7 +47,7 @@ def from_dataframe(df : DataFrameObject) -> pd.DataFrame:
if not hasattr(df, '__dataframe__'):
raise ValueError("`df` does not support __dataframe__")

return _from_dataframe(df.__dataframe__())
return _from_dataframe(df.__dataframe__(allow_copy=allow_copy))


def _from_dataframe(df : DataFrameObject) -> pd.DataFrame:
Expand All @@ -63,19 +64,24 @@ def _from_dataframe(df : DataFrameObject) -> pd.DataFrame:
# least for now, deal with non-numpy dtypes later).
columns = dict()
_k = _DtypeKind
_buffers = [] # hold on to buffers, keeps memory alive
for name in df.column_names():
col = df.get_column_by_name(name)
if col.dtype[0] in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL):
# Simple numerical or bool dtype, turn into numpy array
columns[name] = convert_column_to_ndarray(col)
columns[name], _buf = convert_column_to_ndarray(col)
elif col.dtype[0] == _k.CATEGORICAL:
columns[name] = convert_categorical_column(col)
columns[name], _buf = convert_categorical_column(col)
elif col.dtype[0] == _k.STRING:
columns[name] = convert_string_column(col)
columns[name], _buf = convert_string_column(col)
else:
raise NotImplementedError(f"Data type {col.dtype[0]} not handled yet")

return pd.DataFrame(columns)
_buffers.append(_buf)

df_new = pd.DataFrame(columns)
df_new._buffers = _buffers
return df_new


class _DtypeKind(enum.IntEnum):
Expand All @@ -100,7 +106,7 @@ def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray:
"sentinel values not handled yet")

_buffer, _dtype = col.get_buffers()["data"]
return buffer_to_ndarray(_buffer, _dtype)
return buffer_to_ndarray(_buffer, _dtype), _buffer


def buffer_to_ndarray(_buffer, _dtype) -> np.ndarray:
Expand Down Expand Up @@ -159,7 +165,7 @@ def convert_categorical_column(col : ColumnObject) -> pd.Series:
raise NotImplementedError("Only categorical columns with sentinel "
"value supported at the moment")

return series
return series, codes_buffer


def convert_string_column(col : ColumnObject) -> np.ndarray:
Expand Down Expand Up @@ -218,10 +224,11 @@ def convert_string_column(col : ColumnObject) -> np.ndarray:
str_list.append(s)

# Convert the string list to a NumPy array
return np.asarray(str_list, dtype="object")
return np.asarray(str_list, dtype="object"), buffers


def __dataframe__(cls, nan_as_null : bool = False) -> dict:
def __dataframe__(cls, nan_as_null : bool = False,
allow_copy : bool = True) -> dict:
"""
The public method to attach to pd.DataFrame.
Expand All @@ -232,12 +239,21 @@ def __dataframe__(cls, nan_as_null : bool = False) -> dict:
producer to overwrite null values in the data with ``NaN`` (or ``NaT``).
This currently has no effect; once support for nullable extension
dtypes is added, this value should be propagated to columns.
``allow_copy`` is a keyword that defines whether or not the library is
allowed to make a copy of the data. For example, copying data would be
necessary if a library supports strided buffers, given that this protocol
specifies contiguous buffers.
Currently, if the flag is set to ``False`` and a copy is needed, a
``RuntimeError`` will be raised.
"""
return _PandasDataFrame(cls, nan_as_null=nan_as_null)
return _PandasDataFrame(
cls, nan_as_null=nan_as_null, allow_copy=allow_copy)


# Monkeypatch the Pandas DataFrame class to support the interchange protocol
pd.DataFrame.__dataframe__ = __dataframe__
pd.DataFrame._buffers = []


# Implementation of interchange protocol
Expand All @@ -248,16 +264,18 @@ class _PandasBuffer:
Data in the buffer is guaranteed to be contiguous in memory.
"""

def __init__(self, x : np.ndarray) -> None:
def __init__(self, x : np.ndarray, allow_copy : bool = True) -> None:
"""
Handle only regular columns (= numpy arrays) for now.
"""
if not x.strides == (x.dtype.itemsize,):
# Array is not contiguous - this is possible to get in Pandas,
# there was some discussion on whether to support it. Som extra
# complexity for libraries that don't support it (e.g. Arrow),
# but would help with numpy-based libraries like Pandas.
raise RuntimeError("Design needs fixing - non-contiguous buffer")
# The protocol does not support strided buffers, so a copy is
# necessary. If that's not allowed, we need to raise an exception.
if allow_copy:
x = x.copy()
else:
raise RuntimeError("Exports cannot be zero-copy in the case "
"of a non-contiguous buffer")

# Store the numpy array in which the data resides as a private
# attribute, so we can use it to retrieve the public attributes
Expand Down Expand Up @@ -313,7 +331,8 @@ class _PandasColumn:
"""

def __init__(self, column : pd.Series) -> None:
def __init__(self, column : pd.Series,
allow_copy : bool = True) -> None:
"""
Note: doesn't deal with extension arrays yet, just assume a regular
Series/ndarray for now.
Expand All @@ -324,6 +343,7 @@ def __init__(self, column : pd.Series) -> None:

# Store the column as a private attribute
self._col = column
self._allow_copy = allow_copy

@property
def size(self) -> int:
Expand Down Expand Up @@ -560,11 +580,13 @@ def _get_data_buffer(self) -> Tuple[_PandasBuffer, Any]: # Any is for self.dtyp
"""
_k = _DtypeKind
if self.dtype[0] in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL):
buffer = _PandasBuffer(self._col.to_numpy())
buffer = _PandasBuffer(
self._col.to_numpy(), allow_copy=self._allow_copy)
dtype = self.dtype
elif self.dtype[0] == _k.CATEGORICAL:
codes = self._col.values.codes
buffer = _PandasBuffer(codes)
buffer = _PandasBuffer(
codes, allow_copy=self._allow_copy)
dtype = self._dtype_from_pandasdtype(codes.dtype)
elif self.dtype[0] == _k.STRING:
# Marshal the strings from a NumPy object array into a byte array
Expand Down Expand Up @@ -677,7 +699,8 @@ class _PandasDataFrame:
``pd.DataFrame.__dataframe__`` as objects with the methods and
attributes defined on this class.
"""
def __init__(self, df : pd.DataFrame, nan_as_null : bool = False) -> None:
def __init__(self, df : pd.DataFrame, nan_as_null : bool = False,
allow_copy : bool = True) -> None:
"""
Constructor - an instance of this (private) class is returned from
`pd.DataFrame.__dataframe__`.
Expand All @@ -688,6 +711,7 @@ def __init__(self, df : pd.DataFrame, nan_as_null : bool = False) -> None:
# This currently has no effect; once support for nullable extension
# dtypes is added, this value should be propagated to columns.
self._nan_as_null = nan_as_null
self._allow_copy = allow_copy

@property
def metadata(self):
Expand All @@ -708,13 +732,16 @@ def column_names(self) -> Iterable[str]:
return self._df.columns.tolist()

def get_column(self, i: int) -> _PandasColumn:
return _PandasColumn(self._df.iloc[:, i])
return _PandasColumn(
self._df.iloc[:, i], allow_copy=self._allow_copy)

def get_column_by_name(self, name: str) -> _PandasColumn:
return _PandasColumn(self._df[name])
return _PandasColumn(
self._df[name], allow_copy=self._allow_copy)

def get_columns(self) -> Iterable[_PandasColumn]:
return [_PandasColumn(self._df[name]) for name in self._df.columns]
return [_PandasColumn(self._df[name], allow_copy=self._allow_copy)
for name in self._df.columns]

def select_columns(self, indices: Sequence[int]) -> '_PandasDataFrame':
if not isinstance(indices, collections.Sequence):
Expand Down Expand Up @@ -752,13 +779,14 @@ def test_mixed_intfloat():


def test_noncontiguous_columns():
# Currently raises: TBD whether it should work or not, see code comment
# where the RuntimeError is raised.
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
df = pd.DataFrame(arr)
assert df[0].to_numpy().strides == (24,)
pytest.raises(RuntimeError, from_dataframe, df)
#df2 = from_dataframe(df)
df = pd.DataFrame(arr, columns=['a', 'b', 'c'])
assert df['a'].to_numpy().strides == (24,)
df2 = from_dataframe(df) # uses default of allow_copy=True
tm.assert_frame_equal(df, df2)

with pytest.raises(RuntimeError):
from_dataframe(df, allow_copy=False)


def test_categorical_dtype():
Expand Down

0 comments on commit bcb5024

Please sign in to comment.