diff --git a/python/pyarrow/tests/interchange/test_conversion.py b/python/pyarrow/tests/interchange/test_conversion.py index 6565312118e2a..0680d9c4ec11d 100644 --- a/python/pyarrow/tests/interchange/test_conversion.py +++ b/python/pyarrow/tests/interchange/test_conversion.py @@ -171,11 +171,6 @@ def test_pandas_roundtrip(uint, int, float, np_float): "c": pa.array(np.array(arr, dtype=np_float), type=float), } ) - if Version(pd.__version__) >= Version("2.0"): - # See https://github.com/pandas-dev/pandas/issues/50554 - table["d"] = ["a", "", "c"] - # large string is not supported by pandas implementation - from pandas.api.interchange import ( from_dataframe as pandas_from_dataframe ) @@ -192,6 +187,34 @@ def test_pandas_roundtrip(uint, int, float, np_float): assert table_protocol.column_names() == result_protocol.column_names() +@pytest.mark.pandas +def test_roundtrip_pandas_string(): + # See https://github.com/pandas-dev/pandas/issues/50554 + if Version(pd.__version__) < Version("1.6"): + pytest.skip(" Column.size() called as a method in pandas 2.0.0") + + # large string is not supported by pandas implementation + table = pa.table({"a": pa.array(["a", "", "c"])}) + + from pandas.api.interchange import ( + from_dataframe as pandas_from_dataframe + ) + pandas_df = pandas_from_dataframe(table) + result = pi.from_dataframe(pandas_df) + + assert result[0].to_pylist() == table[0].to_pylist() + assert pa.types.is_string(table[0].type) + assert pa.types.is_large_string(result[0].type) + + table_protocol = table.__dataframe__() + result_protocol = result.__dataframe__() + + assert table_protocol.num_columns() == result_protocol.num_columns() + assert table_protocol.num_rows() == result_protocol.num_rows() + assert table_protocol.num_chunks() == result_protocol.num_chunks() + assert table_protocol.column_names() == result_protocol.column_names() + + @pytest.mark.pandas def test_roundtrip_pandas_boolean(): if Version(pd.__version__) < Version("1.5.0"): @@ -228,12 +251,12 @@ def test_roundtrip_pandas_datetime(unit): dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)] table = pa.table({"a": pa.array(dt_arr, type=pa.timestamp(unit))}) - if Version(pd.__version__) >= Version("2.0"): - expected = table - else: + if Version(pd.__version__) < Version("1.6"): # pandas < 2.0 always creates datetime64 in "ns" # resolution expected = pa.table({"a": pa.array(dt_arr, type=pa.timestamp('ns'))}) + else: + expected = table from pandas.api.interchange import ( from_dataframe as pandas_from_dataframe