Skip to content

Commit

Permalink
Backport PR #48489 on branch 1.5.x (BUG: fix test_arrow.py tests) (#4…
Browse files Browse the repository at this point in the history
…8532)

Backport PR #48489: BUG: fix test_arrow.py tests

Co-authored-by: jbrockmendel <jbrockmendel@gmail.com>
  • Loading branch information
meeseeksmachine and jbrockmendel authored Sep 13, 2022
1 parent 5817209 commit ecc8ab4
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 148 deletions.
8 changes: 4 additions & 4 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,10 +1188,10 @@ def needs_i8_conversion(arr_or_dtype) -> bool:
"""
if arr_or_dtype is None:
return False
if isinstance(arr_or_dtype, (np.dtype, ExtensionDtype)):
# fastpath
dtype = arr_or_dtype
return dtype.kind in ["m", "M"] or dtype.type is Period
if isinstance(arr_or_dtype, np.dtype):
return arr_or_dtype.kind in ["m", "M"]
elif isinstance(arr_or_dtype, ExtensionDtype):
return isinstance(arr_or_dtype, (PeriodDtype, DatetimeTZDtype))

try:
dtype = get_dtype(arr_or_dtype)
Expand Down
13 changes: 9 additions & 4 deletions pandas/core/dtypes/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
is_dtype_equal,
is_sparse,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.dtypes import (
DatetimeTZDtype,
ExtensionDtype,
)
from pandas.core.dtypes.generic import (
ABCCategoricalIndex,
ABCExtensionArray,
Expand Down Expand Up @@ -103,10 +106,12 @@ def is_nonempty(x) -> bool:
# ea_compat_axis see GH#39574
to_concat = non_empties

dtypes = {obj.dtype for obj in to_concat}
kinds = {obj.dtype.kind for obj in to_concat}
contains_datetime = any(kind in ["m", "M"] for kind in kinds) or any(
isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat
)
contains_datetime = any(
isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in ["m", "M"]
for dtype in dtypes
) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)

all_empty = not len(non_empties)
single_dtype = len({x.dtype for x in to_concat}) == 1
Expand Down
19 changes: 6 additions & 13 deletions pandas/core/internals/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)
from pandas.core.dtypes.common import (
is_1d_only_ea_dtype,
is_datetime64tz_dtype,
is_dtype_equal,
is_scalar,
needs_i8_conversion,
Expand All @@ -38,7 +37,10 @@
cast_to_common_type,
concat_compat,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.dtypes import (
DatetimeTZDtype,
ExtensionDtype,
)
from pandas.core.dtypes.missing import (
is_valid_na_for_dtype,
isna,
Expand Down Expand Up @@ -147,16 +149,6 @@ def concat_arrays(to_concat: list) -> ArrayLike:
else:
target_dtype = find_common_type([arr.dtype for arr in to_concat_no_proxy])

if target_dtype.kind in ["m", "M"]:
# for datetimelike use DatetimeArray/TimedeltaArray concatenation
# don't use arr.astype(target_dtype, copy=False), because that doesn't
# work for DatetimeArray/TimedeltaArray (returns ndarray)
to_concat = [
arr.to_array(target_dtype) if isinstance(arr, NullArrayProxy) else arr
for arr in to_concat
]
return type(to_concat_no_proxy[0])._concat_same_type(to_concat, axis=0)

to_concat = [
arr.to_array(target_dtype)
if isinstance(arr, NullArrayProxy)
Expand Down Expand Up @@ -471,7 +463,8 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
if len(values) and values[0] is None:
fill_value = None

if is_datetime64tz_dtype(empty_dtype):
if isinstance(empty_dtype, DatetimeTZDtype):
# NB: exclude e.g. pyarrow[dt64tz] dtypes
i8values = np.full(self.shape, fill_value.value)
return DatetimeArray(i8values, dtype=empty_dtype)

Expand Down
13 changes: 7 additions & 6 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
is_bool,
is_bool_dtype,
is_categorical_dtype,
is_datetime64tz_dtype,
is_dtype_equal,
is_extension_array_dtype,
is_float_dtype,
Expand All @@ -62,6 +61,7 @@
is_object_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
Expand Down Expand Up @@ -1349,12 +1349,12 @@ def _maybe_coerce_merge_keys(self) -> None:
raise ValueError(msg)
elif not needs_i8_conversion(lk.dtype) and needs_i8_conversion(rk.dtype):
raise ValueError(msg)
elif is_datetime64tz_dtype(lk.dtype) and not is_datetime64tz_dtype(
rk.dtype
elif isinstance(lk.dtype, DatetimeTZDtype) and not isinstance(
rk.dtype, DatetimeTZDtype
):
raise ValueError(msg)
elif not is_datetime64tz_dtype(lk.dtype) and is_datetime64tz_dtype(
rk.dtype
elif not isinstance(lk.dtype, DatetimeTZDtype) and isinstance(
rk.dtype, DatetimeTZDtype
):
raise ValueError(msg)

Expand Down Expand Up @@ -2280,9 +2280,10 @@ def _factorize_keys(
rk = extract_array(rk, extract_numpy=True, extract_range=True)
# TODO: if either is a RangeIndex, we can likely factorize more efficiently?

if is_datetime64tz_dtype(lk.dtype) and is_datetime64tz_dtype(rk.dtype):
if isinstance(lk.dtype, DatetimeTZDtype) and isinstance(rk.dtype, DatetimeTZDtype):
# Extract the ndarray (UTC-localized) values
# Note: we dont need the dtypes to match, as these can still be compared
# TODO(non-nano): need to make sure resolutions match
lk = cast("DatetimeArray", lk)._ndarray
rk = cast("DatetimeArray", rk)._ndarray

Expand Down
4 changes: 2 additions & 2 deletions pandas/io/formats/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
is_categorical_dtype,
is_complex_dtype,
is_datetime64_dtype,
is_datetime64tz_dtype,
is_extension_array_dtype,
is_float,
is_float_dtype,
Expand All @@ -79,6 +78,7 @@
is_scalar,
is_timedelta64_dtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.missing import (
isna,
notna,
Expand Down Expand Up @@ -1290,7 +1290,7 @@ def format_array(
fmt_klass: type[GenericArrayFormatter]
if is_datetime64_dtype(values.dtype):
fmt_klass = Datetime64Formatter
elif is_datetime64tz_dtype(values.dtype):
elif isinstance(values.dtype, DatetimeTZDtype):
fmt_klass = Datetime64TZFormatter
elif is_timedelta64_dtype(values.dtype):
fmt_klass = Timedelta64Formatter
Expand Down
120 changes: 1 addition & 119 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,25 +539,13 @@ def test_groupby_extension_apply(
self, data_for_grouping, groupby_apply_op, request
):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
# TODO: Is there a better way to get the "object" ID for groupby_apply_op?
is_object = "object" in request.node.nodeid
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"pyarrow doesn't support factorizing {pa_dtype}",
)
)
elif pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
if is_object:
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason="GH 47514: _concat_datetime expects axis arg.",
)
)
with tm.maybe_produces_warning(
PerformanceWarning, pa_version_under7p0, check_stacklevel=False
):
Expand Down Expand Up @@ -688,70 +676,10 @@ def test_dropna_array(self, data_missing):


class TestBasePrinting(base.BasePrintingTests):
def test_series_repr(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if (
pa.types.is_date(pa_dtype)
or pa.types.is_duration(pa_dtype)
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason="GH 47514: _concat_datetime expects axis arg.",
)
)
super().test_series_repr(data)

def test_dataframe_repr(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if (
pa.types.is_date(pa_dtype)
or pa.types.is_duration(pa_dtype)
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason="GH 47514: _concat_datetime expects axis arg.",
)
)
super().test_dataframe_repr(data)
pass


class TestBaseReshaping(base.BaseReshapingTests):
@pytest.mark.parametrize("in_frame", [True, False])
def test_concat(self, data, in_frame, request):
pa_dtype = data.dtype.pyarrow_dtype
if (
pa.types.is_date(pa_dtype)
or pa.types.is_duration(pa_dtype)
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason="GH 47514: _concat_datetime expects axis arg.",
)
)
super().test_concat(data, in_frame)

@pytest.mark.parametrize("in_frame", [True, False])
def test_concat_all_na_block(self, data_missing, in_frame, request):
pa_dtype = data_missing.dtype.pyarrow_dtype
if (
pa.types.is_date(pa_dtype)
or pa.types.is_duration(pa_dtype)
or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason="GH 47514: _concat_datetime expects axis arg.",
)
)
super().test_concat_all_na_block(data_missing, in_frame)

def test_concat_columns(self, data, na_value, request):
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
if pa_version_under2p0 and tz not in (None, "UTC"):
Expand All @@ -772,26 +700,6 @@ def test_concat_extension_arrays_copy_false(self, data, na_value, request):
)
super().test_concat_extension_arrays_copy_false(data, na_value)

def test_concat_with_reindex(self, data, request, using_array_manager):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason="GH 47514: _concat_datetime expects axis arg.",
)
)
elif pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError if not using_array_manager else TypeError,
reason="GH 34986",
)
)
super().test_concat_with_reindex(data)

def test_align(self, data, na_value, request):
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
if pa_version_under2p0 and tz not in (None, "UTC"):
Expand Down Expand Up @@ -832,32 +740,6 @@ def test_merge(self, data, na_value, request):
)
super().test_merge(data, na_value)

def test_merge_on_extension_array(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
super().test_merge_on_extension_array(data)

def test_merge_on_extension_array_duplicates(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
super().test_merge_on_extension_array_duplicates(data)

def test_ravel(self, data, request):
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
if pa_version_under2p0 and tz not in (None, "UTC"):
Expand Down

0 comments on commit ecc8ab4

Please sign in to comment.