Skip to content

Commit

Permalink
Restrict the allowed pandas timezone objects in cudf (rapidsai#16013)
Browse files Browse the repository at this point in the history
Since cudf's timezone support is based on the OS's tz data and hence `zoneinfo`, cudf cannot naturally support the variety of timezone objects supported by pandas (`pytz`, `dateutil`, etc). Therefore:

* In pandas compatible mode, only accept pandas objects with zoneinfo timezones.
* Otherwise, try to convert the pandas timezone to an equivalent zoneinfo object e.g. `pytz.timezone("US/Pacific")`-> `zoneinfo.ZoneInfo("US/Pacific")`

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: rapidsai#16013
  • Loading branch information
mroeschke authored Jun 24, 2024
1 parent f3183c1 commit 0c6b828
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 39 deletions.
33 changes: 32 additions & 1 deletion python/cudf/cudf/core/_internals/timezones.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,50 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
from __future__ import annotations

import datetime
import os
import zoneinfo
from functools import lru_cache
from typing import TYPE_CHECKING, Literal

import numpy as np
import pandas as pd

import cudf
from cudf._lib.timezone import make_timezone_transition_table
from cudf.core.column.column import as_column

if TYPE_CHECKING:
from cudf.core.column.datetime import DatetimeColumn
from cudf.core.column.timedelta import TimeDeltaColumn


def get_compatible_timezone(dtype: pd.DatetimeTZDtype) -> pd.DatetimeTZDtype:
"""Convert dtype.tz object to zoneinfo object if possible."""
tz = dtype.tz
if isinstance(tz, zoneinfo.ZoneInfo):
return dtype
if cudf.get_option("mode.pandas_compatible"):
raise NotImplementedError(
f"{tz} must be a zoneinfo.ZoneInfo object in pandas_compatible mode."
)
elif (tzname := getattr(tz, "zone", None)) is not None:
# pytz-like
key = tzname
elif (tz_file := getattr(tz, "_filename", None)) is not None:
# dateutil-like
key = tz_file.split("zoneinfo/")[-1]
elif isinstance(tz, datetime.tzinfo):
# Try to get UTC-like tzinfos
reference = datetime.datetime.now()
key = tz.tzname(reference)
if not (isinstance(key, str) and key.lower() == "utc"):
raise NotImplementedError(f"cudf does not support {tz}")
else:
raise NotImplementedError(f"cudf does not support {tz}")
new_tz = zoneinfo.ZoneInfo(key)
return pd.DatetimeTZDtype(dtype.unit, new_tz)


@lru_cache(maxsize=20)
def get_tz_data(zone_name: str) -> tuple[DatetimeColumn, TimeDeltaColumn]:
"""
Expand Down Expand Up @@ -87,6 +116,8 @@ def _read_tzfile_as_columns(
)

if not transition_times_and_offsets:
from cudf.core.column.column import as_column

# this happens for UTC-like zones
min_date = np.int64(np.iinfo("int64").min + 1).astype("M8[s]")
return (as_column([min_date]), as_column([np.timedelta64(0, "s")]))
Expand Down
16 changes: 16 additions & 0 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
is_string_dtype,
)
from cudf.core._compat import PANDAS_GE_210
from cudf.core._internals.timezones import get_compatible_timezone
from cudf.core.abc import Serializable
from cudf.core.buffer import (
Buffer,
Expand Down Expand Up @@ -1854,6 +1855,21 @@ def as_column(
arbitrary.dtype,
(pd.CategoricalDtype, pd.IntervalDtype, pd.DatetimeTZDtype),
):
if isinstance(arbitrary.dtype, pd.DatetimeTZDtype):
new_tz = get_compatible_timezone(arbitrary.dtype)
arbitrary = arbitrary.astype(new_tz)
if isinstance(arbitrary.dtype, pd.CategoricalDtype) and isinstance(
arbitrary.dtype.categories.dtype, pd.DatetimeTZDtype
):
new_tz = get_compatible_timezone(
arbitrary.dtype.categories.dtype
)
new_cats = arbitrary.dtype.categories.astype(new_tz)
new_dtype = pd.CategoricalDtype(
categories=new_cats, ordered=arbitrary.dtype.ordered
)
arbitrary = arbitrary.astype(new_dtype)

return as_column(
pa.array(arbitrary, from_pandas=True),
nan_as_null=nan_as_null,
Expand Down
33 changes: 14 additions & 19 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from cudf._lib.search import search_sorted
from cudf.api.types import is_datetime64_dtype, is_scalar, is_timedelta64_dtype
from cudf.core._compat import PANDAS_GE_220
from cudf.core._internals.timezones import (
check_ambiguous_and_nonexistent,
get_compatible_timezone,
get_tz_data,
)
from cudf.core.column import ColumnBase, as_column, column, string
from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion
from cudf.utils.dtypes import _get_base_dtype
Expand Down Expand Up @@ -282,8 +287,6 @@ def __contains__(self, item: ScalarLike) -> bool:

@functools.cached_property
def time_unit(self) -> str:
if isinstance(self.dtype, pd.DatetimeTZDtype):
return self.dtype.unit
return np.datetime_data(self.dtype)[0]

@property
Expand Down Expand Up @@ -725,8 +728,6 @@ def _find_ambiguous_and_nonexistent(
transitions occur in the time zone database for the given timezone.
If no transitions occur, the tuple `(False, False)` is returned.
"""
from cudf.core._internals.timezones import get_tz_data

transition_times, offsets = get_tz_data(zone_name)
offsets = offsets.astype(f"timedelta64[{self.time_unit}]") # type: ignore[assignment]

Expand Down Expand Up @@ -785,26 +786,22 @@ def tz_localize(
ambiguous: Literal["NaT"] = "NaT",
nonexistent: Literal["NaT"] = "NaT",
):
from cudf.core._internals.timezones import (
check_ambiguous_and_nonexistent,
get_tz_data,
)

if tz is None:
return self.copy()
ambiguous, nonexistent = check_ambiguous_and_nonexistent(
ambiguous, nonexistent
)
dtype = pd.DatetimeTZDtype(self.time_unit, tz)
dtype = get_compatible_timezone(pd.DatetimeTZDtype(self.time_unit, tz))
tzname = dtype.tz.key
ambiguous_col, nonexistent_col = self._find_ambiguous_and_nonexistent(
tz
tzname
)
localized = self._scatter_by_column(
self.isnull() | (ambiguous_col | nonexistent_col),
cudf.Scalar(cudf.NaT, dtype=self.dtype),
)

transition_times, offsets = get_tz_data(tz)
transition_times, offsets = get_tz_data(tzname)
transition_times_local = (transition_times + offsets).astype(
localized.dtype
)
Expand Down Expand Up @@ -845,7 +842,7 @@ def __init__(
offset=offset,
null_count=null_count,
)
self._dtype = dtype
self._dtype = get_compatible_timezone(dtype)

def to_pandas(
self,
Expand All @@ -865,6 +862,10 @@ def to_arrow(self):
self._local_time.to_arrow(), str(self.dtype.tz)
)

@functools.cached_property
def time_unit(self) -> str:
return self.dtype.unit

@property
def _utc_time(self):
"""Return UTC time as naive timestamps."""
Expand All @@ -880,8 +881,6 @@ def _utc_time(self):
@property
def _local_time(self):
"""Return the local time as naive timestamps."""
from cudf.core._internals.timezones import get_tz_data

transition_times, offsets = get_tz_data(str(self.dtype.tz))
transition_times = transition_times.astype(_get_base_dtype(self.dtype))
indices = search_sorted([transition_times], [self], "right") - 1
Expand Down Expand Up @@ -911,10 +910,6 @@ def __repr__(self):
)

