diff --git a/python/cudf/cudf/core/_internals/timezones.py b/python/cudf/cudf/core/_internals/timezones.py index 269fcf3e37f..29cb9d7bd12 100644 --- a/python/cudf/cudf/core/_internals/timezones.py +++ b/python/cudf/cudf/core/_internals/timezones.py @@ -1,21 +1,50 @@ # Copyright (c) 2023-2024, NVIDIA CORPORATION. from __future__ import annotations +import datetime import os import zoneinfo from functools import lru_cache from typing import TYPE_CHECKING, Literal import numpy as np +import pandas as pd +import cudf from cudf._lib.timezone import make_timezone_transition_table -from cudf.core.column.column import as_column if TYPE_CHECKING: from cudf.core.column.datetime import DatetimeColumn from cudf.core.column.timedelta import TimeDeltaColumn +def get_compatible_timezone(dtype: pd.DatetimeTZDtype) -> pd.DatetimeTZDtype: + """Convert dtype.tz object to zoneinfo object if possible.""" + tz = dtype.tz + if isinstance(tz, zoneinfo.ZoneInfo): + return dtype + if cudf.get_option("mode.pandas_compatible"): + raise NotImplementedError( + f"{tz} must be a zoneinfo.ZoneInfo object in pandas_compatible mode." + ) + elif (tzname := getattr(tz, "zone", None)) is not None: + # pytz-like + key = tzname + elif (tz_file := getattr(tz, "_filename", None)) is not None: + # dateutil-like + key = tz_file.split("zoneinfo/")[-1] + elif isinstance(tz, datetime.tzinfo): + # Try to get UTC-like tzinfos + reference = datetime.datetime.now() + key = tz.tzname(reference) + if not (isinstance(key, str) and key.lower() == "utc"): + raise NotImplementedError(f"cudf does not support {tz}") + else: + raise NotImplementedError(f"cudf does not support {tz}") + new_tz = zoneinfo.ZoneInfo(key) + return pd.DatetimeTZDtype(dtype.unit, new_tz) + + @lru_cache(maxsize=20) def get_tz_data(zone_name: str) -> tuple[DatetimeColumn, TimeDeltaColumn]: """ @@ -87,6 +116,8 @@ def _read_tzfile_as_columns( ) if not transition_times_and_offsets: + from cudf.core.column.column import as_column + # this happens for UTC-like zones min_date = np.int64(np.iinfo("int64").min + 1).astype("M8[s]") return (as_column([min_date]), as_column([np.timedelta64(0, "s")])) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index c4e715aeb45..586689e2ee3 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -47,6 +47,7 @@ is_string_dtype, ) from cudf.core._compat import PANDAS_GE_210 +from cudf.core._internals.timezones import get_compatible_timezone from cudf.core.abc import Serializable from cudf.core.buffer import ( Buffer, @@ -1854,6 +1855,21 @@ def as_column( arbitrary.dtype, (pd.CategoricalDtype, pd.IntervalDtype, pd.DatetimeTZDtype), ): + if isinstance(arbitrary.dtype, pd.DatetimeTZDtype): + new_tz = get_compatible_timezone(arbitrary.dtype) + arbitrary = arbitrary.astype(new_tz) + if isinstance(arbitrary.dtype, pd.CategoricalDtype) and isinstance( + arbitrary.dtype.categories.dtype, pd.DatetimeTZDtype + ): + new_tz = get_compatible_timezone( + arbitrary.dtype.categories.dtype + ) + new_cats = arbitrary.dtype.categories.astype(new_tz) + new_dtype = pd.CategoricalDtype( + categories=new_cats, ordered=arbitrary.dtype.ordered + ) + arbitrary = arbitrary.astype(new_dtype) + return as_column( pa.array(arbitrary, from_pandas=True), nan_as_null=nan_as_null, diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index 9ac761b6be1..d88553361dd 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -21,6 +21,11 @@ from cudf._lib.search import search_sorted from cudf.api.types import is_datetime64_dtype, is_scalar, is_timedelta64_dtype from cudf.core._compat import PANDAS_GE_220 +from cudf.core._internals.timezones import ( + check_ambiguous_and_nonexistent, + get_compatible_timezone, + get_tz_data, +) from cudf.core.column import ColumnBase, as_column, column, string from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion from cudf.utils.dtypes import _get_base_dtype @@ -282,8 +287,6 @@ def __contains__(self, item: ScalarLike) -> bool: @functools.cached_property def time_unit(self) -> str: - if isinstance(self.dtype, pd.DatetimeTZDtype): - return self.dtype.unit return np.datetime_data(self.dtype)[0] @property @@ -725,8 +728,6 @@ def _find_ambiguous_and_nonexistent( transitions occur in the time zone database for the given timezone. If no transitions occur, the tuple `(False, False)` is returned. """ - from cudf.core._internals.timezones import get_tz_data - transition_times, offsets = get_tz_data(zone_name) offsets = offsets.astype(f"timedelta64[{self.time_unit}]") # type: ignore[assignment] @@ -785,26 +786,22 @@ def tz_localize( ambiguous: Literal["NaT"] = "NaT", nonexistent: Literal["NaT"] = "NaT", ): - from cudf.core._internals.timezones import ( - check_ambiguous_and_nonexistent, - get_tz_data, - ) - if tz is None: return self.copy() ambiguous, nonexistent = check_ambiguous_and_nonexistent( ambiguous, nonexistent ) - dtype = pd.DatetimeTZDtype(self.time_unit, tz) + dtype = get_compatible_timezone(pd.DatetimeTZDtype(self.time_unit, tz)) + tzname = dtype.tz.key ambiguous_col, nonexistent_col = self._find_ambiguous_and_nonexistent( - tz + tzname ) localized = self._scatter_by_column( self.isnull() | (ambiguous_col | nonexistent_col), cudf.Scalar(cudf.NaT, dtype=self.dtype), ) - transition_times, offsets = get_tz_data(tz) + transition_times, offsets = get_tz_data(tzname) transition_times_local = (transition_times + offsets).astype( localized.dtype ) @@ -845,7 +842,7 @@ def __init__( offset=offset, null_count=null_count, ) - self._dtype = dtype + self._dtype = get_compatible_timezone(dtype) def to_pandas( self, @@ -865,6 +862,10 @@ def to_arrow(self): self._local_time.to_arrow(), str(self.dtype.tz) ) + @functools.cached_property + def time_unit(self) -> str: + return self.dtype.unit + @property def _utc_time(self): """Return UTC time as naive timestamps.""" @@ -880,8 +881,6 @@ def _utc_time(self): @property def _local_time(self): """Return the local time as naive timestamps.""" - from cudf.core._internals.timezones import get_tz_data - transition_times, offsets = get_tz_data(str(self.dtype.tz)) transition_times = transition_times.astype(_get_base_dtype(self.dtype)) indices = search_sorted([transition_times], [self], "right") - 1 @@ -911,10 +910,6 @@ def __repr__(self): ) def tz_localize(self, tz: str | None, ambiguous="NaT", nonexistent="NaT"): - from cudf.core._internals.timezones import ( - check_ambiguous_and_nonexistent, - ) - if tz is None: return self._local_time ambiguous, nonexistent = check_ambiguous_and_nonexistent( diff --git a/python/cudf/cudf/tests/indexes/datetime/test_indexing.py b/python/cudf/cudf/tests/indexes/datetime/test_indexing.py index f2c2d9a263b..ee4d0f7e816 100644 --- a/python/cudf/cudf/tests/indexes/datetime/test_indexing.py +++ b/python/cudf/cudf/tests/indexes/datetime/test_indexing.py @@ -1,4 +1,5 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +import zoneinfo import pandas as pd @@ -7,13 +8,10 @@ def test_slice_datetimetz_index(): + tz = zoneinfo.ZoneInfo("US/Eastern") data = ["2001-01-01", "2001-01-02", None, None, "2001-01-03"] - pidx = pd.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize( - "US/Eastern" - ) - idx = cudf.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize( - "US/Eastern" - ) + pidx = pd.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(tz) + idx = cudf.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(tz) expected = pidx[1:4] got = idx[1:4] assert_eq(expected, got) diff --git a/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py b/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py index b28ef131025..77b32b8ce89 100644 --- a/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py +++ b/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py @@ -1,4 +1,6 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION. +import zoneinfo + import pandas as pd import cudf @@ -6,24 +8,21 @@ def test_tz_localize(): + tz = zoneinfo.ZoneInfo("America/New_York") pidx = pd.date_range("2001-01-01", "2001-01-02", freq="1s") pidx = pidx.astype("