Skip to content

Commit

Permalink
Fix the fix for the failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
AlenkaF committed Jan 12, 2023
1 parent 1b5f248 commit 9139444
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions python/pyarrow/tests/interchange/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,6 @@ def test_pandas_roundtrip(uint, int, float, np_float):
"c": pa.array(np.array(arr, dtype=np_float), type=float),
}
)
if Version(pd.__version__) >= Version("2.0"):
# See https://github.com/pandas-dev/pandas/issues/50554
table["d"] = ["a", "", "c"]
# large string is not supported by pandas implementation

from pandas.api.interchange import (
from_dataframe as pandas_from_dataframe
)
Expand All @@ -192,6 +187,34 @@ def test_pandas_roundtrip(uint, int, float, np_float):
assert table_protocol.column_names() == result_protocol.column_names()


@pytest.mark.pandas
def test_roundtrip_pandas_string():
# See https://github.com/pandas-dev/pandas/issues/50554
if Version(pd.__version__) < Version("1.6"):
pytest.skip(" Column.size() called as a method in pandas 2.0.0")

# large string is not supported by pandas implementation
table = pa.table({"a": pa.array(["a", "", "c"])})

from pandas.api.interchange import (
from_dataframe as pandas_from_dataframe
)
pandas_df = pandas_from_dataframe(table)
result = pi.from_dataframe(pandas_df)

assert result[0].to_pylist() == table[0].to_pylist()
assert pa.types.is_string(table[0].type)
assert pa.types.is_large_string(result[0].type)

table_protocol = table.__dataframe__()
result_protocol = result.__dataframe__()

assert table_protocol.num_columns() == result_protocol.num_columns()
assert table_protocol.num_rows() == result_protocol.num_rows()
assert table_protocol.num_chunks() == result_protocol.num_chunks()
assert table_protocol.column_names() == result_protocol.column_names()


@pytest.mark.pandas
def test_roundtrip_pandas_boolean():
if Version(pd.__version__) < Version("1.5.0"):
Expand Down Expand Up @@ -228,12 +251,12 @@ def test_roundtrip_pandas_datetime(unit):
dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)]
table = pa.table({"a": pa.array(dt_arr, type=pa.timestamp(unit))})

if Version(pd.__version__) >= Version("2.0"):
expected = table
else:
if Version(pd.__version__) < Version("1.6"):
# pandas < 2.0 always creates datetime64 in "ns"
# resolution
expected = pa.table({"a": pa.array(dt_arr, type=pa.timestamp('ns'))})
else:
expected = table

from pandas.api.interchange import (
from_dataframe as pandas_from_dataframe
Expand Down

0 comments on commit 9139444

Please sign in to comment.