diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index f99af9853e518..e29136cc1aef0 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -491,7 +491,8 @@ cdef class _PandasConvertible: bint deduplicate_objects=True, bint ignore_metadata=False, bint split_blocks=False, - bint self_destruct=False + bint self_destruct=False, + types_mapping=None ): """ Convert to a pandas-compatible NumPy array or DataFrame, as appropriate @@ -531,6 +532,11 @@ cdef class _PandasConvertible: memory while converting the Arrow object to pandas. If you use the object after calling to_pandas with this option it will crash your program + types_mapping : dict, default None + A mapping from pyarrow DataType to pandas ExtensionDtype. This can + be used to override the default pandas type for conversion of + built-in pyarrow types or in absence of pandas_metadata in the + Table schema. Returns ------- @@ -548,7 +554,8 @@ cdef class _PandasConvertible: self_destruct=self_destruct ) return self._to_pandas(options, categories=categories, - ignore_metadata=ignore_metadata) + ignore_metadata=ignore_metadata, + types_mapping=types_mapping) cdef PandasOptions _convert_pandas_options(dict options): diff --git a/python/pyarrow/pandas_compat.py b/python/pyarrow/pandas_compat.py index bab12bf1de280..b986e84d0ca30 100644 --- a/python/pyarrow/pandas_compat.py +++ b/python/pyarrow/pandas_compat.py @@ -741,7 +741,8 @@ def make_datetimetz(tz): def table_to_blockmanager(options, table, categories=None, - extension_columns=None, ignore_metadata=False): + extension_columns=None, ignore_metadata=False, + types_mapping=None): from pandas.core.internals import BlockManager all_columns = [] @@ -756,14 +757,14 @@ def table_to_blockmanager(options, table, categories=None, table, index = _reconstruct_index(table, index_descriptors, all_columns) ext_columns_dtypes = _get_extension_dtypes( - table, all_columns, extension_columns) + table, all_columns, extension_columns, types_mapping) else: index = _pandas_api.pd.RangeIndex(table.num_rows) if extension_columns: raise ValueError("extension_columns not supported if there is " "no pandas_metadata") ext_columns_dtypes = _get_extension_dtypes( - table, [], extension_columns) + table, [], extension_columns, types_mapping) _check_data_column_metadata_consistency(all_columns) columns = _deserialize_column_index(table, all_columns, column_indexes) @@ -782,7 +783,8 @@ def table_to_blockmanager(options, table, categories=None, ]) -def _get_extension_dtypes(table, columns_metadata, extension_columns): +def _get_extension_dtypes(table, columns_metadata, extension_columns, + types_mapping=None): """ Based on the stored column pandas metadata and the extension types in the arrow schema, infer which columns should be converted to a @@ -840,6 +842,12 @@ def _get_extension_dtypes(table, columns_metadata, extension_columns): "converted to extension dtype") ext_columns[name] = pandas_dtype + if types_mapping: + for field in table.schema: + typ = field.type + if typ in types_mapping: + ext_columns[field.name] = types_mapping[typ] + return ext_columns diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 1cd575e0c843d..4d7176d0c6a3b 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -1361,11 +1361,13 @@ cdef class Table(_PandasConvertible): return result - def _to_pandas(self, options, categories=None, ignore_metadata=False): + def _to_pandas(self, options, categories=None, ignore_metadata=False, + types_mapping=None): from pyarrow.pandas_compat import table_to_blockmanager mgr = table_to_blockmanager( options, self, categories, - ignore_metadata=ignore_metadata) + ignore_metadata=ignore_metadata, + types_mapping=types_mapping) return pandas_api.data_frame(mgr) def to_pydict(self): diff --git a/python/pyarrow/tests/test_pandas.py b/python/pyarrow/tests/test_pandas.py index 61d09e2b20baf..af12c6bd2b4a9 100644 --- a/python/pyarrow/tests/test_pandas.py +++ b/python/pyarrow/tests/test_pandas.py @@ -3544,6 +3544,21 @@ def test_conversion_extensiontype_to_extensionarray(monkeypatch): table.to_pandas() +def test_to_pandas_extension_dtypes_mapping(): + if LooseVersion(pd.__version__) < "0.26.0.dev": + pytest.skip("Conversion to pandas IntegerArray not yet supported") + + table = pa.table({'a': pa.array([1, 2, 3], pa.int64())}) + + # default use numpy dtype + result = table.to_pandas() + assert result['a'].dtype == np.dtype('int64') + + # specify to override the default + result = table.to_pandas(types_mapping={pa.int64(): pd.Int64Dtype()}) + assert isinstance(result['a'].dtype, pd.Int64Dtype) + + # ---------------------------------------------------------------------- # Legacy metadata compatibility tests