def tz_localize(self, tz: str | None, ambiguous="NaT", nonexistent="NaT"):
from cudf.core._internals.timezones import (
check_ambiguous_and_nonexistent,
)

if tz is None:
return self._local_time
ambiguous, nonexistent = check_ambiguous_and_nonexistent(
Expand Down
12 changes: 5 additions & 7 deletions python/cudf/cudf/tests/indexes/datetime/test_indexing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
import zoneinfo

import pandas as pd

Expand All @@ -7,13 +8,10 @@


def test_slice_datetimetz_index():
tz = zoneinfo.ZoneInfo("US/Eastern")
data = ["2001-01-01", "2001-01-02", None, None, "2001-01-03"]
pidx = pd.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(
"US/Eastern"
)
idx = cudf.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(
"US/Eastern"
)
pidx = pd.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(tz)
idx = cudf.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(tz)
expected = pidx[1:4]
got = idx[1:4]
assert_eq(expected, got)
13 changes: 6 additions & 7 deletions python/cudf/cudf/tests/indexes/datetime/test_time_specific.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
import zoneinfo

import pandas as pd

import cudf
from cudf.testing._utils import assert_eq


def test_tz_localize():
tz = zoneinfo.ZoneInfo("America/New_York")
pidx = pd.date_range("2001-01-01", "2001-01-02", freq="1s")
pidx = pidx.astype("<M8[ns]")
idx = cudf.from_pandas(pidx)
assert pidx.dtype == idx.dtype
assert_eq(
pidx.tz_localize("America/New_York"),
idx.tz_localize("America/New_York"),
)
assert_eq(pidx.tz_localize(tz), idx.tz_localize(tz))


def test_tz_convert():
tz = zoneinfo.ZoneInfo("America/New_York")
pidx = pd.date_range("2023-01-01", periods=3, freq="h")
idx = cudf.from_pandas(pidx)
pidx = pidx.tz_localize("UTC")
idx = idx.tz_localize("UTC")
assert_eq(
pidx.tz_convert("America/New_York"), idx.tz_convert("America/New_York")
)
assert_eq(pidx.tz_convert(tz), idx.tz_convert(tz))


def test_delocalize_naive():
Expand Down
40 changes: 35 additions & 5 deletions python/cudf/cudf/tests/series/test_datetimelike.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.

import datetime
import os
import zoneinfo

import pandas as pd
import pytest
Expand Down Expand Up @@ -70,7 +72,7 @@ def test_localize_ambiguous(request, unit, zone_name):
dtype=f"datetime64[{unit}]",
)
expect = s.to_pandas().dt.tz_localize(
zone_name, ambiguous="NaT", nonexistent="NaT"
zoneinfo.ZoneInfo(zone_name), ambiguous="NaT", nonexistent="NaT"
)
got = s.dt.tz_localize(zone_name)
assert_eq(expect, got)
Expand All @@ -96,7 +98,7 @@ def test_localize_nonexistent(request, unit, zone_name):
dtype=f"datetime64[{unit}]",
)
expect = s.to_pandas().dt.tz_localize(
zone_name, ambiguous="NaT", nonexistent="NaT"
zoneinfo.ZoneInfo(zone_name), ambiguous="NaT", nonexistent="NaT"
)
got = s.dt.tz_localize(zone_name)
assert_eq(expect, got)
Expand Down Expand Up @@ -130,6 +132,9 @@ def test_delocalize_naive():
"to_tz", ["Europe/London", "America/Chicago", "UTC", None]
)
def test_convert(from_tz, to_tz):
from_tz = zoneinfo.ZoneInfo(from_tz)
if to_tz is not None:
to_tz = zoneinfo.ZoneInfo(to_tz)
ps = pd.Series(pd.date_range("2023-01-01", periods=3, freq="h"))
gs = cudf.from_pandas(ps)
ps = ps.dt.tz_localize(from_tz)
Expand Down Expand Up @@ -169,6 +174,8 @@ def test_convert_from_naive():
],
)
def test_convert_edge_cases(data, original_timezone, target_timezone):
original_timezone = zoneinfo.ZoneInfo(original_timezone)
target_timezone = zoneinfo.ZoneInfo(target_timezone)
ps = pd.Series(data, dtype="datetime64[s]").dt.tz_localize(
original_timezone
)
Expand Down Expand Up @@ -229,10 +236,33 @@ def test_tz_convert_naive_typeerror():
"klass", ["Series", "DatetimeIndex", "Index", "CategoricalIndex"]
)
def test_from_pandas_obj_tz_aware(klass):
tz_aware_data = [
pd.Timestamp("2020-01-01", tz="UTC").tz_convert("US/Pacific")
]
tz = zoneinfo.ZoneInfo("US/Pacific")
tz_aware_data = [pd.Timestamp("2020-01-01", tz="UTC").tz_convert(tz)]
pandas_obj = getattr(pd, klass)(tz_aware_data)
result = cudf.from_pandas(pandas_obj)
expected = getattr(cudf, klass)(tz_aware_data)
assert_eq(result, expected)


@pytest.mark.parametrize(
"klass", ["Series", "DatetimeIndex", "Index", "CategoricalIndex"]
)
def test_from_pandas_obj_tz_aware_unsupported(klass):
tz = datetime.timezone(datetime.timedelta(hours=1))
tz_aware_data = [pd.Timestamp("2020-01-01", tz="UTC").tz_convert(tz)]
pandas_obj = getattr(pd, klass)(tz_aware_data)
with pytest.raises(NotImplementedError):
cudf.from_pandas(pandas_obj)


@pytest.mark.parametrize(
"klass", ["Series", "DatetimeIndex", "Index", "CategoricalIndex"]
)
def test_pandas_compatible_non_zoneinfo_raises(klass):
pytz = pytest.importorskip("pytz")
tz = pytz.timezone("US/Pacific")
tz_aware_data = [pd.Timestamp("2020-01-01", tz="UTC").tz_convert(tz)]
pandas_obj = getattr(pd, klass)(tz_aware_data)
with cudf.option_context("mode.pandas_compatible", True):
with pytest.raises(NotImplementedError):
cudf.from_pandas(pandas_obj)

0 comments on commit 0c6b828

Please sign in to comment.