diff --git a/LICENSE-binary b/LICENSE-binary index 456b074842575..b6971798e5577 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -218,7 +218,6 @@ com.google.crypto.tink:tink com.google.flatbuffers:flatbuffers-java com.google.guava:guava com.jamesmurty.utils:java-xmlbuilder -com.jolbox:bonecp com.ning:compress-lzf com.squareup.okhttp3:logging-interceptor com.squareup.okhttp3:okhttp diff --git a/NOTICE-binary b/NOTICE-binary index c82d0b52f31cc..c4cfe0e9f8b31 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -33,11 +33,12 @@ services. // Version 2.0, in this case for // ------------------------------------------------------------------ -Hive Beeline -Copyright 2016 The Apache Software Foundation +=== NOTICE FOR com.clearspring.analytics:streams === +stream-api +Copyright 2016 AddThis -This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). +This product includes software developed by AddThis. +=== END OF NOTICE FOR com.clearspring.analytics:streams === Apache Avro Copyright 2009-2014 The Apache Software Foundation diff --git a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala index 08291859a32cc..ae00987cd69f6 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala @@ -462,14 +462,13 @@ private[spark] object MavenUtils extends Logging { val sysOut = System.out // Default configuration name for ivy val ivyConfName = "default" - - // A Module descriptor must be specified. Entries are dummy strings - val md = getModuleDescriptor - - md.setDefaultConf(ivyConfName) + var md: DefaultModuleDescriptor = null try { // To prevent ivy from logging to system out System.setOut(printStream) + // A Module descriptor must be specified. Entries are dummy strings + md = getModuleDescriptor + md.setDefaultConf(ivyConfName) val artifacts = extractMavenCoordinates(coordinates) // Directories for caching downloads through ivy and storing the jars when maven coordinates // are supplied to spark-submit @@ -548,7 +547,9 @@ private[spark] object MavenUtils extends Logging { } } finally { System.setOut(sysOut) - clearIvyResolutionFiles(md.getModuleRevisionId, ivySettings.getDefaultCache, ivyConfName) + if (md != null) { + clearIvyResolutionFiles(md.getModuleRevisionId, ivySettings.getDefaultCache, ivyConfName) + } } } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index 8110bd9f46a86..203b1295005af 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -108,6 +108,35 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession { assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") } + test("interrupt all - streaming queries") { + val q1 = spark.readStream + .format("rate") + .option("rowsPerSecond", 1) + .load() + .writeStream + .format("console") + .start() + + val q2 = spark.readStream + .format("rate") + .option("rowsPerSecond", 1) + .load() + .writeStream + .format("console") + .start() + + assert(q1.isActive) + assert(q2.isActive) + + val interrupted = spark.interruptAll() + + q1.awaitTermination(timeoutMs = 20 * 1000) + q2.awaitTermination(timeoutMs = 20 * 1000) + assert(!q1.isActive) + assert(!q2.isActive) + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + } + // TODO(SPARK-48139): Re-enable `SparkSessionE2ESuite.interrupt tag` ignore("interrupt tag") { val session = spark @@ -230,6 +259,53 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession { assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") } + test("interrupt tag - streaming query") { + spark.addTag("foo") + val q1 = spark.readStream + .format("rate") + .option("rowsPerSecond", 1) + .load() + .writeStream + .format("console") + .start() + assert(spark.getTags() == Set("foo")) + + spark.addTag("bar") + val q2 = spark.readStream + .format("rate") + .option("rowsPerSecond", 1) + .load() + .writeStream + .format("console") + .start() + assert(spark.getTags() == Set("foo", "bar")) + + spark.clearTags() + + spark.addTag("zoo") + val q3 = spark.readStream + .format("rate") + .option("rowsPerSecond", 1) + .load() + .writeStream + .format("console") + .start() + assert(spark.getTags() == Set("zoo")) + + assert(q1.isActive) + assert(q2.isActive) + assert(q3.isActive) + + val interrupted = spark.interruptTag("foo") + + q1.awaitTermination(timeoutMs = 20 * 1000) + q2.awaitTermination(timeoutMs = 20 * 1000) + assert(!q1.isActive) + assert(!q2.isActive) + assert(q3.isActive) + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + } + test("progress is available for the spark result") { val result = spark .range(10000) diff --git a/connector/connect/docs/client-connection-string.md b/connector/connect/docs/client-connection-string.md index ebab7cbff4fc1..37b2956a5c44a 100644 --- a/connector/connect/docs/client-connection-string.md +++ b/connector/connect/docs/client-connection-string.md @@ -22,8 +22,8 @@ cannot contain arbitrary characters. Configuration parameter are passed in the style of the HTTP URL Path Parameter Syntax. This is similar to the JDBC connection strings. The path component must be empty. All parameters are interpreted **case sensitive**. -```shell -sc://hostname:port/;param1=value;param2=value +```text +sc://host:port/;param1=value;param2=value ``` @@ -34,7 +34,7 @@ sc://hostname:port/;param1=value;param2=value - + - - + @@ -75,7 +75,7 @@ sc://hostname:port/;param1=value;param2=value + Default:
A UUID generated randomly
+ + + + + +
Examples
hostnamehost String The hostname of the endpoint for Spark Connect. Since the endpoint @@ -49,8 +49,8 @@ sc://hostname:port/;param1=value;param2=value
portNumericThe portname to be used when connecting to the GRPC endpoint. The + NumericThe port to be used when connecting to the GRPC endpoint. The default values is: 15002. Any valid port number can be used.
15002
443
user_id String User ID to automatically set in the Spark Connect UserContext message. - This is necssary for the appropriate Spark Session management. This is an + This is necessary for the appropriate Spark Session management. This is an *optional* parameter and depending on the deployment this parameter might be automatically injected using other means. @@ -99,9 +99,16 @@ sc://hostname:port/;param1=value;param2=value allows to provide this session ID to allow sharing Spark Sessions for the same users for example across multiple languages. The value must be provided in a valid UUID string format.
- Default: A UUID generated randomly.
session_id=550e8400-e29b-41d4-a716-446655440000
grpc_max_message_sizeNumericMaximum message size allowed for gRPC messages in bytes.
+ Default:
 128 * 1024 * 1024
