diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 6ccffc718d064..64f095a5b018b 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -26,6 +26,7 @@ from collections.abc import Callable, Sized import functools from threading import RLock +from types import TracebackType from typing import ( Optional, Any, @@ -40,6 +41,7 @@ Mapping, TYPE_CHECKING, ClassVar, + Type, ) import numpy as np @@ -947,6 +949,57 @@ def stop(self) -> None: if "SPARK_REMOTE" in os.environ: del os.environ["SPARK_REMOTE"] + def __enter__(self) -> "SparkSession": + """ + Enable 'with SparkSession.builder.(...).getOrCreate() as session: app' syntax. + + .. versionadded:: 4.1.0 + + Examples + -------- + >>> with SparkSession.builder.master("local").getOrCreate() as session: + ... session.range(5).show() # doctest: +SKIP + +---+ + | id| + +---+ + | 0| + | 1| + | 2| + | 3| + | 4| + +---+ + """ + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """ + Enable 'with SparkSession.builder.(...).getOrCreate() as session: app' syntax. + + Specifically stop the SparkSession on exit of the with block. + + .. versionadded:: 4.1.0 + + Examples + -------- + >>> with SparkSession.builder.master("local").getOrCreate() as session: + ... session.range(5).show() # doctest: +SKIP + +---+ + | id| + +---+ + | 0| + | 1| + | 2| + | 3| + | 4| + +---+ + """ + self.stop() + @property def is_stopped(self) -> bool: """ diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py index 1857796ac9aa0..98b1dfc43dd01 100644 --- a/python/pyspark/sql/tests/connect/test_connect_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -324,6 +324,32 @@ def test_config(self): self.assertEqual(self.spark.conf.get("boolean"), "false") self.assertEqual(self.spark.conf.get("integer"), "1") + def test_context_manager_enter_exit(self): + """Test that SparkSession works as a context manager.""" + # Create a new session for testing + with PySparkSession.builder.remote("local[2]").getOrCreate() as session: + self.assertIsInstance(session, RemoteSparkSession) + self.assertFalse(session.is_stopped) + + df = session.range(3) + result = df.collect() + self.assertEqual(len(result), 3) + + self.assertTrue(session.is_stopped) + + def test_context_manager_with_exception(self): + """Test that SparkSession is properly stopped even when exception occurs.""" + session = None + try: + with PySparkSession.builder.remote("local[2]").getOrCreate() as session: + self.assertIsInstance(session, RemoteSparkSession) + self.assertFalse(session.is_stopped) + raise ValueError("Test exception") + except ValueError: + pass # Expected exception + + self.assertTrue(session.is_stopped) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_session import * # noqa: F401