Skip to content

Commit

Permalink
[SPARK-44731][PYTHON][CONNECT] Make TimestampNTZ works with literals …
Browse files Browse the repository at this point in the history
…in Python Spark Connect

### What changes were proposed in this pull request?

This PR proposes:
- Share the namespaces for `to_timestamp_ntz`, `to_timestamp_ltz` and `to_unix_timestamp` in Spark Connect. They were missed.
- Adds the support of `TimestampNTZ` for literal handling in Python Spark Connect (by respecting `spark.sql.timestampType`).

### Why are the changes needed?

For feature parity, and respect timestamp ntz in resampling in pandas API on Spark

### Does this PR introduce _any_ user-facing change?

Yes, this virtually fixes the same bug: apache#42392 in Spark Connect with Python.

### How was this patch tested?

Unittests reenabled.

Closes apache#42445 from HyukjinKwon/SPARK-44731.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
HyukjinKwon authored and vpolet committed Aug 24, 2023
1 parent 3f047e6 commit f2d136b
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
4 changes: 1 addition & 3 deletions python/pyspark/pandas/tests/connect/test_parity_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ class ResampleParityTests(
class ResampleWithTimezoneTests(
ResampleWithTimezoneMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
):
@unittest.skip("SPARK-44731: Support 'spark.sql.timestampType' in Python Spark Connect client")
def test_series_resample_with_timezone(self):
super().test_series_resample_with_timezone()
pass


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
from pyspark.sql.utils import is_timestamp_ntz_preferred

check_dependencies(__name__)

Expand Down Expand Up @@ -295,6 +296,8 @@ def _infer_type(cls, value: Any) -> DataType:
return StringType()
elif isinstance(value, decimal.Decimal):
return DecimalType()
elif isinstance(value, datetime.datetime) and is_timestamp_ntz_preferred():
return TimestampNTZType()
elif isinstance(value, datetime.datetime):
return TimestampType()
elif isinstance(value, datetime.date):
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7758,6 +7758,7 @@ def check_field(field: Union[Column, str], fieldName: str) -> None:
return _invoke_function("session_window", time_col, gap_duration)


@try_remote_functions
def to_unix_timestamp(
timestamp: "ColumnOrName",
format: Optional["ColumnOrName"] = None,
Expand All @@ -7767,6 +7768,9 @@ def to_unix_timestamp(
.. versionadded:: 3.5.0
.. versionchanged:: 3.5.0
Supports Spark Connect.
Parameters
----------
timestamp : :class:`~pyspark.sql.Column` or str
Expand Down Expand Up @@ -7794,6 +7798,7 @@ def to_unix_timestamp(
return _invoke_function_over_columns("to_unix_timestamp", timestamp)


@try_remote_functions
def to_timestamp_ltz(
timestamp: "ColumnOrName",
format: Optional["ColumnOrName"] = None,
Expand All @@ -7804,6 +7809,9 @@ def to_timestamp_ltz(
.. versionadded:: 3.5.0
.. versionchanged:: 3.5.0
Supports Spark Connect.
Parameters
----------
timestamp : :class:`~pyspark.sql.Column` or str
Expand Down Expand Up @@ -7831,6 +7839,7 @@ def to_timestamp_ltz(
return _invoke_function_over_columns("to_timestamp_ltz", timestamp)


@try_remote_functions
def to_timestamp_ntz(
timestamp: "ColumnOrName",
format: Optional["ColumnOrName"] = None,
Expand All @@ -7841,6 +7850,9 @@ def to_timestamp_ntz(
.. versionadded:: 3.5.0
.. versionchanged:: 3.5.0
Supports Spark Connect.
Parameters
----------
timestamp : :class:`~pyspark.sql.Column` or str
Expand Down
13 changes: 11 additions & 2 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,17 @@ def is_timestamp_ntz_preferred() -> bool:
"""
Return a bool if TimestampNTZType is preferred according to the SQL configuration set.
"""
jvm = SparkContext._jvm
return jvm is not None and jvm.PythonSQLUtils.isTimestampNTZPreferred()
if is_remote():
from pyspark.sql.connect.session import SparkSession as ConnectSparkSession

session = ConnectSparkSession.getActiveSession()
if session is None:
return False
else:
return session.conf.get("spark.sql.timestampType", None) == "TIMESTAMP_NTZ"
else:
jvm = SparkContext._jvm
return jvm is not None and jvm.PythonSQLUtils.isTimestampNTZPreferred()


def is_remote() -> bool:
Expand Down

0 comments on commit f2d136b

Please sign in to comment.