Skip to content

Commit

Permalink
[SPARK-44717][PYTHON][PS] Respect TimestampNTZ in resampling
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes to respect `TimestampNTZ` type in resampling at pandas API on Spark.

### Why are the changes needed?

It still operates as if the timestamps are `TIMESTAMP_LTZ` even when `spark.sql.timestampType` is set to `TIMESTAMP_NTZ`, which is unexpected.

### Does this PR introduce _any_ user-facing change?

This fixes a bug so end users can use exactly same behaviour with pandas with `TimestampNTZType` - pandas does not respect the local timezone with DST. While we might need to follow this even for `TimestampType`, this PR does not address the case as it might be controversial.

### How was this patch tested?

Unittest was added.

Closes apache#42392 from HyukjinKwon/SPARK-44717.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon authored and vpolet committed Aug 24, 2023
1 parent c53d4e0 commit 25a9e3d
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 18 deletions.
4 changes: 3 additions & 1 deletion python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -13155,7 +13155,9 @@ def resample(

if on is None and not isinstance(self.index, DatetimeIndex):
raise NotImplementedError("resample currently works only for DatetimeIndex")
if on is not None and not isinstance(as_spark_type(on.dtype), TimestampType):
if on is not None and not isinstance(
as_spark_type(on.dtype), (TimestampType, TimestampNTZType)
):
raise NotImplementedError("`on` currently works only for TimestampType")

agg_columns: List[ps.Series] = []
Expand Down
43 changes: 28 additions & 15 deletions python/pyspark/pandas/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
from pyspark.sql.types import (
NumericType,
StructField,
TimestampType,
TimestampNTZType,
DataType,
)

from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
Expand Down Expand Up @@ -130,6 +131,13 @@ def _resamplekey_scol(self) -> Column:
else:
return self._resamplekey.spark.column

@property
def _resamplekey_type(self) -> DataType:
if self._resamplekey is None:
return self._psdf.index.spark.data_type
else:
return self._resamplekey.spark.data_type

@property
def _agg_columns_scols(self) -> List[Column]:
return [s.spark.column for s in self._agg_columns]
Expand All @@ -154,7 +162,8 @@ def get_make_interval( # type: ignore[return]
col = col._jc if isinstance(col, Column) else F.lit(col)._jc
return sql_utils.makeInterval(unit, col)

def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
def _bin_timestamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
key_type = self._resamplekey_type
origin_scol = F.lit(origin)
(rule_code, n) = (self._offset.rule_code, getattr(self._offset, "n"))
left_closed, right_closed = (self._closed == "left", self._closed == "right")
Expand Down Expand Up @@ -188,7 +197,7 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
F.year(ts_scol) - (mod - n)
)

return F.to_timestamp(
ret = F.to_timestamp(
F.make_date(
F.when(edge_cond, edge_label).otherwise(non_edge_label), F.lit(12), F.lit(31)
)
Expand Down Expand Up @@ -227,7 +236,7 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
truncated_ts_scol - self.get_make_interval("MONTH", mod - n)
)

return F.to_timestamp(
ret = F.to_timestamp(
F.last_day(F.when(edge_cond, edge_label).otherwise(non_edge_label))
)

Expand All @@ -242,15 +251,15 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
)

if left_closed and left_labeled:
return F.date_trunc("DAY", ts_scol)
ret = F.date_trunc("DAY", ts_scol)
elif left_closed and right_labeled:
return F.date_trunc("DAY", F.date_add(ts_scol, 1))
ret = F.date_trunc("DAY", F.date_add(ts_scol, 1))
elif right_closed and left_labeled:
return F.when(edge_cond, F.date_trunc("DAY", F.date_sub(ts_scol, 1))).otherwise(
ret = F.when(edge_cond, F.date_trunc("DAY", F.date_sub(ts_scol, 1))).otherwise(
F.date_trunc("DAY", ts_scol)
)
else:
return F.when(edge_cond, F.date_trunc("DAY", ts_scol)).otherwise(
ret = F.when(edge_cond, F.date_trunc("DAY", ts_scol)).otherwise(
F.date_trunc("DAY", F.date_add(ts_scol, 1))
)

Expand All @@ -272,13 +281,15 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
else:
non_edge_label = F.date_sub(truncated_ts_scol, mod - n)

return F.when(edge_cond, edge_label).otherwise(non_edge_label)
ret = F.when(edge_cond, edge_label).otherwise(non_edge_label)

elif rule_code in ["H", "T", "S"]:
unit_mapping = {"H": "HOUR", "T": "MINUTE", "S": "SECOND"}
unit_str = unit_mapping[rule_code]

truncated_ts_scol = F.date_trunc(unit_str, ts_scol)
if isinstance(key_type, TimestampNTZType):
truncated_ts_scol = F.to_timestamp_ntz(truncated_ts_scol)
diff = timestampdiff(unit_str, origin_scol, truncated_ts_scol)
mod = F.lit(0) if n == 1 else (diff % F.lit(n))

Expand Down Expand Up @@ -307,11 +318,16 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
truncated_ts_scol + self.get_make_interval(unit_str, n),
).otherwise(truncated_ts_scol - self.get_make_interval(unit_str, mod - n))

return F.when(edge_cond, edge_label).otherwise(non_edge_label)
ret = F.when(edge_cond, edge_label).otherwise(non_edge_label)

else:
raise ValueError("Got the unexpected unit {}".format(rule_code))

if isinstance(key_type, TimestampNTZType):
return F.to_timestamp_ntz(ret)
else:
return ret

def _downsample(self, f: str) -> DataFrame:
"""
Downsample the defined function.
Expand Down Expand Up @@ -374,12 +390,9 @@ def _downsample(self, f: str) -> DataFrame:
bin_col_label = verify_temp_column_name(self._psdf, bin_col_name)
bin_col_field = InternalField(
dtype=np.dtype("datetime64[ns]"),
struct_field=StructField(bin_col_name, TimestampType(), True),
)
bin_scol = self._bin_time_stamp(
ts_origin,
self._resamplekey_scol,
struct_field=StructField(bin_col_name, self._resamplekey_type, True),
)
bin_scol = self._bin_timestamp(ts_origin, self._resamplekey_scol)

agg_columns = [
psser for psser in self._agg_columns if (isinstance(psser.spark.data_type, NumericType))
Expand Down
12 changes: 10 additions & 2 deletions python/pyspark/pandas/tests/connect/test_parity_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,25 @@
#
import unittest

from pyspark.pandas.tests.test_resample import ResampleTestsMixin
from pyspark.pandas.tests.test_resample import ResampleTestsMixin, ResampleWithTimezoneMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils


class ResampleTestsParityMixin(
class ResampleParityTests(
ResampleTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
):
pass


class ResampleWithTimezoneTests(
ResampleWithTimezoneMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
):
@unittest.skip("SPARK-44731: Support 'spark.sql.timestampType' in Python Spark Connect client")
def test_series_resample_with_timezone(self):
super().test_series_resample_with_timezone()


if __name__ == "__main__":
from pyspark.pandas.tests.connect.test_parity_resample import * # noqa: F401

Expand Down
47 changes: 47 additions & 0 deletions python/pyspark/pandas/tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import unittest
import inspect
import datetime
import os

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -283,10 +285,55 @@ def test_resample_on(self):
)


class ResampleWithTimezoneMixin:
timezone = None

@classmethod
def setUpClass(cls):
cls.timezone = os.environ.get("TZ", None)
os.environ["TZ"] = "America/New_York"
super(ResampleWithTimezoneMixin, cls).setUpClass()

@classmethod
def tearDownClass(cls):
super(ResampleWithTimezoneMixin, cls).tearDownClass()
if cls.timezone is not None:
os.environ["TZ"] = cls.timezone

@property
def pdf(self):
np.random.seed(22)
index = pd.date_range(start="2011-01-02", end="2022-05-01", freq="1D")
return pd.DataFrame(np.random.rand(len(index), 2), index=index, columns=list("AB"))

@property
def psdf(self):
return ps.from_pandas(self.pdf)

def test_series_resample_with_timezone(self):
with self.sql_conf(
{
"spark.sql.session.timeZone": "Asia/Seoul",
"spark.sql.timestampType": "TIMESTAMP_NTZ",
}
):
p_resample = self.pdf.resample(rule="1001H", closed="right", label="right")
ps_resample = self.psdf.resample(rule="1001H", closed="right", label="right")
self.assert_eq(
p_resample.sum().sort_index(),
ps_resample.sum().sort_index(),
almost=True,
)


class ResampleTests(ResampleTestsMixin, PandasOnSparkTestCase, TestUtils):
pass


class ResampleWithTimezoneTests(ResampleWithTimezoneMixin, PandasOnSparkTestCase, TestUtils):
pass


if __name__ == "__main__":
from pyspark.pandas.tests.test_resample import * # noqa: F401

Expand Down

0 comments on commit 25a9e3d

Please sign in to comment.