Skip to content

Commit

Permalink
[SPARK-25003][PYSPARK] Use SessionExtensions in Pyspark
Browse files Browse the repository at this point in the history
Master

## What changes were proposed in this pull request?

Previously Pyspark used the private constructor for SparkSession when
building that object. This resulted in a SparkSession without checking
the sql.extensions parameter for additional session extensions. To fix
this we instead use the Session.builder() path as SparkR uses, this
loads the extensions and allows their use in PySpark.

## How was this patch tested?

An integration test was added which mimics the Scala test for the same feature.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Closes #21990 from RussellSpitzer/SPARK-25003-master.

Authored-by: Russell Spitzer <Russell.Spitzer@gmail.com>
Signed-off-by: hyukjinkwon <gurwls223@apache.org>
  • Loading branch information
RussellSpitzer authored and HyukjinKwon committed Oct 18, 2018
1 parent 7d425b1 commit c3eaee7
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 18 deletions.
42 changes: 42 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3837,6 +3837,48 @@ def test_query_execution_listener_on_collect_with_arrow(self):
"The callback from the query execution listener should be called after 'toPandas'")


class SparkExtensionsTest(unittest.TestCase):
# These tests are separate because it uses 'spark.sql.extensions' which is
# static and immutable. This can't be set or unset, for example, via `spark.conf`.

@classmethod
def setUpClass(cls):
import glob
from pyspark.find_spark_home import _find_spark_home

SPARK_HOME = _find_spark_home()
filename_pattern = (
"sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
"SparkSessionExtensionSuite.class")
if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
raise unittest.SkipTest(
"'org.apache.spark.sql.SparkSessionExtensionSuite' is not "
"available. Will skip the related tests.")

# Note that 'spark.sql.extensions' is a static immutable configuration.
cls.spark = SparkSession.builder \
.master("local[4]") \
.appName(cls.__name__) \
.config(
"spark.sql.extensions",
"org.apache.spark.sql.MyExtensions") \
.getOrCreate()

@classmethod
def tearDownClass(cls):
cls.spark.stop()

def test_use_custom_class_for_extensions(self):
self.assertTrue(
self.spark._jsparkSession.sessionState().planner().strategies().contains(
self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)),
"MySparkStrategy not found in active planner strategies")
self.assertTrue(
self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains(
self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)),
"MyRule not found in extended resolution rules")


class SparkSessionTests(PySparkTestCase):

# This test is separate because it's closely related with session's start and stop.
Expand Down
56 changes: 38 additions & 18 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,17 @@ class SparkSession private(
// The call site where this SparkSession was constructed.
private val creationSite: CallSite = Utils.getCallSite()

/**
* Constructor used in Pyspark. Contains explicit application of Spark Session Extensions
* which otherwise only occurs during getOrCreate. We cannot add this to the default constructor
* since that would cause every new session to reinvoke Spark Session Extensions on the currently
* running extensions.
*/
private[sql] def this(sc: SparkContext) {
this(sc, None, None, new SparkSessionExtensions)
this(sc, None, None,
SparkSession.applyExtensions(
sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
new SparkSessionExtensions))
}

sparkContext.assertNotStopped()
Expand Down Expand Up @@ -936,23 +945,9 @@ object SparkSession extends Logging {
// Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions.
}

// Initialize extensions if the user has defined a configurator class.
val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
if (extensionConfOption.isDefined) {
val extensionConfClassName = extensionConfOption.get
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.newInstance()
.asInstanceOf[SparkSessionExtensions => Unit]
extensionConf(extensions)
} catch {
// Ignore the error if we cannot find the class or when the class has the wrong type.
case e @ (_: ClassCastException |
_: ClassNotFoundException |
_: NoClassDefFoundError) =>
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
}
}
applyExtensions(
sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
extensions)

session = new SparkSession(sparkContext, None, None, extensions)
options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }
Expand Down Expand Up @@ -1137,4 +1132,29 @@ object SparkSession extends Logging {
SparkSession.clearDefaultSession()
}
}

/**
* Initialize extensions for given extension classname. This class will be applied to the
* extensions passed into this function.
*/
private def applyExtensions(
extensionOption: Option[String],
extensions: SparkSessionExtensions): SparkSessionExtensions = {
if (extensionOption.isDefined) {
val extensionConfClassName = extensionOption.get
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.newInstance()
.asInstanceOf[SparkSessionExtensions => Unit]
extensionConf(extensions)
} catch {
// Ignore the error if we cannot find the class or when the class has the wrong type.
case e@(_: ClassCastException |
_: ClassNotFoundException |
_: NoClassDefFoundError) =>
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
}
}
extensions
}
}

0 comments on commit c3eaee7

Please sign in to comment.