Skip to content

Commit

Permalink
API (string): return str dtype for .dt methods, DatetimeIndex methods (
Browse files Browse the repository at this point in the history
…#59526)

* API (string): return str dtype for .dt methods, DatetimeIndex methods

* mypy fixup
  • Loading branch information
jbrockmendel authored Aug 16, 2024
1 parent 96a7462 commit ff28a3e
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 17 deletions.
5 changes: 5 additions & 0 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np

from pandas._config import using_string_dtype
from pandas._config.config import get_option

from pandas._libs import (
Expand Down Expand Up @@ -1759,6 +1760,10 @@ def strftime(self, date_format: str) -> npt.NDArray[np.object_]:
dtype='object')
"""
result = self._format_native_types(date_format=date_format, na_rep=np.nan)
if using_string_dtype():
from pandas import StringDtype

return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
return result.astype(object, copy=False)


Expand Down
16 changes: 16 additions & 0 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np

from pandas._config import using_string_dtype
from pandas._config.config import get_option

from pandas._libs import (
Expand Down Expand Up @@ -1332,6 +1333,13 @@ def month_name(self, locale=None) -> npt.NDArray[np.object_]:
values, "month_name", locale=locale, reso=self._creso
)
result = self._maybe_mask_results(result, fill_value=None)
if using_string_dtype():
from pandas import (
StringDtype,
array as pd_array,
)

return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
return result

def day_name(self, locale=None) -> npt.NDArray[np.object_]:
Expand Down Expand Up @@ -1393,6 +1401,14 @@ def day_name(self, locale=None) -> npt.NDArray[np.object_]:
values, "day_name", locale=locale, reso=self._creso
)
result = self._maybe_mask_results(result, fill_value=None)
if using_string_dtype():
# TODO: no tests that check for dtype of result as of 2024-08-15
from pandas import (
StringDtype,
array as pd_array,
)

return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
return result

@property
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def _engine_type(self) -> type[libindex.DatetimeEngine]:
@doc(DatetimeArray.strftime)
def strftime(self, date_format) -> Index:
arr = self._data.strftime(date_format)
return Index(arr, name=self.name, dtype=object)
return Index(arr, name=self.name, dtype=arr.dtype)

@doc(DatetimeArray.tz_convert)
def tz_convert(self, tz) -> Self:
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def fget(self):
return type(self)._simple_new(result, name=self.name)
elif isinstance(result, ABCDataFrame):
return result.set_index(self)
return Index(result, name=self.name)
return Index(result, name=self.name, dtype=result.dtype)
return result

def fset(self, value) -> None:
Expand All @@ -101,7 +101,7 @@ def method(self, *args, **kwargs): # type: ignore[misc]
return type(self)._simple_new(result, name=self.name)
elif isinstance(result, ABCDataFrame):
return result.set_index(self)
return Index(result, name=self.name)
return Index(result, name=self.name, dtype=result.dtype)
return result

# error: "property" has no attribute "__name__"
Expand Down
24 changes: 16 additions & 8 deletions pandas/tests/arrays/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,20 +891,24 @@ def test_concat_same_type_different_freq(self, unit):

tm.assert_datetime_array_equal(result, expected)

def test_strftime(self, arr1d):
def test_strftime(self, arr1d, using_infer_string):
arr = arr1d

result = arr.strftime("%Y %b")
expected = np.array([ts.strftime("%Y %b") for ts in arr], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)

def test_strftime_nat(self):
def test_strftime_nat(self, using_infer_string):
# GH 29578
arr = DatetimeIndex(["2019-01-01", NaT])._data

result = arr.strftime("%Y-%m-%d")
expected = np.array(["2019-01-01", np.nan], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)


class TestTimedeltaArray(SharedTests):
Expand Down Expand Up @@ -1161,20 +1165,24 @@ def test_array_interface(self, arr1d):
expected = np.asarray(arr).astype("S20")
tm.assert_numpy_array_equal(result, expected)

def test_strftime(self, arr1d):
def test_strftime(self, arr1d, using_infer_string):
arr = arr1d

result = arr.strftime("%Y")
expected = np.array([per.strftime("%Y") for per in arr], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)

def test_strftime_nat(self):
def test_strftime_nat(self, using_infer_string):
# GH 29578
arr = PeriodArray(PeriodIndex(["2019-01-01", NaT], dtype="period[D]"))

result = arr.strftime("%Y-%m-%d")
expected = np.array(["2019-01-01", np.nan], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)


@pytest.mark.parametrize(
Expand Down
1 change: 0 additions & 1 deletion pandas/tests/io/excel/test_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def test_excel_multindex_roundtrip(
)
tm.assert_frame_equal(df, act, check_names=check_names)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_read_excel_parse_dates(self, tmp_excel):
# see gh-11544, gh-12051
df = DataFrame(
Expand Down
8 changes: 3 additions & 5 deletions pandas/tests/series/accessors/test_dt_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Period,
PeriodIndex,
Series,
StringDtype,
TimedeltaIndex,
date_range,
period_range,
Expand Down Expand Up @@ -513,7 +514,6 @@ def test_dt_accessor_datetime_name_accessors(self, time_locale):
ser = pd.concat([ser, Series([pd.NaT])])
assert np.isnan(ser.dt.month_name(locale=time_locale).iloc[-1])

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_strftime(self):
# GH 10086
ser = Series(date_range("20130101", periods=5))
Expand Down Expand Up @@ -584,10 +584,9 @@ def test_strftime_period_days(self, using_infer_string):
dtype="=U10",
)
if using_infer_string:
expected = expected.astype("str")
expected = expected.astype(StringDtype(na_value=np.nan))
tm.assert_index_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_strftime_dt64_microsecond_resolution(self):
ser = Series([datetime(2013, 1, 1, 2, 32, 59), datetime(2013, 1, 2, 14, 32, 1)])
result = ser.dt.strftime("%Y-%m-%d %H:%M:%S")
Expand Down Expand Up @@ -620,7 +619,6 @@ def test_strftime_period_minutes(self):
)
tm.assert_series_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.parametrize(
"data",
[
Expand All @@ -643,7 +641,7 @@ def test_strftime_all_nat(self, data):
ser = Series(data)
with tm.assert_produces_warning(None):
result = ser.dt.strftime("%Y-%m-%d")
expected = Series([np.nan], dtype=object)
expected = Series([np.nan], dtype="str")
tm.assert_series_equal(result, expected)

def test_valid_dt_with_missing_values(self):
Expand Down

0 comments on commit ff28a3e

Please sign in to comment.