From 758016f9cd9719a4e531d93f02a55a83548e15ba Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 8 Aug 2023 20:00:31 +0900 Subject: [PATCH 1/2] Respect TimestampNTZ in resampling --- python/pyspark/pandas/frame.py | 4 +- python/pyspark/pandas/resample.py | 43 +++++++++++++------- python/pyspark/pandas/tests/test_resample.py | 22 +++++++++- 3 files changed, 52 insertions(+), 17 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 72d4a88b69203..65c43eb7cf42c 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -13155,7 +13155,9 @@ def resample( if on is None and not isinstance(self.index, DatetimeIndex): raise NotImplementedError("resample currently works only for DatetimeIndex") - if on is not None and not isinstance(as_spark_type(on.dtype), TimestampType): + if on is not None and not isinstance( + as_spark_type(on.dtype), (TimestampType, TimestampNTZType) + ): raise NotImplementedError("`on` currently works only for TimestampType") agg_columns: List[ps.Series] = [] diff --git a/python/pyspark/pandas/resample.py b/python/pyspark/pandas/resample.py index c6c6019c07e6a..30f8c9d31695e 100644 --- a/python/pyspark/pandas/resample.py +++ b/python/pyspark/pandas/resample.py @@ -46,7 +46,8 @@ from pyspark.sql.types import ( NumericType, StructField, - TimestampType, + TimestampNTZType, + DataType, ) from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. @@ -130,6 +131,13 @@ def _resamplekey_scol(self) -> Column: else: return self._resamplekey.spark.column + @property + def _resamplekey_type(self) -> DataType: + if self._resamplekey is None: + return self._psdf.index.spark.data_type + else: + return self._resamplekey.spark.data_type + @property def _agg_columns_scols(self) -> List[Column]: return [s.spark.column for s in self._agg_columns] @@ -154,7 +162,8 @@ def get_make_interval( # type: ignore[return] col = col._jc if isinstance(col, Column) else F.lit(col)._jc return sql_utils.makeInterval(unit, col) - def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: + def _bin_timestamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: + key_type = self._resamplekey_type origin_scol = F.lit(origin) (rule_code, n) = (self._offset.rule_code, getattr(self._offset, "n")) left_closed, right_closed = (self._closed == "left", self._closed == "right") @@ -188,7 +197,7 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: F.year(ts_scol) - (mod - n) ) - return F.to_timestamp( + ret = F.to_timestamp( F.make_date( F.when(edge_cond, edge_label).otherwise(non_edge_label), F.lit(12), F.lit(31) ) @@ -227,7 +236,7 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: truncated_ts_scol - self.get_make_interval("MONTH", mod - n) ) - return F.to_timestamp( + ret = F.to_timestamp( F.last_day(F.when(edge_cond, edge_label).otherwise(non_edge_label)) ) @@ -242,15 +251,15 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: ) if left_closed and left_labeled: - return F.date_trunc("DAY", ts_scol) + ret = F.date_trunc("DAY", ts_scol) elif left_closed and right_labeled: - return F.date_trunc("DAY", F.date_add(ts_scol, 1)) + ret = F.date_trunc("DAY", F.date_add(ts_scol, 1)) elif right_closed and left_labeled: - return F.when(edge_cond, F.date_trunc("DAY", F.date_sub(ts_scol, 1))).otherwise( + ret = F.when(edge_cond, F.date_trunc("DAY", F.date_sub(ts_scol, 1))).otherwise( F.date_trunc("DAY", ts_scol) ) else: - return F.when(edge_cond, F.date_trunc("DAY", ts_scol)).otherwise( + ret = F.when(edge_cond, F.date_trunc("DAY", ts_scol)).otherwise( F.date_trunc("DAY", F.date_add(ts_scol, 1)) ) @@ -272,13 +281,15 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: else: non_edge_label = F.date_sub(truncated_ts_scol, mod - n) - return F.when(edge_cond, edge_label).otherwise(non_edge_label) + ret = F.when(edge_cond, edge_label).otherwise(non_edge_label) elif rule_code in ["H", "T", "S"]: unit_mapping = {"H": "HOUR", "T": "MINUTE", "S": "SECOND"} unit_str = unit_mapping[rule_code] truncated_ts_scol = F.date_trunc(unit_str, ts_scol) + if isinstance(key_type, TimestampNTZType): + truncated_ts_scol = F.to_timestamp_ntz(truncated_ts_scol) diff = timestampdiff(unit_str, origin_scol, truncated_ts_scol) mod = F.lit(0) if n == 1 else (diff % F.lit(n)) @@ -307,11 +318,16 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: truncated_ts_scol + self.get_make_interval(unit_str, n), ).otherwise(truncated_ts_scol - self.get_make_interval(unit_str, mod - n)) - return F.when(edge_cond, edge_label).otherwise(non_edge_label) + ret = F.when(edge_cond, edge_label).otherwise(non_edge_label) else: raise ValueError("Got the unexpected unit {}".format(rule_code)) + if isinstance(key_type, TimestampNTZType): + return F.to_timestamp_ntz(ret) + else: + return ret + def _downsample(self, f: str) -> DataFrame: """ Downsample the defined function. @@ -374,12 +390,9 @@ def _downsample(self, f: str) -> DataFrame: bin_col_label = verify_temp_column_name(self._psdf, bin_col_name) bin_col_field = InternalField( dtype=np.dtype("datetime64[ns]"), - struct_field=StructField(bin_col_name, TimestampType(), True), - ) - bin_scol = self._bin_time_stamp( - ts_origin, - self._resamplekey_scol, + struct_field=StructField(bin_col_name, self._resamplekey_type, True), ) + bin_scol = self._bin_timestamp(ts_origin, self._resamplekey_scol) agg_columns = [ psser for psser in self._agg_columns if (isinstance(psser.spark.data_type, NumericType)) diff --git a/python/pyspark/pandas/tests/test_resample.py b/python/pyspark/pandas/tests/test_resample.py index 0650fc40448e3..641d5572b928d 100644 --- a/python/pyspark/pandas/tests/test_resample.py +++ b/python/pyspark/pandas/tests/test_resample.py @@ -19,6 +19,8 @@ import unittest import inspect import datetime +import os + import numpy as np import pandas as pd @@ -252,7 +254,7 @@ def test_dataframe_resample(self): self._test_resample(self.pdf5, self.psdf5, ["55MIN", "2H", "D"], "left", "left", "std") self._test_resample(self.pdf6, self.psdf6, ["29S", "10MIN", "3H"], "left", "right", "var") - def test_series_resample(self): + def check_series_resample(self): self._test_resample(self.pdf1.A, self.psdf1.A, ["4Y"], "right", None, "min") self._test_resample(self.pdf2.A, self.psdf2.A, ["13M"], "right", "left", "max") self._test_resample(self.pdf3.A, self.psdf3.A, ["1001H"], "right", "right", "sum") @@ -260,6 +262,24 @@ def test_series_resample(self): self._test_resample(self.pdf5.A, self.psdf5.A, ["47T"], "left", "left", "var") self._test_resample(self.pdf6.A, self.psdf6.A, ["111S"], "right", "right", "std") + def test_series_resample(self): + self.check_series_resample() + + def test_series_resample_with_timezone(self): + timezone = os.environ.get("TZ", None) + try: + os.environ["TZ"] = "America/New_York" + with self.sql_conf( + { + "spark.sql.session.timeZone": "Asia/Seoul", + "spark.sql.timestampType": "TIMESTAMP_NTZ", + } + ): + self.check_series_resample() + finally: + if timezone is not None: + os.environ["TZ"] = timezone + def test_resample_on(self): np.random.seed(77) dates = [ From 1d3df69ea8c2b255909772aafa52e8fae821f3e9 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 9 Aug 2023 08:54:00 +0900 Subject: [PATCH 2/2] Address comments --- .../tests/connect/test_parity_resample.py | 12 +++- python/pyspark/pandas/tests/test_resample.py | 65 +++++++++++++------ 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/python/pyspark/pandas/tests/connect/test_parity_resample.py b/python/pyspark/pandas/tests/connect/test_parity_resample.py index e5957cc9b4a29..d5c901f113a05 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_resample.py +++ b/python/pyspark/pandas/tests/connect/test_parity_resample.py @@ -16,17 +16,25 @@ # import unittest -from pyspark.pandas.tests.test_resample import ResampleTestsMixin +from pyspark.pandas.tests.test_resample import ResampleTestsMixin, ResampleWithTimezoneMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils -class ResampleTestsParityMixin( +class ResampleParityTests( ResampleTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase ): pass +class ResampleWithTimezoneTests( + ResampleWithTimezoneMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +): + @unittest.skip("SPARK-44731: Support 'spark.sql.timestampType' in Python Spark Connect client") + def test_series_resample_with_timezone(self): + super().test_series_resample_with_timezone() + + if __name__ == "__main__": from pyspark.pandas.tests.connect.test_parity_resample import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/test_resample.py b/python/pyspark/pandas/tests/test_resample.py index 641d5572b928d..4061402590767 100644 --- a/python/pyspark/pandas/tests/test_resample.py +++ b/python/pyspark/pandas/tests/test_resample.py @@ -254,7 +254,7 @@ def test_dataframe_resample(self): self._test_resample(self.pdf5, self.psdf5, ["55MIN", "2H", "D"], "left", "left", "std") self._test_resample(self.pdf6, self.psdf6, ["29S", "10MIN", "3H"], "left", "right", "var") - def check_series_resample(self): + def test_series_resample(self): self._test_resample(self.pdf1.A, self.psdf1.A, ["4Y"], "right", None, "min") self._test_resample(self.pdf2.A, self.psdf2.A, ["13M"], "right", "left", "max") self._test_resample(self.pdf3.A, self.psdf3.A, ["1001H"], "right", "right", "sum") @@ -262,24 +262,6 @@ def check_series_resample(self): self._test_resample(self.pdf5.A, self.psdf5.A, ["47T"], "left", "left", "var") self._test_resample(self.pdf6.A, self.psdf6.A, ["111S"], "right", "right", "std") - def test_series_resample(self): - self.check_series_resample() - - def test_series_resample_with_timezone(self): - timezone = os.environ.get("TZ", None) - try: - os.environ["TZ"] = "America/New_York" - with self.sql_conf( - { - "spark.sql.session.timeZone": "Asia/Seoul", - "spark.sql.timestampType": "TIMESTAMP_NTZ", - } - ): - self.check_series_resample() - finally: - if timezone is not None: - os.environ["TZ"] = timezone - def test_resample_on(self): np.random.seed(77) dates = [ @@ -303,10 +285,55 @@ def test_resample_on(self): ) +class ResampleWithTimezoneMixin: + timezone = None + + @classmethod + def setUpClass(cls): + cls.timezone = os.environ.get("TZ", None) + os.environ["TZ"] = "America/New_York" + super(ResampleWithTimezoneMixin, cls).setUpClass() + + @classmethod + def tearDownClass(cls): + super(ResampleWithTimezoneMixin, cls).tearDownClass() + if cls.timezone is not None: + os.environ["TZ"] = cls.timezone + + @property + def pdf(self): + np.random.seed(22) + index = pd.date_range(start="2011-01-02", end="2022-05-01", freq="1D") + return pd.DataFrame(np.random.rand(len(index), 2), index=index, columns=list("AB")) + + @property + def psdf(self): + return ps.from_pandas(self.pdf) + + def test_series_resample_with_timezone(self): + with self.sql_conf( + { + "spark.sql.session.timeZone": "Asia/Seoul", + "spark.sql.timestampType": "TIMESTAMP_NTZ", + } + ): + p_resample = self.pdf.resample(rule="1001H", closed="right", label="right") + ps_resample = self.psdf.resample(rule="1001H", closed="right", label="right") + self.assert_eq( + p_resample.sum().sort_index(), + ps_resample.sum().sort_index(), + almost=True, + ) + + class ResampleTests(ResampleTestsMixin, PandasOnSparkTestCase, TestUtils): pass +class ResampleWithTimezoneTests(ResampleWithTimezoneMixin, PandasOnSparkTestCase, TestUtils): + pass + + if __name__ == "__main__": from pyspark.pandas.tests.test_resample import * # noqa: F401