Skip to content

Commit

Permalink
Also support arrow ExtensionTypes via to_pandas_dtype (without having…
Browse files Browse the repository at this point in the history
… pandas metadata)
  • Loading branch information
jorisvandenbossche committed Oct 25, 2019
1 parent 5df424b commit 3d669bc
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 8 deletions.
28 changes: 21 additions & 7 deletions python/pyarrow/pandas_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,13 +710,14 @@ def table_to_blockmanager(options, table, categories=None,
table, index = _reconstruct_index(table, index_descriptors,
all_columns)
ext_columns_dtypes = _get_extension_dtypes(
all_columns, extension_columns)
table, all_columns, extension_columns)
else:
index = _pandas_api.pd.RangeIndex(table.num_rows)
if extension_columns:
raise ValueError("extension_columns not supported if there is "
"no pandas_metadata")
ext_columns_dtypes = {}
ext_columns_dtypes = _get_extension_dtypes(
table, [], extension_columns)

_check_data_column_metadata_consistency(all_columns)
blocks = _table_to_blocks(options, table, categories, ext_columns_dtypes)
Expand All @@ -726,7 +727,7 @@ def table_to_blockmanager(options, table, categories=None,
return BlockManager(blocks, axes)


def _get_extension_dtypes(columns, extension_columns):
def _get_extension_dtypes(table, columns_metadata, extension_columns):
"""
Based on the stored column pandas metadata, infer which columns
should be converted to a pandas extension dtype.
Expand All @@ -743,21 +744,34 @@ def _get_extension_dtypes(columns, extension_columns):
# older pandas version that does not yet support extension dtypes
if _pandas_api.extension_dtype is None:
if extension_columns is not None:
raise ValueError("not supported")
raise ValueError(
"Converting to pandas ExtensionDtypes is not supported")
return ext_columns

if extension_columns is None:
# infer the extension columns
for col_meta in columns:
# infer the extension columns from the pandas metadata
for col_meta in columns_metadata:
name = col_meta['name']
pandas_dtype = _pandas_api.pandas_dtype(col_meta['numpy_type'])
if isinstance(pandas_dtype, _pandas_api.extension_dtype):
if hasattr(pandas_dtype, "__from_arrow__"):
ext_columns[name] = pandas_dtype
# infer from extension type in the schema
for field in table.schema:
typ = field.type
if isinstance(typ, pa.BaseExtensionType):
try:
pandas_dtype = typ.to_pandas_dtype()
except NotImplementedError:
pass
else:
ext_columns[field.name] = pandas_dtype

else:
# get the extension dtype for the specified columns
for name in extension_columns:
col_meta = [meta for meta in columns if meta['name'] == name][0]
col_meta = [
meta for meta in columns_metadata if meta['name'] == name][0]
pandas_dtype = _pandas_api.pandas_dtype(col_meta['numpy_type'])
if not isinstance(pandas_dtype, _pandas_api.extension_dtype):
raise ValueError("not an extension dtype")
Expand Down
40 changes: 39 additions & 1 deletion python/pyarrow/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3231,7 +3231,7 @@ def _Int64Dtype__from_arrow__(self, array):
# TODO: do we require handling of chunked arrays in the protocol?
arr = array.chunk(0)
buflist = arr.buffers()
data = np.frombuffer(buflist[-1], dtype=arr.type.to_pandas_dtype())[
data = np.frombuffer(buflist[-1], dtype='int64')[
arr.offset:arr.offset + len(arr)]
bitmask = buflist[0]
if bitmask is not None:
Expand Down Expand Up @@ -3295,6 +3295,44 @@ def test_convert_to_extension_array():
del pd.Int64Dtype.__from_arrow__


class MyCustomIntegerType(pa.PyExtensionType):

def __init__(self):
pa.PyExtensionType.__init__(self, pa.int64())

def __reduce__(self):
return MyCustomIntegerType, ()

def to_pandas_dtype(self):
return pd.Int64Dtype()


def test_conversion_extensiontype_to_extensionarray():
# converting extension type to linked pandas ExtensionDtype/Array
import pandas.core.internals as _int

storage = pa.array([1, 2, 3, 4], pa.int64())
arr = pa.ExtensionArray.from_storage(MyCustomIntegerType(), storage)
table = pa.table({'a': arr})

with pytest.raises(ValueError):
table.to_pandas()

try:
# patch pandas Int64Dtype to have the protocol method
pd.Int64Dtype.__from_arrow__ = _Int64Dtype__from_arrow__

# extension type points to Int64Dtype, which knows how to create a
# pandas ExtensionArray
result = table.to_pandas()
assert isinstance(result._data.blocks[0], _int.ExtensionBlock)
expected = pd.DataFrame({'a': pd.array([1, 2, 3, 4], dtype='Int64')})
tm.assert_frame_equal(result, expected)

finally:
del pd.Int64Dtype.__from_arrow__


# ----------------------------------------------------------------------
# Legacy metadata compatibility tests

Expand Down

0 comments on commit 3d669bc

Please sign in to comment.