Skip to content

Commit

Permalink
REF: avoid special case in DTA/TDA.median, flesh out tests (#37423)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Oct 31, 2020
1 parent dcde1f4 commit 86ee235
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 16 deletions.
21 changes: 10 additions & 11 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,21 +1356,20 @@ def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwarg
if axis is not None and abs(axis) >= self.ndim:
raise ValueError("abs(axis) must be less than ndim")

if self.size == 0:
if self.ndim == 1 or axis is None:
return NaT
shape = list(self.shape)
del shape[axis]
shape = [1 if x == 0 else x for x in shape]
result = np.empty(shape, dtype="i8")
result.fill(iNaT)
if is_period_dtype(self.dtype):
# pass datetime64 values to nanops to get correct NaT semantics
result = nanops.nanmedian(
self._ndarray.view("M8[ns]"), axis=axis, skipna=skipna
)
result = result.view("i8")
if axis is None or self.ndim == 1:
return self._box_func(result)
return self._from_backing_data(result)

mask = self.isna()
result = nanops.nanmedian(self.asi8, axis=axis, skipna=skipna, mask=mask)
result = nanops.nanmedian(self._ndarray, axis=axis, skipna=skipna)
if axis is None or self.ndim == 1:
return self._box_func(result)
return self._from_backing_data(result.astype("i8"))
return self._from_backing_data(result)


class DatelikeOps(DatetimeLikeArrayMixin):
Expand Down
13 changes: 11 additions & 2 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools
import operator
from typing import Any, Optional, Tuple, Union, cast
import warnings

import numpy as np

Expand Down Expand Up @@ -645,7 +646,11 @@ def get_median(x):
mask = notna(x)
if not skipna and not mask.all():
return np.nan
return np.nanmedian(x[mask])
with warnings.catch_warnings():
# Suppress RuntimeWarning about All-NaN slice
warnings.filterwarnings("ignore", "All-NaN slice encountered")
res = np.nanmedian(x[mask])
return res

values, mask, dtype, _, _ = _get_values(values, skipna, mask=mask)
if not is_float_dtype(values.dtype):
Expand Down Expand Up @@ -673,7 +678,11 @@ def get_median(x):
)

# fastpath for the skipna case
return _wrap_results(np.nanmedian(values, axis), dtype)
with warnings.catch_warnings():
# Suppress RuntimeWarning about All-NaN slice
warnings.filterwarnings("ignore", "All-NaN slice encountered")
res = np.nanmedian(values, axis)
return _wrap_results(res, dtype)

# must return the correct shape, but median is not defined for the
# empty set so return nans of shape "everything but the passed axis"
Expand Down
52 changes: 50 additions & 2 deletions pandas/tests/arrays/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import pytz

from pandas._libs import OutOfBoundsDatetime, Timestamp
from pandas._libs import NaT, OutOfBoundsDatetime, Timestamp
from pandas.compat.numpy import np_version_under1p18

import pandas as pd
Expand Down Expand Up @@ -456,6 +456,54 @@ def test_shift_fill_int_deprecated(self):
expected[1:] = arr[:-1]
tm.assert_equal(result, expected)

def test_median(self, arr1d):
arr = arr1d
if len(arr) % 2 == 0:
# make it easier to define `expected`
arr = arr[:-1]

expected = arr[len(arr) // 2]

result = arr.median()
assert type(result) is type(expected)
assert result == expected

arr[len(arr) // 2] = NaT
if not isinstance(expected, Period):
expected = arr[len(arr) // 2 - 1 : len(arr) // 2 + 2].mean()

assert arr.median(skipna=False) is NaT

result = arr.median()
assert type(result) is type(expected)
assert result == expected

assert arr[:0].median() is NaT
assert arr[:0].median(skipna=False) is NaT

# 2d Case
arr2 = arr.reshape(-1, 1)

result = arr2.median(axis=None)
assert type(result) is type(expected)
assert result == expected

assert arr2.median(axis=None, skipna=False) is NaT

result = arr2.median(axis=0)
expected2 = type(arr)._from_sequence([expected], dtype=arr.dtype)
tm.assert_equal(result, expected2)

result = arr2.median(axis=0, skipna=False)
expected2 = type(arr)._from_sequence([NaT], dtype=arr.dtype)
tm.assert_equal(result, expected2)

result = arr2.median(axis=1)
tm.assert_equal(result, arr)

result = arr2.median(axis=1, skipna=False)
tm.assert_equal(result, arr)


class TestDatetimeArray(SharedTests):
index_cls = pd.DatetimeIndex
Expand All @@ -465,7 +513,7 @@ class TestDatetimeArray(SharedTests):
@pytest.fixture
def arr1d(self, tz_naive_fixture, freqstr):
tz = tz_naive_fixture
dti = pd.date_range("2016-01-01 01:01:00", periods=3, freq=freqstr, tz=tz)
dti = pd.date_range("2016-01-01 01:01:00", periods=5, freq=freqstr, tz=tz)
dta = dti._data
return dta

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/test_datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def test_median_empty(self, skipna, tz):
tm.assert_equal(result, expected)

result = arr.median(axis=1, skipna=skipna)
expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype)
expected = type(arr)._from_sequence([], dtype=arr.dtype)
tm.assert_equal(result, expected)

def test_median(self, arr1d):
Expand Down

0 comments on commit 86ee235

Please sign in to comment.