Skip to content

Commit

Permalink
BUG: Series.where casting dt64 to int64 (pandas-dev#38073)
Browse files Browse the repository at this point in the history
* ENH: support 2D in DatetimeArray._from_sequence

* BUG: Series.where casting dt64 to int64

* whatsnew

* move whatsnew

* use fixture, remove unnecessary check
  • Loading branch information
jbrockmendel authored and luckyvs1 committed Jan 20, 2021
1 parent 9b8641c commit 4b89980
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 14 deletions.
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ Datetimelike
- Bug in :meth:`DataFrame.first` and :meth:`Series.first` returning two months for offset one month when first day is last calendar day (:issue:`29623`)
- Bug in constructing a :class:`DataFrame` or :class:`Series` with mismatched ``datetime64`` data and ``timedelta64`` dtype, or vice-versa, failing to raise ``TypeError`` (:issue:`38575`)
- Bug in :meth:`DatetimeIndex.intersection`, :meth:`DatetimeIndex.symmetric_difference`, :meth:`PeriodIndex.intersection`, :meth:`PeriodIndex.symmetric_difference` always returning object-dtype when operating with :class:`CategoricalIndex` (:issue:`38741`)
- Bug in :meth:`Series.where` incorrectly casting ``datetime64`` values to ``int64`` (:issue:`37682`)
-

Timedelta
^^^^^^^^^
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def __init__(self, values: Union[np.ndarray, "PandasArray"], copy: bool = False)
f"'values' must be a NumPy array, not {type(values).__name__}"
)

if values.ndim != 1:
if values.ndim == 0:
# Technically we support 2, but do not advertise that fact.
raise ValueError("PandasArray must be 1-dimensional.")

if copy:
Expand Down
50 changes: 38 additions & 12 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,22 @@ def shift(self, periods: int, axis: int = 0, fill_value=None):

return [self.make_block(new_values)]

def _maybe_reshape_where_args(self, values, other, cond, axis):
transpose = self.ndim == 2

cond = _extract_bool_array(cond)

# If the default broadcasting would go in the wrong direction, then
# explicitly reshape other instead
if getattr(other, "ndim", 0) >= 1:
if values.ndim - 1 == other.ndim and axis == 1:
other = other.reshape(tuple(other.shape + (1,)))
elif transpose and values.ndim == self.ndim - 1:
# TODO(EA2D): not neceesssary with 2D EAs
cond = cond.T

return other, cond

def where(
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0
) -> List["Block"]:
Expand All @@ -1354,7 +1370,6 @@ def where(
"""
import pandas.core.computation.expressions as expressions

cond = _extract_bool_array(cond)
assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame))

assert errors in ["raise", "ignore"]
Expand All @@ -1365,17 +1380,7 @@ def where(
if transpose:
values = values.T

# If the default broadcasting would go in the wrong direction, then
# explicitly reshape other instead
if getattr(other, "ndim", 0) >= 1:
if values.ndim - 1 == other.ndim and axis == 1:
other = other.reshape(tuple(other.shape + (1,)))
elif transpose and values.ndim == self.ndim - 1:
# TODO(EA2D): not neceesssary with 2D EAs
cond = cond.T

if not hasattr(cond, "shape"):
raise ValueError("where must have a condition that is ndarray like")
other, cond = self._maybe_reshape_where_args(values, other, cond, axis)

if cond.ravel("K").all():
result = values
Expand Down Expand Up @@ -2128,6 +2133,26 @@ def to_native_types(self, na_rep="NaT", **kwargs):
result = arr._format_native_types(na_rep=na_rep, **kwargs)
return self.make_block(result)

def where(
self, other, cond, errors="raise", try_cast: bool = False, axis: int = 0
) -> List["Block"]:
# TODO(EA2D): reshape unnecessary with 2D EAs
arr = self.array_values().reshape(self.shape)

other, cond = self._maybe_reshape_where_args(arr, other, cond, axis)

try:
res_values = arr.T.where(cond, other).T
except (ValueError, TypeError):
return super().where(
other, cond, errors=errors, try_cast=try_cast, axis=axis
)

# TODO(EA2D): reshape not needed with 2D EAs
res_values = res_values.reshape(self.values.shape)
nb = self.make_block_same_class(res_values)
return [nb]

def _can_hold_element(self, element: Any) -> bool:
arr = self.array_values()

Expand Down Expand Up @@ -2196,6 +2221,7 @@ class DatetimeTZBlock(ExtensionBlock, DatetimeBlock):
fillna = DatetimeBlock.fillna # i.e. Block.fillna
fill_value = DatetimeBlock.fill_value
_can_hold_na = DatetimeBlock._can_hold_na
where = DatetimeBlock.where

array_values = ExtensionBlock.array_values

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def test_array_inference_fails(data):
tm.assert_extension_array_equal(result, expected)


@pytest.mark.parametrize("data", [np.array([[1, 2], [3, 4]]), [[1, 2], [3, 4]]])
@pytest.mark.parametrize("data", [np.array(0)])
def test_nd_raises(data):
with pytest.raises(ValueError, match="PandasArray must be 1-dimensional"):
pd.array(data, dtype="int64")
Expand Down
32 changes: 32 additions & 0 deletions pandas/tests/series/indexing/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,3 +464,35 @@ def test_where_categorical(klass):
df = klass(["A", "A", "B", "B", "C"], dtype="category")
res = df.where(df != "C")
tm.assert_equal(exp, res)


def test_where_datetimelike_categorical(tz_naive_fixture):
# GH#37682
tz = tz_naive_fixture

dr = pd.date_range("2001-01-01", periods=3, tz=tz)._with_freq(None)
lvals = pd.DatetimeIndex([dr[0], dr[1], pd.NaT])
rvals = pd.Categorical([dr[0], pd.NaT, dr[2]])

mask = np.array([True, True, False])

# DatetimeIndex.where
res = lvals.where(mask, rvals)
tm.assert_index_equal(res, dr)

# DatetimeArray.where
res = lvals._data.where(mask, rvals)
tm.assert_datetime_array_equal(res, dr._data)

# Series.where
res = Series(lvals).where(mask, rvals)
tm.assert_series_equal(res, Series(dr))

# DataFrame.where
if tz is None:
res = pd.DataFrame(lvals).where(mask[:, None], pd.DataFrame(rvals))
else:
with pytest.xfail(reason="frame._values loses tz"):
res = pd.DataFrame(lvals).where(mask[:, None], pd.DataFrame(rvals))

tm.assert_frame_equal(res, pd.DataFrame(dr))

0 comments on commit 4b89980

Please sign in to comment.