From 7cd5b13340315803e0738844ba212502d161315f Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 11 Aug 2023 17:35:00 +0800 Subject: [PATCH] [SPARK-44731][PYTHON][CONNECT] Make TimestampNTZ works with literals in Python Spark Connect ### What changes were proposed in this pull request? This PR proposes: - Share the namespaces for `to_timestamp_ntz`, `to_timestamp_ltz` and `to_unix_timestamp` in Spark Connect. They were missed. - Adds the support of `TimestampNTZ` for literal handling in Python Spark Connect (by respecting `spark.sql.timestampType`). ### Why are the changes needed? For feature parity, and respect timestamp ntz in resampling in pandas API on Spark ### Does this PR introduce _any_ user-facing change? Yes, this virtually fixes the same bug: https://github.com/apache/spark/pull/42392 in Spark Connect with Python. ### How was this patch tested? Unittests reenabled. Closes #42445 from HyukjinKwon/SPARK-44731. Authored-by: Hyukjin Kwon Signed-off-by: Ruifeng Zheng --- .../pandas/tests/connect/test_parity_resample.py | 4 +--- python/pyspark/sql/connect/expressions.py | 3 +++ python/pyspark/sql/functions.py | 12 ++++++++++++ python/pyspark/sql/utils.py | 13 +++++++++++-- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/python/pyspark/pandas/tests/connect/test_parity_resample.py b/python/pyspark/pandas/tests/connect/test_parity_resample.py index d5c901f113a05..caca2f957b507 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_resample.py +++ b/python/pyspark/pandas/tests/connect/test_parity_resample.py @@ -30,9 +30,7 @@ class ResampleParityTests( class ResampleWithTimezoneTests( ResampleWithTimezoneMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase ): - @unittest.skip("SPARK-44731: Support 'spark.sql.timestampType' in Python Spark Connect client") - def test_series_resample_with_timezone(self): - super().test_series_resample_with_timezone() + pass if __name__ == "__main__": diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 44e6e174f70c5..d0a9b1d69aee3 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -15,6 +15,7 @@ # limitations under the License. # from pyspark.sql.connect.utils import check_dependencies +from pyspark.sql.utils import is_timestamp_ntz_preferred check_dependencies(__name__) @@ -295,6 +296,8 @@ def _infer_type(cls, value: Any) -> DataType: return StringType() elif isinstance(value, decimal.Decimal): return DecimalType() + elif isinstance(value, datetime.datetime) and is_timestamp_ntz_preferred(): + return TimestampNTZType() elif isinstance(value, datetime.datetime): return TimestampType() elif isinstance(value, datetime.date): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b8a946e02e48..fdb4ec8111ed4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7758,6 +7758,7 @@ def check_field(field: Union[Column, str], fieldName: str) -> None: return _invoke_function("session_window", time_col, gap_duration) +@try_remote_functions def to_unix_timestamp( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, @@ -7767,6 +7768,9 @@ def to_unix_timestamp( .. versionadded:: 3.5.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- timestamp : :class:`~pyspark.sql.Column` or str @@ -7794,6 +7798,7 @@ def to_unix_timestamp( return _invoke_function_over_columns("to_unix_timestamp", timestamp) +@try_remote_functions def to_timestamp_ltz( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, @@ -7804,6 +7809,9 @@ def to_timestamp_ltz( .. versionadded:: 3.5.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- timestamp : :class:`~pyspark.sql.Column` or str @@ -7831,6 +7839,7 @@ def to_timestamp_ltz( return _invoke_function_over_columns("to_timestamp_ltz", timestamp) +@try_remote_functions def to_timestamp_ntz( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, @@ -7841,6 +7850,9 @@ def to_timestamp_ntz( .. versionadded:: 3.5.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- timestamp : :class:`~pyspark.sql.Column` or str diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index d4f56fe822f3e..cb262a14cbe2c 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -140,8 +140,17 @@ def is_timestamp_ntz_preferred() -> bool: """ Return a bool if TimestampNTZType is preferred according to the SQL configuration set. """ - jvm = SparkContext._jvm - return jvm is not None and jvm.PythonSQLUtils.isTimestampNTZPreferred() + if is_remote(): + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + + session = ConnectSparkSession.getActiveSession() + if session is None: + return False + else: + return session.conf.get("spark.sql.timestampType", None) == "TIMESTAMP_NTZ" + else: + jvm = SparkContext._jvm + return jvm is not None and jvm.PythonSQLUtils.isTimestampNTZPreferred() def is_remote() -> bool: