diff --git a/LICENSE-binary b/LICENSE-binary
index 456b074842575..b6971798e5577 100644
--- a/LICENSE-binary
+++ b/LICENSE-binary
@@ -218,7 +218,6 @@ com.google.crypto.tink:tink
com.google.flatbuffers:flatbuffers-java
com.google.guava:guava
com.jamesmurty.utils:java-xmlbuilder
-com.jolbox:bonecp
com.ning:compress-lzf
com.squareup.okhttp3:logging-interceptor
com.squareup.okhttp3:okhttp
diff --git a/NOTICE-binary b/NOTICE-binary
index c82d0b52f31cc..c4cfe0e9f8b31 100644
--- a/NOTICE-binary
+++ b/NOTICE-binary
@@ -33,11 +33,12 @@ services.
// Version 2.0, in this case for
// ------------------------------------------------------------------
-Hive Beeline
-Copyright 2016 The Apache Software Foundation
+=== NOTICE FOR com.clearspring.analytics:streams ===
+stream-api
+Copyright 2016 AddThis
-This product includes software developed at
-The Apache Software Foundation (http://www.apache.org/).
+This product includes software developed by AddThis.
+=== END OF NOTICE FOR com.clearspring.analytics:streams ===
Apache Avro
Copyright 2009-2014 The Apache Software Foundation
diff --git a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala
index 08291859a32cc..ae00987cd69f6 100644
--- a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala
+++ b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala
@@ -462,14 +462,13 @@ private[spark] object MavenUtils extends Logging {
val sysOut = System.out
// Default configuration name for ivy
val ivyConfName = "default"
-
- // A Module descriptor must be specified. Entries are dummy strings
- val md = getModuleDescriptor
-
- md.setDefaultConf(ivyConfName)
+ var md: DefaultModuleDescriptor = null
try {
// To prevent ivy from logging to system out
System.setOut(printStream)
+ // A Module descriptor must be specified. Entries are dummy strings
+ md = getModuleDescriptor
+ md.setDefaultConf(ivyConfName)
val artifacts = extractMavenCoordinates(coordinates)
// Directories for caching downloads through ivy and storing the jars when maven coordinates
// are supplied to spark-submit
@@ -548,7 +547,9 @@ private[spark] object MavenUtils extends Logging {
}
} finally {
System.setOut(sysOut)
- clearIvyResolutionFiles(md.getModuleRevisionId, ivySettings.getDefaultCache, ivyConfName)
+ if (md != null) {
+ clearIvyResolutionFiles(md.getModuleRevisionId, ivySettings.getDefaultCache, ivyConfName)
+ }
}
}
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
index 8110bd9f46a86..203b1295005af 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
@@ -108,6 +108,35 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession {
assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
}
+ test("interrupt all - streaming queries") {
+ val q1 = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", 1)
+ .load()
+ .writeStream
+ .format("console")
+ .start()
+
+ val q2 = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", 1)
+ .load()
+ .writeStream
+ .format("console")
+ .start()
+
+ assert(q1.isActive)
+ assert(q2.isActive)
+
+ val interrupted = spark.interruptAll()
+
+ q1.awaitTermination(timeoutMs = 20 * 1000)
+ q2.awaitTermination(timeoutMs = 20 * 1000)
+ assert(!q1.isActive)
+ assert(!q2.isActive)
+ assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
+ }
+
// TODO(SPARK-48139): Re-enable `SparkSessionE2ESuite.interrupt tag`
ignore("interrupt tag") {
val session = spark
@@ -230,6 +259,53 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession {
assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
}
+ test("interrupt tag - streaming query") {
+ spark.addTag("foo")
+ val q1 = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", 1)
+ .load()
+ .writeStream
+ .format("console")
+ .start()
+ assert(spark.getTags() == Set("foo"))
+
+ spark.addTag("bar")
+ val q2 = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", 1)
+ .load()
+ .writeStream
+ .format("console")
+ .start()
+ assert(spark.getTags() == Set("foo", "bar"))
+
+ spark.clearTags()
+
+ spark.addTag("zoo")
+ val q3 = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", 1)
+ .load()
+ .writeStream
+ .format("console")
+ .start()
+ assert(spark.getTags() == Set("zoo"))
+
+ assert(q1.isActive)
+ assert(q2.isActive)
+ assert(q3.isActive)
+
+ val interrupted = spark.interruptTag("foo")
+
+ q1.awaitTermination(timeoutMs = 20 * 1000)
+ q2.awaitTermination(timeoutMs = 20 * 1000)
+ assert(!q1.isActive)
+ assert(!q2.isActive)
+ assert(q3.isActive)
+ assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
+ }
+
test("progress is available for the spark result") {
val result = spark
.range(10000)
diff --git a/connector/connect/docs/client-connection-string.md b/connector/connect/docs/client-connection-string.md
index ebab7cbff4fc1..37b2956a5c44a 100644
--- a/connector/connect/docs/client-connection-string.md
+++ b/connector/connect/docs/client-connection-string.md
@@ -22,8 +22,8 @@ cannot contain arbitrary characters. Configuration parameter are passed in the
style of the HTTP URL Path Parameter Syntax. This is similar to the JDBC connection
strings. The path component must be empty. All parameters are interpreted **case sensitive**.
-```shell
-sc://hostname:port/;param1=value;param2=value
+```text
+sc://host:port/;param1=value;param2=value
```
@@ -34,7 +34,7 @@ sc://hostname:port/;param1=value;param2=value
Examples |
- hostname |
+ host |
String |
The hostname of the endpoint for Spark Connect. Since the endpoint
@@ -49,8 +49,8 @@ sc://hostname:port/;param1=value;param2=value
|
port |
-Numeric |
- The portname to be used when connecting to the GRPC endpoint. The
+ | Numeric |
+ The port to be used when connecting to the GRPC endpoint. The
default values is: 15002. Any valid port number can be used. |
15002 443 |
@@ -75,7 +75,7 @@ sc://hostname:port/;param1=value;param2=value
user_id |
String |
User ID to automatically set in the Spark Connect UserContext message.
- This is necssary for the appropriate Spark Session management. This is an
+ This is necessary for the appropriate Spark Session management. This is an
*optional* parameter and depending on the deployment this parameter might
be automatically injected using other means. |
@@ -99,9 +99,16 @@ sc://hostname:port/;param1=value;param2=value
allows to provide this session ID to allow sharing Spark Sessions for the same users
for example across multiple languages. The value must be provided in a valid UUID
string format.
- Default: A UUID generated randomly. |
+ Default: A UUID generated randomly
session_id=550e8400-e29b-41d4-a716-446655440000 |
+
+ grpc_max_message_size |
+ Numeric |
+ Maximum message size allowed for gRPC messages in bytes.
+ Default: 128 * 1024 * 1024 |
+ grpc_max_message_size=134217728 |
+
## Examples
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b3ae8c71611f4..453d2b30876f4 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -3163,7 +3163,11 @@ class SparkConnectPlanner(
}
// Register the new query so that its reference is cached and is stopped on session timeout.
- SparkConnectService.streamingSessionManager.registerNewStreamingQuery(sessionHolder, query)
+ SparkConnectService.streamingSessionManager.registerNewStreamingQuery(
+ sessionHolder,
+ query,
+ executeHolder.sparkSessionTags,
+ executeHolder.operationId)
// Register the runner with the query if Python foreachBatch is enabled.
foreachBatchRunnerCleaner.foreach { cleaner =>
sessionHolder.streamingForeachBatchRunnerCleanerCache.registerCleanerForQuery(
@@ -3228,7 +3232,9 @@ class SparkConnectPlanner(
// Find the query in connect service level cache, otherwise check session's active streams.
val query = SparkConnectService.streamingSessionManager
- .getCachedQuery(id, runId, session) // Common case: query is cached in the cache.
+ // Common case: query is cached in the cache.
+ .getCachedQuery(id, runId, executeHolder.sparkSessionTags, session)
+ .map(_.query)
.orElse { // Else try to find it in active streams. Mostly will not be found here either.
Option(session.streams.get(id))
} match {
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 0285b48405773..681f7e29630ff 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit}
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
+import scala.concurrent.{ExecutionContext, Future}
import scala.jdk.CollectionConverters._
import scala.util.Try
@@ -179,12 +180,14 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
*/
private[service] def interruptAll(): Seq[String] = {
val interruptedIds = new mutable.ArrayBuffer[String]()
+ val operationsIds =
+ SparkConnectService.streamingSessionManager.cleanupRunningQueries(this, blocking = false)
executions.asScala.values.foreach { execute =>
if (execute.interrupt()) {
interruptedIds += execute.operationId
}
}
- interruptedIds.toSeq
+ interruptedIds.toSeq ++ operationsIds
}
/**
@@ -194,6 +197,8 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
*/
private[service] def interruptTag(tag: String): Seq[String] = {
val interruptedIds = new mutable.ArrayBuffer[String]()
+ val queries = SparkConnectService.streamingSessionManager.getTaggedQuery(tag, session)
+ queries.foreach(q => Future(q.query.stop())(ExecutionContext.global))
executions.asScala.values.foreach { execute =>
if (execute.sparkSessionTags.contains(tag)) {
if (execute.interrupt()) {
@@ -201,7 +206,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
}
}
}
- interruptedIds.toSeq
+ interruptedIds.toSeq ++ queries.map(_.operationId)
}
/**
@@ -298,7 +303,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
// Clean up running streaming queries.
// Note: there can be concurrent streaming queries being started.
- SparkConnectService.streamingSessionManager.cleanupRunningQueries(this)
+ SparkConnectService.streamingSessionManager.cleanupRunningQueries(this, blocking = true)
streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any streaming workers.
removeAllListeners() // removes all listener and stop python listener processes if necessary.
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
index 91eae18f2d5da..03719ddd87419 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
+import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.{Duration, DurationInt, FiniteDuration}
import scala.util.control.NonFatal
@@ -55,16 +56,28 @@ private[connect] class SparkConnectStreamingQueryCache(
import SparkConnectStreamingQueryCache._
- def registerNewStreamingQuery(sessionHolder: SessionHolder, query: StreamingQuery): Unit = {
- queryCacheLock.synchronized {
+ def registerNewStreamingQuery(
+ sessionHolder: SessionHolder,
+ query: StreamingQuery,
+ tags: Set[String],
+ operationId: String): Unit = queryCacheLock.synchronized {
+ taggedQueriesLock.synchronized {
val value = QueryCacheValue(
userId = sessionHolder.userId,
sessionId = sessionHolder.sessionId,
session = sessionHolder.session,
query = query,
+ operationId = operationId,
expiresAtMs = None)
- queryCache.put(QueryCacheKey(query.id.toString, query.runId.toString), value) match {
+ val queryKey = QueryCacheKey(query.id.toString, query.runId.toString)
+ tags.foreach { tag =>
+ taggedQueries
+ .getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey])
+ .addOne(queryKey)
+ }
+
+ queryCache.put(queryKey, value) match {
case Some(existing) => // Query is being replace. Not really expected.
logWarning(log"Replacing existing query in the cache (unexpected). " +
log"Query Id: ${MDC(QUERY_ID, query.id)}.Existing value ${MDC(OLD_VALUE, existing)}, " +
@@ -80,7 +93,7 @@ private[connect] class SparkConnectStreamingQueryCache(
}
/**
- * Returns [[StreamingQuery]] if it is cached and session matches the cached query. It ensures
+ * Returns [[QueryCacheValue]] if it is cached and session matches the cached query. It ensures
* the session associated with it matches the session passed into the call. If the query is
* inactive (i.e. it has a cache expiry time set), this access extends its expiry time. So if a
* client keeps accessing a query, it stays in the cache.
@@ -88,8 +101,35 @@ private[connect] class SparkConnectStreamingQueryCache(
def getCachedQuery(
queryId: String,
runId: String,
- session: SparkSession): Option[StreamingQuery] = {
- val key = QueryCacheKey(queryId, runId)
+ tags: Set[String],
+ session: SparkSession): Option[QueryCacheValue] = {
+ taggedQueriesLock.synchronized {
+ val key = QueryCacheKey(queryId, runId)
+ val result = getCachedQuery(QueryCacheKey(queryId, runId), session)
+ tags.foreach { tag =>
+ taggedQueries.getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]).addOne(key)
+ }
+ result
+ }
+ }
+
+ /**
+ * Similar with [[getCachedQuery]] but it gets queries tagged previously.
+ */
+ def getTaggedQuery(tag: String, session: SparkSession): Seq[QueryCacheValue] = {
+ taggedQueriesLock.synchronized {
+ taggedQueries
+ .get(tag)
+ .map { k =>
+ k.flatMap(getCachedQuery(_, session)).toSeq
+ }
+ .getOrElse(Seq.empty[QueryCacheValue])
+ }
+ }
+
+ private def getCachedQuery(
+ key: QueryCacheKey,
+ session: SparkSession): Option[QueryCacheValue] = {
queryCacheLock.synchronized {
queryCache.get(key).flatMap { v =>
if (v.session == session) {
@@ -98,7 +138,7 @@ private[connect] class SparkConnectStreamingQueryCache(
val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis
queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs)))
}
- Some(v.query)
+ Some(v)
} else None // Should be rare, may be client is trying access from a different session.
}
}
@@ -109,7 +149,10 @@ private[connect] class SparkConnectStreamingQueryCache(
* the queryCache. This is used when session is expired and we need to cleanup resources of that
* session.
*/
- def cleanupRunningQueries(sessionHolder: SessionHolder): Unit = {
+ def cleanupRunningQueries(
+ sessionHolder: SessionHolder,
+ blocking: Boolean = true): Seq[String] = {
+ val operationIds = new mutable.ArrayBuffer[String]()
for ((k, v) <- queryCache) {
if (v.userId.equals(sessionHolder.userId) && v.sessionId.equals(sessionHolder.sessionId)) {
if (v.query.isActive && Option(v.session.streams.get(k.queryId)).nonEmpty) {
@@ -117,7 +160,12 @@ private[connect] class SparkConnectStreamingQueryCache(
log"Stopping the query with id ${MDC(QUERY_ID, k.queryId)} " +
log"since the session has timed out")
try {
- v.query.stop()
+ if (blocking) {
+ v.query.stop()
+ } else {
+ Future(v.query.stop())(ExecutionContext.global)
+ }
+ operationIds.addOne(v.operationId)
} catch {
case NonFatal(ex) =>
logWarning(
@@ -128,6 +176,7 @@ private[connect] class SparkConnectStreamingQueryCache(
}
}
}
+ operationIds.toSeq
}
// Visible for testing
@@ -146,6 +195,10 @@ private[connect] class SparkConnectStreamingQueryCache(
private val queryCache = new mutable.HashMap[QueryCacheKey, QueryCacheValue]
private val queryCacheLock = new Object
+ @GuardedBy("queryCacheLock")
+ private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]]
+ private val taggedQueriesLock = new Object
+
@GuardedBy("queryCacheLock")
private var scheduledExecutor: Option[ScheduledExecutorService] = None
@@ -176,7 +229,7 @@ private[connect] class SparkConnectStreamingQueryCache(
* - Update status of query if it is inactive. Sets an expiry time for such queries
* - Drop expired queries from the cache.
*/
- private def periodicMaintenance(): Unit = {
+ private def periodicMaintenance(): Unit = taggedQueriesLock.synchronized {
queryCacheLock.synchronized {
val nowMs = clock.getTimeMillis()
@@ -212,6 +265,18 @@ private[connect] class SparkConnectStreamingQueryCache(
}
}
}
+
+ taggedQueries.toArray.foreach { case (key, value) =>
+ value.zipWithIndex.toArray.foreach { case (queryKey, i) =>
+ if (queryCache.contains(queryKey)) {
+ value.remove(i)
+ }
+ }
+
+ if (value.isEmpty) {
+ taggedQueries.remove(key)
+ }
+ }
}
}
}
@@ -225,6 +290,7 @@ private[connect] object SparkConnectStreamingQueryCache {
sessionId: String,
session: SparkSession, // Holds the reference to the session.
query: StreamingQuery, // Holds the reference to the query.
+ operationId: String,
expiresAtMs: Option[Long] = None // Expiry time for a stopped query.
) {
override def toString(): String =
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index af18fca9dd216..71ca0f44af680 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -826,7 +826,9 @@ class SparkConnectServiceSuite
when(restartedQuery.runId).thenReturn(DEFAULT_UUID)
SparkConnectService.streamingSessionManager.registerNewStreamingQuery(
SparkConnectService.getOrCreateIsolatedSession("c1", sessionId, None),
- restartedQuery)
+ restartedQuery,
+ Set.empty[String],
+ "")
f(verifyEvents)
}
}
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala
index ed3da2c0f7156..512a0a80c4a91 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala
@@ -67,7 +67,7 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug
// Register the query.
- sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery)
+ sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery, Set.empty[String], "")
sessionMgr.getCachedValue(queryId, runId) match {
case Some(v) =>
@@ -78,9 +78,14 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug
}
// Verify query is returned only with the correct session, not with a different session.
- assert(sessionMgr.getCachedQuery(queryId, runId, mock[SparkSession]).isEmpty)
+ assert(
+ sessionMgr.getCachedQuery(queryId, runId, Set.empty[String], mock[SparkSession]).isEmpty)
// Query is returned when correct session is used
- assert(sessionMgr.getCachedQuery(queryId, runId, mockSession).contains(mockQuery))
+ assert(
+ sessionMgr
+ .getCachedQuery(queryId, runId, Set.empty[String], mockSession)
+ .map(_.query)
+ .contains(mockQuery))
// Cleanup the query and verify if stop() method has been called.
when(mockQuery.isActive).thenReturn(false)
@@ -99,7 +104,11 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug
clock.advance(30.seconds.toMillis)
// Access the query. This should advance expiry time by 30 seconds.
- assert(sessionMgr.getCachedQuery(queryId, runId, mockSession).contains(mockQuery))
+ assert(
+ sessionMgr
+ .getCachedQuery(queryId, runId, Set.empty[String], mockSession)
+ .map(_.query)
+ .contains(mockQuery))
val expiresAtMs = sessionMgr.getCachedValue(queryId, runId).get.expiresAtMs.get
assert(expiresAtMs == prevExpiryTimeMs + 30.seconds.toMillis)
@@ -112,7 +121,7 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug
when(restartedQuery.isActive).thenReturn(true)
when(mockStreamingQueryManager.get(queryId)).thenReturn(restartedQuery)
- sessionMgr.registerNewStreamingQuery(sessionHolder, restartedQuery)
+ sessionMgr.registerNewStreamingQuery(sessionHolder, restartedQuery, Set.empty[String], "")
// Both queries should existing in the cache.
assert(sessionMgr.getCachedValue(queryId, runId).map(_.query).contains(mockQuery))
diff --git a/core/benchmarks/LZFBenchmark-jdk21-results.txt b/core/benchmarks/LZFBenchmark-jdk21-results.txt
new file mode 100644
index 0000000000000..e1566f201a1f6
--- /dev/null
+++ b/core/benchmarks/LZFBenchmark-jdk21-results.txt
@@ -0,0 +1,19 @@
+================================================================================================
+Benchmark LZFCompressionCodec
+================================================================================================
+
+OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1021-azure
+AMD EPYC 7763 64-Core Processor
+Compress small objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+--------------------------------------------------------------------------------------------------------------------------------
+Compression 256000000 int values in parallel 598 600 2 428.2 2.3 1.0X
+Compression 256000000 int values single-threaded 568 570 2 451.0 2.2 1.1X
+
+OpenJDK 64-Bit Server VM 21.0.3+9-LTS on Linux 6.5.0-1021-azure
+AMD EPYC 7763 64-Core Processor
+Compress large objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+-----------------------------------------------------------------------------------------------------------------------------
+Compression 1024 array values in 1 threads 39 45 5 0.0 38475.4 1.0X
+Compression 1024 array values single-threaded 32 33 1 0.0 31154.5 1.2X
+
+
diff --git a/core/benchmarks/LZFBenchmark-results.txt b/core/benchmarks/LZFBenchmark-results.txt
new file mode 100644
index 0000000000000..facc67f9cf4a8
--- /dev/null
+++ b/core/benchmarks/LZFBenchmark-results.txt
@@ -0,0 +1,19 @@
+================================================================================================
+Benchmark LZFCompressionCodec
+================================================================================================
+
+OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1021-azure
+AMD EPYC 7763 64-Core Processor
+Compress small objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+--------------------------------------------------------------------------------------------------------------------------------
+Compression 256000000 int values in parallel 602 612 6 425.1 2.4 1.0X
+Compression 256000000 int values single-threaded 610 617 5 419.8 2.4 1.0X
+
+OpenJDK 64-Bit Server VM 17.0.11+9-LTS on Linux 6.5.0-1021-azure
+AMD EPYC 7763 64-Core Processor
+Compress large objects: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
+-----------------------------------------------------------------------------------------------------------------------------
+Compression 1024 array values in 1 threads 35 43 6 0.0 33806.8 1.0X
+Compression 1024 array values single-threaded 32 32 0 0.0 30990.4 1.1X
+
+
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 90d8cef00ef83..6eb2bea40bdb5 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -421,7 +421,7 @@ class SparkContext(config: SparkConf) extends Logging {
}
// HADOOP-19097 Set fs.s3a.connection.establish.timeout to 30s
// We can remove this after Apache Hadoop 3.4.1 releases
- conf.setIfMissing("spark.hadoop.fs.s3a.connection.establish.timeout", "30s")
+ conf.setIfMissing("spark.hadoop.fs.s3a.connection.establish.timeout", "30000")
// This should be set as early as possible.
SparkContext.fillMissingMagicCommitterConfsIfNeeded(_conf)
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index dc3edfaa86133..a7268c6409913 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -2031,6 +2031,13 @@ package object config {
.intConf
.createWithDefault(1)
+ private[spark] val IO_COMPRESSION_LZF_PARALLEL =
+ ConfigBuilder("spark.io.compression.lzf.parallel.enabled")
+ .doc("When true, LZF compression will use multiple threads to compress data in parallel.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
private[spark] val IO_WARNING_LARGEFILETHRESHOLD =
ConfigBuilder("spark.io.warning.largeFileThreshold")
.internal()
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index 07e694b6c5b03..233228a9c6d4c 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -22,6 +22,7 @@ import java.util.Locale
import com.github.luben.zstd.{NoPool, RecyclingBufferPool, ZstdInputStreamNoFinalizer, ZstdOutputStreamNoFinalizer}
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
+import com.ning.compress.lzf.parallel.PLZFOutputStream
import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream, LZ4Factory}
import net.jpountz.xxhash.XXHashFactory
import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream}
@@ -100,8 +101,9 @@ private[spark] object CompressionCodec {
* If it is already a short name, just return it.
*/
def getShortName(codecName: String): String = {
- if (shortCompressionCodecNames.contains(codecName)) {
- codecName
+ val lowercasedCodec = codecName.toLowerCase(Locale.ROOT)
+ if (shortCompressionCodecNames.contains(lowercasedCodec)) {
+ lowercasedCodec
} else {
shortCompressionCodecNames
.collectFirst { case (k, v) if v == codecName => k }
@@ -170,9 +172,14 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec {
*/
@DeveloperApi
class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec {
+ private val parallelCompression = conf.get(IO_COMPRESSION_LZF_PARALLEL)
override def compressedOutputStream(s: OutputStream): OutputStream = {
- new LZFOutputStream(s).setFinishBlockOnFlush(true)
+ if (parallelCompression) {
+ new PLZFOutputStream(s)
+ } else {
+ new LZFOutputStream(s).setFinishBlockOnFlush(true)
+ }
}
override def compressedInputStream(s: InputStream): InputStream = new LZFInputStream(s)
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 991fb074d246d..0ac1405abe6c3 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -19,7 +19,7 @@ package org.apache.spark.util
import java.io._
import java.lang.{Byte => JByte}
-import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
+import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, PlatformManagedObject, ThreadInfo}
import java.lang.reflect.InvocationTargetException
import java.math.{MathContext, RoundingMode}
import java.net._
@@ -3058,8 +3058,16 @@ private[spark] object Utils
*/
lazy val isG1GC: Boolean = {
Try {
- ManagementFactory.getGarbageCollectorMXBeans.asScala
- .exists(_.getName.contains("G1"))
+ val clazz = Utils.classForName("com.sun.management.HotSpotDiagnosticMXBean")
+ .asInstanceOf[Class[_ <: PlatformManagedObject]]
+ val vmOptionClazz = Utils.classForName("com.sun.management.VMOption")
+ val hotSpotDiagnosticMXBean = ManagementFactory.getPlatformMXBean(clazz)
+ val vmOptionMethod = clazz.getMethod("getVMOption", classOf[String])
+ val valueMethod = vmOptionClazz.getMethod("getValue")
+
+ val useG1GCObject = vmOptionMethod.invoke(hotSpotDiagnosticMXBean, "UseG1GC")
+ val useG1GC = valueMethod.invoke(useG1GCObject).asInstanceOf[String]
+ "true".equals(useG1GC)
}.getOrElse(false)
}
}
diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
index 729fcecff1207..5c09a1f965b9e 100644
--- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.io
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.util.Locale
import com.google.common.io.ByteStreams
@@ -160,4 +161,18 @@ class CompressionCodecSuite extends SparkFunSuite {
ByteStreams.readFully(concatenatedBytes, decompressed)
assert(decompressed.toSeq === (0 to 127))
}
+
+ test("SPARK-48506: CompressionCodec getShortName is case insensitive for short names") {
+ CompressionCodec.shortCompressionCodecNames.foreach { case (shortName, codecClass) =>
+ assert(CompressionCodec.getShortName(shortName) === shortName)
+ assert(CompressionCodec.getShortName(shortName.toUpperCase(Locale.ROOT)) === shortName)
+ assert(CompressionCodec.getShortName(codecClass) === shortName)
+ checkError(
+ exception = intercept[SparkIllegalArgumentException] {
+ CompressionCodec.getShortName(codecClass.toUpperCase(Locale.ROOT))
+ },
+ errorClass = "CODEC_SHORT_NAME_NOT_FOUND",
+ parameters = Map("codecName" -> codecClass.toUpperCase(Locale.ROOT)))
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/io/LZFBenchmark.scala b/core/src/test/scala/org/apache/spark/io/LZFBenchmark.scala
new file mode 100644
index 0000000000000..1934bd5169703
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/io/LZFBenchmark.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.io
+
+import java.io.{ByteArrayOutputStream, ObjectOutputStream}
+import java.lang.management.ManagementFactory
+
+import org.apache.spark.SparkConf
+import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
+import org.apache.spark.internal.config.IO_COMPRESSION_LZF_PARALLEL
+
+/**
+ * Benchmark for ZStandard codec performance.
+ * {{{
+ * To run this benchmark:
+ * 1. without sbt: bin/spark-submit --class
+ * 2. build/sbt "core/Test/runMain "
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/Test/runMain "
+ * Results will be written to "benchmarks/ZStandardBenchmark-results.txt".
+ * }}}
+ */
+object LZFBenchmark extends BenchmarkBase {
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ runBenchmark("Benchmark LZFCompressionCodec") {
+ compressSmallObjects()
+ compressLargeObjects()
+ }
+ }
+
+ private def compressSmallObjects(): Unit = {
+ val N = 256_000_000
+ val benchmark = new Benchmark("Compress small objects", N, output = output)
+ Seq(true, false).foreach { parallel =>
+ val conf = new SparkConf(false).set(IO_COMPRESSION_LZF_PARALLEL, parallel)
+ val condition = if (parallel) "in parallel" else "single-threaded"
+ benchmark.addCase(s"Compression $N int values $condition") { _ =>
+ val os = new LZFCompressionCodec(conf).compressedOutputStream(new ByteArrayOutputStream())
+ for (i <- 1 until N) {
+ os.write(i)
+ }
+ os.close()
+ }
+ }
+ benchmark.run()
+ }
+
+ private def compressLargeObjects(): Unit = {
+ val N = 1024
+ val data: Array[Byte] = (1 until 128 * 1024 * 1024).map(_.toByte).toArray
+ val benchmark = new Benchmark(s"Compress large objects", N, output = output)
+
+ // com.ning.compress.lzf.parallel.PLZFOutputStream.getNThreads
+ def getNThreads: Int = {
+ var nThreads = Runtime.getRuntime.availableProcessors
+ val jmx = ManagementFactory.getOperatingSystemMXBean
+ if (jmx != null) {
+ val loadAverage = jmx.getSystemLoadAverage.toInt
+ if (nThreads > 1 && loadAverage >= 1) nThreads = Math.max(1, nThreads - loadAverage)
+ }
+ nThreads
+ }
+ Seq(true, false).foreach { parallel =>
+ val conf = new SparkConf(false).set(IO_COMPRESSION_LZF_PARALLEL, parallel)
+ val condition = if (parallel) s"in $getNThreads threads" else "single-threaded"
+ benchmark.addCase(s"Compression $N array values $condition") { _ =>
+ val baos = new ByteArrayOutputStream()
+ val zcos = new LZFCompressionCodec(conf).compressedOutputStream(baos)
+ val oos = new ObjectOutputStream(zcos)
+ 1 to N foreach { _ =>
+ oos.writeObject(data)
+ }
+ oos.close()
+ }
+ }
+ benchmark.run()
+ }
+}
diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3
index 65e627b1854f9..8ab76b5787b8c 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -29,7 +29,6 @@ azure-data-lake-store-sdk/2.3.9//azure-data-lake-store-sdk-2.3.9.jar
azure-keyvault-core/1.0.0//azure-keyvault-core-1.0.0.jar
azure-storage/7.0.1//azure-storage-7.0.1.jar
blas/3.0.3//blas-3.0.3.jar
-bonecp/0.8.0.RELEASE//bonecp-0.8.0.RELEASE.jar
breeze-macros_2.13/2.1.0//breeze-macros_2.13-2.1.0.jar
breeze_2.13/2.1.0//breeze_2.13-2.1.0.jar
bundle/2.24.6//bundle-2.24.6.jar
@@ -137,8 +136,8 @@ jersey-container-servlet/3.0.12//jersey-container-servlet-3.0.12.jar
jersey-hk2/3.0.12//jersey-hk2-3.0.12.jar
jersey-server/3.0.12//jersey-server-3.0.12.jar
jettison/1.5.4//jettison-1.5.4.jar
-jetty-util-ajax/11.0.20//jetty-util-ajax-11.0.20.jar
-jetty-util/11.0.20//jetty-util-11.0.20.jar
+jetty-util-ajax/11.0.21//jetty-util-ajax-11.0.21.jar
+jetty-util/11.0.21//jetty-util-11.0.21.jar
jline/2.14.6//jline-2.14.6.jar
jline/3.25.1//jline-3.25.1.jar
jna/5.14.0//jna-5.14.0.jar
@@ -262,7 +261,7 @@ spire-platform_2.13/0.18.0//spire-platform_2.13-0.18.0.jar
spire-util_2.13/0.18.0//spire-util_2.13-0.18.0.jar
spire_2.13/0.18.0//spire_2.13-0.18.0.jar
stax-api/1.0.1//stax-api-1.0.1.jar
-stream/2.9.6//stream-2.9.6.jar
+stream/2.9.8//stream-2.9.8.jar
super-csv/2.2.0//super-csv-2.2.0.jar
threeten-extra/1.7.1//threeten-extra-1.7.1.jar
tink/1.13.0//tink-1.13.0.jar
diff --git a/dev/pyproject.toml b/dev/pyproject.toml
index 4f462d14c7838..f19107b3782a6 100644
--- a/dev/pyproject.toml
+++ b/dev/pyproject.toml
@@ -29,6 +29,6 @@ testpaths = [
# GitHub workflow version and dev/reformat-python
required-version = "23.9.1"
line-length = 100
-target-version = ['py38']
+target-version = ['py39']
include = '\.pyi?$'
extend-exclude = 'cloudpickle|error_classes.py'
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index ec4b25547dc4c..e182d0c33f16c 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1009,6 +1009,7 @@ def __hash__(self):
# sql unittests
"pyspark.sql.tests.connect.test_connect_plan",
"pyspark.sql.tests.connect.test_connect_basic",
+ "pyspark.sql.tests.connect.test_connect_dataframe_property",
"pyspark.sql.tests.connect.test_connect_error",
"pyspark.sql.tests.connect.test_connect_function",
"pyspark.sql.tests.connect.test_connect_collection",
diff --git a/docs/configuration.md b/docs/configuration.md
index 409f1f521eb52..23443cab2eacc 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1895,6 +1895,14 @@ Apart from these, the following properties are also available, and may be useful
4.0.0 |
+
+ spark.io.compression.lzf.parallel.enabled |
+ false |
+
+ When true, LZF compression will use multiple threads to compress data in parallel.
+ |
+ 4.0.0 |
+
spark.kryo.classesToRegister |
(none) |
diff --git a/docs/core-migration-guide.md b/docs/core-migration-guide.md
index 28a9dd0f43715..1c37fded53ab7 100644
--- a/docs/core-migration-guide.md
+++ b/docs/core-migration-guide.md
@@ -48,7 +48,7 @@ license: |
- Since Spark 4.0, the MDC (Mapped Diagnostic Context) key for Spark task names in Spark logs has been changed from `mdc.taskName` to `task_name`. To use the key `mdc.taskName`, you can set `spark.log.legacyTaskNameMdc.enabled` to `true`.
-- Since Spark 4.0, Spark performs speculative executions less agressively with `spark.speculation.multiplier=3` and `spark.speculation.quantile=0.9`. To restore the legacy behavior, you can set `spark.speculation.multiplier=1.5` and `spark.speculation.quantile=0.75`.
+- Since Spark 4.0, Spark performs speculative executions less aggressively with `spark.speculation.multiplier=3` and `spark.speculation.quantile=0.9`. To restore the legacy behavior, you can set `spark.speculation.multiplier=1.5` and `spark.speculation.quantile=0.75`.
## Upgrading from Core 3.4 to 3.5
diff --git a/pom.xml b/pom.xml
index ded8cc2405fde..bc81b810715b5 100644
--- a/pom.xml
+++ b/pom.xml
@@ -140,7 +140,7 @@
1.13.1
2.0.1
shaded-protobuf
- 11.0.20
+ 11.0.21
5.0.0
4.0.1
@@ -806,7 +806,7 @@
com.clearspring.analytics
stream
- 2.9.6
+ 2.9.8
@@ -1264,7 +1264,7 @@
com.github.docker-java
docker-java
- 3.3.5
+ 3.3.6
test
@@ -1284,7 +1284,7 @@
com.github.docker-java
docker-java-transport-zerodep
- 3.3.5
+ 3.3.6
test
@@ -2332,6 +2332,10 @@
co.cask.tephra
*
+
+ com.jolbox
+ bonecp
+
diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py
index 2a5bb5939a3f8..85806b1a265b0 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -301,7 +301,7 @@ def applyInPandas(
evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
)
- return DataFrame(
+ res = DataFrame(
plan.GroupMap(
child=self._df._plan,
grouping_cols=self._grouping_cols,
@@ -310,6 +310,9 @@ def applyInPandas(
),
session=self._df._session,
)
+ if isinstance(schema, StructType):
+ res._cached_schema = schema
+ return res
applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__
@@ -370,7 +373,7 @@ def applyInArrow(
evalType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
)
- return DataFrame(
+ res = DataFrame(
plan.GroupMap(
child=self._df._plan,
grouping_cols=self._grouping_cols,
@@ -379,6 +382,9 @@ def applyInArrow(
),
session=self._df._session,
)
+ if isinstance(schema, StructType):
+ res._cached_schema = schema
+ return res
applyInArrow.__doc__ = PySparkGroupedData.applyInArrow.__doc__
@@ -410,7 +416,7 @@ def applyInPandas(
evalType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
)
- return DataFrame(
+ res = DataFrame(
plan.CoGroupMap(
input=self._gd1._df._plan,
input_grouping_cols=self._gd1._grouping_cols,
@@ -420,6 +426,9 @@ def applyInPandas(
),
session=self._gd1._df._session,
)
+ if isinstance(schema, StructType):
+ res._cached_schema = schema
+ return res
applyInPandas.__doc__ = PySparkPandasCogroupedOps.applyInPandas.__doc__
@@ -436,7 +445,7 @@ def applyInArrow(
evalType=PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
)
- return DataFrame(
+ res = DataFrame(
plan.CoGroupMap(
input=self._gd1._df._plan,
input_grouping_cols=self._gd1._grouping_cols,
@@ -446,6 +455,9 @@ def applyInArrow(
),
session=self._gd1._df._session,
)
+ if isinstance(schema, StructType):
+ res._cached_schema = schema
+ return res
applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__
diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
new file mode 100644
index 0000000000000..f80f4509a7cec
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
@@ -0,0 +1,284 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType
+from pyspark.sql.utils import is_remote
+
+from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
+from pyspark.testing.sqlutils import (
+ have_pandas,
+ have_pyarrow,
+ pandas_requirement_message,
+ pyarrow_requirement_message,
+)
+
+if have_pyarrow:
+ import pyarrow as pa
+ import pyarrow.compute as pc
+
+if have_pandas:
+ import pandas as pd
+
+
+class SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
+ def test_cached_schema_to(self):
+ cdf = self.connect.read.table(self.tbl_name)
+ sdf = self.spark.read.table(self.tbl_name)
+
+ schema = StructType(
+ [
+ StructField("id", IntegerType(), True),
+ StructField("name", StringType(), True),
+ ]
+ )
+
+ cdf1 = cdf.to(schema)
+ self.assertEqual(cdf1._cached_schema, schema)
+
+ sdf1 = sdf.to(schema)
+
+ self.assertEqual(cdf1.schema, sdf1.schema)
+ self.assertEqual(cdf1.collect(), sdf1.collect())
+
+ @unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message,
+ )
+ def test_cached_schema_map_in_pandas(self):
+ data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
+ cdf = self.connect.createDataFrame(data, "a int, b string")
+ sdf = self.spark.createDataFrame(data, "a int, b string")
+
+ def func(iterator):
+ for pdf in iterator:
+ assert isinstance(pdf, pd.DataFrame)
+ assert [d.name for d in list(pdf.dtypes)] == ["int32", "object"]
+ yield pdf
+
+ schema = StructType(
+ [
+ StructField("a", IntegerType(), True),
+ StructField("b", StringType(), True),
+ ]
+ )
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+ self.assertTrue(is_remote())
+ cdf1 = cdf.mapInPandas(func, schema)
+ self.assertEqual(cdf1._cached_schema, schema)
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
+ # 'mapInPandas' depends on the method 'pandas_udf', which is dispatched
+ # based on 'is_remote'. However, in SparkConnectSQLTestCase, the remote
+ # mode is always on, so 'sdf.mapInPandas' fails with incorrect dispatch.
+ # Using this temp env to properly invoke mapInPandas in PySpark Classic.
+ self.assertFalse(is_remote())
+ sdf1 = sdf.mapInPandas(func, schema)
+
+ self.assertEqual(cdf1.schema, sdf1.schema)
+ self.assertEqual(cdf1.collect(), sdf1.collect())
+
+ @unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message,
+ )
+ def test_cached_schema_map_in_arrow(self):
+ data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
+ cdf = self.connect.createDataFrame(data, "a int, b string")
+ sdf = self.spark.createDataFrame(data, "a int, b string")
+
+ def func(iterator):
+ for batch in iterator:
+ assert isinstance(batch, pa.RecordBatch)
+ assert batch.schema.types == [pa.int32(), pa.string()]
+ yield batch
+
+ schema = StructType(
+ [
+ StructField("a", IntegerType(), True),
+ StructField("b", StringType(), True),
+ ]
+ )
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+ self.assertTrue(is_remote())
+ cdf1 = cdf.mapInArrow(func, schema)
+ self.assertEqual(cdf1._cached_schema, schema)
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
+ self.assertFalse(is_remote())
+ sdf1 = sdf.mapInArrow(func, schema)
+
+ self.assertEqual(cdf1.schema, sdf1.schema)
+ self.assertEqual(cdf1.collect(), sdf1.collect())
+
+ @unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message,
+ )
+ def test_cached_schema_group_apply_in_pandas(self):
+ data = [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)]
+ cdf = self.connect.createDataFrame(data, ("id", "v"))
+ sdf = self.spark.createDataFrame(data, ("id", "v"))
+
+ def normalize(pdf):
+ v = pdf.v
+ return pdf.assign(v=(v - v.mean()) / v.std())
+
+ schema = StructType(
+ [
+ StructField("id", LongType(), True),
+ StructField("v", DoubleType(), True),
+ ]
+ )
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+ self.assertTrue(is_remote())
+ cdf1 = cdf.groupby("id").applyInPandas(normalize, schema)
+ self.assertEqual(cdf1._cached_schema, schema)
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
+ self.assertFalse(is_remote())
+ sdf1 = sdf.groupby("id").applyInPandas(normalize, schema)
+
+ self.assertEqual(cdf1.schema, sdf1.schema)
+ self.assertEqual(cdf1.collect(), sdf1.collect())
+
+ @unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message,
+ )
+ def test_cached_schema_group_apply_in_arrow(self):
+ data = [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)]
+ cdf = self.connect.createDataFrame(data, ("id", "v"))
+ sdf = self.spark.createDataFrame(data, ("id", "v"))
+
+ def normalize(table):
+ v = table.column("v")
+ norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, ddof=1))
+ return table.set_column(1, "v", norm)
+
+ schema = StructType(
+ [
+ StructField("id", LongType(), True),
+ StructField("v", DoubleType(), True),
+ ]
+ )
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+ self.assertTrue(is_remote())
+ cdf1 = cdf.groupby("id").applyInArrow(normalize, schema)
+ self.assertEqual(cdf1._cached_schema, schema)
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
+ self.assertFalse(is_remote())
+ sdf1 = sdf.groupby("id").applyInArrow(normalize, schema)
+
+ self.assertEqual(cdf1.schema, sdf1.schema)
+ self.assertEqual(cdf1.collect(), sdf1.collect())
+
+ @unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message,
+ )
+ def test_cached_schema_cogroup_apply_in_pandas(self):
+ data1 = [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)]
+ data2 = [(20000101, 1, "x"), (20000101, 2, "y")]
+
+ cdf1 = self.connect.createDataFrame(data1, ("time", "id", "v1"))
+ sdf1 = self.spark.createDataFrame(data1, ("time", "id", "v1"))
+ cdf2 = self.connect.createDataFrame(data2, ("time", "id", "v2"))
+ sdf2 = self.spark.createDataFrame(data2, ("time", "id", "v2"))
+
+ def asof_join(left, right):
+ return pd.merge_asof(left, right, on="time", by="id")
+
+ schema = StructType(
+ [
+ StructField("time", IntegerType(), True),
+ StructField("id", IntegerType(), True),
+ StructField("v1", DoubleType(), True),
+ StructField("v2", StringType(), True),
+ ]
+ )
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+ self.assertTrue(is_remote())
+ cdf3 = cdf1.groupby("id").cogroup(cdf2.groupby("id")).applyInPandas(asof_join, schema)
+ self.assertEqual(cdf3._cached_schema, schema)
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
+ self.assertFalse(is_remote())
+ sdf3 = sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInPandas(asof_join, schema)
+
+ self.assertEqual(cdf3.schema, sdf3.schema)
+ self.assertEqual(cdf3.collect(), sdf3.collect())
+
+ @unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message,
+ )
+ def test_cached_schema_cogroup_apply_in_arrow(self):
+ data1 = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)]
+ data2 = [(1, "x"), (2, "y")]
+
+ cdf1 = self.connect.createDataFrame(data1, ("id", "v1"))
+ sdf1 = self.spark.createDataFrame(data1, ("id", "v1"))
+ cdf2 = self.connect.createDataFrame(data2, ("id", "v2"))
+ sdf2 = self.spark.createDataFrame(data2, ("id", "v2"))
+
+ def summarize(left, right):
+ return pa.Table.from_pydict(
+ {
+ "left": [left.num_rows],
+ "right": [right.num_rows],
+ }
+ )
+
+ schema = StructType(
+ [
+ StructField("left", LongType(), True),
+ StructField("right", LongType(), True),
+ ]
+ )
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": "1"}):
+ self.assertTrue(is_remote())
+ cdf3 = cdf1.groupby("id").cogroup(cdf2.groupby("id")).applyInArrow(summarize, schema)
+ self.assertEqual(cdf3._cached_schema, schema)
+
+ with self.temp_env({"SPARK_CONNECT_MODE_ENABLED": None}):
+ self.assertFalse(is_remote())
+ sdf3 = sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInArrow(summarize, schema)
+
+ self.assertEqual(cdf3.schema, sdf3.schema)
+ self.assertEqual(cdf3.collect(), sdf3.collect())
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index 74d56491a29e7..4e2e2c51b4db0 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -1204,8 +1204,13 @@ def check_toPandas_error(self, arrow_enabled):
self.spark.sql("select 1/0").toPandas()
def test_toArrow_error(self):
- with self.assertRaises(ArithmeticException):
- self.spark.sql("select 1/0").toArrow()
+ with self.sql_conf(
+ {
+ "spark.sql.ansi.enabled": True,
+ }
+ ):
+ with self.assertRaises(ArithmeticException):
+ self.spark.sql("select 1/0").toArrow()
def test_toPandas_duplicate_field_names(self):
for arrow_enabled in [True, False]:
diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py
index b381833314861..f363b8748c0b9 100644
--- a/python/pyspark/sql/tests/test_context.py
+++ b/python/pyspark/sql/tests/test_context.py
@@ -26,13 +26,13 @@
from pyspark import SparkContext, SQLContext
from pyspark.sql import Row, SparkSession
from pyspark.sql.types import StructType, StringType, StructField
-from pyspark.testing.utils import ReusedPySparkTestCase
+from pyspark.testing.sqlutils import ReusedSQLTestCase
-class HiveContextSQLTests(ReusedPySparkTestCase):
+class HiveContextSQLTests(ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
- ReusedPySparkTestCase.setUpClass()
+ ReusedSQLTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
cls.hive_available = True
cls.spark = None
@@ -58,7 +58,7 @@ def setUp(self):
@classmethod
def tearDownClass(cls):
- ReusedPySparkTestCase.tearDownClass()
+ ReusedSQLTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
if cls.spark is not None:
cls.spark.stop()
@@ -100,23 +100,20 @@ def test_save_and_load_table(self):
self.spark.sql("DROP TABLE savedJsonTable")
self.spark.sql("DROP TABLE externalJsonTable")
- defaultDataSourceName = self.spark.conf.get(
- "spark.sql.sources.default", "org.apache.spark.sql.parquet"
- )
- self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
- df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
- actual = self.spark.catalog.createTable("externalJsonTable", path=tmpPath)
- self.assertEqual(
- sorted(df.collect()), sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())
- )
- self.assertEqual(
- sorted(df.collect()),
- sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()),
- )
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- self.spark.sql("DROP TABLE savedJsonTable")
- self.spark.sql("DROP TABLE externalJsonTable")
- self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+ with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}):
+ df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
+ actual = self.spark.catalog.createTable("externalJsonTable", path=tmpPath)
+ self.assertEqual(
+ sorted(df.collect()),
+ sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()),
+ )
+ self.assertEqual(
+ sorted(df.collect()),
+ sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()),
+ )
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ self.spark.sql("DROP TABLE savedJsonTable")
+ self.spark.sql("DROP TABLE externalJsonTable")
shutil.rmtree(tmpPath)
diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py
index e752856d03164..8060a9ae8bc76 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -55,12 +55,9 @@ def test_save_and_load(self):
)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- try:
- self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}):
actual = self.spark.read.load(path=tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- finally:
- self.spark.sql("RESET spark.sql.sources.default")
csvpath = os.path.join(tempfile.mkdtemp(), "data")
df.write.option("quote", None).format("csv").save(csvpath)
@@ -94,12 +91,9 @@ def test_save_and_load_builder(self):
)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- try:
- self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}):
actual = self.spark.read.load(path=tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- finally:
- self.spark.sql("RESET spark.sql.sources.default")
finally:
shutil.rmtree(tmpPath)
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index 80f2c0fcbc033..1882c1fd1f6ad 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -491,14 +491,11 @@ class User:
self.assertEqual(asdict(user), r.asDict())
def test_negative_decimal(self):
- try:
- self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true")
+ with self.sql_conf({"spark.sql.legacy.allowNegativeScaleOfDecimal": True}):
df = self.spark.createDataFrame([(1,), (11,)], ["value"])
ret = df.select(F.col("value").cast(DecimalType(1, -1))).collect()
actual = list(map(lambda r: int(r.value), ret))
self.assertEqual(actual, [0, 10])
- finally:
- self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false")
def test_create_dataframe_from_objects(self):
data = [MyObject(1, "1"), MyObject(2, "2")]
diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py
index a0fdada72972b..9f07c44c084cf 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -247,6 +247,29 @@ def function(self, *functions):
for f in functions:
self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
+ @contextmanager
+ def temp_env(self, pairs):
+ assert isinstance(pairs, dict), "pairs should be a dictionary."
+
+ keys = pairs.keys()
+ new_values = pairs.values()
+ old_values = [os.environ.get(key, None) for key in keys]
+ for key, new_value in zip(keys, new_values):
+ if new_value is None:
+ if key in os.environ:
+ del os.environ[key]
+ else:
+ os.environ[key] = new_value
+ try:
+ yield
+ finally:
+ for key, old_value in zip(keys, old_values):
+ if old_value is None:
+ if key in os.environ:
+ del os.environ[key]
+ else:
+ os.environ[key] = old_value
+
@staticmethod
def assert_close(a, b):
c = [j[0] for j in b]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 1c2baa78be1b9..f4408220ac939 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -143,50 +143,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
errorClass, missingCol, orderedCandidates, a.origin)
}
- private def checkUnreferencedCTERelations(
- cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
- visited: mutable.Map[Long, Boolean],
- danglingCTERelations: mutable.ArrayBuffer[CTERelationDef],
- cteId: Long): Unit = {
- if (visited(cteId)) {
- return
- }
- val (cteDef, _, refMap) = cteMap(cteId)
- refMap.foreach { case (id, _) =>
- checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, id)
- }
- danglingCTERelations.append(cteDef)
- visited(cteId) = true
- }
-
def checkAnalysis(plan: LogicalPlan): Unit = {
- val inlineCTE = InlineCTE(alwaysInline = true)
- val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
- inlineCTE.buildCTEMap(plan, cteMap)
- val danglingCTERelations = mutable.ArrayBuffer.empty[CTERelationDef]
- val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false)
- // If a CTE relation is never used, it will disappear after inline. Here we explicitly collect
- // these dangling CTE relations, and put them back in the main query, to make sure the entire
- // query plan is valid.
- cteMap.foreach { case (cteId, (_, refCount, _)) =>
- // If a CTE relation ref count is 0, the other CTE relations that reference it should also be
- // collected. This code will also guarantee the leaf relations that do not reference
- // any others are collected first.
- if (refCount == 0) {
- checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, cteId)
- }
- }
- // Inline all CTEs in the plan to help check query plan structures in subqueries.
- var inlinedPlan: LogicalPlan = plan
- try {
- inlinedPlan = inlineCTE(plan)
+ // We should inline all CTE relations to restore the original plan shape, as the analysis check
+ // may need to match certain plan shapes. For dangling CTE relations, they will still be kept
+ // in the original `WithCTE` node, as we need to perform analysis check for them as well.
+ val inlineCTE = InlineCTE(alwaysInline = true, keepDanglingRelations = true)
+ val inlinedPlan: LogicalPlan = try {
+ inlineCTE(plan)
} catch {
case e: AnalysisException =>
throw new ExtendedAnalysisException(e, plan)
}
- if (danglingCTERelations.nonEmpty) {
- inlinedPlan = WithCTE(inlinedPlan, danglingCTERelations.toSeq)
- }
try {
checkAnalysis0(inlinedPlan)
} catch {
@@ -1404,6 +1371,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
aggregated,
canContainOuter && SQLConf.get.getConf(SQLConf.DECORRELATE_OFFSET_ENABLED))
+ // We always inline CTE relations before analysis check, and only un-referenced CTE
+ // relations will be kept in the plan. Here we should simply skip them and check the
+ // children, as un-referenced CTE relations won't be executed anyway and doesn't need to
+ // be restricted by the current subquery correlation limitations.
+ case _: WithCTE | _: CTERelationDef =>
+ plan.children.foreach(p => checkPlan(p, aggregated, canContainOuter))
+
// Category 4: Any other operators not in the above 3 categories
// cannot be on a correlation path, that is they are allowed only
// under a correlation point but they and their descendant operators
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index a52feaa41acf9..588752f3fc176 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -222,7 +222,7 @@ trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging
builder: FunctionBuilder): Unit = {
val newFunction = (info, builder)
functionBuilders.put(name, newFunction) match {
- case previousFunction if previousFunction != newFunction =>
+ case previousFunction if previousFunction != null =>
logWarning(log"The function ${MDC(FUNCTION_NAME, name)} replaced a " +
log"previously registered function.")
case _ =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
index 8d7ff4cbf163d..50828b945bb40 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
@@ -37,23 +37,19 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
* query level.
*
* @param alwaysInline if true, inline all CTEs in the query plan.
+ * @param keepDanglingRelations if true, dangling CTE relations will be kept in the original
+ * `WithCTE` node.
*/
-case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
+case class InlineCTE(
+ alwaysInline: Boolean = false,
+ keepDanglingRelations: Boolean = false) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) {
- val cteMap = mutable.SortedMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
+ val cteMap = mutable.SortedMap.empty[Long, CTEReferenceInfo]
buildCTEMap(plan, cteMap)
cleanCTEMap(cteMap)
- val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
- val inlined = inlineCTE(plan, cteMap, notInlined)
- // CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add
- // WithCTE as top node here.
- if (notInlined.isEmpty) {
- inlined
- } else {
- WithCTE(inlined, notInlined.toSeq)
- }
+ inlineCTE(plan, cteMap)
} else {
plan
}
@@ -74,22 +70,23 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
*
* @param plan The plan to collect the CTEs from
* @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
- * ids. The value of the map is tuple whose elements are:
- * - The CTE definition
- * - The number of incoming references to the CTE. This includes references from
- * other CTEs and regular places.
- * - A mutable inner map that tracks outgoing references (counts) to other CTEs.
+ * ids.
* @param outerCTEId While collecting the map we use this optional CTE id to identify the
* current outer CTE.
*/
- def buildCTEMap(
+ private def buildCTEMap(
plan: LogicalPlan,
- cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
+ cteMap: mutable.Map[Long, CTEReferenceInfo],
outerCTEId: Option[Long] = None): Unit = {
plan match {
case WithCTE(child, cteDefs) =>
cteDefs.foreach { cteDef =>
- cteMap(cteDef.id) = (cteDef, 0, mutable.Map.empty.withDefaultValue(0))
+ cteMap(cteDef.id) = CTEReferenceInfo(
+ cteDef = cteDef,
+ refCount = 0,
+ outgoingRefs = mutable.Map.empty.withDefaultValue(0),
+ shouldInline = true
+ )
}
cteDefs.foreach { cteDef =>
buildCTEMap(cteDef, cteMap, Some(cteDef.id))
@@ -97,11 +94,9 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
buildCTEMap(child, cteMap, outerCTEId)
case ref: CTERelationRef =>
- val (cteDef, refCount, refMap) = cteMap(ref.cteId)
- cteMap(ref.cteId) = (cteDef, refCount + 1, refMap)
+ cteMap(ref.cteId) = cteMap(ref.cteId).withRefCountIncreased(1)
outerCTEId.foreach { cteId =>
- val (_, _, outerRefMap) = cteMap(cteId)
- outerRefMap(ref.cteId) += 1
+ cteMap(cteId).increaseOutgoingRefCount(ref.cteId, 1)
}
case _ =>
@@ -129,15 +124,12 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
* @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE
* ids. Needs to be sorted to speed up cleaning.
*/
- private def cleanCTEMap(
- cteMap: mutable.SortedMap[Long, (CTERelationDef, Int, mutable.Map[Long, Int])]
- ) = {
+ private def cleanCTEMap(cteMap: mutable.SortedMap[Long, CTEReferenceInfo]): Unit = {
cteMap.keys.toSeq.reverse.foreach { currentCTEId =>
- val (_, currentRefCount, refMap) = cteMap(currentCTEId)
- if (currentRefCount == 0) {
- refMap.foreach { case (referencedCTEId, uselessRefCount) =>
- val (cteDef, refCount, refMap) = cteMap(referencedCTEId)
- cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, refMap)
+ val refInfo = cteMap(currentCTEId)
+ if (refInfo.refCount == 0) {
+ refInfo.outgoingRefs.foreach { case (referencedCTEId, uselessRefCount) =>
+ cteMap(referencedCTEId) = cteMap(referencedCTEId).withRefCountDecreased(uselessRefCount)
}
}
}
@@ -145,30 +137,45 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
private def inlineCTE(
plan: LogicalPlan,
- cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
- notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = {
+ cteMap: mutable.Map[Long, CTEReferenceInfo]): LogicalPlan = {
plan match {
case WithCTE(child, cteDefs) =>
- cteDefs.foreach { cteDef =>
- val (cte, refCount, refMap) = cteMap(cteDef.id)
- if (refCount > 0) {
- val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined))
- cteMap(cteDef.id) = (inlined, refCount, refMap)
- if (!shouldInline(inlined, refCount)) {
- notInlined.append(inlined)
- }
+ val remainingDefs = cteDefs.filter { cteDef =>
+ val refInfo = cteMap(cteDef.id)
+ if (refInfo.refCount > 0) {
+ val newDef = refInfo.cteDef.copy(child = inlineCTE(refInfo.cteDef.child, cteMap))
+ val inlineDecision = shouldInline(newDef, refInfo.refCount)
+ cteMap(cteDef.id) = cteMap(cteDef.id).copy(
+ cteDef = newDef, shouldInline = inlineDecision
+ )
+ // Retain the not-inlined CTE relations in place.
+ !inlineDecision
+ } else {
+ keepDanglingRelations
}
}
- inlineCTE(child, cteMap, notInlined)
+ val inlined = inlineCTE(child, cteMap)
+ if (remainingDefs.isEmpty) {
+ inlined
+ } else {
+ WithCTE(inlined, remainingDefs)
+ }
case ref: CTERelationRef =>
- val (cteDef, refCount, _) = cteMap(ref.cteId)
- if (shouldInline(cteDef, refCount)) {
- if (ref.outputSet == cteDef.outputSet) {
- cteDef.child
+ val refInfo = cteMap(ref.cteId)
+ if (refInfo.shouldInline) {
+ if (ref.outputSet == refInfo.cteDef.outputSet) {
+ refInfo.cteDef.child
} else {
val ctePlan = DeduplicateRelations(
- Join(cteDef.child, cteDef.child, Inner, None, JoinHint(None, None))).children(1)
+ Join(
+ refInfo.cteDef.child,
+ refInfo.cteDef.child,
+ Inner,
+ None,
+ JoinHint(None, None)
+ )
+ ).children(1)
val projectList = ref.output.zip(ctePlan.output).map { case (tgtAttr, srcAttr) =>
if (srcAttr.semanticEquals(tgtAttr)) {
tgtAttr
@@ -184,13 +191,41 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
case _ if plan.containsPattern(CTE) =>
plan
- .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, notInlined)))
+ .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap)))
.transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
case e: SubqueryExpression =>
- e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined))
+ e.withNewPlan(inlineCTE(e.plan, cteMap))
}
case _ => plan
}
}
}
+
+/**
+ * The bookkeeping information for tracking CTE relation references.
+ *
+ * @param cteDef The CTE relation definition
+ * @param refCount The number of incoming references to this CTE relation. This includes references
+ * from other CTE relations and regular places.
+ * @param outgoingRefs A mutable map that tracks outgoing reference counts to other CTE relations.
+ * @param shouldInline If true, this CTE relation should be inlined in the places that reference it.
+ */
+case class CTEReferenceInfo(
+ cteDef: CTERelationDef,
+ refCount: Int,
+ outgoingRefs: mutable.Map[Long, Int],
+ shouldInline: Boolean) {
+
+ def withRefCountIncreased(count: Int): CTEReferenceInfo = {
+ copy(refCount = refCount + count)
+ }
+
+ def withRefCountDecreased(count: Int): CTEReferenceInfo = {
+ copy(refCount = refCount - count)
+ }
+
+ def increaseOutgoingRefCount(cteDefId: Long, count: Int): Unit = {
+ outgoingRefs(cteDefId) += count
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 86490a2eea970..0e0946668197a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.parser
import java.util.Locale
import java.util.concurrent.TimeUnit
-import scala.collection.immutable.Seq
import scala.collection.mutable.{ArrayBuffer, Set}
import scala.jdk.CollectionConverters._
import scala.util.{Left, Right}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 9242a06cf1d6e..0135fcfb3cc8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -911,6 +911,10 @@ case class WithCTE(plan: LogicalPlan, cteDefs: Seq[CTERelationDef]) extends Logi
def withNewPlan(newPlan: LogicalPlan): WithCTE = {
withNewChildren(children.init :+ newPlan).asInstanceOf[WithCTE]
}
+
+ override def maxRows: Option[Long] = plan.maxRows
+
+ override def maxRowsPerPartition: Option[Long] = plan.maxRowsPerPartition
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 88c2228e640c4..f4751f2027894 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2301,7 +2301,9 @@ object SQLConf {
buildConf("spark.sql.streaming.stateStore.skipNullsForStreamStreamJoins.enabled")
.internal()
.doc("When true, this config will skip null values in hash based stream-stream joins. " +
- "The number of skipped null values will be shown as custom metric of stream join operator.")
+ "The number of skipped null values will be shown as custom metric of stream join operator. " +
+ "If the streaming query was started with Spark 3.5 or above, please exercise caution " +
+ "before enabling this config since it may hide potential data loss/corruption issues.")
.version("3.3.0")
.booleanConf
.createWithDefault(false)
@@ -4614,6 +4616,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val LEGACY_NO_CHAR_PADDING_IN_PREDICATE = buildConf("spark.sql.legacy.noCharPaddingInPredicate")
+ .internal()
+ .doc("When true, Spark will not apply char type padding for CHAR type columns in string " +
+ s"comparison predicates, when '${READ_SIDE_CHAR_PADDING.key}' is false.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
val CLI_PRINT_HEADER =
buildConf("spark.sql.cli.print.header")
.doc("When set to true, spark-sql CLI prints the names of the columns in query output.")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala
new file mode 100644
index 0000000000000..9d775a5335c67
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.TestRelation
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CTERelationDef, CTERelationRef, LogicalPlan, OneRowRelation, WithCTE}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class InlineCTESuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Batch("inline CTE", FixedPoint(100), InlineCTE()) :: Nil
+ }
+
+ test("SPARK-48307: not-inlined CTE relation in command") {
+ val cteDef = CTERelationDef(OneRowRelation().select(rand(0).as("a")))
+ val cteRef = CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output, cteDef.isStreaming)
+ val plan = AppendData.byName(
+ TestRelation(Seq($"a".double)),
+ WithCTE(cteRef.except(cteRef, isAll = true), Seq(cteDef))
+ ).analyze
+ comparePlans(Optimize.execute(plan), plan)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala
index b5bf337a5a2e6..1b7b0d702ab98 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{CharType, Metadata, StringType}
import org.apache.spark.unsafe.types.UTF8String
@@ -66,9 +67,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols)
})
}
- paddingForStringComparison(newPlan)
+ paddingForStringComparison(newPlan, padCharCol = false)
} else {
- paddingForStringComparison(plan)
+ paddingForStringComparison(
+ plan, padCharCol = !conf.getConf(SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE))
}
}
@@ -90,7 +92,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
}
}
- private def paddingForStringComparison(plan: LogicalPlan): LogicalPlan = {
+ private def paddingForStringComparison(plan: LogicalPlan, padCharCol: Boolean): LogicalPlan = {
plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) {
case operator => operator.transformExpressionsUpWithPruning(
_.containsAnyPattern(BINARY_COMPARISON, IN)) {
@@ -99,12 +101,12 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
// String literal is treated as char type when it's compared to a char type column.
// We should pad the shorter one to the longer length.
case b @ BinaryComparison(e @ AttrOrOuterRef(attr), lit) if lit.foldable =>
- padAttrLitCmp(e, attr.metadata, lit).map { newChildren =>
+ padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren =>
b.withNewChildren(newChildren)
}.getOrElse(b)
case b @ BinaryComparison(lit, e @ AttrOrOuterRef(attr)) if lit.foldable =>
- padAttrLitCmp(e, attr.metadata, lit).map { newChildren =>
+ padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren =>
b.withNewChildren(newChildren.reverse)
}.getOrElse(b)
@@ -117,9 +119,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
val literalCharLengths = literalChars.map(_.numChars())
val targetLen = (length +: literalCharLengths).max
Some(i.copy(
- value = addPadding(e, length, targetLen),
+ value = addPadding(e, length, targetLen, alwaysPad = padCharCol),
list = list.zip(literalCharLengths).map {
- case (lit, charLength) => addPadding(lit, charLength, targetLen)
+ case (lit, charLength) =>
+ addPadding(lit, charLength, targetLen, alwaysPad = false)
} ++ nulls.map(Literal.create(_, StringType))))
case _ => None
}.getOrElse(i)
@@ -162,6 +165,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
private def padAttrLitCmp(
expr: Expression,
metadata: Metadata,
+ padCharCol: Boolean,
lit: Expression): Option[Seq[Expression]] = {
if (expr.dataType == StringType) {
CharVarcharUtils.getRawType(metadata).flatMap {
@@ -174,7 +178,14 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
if (length < stringLitLen) {
Some(Seq(StringRPad(expr, Literal(stringLitLen)), lit))
} else if (length > stringLitLen) {
- Some(Seq(expr, StringRPad(lit, Literal(length))))
+ val paddedExpr = if (padCharCol) {
+ StringRPad(expr, Literal(length))
+ } else {
+ expr
+ }
+ Some(Seq(paddedExpr, StringRPad(lit, Literal(length))))
+ } else if (padCharCol) {
+ Some(Seq(StringRPad(expr, Literal(length)), lit))
} else {
None
}
@@ -186,7 +197,15 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
}
}
- private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = {
- if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr
+ private def addPadding(
+ expr: Expression,
+ charLength: Int,
+ targetLength: Int,
+ alwaysPad: Boolean): Expression = {
+ if (targetLength > charLength) {
+ StringRPad(expr, Literal(targetLength))
+ } else if (alwaysPad) {
+ StringRPad(expr, Literal(charLength))
+ } else expr
}
}
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-legacy.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-legacy.sql.out
index 594a30b054edd..f9b78e94236fb 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-legacy.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-legacy.sql.out
@@ -43,6 +43,30 @@ Project [scalar-subquery#x [] AS scalarsubquery()#x]
+- OneRowRelation
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT id)
+ SELECT 1
+) FROM range(1)
+-- !query analysis
+Project [scalar-subquery#x [] AS scalarsubquery()#x]
+: +- Project [1 AS 1#x]
+: +- OneRowRelation
++- Range (0, 1, step=1)
+
+
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT 1)
+ SELECT id
+) FROM range(1)
+-- !query analysis
+Project [scalar-subquery#x [id#xL] AS scalarsubquery(id)#xL]
+: +- Project [outer(id#xL)]
+: +- OneRowRelation
++- Range (0, 1, step=1)
+
+
-- !query
SELECT * FROM
(
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out
index f1a302b06f2a8..3a9fc5ea1297f 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nested.sql.out
@@ -58,6 +58,40 @@ Project [scalar-subquery#x [] AS scalarsubquery()#x]
+- OneRowRelation
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT id)
+ SELECT 1
+) FROM range(1)
+-- !query analysis
+Project [scalar-subquery#x [id#xL] AS scalarsubquery(id)#x]
+: +- WithCTE
+: :- CTERelationDef xxxx, false
+: : +- SubqueryAlias unreferenced
+: : +- Project [outer(id#xL)]
+: : +- OneRowRelation
+: +- Project [1 AS 1#x]
+: +- OneRowRelation
++- Range (0, 1, step=1)
+
+
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT 1)
+ SELECT id
+) FROM range(1)
+-- !query analysis
+Project [scalar-subquery#x [id#xL] AS scalarsubquery(id)#xL]
+: +- WithCTE
+: :- CTERelationDef xxxx, false
+: : +- SubqueryAlias unreferenced
+: : +- Project [1 AS 1#x]
+: : +- OneRowRelation
+: +- Project [outer(id#xL)]
+: +- OneRowRelation
++- Range (0, 1, step=1)
+
+
-- !query
SELECT * FROM
(
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out
index 6e55c6fa83cd9..e8640c3cbb6bd 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-nonlegacy.sql.out
@@ -58,6 +58,40 @@ Project [scalar-subquery#x [] AS scalarsubquery()#x]
+- OneRowRelation
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT id)
+ SELECT 1
+) FROM range(1)
+-- !query analysis
+Project [scalar-subquery#x [id#xL] AS scalarsubquery(id)#x]
+: +- WithCTE
+: :- CTERelationDef xxxx, false
+: : +- SubqueryAlias unreferenced
+: : +- Project [outer(id#xL)]
+: : +- OneRowRelation
+: +- Project [1 AS 1#x]
+: +- OneRowRelation
++- Range (0, 1, step=1)
+
+
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT 1)
+ SELECT id
+) FROM range(1)
+-- !query analysis
+Project [scalar-subquery#x [id#xL] AS scalarsubquery(id)#xL]
+: +- WithCTE
+: :- CTERelationDef xxxx, false
+: : +- SubqueryAlias unreferenced
+: : +- Project [1 AS 1#x]
+: : +- OneRowRelation
+: +- Project [outer(id#xL)]
+: +- OneRowRelation
++- Range (0, 1, step=1)
+
+
-- !query
SELECT * FROM
(
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql
index e5ef244341751..3b2ba1fcdd66e 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cte-nested.sql
@@ -17,6 +17,18 @@ SELECT (
SELECT * FROM t
);
+-- un-referenced CTE in subquery expression: outer reference in CTE relation
+SELECT (
+ WITH unreferenced AS (SELECT id)
+ SELECT 1
+) FROM range(1);
+
+-- un-referenced CTE in subquery expression: outer reference in CTE main query
+SELECT (
+ WITH unreferenced AS (SELECT 1)
+ SELECT id
+) FROM range(1);
+
-- Make sure CTE in subquery is scoped to that subquery rather than global
-- the 2nd half of the union should fail because the cte is scoped to the first half
SELECT * FROM
diff --git a/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out
index b79d8b1afb0d4..1255e8b51f301 100644
--- a/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cte-legacy.sql.out
@@ -33,6 +33,28 @@ struct
1
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT id)
+ SELECT 1
+) FROM range(1)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT 1)
+ SELECT id
+) FROM range(1)
+-- !query schema
+struct
+-- !query output
+0
+
+
-- !query
SELECT * FROM
(
diff --git a/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out
index a93bcb7593768..7cf488ce8cad4 100644
--- a/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cte-nested.sql.out
@@ -33,6 +33,28 @@ struct
1
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT id)
+ SELECT 1
+) FROM range(1)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT 1)
+ SELECT id
+) FROM range(1)
+-- !query schema
+struct
+-- !query output
+0
+
+
-- !query
SELECT * FROM
(
diff --git a/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out
index ba311c0253ab1..94ef47397eff1 100644
--- a/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cte-nonlegacy.sql.out
@@ -33,6 +33,28 @@ struct
1
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT id)
+ SELECT 1
+) FROM range(1)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT (
+ WITH unreferenced AS (SELECT 1)
+ SELECT id
+) FROM range(1)
+-- !query schema
+struct
+-- !query output
+0
+
+
-- !query
SELECT * FROM
(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
index 013177425da78..a93dee3bf2a61 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
@@ -942,6 +942,34 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa
}
}
}
+
+ test("SPARK-48498: always do char padding in predicates") {
+ import testImplicits._
+ withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false") {
+ withTempPath { dir =>
+ withTable("t") {
+ Seq(
+ "12" -> "12",
+ "12" -> "12 ",
+ "12 " -> "12",
+ "12 " -> "12 "
+ ).toDF("c1", "c2").write.format(format).save(dir.toString)
+ sql(s"CREATE TABLE t (c1 CHAR(3), c2 STRING) USING $format LOCATION '$dir'")
+ // Comparing CHAR column with STRING column directly compares the stored value.
+ checkAnswer(
+ sql("SELECT c1 = c2 FROM t"),
+ Seq(Row(true), Row(false), Row(false), Row(true))
+ )
+ // No matter the CHAR type value is padded or not in the storage, we should always pad it
+ // before comparison with STRING literals.
+ checkAnswer(
+ sql("SELECT c1 = '12', c1 = '12 ', c1 = '12 ' FROM t WHERE c2 = '12'"),
+ Seq(Row(true, true, true), Row(true, true, true))
+ )
+ }
+ }
+ }
+ }
}
class DSV2CharVarcharTestSuite extends CharVarcharTestSuite
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
index 34c6c49bc4981..ad424b3a7cc76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
@@ -256,9 +256,11 @@ trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite {
protected def testQuery(tpcdsGroup: String, query: String, suffix: String = ""): Unit = {
val queryString = resourceToString(s"$tpcdsGroup/$query.sql",
classLoader = Thread.currentThread().getContextClassLoader)
- // Disable char/varchar read-side handling for better performance.
- withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {
+ withSQLConf(
+ // Disable char/varchar read-side handling for better performance.
+ SQLConf.READ_SIDE_CHAR_PADDING.key -> "false",
+ SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {
val qe = sql(queryString).queryExecution
val plan = qe.executedPlan
val explain = normalizeLocation(normalizeIds(qe.explainString(FormattedMode)))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index 2bb2fe970a118..11e077e891bd7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -1340,6 +1340,15 @@ private[hive] object HiveClientImpl extends Logging {
log"will be reset to 'mr' to disable useless hive logic")
hiveConf.set("hive.execution.engine", "mr", SOURCE_SPARK)
}
+ val cpType = hiveConf.get("datanucleus.connectionPoolingType")
+ // Bonecp might cause memory leak, it could affect some hive client versions we support
+ // See more details in HIVE-15551
+ // Also, Bonecp is removed in Hive 4.0.0, see HIVE-23258
+ // Here we use DBCP to replace bonecp instead of HikariCP as HikariCP was introduced in
+ // Hive 2.2.0 (see HIVE-13931) while the minium Hive we support is 2.0.0.
+ if ("bonecp".equalsIgnoreCase(cpType)) {
+ hiveConf.set("datanucleus.connectionPoolingType", "DBCP", SOURCE_SPARK)
+ }
hiveConf
}