grpc_max_message_size=134217728
## Examples diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index b3ae8c71611f4..453d2b30876f4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -3163,7 +3163,11 @@ class SparkConnectPlanner( } // Register the new query so that its reference is cached and is stopped on session timeout. - SparkConnectService.streamingSessionManager.registerNewStreamingQuery(sessionHolder, query) + SparkConnectService.streamingSessionManager.registerNewStreamingQuery( + sessionHolder, + query, + executeHolder.sparkSessionTags, + executeHolder.operationId) // Register the runner with the query if Python foreachBatch is enabled. foreachBatchRunnerCleaner.foreach { cleaner => sessionHolder.streamingForeachBatchRunnerCleanerCache.registerCleanerForQuery( @@ -3228,7 +3232,9 @@ class SparkConnectPlanner( // Find the query in connect service level cache, otherwise check session's active streams. val query = SparkConnectService.streamingSessionManager - .getCachedQuery(id, runId, session) // Common case: query is cached in the cache. + // Common case: query is cached in the cache. + .getCachedQuery(id, runId, executeHolder.sparkSessionTags, session) + .map(_.query) .orElse { // Else try to find it in active streams. Mostly will not be found here either. Option(session.streams.get(id)) } match { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 0285b48405773..681f7e29630ff 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -23,6 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} import scala.jdk.CollectionConverters._ import scala.util.Try @@ -179,12 +180,14 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ private[service] def interruptAll(): Seq[String] = { val interruptedIds = new mutable.ArrayBuffer[String]() + val operationsIds = + SparkConnectService.streamingSessionManager.cleanupRunningQueries(this, blocking = false) executions.asScala.values.foreach { execute => if (execute.interrupt()) { interruptedIds += execute.operationId } } - interruptedIds.toSeq + interruptedIds.toSeq ++ operationsIds } /** @@ -194,6 +197,8 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ private[service] def interruptTag(tag: String): Seq[String] = { val interruptedIds = new mutable.ArrayBuffer[String]() + val queries = SparkConnectService.streamingSessionManager.getTaggedQuery(tag, session) + queries.foreach(q => Future(q.query.stop())(ExecutionContext.global)) executions.asScala.values.foreach { execute => if (execute.sparkSessionTags.contains(tag)) { if (execute.interrupt()) { @@ -201,7 +206,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } } } - interruptedIds.toSeq + interruptedIds.toSeq ++ queries.map(_.operationId) } /** @@ -298,7 +303,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio // Clean up running streaming queries. // Note: there can be concurrent streaming queries being started. - SparkConnectService.streamingSessionManager.cleanupRunningQueries(this) + SparkConnectService.streamingSessionManager.cleanupRunningQueries(this, blocking = true) streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any streaming workers. removeAllListeners() // removes all listener and stop python listener processes if necessary. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala index 91eae18f2d5da..03719ddd87419 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit import javax.annotation.concurrent.GuardedBy import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.{Duration, DurationInt, FiniteDuration} import scala.util.control.NonFatal @@ -55,16 +56,28 @@ private[connect] class SparkConnectStreamingQueryCache( import SparkConnectStreamingQueryCache._ - def registerNewStreamingQuery(sessionHolder: SessionHolder, query: StreamingQuery): Unit = { - queryCacheLock.synchronized { + def registerNewStreamingQuery( + sessionHolder: SessionHolder, + query: StreamingQuery, + tags: Set[String], + operationId: String): Unit = queryCacheLock.synchronized { + taggedQueriesLock.synchronized { val value = QueryCacheValue( userId = sessionHolder.userId, sessionId = sessionHolder.sessionId, session = sessionHolder.session, query = query, + operationId = operationId, expiresAtMs = None) - queryCache.put(QueryCacheKey(query.id.toString, query.runId.toString), value) match { + val queryKey = QueryCacheKey(query.id.toString, query.runId.toString) + tags.foreach { tag => + taggedQueries + .getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]) + .addOne(queryKey) + } + + queryCache.put(queryKey, value) match { case Some(existing) => // Query is being replace. Not really expected. logWarning(log"Replacing existing query in the cache (unexpected). " + log"Query Id: ${MDC(QUERY_ID, query.id)}.Existing value ${MDC(OLD_VALUE, existing)}, " + @@ -80,7 +93,7 @@ private[connect] class SparkConnectStreamingQueryCache( } /** - * Returns [[StreamingQuery]] if it is cached and session matches the cached query. It ensures + * Returns [[QueryCacheValue]] if it is cached and session matches the cached query. It ensures * the session associated with it matches the session passed into the call. If the query is * inactive (i.e. it has a cache expiry time set), this access extends its expiry time. So if a * client keeps accessing a query, it stays in the cache. @@ -88,8 +101,35 @@ private[connect] class SparkConnectStreamingQueryCache( def getCachedQuery( queryId: String, runId: String, - session: SparkSession): Option[StreamingQuery] = { - val key = QueryCacheKey(queryId, runId) + tags: Set[String], + session: SparkSession): Option[QueryCacheValue] = { + taggedQueriesLock.synchronized { + val key = QueryCacheKey(queryId, runId) + val result = getCachedQuery(QueryCacheKey(queryId, runId), session) + tags.foreach { tag => + taggedQueries.getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]).addOne(key) + } + result + } + } + + /** + * Similar with [[getCachedQuery]] but it gets queries tagged previously. + */ + def getTaggedQuery(tag: String, session: SparkSession): Seq[QueryCacheValue] = { + taggedQueriesLock.synchronized { + taggedQueries + .get(tag) + .map { k => + k.flatMap(getCachedQuery(_, session)).toSeq + } + .getOrElse(Seq.empty[QueryCacheValue]) + } + } + + private def getCachedQuery( + key: QueryCacheKey, + session: SparkSession): Option[QueryCacheValue] = { queryCacheLock.synchronized { queryCache.get(key).flatMap { v => if (v.session == session) { @@ -98,7 +138,7 @@ private[connect] class SparkConnectStreamingQueryCache( val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs))) } - Some(v.query) + Some(v) } else None // Should be rare, may be client is trying access from a different session. } } @@ -109,7 +149,10 @@ private[connect] class SparkConnectStreamingQueryCache( * the queryCache. This is used when session is expired and we need to cleanup resources of that * session. */ - def cleanupRunningQueries(sessionHolder: SessionHolder): Unit = { + def cleanupRunningQueries( + sessionHolder: SessionHolder, + blocking: Boolean = true): Seq[String] = { + val operationIds = new mutable.ArrayBuffer[String]() for ((k, v) <- queryCache) { if (v.userId.equals(sessionHolder.userId) && v.sessionId.equals(sessionHolder.sessionId)) { if (v.query.isActive && Option(v.session.streams.get(k.queryId)).nonEmpty) { @@ -117,7 +160,12 @@ private[connect] class SparkConnectStreamingQueryCache( log"Stopping the query with id ${MDC(QUERY_ID, k.queryId)} " + log"since the session has timed out") try { - v.query.stop() + if (blocking) { + v.query.stop() + } else { + Future(v.query.stop())(ExecutionContext.global) + } + operationIds.addOne(v.operationId) } catch { case NonFatal(ex) => logWarning( @@ -128,6 +176,7 @@ private[connect] class SparkConnectStreamingQueryCache( } } } + operationIds.toSeq } // Visible for testing @@ -146,6 +195,10 @@ private[connect] class SparkConnectStreamingQueryCache( private val queryCache = new mutable.HashMap[QueryCacheKey, QueryCacheValue] private val queryCacheLock = new Object + @GuardedBy("queryCacheLock") + private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]] + private val taggedQueriesLock = new Object + @GuardedBy("queryCacheLock") private var scheduledExecutor: Option[ScheduledExecutorService] = None @@ -176,7 +229,7 @@ private[connect] class SparkConnectStreamingQueryCache( * - Update status of query if it is inactive. Sets an expiry time for such queries * - Drop expired queries from the cache. */ - private def periodicMaintenance(): Unit = { + private def periodicMaintenance(): Unit = taggedQueriesLock.synchronized { queryCacheLock.synchronized { val nowMs = clock.getTimeMillis() @@ -212,6 +265,18 @@ private[connect] class SparkConnectStreamingQueryCache( } } } + + taggedQueries.toArray.foreach { case (key, value) => + value.zipWithIndex.toArray.foreach { case (queryKey, i) => + if (queryCache.contains(queryKey)) { + value.remove(i) + } + } + + if (value.isEmpty) { + taggedQueries.remove(key) + } + } } } } @@ -225,6 +290,7 @@ private[connect] object SparkConnectStreamingQueryCache { sessionId: String, session: SparkSession, // Holds the reference to the session. query: StreamingQuery, // Holds the reference to the query. + operationId: String, expiresAtMs: Option[Long] = None // Expiry time for a stopped query. ) { override def toString(): String = diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index af18fca9dd216..71ca0f44af680 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -826,7 +826,9 @@ class SparkConnectServiceSuite when(restartedQuery.runId).thenReturn(DEFAULT_UUID) SparkConnectService.streamingSessionManager.registerNewStreamingQuery( SparkConnectService.getOrCreateIsolatedSession("c1", sessionId, None), - restartedQuery) + restartedQuery, + Set.empty[String], + "") f(verifyEvents) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala index ed3da2c0f7156..512a0a80c4a91 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala @@ -67,7 +67,7 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug // Register the query. - sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery) + sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery, Set.empty[String], "") sessionMgr.getCachedValue(queryId, runId) match { case Some(v) => @@ -78,9 +78,14 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug } // Verify query is returned only with the correct session, not with a different session. - assert(sessionMgr.getCachedQuery(queryId, runId, mock[SparkSession]).isEmpty) + assert( + sessionMgr.getCachedQuery(queryId, runId, Set.empty[String], mock[SparkSession]).isEmpty) // Query is returned when correct session is used - assert(sessionMgr.getCachedQuery(queryId, runId, mockSession).contains(mockQuery)) + assert( + sessionMgr + .getCachedQuery(queryId, runId, Set.empty[String], mockSession) + .map(_.query) + .contains(mockQuery)) // Cleanup the query and verify if stop() method has been called. when(mockQuery.isActive).thenReturn(false) @@ -99,7 +104,11 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug clock.advance(30.seconds.toMillis) // Access the query. This should advance expiry time by 30 seconds. - assert(sessionMgr.getCachedQuery(queryId, runId, mockSession).contains(mockQuery)) + assert( + sessionMgr + .getCachedQuery(queryId, runId, Set.empty[String], mockSession) + .map(_.query) + .contains(mockQuery)) val expiresAtMs = sessionMgr.getCachedValue(queryId, runId).get.expiresAtMs.get assert(expiresAtMs == prevExpiryTimeMs + 30.seconds.toMillis) @@ -112,7 +121,7 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug when(restartedQuery.isActive).thenReturn(true) when(mockStreamingQueryManager.get(queryId)).thenReturn(restartedQuery) - sessionMgr.registerNewStreamingQuery(sessionHolder, restartedQuery) + sessionMgr.registerNewStreamingQuery(sessionHolder, restartedQuery, Set.empty[String], "") // Both queries should existing in the cache. assert(sessionMgr.getCachedValue(queryId, runId).map(_.query).contains(mockQuery)) diff --git a/core/benchmarks/LZFBenchmark-jdk21-results.txt b/core/benchmarks/LZFBenchmark-jdk21-results.txt new file mode 100644 index 0000000000000..e1566f201a1f6 --- /dev/null +++ b/core/benchmarks/LZFBenchmark-jdk21-results.txt @@ -0,0 +1,19 @@ +================================================================================================ +Benchmark LZFCompressionCodec +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1021-azure +AMD EPYC 7763 64-Core Processor +Compress small objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Compression 256000000 int values in parallel 598 600 2 428.2 2.3 1.0X +Compression 256000000 int values single-threaded 568 570 2 451.0 2.2 1.1X + +OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1021-azure +AMD EPYC 7763 64-Core Processor +Compress large objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Compression 1024 array values in 1 threads 39 45 5 0.0 38475.4 1.0X +Compression 1024 array values single-threaded 32 33 1 0.0 31154.5 1.2X + + diff --git a/core/benchmarks/LZFBenchmark-results.txt b/core/benchmarks/LZFBenchmark-results.txt new file mode 100644 index 0000000000000..facc67f9cf4a8 --- /dev/null +++ b/core/benchmarks/LZFBenchmark-results.txt @@ -0,0 +1,19 @@ +================================================================================================ +Benchmark LZFCompressionCodec +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1021-azure +AMD EPYC 7763 64-Core Processor +Compress small objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------- +Compression 256000000 int values in parallel 602 612 6 425.1 2.4 1.0X +Compression 256000000 int values single-threaded 610 617 5 419.8 2.4 1.0X + +OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1021-azure +AMD EPYC 7763 64-Core Processor +Compress large objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------- +Compression 1024 array values in 1 threads 35 43 6 0.0 33806.8 1.0X +Compression 1024 array values single-threaded 32 32 0 0.0 30990.4 1.1X + + diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 90d8cef00ef83..6eb2bea40bdb5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -421,7 +421,7 @@ class SparkContext(config: SparkConf) extends Logging { } // HADOOP-19097 Set fs.s3a.connection.establish.timeout to 30s // We can remove this after Apache Hadoop 3.4.1 releases - conf.setIfMissing("spark.hadoop.fs.s3a.connection.establish.timeout", "30s") + conf.setIfMissing("spark.hadoop.fs.s3a.connection.establish.timeout", "30000") // This should be set as early as possible. SparkContext.fillMissingMagicCommitterConfsIfNeeded(_conf) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index dc3edfaa86133..a7268c6409913 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2031,6 +2031,13 @@ package object config { .intConf .createWithDefault(1) + private[spark] val IO_COMPRESSION_LZF_PARALLEL = + ConfigBuilder("spark.io.compression.lzf.parallel.enabled") + .doc("When true, LZF compression will use multiple threads to compress data in parallel.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + private[spark] val IO_WARNING_LARGEFILETHRESHOLD = ConfigBuilder("spark.io.warning.largeFileThreshold") .internal() diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 07e694b6c5b03..233228a9c6d4c 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -22,6 +22,7 @@ import java.util.Locale import com.github.luben.zstd.{NoPool, RecyclingBufferPool, ZstdInputStreamNoFinalizer, ZstdOutputStreamNoFinalizer} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} +import com.ning.compress.lzf.parallel.PLZFOutputStream import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream, LZ4Factory} import net.jpountz.xxhash.XXHashFactory import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} @@ -100,8 +101,9 @@ private[spark] object CompressionCodec { * If it is already a short name, just return it. */ def getShortName(codecName: String): String = { - if (shortCompressionCodecNames.contains(codecName)) { - codecName + val lowercasedCodec = codecName.toLowerCase(Locale.ROOT) + if (shortCompressionCodecNames.contains(lowercasedCodec)) { + lowercasedCodec } else { shortCompressionCodecNames .collectFirst { case (k, v) if v == codecName => k } @@ -170,9 +172,14 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { */ @DeveloperApi class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { + private val parallelCompression = conf.get(IO_COMPRESSION_LZF_PARALLEL) override def compressedOutputStream(s: OutputStream): OutputStream = { - new LZFOutputStream(s).setFinishBlockOnFlush(true) + if (parallelCompression) { + new PLZFOutputStream(s) + } else { + new LZFOutputStream(s).setFinishBlockOnFlush(true) + } } override def compressedInputStream(s: InputStream): InputStream = new LZFInputStream(s) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 991fb074d246d..0ac1405abe6c3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import java.io._ import java.lang.{Byte => JByte} -import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} +import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, PlatformManagedObject, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} import java.net._ @@ -3058,8 +3058,16 @@ private[spark] object Utils */ lazy val isG1GC: Boolean = { Try { - ManagementFactory.getGarbageCollectorMXBeans.asScala - .exists(_.getName.contains("G1")) + val clazz = Utils.classForName("com.sun.management.HotSpotDiagnosticMXBean") + .asInstanceOf[Class[_ <: PlatformManagedObject]] + val vmOptionClazz = Utils.classForName("com.sun.management.VMOption") + val hotSpotDiagnosticMXBean = ManagementFactory.getPlatformMXBean(clazz) + val vmOptionMethod = clazz.getMethod("getVMOption", classOf[String]) + val valueMethod = vmOptionClazz.getMethod("getValue") + + val useG1GCObject = vmOptionMethod.invoke(hotSpotDiagnosticMXBean, "UseG1GC") + val useG1GC = valueMethod.invoke(useG1GCObject).asInstanceOf[String] + "true".equals(useG1GC) }.getOrElse(false) } } diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 729fcecff1207..5c09a1f965b9e 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.io import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.util.Locale import com.google.common.io.ByteStreams @@ -160,4 +161,18 @@ class CompressionCodecSuite extends SparkFunSuite { ByteStreams.readFully(concatenatedBytes, decompressed) assert(decompressed.toSeq === (0 to 127)) } + + test("SPARK-48506: CompressionCodec getShortName is case insensitive for short names") { + CompressionCodec.shortCompressionCodecNames.foreach { case (shortName, codecClass) => + assert(CompressionCodec.getShortName(shortName) === shortName) + assert(CompressionCodec.getShortName(shortName.toUpperCase(Locale.ROOT)) === shortName) + assert(CompressionCodec.getShortName(codecClass) === shortName) + checkError( + exception = intercept[SparkIllegalArgumentException] { + CompressionCodec.getShortName(codecClass.toUpperCase(Locale.ROOT)) + }, + errorClass = "CODEC_SHORT_NAME_NOT_FOUND", + parameters = Map("codecName" -> codecClass.toUpperCase(Locale.ROOT))) + } + } } diff --git a/core/src/test/scala/org/apache/spark/io/LZFBenchmark.scala b/core/src/test/scala/org/apache/spark/io/LZFBenchmark.scala new file mode 100644 index 0000000000000..1934bd5169703 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/io/LZFBenchmark.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.io + +import java.io.{ByteArrayOutputStream, ObjectOutputStream} +import java.lang.management.ManagementFactory + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.internal.config.IO_COMPRESSION_LZF_PARALLEL + +/** + * Benchmark for ZStandard codec performance. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "core/Test/runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/Test/runMain " + * Results will be written to "benchmarks/ZStandardBenchmark-results.txt". + * }}} + */ +object LZFBenchmark extends BenchmarkBase { + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Benchmark LZFCompressionCodec") { + compressSmallObjects() + compressLargeObjects() + } + } + + private def compressSmallObjects(): Unit = { + val N = 256_000_000 + val benchmark = new Benchmark("Compress small objects", N, output = output) + Seq(true, false).foreach { parallel => + val conf = new SparkConf(false).set(IO_COMPRESSION_LZF_PARALLEL, parallel) + val condition = if (parallel) "in parallel" else "single-threaded" + benchmark.addCase(s"Compression $N int values $condition") { _ => + val os = new LZFCompressionCodec(conf).compressedOutputStream(new ByteArrayOutputStream()) + for (i <- 1 until N) { + os.write(i) + } + os.close() + } + } + benchmark.run() + } + + private def compressLargeObjects(): Unit = { + val N = 1024 + val data: Array[Byte] = (1 until 128 * 1024 * 1024).map(_.toByte).toArray + val benchmark = new Benchmark(s"Compress large objects", N, output = output) + + // com.ning.compress.lzf.parallel.PLZFOutputStream.getNThreads + def getNThreads: Int = { + var nThreads = Runtime.getRuntime.availableProcessors + val jmx = ManagementFactory.getOperatingSystemMXBean + if (jmx != null) { + val loadAverage = jmx.getSystemLoadAverage.toInt + if (nThreads > 1 && loadAverage >= 1) nThreads = Math.max(1, nThreads - loadAverage) + } + nThreads + } + Seq(true, false).foreach { parallel => + val conf = new SparkConf(false).set(IO_COMPRESSION_LZF_PARALLEL, parallel) + val condition = if (parallel) s"in $getNThreads threads" else "single-threaded" + benchmark.addCase(s"Compression $N array values $condition") { _ => + val baos = new ByteArrayOutputStream() + val zcos = new LZFCompressionCodec(conf).compressedOutputStream(baos) + val oos = new ObjectOutputStream(zcos) + 1 to N foreach { _ => + oos.writeObject(data) + } + oos.close() + } + } + benchmark.run() + } +} diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 65e627b1854f9..8ab76b5787b8c 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -29,7 +29,6 @@ azure-data-lake-store-sdk/2.3.9//azure-data-lake-store-sdk-2.3.9.jar azure-keyvault-core/1.0.0//azure-keyvault-core-1.0.0.jar azure-storage/7.0.1//azure-storage-7.0.1.jar blas/3.0.3//blas-3.0.3.jar -bonecp/0.8.0.RELEASE//bonecp-0.8.0.RELEASE.jar breeze-macros_2.13/2.1.0//breeze-macros_2.13-2.1.0.jar breeze_2.13/2.1.0//breeze_2.13-2.1.0.jar bundle/2.24.6//bundle-2.24.6.jar @@ -137,8 +136,8 @@ jersey-container-servlet/3.0.12//jersey-container-servlet-3.0.12.jar jersey-hk2/3.0.12//jersey-hk2-3.0.12.jar jersey-server/3.0.12//jersey-server-3.0.12.jar jettison/1.5.4//jettison-1.5.4.jar -jetty-util-ajax/11.0.20//jetty-util-ajax-11.0.20.jar -jetty-util/11.0.20//jetty-util-11.0.20.jar +jetty-util-ajax/11.0.21//jetty-util-ajax-11.0.21.jar +jetty-util/11.0.21//jetty-util-11.0.21.jar jline/2.14.6//jline-2.14.6.jar jline/3.25.1//jline-3.25.1.jar jna/5.14.0//jna-5.14.0.jar @@ -262,7 +261,7 @@ spire-platform_2.13/0.18.0//spire-platform_2.13-0.18.0.jar spire-util_2.13/0.18.0//spire-util_2.13-0.18.0.jar spire_2.13/0.18.0//spire_2.13-0.18.0.jar stax-api/1.0.1//stax-api-1.0.1.jar -stream/2.9.6//stream-2.9.6.jar +stream/2.9.8//stream-2.9.8.jar super-csv/2.2.0//super-csv-2.2.0.jar threeten-extra/1.7.1//threeten-extra-1.7.1.jar tink/1.13.0//tink-1.13.0.jar diff --git a/dev/pyproject.toml b/dev/pyproject.toml index 4f462d14c7838..f19107b3782a6 100644 --- a/dev/pyproject.toml +++ b/dev/pyproject.toml @@ -29,6 +29,6 @@ testpaths = [ # GitHub workflow version and dev/reformat-python required-version = "23.9.1" line-length = 100 -target-version = ['py38'] +target-version = ['py39'] include = '\.pyi?$' extend-exclude = 'cloudpickle|error_classes.py' diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index ec4b25547dc4c..e182d0c33f16c 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1009,6 +1009,7 @@ def __hash__(self): # sql unittests "pyspark.sql.tests.connect.test_connect_plan", "pyspark.sql.tests.connect.test_connect_basic", + "pyspark.sql.tests.connect.test_connect_dataframe_property", "pyspark.sql.tests.connect.test_connect_error", "pyspark.sql.tests.connect.test_connect_function", "pyspark.sql.tests.connect.test_connect_collection", diff --git a/docs/configuration.md b/docs/configuration.md index 409f1f521eb52..23443cab2eacc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1895,6 +1895,14 @@ Apart from these, the following properties are also available, and may be useful 4.0.0 + + spark.io.compression.lzf.parallel.enabled + false + + When true, LZF compression will use multiple threads to compress data in parallel. + + 4.0.0 + spark.kryo.classesToRegister (none) diff --git a/docs/core-migration-guide.md b/docs/core-migration-guide.md index 28a9dd0f43715..1c37fded53ab7 100644 --- a/docs/core-migration-guide.md +++ b/docs/core-migration-guide.md @@ -48,7 +48,7 @@ license: | - Since Spark 4.0, the MDC (Mapped Diagnostic Context) key for Spark task names in Spark logs has been changed from `mdc.taskName` to `task_name`. To use the key `mdc.taskName`, you can set `spark.log.legacyTaskNameMdc.enabled` to `true`. -- Since Spark 4.0, Spark performs speculative executions less agressively with `spark.speculation.multiplier=3` and `spark.speculation.quantile=0.9`. To restore the legacy behavior, you can set `spark.speculation.multiplier=1.5` and `spark.speculation.quantile=0.75`. +- Since Spark 4.0, Spark performs speculative executions less aggressively with `spark.speculation.multiplier=3` and `spark.speculation.quantile=0.9`. To restore the legacy behavior, you can set `spark.speculation.multiplier=1.5` and `spark.speculation.quantile=0.75`. ## Upgrading from Core 3.4 to 3.5 diff --git a/pom.xml b/pom.xml index ded8cc2405fde..bc81b810715b5 100644 --- a/pom.xml +++ b/pom.xml @@ -140,7 +140,7 @@ 1.13.1 2.0.1 shaded-protobuf - 11.0.20 + 11.0.21 5.0.0 4.0.1 @@ -806,7 +806,7 @@ com.clearspring.analytics stream - 2.9.6 + 2.9.8 @@ -1264,7 +1264,7 @@ com.github.docker-java docker-java - 3.3.5 + 3.3.6 test @@ -1284,7 +1284,7 @@ com.github.docker-java docker-java-transport-zerodep - 3.3.5 + 3.3.6 test @@ -2332,6 +2332,10 @@ co.cask.tephra * + + com.jolbox + bonecp + diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 2a5bb5939a3f8..85806b1a265b0 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -301,7 +301,7 @@ def applyInPandas( evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, ) - return DataFrame( + res = DataFrame( plan.GroupMap( child=self._df._plan, grouping_cols=self._grouping_cols, @@ -310,6 +310,9 @@ def applyInPandas( ), session=self._df._session, ) + if isinstance(schema, StructType): + res._cached_schema = schema + return res applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__ @@ -370,7 +373,7 @@ def applyInArrow( evalType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, ) - return DataFrame( + res = DataFrame( plan.GroupMap( child=self._df._plan, grouping_cols=self._grouping_cols, @@ -379,6 +382,9 @@ def applyInArrow( ), session=self._df._session, ) + if isinstance(schema, StructType): + res._cached_schema = schema + return res applyInArrow.__doc__ = PySparkGroupedData.applyInArrow.__doc__ @@ -410,7 +416,7 @@ def applyInPandas( evalType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, ) - return DataFrame( + res = DataFrame( plan.CoGroupMap( input=self._gd1._df._plan, input_grouping_cols=self._gd1._grouping_cols, @@ -420,6 +426,9 @@ def applyInPandas( ), session=self._gd1._df._session, ) + if isinstance(schema, StructType): + res._cached_schema = schema + return res applyInPandas.__doc__ = PySparkPandasCogroupedOps.applyInPandas.__doc__ @@ -436,7 +445,7 @@ def applyInArrow( evalType=PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, ) - return DataFrame( + res = DataFrame( plan.CoGroupMap( input=self._gd1._df._plan, input_grouping_cols=self._gd1._grouping_cols, @@ -446,6 +455,9 @@ def applyInArrow( ), session=self._gd1._df._session, ) + if isinstance(schema, StructType): + res._cached_schema = schema + return res applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py new file mode 100644 index 0000000000000..f80f4509a7cec --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py @@ -0,0 +1,284 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType +from pyspark.sql.utils import is_remote + +from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase +from pyspark.testing.sqlutils import ( + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + +if have_pyarrow: + import pyarrow as pa + import pyarrow.compute as pc + +if have_pandas: + import pandas as pd + + +class SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase): + def test_cached_schema_to(self): + cdf = self.connect.read.table(self.tbl_name) + sdf = self.spark.read.table(self.tbl_name) + + schema = StructType( + [ + StructField("id", IntegerType(), True), + StructField("name", StringType(), True), + ] + ) + + cdf1 = cdf.to(schema) + self.assertEqual(cdf1._cached_schema, schema) + + sdf1 = sdf.to(schema) + + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertEqual(cdf1.collect(), sdf1.collect()) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, + ) + def test_cached_schema_map_in_pandas(self): + data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")] + cdf = self.connect.createDataFrame(data, "a int, b string") + sdf = self.spark.createDataFrame(data, "a int, b string") + + def func(iterator): + for pdf in iterator: + assert isinstance(pdf, pd.DataFrame) + assert [d.name for d in list(pdf.dtypes)] == ["int32", "object"] + yield pdf + + schema = StructType( + [ + StructField("a", IntegerType(), True), + StructField("b", StringType(), True), + ] + ) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}): + self.assertTrue(is_remote()) + cdf1 = cdf.mapInPandas(func, schema) + self.assertEqual(cdf1._cached_schema, schema) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}): + # 'mapInPandas' depends on the method 'pandas_udf', which is dispatched + # based on 'is_remote'. However, in SparkConnectSQLTestCase, the remote + # mode is always on, so 'sdf.mapInPandas' fails with incorrect dispatch. + # Using this temp env to properly invoke mapInPandas in PySpark Classic. + self.assertFalse(is_remote()) + sdf1 = sdf.mapInPandas(func, schema) + + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertEqual(cdf1.collect(), sdf1.collect()) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, + ) + def test_cached_schema_map_in_arrow(self): + data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")] + cdf = self.connect.createDataFrame(data, "a int, b string") + sdf = self.spark.createDataFrame(data, "a int, b string") + + def func(iterator): + for batch in iterator: + assert isinstance(batch, pa.RecordBatch) + assert batch.schema.types == [pa.int32(), pa.string()] + yield batch + + schema = StructType( + [ + StructField("a", IntegerType(), True), + StructField("b", StringType(), True), + ] + ) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}): + self.assertTrue(is_remote()) + cdf1 = cdf.mapInArrow(func, schema) + self.assertEqual(cdf1._cached_schema, schema) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}): + self.assertFalse(is_remote()) + sdf1 = sdf.mapInArrow(func, schema) + + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertEqual(cdf1.collect(), sdf1.collect()) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, + ) + def test_cached_schema_group_apply_in_pandas(self): + data = [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)] + cdf = self.connect.createDataFrame(data, ("id", "v")) + sdf = self.spark.createDataFrame(data, ("id", "v")) + + def normalize(pdf): + v = pdf.v + return pdf.assign(v=(v - v.mean()) / v.std()) + + schema = StructType( + [ + StructField("id", LongType(), True), + StructField("v", DoubleType(), True), + ] + ) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}): + self.assertTrue(is_remote()) + cdf1 = cdf.groupby("id").applyInPandas(normalize, schema) + self.assertEqual(cdf1._cached_schema, schema) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}): + self.assertFalse(is_remote()) + sdf1 = sdf.groupby("id").applyInPandas(normalize, schema) + + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertEqual(cdf1.collect(), sdf1.collect()) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, + ) + def test_cached_schema_group_apply_in_arrow(self): + data = [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)] + cdf = self.connect.createDataFrame(data, ("id", "v")) + sdf = self.spark.createDataFrame(data, ("id", "v")) + + def normalize(table): + v = table.column("v") + norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, ddof=1)) + return table.set_column(1, "v", norm) + + schema = StructType( + [ + StructField("id", LongType(), True), + StructField("v", DoubleType(), True), + ] + ) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}): + self.assertTrue(is_remote()) + cdf1 = cdf.groupby("id").applyInArrow(normalize, schema) + self.assertEqual(cdf1._cached_schema, schema) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}): + self.assertFalse(is_remote()) + sdf1 = sdf.groupby("id").applyInArrow(normalize, schema) + + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertEqual(cdf1.collect(), sdf1.collect()) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, + ) + def test_cached_schema_cogroup_apply_in_pandas(self): + data1 = [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)] + data2 = [(20000101, 1, "x"), (20000101, 2, "y")] + + cdf1 = self.connect.createDataFrame(data1, ("time", "id", "v1")) + sdf1 = self.spark.createDataFrame(data1, ("time", "id", "v1")) + cdf2 = self.connect.createDataFrame(data2, ("time", "id", "v2")) + sdf2 = self.spark.createDataFrame(data2, ("time", "id", "v2")) + + def asof_join(left, right): + return pd.merge_asof(left, right, on="time", by="id") + + schema = StructType( + [ + StructField("time", IntegerType(), True), + StructField("id", IntegerType(), True), + StructField("v1", DoubleType(), True), + StructField("v2", StringType(), True), + ] + ) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}): + self.assertTrue(is_remote()) + cdf3 = cdf1.groupby("id").cogroup(cdf2.groupby("id")).applyInPandas(asof_join, schema) + self.assertEqual(cdf3._cached_schema, schema) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}): + self.assertFalse(is_remote()) + sdf3 = sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInPandas(asof_join, schema) + + self.assertEqual(cdf3.schema, sdf3.schema) + self.assertEqual(cdf3.collect(), sdf3.collect()) + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, + ) + def test_cached_schema_cogroup_apply_in_arrow(self): + data1 = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)] + data2 = [(1, "x"), (2, "y")] + + cdf1 = self.connect.createDataFrame(data1, ("id", "v1")) + sdf1 = self.spark.createDataFrame(data1, ("id", "v1")) + cdf2 = self.connect.createDataFrame(data2, ("id", "v2")) + sdf2 = self.spark.createDataFrame(data2, ("id", "v2")) + + def summarize(left, right): + return pa.Table.from_pydict( + { + "left": [left.num_rows], + "right": [right.num_rows], + } + ) + + schema = StructType( + [ + StructField("left", LongType(), True), + StructField("right", LongType(), True), + ] + ) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}): + self.assertTrue(is_remote()) + cdf3 = cdf1.groupby("id").cogroup(cdf2.groupby("id")).applyInArrow(summarize, schema) + self.assertEqual(cdf3._cached_schema, schema) + + with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}): + self.assertFalse(is_remote()) + sdf3 = sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInArrow(summarize, schema) + + self.assertEqual(cdf3.schema, sdf3.schema) + self.assertEqual(cdf3.collect(), sdf3.collect()) + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 74d56491a29e7..4e2e2c51b4db0 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -1204,8 +1204,13 @@ def check_toPandas_error(self, arrow_enabled): self.spark.sql("select 1/0").toPandas() def test_toArrow_error(self): - with self.assertRaises(ArithmeticException): - self.spark.sql("select 1/0").toArrow() + with self.sql_conf( + { + "spark.sql.ansi.enabled": True, + } + ): + with self.assertRaises(ArithmeticException): + self.spark.sql("select 1/0").toArrow() def test_toPandas_duplicate_field_names(self): for arrow_enabled in [True, False]: diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py index b381833314861..f363b8748c0b9 100644 --- a/python/pyspark/sql/tests/test_context.py +++ b/python/pyspark/sql/tests/test_context.py @@ -26,13 +26,13 @@ from pyspark import SparkContext, SQLContext from pyspark.sql import Row, SparkSession from pyspark.sql.types import StructType, StringType, StructField -from pyspark.testing.utils import ReusedPySparkTestCase +from pyspark.testing.sqlutils import ReusedSQLTestCase -class HiveContextSQLTests(ReusedPySparkTestCase): +class HiveContextSQLTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() + ReusedSQLTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) cls.hive_available = True cls.spark = None @@ -58,7 +58,7 @@ def setUp(self): @classmethod def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() + ReusedSQLTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) if cls.spark is not None: cls.spark.stop() @@ -100,23 +100,20 @@ def test_save_and_load_table(self): self.spark.sql("DROP TABLE savedJsonTable") self.spark.sql("DROP TABLE externalJsonTable") - defaultDataSourceName = self.spark.conf.get( - "spark.sql.sources.default", "org.apache.spark.sql.parquet" - ) - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") - actual = self.spark.catalog.createTable("externalJsonTable", path=tmpPath) - self.assertEqual( - sorted(df.collect()), sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()) - ) - self.assertEqual( - sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()), - ) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("DROP TABLE savedJsonTable") - self.spark.sql("DROP TABLE externalJsonTable") - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}): + df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + actual = self.spark.catalog.createTable("externalJsonTable", path=tmpPath) + self.assertEqual( + sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()), + ) + self.assertEqual( + sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()), + ) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") shutil.rmtree(tmpPath) diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index e752856d03164..8060a9ae8bc76 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -55,12 +55,9 @@ def test_save_and_load(self): ) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - try: - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}): actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - finally: - self.spark.sql("RESET spark.sql.sources.default") csvpath = os.path.join(tempfile.mkdtemp(), "data") df.write.option("quote", None).format("csv").save(csvpath) @@ -94,12 +91,9 @@ def test_save_and_load_builder(self): ) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - try: - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}): actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - finally: - self.spark.sql("RESET spark.sql.sources.default") finally: shutil.rmtree(tmpPath) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 80f2c0fcbc033..1882c1fd1f6ad 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -491,14 +491,11 @@ class User: self.assertEqual(asdict(user), r.asDict()) def test_negative_decimal(self): - try: - self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true") + with self.sql_conf({"spark.sql.legacy.allowNegativeScaleOfDecimal": True}): df = self.spark.createDataFrame([(1,), (11,)], ["value"]) ret = df.select(F.col("value").cast(DecimalType(1, -1))).collect() actual = list(map(lambda r: int(r.value), ret)) self.assertEqual(actual, [0, 10]) - finally: - self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false") def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index a0fdada72972b..9f07c44c084cf 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -247,6 +247,29 @@ def function(self, *functions): for f in functions: self.spark.sql("DROP FUNCTION IF EXISTS %s" % f) + @contextmanager + def temp_env(self, pairs): + assert isinstance(pairs, dict), "pairs should be a dictionary." + + keys = pairs.keys() + new_values = pairs.values() + old_values = [os.environ.get(key, None) for key in keys] + for key, new_value in zip(keys, new_values): + if new_value is None: + if key in os.environ: + del os.environ[key] + else: + os.environ[key] = new_value + try: + yield + finally: + for key, old_value in zip(keys, old_values): + if old_value is None: + if key in os.environ: + del os.environ[key] + else: + os.environ[key] = old_value + @staticmethod def assert_close(a, b): c = [j[0] for j in b] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1c2baa78be1b9..f4408220ac939 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -143,50 +143,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass, missingCol, orderedCandidates, a.origin) } - private def checkUnreferencedCTERelations( - cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], - visited: mutable.Map[Long, Boolean], - danglingCTERelations: mutable.ArrayBuffer[CTERelationDef], - cteId: Long): Unit = { - if (visited(cteId)) { - return - } - val (cteDef, _, refMap) = cteMap(cteId) - refMap.foreach { case (id, _) => - checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, id) - } - danglingCTERelations.append(cteDef) - visited(cteId) = true - } - def checkAnalysis(plan: LogicalPlan): Unit = { - val inlineCTE = InlineCTE(alwaysInline = true) - val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] - inlineCTE.buildCTEMap(plan, cteMap) - val danglingCTERelations = mutable.ArrayBuffer.empty[CTERelationDef] - val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false) - // If a CTE relation is never used, it will disappear after inline. Here we explicitly collect - // these dangling CTE relations, and put them back in the main query, to make sure the entire - // query plan is valid. - cteMap.foreach { case (cteId, (_, refCount, _)) => - // If a CTE relation ref count is 0, the other CTE relations that reference it should also be - // collected. This code will also guarantee the leaf relations that do not reference - // any others are collected first. - if (refCount == 0) { - checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, cteId) - } - } - // Inline all CTEs in the plan to help check query plan structures in subqueries. - var inlinedPlan: LogicalPlan = plan - try { - inlinedPlan = inlineCTE(plan) + // We should inline all CTE relations to restore the original plan shape, as the analysis check + // may need to match certain plan shapes. For dangling CTE relations, they will still be kept + // in the original `WithCTE` node, as we need to perform analysis check for them as well. + val inlineCTE = InlineCTE(alwaysInline = true, keepDanglingRelations = true) + val inlinedPlan: LogicalPlan = try { + inlineCTE(plan) } catch { case e: AnalysisException => throw new ExtendedAnalysisException(e, plan) } - if (danglingCTERelations.nonEmpty) { - inlinedPlan = WithCTE(inlinedPlan, danglingCTERelations.toSeq) - } try { checkAnalysis0(inlinedPlan) } catch { @@ -1404,6 +1371,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB aggregated, canContainOuter && SQLConf.get.getConf(SQLConf.DECORRELATE_OFFSET_ENABLED)) + // We always inline CTE relations before analysis check, and only un-referenced CTE + // relations will be kept in the plan. Here we should simply skip them and check the + // children, as un-referenced CTE relations won't be executed anyway and doesn't need to + // be restricted by the current subquery correlation limitations. + case _: WithCTE | _: CTERelationDef => + plan.children.foreach(p => checkPlan(p, aggregated, canContainOuter)) + // Category 4: Any other operators not in the above 3 categories // cannot be on a correlation path, that is they are allowed only // under a correlation point but they and their descendant operators diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a52feaa41acf9..588752f3fc176 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -222,7 +222,7 @@ trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging builder: FunctionBuilder): Unit = { val newFunction = (info, builder) functionBuilders.put(name, newFunction) match { - case previousFunction if previousFunction != newFunction => + case previousFunction if previousFunction != null => logWarning(log"The function ${MDC(FUNCTION_NAME, name)} replaced a " + log"previously registered function.") case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index 8d7ff4cbf163d..50828b945bb40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -37,23 +37,19 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION} * query level. * * @param alwaysInline if true, inline all CTEs in the query plan. + * @param keepDanglingRelations if true, dangling CTE relations will be kept in the original + * `WithCTE` node. */ -case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { +case class InlineCTE( + alwaysInline: Boolean = false, + keepDanglingRelations: Boolean = false) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) { - val cteMap = mutable.SortedMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] + val cteMap = mutable.SortedMap.empty[Long, CTEReferenceInfo] buildCTEMap(plan, cteMap) cleanCTEMap(cteMap) - val notInlined = mutable.ArrayBuffer.empty[CTERelationDef] - val inlined = inlineCTE(plan, cteMap, notInlined) - // CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add - // WithCTE as top node here. - if (notInlined.isEmpty) { - inlined - } else { - WithCTE(inlined, notInlined.toSeq) - } + inlineCTE(plan, cteMap) } else { plan } @@ -74,22 +70,23 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { * * @param plan The plan to collect the CTEs from * @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE - * ids. The value of the map is tuple whose elements are: - * - The CTE definition - * - The number of incoming references to the CTE. This includes references from - * other CTEs and regular places. - * - A mutable inner map that tracks outgoing references (counts) to other CTEs. + * ids. * @param outerCTEId While collecting the map we use this optional CTE id to identify the * current outer CTE. */ - def buildCTEMap( + private def buildCTEMap( plan: LogicalPlan, - cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], + cteMap: mutable.Map[Long, CTEReferenceInfo], outerCTEId: Option[Long] = None): Unit = { plan match { case WithCTE(child, cteDefs) => cteDefs.foreach { cteDef => - cteMap(cteDef.id) = (cteDef, 0, mutable.Map.empty.withDefaultValue(0)) + cteMap(cteDef.id) = CTEReferenceInfo( + cteDef = cteDef, + refCount = 0, + outgoingRefs = mutable.Map.empty.withDefaultValue(0), + shouldInline = true + ) } cteDefs.foreach { cteDef => buildCTEMap(cteDef, cteMap, Some(cteDef.id)) @@ -97,11 +94,9 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { buildCTEMap(child, cteMap, outerCTEId) case ref: CTERelationRef => - val (cteDef, refCount, refMap) = cteMap(ref.cteId) - cteMap(ref.cteId) = (cteDef, refCount + 1, refMap) + cteMap(ref.cteId) = cteMap(ref.cteId).withRefCountIncreased(1) outerCTEId.foreach { cteId => - val (_, _, outerRefMap) = cteMap(cteId) - outerRefMap(ref.cteId) += 1 + cteMap(cteId).increaseOutgoingRefCount(ref.cteId, 1) } case _ => @@ -129,15 +124,12 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { * @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE * ids. Needs to be sorted to speed up cleaning. */ - private def cleanCTEMap( - cteMap: mutable.SortedMap[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] - ) = { + private def cleanCTEMap(cteMap: mutable.SortedMap[Long, CTEReferenceInfo]): Unit = { cteMap.keys.toSeq.reverse.foreach { currentCTEId => - val (_, currentRefCount, refMap) = cteMap(currentCTEId) - if (currentRefCount == 0) { - refMap.foreach { case (referencedCTEId, uselessRefCount) => - val (cteDef, refCount, refMap) = cteMap(referencedCTEId) - cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, refMap) + val refInfo = cteMap(currentCTEId) + if (refInfo.refCount == 0) { + refInfo.outgoingRefs.foreach { case (referencedCTEId, uselessRefCount) => + cteMap(referencedCTEId) = cteMap(referencedCTEId).withRefCountDecreased(uselessRefCount) } } } @@ -145,30 +137,45 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { private def inlineCTE( plan: LogicalPlan, - cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], - notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = { + cteMap: mutable.Map[Long, CTEReferenceInfo]): LogicalPlan = { plan match { case WithCTE(child, cteDefs) => - cteDefs.foreach { cteDef => - val (cte, refCount, refMap) = cteMap(cteDef.id) - if (refCount > 0) { - val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined)) - cteMap(cteDef.id) = (inlined, refCount, refMap) - if (!shouldInline(inlined, refCount)) { - notInlined.append(inlined) - } + val remainingDefs = cteDefs.filter { cteDef => + val refInfo = cteMap(cteDef.id) + if (refInfo.refCount > 0) { + val newDef = refInfo.cteDef.copy(child = inlineCTE(refInfo.cteDef.child, cteMap)) + val inlineDecision = shouldInline(newDef, refInfo.refCount) + cteMap(cteDef.id) = cteMap(cteDef.id).copy( + cteDef = newDef, shouldInline = inlineDecision + ) + // Retain the not-inlined CTE relations in place. + !inlineDecision + } else { + keepDanglingRelations } } - inlineCTE(child, cteMap, notInlined) + val inlined = inlineCTE(child, cteMap) + if (remainingDefs.isEmpty) { + inlined + } else { + WithCTE(inlined, remainingDefs) + } case ref: CTERelationRef => - val (cteDef, refCount, _) = cteMap(ref.cteId) - if (shouldInline(cteDef, refCount)) { - if (ref.outputSet == cteDef.outputSet) { - cteDef.child + val refInfo = cteMap(ref.cteId) + if (refInfo.shouldInline) { + if (ref.outputSet == refInfo.cteDef.outputSet) { + refInfo.cteDef.child } else { val ctePlan = DeduplicateRelations( - Join(cteDef.child, cteDef.child, Inner, None, JoinHint(None, None))).children(1) + Join( + refInfo.cteDef.child, + refInfo.cteDef.child, + Inner, + None, + JoinHint(None, None) + ) + ).children(1) val projectList = ref.output.zip(ctePlan.output).map { case (tgtAttr, srcAttr) => if (srcAttr.semanticEquals(tgtAttr)) { tgtAttr @@ -184,13 +191,41 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { case _ if plan.containsPattern(CTE) => plan - .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, notInlined))) + .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap))) .transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) { case e: SubqueryExpression => - e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined)) + e.withNewPlan(inlineCTE(e.plan, cteMap)) } case _ => plan } } } + +/** + * The bookkeeping information for tracking CTE relation references. + * + * @param cteDef The CTE relation definition + * @param refCount The number of incoming references to this CTE relation. This includes references + * from other CTE relations and regular places. + * @param outgoingRefs A mutable map that tracks outgoing reference counts to other CTE relations. + * @param shouldInline If true, this CTE relation should be inlined in the places that reference it. + */ +case class CTEReferenceInfo( + cteDef: CTERelationDef, + refCount: Int, + outgoingRefs: mutable.Map[Long, Int], + shouldInline: Boolean) { + + def withRefCountIncreased(count: Int): CTEReferenceInfo = { + copy(refCount = refCount + count) + } + + def withRefCountDecreased(count: Int): CTEReferenceInfo = { + copy(refCount = refCount - count) + } + + def increaseOutgoingRefCount(cteDefId: Long, count: Int): Unit = { + outgoingRefs(cteDefId) += count + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 86490a2eea970..0e0946668197a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale import java.util.concurrent.TimeUnit -import scala.collection.immutable.Seq import scala.collection.mutable.{ArrayBuffer, Set} import scala.jdk.CollectionConverters._ import scala.util.{Left, Right} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 9242a06cf1d6e..0135fcfb3cc8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -911,6 +911,10 @@ case class WithCTE(plan: LogicalPlan, cteDefs: Seq[CTERelationDef]) extends Logi def withNewPlan(newPlan: LogicalPlan): WithCTE = { withNewChildren(children.init :+ newPlan).asInstanceOf[WithCTE] } + + override def maxRows: Option[Long] = plan.maxRows + + override def maxRowsPerPartition: Option[Long] = plan.maxRowsPerPartition } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 88c2228e640c4..f4751f2027894 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2301,7 +2301,9 @@ object SQLConf { buildConf("spark.sql.streaming.stateStore.skipNullsForStreamStreamJoins.enabled") .internal() .doc("When true, this config will skip null values in hash based stream-stream joins. " + - "The number of skipped null values will be shown as custom metric of stream join operator.") + "The number of skipped null values will be shown as custom metric of stream join operator. " + + "If the streaming query was started with Spark 3.5 or above, please exercise caution " + + "before enabling this config since it may hide potential data loss/corruption issues.") .version("3.3.0") .booleanConf .createWithDefault(false) @@ -4614,6 +4616,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val LEGACY_NO_CHAR_PADDING_IN_PREDICATE = buildConf("spark.sql.legacy.noCharPaddingInPredicate") + .internal() + .doc("When true, Spark will not apply char type padding for CHAR type columns in string " + + s"comparison predicates, when '${READ_SIDE_CHAR_PADDING.key}' is false.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val CLI_PRINT_HEADER = buildConf("spark.sql.cli.print.header") .doc("When set to true, spark-sql CLI prints the names of the columns in query output.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala new file mode 100644 index 0000000000000..9d775a5335c67 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.TestRelation +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CTERelationDef, CTERelationRef, LogicalPlan, OneRowRelation, WithCTE} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class InlineCTESuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("inline CTE", FixedPoint(100), InlineCTE()) :: Nil + } + + test("SPARK-48307: not-inlined CTE relation in command") { + val cteDef = CTERelationDef(OneRowRelation().select(rand(0).as("a"))) + val cteRef = CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output, cteDef.isStreaming) + val plan = AppendData.byName( + TestRelation(Seq($"a".double)), + WithCTE(cteRef.except(cteRef, isAll = true), Seq(cteDef)) + ).analyze + comparePlans(Optimize.execute(plan), plan) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala index b5bf337a5a2e6..1b7b0d702ab98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{CharType, Metadata, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -66,9 +67,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols) }) } - paddingForStringComparison(newPlan) + paddingForStringComparison(newPlan, padCharCol = false) } else { - paddingForStringComparison(plan) + paddingForStringComparison( + plan, padCharCol = !conf.getConf(SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE)) } } @@ -90,7 +92,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } - private def paddingForStringComparison(plan: LogicalPlan): LogicalPlan = { + private def paddingForStringComparison(plan: LogicalPlan, padCharCol: Boolean): LogicalPlan = { plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) { case operator => operator.transformExpressionsUpWithPruning( _.containsAnyPattern(BINARY_COMPARISON, IN)) { @@ -99,12 +101,12 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { // String literal is treated as char type when it's compared to a char type column. // We should pad the shorter one to the longer length. case b @ BinaryComparison(e @ AttrOrOuterRef(attr), lit) if lit.foldable => - padAttrLitCmp(e, attr.metadata, lit).map { newChildren => + padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren => b.withNewChildren(newChildren) }.getOrElse(b) case b @ BinaryComparison(lit, e @ AttrOrOuterRef(attr)) if lit.foldable => - padAttrLitCmp(e, attr.metadata, lit).map { newChildren => + padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren => b.withNewChildren(newChildren.reverse) }.getOrElse(b) @@ -117,9 +119,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { val literalCharLengths = literalChars.map(_.numChars()) val targetLen = (length +: literalCharLengths).max Some(i.copy( - value = addPadding(e, length, targetLen), + value = addPadding(e, length, targetLen, alwaysPad = padCharCol), list = list.zip(literalCharLengths).map { - case (lit, charLength) => addPadding(lit, charLength, targetLen) + case (lit, charLength) => + addPadding(lit, charLength, targetLen, alwaysPad = false) } ++ nulls.map(Literal.create(_, StringType)))) case _ => None }.getOrElse(i) @@ -162,6 +165,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { private def padAttrLitCmp( expr: Expression, metadata: Metadata, + padCharCol: Boolean, lit: Expression): Option[Seq[Expression]] = { if (expr.dataType == StringType) { CharVarcharUtils.getRawType(metadata).flatMap { @@ -174,7 +178,14 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { if (length < stringLitLen) { Some(Seq(StringRPad(expr, Literal(stringLitLen)), lit)) } else if (length > stringLitLen) { - Some(Seq(expr, StringRPad(lit, Literal(length)))) + val paddedExpr = if (padCharCol) { + StringRPad(expr, Literal(length)) + } else { + expr + } + Some(Seq(paddedExpr, StringRPad(lit, Literal(length)))) + } else if (padCharCol) { + Some(Seq(StringRPad(expr, Literal(length)), lit)) } else { None } @@ -186,7 +197,15 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } - private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = { - if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr + private def addPadding( + expr: Expression, + charLength: Int, + targetLength: Int, + alwaysPad: Boolean): Expression = { + if (targetLength > charLength) { + StringRPad(expr, Literal(targetLength)) + } else if (alwaysPad) { + StringRPad(expr, Literal(charLength)) + } else expr } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-legacy.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-legacy.sql.out index 594a30b054edd..f9b78e94236fb 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-legacy.sql.out @@ -43,6 +43,30 @@ Project [scalar-subquery#x [] AS scalarsubquery()#x] +- OneRowRelation +-- !query +SELECT ( + WITH unreferenced AS (SELECT id) + SELECT 1 +) FROM range(1) +-- !query analysis +Project [scalar-subquery#x [] AS scalarsubquery()#x] +: +- Project [1 AS 1#x] +: +- OneRowRelation ++- Range (0, 1, step=1) + + +-- !query +SELECT ( + WITH unreferenced AS (SELECT 1) + SELECT id +) FROM range(1) +-- !query analysis +Project [scalar-subquery#x [id#xL] AS scalarsubquery(id)#xL] +: +- Project [outer(id#xL)] +: +- OneRowRelation ++- Range (0, 1, step=1) + + -- !query SELECT * FROM ( diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out index f1a302b06f2a8..3a9fc5ea1297f 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out @@ -58,6 +58,40 @@ Project [scalar-subquery#x [] AS scalarsubquery()#x] +- OneRowRelation +-- !query +SELECT ( + WITH unreferenced AS (SELECT id) + SELECT 1 +) FROM range(1) +-- !query analysis +Project [scalar-subquery#x [id#xL] AS scalarsubquery(id)#x] +: +- WithCTE +: :- CTERelationDef xxxx, false +: : +- SubqueryAlias unreferenced +: : +- Project [outer(id#xL)] +: : +- OneRowRelation +: +- Project [1 AS 1#x] +: +- OneRowRelation ++- Range (0, 1, step=1) + + +-- !query +SELECT ( + WITH unreferenced AS (SELECT 1) + SELECT id +) FROM range(1) +-- !query analysis +Project [scalar-subquery#x [id#xL] AS scalarsubquery(id)#xL] +: +- WithCTE +: :- CTERelationDef xxxx, false +: : +- SubqueryAlias unreferenced +: : +- Project [1 AS 1#x] +: : +- OneRowRelation +: +- Project [outer(id#xL)] +: +- OneRowRelation ++- Range (0, 1, step=1) + + -- !query SELECT * FROM ( diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out index 6e55c6fa83cd9..e8640c3cbb6bd 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out @@ -58,6 +58,40 @@ Project [scalar-subquery#x [] AS scalarsubquery()#x] +- OneRowRelation +-- !query +SELECT ( + WITH unreferenced AS (SELECT id) + SELECT 1 +) FROM range(1) +-- !query analysis +Project [scalar-subquery#x [id#xL] AS scalarsubquery(id)#x] +: +- WithCTE +: :- CTERelationDef xxxx, false +: : +- SubqueryAlias unreferenced +: : +- Project [outer(id#xL)] +: : +- OneRowRelation +: +- Project [1 AS 1#x] +: +- OneRowRelation ++- Range (0, 1, step=1) + + +-- !query +SELECT ( + WITH unreferenced AS (SELECT 1) + SELECT id +) FROM range(1) +-- !query analysis +Project [scalar-subquery#x [id#xL] AS scalarsubquery(id)#xL] +: +- WithCTE +: :- CTERelationDef xxxx, false +: : +- SubqueryAlias unreferenced +: : +- Project [1 AS 1#x] +: : +- OneRowRelation +: +- Project [outer(id#xL)] +: +- OneRowRelation ++- Range (0, 1, step=1) + + -- !query SELECT * FROM ( diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql index e5ef244341751..3b2ba1fcdd66e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql @@ -17,6 +17,18 @@ SELECT ( SELECT * FROM t ); +-- un-referenced CTE in subquery expression: outer reference in CTE relation +SELECT ( + WITH unreferenced AS (SELECT id) + SELECT 1 +) FROM range(1); + +-- un-referenced CTE in subquery expression: outer reference in CTE main query +SELECT ( + WITH unreferenced AS (SELECT 1) + SELECT id +) FROM range(1); + -- Make sure CTE in subquery is scoped to that subquery rather than global -- the 2nd half of the union should fail because the cte is scoped to the first half SELECT * FROM diff --git a/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out index b79d8b1afb0d4..1255e8b51f301 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out @@ -33,6 +33,28 @@ struct 1 +-- !query +SELECT ( + WITH unreferenced AS (SELECT id) + SELECT 1 +) FROM range(1) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT ( + WITH unreferenced AS (SELECT 1) + SELECT id +) FROM range(1) +-- !query schema +struct +-- !query output +0 + + -- !query SELECT * FROM ( diff --git a/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out index a93bcb7593768..7cf488ce8cad4 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out @@ -33,6 +33,28 @@ struct 1 +-- !query +SELECT ( + WITH unreferenced AS (SELECT id) + SELECT 1 +) FROM range(1) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT ( + WITH unreferenced AS (SELECT 1) + SELECT id +) FROM range(1) +-- !query schema +struct +-- !query output +0 + + -- !query SELECT * FROM ( diff --git a/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out index ba311c0253ab1..94ef47397eff1 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out @@ -33,6 +33,28 @@ struct 1 +-- !query +SELECT ( + WITH unreferenced AS (SELECT id) + SELECT 1 +) FROM range(1) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT ( + WITH unreferenced AS (SELECT 1) + SELECT id +) FROM range(1) +-- !query schema +struct +-- !query output +0 + + -- !query SELECT * FROM ( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 013177425da78..a93dee3bf2a61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -942,6 +942,34 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa } } } + + test("SPARK-48498: always do char padding in predicates") { + import testImplicits._ + withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false") { + withTempPath { dir => + withTable("t") { + Seq( + "12" -> "12", + "12" -> "12 ", + "12 " -> "12", + "12 " -> "12 " + ).toDF("c1", "c2").write.format(format).save(dir.toString) + sql(s"CREATE TABLE t (c1 CHAR(3), c2 STRING) USING $format LOCATION '$dir'") + // Comparing CHAR column with STRING column directly compares the stored value. + checkAnswer( + sql("SELECT c1 = c2 FROM t"), + Seq(Row(true), Row(false), Row(false), Row(true)) + ) + // No matter the CHAR type value is padded or not in the storage, we should always pad it + // before comparison with STRING literals. + checkAnswer( + sql("SELECT c1 = '12', c1 = '12 ', c1 = '12 ' FROM t WHERE c2 = '12'"), + Seq(Row(true, true, true), Row(true, true, true)) + ) + } + } + } + } } class DSV2CharVarcharTestSuite extends CharVarcharTestSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala index 34c6c49bc4981..ad424b3a7cc76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala @@ -256,9 +256,11 @@ trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite { protected def testQuery(tpcdsGroup: String, query: String, suffix: String = ""): Unit = { val queryString = resourceToString(s"$tpcdsGroup/$query.sql", classLoader = Thread.currentThread().getContextClassLoader) - // Disable char/varchar read-side handling for better performance. - withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") { + withSQLConf( + // Disable char/varchar read-side handling for better performance. + SQLConf.READ_SIDE_CHAR_PADDING.key -> "false", + SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") { val qe = sql(queryString).queryExecution val plan = qe.executedPlan val explain = normalizeLocation(normalizeIds(qe.explainString(FormattedMode))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 2bb2fe970a118..11e077e891bd7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -1340,6 +1340,15 @@ private[hive] object HiveClientImpl extends Logging { log"will be reset to 'mr' to disable useless hive logic") hiveConf.set("hive.execution.engine", "mr", SOURCE_SPARK) } + val cpType = hiveConf.get("datanucleus.connectionPoolingType") + // Bonecp might cause memory leak, it could affect some hive client versions we support + // See more details in HIVE-15551 + // Also, Bonecp is removed in Hive 4.0.0, see HIVE-23258 + // Here we use DBCP to replace bonecp instead of HikariCP as HikariCP was introduced in + // Hive 2.2.0 (see HIVE-13931) while the minium Hive we support is 2.0.0. + if ("bonecp".equalsIgnoreCase(cpType)) { + hiveConf.set("datanucleus.connectionPoolingType", "DBCP", SOURCE_SPARK) + } hiveConf }