From d6b53c34c5c586fe04e000929412d54383202c0f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 1 Nov 2023 11:53:32 -0700 Subject: [PATCH 001/121] [SPARK-45754][CORE] Support `spark.deploy.appIdPattern` ### What changes were proposed in this pull request? This PR aims to support `spark.deploy.appIdPattern` for Apache Spark 4.0.0. ### Why are the changes needed? This allows the users to be able to control driver ID pattern. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs with the newly added test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43616 from dongjoon-hyun/SPARK-45754. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/deploy/master/Master.scala | 3 ++- .../org/apache/spark/internal/config/Deploy.scala | 9 +++++++++ .../apache/spark/deploy/master/MasterSuite.scala | 15 +++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 0a66cc974da7c..058b944c591ad 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -54,6 +54,7 @@ private[deploy] class Master( ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") private val driverIdPattern = conf.get(DRIVER_ID_PATTERN) + private val appIdPattern = conf.get(APP_ID_PATTERN) // For application IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) @@ -1152,7 +1153,7 @@ private[deploy] class Master( /** Generate a new app ID given an app's submission date */ private def newApplicationId(submitDate: Date): String = { - val appId = "app-%s-%04d".format(createDateFormat.format(submitDate), nextAppNumber) + val appId = appIdPattern.format(createDateFormat.format(submitDate), nextAppNumber) nextAppNumber += 1 appId } diff --git a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala index bffdc79175bd9..c6ccf9550bc91 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala @@ -90,4 +90,13 @@ private[spark] object Deploy { .stringConf .checkValue(!_.format("20231101000000", 0).exists(_.isWhitespace), "Whitespace is not allowed.") .createWithDefault("driver-%s-%04d") + + val APP_ID_PATTERN = ConfigBuilder("spark.deploy.appIdPattern") + .doc("The pattern for app ID generation based on Java `String.format` method.. " + + "The default value is `app-%s-%04d` which represents the existing app id string, " + + "e.g., `app-20231031224509-0008`. Plesae be careful to generate unique IDs.") + .version("4.0.0") + .stringConf + .checkValue(!_.format("20231101000000", 0).exists(_.isWhitespace), "Whitespace is not allowed.") + .createWithDefault("app-%s-%04d") } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index cef0e84f20f7a..e8615cdbdd559 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -803,6 +803,7 @@ class MasterSuite extends SparkFunSuite PrivateMethod[mutable.ArrayBuffer[DriverInfo]](Symbol("waitingDrivers")) private val _state = PrivateMethod[RecoveryState.Value](Symbol("state")) private val _newDriverId = PrivateMethod[String](Symbol("newDriverId")) + private val _newApplicationId = PrivateMethod[String](Symbol("newApplicationId")) private val workerInfo = makeWorkerInfo(4096, 10) private val workerInfos = Array(workerInfo, workerInfo, workerInfo) @@ -1251,6 +1252,20 @@ class MasterSuite extends SparkFunSuite }.getMessage assert(m.contains("Whitespace is not allowed")) } + + test("SPARK-45754: Support app id pattern") { + val master = makeMaster(new SparkConf().set(APP_ID_PATTERN, "my-app-%2$05d")) + val submitDate = new Date() + assert(master.invokePrivate(_newApplicationId(submitDate)) === "my-app-00000") + assert(master.invokePrivate(_newApplicationId(submitDate)) === "my-app-00001") + } + + test("SPARK-45754: Prevent invalid app id patterns") { + val m = intercept[IllegalArgumentException] { + makeMaster(new SparkConf().set(APP_ID_PATTERN, "my app")) + }.getMessage + assert(m.contains("Whitespace is not allowed")) + } } private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer) From b14c1f036f8f394ad1903998128c05d04dd584a9 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 1 Nov 2023 13:31:12 -0700 Subject: [PATCH 002/121] [SPARK-45763][CORE][UI] Improve `MasterPage` to show `Resource` column only when it exists MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR aims to improve `MasterPage` to show `Resource` column only when it exists. ### Why are the changes needed? For non-GPU clusters, `Resource` column is empty always. ### Does this PR introduce _any_ user-facing change? After this PR, `MasterPage` still shows `Resource` column if the resource exists like the following. ![Screenshot 2023-11-01 at 11 02 43 AM](https://github.com/apache/spark/assets/9700541/104dd4e7-938b-4269-8952-512e8fb5fa39) If there is no resource on all workers, the `Resource` column is omitted. ![Screenshot 2023-11-01 at 11 03 20 AM](https://github.com/apache/spark/assets/9700541/12c9d4b2-330a-4e36-a6eb-ac2813e0649a) ### How was this patch tested? Manual test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43628 from dongjoon-hyun/SPARK-45763. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/deploy/master/ui/MasterPage.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 48c0c9601c14b..cb325b37958ec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -98,10 +98,15 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { val state = getMasterState - val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory", "Resources") + val showResourceColumn = state.workers.filter(_.resourcesInfoUsed.nonEmpty).nonEmpty + val workerHeaders = if (showResourceColumn) { + Seq("Worker Id", "Address", "State", "Cores", "Memory", "Resources") + } else { + Seq("Worker Id", "Address", "State", "Cores", "Memory") + } val workers = state.workers.sortBy(_.id) val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) - val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) + val workerTable = UIUtils.listingTable(workerHeaders, workerRow(showResourceColumn), workers) val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Executor", "Resources Per Executor", "Submitted Time", "User", "State", "Duration") @@ -256,7 +261,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { UIUtils.basicSparkPage(request, content, "Spark Master at " + state.uri) } - private def workerRow(worker: WorkerInfo): Seq[Node] = { + private def workerRow(showResourceColumn: Boolean): WorkerInfo => Seq[Node] = worker => { { @@ -276,7 +281,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {Utils.megabytesToString(worker.memory)} ({Utils.megabytesToString(worker.memoryUsed)} Used) - {formatWorkerResourcesDetails(worker)} + {if (showResourceColumn) { + {formatWorkerResourcesDetails(worker)} + }} } From 59e291d36c4a9d956b993968a324359b3d75fe5f Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Thu, 2 Nov 2023 09:11:48 +0900 Subject: [PATCH 003/121] [SPARK-45680][CONNECT] Release session ### What changes were proposed in this pull request? Introduce a new `ReleaseSession` Spark Connect RPC, which cancels everything running in the session and removes the session server side. Refactor code around managing the cache of sessions into `SparkConnectSessionManager`. ### Why are the changes needed? Better session management. ### Does this PR introduce _any_ user-facing change? Not really. `SparkSession.stop()` API already existed on the client side. It was closing the client's network connection, but the Session was still there cached for 1 hour on the server side. Caveats, which were not really supported user behaviour: * After `session.stop()`, user could have created a new session with the same session_id in Configuration. That session would be a new session on the client side, but connect to the old cached session in the server. It could therefore e.g. access that old session's state like views or artifacts. * If a session timed out and was removed in the server, it used to be that a new request would re-create the session. The client would then see this as the old session, but the server would see a new one, and e.g. not have access to old session state that was removed. * User is no longer allowed to create a new session with the same session_id as before. ### How was this patch tested? Tests added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43546 from juliuszsompolski/release-session. Lead-authored-by: Juliusz Sompolski Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../main/resources/error/error-classes.json | 5 + .../org/apache/spark/sql/SparkSession.scala | 8 + .../spark/sql/PlanGenerationTestSuite.scala | 4 +- .../apache/spark/sql/SparkSessionSuite.scala | 38 ++-- .../main/protobuf/spark/connect/base.proto | 30 +++ .../CustomSparkConnectBlockingStub.scala | 11 ++ .../connect/client/SparkConnectClient.scala | 10 + .../spark/sql/connect/config/Connect.scala | 18 ++ .../sql/connect/service/SessionHolder.scala | 79 +++++++- .../SparkConnectExecutionManager.scala | 23 ++- .../SparkConnectReleaseExecuteHandler.scala | 4 +- .../SparkConnectReleaseSessionHandler.scala | 40 ++++ .../connect/service/SparkConnectService.scala | 117 +++--------- .../service/SparkConnectSessionManager.scala | 177 ++++++++++++++++++ .../spark/sql/connect/utils/ErrorUtils.scala | 27 +-- .../sql/connect/SparkConnectServerTest.scala | 21 ++- .../execution/ReattachableExecuteSuite.scala | 4 + .../planner/SparkConnectServiceSuite.scala | 4 +- .../service/SparkConnectServiceE2ESuite.scala | 158 ++++++++++++++++ ...r-conditions-invalid-handle-error-class.md | 4 + python/pyspark/sql/connect/client/core.py | 23 ++- python/pyspark/sql/connect/proto/base_pb2.py | 42 +++-- python/pyspark/sql/connect/proto/base_pb2.pyi | 78 ++++++++ .../sql/connect/proto/base_pb2_grpc.py | 49 +++++ python/pyspark/sql/connect/session.py | 12 +- .../sql/tests/connect/test_connect_basic.py | 1 + 26 files changed, 819 insertions(+), 168 deletions(-) create mode 100644 connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala create mode 100644 connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 278011b8cc8f4..af32bcf129c08 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -1737,6 +1737,11 @@ "Session already exists." ] }, + "SESSION_CLOSED" : { + "message" : [ + "Session was closed." + ] + }, "SESSION_NOT_FOUND" : { "message" : [ "Session not found." diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 969ac017ecb1d..1cc1c8400fa89 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -665,6 +665,9 @@ class SparkSession private[sql] ( * @since 3.4.0 */ override def close(): Unit = { + if (releaseSessionOnClose) { + client.releaseSession() + } client.shutdown() allocator.close() SparkSession.onSessionClose(this) @@ -735,6 +738,11 @@ class SparkSession private[sql] ( * We null out the instance for now. */ private def writeReplace(): Any = null + + /** + * Set to false to prevent client.releaseSession on close() (testing only) + */ + private[sql] var releaseSessionOnClose = true } // The minimal builder needed to create a spark session. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index cf287088b59fb..5cc63bc45a04a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -120,7 +120,9 @@ class PlanGenerationTestSuite } override protected def afterAll(): Unit = { - session.close() + // Don't call client.releaseSession on close(), because the connection details are dummy. + session.releaseSessionOnClose = false + session.stop() if (cleanOrphanedGoldenFiles) { cleanOrphanedGoldenFile() } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 4c858262c6ef5..8abc41639fdd2 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -33,18 +33,24 @@ class SparkSessionSuite extends ConnectFunSuite { private val connectionString2: String = "sc://test.me:14099" private val connectionString3: String = "sc://doit:16845" + private def closeSession(session: SparkSession): Unit = { + // Don't call client.releaseSession on close(), because the connection details are dummy. + session.releaseSessionOnClose = false + session.close() + } + test("default") { val session = SparkSession.builder().getOrCreate() assert(session.client.configuration.host == "localhost") assert(session.client.configuration.port == 15002) - session.close() + closeSession(session) } test("remote") { val session = SparkSession.builder().remote(connectionString2).getOrCreate() assert(session.client.configuration.host == "test.me") assert(session.client.configuration.port == 14099) - session.close() + closeSession(session) } test("getOrCreate") { @@ -53,8 +59,8 @@ class SparkSessionSuite extends ConnectFunSuite { try { assert(session1 eq session2) } finally { - session1.close() - session2.close() + closeSession(session1) + closeSession(session2) } } @@ -65,8 +71,8 @@ class SparkSessionSuite extends ConnectFunSuite { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) } finally { - session1.close() - session2.close() + closeSession(session1) + closeSession(session2) } } @@ -77,8 +83,8 @@ class SparkSessionSuite extends ConnectFunSuite { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) } finally { - session1.close() - session2.close() + closeSession(session1) + closeSession(session2) } } @@ -98,7 +104,7 @@ class SparkSessionSuite extends ConnectFunSuite { assertThrows[RuntimeException] { session.range(10).count() } - session.close() + closeSession(session) } test("Default/Active session") { @@ -136,12 +142,12 @@ class SparkSessionSuite extends ConnectFunSuite { assert(SparkSession.getActiveSession.contains(session1)) // Close session1 - session1.close() + closeSession(session1) assert(SparkSession.getDefaultSession.contains(session2)) assert(SparkSession.getActiveSession.isEmpty) // Close session2 - session2.close() + closeSession(session2) assert(SparkSession.getDefaultSession.isEmpty) assert(SparkSession.getActiveSession.isEmpty) } @@ -187,7 +193,7 @@ class SparkSessionSuite extends ConnectFunSuite { // Step 3 - close session 1, no more default session in both scripts phaser.arriveAndAwaitAdvance() - session1.close() + closeSession(session1) // Step 4 - no default session, same active session. phaser.arriveAndAwaitAdvance() @@ -240,13 +246,13 @@ class SparkSessionSuite extends ConnectFunSuite { // Step 7 - close active session in script2 phaser.arriveAndAwaitAdvance() - internalSession.close() + closeSession(internalSession) assert(SparkSession.getActiveSession.isEmpty) } assert(script1.get()) assert(script2.get()) assert(SparkSession.getActiveSession.contains(session2)) - session2.close() + closeSession(session2) assert(SparkSession.getActiveSession.isEmpty) } finally { executor.shutdown() @@ -254,13 +260,13 @@ class SparkSessionSuite extends ConnectFunSuite { } test("deprecated methods") { - SparkSession + val session = SparkSession .builder() .master("yayay") .appName("bob") .enableHiveSupport() .create() - .close() + closeSession(session) } test("serialize as null") { diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 27f51551ba921..19a94a5a429f0 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -784,6 +784,30 @@ message ReleaseExecuteResponse { optional string operation_id = 2; } +message ReleaseSessionRequest { + // (Required) + // + // The session_id of the request to reattach to. + // This must be an id of existing session. + string session_id = 1; + + // (Required) User context + // + // user_context.user_id and session+id both identify a unique remote spark session on the + // server side. + UserContext user_context = 2; + + // Provides optional information about the client sending the request. This field + // can be used for language or version specific information and is only intended for + // logging purposes and will not be interpreted by the server. + optional string client_type = 3; +} + +message ReleaseSessionResponse { + // Session id of the session on which the release executed. + string session_id = 1; +} + message FetchErrorDetailsRequest { // (Required) @@ -934,6 +958,12 @@ service SparkConnectService { // RPC and ReleaseExecute may not be used. rpc ReleaseExecute(ReleaseExecuteRequest) returns (ReleaseExecuteResponse) {} + // Release a session. + // All the executions in the session will be released. Any further requests for the session with + // that session_id for the given user_id will fail. If the session didn't exist or was already + // released, this is a noop. + rpc ReleaseSession(ReleaseSessionRequest) returns (ReleaseSessionResponse) {} + // FetchErrorDetails retrieves the matched exception with details based on a provided error id. rpc FetchErrorDetails(FetchErrorDetailsRequest) returns (FetchErrorDetailsResponse) {} } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index f2efa26f6b609..e963b4136160f 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -96,6 +96,17 @@ private[connect] class CustomSparkConnectBlockingStub( } } + def releaseSession(request: ReleaseSessionRequest): ReleaseSessionResponse = { + grpcExceptionConverter.convert( + request.getSessionId, + request.getUserContext, + request.getClientType) { + retryHandler.retry { + stub.releaseSession(request) + } + } + } + def artifactStatus(request: ArtifactStatusesRequest): ArtifactStatusesResponse = { grpcExceptionConverter.convert( request.getSessionId, diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 42ace003da89f..6d3d9420e2263 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -243,6 +243,16 @@ private[sql] class SparkConnectClient( bstub.interrupt(request) } + private[sql] def releaseSession(): proto.ReleaseSessionResponse = { + val builder = proto.ReleaseSessionRequest.newBuilder() + val request = builder + .setUserContext(userContext) + .setSessionId(sessionId) + .setClientType(userAgent) + .build() + bstub.releaseSession(request) + } + private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] { override def childValue(parent: mutable.Set[String]): mutable.Set[String] = { // Note: make a clone such that changes in the parent tags aren't reflected in diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 2b3f218362cd3..1a5944676f5fb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -74,6 +74,24 @@ object Connect { .intConf .createWithDefault(1024) + val CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT = + buildStaticConf("spark.connect.session.manager.defaultSessionTimeout") + .internal() + .doc("Timeout after which sessions without any new incoming RPC will be removed.") + .version("4.0.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("60m") + + val CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE = + buildStaticConf("spark.connect.session.manager.closedSessionsTombstonesSize") + .internal() + .doc( + "Maximum size of the cache of sessions after which sessions that did not receive any " + + "requests will be removed.") + .version("4.0.0") + .intConf + .createWithDefaultString("1000") + val CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT = buildStaticConf("spark.connect.execute.manager.detachedTimeout") .internal() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index dcced21f37148..792012a682b28 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._ import com.google.common.base.Ticker import com.google.common.cache.CacheBuilder -import org.apache.spark.{JobArtifactSet, SparkException} +import org.apache.spark.{JobArtifactSet, SparkException, SparkSQLException} import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession @@ -40,12 +40,19 @@ import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.SystemClock import org.apache.spark.util.Utils +// Unique key identifying session by combination of user, and session id +case class SessionKey(userId: String, sessionId: String) + /** * Object used to hold the Spark Connect session state. */ case class SessionHolder(userId: String, sessionId: String, session: SparkSession) extends Logging { + @volatile private var lastRpcAccessTime: Option[Long] = None + + @volatile private var isClosing: Boolean = false + private val executions: ConcurrentMap[String, ExecuteHolder] = new ConcurrentHashMap[String, ExecuteHolder]() @@ -73,8 +80,21 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private[connect] lazy val streamingForeachBatchRunnerCleanerCache = new StreamingForeachBatchHelper.CleanerCache(this) - /** Add ExecuteHolder to this session. Called only by SparkConnectExecutionManager. */ + def key: SessionKey = SessionKey(userId, sessionId) + + /** + * Add ExecuteHolder to this session. + * + * Called only by SparkConnectExecutionManager under executionsLock. + */ private[service] def addExecuteHolder(executeHolder: ExecuteHolder): Unit = { + if (isClosing) { + // Do not accept new executions if the session is closing. + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_CLOSED", + messageParameters = Map("handle" -> sessionId)) + } + val oldExecute = executions.putIfAbsent(executeHolder.operationId, executeHolder) if (oldExecute != null) { // the existence of this should alrady be checked by SparkConnectExecutionManager @@ -160,21 +180,55 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ def classloader: ClassLoader = artifactManager.classloader + private[connect] def updateAccessTime(): Unit = { + lastRpcAccessTime = Some(System.currentTimeMillis()) + } + + /** + * Initialize the session. + * + * Called only by SparkConnectSessionManager. + */ private[connect] def initializeSession(): Unit = { + updateAccessTime() eventManager.postStarted() } /** * Expire this session and trigger state cleanup mechanisms. + * + * Called only by SparkConnectSessionManager. */ - private[connect] def expireSession(): Unit = { - logDebug(s"Expiring session with userId: $userId and sessionId: $sessionId") + private[connect] def close(): Unit = { + logInfo(s"Closing session with userId: $userId and sessionId: $sessionId") + + // After isClosing=true, SessionHolder.addExecuteHolder() will not allow new executions for + // this session. Because both SessionHolder.addExecuteHolder() and + // SparkConnectExecutionManager.removeAllExecutionsForSession() are executed under + // executionsLock, this guarantees that removeAllExecutionsForSession triggered below will + // remove all executions and no new executions will be added in the meanwhile. + isClosing = true + + // Note on the below notes about concurrency: + // While closing the session can potentially race with operations started on the session, the + // intended use is that the client session will get closed when it's really not used anymore, + // or that it expires due to inactivity, in which case there should be no races. + + // Clean up all artifacts. + // Note: there can be concurrent AddArtifact calls still adding something. artifactManager.cleanUpResources() - eventManager.postClosed() - // Clean up running queries + + // Clean up running streaming queries. + // Note: there can be concurrent streaming queries being started. SparkConnectService.streamingSessionManager.cleanupRunningQueries(this) streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any streaming workers. removeAllListeners() // removes all listener and stop python listener processes if necessary. + + // Clean up all executions + // It is guaranteed at this point that no new addExecuteHolder are getting started. + SparkConnectService.executionManager.removeAllExecutionsForSession(this.key) + + eventManager.postClosed() } /** @@ -204,6 +258,10 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } } + /** Get SessionInfo with information about this SessionHolder. */ + def getSessionHolderInfo: SessionHolderInfo = + SessionHolderInfo(userId, sessionId, eventManager.status, lastRpcAccessTime) + /** * Caches given DataFrame with the ID. The cache does not expire. The entry needs to be * explicitly removed by the owners of the DataFrame once it is not needed. @@ -291,7 +349,14 @@ object SessionHolder { userId = "testUser", sessionId = UUID.randomUUID().toString, session = session) - SparkConnectService.putSessionForTesting(ret) + SparkConnectService.sessionManager.putSessionForTesting(ret) ret } } + +/** Basic information about SessionHolder. */ +case class SessionHolderInfo( + userId: String, + sessionId: String, + status: SessionStatus, + lastRpcAccesTime: Option[Long]) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index 3c72548978222..c004358e1cf18 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -95,11 +95,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * Remove an ExecuteHolder from this global manager and from its session. Interrupt the * execution if still running, free all resources. */ - private[connect] def removeExecuteHolder(key: ExecuteKey): Unit = { + private[connect] def removeExecuteHolder(key: ExecuteKey, abandoned: Boolean = false): Unit = { var executeHolder: Option[ExecuteHolder] = None executionsLock.synchronized { executeHolder = executions.remove(key) - executeHolder.foreach(e => e.sessionHolder.removeExecuteHolder(e.operationId)) + executeHolder.foreach { e => + if (abandoned) { + abandonedTombstones.put(key, e.getExecuteInfo) + } + e.sessionHolder.removeExecuteHolder(e.operationId) + } if (executions.isEmpty) { lastExecutionTime = Some(System.currentTimeMillis()) } @@ -115,6 +120,17 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } } + private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = { + val sessionExecutionHolders = executionsLock.synchronized { + executions.filter(_._2.sessionHolder.key == key) + } + sessionExecutionHolders.foreach { case (_, executeHolder) => + val info = executeHolder.getExecuteInfo + logInfo(s"Execution $info removed in removeSessionExecutions.") + removeExecuteHolder(executeHolder.key, abandoned = true) + } + } + /** Get info about abandoned execution, if there is one. */ private[connect] def getAbandonedTombstone(key: ExecuteKey): Option[ExecuteInfo] = { Option(abandonedTombstones.getIfPresent(key)) @@ -204,8 +220,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { toRemove.foreach { executeHolder => val info = executeHolder.getExecuteInfo logInfo(s"Found execution $info that was abandoned and expired and will be removed.") - removeExecuteHolder(executeHolder.key) - abandonedTombstones.put(executeHolder.key, info) + removeExecuteHolder(executeHolder.key, abandoned = true) } } logInfo("Finished periodic run of SparkConnectExecutionManager maintenance.") diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala index a3a7815609e40..1ca886960d536 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala @@ -28,8 +28,8 @@ class SparkConnectReleaseExecuteHandler( extends Logging { def handle(v: proto.ReleaseExecuteRequest): Unit = { - val sessionHolder = SparkConnectService - .getIsolatedSession(v.getUserContext.getUserId, v.getSessionId) + val sessionHolder = SparkConnectService.sessionManager + .getIsolatedSession(SessionKey(v.getUserContext.getUserId, v.getSessionId)) val responseBuilder = proto.ReleaseExecuteResponse .newBuilder() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala new file mode 100644 index 0000000000000..a32852bac45ea --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import io.grpc.stub.StreamObserver + +import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging + +class SparkConnectReleaseSessionHandler( + responseObserver: StreamObserver[proto.ReleaseSessionResponse]) + extends Logging { + + def handle(v: proto.ReleaseSessionRequest): Unit = { + val responseBuilder = proto.ReleaseSessionResponse.newBuilder() + responseBuilder.setSessionId(v.getSessionId) + + // If the session doesn't exist, this will just be a noop. + val key = SessionKey(v.getUserContext.getUserId, v.getSessionId) + SparkConnectService.sessionManager.closeSession(key) + + responseObserver.onNext(responseBuilder.build()) + responseObserver.onCompleted() + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index e82c9cba56264..e4b60eeeff0d6 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -18,13 +18,10 @@ package org.apache.spark.sql.connect.service import java.net.InetSocketAddress -import java.util.UUID -import java.util.concurrent.{Callable, TimeUnit} +import java.util.concurrent.TimeUnit import scala.jdk.CollectionConverters._ -import com.google.common.base.Ticker -import com.google.common.cache.{CacheBuilder, RemovalListener, RemovalNotification} import com.google.protobuf.MessageLite import io.grpc.{BindableService, MethodDescriptor, Server, ServerMethodDefinition, ServerServiceDefinition} import io.grpc.MethodDescriptor.PrototypeMarshaller @@ -34,13 +31,12 @@ import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver import org.apache.commons.lang3.StringUtils -import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException} +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, SparkConnectServiceGrpc} import org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService import org.apache.spark.internal.Logging import org.apache.spark.internal.config.UI.UI_ENABLED -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE} import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab} import org.apache.spark.sql.connect.utils.ErrorUtils @@ -201,6 +197,22 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ sessionId = request.getSessionId) } + /** + * Release session. + */ + override def releaseSession( + request: proto.ReleaseSessionRequest, + responseObserver: StreamObserver[proto.ReleaseSessionResponse]): Unit = { + try { + new SparkConnectReleaseSessionHandler(responseObserver).handle(request) + } catch + ErrorUtils.handleError( + "releaseSession", + observer = responseObserver, + userId = request.getUserContext.getUserId, + sessionId = request.getSessionId) + } + override def fetchErrorDetails( request: proto.FetchErrorDetailsRequest, responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]): Unit = { @@ -268,14 +280,6 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ */ object SparkConnectService extends Logging { - private val CACHE_SIZE = 100 - - private val CACHE_TIMEOUT_SECONDS = 3600 - - // Type alias for the SessionCacheKey. Right now this is a String but allows us to switch to a - // different or complex type easily. - private type SessionCacheKey = (String, String) - private[connect] var server: Server = _ private[connect] var uiTab: Option[SparkConnectServerTab] = None @@ -289,77 +293,18 @@ object SparkConnectService extends Logging { server.getPort } - private val userSessionMapping = - cacheBuilder(CACHE_SIZE, CACHE_TIMEOUT_SECONDS).build[SessionCacheKey, SessionHolder]() - private[connect] lazy val executionManager = new SparkConnectExecutionManager() + private[connect] lazy val sessionManager = new SparkConnectSessionManager() + private[connect] val streamingSessionManager = new SparkConnectStreamingQueryCache() - private class RemoveSessionListener extends RemovalListener[SessionCacheKey, SessionHolder] { - override def onRemoval( - notification: RemovalNotification[SessionCacheKey, SessionHolder]): Unit = { - notification.getValue.expireSession() - } - } - - // Simple builder for creating the cache of Sessions. - private def cacheBuilder(cacheSize: Int, timeoutSeconds: Int): CacheBuilder[Object, Object] = { - var cacheBuilder = CacheBuilder.newBuilder().ticker(Ticker.systemTicker()) - if (cacheSize >= 0) { - cacheBuilder = cacheBuilder.maximumSize(cacheSize) - } - if (timeoutSeconds >= 0) { - cacheBuilder.expireAfterAccess(timeoutSeconds, TimeUnit.SECONDS) - } - cacheBuilder.removalListener(new RemoveSessionListener) - cacheBuilder - } - /** * Based on the userId and sessionId, find or create a new SparkSession. */ def getOrCreateIsolatedSession(userId: String, sessionId: String): SessionHolder = { - getSessionOrDefault( - userId, - sessionId, - () => { - val holder = SessionHolder(userId, sessionId, newIsolatedSession()) - holder.initializeSession() - holder - }) - } - - /** - * Based on the userId and sessionId, find an existing SparkSession or throw error. - */ - def getIsolatedSession(userId: String, sessionId: String): SessionHolder = { - getSessionOrDefault( - userId, - sessionId, - () => { - logDebug(s"Session not found: ($userId, $sessionId)") - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.SESSION_NOT_FOUND", - messageParameters = Map("handle" -> sessionId)) - }) - } - - private def getSessionOrDefault( - userId: String, - sessionId: String, - default: Callable[SessionHolder]): SessionHolder = { - // Validate that sessionId is formatted like UUID before creating session. - try { - UUID.fromString(sessionId).toString - } catch { - case _: IllegalArgumentException => - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.FORMAT", - messageParameters = Map("handle" -> sessionId)) - } - userSessionMapping.get((userId, sessionId), default) + sessionManager.getOrCreateIsolatedSession(SessionKey(userId, sessionId)) } /** @@ -368,24 +313,6 @@ object SparkConnectService extends Logging { */ def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = executionManager.listActiveExecutions - /** - * Used for testing - */ - private[connect] def invalidateAllSessions(): Unit = { - userSessionMapping.invalidateAll() - } - - /** - * Used for testing. - */ - private[connect] def putSessionForTesting(sessionHolder: SessionHolder): Unit = { - userSessionMapping.put((sessionHolder.userId, sessionHolder.sessionId), sessionHolder) - } - - private def newIsolatedSession(): SparkSession = { - SparkSession.active.newSession() - } - private def createListenerAndUI(sc: SparkContext): Unit = { val kvStore = sc.statusStore.store.asInstanceOf[ElementTrackingStore] listener = new SparkConnectServerListener(kvStore, sc.conf) @@ -445,7 +372,7 @@ object SparkConnectService extends Logging { } streamingSessionManager.shutdown() executionManager.shutdown() - userSessionMapping.invalidateAll() + sessionManager.shutdown() uiTab.foreach(_.detach()) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala new file mode 100644 index 0000000000000..5c8e3c611586c --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import java.util.UUID +import java.util.concurrent.{Callable, TimeUnit} + +import com.google.common.base.Ticker +import com.google.common.cache.{CacheBuilder, RemovalListener, RemovalNotification} + +import org.apache.spark.{SparkEnv, SparkSQLException} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.config.Connect.{CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE, CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT} + +/** + * Global tracker of all SessionHolders holding Spark Connect sessions. + */ +class SparkConnectSessionManager extends Logging { + + private val sessionsLock = new Object + + private val sessionStore = + CacheBuilder + .newBuilder() + .ticker(Ticker.systemTicker()) + .expireAfterAccess( + SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT), + TimeUnit.MILLISECONDS) + .removalListener(new RemoveSessionListener) + .build[SessionKey, SessionHolder]() + + private val closedSessionsCache = + CacheBuilder + .newBuilder() + .maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE)) + .build[SessionKey, SessionHolderInfo]() + + /** + * Based on the userId and sessionId, find or create a new SparkSession. + */ + private[connect] def getOrCreateIsolatedSession(key: SessionKey): SessionHolder = { + // Lock to guard against concurrent removal and insertion into closedSessionsCache. + sessionsLock.synchronized { + getSession( + key, + Some(() => { + validateSessionCreate(key) + val holder = SessionHolder(key.userId, key.sessionId, newIsolatedSession()) + holder.initializeSession() + holder + })) + } + } + + /** + * Based on the userId and sessionId, find an existing SparkSession or throw error. + */ + private[connect] def getIsolatedSession(key: SessionKey): SessionHolder = { + getSession( + key, + Some(() => { + logDebug(s"Session not found: $key") + if (closedSessionsCache.getIfPresent(key) != null) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_CLOSED", + messageParameters = Map("handle" -> key.sessionId)) + } else { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_NOT_FOUND", + messageParameters = Map("handle" -> key.sessionId)) + } + })) + } + + /** + * Based on the userId and sessionId, get an existing SparkSession if present. + */ + private[connect] def getIsolatedSessionIfPresent(key: SessionKey): Option[SessionHolder] = { + Option(getSession(key, None)) + } + + private def getSession( + key: SessionKey, + default: Option[Callable[SessionHolder]]): SessionHolder = { + val session = default match { + case Some(callable) => sessionStore.get(key, callable) + case None => sessionStore.getIfPresent(key) + } + // record access time before returning + session match { + case null => + null + case s: SessionHolder => + s.updateAccessTime() + s + } + } + + def closeSession(key: SessionKey): Unit = { + // Invalidate will trigger RemoveSessionListener + sessionStore.invalidate(key) + } + + private class RemoveSessionListener extends RemovalListener[SessionKey, SessionHolder] { + override def onRemoval(notification: RemovalNotification[SessionKey, SessionHolder]): Unit = { + val sessionHolder = notification.getValue + sessionsLock.synchronized { + // First put into closedSessionsCache, so that it cannot get accidentally recreated by + // getOrCreateIsolatedSession. + closedSessionsCache.put(sessionHolder.key, sessionHolder.getSessionHolderInfo) + } + // Rest of the cleanup outside sessionLock - the session cannot be accessed anymore by + // getOrCreateIsolatedSession. + sessionHolder.close() + } + } + + def shutdown(): Unit = { + sessionsLock.synchronized { + sessionStore.invalidateAll() + closedSessionsCache.invalidateAll() + } + } + + private def newIsolatedSession(): SparkSession = { + SparkSession.active.newSession() + } + + private def validateSessionCreate(key: SessionKey): Unit = { + // Validate that sessionId is formatted like UUID before creating session. + try { + UUID.fromString(key.sessionId).toString + } catch { + case _: IllegalArgumentException => + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.FORMAT", + messageParameters = Map("handle" -> key.sessionId)) + } + // Validate that session with that key has not been already closed. + if (closedSessionsCache.getIfPresent(key) != null) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_CLOSED", + messageParameters = Map("handle" -> key.sessionId)) + } + } + + /** + * Used for testing + */ + private[connect] def invalidateAllSessions(): Unit = { + sessionStore.invalidateAll() + closedSessionsCache.invalidateAll() + } + + /** + * Used for testing. + */ + private[connect] def putSessionForTesting(sessionHolder: SessionHolder): Unit = { + sessionStore.put(sessionHolder.key, sessionHolder) + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 741fa97f17878..837ee5a00227c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -41,7 +41,7 @@ import org.apache.spark.api.python.PythonException import org.apache.spark.connect.proto.FetchErrorDetailsResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.config.Connect -import org.apache.spark.sql.connect.service.{ExecuteEventsManager, SessionHolder, SparkConnectService} +import org.apache.spark.sql.connect.service.{ExecuteEventsManager, SessionHolder, SessionKey, SparkConnectService} import org.apache.spark.sql.internal.SQLConf private[connect] object ErrorUtils extends Logging { @@ -153,7 +153,9 @@ private[connect] object ErrorUtils extends Logging { .build() } - private def buildStatusFromThrowable(st: Throwable, sessionHolder: SessionHolder): RPCStatus = { + private def buildStatusFromThrowable( + st: Throwable, + sessionHolderOpt: Option[SessionHolder]): RPCStatus = { val errorInfo = ErrorInfo .newBuilder() .setReason(st.getClass.getName) @@ -162,20 +164,20 @@ private[connect] object ErrorUtils extends Logging { "classes", JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName)))) - if (sessionHolder.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)) { + if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) { // Generate a new unique key for this exception. val errorId = UUID.randomUUID().toString errorInfo.putMetadata("errorId", errorId) - sessionHolder.errorIdToError + sessionHolderOpt.get.errorIdToError .put(errorId, st) } lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st)) val withStackTrace = - if (sessionHolder.session.conf.get( - SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty) { + if (sessionHolderOpt.exists( + _.session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty)) { val maxSize = SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE) errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize)) } else { @@ -215,19 +217,22 @@ private[connect] object ErrorUtils extends Logging { sessionId: String, events: Option[ExecuteEventsManager] = None, isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = { - val sessionHolder = - SparkConnectService - .getOrCreateIsolatedSession(userId, sessionId) + + // SessionHolder may not be present, e.g. if the session was already closed. + // When SessionHolder is not present error details will not be available for FetchErrorDetails. + val sessionHolderOpt = + SparkConnectService.sessionManager.getIsolatedSessionIfPresent( + SessionKey(userId, sessionId)) val partial: PartialFunction[Throwable, (Throwable, Throwable)] = { case se: SparkException if isPythonExecutionException(se) => ( se, StatusProto.toStatusRuntimeException( - buildStatusFromThrowable(se.getCause, sessionHolder))) + buildStatusFromThrowable(se.getCause, sessionHolderOpt))) case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => - (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, sessionHolder))) + (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, sessionHolderOpt))) case e: Throwable => ( diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index 7b02377f4847c..120126f20ec24 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -59,10 +59,6 @@ trait SparkConnectServerTest extends SharedSparkSession { withSparkEnvConfs((Connect.CONNECT_GRPC_BINDING_PORT.key, serverPort.toString)) { SparkConnectService.start(spark.sparkContext) } - // register udf directly on the server, we're not testing client UDFs here... - val serverSession = - SparkConnectService.getOrCreateIsolatedSession(defaultUserId, defaultSessionId).session - serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms })) } override def afterAll(): Unit = { @@ -84,6 +80,7 @@ trait SparkConnectServerTest extends SharedSparkSession { protected def clearAllExecutions(): Unit = { SparkConnectService.executionManager.listExecuteHolders.foreach(_.close()) SparkConnectService.executionManager.periodicMaintenance(0) + SparkConnectService.sessionManager.invalidateAllSessions() assertNoActiveExecutions() } @@ -215,12 +212,24 @@ trait SparkConnectServerTest extends SharedSparkSession { } } + protected def withClient(sessionId: String = defaultSessionId, userId: String = defaultUserId)( + f: SparkConnectClient => Unit): Unit = { + withClient(f, sessionId, userId) + } + protected def withClient(f: SparkConnectClient => Unit): Unit = { + withClient(f, defaultSessionId, defaultUserId) + } + + protected def withClient( + f: SparkConnectClient => Unit, + sessionId: String, + userId: String): Unit = { val client = SparkConnectClient .builder() .port(serverPort) - .sessionId(defaultSessionId) - .userId(defaultUserId) + .sessionId(sessionId) + .userId(userId) .enableReattachableExecute() .build() try f(client) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala index 0e29a07b719af..784b978f447df 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala @@ -347,6 +347,10 @@ class ReattachableExecuteSuite extends SparkConnectServerTest { } test("long sleeping query") { + // register udf directly on the server, we're not testing client UDFs here... + val serverSession = + SparkConnectService.getOrCreateIsolatedSession(defaultUserId, defaultSessionId).session + serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms })) // query will be sleeping and not returning results, while having multiple reattach withSparkEnvConfs( (Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION.key, "1s")) { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index ce452623e6b84..b314e7d8d4834 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -841,12 +841,12 @@ class SparkConnectServiceSuite spark.sparkContext.addSparkListener(verifyEvents.listener) Utils.tryWithSafeFinally({ f(verifyEvents) - SparkConnectService.invalidateAllSessions() + SparkConnectService.sessionManager.invalidateAllSessions() verifyEvents.onSessionClosed() }) { verifyEvents.waitUntilEmpty() spark.sparkContext.removeSparkListener(verifyEvents.listener) - SparkConnectService.invalidateAllSessions() + SparkConnectService.sessionManager.invalidateAllSessions() SparkConnectPluginRegistry.reset() } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala index 14ecc9a2e95e4..cc0481dab0f4f 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala @@ -16,13 +16,171 @@ */ package org.apache.spark.sql.connect.service +import java.util.UUID + import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkException import org.apache.spark.sql.connect.SparkConnectServerTest class SparkConnectServiceE2ESuite extends SparkConnectServerTest { + // Making results of these queries large enough, so that all the results do not fit in the + // buffers and are not pushed out immediately even when the client doesn't consume them, so that + // even if the connection got closed, the client would see it as succeeded because the results + // were all already in the buffer. + val BIG_ENOUGH_QUERY = "select * from range(1000000)" + + test("ReleaseSession releases all queries and does not allow more requests in the session") { + withClient { client => + val query1 = client.execute(buildPlan(BIG_ENOUGH_QUERY)) + val query2 = client.execute(buildPlan(BIG_ENOUGH_QUERY)) + val query3 = client.execute(buildPlan("select 1")) + // just creating the iterator is lazy, trigger query1 and query2 to be sent. + query1.hasNext + query2.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 2 + } + + // Close session + client.releaseSession() + + // Check that queries get cancelled + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 0 + // SparkConnectService.sessionManager. + } + + // query1 and query2 could get either an: + // OPERATION_CANCELED if it happens fast - when closing the session interrupted the queries, + // and that error got pushed to the client buffers before the client got disconnected. + // OPERATION_ABANDONED if it happens slow - when closing the session interrupted the client + // RPCs before it pushed out the error above. The client would then get an + // INVALID_CURSOR.DISCONNECTED, which it will retry with a ReattachExecute, and then get an + // INVALID_HANDLE.OPERATION_ABANDONED. + val query1Error = intercept[SparkException] { + while (query1.hasNext) query1.next() + } + assert( + query1Error.getMessage.contains("OPERATION_CANCELED") || + query1Error.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + val query2Error = intercept[SparkException] { + while (query2.hasNext) query2.next() + } + assert( + query2Error.getMessage.contains("OPERATION_CANCELED") || + query2Error.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + + // query3 has not been submitted before, so it should now fail with SESSION_CLOSED + val query3Error = intercept[SparkException] { + query3.hasNext + } + assert(query3Error.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + + // No other requests should be allowed in the session, failing with SESSION_CLOSED + val requestError = intercept[SparkException] { + client.interruptAll() + } + assert(requestError.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + } + } + + private def testReleaseSessionTwoSessions( + sessionIdA: String, + userIdA: String, + sessionIdB: String, + userIdB: String): Unit = { + withClient(sessionId = sessionIdA, userId = userIdA) { clientA => + withClient(sessionId = sessionIdB, userId = userIdB) { clientB => + val queryA = clientA.execute(buildPlan(BIG_ENOUGH_QUERY)) + val queryB = clientB.execute(buildPlan(BIG_ENOUGH_QUERY)) + // just creating the iterator is lazy, trigger query1 and query2 to be sent. + queryA.hasNext + queryB.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 2 + } + // Close session A + clientA.releaseSession() + + // A's query gets kicked out. + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 1 + } + val queryAError = intercept[SparkException] { + while (queryA.hasNext) queryA.next() + } + assert( + queryAError.getMessage.contains("OPERATION_CANCELED") || + queryAError.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + + // B's query can run. + while (queryB.hasNext) queryB.next() + + // B can submit more queries. + val queryB2 = clientB.execute(buildPlan("SELECT 1")) + while (queryB2.hasNext) queryB2.next() + // A can't submit more queries. + val queryA2 = clientA.execute(buildPlan("SELECT 1")) + val queryA2Error = intercept[SparkException] { + clientA.interruptAll() + } + assert(queryA2Error.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + } + } + } + + test("ReleaseSession for different user_id with same session_id do not affect each other") { + testReleaseSessionTwoSessions(defaultSessionId, "A", defaultSessionId, "B") + } + + test("ReleaseSession for different session_id with same user_id do not affect each other") { + val sessionIdA = UUID.randomUUID.toString() + val sessionIdB = UUID.randomUUID.toString() + testReleaseSessionTwoSessions(sessionIdA, "X", sessionIdB, "X") + } + + test("ReleaseSession: can't create a new session with the same id and user after release") { + val sessionId = UUID.randomUUID.toString() + val userId = "Y" + withClient(sessionId = sessionId, userId = userId) { client => + // this will create the session, and then ReleaseSession at the end of withClient. + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + withClient(sessionId = sessionId, userId = userId) { client => + // shall not be able to create a new session with the same id and user. + val query = client.execute(buildPlan("SELECT 1")) + val queryError = intercept[SparkException] { + while (query.hasNext) query.next() + } + assert(queryError.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + } + } + + test("ReleaseSession: session with different session_id or user_id allowed after release") { + val sessionId = UUID.randomUUID.toString() + val userId = "Y" + withClient(sessionId = sessionId, userId = userId) { client => + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + withClient(sessionId = UUID.randomUUID.toString, userId = userId) { client => + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + withClient(sessionId = sessionId, userId = "YY") { client => + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + } + test("SPARK-45133 query should reach FINISHED state when results are not consumed") { withRawBlockingStub { stub => val iter = diff --git a/docs/sql-error-conditions-invalid-handle-error-class.md b/docs/sql-error-conditions-invalid-handle-error-class.md index c4cbb48035ff5..14526cd53724f 100644 --- a/docs/sql-error-conditions-invalid-handle-error-class.md +++ b/docs/sql-error-conditions-invalid-handle-error-class.md @@ -45,6 +45,10 @@ Operation not found. Session already exists. +## SESSION_CLOSED + +Session was closed. + ## SESSION_NOT_FOUND Session not found. diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 318f7d7ade4a2..11a1112ad1fe7 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -19,7 +19,7 @@ "SparkConnectClient", ] -from pyspark.loose_version import LooseVersion + from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -61,6 +61,7 @@ from google.protobuf import text_format from google.rpc import error_details_pb2 +from pyspark.loose_version import LooseVersion from pyspark.version import __version__ from pyspark.resource.information import ResourceInformation from pyspark.sql.connect.client.artifact import ArtifactManager @@ -1471,6 +1472,26 @@ def interrupt_operation(self, op_id: str) -> Optional[List[str]]: except Exception as error: self._handle_error(error) + def release_session(self) -> None: + req = pb2.ReleaseSessionRequest() + req.session_id = self._session_id + req.client_type = self._builder.userAgent + if self._user_id: + req.user_context.user_id = self._user_id + try: + for attempt in self._retrying(): + with attempt: + resp = self._stub.ReleaseSession(req, metadata=self._builder.metadata()) + if resp.session_id != self._session_id: + raise SparkConnectException( + "Received incorrect session identifier for request:" + f"{resp.session_id} != {self._session_id}" + ) + return + raise SparkConnectException("Invalid state during retry exception handling.") + except Exception as error: + self._handle_error(error) + def add_tag(self, tag: str) -> None: self._throw_if_invalid_tag(tag) if not hasattr(self.thread_local, "tags"): diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 0ea02525f78ff..0e374e7aa2ccb 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x0f\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"p\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xbe\x0b\n\x19\x46\x65tchErrorDetailsResponse\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xef\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1a\n\x08\x63\x61llSite\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xd1\x06\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x0f\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"p\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xab\x01\n\x15ReleaseSessionRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"7\n\x16ReleaseSessionResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xbe\x0b\n\x19\x46\x65tchErrorDetailsResponse\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xef\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1a\n\x08\x63\x61llSite\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xb2\x07\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12_\n\x0eReleaseSession\x12$.spark.connect.ReleaseSessionRequest\x1a%.spark.connect.ReleaseSessionResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -199,22 +199,26 @@ _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11234 _RELEASEEXECUTERESPONSE._serialized_start = 11263 _RELEASEEXECUTERESPONSE._serialized_end = 11375 - _FETCHERRORDETAILSREQUEST._serialized_start = 11378 - _FETCHERRORDETAILSREQUEST._serialized_end = 11579 - _FETCHERRORDETAILSRESPONSE._serialized_start = 11582 - _FETCHERRORDETAILSRESPONSE._serialized_end = 13052 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11727 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11901 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 11904 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12271 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 12234 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12271 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12274 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12683 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 12585 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 12653 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12686 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13033 - _SPARKCONNECTSERVICE._serialized_start = 13055 - _SPARKCONNECTSERVICE._serialized_end = 13904 + _RELEASESESSIONREQUEST._serialized_start = 11378 + _RELEASESESSIONREQUEST._serialized_end = 11549 + _RELEASESESSIONRESPONSE._serialized_start = 11551 + _RELEASESESSIONRESPONSE._serialized_end = 11606 + _FETCHERRORDETAILSREQUEST._serialized_start = 11609 + _FETCHERRORDETAILSREQUEST._serialized_end = 11810 + _FETCHERRORDETAILSRESPONSE._serialized_start = 11813 + _FETCHERRORDETAILSRESPONSE._serialized_end = 13283 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11958 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 12132 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 12135 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12502 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 12465 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12502 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12505 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12914 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 12816 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 12884 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12917 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13264 + _SPARKCONNECTSERVICE._serialized_start = 13286 + _SPARKCONNECTSERVICE._serialized_end = 14232 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index c29feb4164cf1..20abbcb348bdd 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -2763,6 +2763,84 @@ class ReleaseExecuteResponse(google.protobuf.message.Message): global___ReleaseExecuteResponse = ReleaseExecuteResponse +class ReleaseSessionRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SESSION_ID_FIELD_NUMBER: builtins.int + USER_CONTEXT_FIELD_NUMBER: builtins.int + CLIENT_TYPE_FIELD_NUMBER: builtins.int + session_id: builtins.str + """(Required) + + The session_id of the request to reattach to. + This must be an id of existing session. + """ + @property + def user_context(self) -> global___UserContext: + """(Required) User context + + user_context.user_id and session+id both identify a unique remote spark session on the + server side. + """ + client_type: builtins.str + """Provides optional information about the client sending the request. This field + can be used for language or version specific information and is only intended for + logging purposes and will not be interpreted by the server. + """ + def __init__( + self, + *, + session_id: builtins.str = ..., + user_context: global___UserContext | None = ..., + client_type: builtins.str | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_client_type", + b"_client_type", + "client_type", + b"client_type", + "user_context", + b"user_context", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_client_type", + b"_client_type", + "client_type", + b"client_type", + "session_id", + b"session_id", + "user_context", + b"user_context", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"] + ) -> typing_extensions.Literal["client_type"] | None: ... + +global___ReleaseSessionRequest = ReleaseSessionRequest + +class ReleaseSessionResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SESSION_ID_FIELD_NUMBER: builtins.int + session_id: builtins.str + """Session id of the session on which the release executed.""" + def __init__( + self, + *, + session_id: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["session_id", b"session_id"] + ) -> None: ... + +global___ReleaseSessionResponse = ReleaseSessionResponse + class FetchErrorDetailsRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py b/python/pyspark/sql/connect/proto/base_pb2_grpc.py index f6c5573ded6b5..12675747e0f92 100644 --- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py +++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py @@ -70,6 +70,11 @@ def __init__(self, channel): request_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.FromString, ) + self.ReleaseSession = channel.unary_unary( + "/spark.connect.SparkConnectService/ReleaseSession", + request_serializer=spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.SerializeToString, + response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.FromString, + ) self.FetchErrorDetails = channel.unary_unary( "/spark.connect.SparkConnectService/FetchErrorDetails", request_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString, @@ -141,6 +146,16 @@ def ReleaseExecute(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def ReleaseSession(self, request, context): + """Release a session. + All the executions in the session will be released. Any further requests for the session with + that session_id for the given user_id will fail. If the session didn't exist or was already + released, this is a noop. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def FetchErrorDetails(self, request, context): """FetchErrorDetails retrieves the matched exception with details based on a provided error id.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -190,6 +205,11 @@ def add_SparkConnectServiceServicer_to_server(servicer, server): request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.FromString, response_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.SerializeToString, ), + "ReleaseSession": grpc.unary_unary_rpc_method_handler( + servicer.ReleaseSession, + request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.FromString, + response_serializer=spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.SerializeToString, + ), "FetchErrorDetails": grpc.unary_unary_rpc_method_handler( servicer.FetchErrorDetails, request_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.FromString, @@ -438,6 +458,35 @@ def ReleaseExecute( metadata, ) + @staticmethod + def ReleaseSession( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/spark.connect.SparkConnectService/ReleaseSession", + spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.SerializeToString, + spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + @staticmethod def FetchErrorDetails( request, diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 09bd60606c769..1aa857b4f6175 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -254,6 +254,9 @@ def __init__(self, connection: Union[str, ChannelBuilder], userId: Optional[str] self._client = SparkConnectClient(connection=connection, user_id=userId) self._session_id = self._client._session_id + # Set to false to prevent client.release_session on close() (testing only) + self.release_session_on_close = True + @classmethod def _set_default_and_active_session(cls, session: "SparkSession") -> None: """ @@ -645,15 +648,16 @@ def clearTags(self) -> None: clearTags.__doc__ = PySparkSession.clearTags.__doc__ def stop(self) -> None: - # Stopping the session will only close the connection to the current session (and - # the life cycle of the session is maintained by the server), - # whereas the regular PySpark session immediately terminates the Spark Context - # itself, meaning that stopping all Spark sessions. + # Whereas the regular PySpark session immediately terminates the Spark Context + # itself, meaning that stopping all Spark sessions, this will only stop this one session + # on the server. # It is controversial to follow the existing the regular Spark session's behavior # specifically in Spark Connect the Spark Connect server is designed for # multi-tenancy - the remote client side cannot just stop the server and stop # other remote clients being used from other users. with SparkSession._lock: + if not self.is_stopped and self.release_session_on_close: + self.client.release_session() self.client.close() if self is SparkSession._default_session: SparkSession._default_session = None diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 34bd314c76f7c..f024a03c2686c 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3437,6 +3437,7 @@ def test_can_create_multiple_sessions_to_different_remotes(self): # Gets currently active session. same = PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate() self.assertEquals(other, same) + same.release_session_on_close = False # avoid sending release to dummy connection same.stop() # Make sure the environment is clean. From a04d4e2233c0d20c6a86c64391b1e1a6071b4550 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 2 Nov 2023 09:18:42 +0900 Subject: [PATCH 004/121] [SPARK-45761][K8S][INFRA][DOCS] Upgrade `Volcano` to 1.8.1 ### What changes were proposed in this pull request? This PR aims to upgrade `Volcano` to 1.8.1 in K8s integration test document and GitHub Action job. ### Why are the changes needed? To bring the latest feature and bug fixes in addition to the test coverage for Volcano scheduler 1.8.1. - https://github.com/volcano-sh/volcano/releases/tag/v1.8.1 - https://github.com/volcano-sh/volcano/pull/3101 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43624 from dongjoon-hyun/SPARK-45761. Authored-by: Dongjoon Hyun Signed-off-by: Hyukjin Kwon --- .github/workflows/build_and_test.yml | 2 +- resource-managers/kubernetes/integration-tests/README.md | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 5825185f34450..eded5da5c1ddd 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -1063,7 +1063,7 @@ jobs: export PVC_TESTS_VM_PATH=$PVC_TMP_DIR minikube mount ${PVC_TESTS_HOST_PATH}:${PVC_TESTS_VM_PATH} --gid=0 --uid=185 & kubectl create clusterrolebinding serviceaccounts-cluster-admin --clusterrole=cluster-admin --group=system:serviceaccounts || true - kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.0/installer/volcano-development.yaml || true + kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.1/installer/volcano-development.yaml || true eval $(minikube docker-env) build/sbt -Psparkr -Pkubernetes -Pvolcano -Pkubernetes-integration-tests -Dspark.kubernetes.test.driverRequestCores=0.5 -Dspark.kubernetes.test.executorRequestCores=0.2 -Dspark.kubernetes.test.volcanoMaxConcurrencyJobNum=1 -Dtest.exclude.tags=local "kubernetes-integration-tests/test" - name: Upload Spark on K8S integration tests log files diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index d39fdfbfd3966..d5ccd3fe756b7 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -329,11 +329,11 @@ You can also specify your specific dockerfile to build JVM/Python/R based image ## Requirements - A minimum of 6 CPUs and 9G of memory is required to complete all Volcano test cases. -- Volcano v1.8.0. +- Volcano v1.8.1. ## Installation - kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.0/installer/volcano-development.yaml + kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.1/installer/volcano-development.yaml ## Run tests @@ -354,5 +354,5 @@ You can also specify `volcano` tag to only run Volcano test: ## Cleanup Volcano - kubectl delete -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.0/installer/volcano-development.yaml + kubectl delete -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.1/installer/volcano-development.yaml From 30ec6e358536dfb695fcc1b8c3f084acb576d871 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Wed, 1 Nov 2023 21:08:04 -0700 Subject: [PATCH 005/121] [SPARK-45742][CORE][CONNECT][MLLIB][PYTHON] Introduce an implicit function for Scala Array to wrap into `immutable.ArraySeq` ### What changes were proposed in this pull request? Currently, we need to use `immutable.ArraySeq.unsafeWrapArray(array)` to wrap an Array into an `immutable.ArraySeq`, which makes the code look bloated. So this PR introduces an implicit function `toImmutableArraySeq` to make it easier for Scala Array to be wrapped into `immutable.ArraySeq`. After this pr, we can use the following way to wrap an array into an `immutable.ArraySeq`: ```scala import org.apache.spark.util.ArrayImplicits._ val dataArray = ... val immutableArraySeq = dataArray.toImmutableArraySeq ``` At the same time, this pr replaces the existing use of `immutable.ArraySeq.unsafeWrapArray(array)` with the new method. On the other hand, this implicit function will be conducive to the progress of work SPARK-45686 and SPARK-45687. ### Why are the changes needed? Makes the code for wrapping a Scala Array into an `immutable.ArraySeq` look less bloated. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43607 from LuciferYang/SPARK-45742. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../apache/spark/util/ArrayImplicits.scala | 36 +++++++++++++ .../org/apache/spark/sql/SparkSession.scala | 4 +- .../client/GrpcExceptionConverter.scala | 4 +- .../connect/planner/SparkConnectPlanner.scala | 27 +++++----- .../spark/sql/connect/utils/ErrorUtils.scala | 32 ++++++------ .../spark/util/ArrayImplicitsSuite.scala | 50 +++++++++++++++++++ .../python/GaussianMixtureModelWrapper.scala | 4 +- .../mllib/api/python/LDAModelWrapper.scala | 8 +-- 8 files changed, 126 insertions(+), 39 deletions(-) create mode 100644 common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala create mode 100644 core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala diff --git a/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala new file mode 100644 index 0000000000000..08997a800c957 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.immutable + +/** + * Implicit methods related to Scala Array. + */ +private[spark] object ArrayImplicits { + + implicit class SparkArrayOps[T](xs: Array[T]) { + + /** + * Wraps an Array[T] as an immutable.ArraySeq[T] without copying. + */ + def toImmutableArraySeq: immutable.ArraySeq[T] = + if (xs eq null) null + else immutable.ArraySeq.unsafeWrapArray(xs) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1cc1c8400fa89..34756f9a440bb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,7 +21,6 @@ import java.net.URI import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicLong, AtomicReference} -import scala.collection.immutable import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag @@ -45,6 +44,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ /** * The entry point to programming Spark with the Dataset and DataFrame API. @@ -248,7 +248,7 @@ class SparkSession private[sql] ( proto.SqlCommand .newBuilder() .setSql(sqlText) - .addAllPosArguments(immutable.ArraySeq.unsafeWrapArray(args.map(lit(_).expr)).asJava))) + .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) // .toBuffer forces that the iterator is consumed and closed val responseSeq = client.execute(plan.build()).toBuffer.toSeq diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 3e53722caeb07..652797bc2e40f 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.client import java.time.DateTimeException -import scala.collection.immutable import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -37,6 +36,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.streaming.StreamingQueryException +import org.apache.spark.util.ArrayImplicits._ /** * GrpcExceptionConverter handles the conversion of StatusRuntimeExceptions into Spark exceptions. @@ -375,7 +375,7 @@ private[client] object GrpcExceptionConverter { FetchErrorDetailsResponse.Error .newBuilder() .setMessage(message) - .addAllErrorTypeHierarchy(immutable.ArraySeq.unsafeWrapArray(classes).asJava) + .addAllErrorTypeHierarchy(classes.toImmutableArraySeq.asJava) .build())) } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index ec57909ad144e..018e293795e9d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connect.planner -import scala.collection.immutable import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.util.Try @@ -80,6 +79,7 @@ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.CacheId +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils final case class InvalidCommandInput( @@ -3184,9 +3184,9 @@ class SparkConnectPlanner( case StreamingQueryManagerCommand.CommandCase.ACTIVE => val active_queries = session.streams.active respBuilder.getActiveBuilder.addAllActiveQueries( - immutable.ArraySeq - .unsafeWrapArray(active_queries - .map(query => buildStreamingQueryInstance(query))) + active_queries + .map(query => buildStreamingQueryInstance(query)) + .toImmutableArraySeq .asJava) case StreamingQueryManagerCommand.CommandCase.GET_QUERY => @@ -3265,15 +3265,16 @@ class SparkConnectPlanner( .setGetResourcesCommandResult( proto.GetResourcesCommandResult .newBuilder() - .putAllResources(session.sparkContext.resources.view - .mapValues(resource => - proto.ResourceInformation - .newBuilder() - .setName(resource.name) - .addAllAddresses(immutable.ArraySeq.unsafeWrapArray(resource.addresses).asJava) - .build()) - .toMap - .asJava) + .putAllResources( + session.sparkContext.resources.view + .mapValues(resource => + proto.ResourceInformation + .newBuilder() + .setName(resource.name) + .addAllAddresses(resource.addresses.toImmutableArraySeq.asJava) + .build()) + .toMap + .asJava) .build()) .build()) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 837ee5a00227c..744fa3c8aa1a4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.connect.utils import java.util.UUID import scala.annotation.tailrec -import scala.collection.immutable import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ @@ -43,6 +42,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.{ExecuteEventsManager, SessionHolder, SessionKey, SparkConnectService} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ArrayImplicits._ private[connect] object ErrorUtils extends Logging { @@ -91,21 +91,21 @@ private[connect] object ErrorUtils extends Logging { if (serverStackTraceEnabled) { builder.addAllStackTrace( - immutable.ArraySeq - .unsafeWrapArray(currentError.getStackTrace - .map { stackTraceElement => - val stackTraceBuilder = FetchErrorDetailsResponse.StackTraceElement - .newBuilder() - .setDeclaringClass(stackTraceElement.getClassName) - .setMethodName(stackTraceElement.getMethodName) - .setLineNumber(stackTraceElement.getLineNumber) - - if (stackTraceElement.getFileName != null) { - stackTraceBuilder.setFileName(stackTraceElement.getFileName) - } - - stackTraceBuilder.build() - }) + currentError.getStackTrace + .map { stackTraceElement => + val stackTraceBuilder = FetchErrorDetailsResponse.StackTraceElement + .newBuilder() + .setDeclaringClass(stackTraceElement.getClassName) + .setMethodName(stackTraceElement.getMethodName) + .setLineNumber(stackTraceElement.getLineNumber) + + if (stackTraceElement.getFileName != null) { + stackTraceBuilder.setFileName(stackTraceElement.getFileName) + } + + stackTraceBuilder.build() + } + .toImmutableArraySeq .asJava) } diff --git a/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala new file mode 100644 index 0000000000000..135af550c4b39 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.immutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ArrayImplicits._ + +class ArrayImplicitsSuite extends SparkFunSuite { + + test("Int Array") { + val data = Array(1, 2, 3) + val arraySeq = data.toImmutableArraySeq + assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofInt]) + assert(arraySeq.length === 3) + assert(arraySeq.unsafeArray.sameElements(data)) + } + + test("TestClass Array") { + val data = Array(TestClass(1), TestClass(2), TestClass(3)) + val arraySeq = data.toImmutableArraySeq + assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofRef[TestClass]]) + assert(arraySeq.length === 3) + assert(arraySeq.unsafeArray.sameElements(data)) + } + + test("Null Array") { + val data: Array[Int] = null + val arraySeq = data.toImmutableArraySeq + assert(arraySeq == null) + } + + case class TestClass(i: Int) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index 1eed97a8d4f65..2f3f396730be2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.api.python -import scala.collection.immutable import scala.jdk.CollectionConverters._ import org.apache.spark.SparkContext import org.apache.spark.mllib.clustering.GaussianMixtureModel import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.util.ArrayImplicits._ /** * Wrapper around GaussianMixtureModel to provide helper methods in Python @@ -38,7 +38,7 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { val modelGaussians = model.gaussians.map { gaussian => Array[Any](gaussian.mu, gaussian.sigma) } - SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(modelGaussians).asJava) + SerDe.dumps(modelGaussians.toImmutableArraySeq.asJava) } def predictSoft(point: Vector): Vector = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala index b919b0a8c3f2e..6a6c6cf6bcfb3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala @@ -16,12 +16,12 @@ */ package org.apache.spark.mllib.api.python -import scala.collection.immutable import scala.jdk.CollectionConverters._ import org.apache.spark.SparkContext import org.apache.spark.mllib.clustering.LDAModel import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.util.ArrayImplicits._ /** * Wrapper around LDAModel to provide helper methods in Python @@ -36,11 +36,11 @@ private[python] class LDAModelWrapper(model: LDAModel) { def describeTopics(maxTermsPerTopic: Int): Array[Byte] = { val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) => - val jTerms = immutable.ArraySeq.unsafeWrapArray(terms).asJava - val jTermWeights = immutable.ArraySeq.unsafeWrapArray(termWeights).asJava + val jTerms = terms.toImmutableArraySeq.asJava + val jTermWeights = termWeights.toImmutableArraySeq.asJava Array[Any](jTerms, jTermWeights) } - SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(topics).asJava) + SerDe.dumps(topics.toImmutableArraySeq.asJava) } def save(sc: SparkContext, path: String): Unit = model.save(sc, path) From 5970d353360d4fb6647c8fbc10f733abf009eca1 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 2 Nov 2023 23:04:07 +0800 Subject: [PATCH 006/121] [SPARK-45767][CORE] Delete `TimeStampedHashMap` and its UT ### What changes were proposed in this pull request? The pr aims to delete `TimeStampedHashMap` and its UT. ### Why are the changes needed? During Pr https://github.com/apache/spark/pull/43578, we found that the class `TimeStampedHashMap` is no longer in use. Based on the suggestion, we have removed it. https://github.com/apache/spark/pull/43578#discussion_r1378687555 - First time this class `TimeStampedHashMap` be introduced: https://github.com/apache/spark/commit/b18d70870a33a4783c6b3b787bef9b0eec30bce0#diff-77b12178a7036c71135074c6ddf7d659e5a69906264d5e3061087e4352e304ed introduced this data structure - After https://github.com/apache/spark/pull/22339, this class `TimeStampedHashMap` is only being used in UT `TimeStampedHashMapSuite`. So, after Spark 3.0, this data structure has not been used by any production code of Spark. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43633 from panbingkun/remove_TimeStampedHashMap. Authored-by: panbingkun Signed-off-by: yangjie01 --- .../spark/util/TimeStampedHashMap.scala | 143 -------------- .../spark/util/TimeStampedHashMapSuite.scala | 179 ------------------ 2 files changed, 322 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala deleted file mode 100644 index b0fb339465205..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.util.Map.Entry -import java.util.Set -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.mutable -import scala.jdk.CollectionConverters._ - -import org.apache.spark.internal.Logging - -private[spark] case class TimeStampedValue[V](value: V, timestamp: Long) - -/** - * This is a custom implementation of scala.collection.mutable.Map which stores the insertion - * timestamp along with each key-value pair. If specified, the timestamp of each pair can be - * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular - * threshold time can then be removed using the clearOldValues method. This is intended to - * be a drop-in replacement of scala.collection.mutable.HashMap. - * - * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed - */ -private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends mutable.Map[A, B]() with Logging { - - // Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation. - - private val internalMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() - - def get(key: A): Option[B] = { - val value = internalMap.get(key) - if (value != null && updateTimeStampOnGet) { - internalMap.replace(key, value, TimeStampedValue(value.value, currentTime)) - } - Option(value).map(_.value) - } - - def iterator: Iterator[(A, B)] = { - getEntrySet.iterator.asScala.map(kv => (kv.getKey, kv.getValue.value)) - } - - def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet - - override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { - val newMap = new TimeStampedHashMap[A, B1] - val oldInternalMap = this.internalMap.asInstanceOf[ConcurrentHashMap[A, TimeStampedValue[B1]]] - newMap.internalMap.putAll(oldInternalMap) - kv match { case (a, b) => newMap.internalMap.put(a, TimeStampedValue(b, currentTime)) } - newMap - } - - override def addOne(kv: (A, B)): this.type = { - kv match { case (a, b) => internalMap.put(a, TimeStampedValue(b, currentTime)) } - this - } - - override def subtractOne(key: A): this.type = { - internalMap.remove(key) - this - } - - override def update(key: A, value: B): Unit = { - this += ((key, value)) - } - - override def apply(key: A): B = { - get(key).getOrElse { throw new NoSuchElementException() } - } - - override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { - internalMap.asScala.map { case (k, TimeStampedValue(v, t)) => (k, v) }.filter(p) - } - - override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]() - - override def size: Int = internalMap.size - - override def foreach[U](f: ((A, B)) => U): Unit = { - val it = getEntrySet.iterator - while(it.hasNext) { - val entry = it.next() - val kv = (entry.getKey, entry.getValue.value) - f(kv) - } - } - - def putIfAbsent(key: A, value: B): Option[B] = { - val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime)) - Option(prev).map(_.value) - } - - def putAll(map: Map[A, B]): Unit = { - map.foreach { case (k, v) => update(k, v) } - } - - def toMap: Map[A, B] = iterator.toMap - - def clearOldValues(threshTime: Long, f: (A, B) => Unit): Unit = { - val it = getEntrySet.iterator - while (it.hasNext) { - val entry = it.next() - if (entry.getValue.timestamp < threshTime) { - f(entry.getKey, entry.getValue.value) - logDebug("Removing key " + entry.getKey) - it.remove() - } - } - } - - /** Removes old key-value pairs that have timestamp earlier than `threshTime`. */ - def clearOldValues(threshTime: Long): Unit = { - clearOldValues(threshTime, (_, _) => ()) - } - - private def currentTime: Long = System.currentTimeMillis - - // For testing - - def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = { - Option(internalMap.get(key)) - } - - def getTimestamp(key: A): Option[Long] = { - getTimeStampedValue(key).map(_.timestamp) - } -} diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala deleted file mode 100644 index 1644540946839..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.util.Random - -import org.apache.spark.SparkFunSuite - -class TimeStampedHashMapSuite extends SparkFunSuite { - - // Test the testMap function - a Scala HashMap should obviously pass - testMap(new mutable.HashMap[String, String]()) - - // Test TimeStampedHashMap basic functionality - testMap(new TimeStampedHashMap[String, String]()) - testMapThreadSafety(new TimeStampedHashMap[String, String]()) - - test("TimeStampedHashMap - clearing by timestamp") { - // clearing by insertion time - val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false) - map("k1") = "v1" - assert(map("k1") === "v1") - Thread.sleep(10) - val threshTime = System.currentTimeMillis - assert(map.getTimestamp("k1").isDefined) - assert(map.getTimestamp("k1").get < threshTime) - map.clearOldValues(threshTime) - assert(map.get("k1") === None) - - // clearing by modification time - val map1 = new TimeStampedHashMap[String, String](updateTimeStampOnGet = true) - map1("k1") = "v1" - map1("k2") = "v2" - assert(map1("k1") === "v1") - Thread.sleep(10) - val threshTime1 = System.currentTimeMillis - Thread.sleep(10) - assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime - assert(map1.getTimestamp("k1").isDefined) - assert(map1.getTimestamp("k1").get < threshTime1) - assert(map1.getTimestamp("k2").isDefined) - assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) // should only clear k1 - assert(map1.get("k1") === None) - assert(map1.get("k2").isDefined) - } - - /** Test basic operations of a Scala mutable Map. */ - def testMap(hashMapConstructor: => mutable.Map[String, String]): Unit = { - def newMap() = hashMapConstructor - val testMap1 = newMap() - val testMap2 = newMap() - val name = testMap1.getClass.getSimpleName - - test(name + " - basic test") { - // put, get, and apply - testMap1 += (("k1", "v1")) - assert(testMap1.get("k1").isDefined) - assert(testMap1("k1") === "v1") - testMap1("k2") = "v2" - assert(testMap1.get("k2").isDefined) - assert(testMap1("k2") === "v2") - assert(testMap1("k2") === "v2") - testMap1.update("k3", "v3") - assert(testMap1.get("k3").isDefined) - assert(testMap1("k3") === "v3") - - // remove - testMap1.remove("k1") - assert(testMap1.get("k1").isEmpty) - testMap1.remove("k2") - intercept[NoSuchElementException] { - testMap1("k2") // Map.apply() causes exception - } - testMap1 -= "k3" - assert(testMap1.get("k3").isEmpty) - - // multi put - val keys = (1 to 100).map(_.toString) - val pairs = keys.map(x => (x, x * 2)) - assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) - testMap2 ++= pairs - - // iterator - assert(testMap2.iterator.toSet === pairs.toSet) - - // filter - val filtered = testMap2.filter { case (_, v) => v.toInt % 2 == 0 } - val evenPairs = pairs.filter { case (_, v) => v.toInt % 2 == 0 } - assert(filtered.iterator.toSet === evenPairs.toSet) - - // foreach - val buffer = new ArrayBuffer[(String, String)] - testMap2.foreach(x => buffer += x) - assert(testMap2.toSet === buffer.toSet) - - // multi remove - testMap2("k1") = "v1" - testMap2 --= keys - assert(testMap2.size === 1) - assert(testMap2.iterator.toSeq.head === (("k1", "v1"))) - - // + - val testMap3 = testMap2 + (("k0", "v0")) - assert(testMap3.size === 2) - assert(testMap3.get("k1").isDefined) - assert(testMap3("k1") === "v1") - assert(testMap3.get("k0").isDefined) - assert(testMap3("k0") === "v0") - - // - - val testMap4 = testMap3 - "k0" - assert(testMap4.size === 1) - assert(testMap4.get("k1").isDefined) - assert(testMap4("k1") === "v1") - } - } - - /** Test thread safety of a Scala mutable map. */ - def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]): Unit = { - def newMap() = hashMapConstructor - val name = newMap().getClass.getSimpleName - val testMap = newMap() - @volatile var error = false - - def getRandomKey(m: mutable.Map[String, String]): Option[String] = { - val keys = testMap.keysIterator.toSeq - if (keys.nonEmpty) { - Some(keys(Random.nextInt(keys.size))) - } else { - None - } - } - - val threads = (1 to 25).map(i => new Thread() { - override def run(): Unit = { - try { - for (j <- 1 to 1000) { - Random.nextInt(3) match { - case 0 => - testMap(Random.nextString(10)) = Random.nextDouble().toString // put - case 1 => - getRandomKey(testMap).map(testMap.get) // get - case 2 => - getRandomKey(testMap).map(testMap.remove) // remove - } - } - } catch { - case t: Throwable => - error = true - throw t - } - } - }) - - test(name + " - threading safety test") { - threads.foreach(_.start()) - threads.foreach(_.join()) - assert(!error) - } - } -} From 653b31e18b3fc2546bd6b13b384459f8afddabdc Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 2 Nov 2023 13:14:04 -0700 Subject: [PATCH 007/121] [SPARK-45771][CORE] Enable `spark.eventLog.rolling.enabled` by default ### What changes were proposed in this pull request? This PR aims to enable `spark.eventLog.rolling.enabled` by default for Apache Spark 4.0.0. ### Why are the changes needed? Since Apache Spark 3.0.0, we have been using event log rolling not only for **long-running jobs**, but also for **some failed jobs** to archive the partial event logs incrementally. - https://github.com/apache/spark/pull/25670 ### Does this PR introduce _any_ user-facing change? - No because `spark.eventLog.enabled` is disabled by default. - For the users with `spark.eventLog.enabled=true`, yes, `spark-events` directory will have different layouts. However, all 3.3+ `Spark History Server` can read both old and new event logs. I believe that the event log users are already using this configuration to avoid the loss of event logs for long-running jobs and some failed jobs. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43638 from dongjoon-hyun/SPARK-45771. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/internal/config/package.scala | 2 +- .../test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 1 + .../apache/spark/deploy/history/EventLogFileWritersSuite.scala | 2 +- .../org/apache/spark/deploy/history/EventLogTestHelper.scala | 1 + .../org/apache/spark/scheduler/EventLoggingListenerSuite.scala | 3 ++- docs/core-migration-guide.md | 2 ++ 6 files changed, 8 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7b0fcf3433cf6..143dd0c44ce84 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -238,7 +238,7 @@ package object config { "each event log file to the configured size.") .version("3.0.0") .booleanConf - .createWithDefault(false) + .createWithDefault(true) private[spark] val EVENT_LOG_ROLLING_MAX_FILE_SIZE = ConfigBuilder("spark.eventLog.rolling.maxFileSize") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 88f015f864def..7ebb0165e620a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -731,6 +731,7 @@ class SparkSubmitSuite "--conf", "spark.master.rest.enabled=false", "--conf", "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password", "--conf", "spark.eventLog.enabled=true", + "--conf", "spark.eventLog.rolling.enabled=false", "--conf", "spark.eventLog.testing=true", "--conf", s"spark.eventLog.dir=${testDirPath.toUri.toString}", "--conf", "spark.hadoop.fs.defaultFS=unsupported://example.com", diff --git a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala index 455e2e18b11e1..b575cbc080c07 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala @@ -66,7 +66,7 @@ abstract class EventLogFileWritersSuite extends SparkFunSuite with LocalSparkCon conf.set(EVENT_LOG_DIR, testDir.toString) // default config - buildWriterAndVerify(conf, classOf[SingleEventLogFileWriter]) + buildWriterAndVerify(conf, classOf[RollingEventLogFilesWriter]) conf.set(EVENT_LOG_ENABLE_ROLLING, true) buildWriterAndVerify(conf, classOf[RollingEventLogFilesWriter]) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/EventLogTestHelper.scala b/core/src/test/scala/org/apache/spark/deploy/history/EventLogTestHelper.scala index ea8da01085920..ac89f60955eed 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/EventLogTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/EventLogTestHelper.scala @@ -38,6 +38,7 @@ object EventLogTestHelper { def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None): SparkConf = { val conf = new SparkConf conf.set(EVENT_LOG_ENABLED, true) + conf.set(EVENT_LOG_ENABLE_ROLLING, false) conf.set(EVENT_LOG_BLOCK_UPDATES, true) conf.set(EVENT_LOG_TESTING, true) conf.set(EVENT_LOG_DIR, logDir.toString) diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index cd8fac2c65701..939923e12b8e7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.{EventLogFileReader, SingleEventLogFileWriter} import org.apache.spark.deploy.history.EventLogTestHelper._ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} -import org.apache.spark.internal.config.{EVENT_LOG_COMPRESS, EVENT_LOG_DIR, EVENT_LOG_ENABLED} +import org.apache.spark.internal.config.{EVENT_LOG_COMPRESS, EVENT_LOG_DIR, EVENT_LOG_ENABLE_ROLLING, EVENT_LOG_ENABLED} import org.apache.spark.io._ import org.apache.spark.metrics.{ExecutorMetricType, MetricsSystem} import org.apache.spark.resource.ResourceProfile @@ -163,6 +163,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit test("SPARK-31764: isBarrier should be logged in event log") { val conf = new SparkConf() conf.set(EVENT_LOG_ENABLED, true) + conf.set(EVENT_LOG_ENABLE_ROLLING, false) conf.set(EVENT_LOG_COMPRESS, false) conf.set(EVENT_LOG_DIR, testDirPath.toString) val sc = new SparkContext("local", "test-SPARK-31764", conf) diff --git a/docs/core-migration-guide.md b/docs/core-migration-guide.md index fb9471d0c1ae6..09ba4b474e975 100644 --- a/docs/core-migration-guide.md +++ b/docs/core-migration-guide.md @@ -24,6 +24,8 @@ license: | ## Upgrading from Core 3.5 to 4.0 +- Since Spark 4.0, Spark will roll event logs to archive them incrementally. To restore the behavior before Spark 4.0, you can set `spark.eventLog.rolling.enabled` to `false`. + - Since Spark 4.0, Spark will compress event logs. To restore the behavior before Spark 4.0, you can set `spark.eventLog.compress` to `false`. - Since Spark 4.0, `spark.shuffle.service.db.backend` is set to `ROCKSDB` by default which means Spark will use RocksDB store for shuffle service. To restore the behavior before Spark 4.0, you can set `spark.shuffle.service.db.backend` to `LEVELDB`. From eba37d2d408ba21e849c1a945a6620b66cc299a9 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Thu, 2 Nov 2023 14:22:37 -0700 Subject: [PATCH 008/121] [SPARK-45718][PS] Remove remaining deprecated Pandas features from Spark 3.4.0 ### What changes were proposed in this pull request? This PR proposes to remove remaining deprecated Pandas features from Spark 3.4.0 ### Why are the changes needed? To match the behavior of Pandas. We cleaned up most of APIs, but there are still some features that deprecated from Spark 3.4.0 need to be removed. ### Does this PR introduce _any_ user-facing change? Yes, some parameters and APIs are removed from Spark 4.0.0. ### How was this patch tested? The existing CI should pass. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43581 from itholic/SPARK-45718. Authored-by: Haejoon Lee Signed-off-by: Dongjoon Hyun --- .../migration_guide/pyspark_upgrade.rst | 9 ++++ .../reference/pyspark.pandas/groupby.rst | 1 - .../reference/pyspark.pandas/indexing.rst | 2 - python/pyspark/pandas/generic.py | 13 ------ python/pyspark/pandas/groupby.py | 28 ------------- python/pyspark/pandas/indexes/base.py | 41 ------------------- python/pyspark/pandas/indexes/multi.py | 22 ---------- python/pyspark/pandas/namespace.py | 29 ------------- python/pyspark/pandas/tests/test_csv.py | 5 --- 9 files changed, 9 insertions(+), 141 deletions(-) diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index 20fab57850463..06991281bf07a 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -54,6 +54,15 @@ Upgrading from PySpark 3.5 to 4.0 * In Spark 4.0, ``DataFrame.to_spark_io`` has been removed from pandas API on Spark, use ``DataFrame.spark.to_spark_io`` instead. * In Spark 4.0, ``Series.is_monotonic`` and ``Index.is_monotonic`` have been removed from pandas API on Spark, use ``Series.is_monotonic_increasing`` or ``Index.is_monotonic_increasing`` instead respectively. * In Spark 4.0, ``DataFrame.get_dtype_counts`` has been removed from pandas API on Spark, use ``DataFrame.dtypes.value_counts()`` instead. +* In Spark 4.0, ``encoding`` parameter from ``DataFrame.to_excel`` and ``Series.to_excel`` have been removed from pandas API on Spark. +* In Spark 4.0, ``verbose`` parameter from ``DataFrame.to_excel`` and ``Series.to_excel`` have been removed from pandas API on Spark. +* In Spark 4.0, ``mangle_dupe_cols`` parameter from ``read_csv`` has been removed from pandas API on Spark. +* In Spark 4.0, ``DataFrameGroupBy.backfill`` has been removed from pandas API on Spark, use ``DataFrameGroupBy.bfill`` instead. +* In Spark 4.0, ``DataFrameGroupBy.pad`` has been removed from pandas API on Spark, use ``DataFrameGroupBy.ffill`` instead. +* In Spark 4.0, ``Index.is_all_dates`` has been removed from pandas API on Spark. +* In Spark 4.0, ``convert_float`` parameter from ``read_excel`` has been removed from pandas API on Spark. +* In Spark 4.0, ``mangle_dupe_cols`` parameter from ``read_excel`` has been removed from pandas API on Spark. + Upgrading from PySpark 3.3 to 3.4 diff --git a/python/docs/source/reference/pyspark.pandas/groupby.rst b/python/docs/source/reference/pyspark.pandas/groupby.rst index e71e81c56dd3e..7a0c771e8caac 100644 --- a/python/docs/source/reference/pyspark.pandas/groupby.rst +++ b/python/docs/source/reference/pyspark.pandas/groupby.rst @@ -89,7 +89,6 @@ Computations / Descriptive Stats GroupBy.bfill GroupBy.ffill GroupBy.head - GroupBy.backfill GroupBy.shift GroupBy.tail diff --git a/python/docs/source/reference/pyspark.pandas/indexing.rst b/python/docs/source/reference/pyspark.pandas/indexing.rst index 08f5e224e06eb..71584892ca38d 100644 --- a/python/docs/source/reference/pyspark.pandas/indexing.rst +++ b/python/docs/source/reference/pyspark.pandas/indexing.rst @@ -43,7 +43,6 @@ Properties Index.hasnans Index.dtype Index.inferred_type - Index.is_all_dates Index.shape Index.name Index.names @@ -219,7 +218,6 @@ MultiIndex Properties MultiIndex.has_duplicates MultiIndex.hasnans MultiIndex.inferred_type - MultiIndex.is_all_dates MultiIndex.shape MultiIndex.names MultiIndex.ndim diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py index 16eaeb6142e55..231397628822f 100644 --- a/python/pyspark/pandas/generic.py +++ b/python/pyspark/pandas/generic.py @@ -990,9 +990,7 @@ def to_excel( startcol: int = 0, engine: Optional[str] = None, merge_cells: bool = True, - encoding: Optional[str] = None, inf_rep: str = "inf", - verbose: bool = True, freeze_panes: Optional[Tuple[int, int]] = None, ) -> None: """ @@ -1043,20 +1041,9 @@ def to_excel( ``io.excel.xlsm.writer``. merge_cells: bool, default True Write MultiIndex and Hierarchical Rows as merged cells. - encoding: str, optional - Encoding of the resulting excel file. Only necessary for xlwt, - other writers support unicode natively. - - .. deprecated:: 3.4.0 - inf_rep: str, default 'inf' Representation for infinity (there is no native representation for infinity in Excel). - verbose: bool, default True - Display more information in the error logs. - - .. deprecated:: 3.4.0 - freeze_panes: tuple of int (length 2), optional Specifies the one-based bottommost row and rightmost column that is to be frozen. diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index b19a40b837a0a..4cce147b2606e 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -2614,20 +2614,6 @@ def bfill(self, limit: Optional[int] = None) -> FrameLike: """ return self.fillna(method="bfill", limit=limit) - def backfill(self, limit: Optional[int] = None) -> FrameLike: - """ - Alias for bfill. - - .. deprecated:: 3.4.0 - """ - warnings.warn( - "The GroupBy.backfill method is deprecated " - "and will be removed in a future version. " - "Use GroupBy.bfill instead.", - FutureWarning, - ) - return self.bfill(limit=limit) - def ffill(self, limit: Optional[int] = None) -> FrameLike: """ Synonym for `DataFrame.fillna()` with ``method=`ffill```. @@ -2677,20 +2663,6 @@ def ffill(self, limit: Optional[int] = None) -> FrameLike: """ return self.fillna(method="ffill", limit=limit) - def pad(self, limit: Optional[int] = None) -> FrameLike: - """ - Alias for ffill. - - .. deprecated:: 3.4.0 - """ - warnings.warn( - "The GroupBy.pad method is deprecated " - "and will be removed in a future version. " - "Use GroupBy.ffill instead.", - FutureWarning, - ) - return self.ffill(limit=limit) - def _limit(self, n: int, asc: bool) -> FrameLike: """ Private function for tail and head. diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index 2ec0a39dc7135..6c6ee9ae0d7dc 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -2118,47 +2118,6 @@ def difference(self, other: "Index", sort: Optional[bool] = None) -> "Index": result.name = self.name return result if sort is None else cast(Index, result.sort_values()) - @property - def is_all_dates(self) -> bool: - """ - Return if all data types of the index are datetime. - remember that since pandas-on-Spark does not support multiple data types in an index, - so it returns True if any type of data is datetime. - - .. deprecated:: 3.4.0 - - Examples - -------- - >>> from datetime import datetime - - >>> idx = ps.Index([datetime(2019, 1, 1, 0, 0, 0), datetime(2019, 2, 3, 0, 0, 0)]) - >>> idx - DatetimeIndex(['2019-01-01', '2019-02-03'], dtype='datetime64[ns]', freq=None) - - >>> idx.is_all_dates - True - - >>> idx = ps.Index([datetime(2019, 1, 1, 0, 0, 0), None]) - >>> idx - DatetimeIndex(['2019-01-01', 'NaT'], dtype='datetime64[ns]', freq=None) - - >>> idx.is_all_dates - True - - >>> idx = ps.Index([0, 1, 2]) - >>> idx - Index([0, 1, 2], dtype='int64') - - >>> idx.is_all_dates - False - """ - warnings.warn( - "Index.is_all_dates is deprecated, will be removed in a future version. " - "check index.inferred_type instead", - FutureWarning, - ) - return isinstance(self.spark.data_type, (TimestampType, TimestampNTZType)) - def repeat(self, repeats: int) -> "Index": """ Repeat elements of a Index/MultiIndex. diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index 9917a42fb3857..41c3b93ed51b6 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -975,28 +975,6 @@ def asof(self, label: Any) -> None: "only the default get_loc method is currently supported for MultiIndex" ) - @property - def is_all_dates(self) -> bool: - """ - is_all_dates always returns False for MultiIndex - - Examples - -------- - >>> from datetime import datetime - - >>> idx = ps.MultiIndex.from_tuples( - ... [(datetime(2019, 1, 1, 0, 0, 0), datetime(2019, 1, 1, 0, 0, 0)), - ... (datetime(2019, 1, 1, 0, 0, 0), datetime(2019, 1, 1, 0, 0, 0))]) - >>> idx # doctest: +SKIP - MultiIndex([('2019-01-01', '2019-01-01'), - ('2019-01-01', '2019-01-01')], - ) - - >>> idx.is_all_dates - False - """ - return False - def __getattr__(self, item: str) -> Any: if hasattr(MissingPandasLikeMultiIndex, item): property_or_func = getattr(MissingPandasLikeMultiIndex, item) diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py index 9b64300e948fd..aa9374b6dceb2 100644 --- a/python/pyspark/pandas/namespace.py +++ b/python/pyspark/pandas/namespace.py @@ -222,7 +222,6 @@ def read_csv( names: Optional[Union[str, List[str]]] = None, index_col: Optional[Union[str, List[str]]] = None, usecols: Optional[Union[List[int], List[str], Callable[[str], bool]]] = None, - mangle_dupe_cols: bool = True, dtype: Optional[Union[str, Dtype, Dict[str, Union[str, Dtype]]]] = None, nrows: Optional[int] = None, parse_dates: bool = False, @@ -261,14 +260,6 @@ def read_csv( from the document header row(s). If callable, the callable function will be evaluated against the column names, returning names where the callable function evaluates to `True`. - mangle_dupe_cols : bool, default True - Duplicate columns will be specified as 'X0', 'X1', ... 'XN', rather - than 'X' ... 'X'. Passing in False will cause data to be overwritten if - there are duplicate names in the columns. - Currently only `True` is allowed. - - .. deprecated:: 3.4.0 - dtype : Type name or dict of column -> type, default None Data type for data or columns. E.g. {‘a’: np.float64, ‘b’: np.int32} Use str or object together with suitable na_values settings to preserve and not interpret dtype. @@ -310,8 +301,6 @@ def read_csv( if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: options = options.get("options") - if mangle_dupe_cols is not True: - raise ValueError("mangle_dupe_cols can only be `True`: %s" % mangle_dupe_cols) if parse_dates is not False: raise ValueError("parse_dates can only be `False`: %s" % parse_dates) @@ -917,8 +906,6 @@ def read_excel( thousands: Optional[str] = None, comment: Optional[str] = None, skipfooter: int = 0, - convert_float: bool = True, - mangle_dupe_cols: bool = True, **kwds: Any, ) -> Union[DataFrame, Series, Dict[str, Union[DataFrame, Series]]]: """ @@ -1041,20 +1028,6 @@ def read_excel( comment string and the end of the current line is ignored. skipfooter : int, default 0 Rows at the end to skip (0-indexed). - convert_float : bool, default True - Convert integral floats to int (i.e., 1.0 --> 1). If False, all numeric - data will be read in as floats: Excel stores all numbers as floats - internally. - - .. deprecated:: 3.4.0 - - mangle_dupe_cols : bool, default True - Duplicate columns will be specified as 'X', 'X.1', ...'X.N', rather than - 'X'...'X'. Passing in False will cause data to be overwritten if there - are duplicate names in the columns. - - .. deprecated:: 3.4.0 - **kwds : optional Optional keyword arguments can be passed to ``TextFileReader``. @@ -1150,8 +1123,6 @@ def pd_read_excel( thousands=thousands, comment=comment, skipfooter=skipfooter, - convert_float=convert_float, - mangle_dupe_cols=mangle_dupe_cols, **kwds, ) diff --git a/python/pyspark/pandas/tests/test_csv.py b/python/pyspark/pandas/tests/test_csv.py index a62388050472c..e35b49315712e 100644 --- a/python/pyspark/pandas/tests/test_csv.py +++ b/python/pyspark/pandas/tests/test_csv.py @@ -254,11 +254,6 @@ def test_read_csv_with_sep(self): actual = ps.read_csv(fn, sep="\t") self.assert_eq(expected, actual, almost=True) - def test_read_csv_with_mangle_dupe_cols(self): - self.assertRaisesRegex( - ValueError, "mangle_dupe_cols", lambda: ps.read_csv("path", mangle_dupe_cols=False) - ) - def test_read_csv_with_parse_dates(self): self.assertRaisesRegex( ValueError, "parse_dates", lambda: ps.read_csv("path", parse_dates=True) From cdc66a71ea3d5da0f3f3ef7eabb215336ae472a5 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 2 Nov 2023 14:27:14 -0700 Subject: [PATCH 009/121] [SPARK-45757][ML] Avoid re-computation of NNZ in Binarizer ### What changes were proposed in this pull request? 1, compress vectors with given nnz in Binarizer; 2, rename internal function `def compressed(nnz: Int): Vector` to avoid ambiguous reference issue (`vec.compressed.apply(nnz)`) when there is no type hint ``` [error] /Users/ruifeng.zheng/Dev/spark/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala:132:61: ambiguous reference to overloaded definition, [error] both method compressed in trait Vector of type (nnz: Int): org.apache.spark.ml.linalg.Vector [error] and method compressed in trait Vector of type org.apache.spark.ml.linalg.Vector ``` ### Why are the changes needed? `nnz` is known before compression ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #43619 from zhengruifeng/ml_binarizer_nnz. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/ml/linalg/Vectors.scala | 4 +-- .../apache/spark/ml/feature/Binarizer.scala | 26 ++++++++++++------- .../spark/ml/feature/VectorAssembler.scala | 6 +++-- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 016a8366ab868..d8e17ddd24db7 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -184,9 +184,9 @@ sealed trait Vector extends Serializable { * Returns a vector in either dense or sparse format, whichever uses less storage. */ @Since("2.0.0") - def compressed: Vector = compressed(numNonzeros) + def compressed: Vector = compressedWithNNZ(numNonzeros) - private[ml] def compressed(nnz: Int): Vector = { + private[ml] def compressedWithNNZ(nnz: Int): Vector = { // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes. if (1.5 * (nnz + 1.0) < size) { toSparseWithSize(nnz) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 2ec7a8632e39d..2e09e74449572 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -112,10 +112,11 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) (Seq($(inputCol)), Seq($(outputCol)), Seq($(threshold))) } - val mappedOutputCols = inputColNames.zip(tds).map { case (inputColName, td) => - val binarizerUDF = dataset.schema(inputColName).dataType match { + val mappedOutputCols = inputColNames.zip(tds).map { case (colName, td) => + dataset.schema(colName).dataType match { case DoubleType => - udf { in: Double => if (in > td) 1.0 else 0.0 } + when(!col(colName).isNaN && col(colName) > td, lit(1.0)) + .otherwise(lit(0.0)) case _: VectorUDT if td >= 0 => udf { vector: Vector => @@ -124,27 +125,32 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) vector.foreachNonZero { (index, value) => if (value > td) { indices += index - values += 1.0 + values += 1.0 } } - Vectors.sparse(vector.size, indices.result(), values.result()).compressed - } + + val idxArray = indices.result() + val valArray = values.result() + Vectors.sparse(vector.size, idxArray, valArray) + .compressedWithNNZ(idxArray.length) + }.apply(col(colName)) case _: VectorUDT if td < 0 => this.logWarning(s"Binarization operations on sparse dataset with negative threshold " + s"$td will build a dense output, so take care when applying to sparse input.") udf { vector: Vector => val values = Array.fill(vector.size)(1.0) + var nnz = vector.size vector.foreachNonZero { (index, value) => if (value <= td) { values(index) = 0.0 + nnz -= 1 } } - Vectors.dense(values).compressed - } - } - binarizerUDF(col(inputColName)) + Vectors.dense(values).compressedWithNNZ(nnz) + }.apply(col(colName)) + } } val outputMetadata = outputColNames.map(outputSchema(_).metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 761352e34a3eb..cf5b5ecb20148 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -296,7 +296,9 @@ object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") } - val (idxArray, valArray) = (indices.result(), values.result()) - Vectors.sparse(featureIndex, idxArray, valArray).compressed(idxArray.length) + val idxArray = indices.result() + val valArray = values.result() + Vectors.sparse(featureIndex, idxArray, valArray) + .compressedWithNNZ(idxArray.length) } } From eda9911057b893e42f49dbd7448f20f91f2798c4 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 3 Nov 2023 09:33:02 +0900 Subject: [PATCH 010/121] [SPARK-45768][SQL][PYTHON] Make faulthandler a runtime configuration for Python execution in SQL ### What changes were proposed in this pull request? This PR proposes to make `faulthandler` as a runtime configuration so we can turn on and off during runtime. ### Why are the changes needed? `faulthandler` feature within PySpark is really useful especially to debug an errors that regular Python interpreter cannot catch out of the box such as a segmentation fault errors, see also https://github.com/apache/spark/pull/43600. It would be very useful to convert this as a runtime configuration without restarting the shell. ### Does this PR introduce _any_ user-facing change? Yes, users can now set `spark.sql.execution.pyspark.udf.faulthandler.enabled` during runtime to enable `faulthandler` feature. ### How was this patch tested? Unittest added ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43635 from HyukjinKwon/runtime-conf. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/api/python/PythonRunner.scala | 2 +- python/pyspark/sql/tests/test_udf.py | 7 +++++++ .../org/apache/spark/sql/internal/SQLConf.scala | 14 ++++++++++++-- .../ApplyInPandasWithStatePythonRunner.scala | 2 ++ .../sql/execution/python/ArrowPythonRunner.scala | 2 ++ .../execution/python/ArrowPythonUDTFRunner.scala | 2 ++ .../python/CoGroupedArrowPythonRunner.scala | 2 ++ .../sql/execution/python/PythonForeachWriter.scala | 4 ++++ .../sql/execution/python/PythonUDFRunner.scala | 2 ++ 9 files changed, 34 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index d265bb2fd8b8e..1a01ad1bc219a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -109,7 +109,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( protected val timelyFlushTimeoutNanos: Long = 0 protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) - private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) + protected val faultHandlerEnabled: Boolean = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) protected val simplifiedTraceback: Boolean = false // All the Python functions should have the same exec, version and envvars. diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 1f895b1780bb6..8b82f591ffc35 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -1020,6 +1020,13 @@ def test_udf(a): with self.assertRaisesRegex(PythonException, "StopIteration"): self.spark.range(10).select(test_udf(col("id"))).show() + def test_python_udf_segfault(self): + with self.sql_conf({"spark.sql.execution.pyspark.udf.faulthandler.enabled": True}): + with self.assertRaisesRegex(Exception, "Segmentation fault"): + import ctypes + + self.spark.range(1).select(udf(lambda x: ctypes.string_at(0))("id")).collect() + class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1f37363648aee..1af0b41d0faac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1504,7 +1504,7 @@ object SQLConf { .booleanConf .createWithDefault(false) - val V2_BUCKETING_SHUFFLE_ENABLED = + val V2_BUCKETING_SHUFFLE_ENABLED = buildConf("spark.sql.sources.v2.bucketing.shuffle.enabled") .doc("During a storage-partitioned join, whether to allow to shuffle only one side." + "When only one side is KeyGroupedPartitioning, if the conditions are met, spark will " + @@ -1514,7 +1514,7 @@ object SQLConf { .booleanConf .createWithDefault(false) - val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS = + val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS = buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled") .doc("Whether to allow storage-partition join in the case where join keys are" + "a subset of the partition keys of the source tables. At planning time, " + @@ -2899,6 +2899,14 @@ object SQLConf { // show full stacktrace in tests but hide in production by default. .createWithDefault(Utils.isTesting) + val PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED = + buildConf("spark.sql.execution.pyspark.udf.faulthandler.enabled") + .doc( + s"Same as ${Python.PYTHON_WORKER_FAULTHANLDER_ENABLED.key} for Python execution with " + + "DataFrame and SQL. It can change during runtime.") + .version("4.0.0") + .fallbackConf(Python.PYTHON_WORKER_FAULTHANLDER_ENABLED) + val ARROW_SPARKR_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.sparkr.enabled") .doc("When true, make use of Apache Arrow for columnar data transfers in SparkR. " + @@ -5168,6 +5176,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED) + def pythonUDFWorkerFaulthandlerEnabled: Boolean = getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED) + def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED) def arrowPySparkFallbackEnabled: Boolean = getConf(ARROW_PYSPARK_FALLBACK_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index e5763b9f230b1..86ebc2e7ef148 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -71,6 +71,8 @@ class ApplyInPandasWithStatePythonRunner( SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head.funcs.head.pythonExec) + override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + private val sqlConf = SQLConf.get // Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 251e682c9e1be..a9eaf79c9db03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -46,6 +46,8 @@ abstract class BaseArrowPythonRunner( SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head.funcs.head.pythonExec) + override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + override val errorOnDuplicatedFieldNames: Boolean = true override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index 690947b41293a..87d1ccb257769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -55,6 +55,8 @@ class ArrowPythonUDTFRunner( SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head.funcs.head.pythonExec) + override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + override val errorOnDuplicatedFieldNames: Boolean = true override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index bd901545bb03c..eb56298bfbee3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -57,6 +57,8 @@ class CoGroupedArrowPythonRunner( SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head.funcs.head.pythonExec) + override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback protected def newWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index f5c2d9acbf298..67b264436fea9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -92,6 +92,10 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head.funcs.head.pythonExec) + + override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index dfcfcecaeb010..167a96ed41c72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -44,6 +44,8 @@ abstract class BasePythonUDFRunner( override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + protected def writeUDF(dataOut: DataOutputStream): Unit protected override def newWriter( From bc9c255e28c323035036a6f75a7bd984e0306b4c Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Fri, 3 Nov 2023 11:14:24 +0800 Subject: [PATCH 011/121] [SPARK-45694][SPARK-45695][SQL] Clean up deprecated API usage `View.force` and `ScalaNumberProxy.signum` ### What changes were proposed in this pull request? Clean up deprecated API usage: 1. ScalaNumberProxy.signum -> use `sign` instead; 2. Map.view.mapValues.view.force.toMap -> replaced to Map.vew.mapValues.toMap; ### Why are the changes needed? Eliminate compile warnings and no longer use deprecated scala APIs. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43637 from ivoson/SPARK-45694. Authored-by: Tengfei Huang Signed-off-by: yangjie01 --- .../catalyst/expressions/EquivalentExpressions.scala | 2 +- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 8738015ce9107..7f43b2b784785 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -191,7 +191,7 @@ class EquivalentExpressions( val skip = useCount == 0 || expr.isInstanceOf[LeafExpression] if (!skip && !updateExprInMap(expr, map, useCount)) { - val uc = useCount.signum + val uc = useCount.sign childrenToRecurse(expr).foreach(updateExprTree(_, map, uc)) commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index c9425b24764cc..017b20077cf66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -362,10 +362,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] case s: Seq[_] => s.map(mapChild) case m: Map[_, _] => - // `map.mapValues().view.force` return `Map` in Scala 2.12 but return `IndexedSeq` in Scala - // 2.13, call `toMap` method manually to compatible with Scala 2.12 and Scala 2.13 - // `mapValues` is lazy and we need to force it to materialize - m.view.mapValues(mapChild).view.force.toMap + // `mapValues` is lazy and we need to force it to materialize by converting to Map + m.view.mapValues(mapChild).toMap case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) case Some(child) => Some(mapChild(child)) case nonChild: AnyRef => nonChild @@ -784,13 +782,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] arg.asInstanceOf[BaseType].clone() case Some(arg: TreeNode[_]) if containsChild(arg) => Some(arg.asInstanceOf[BaseType].clone()) - // `map.mapValues().view.force` return `Map` in Scala 2.12 but return `IndexedSeq` in Scala - // 2.13, call `toMap` method manually to compatible with Scala 2.12 and Scala 2.13 + // `mapValues` is lazy and we need to force it to materialize by converting to Map case m: Map[_, _] => m.view.mapValues { case arg: TreeNode[_] if containsChild(arg) => arg.asInstanceOf[BaseType].clone() case other => other - }.view.force.toMap // `mapValues` is lazy and we need to force it to materialize + }.toMap case d: DataType => d // Avoid unpacking Structs case args: LazyList[_] => args.map(mapChild).force // Force materialization on stream case args: Iterable[_] => args.map(mapChild) From 0ebfa691fc8e88d2910bf41900718252edea1ee3 Mon Sep 17 00:00:00 2001 From: Hasnain Lakhani Date: Fri, 3 Nov 2023 00:09:21 -0500 Subject: [PATCH 012/121] [SPARK-45730][CORE] Make ReloadingX509TrustManagerSuite less flaky ### What changes were proposed in this pull request? Improve a few timing related constraints: * Wait 10s instead of 5 for a reload to happen when under high load. This should not delay the test in the average case as it checks every 100ms for an event to happen. * In certain cases we run *too fast* so the new file we create has the same timestamp as the old file, and thus we never reload. Add a sleep there so the modification times are different. This was accidentally reverted in https://github.com/apache/spark/pull/43249/commits/b7dac1f96ec45a300963e4da1dc4fc1173470da7#diff-de96db6b61e9f48fb9bd8b781f4367e60a48b3886dfe03d4cf16b47ef6c26d0a ### Why are the changes needed? These changes are needed to make the test more reliable and less flaky under load. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Ran this test in parallel on a machine under high load. Previously under those conditions I would repeatedly get high rates of failure (80%+) and now it does not fail. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43596 from hasnain-db/fix-flaky-test. Authored-by: Hasnain Lakhani Signed-off-by: Mridul Muralidharan gmail.com> --- .../ssl/ReloadingX509TrustManagerSuite.java | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java index 0526fcb11bea3..5bb47ff388671 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java @@ -161,14 +161,17 @@ public void testReload() throws Exception { // At this point we haven't reloaded, just the initial load assertEquals(0, tm.reloadCount); + // Wait so that the file modification time is different + Thread.sleep((tm.getReloadInterval() + 1000)); + // Add another cert Map certs = new HashMap(); certs.put("cert1", cert1); certs.put("cert2", cert2); createTrustStore(trustStore, "password", certs); - // Wait up to 5s until we reload - waitForReloadCount(tm, 1, 50); + // Wait up to 10s until we reload + waitForReloadCount(tm, 1, 100); assertEquals(2, tm.getAcceptedIssuers().length); } finally { @@ -286,8 +289,8 @@ public void testReloadSymlink() throws Exception { trustStoreSymlink.delete(); Files.createSymbolicLink(trustStoreSymlink.toPath(), trustStore2.toPath()); - // Wait up to 5s until we reload - waitForReloadCount(tm, 1, 50); + // Wait up to 10s until we reload + waitForReloadCount(tm, 1, 100); assertEquals(2, tm.getAcceptedIssuers().length); @@ -295,8 +298,8 @@ public void testReloadSymlink() throws Exception { certs.put("cert3", cert3); createTrustStore(trustStore2, "password", certs); - // Wait up to 5s until we reload - waitForReloadCount(tm, 2, 50); + // Wait up to 10s until we reload + waitForReloadCount(tm, 2, 100); assertEquals(3, tm.getAcceptedIssuers().length); } finally { From b9d379a6b84b67b29ccb578938764b888d64f293 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 2 Nov 2023 23:13:27 -0700 Subject: [PATCH 013/121] [SPARK-45777][CORE] Support `spark.test.appId` in `LocalSchedulerBackend` ### What changes were proposed in this pull request? This PR aims to support `spark.test.appId` in `LocalSchedulerBackend` like the following. ``` $ bin/spark-shell --driver-java-options="-Dspark.test.appId=test-app-2023" ... Spark context available as 'sc' (master = local[*], app id = test-app-2023). ``` ``` $ bin/spark-shell -c spark.test.appId=test-app-2026 -c spark.eventLog.enabled=true -c spark.eventLog.dir=/Users/dongjoon/data/history ... Spark context available as 'sc' (master = local[*], app id = test-app-2026). ``` ### Why are the changes needed? Like the other `spark.test.*` configurations, this enables the developers control the appId in `LocalSchedulerBackend`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43645 from dongjoon-hyun/SPARK-45777. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../apache/spark/scheduler/local/LocalSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 79084b75f6c39..a00fe2a06899f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -110,7 +110,7 @@ private[spark] class LocalSchedulerBackend( val totalCores: Int) extends SchedulerBackend with ExecutorBackend with Logging { - private val appId = "local-" + System.currentTimeMillis + private val appId = conf.get("spark.test.appId", "local-" + System.currentTimeMillis) private var localEndpoint: RpcEndpointRef = null private val userClassPath = getUserClasspath(conf) private val listenerBus = scheduler.sc.listenerBus From 1649b256133f352be8120c84513f67c8a873f0a9 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 3 Nov 2023 16:06:03 +0800 Subject: [PATCH 014/121] [SPARK-45742][CORE][FOLLOWUP] Remove unnecessary null check from `ArrayImplicits.SparkArrayOps#toImmutableArraySeq` ### What changes were proposed in this pull request? The implementation of the `mmutable.ArraySeq.unsafeWrapArray` function is as follows: ```scala def unsafeWrapArray[T](x: Array[T]): ArraySeq[T] = ((x: unchecked) match { case null => null case x: Array[AnyRef] => new ofRef[AnyRef](x) case x: Array[Int] => new ofInt(x) case x: Array[Double] => new ofDouble(x) case x: Array[Long] => new ofLong(x) case x: Array[Float] => new ofFloat(x) case x: Array[Char] => new ofChar(x) case x: Array[Byte] => new ofByte(x) case x: Array[Short] => new ofShort(x) case x: Array[Boolean] => new ofBoolean(x) case x: Array[Unit] => new ofUnit(x) }).asInstanceOf[ArraySeq[T]] ``` The first case of match is null, there is no need to do another manual null check, so this PR removes it. ### Why are the changes needed? Remove unnecessary null check from `ArrayImplicits.SparkArrayOps#toImmutableArraySeq` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing test cases, such as `ArrayImplicitsSuite`. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43641 from LuciferYang/SPARK-45742-FOLLOWUP. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: yangjie01 --- .../src/main/scala/org/apache/spark/util/ArrayImplicits.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala index 08997a800c957..38c2a415af3db 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala @@ -30,7 +30,6 @@ private[spark] object ArrayImplicits { * Wraps an Array[T] as an immutable.ArraySeq[T] without copying. */ def toImmutableArraySeq: immutable.ArraySeq[T] = - if (xs eq null) null - else immutable.ArraySeq.unsafeWrapArray(xs) + immutable.ArraySeq.unsafeWrapArray(xs) } } From 1359c1327345efdf9a35c46a355b5f928ac33e6d Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 3 Nov 2023 17:56:29 +0800 Subject: [PATCH 015/121] [SPARK-45776][CORE] Remove the defensive null check for `MapOutputTrackerMaster#unregisterShuffle` added in SPARK-39553 ### What changes were proposed in this pull request? This pr Remove the defensive null check for `MapOutputTrackerMaster#unregisterShuffle` added in SPARK-39553. ### Why are the changes needed? https://github.com/scala/bug/issues/12613 has been fixed in Scala 2.13.9. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing test like `SPARK-39553: Multi-thread unregister shuffle shouldn't throw NPE` in `MapOutputTrackerSuite` ### Was this patch authored or co-authored using generative AI tooling? No Closes #43644 from LuciferYang/remove-39553-null-check. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Kent Yao --- .../main/scala/org/apache/spark/MapOutputTracker.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6d5c04635ad7a..a787cdefe8085 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -906,12 +906,8 @@ private[spark] class MapOutputTrackerMaster( /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int): Unit = { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => - // SPARK-39553: Add protection for Scala 2.13 due to https://github.com/scala/bug/issues/12613 - // We should revert this if Scala 2.13 solves this issue. - if (shuffleStatus != null) { - shuffleStatus.invalidateSerializedMapOutputStatusCache() - shuffleStatus.invalidateSerializedMergeOutputStatusCache() - } + shuffleStatus.invalidateSerializedMapOutputStatusCache() + shuffleStatus.invalidateSerializedMergeOutputStatusCache() } } From e9339e60dd418559619e4beae3de0a58ca3b2a31 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 3 Nov 2023 23:17:06 +0800 Subject: [PATCH 016/121] [SPARK-45774][CORE][UI] Support `spark.master.ui.historyServerUrl` in `ApplicationPage` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR aims to support a new configuration `spark.master.ui.historyServerUrl` to show a new link, **Application History UI**, from `Application Page` to `Spark History Server` in `Spark Standalone Cluster`s. This is useful when `spark.eventLog.enabled=true` and `Spark History Server` exists. A user can launch `Spark Master` with `spark.master.ui.historyServerUrl` configuration to link to it. Please note that `Spark History Server` is an orthogonal service from not only `Spark Master` but also `Spark Applications`. For example, `Spark Application`s know only `spark.eventLog.dir`, not SHS info. Moreover, you can launch multiple SHS based on the same event location with load balancer. Lastly, SHS can be behind proxy layer too. So, only the users know the exposed SHS URL. ### Why are the changes needed? **Application Detail UI** Apache Spark `Master` currently shows `Application Detail UI` link for only live Spark jobs. Screenshot 2023-11-02 at 8 30 20 PM **Application History UI** This PR adds `Application History UI` link for completed jobs as the best effort when `spark.ui.historyServerUrl` is given. Screenshot 2023-11-02 at 8 38 37 PM ### Does this PR introduce _any_ user-facing change? No. This is disabled by default. ### How was this patch tested? Pass the CIs with the newly added test suite. Also, manual procedure ``` $ SPARK_HISTORY_OPTS="-Dspark.history.fs.logDirectory=$HOME/data/history" sbin/start-history-server.sh $ SPARK_MASTER_OPTS=-Dspark.master.ui.historyServerUrl=http://localhost:18080 sbin/start-master.sh $ sbin/start-worker.sh spark://localhost:7077 $ bin/spark-shell --master spark://127.0.0.1:7077 -c spark.eventLog.enabled=true -c spark.eventLog.dir=/Users/dongjoon/data/history ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43643 from dongjoon-hyun/SPARK-45774. Authored-by: Dongjoon Hyun Signed-off-by: Kent Yao --- .../apache/spark/deploy/master/Master.scala | 1 + .../deploy/master/ui/ApplicationPage.scala | 5 ++ .../spark/internal/config/package.scala | 8 ++ .../master/ui/ApplicationPageSuite.scala | 73 +++++++++++++++++++ 4 files changed, 87 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 058b944c591ad..d5de1366ac053 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -121,6 +121,7 @@ private[deploy] class Master( if (defaultCores < 1) { throw new SparkException(s"${DEFAULT_CORES.key} must be positive") } + val historyServerUrl = conf.get(MASTER_UI_HISTORY_SERVER_URL) // Alternative application submission gateway that is stable across Spark versions private val restServerEnabled = conf.get(MASTER_REST_SERVER_ENABLED) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 202926233d978..3087d5e8c966f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -98,6 +98,11 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") Application Detail UI + } else if (parent.master.historyServerUrl.nonEmpty) { +
  • + + Application History UI +
  • } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 143dd0c44ce84..93a42eec8326b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1837,6 +1837,14 @@ package object config { .intConf .createWithDefault(8080) + private[spark] val MASTER_UI_HISTORY_SERVER_URL = + ConfigBuilder("spark.master.ui.historyServerUrl") + .doc("The URL where Spark history server is running. Please note that this assumes " + + "that all Spark jobs share the same event log location where the history server accesses.") + .version("4.0.0") + .stringConf + .createOptional + private[spark] val IO_COMPRESSION_SNAPPY_BLOCKSIZE = ConfigBuilder("spark.io.compression.snappy.blockSize") .doc("Block size in bytes used in Snappy compression, in the case when " + diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala new file mode 100644 index 0000000000000..9890ac24e168e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/ApplicationPageSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import java.util.Date +import javax.servlet.http.HttpServletRequest + +import org.mockito.Mockito.{mock, when} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.master.{ApplicationInfo, ApplicationState, Master} +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.rpc.RpcEndpointRef + +class ApplicationPageSuite extends SparkFunSuite { + + private val master = mock(classOf[Master]) + when(master.historyServerUrl).thenReturn(Some("http://my-history-server:18080")) + + private val rp = new ResourceProfile(Map.empty, Map.empty) + private val desc = ApplicationDescription("name", Some(4), null, "appUiUrl", rp) + private val appFinished = new ApplicationInfo(0, "app-finished", desc, new Date, null, 1) + appFinished.markFinished(ApplicationState.FINISHED) + private val appLive = new ApplicationInfo(0, "app-live", desc, new Date, null, 1) + + private val state = mock(classOf[MasterStateResponse]) + when(state.completedApps).thenReturn(Array(appFinished)) + when(state.activeApps).thenReturn(Array(appLive)) + + private val rpc = mock(classOf[RpcEndpointRef]) + when(rpc.askSync[MasterStateResponse](RequestMasterState)).thenReturn(state) + + private val masterWebUI = mock(classOf[MasterWebUI]) + when(masterWebUI.master).thenReturn(master) + when(masterWebUI.masterEndpointRef).thenReturn(rpc) + + test("SPARK-45774: Application Detail UI") { + val request = mock(classOf[HttpServletRequest]) + when(request.getParameter("appId")).thenReturn("app-live") + + val result = new ApplicationPage(masterWebUI).render(request).toString() + assert(result.contains("Application Detail UI")) + assert(!result.contains("Application History UI")) + assert(!result.contains(master.historyServerUrl.get)) + } + + test("SPARK-45774: Application History UI") { + val request = mock(classOf[HttpServletRequest]) + when(request.getParameter("appId")).thenReturn("app-finished") + + val result = new ApplicationPage(masterWebUI).render(request).toString() + assert(!result.contains("Application Detail UI")) + assert(result.contains("Application History UI")) + assert(result.contains(master.historyServerUrl.get)) + } +} From ec875c521feed18f72200a8f87a2be5d9e3ccf96 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 3 Nov 2023 23:27:38 +0800 Subject: [PATCH 017/121] [SPARK-45688][SPARK-45693][CORE] Clean up the deprecated API usage related to `MapOps` & Fix `method += in trait Growable is deprecated` ### What changes were proposed in this pull request? The pr aims to: - clean up the deprecated API usage related to MapOps. - fix method += in trait Growable is deprecated. ### Why are the changes needed? Eliminate warnings and no longer use `deprecated scala APIs`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Pass GA. - Manually test: ``` build/sbt -Phadoop-3 -Pdocker-integration-tests -Pspark-ganglia-lgpl -Pkinesis-asl -Pkubernetes -Phive-thriftserver -Pconnect -Pyarn -Phive -Phadoop-cloud -Pvolcano -Pkubernetes-integration-tests Test/package streaming-kinesis-asl-assembly/assembly connect/assembly ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43578 from panbingkun/SPARK-45688. Authored-by: panbingkun Signed-off-by: yangjie01 --- .../org/apache/spark/deploy/SparkSubmit.scala | 28 +++++++++---------- .../spark/deploy/worker/CommandUtils.scala | 5 ++-- .../apache/spark/util/JsonProtocolSuite.scala | 2 +- .../cluster/YarnClientSchedulerBackend.scala | 2 +- .../statsEstimation/JoinEstimation.scala | 7 ++--- .../v2/DescribeNamespaceExec.scala | 4 +-- .../v2/V2SessionCatalogSuite.scala | 2 +- 7 files changed, 25 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index e60be5d5a651f..30b542eefb60b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -715,7 +715,7 @@ private[spark] class SparkSubmit extends Logging { if (opt.value != null && (deployMode & opt.deployMode) != 0 && (clusterManager & opt.clusterManager) != 0) { - if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) } + if (opt.clOption != null) { childArgs += opt.clOption += opt.value } if (opt.confKey != null) { if (opt.mergeFn.isDefined && sparkConf.contains(opt.confKey)) { sparkConf.set(opt.confKey, opt.mergeFn.get.apply(sparkConf.get(opt.confKey), opt.value)) @@ -747,15 +747,15 @@ private[spark] class SparkSubmit extends Logging { if (args.isStandaloneCluster) { if (args.useRest) { childMainClass = REST_CLUSTER_SUBMIT_CLASS - childArgs += (args.primaryResource, args.mainClass) + childArgs += args.primaryResource += args.mainClass } else { // In legacy standalone cluster mode, use Client as a wrapper around the user class childMainClass = STANDALONE_CLUSTER_SUBMIT_CLASS if (args.supervise) { childArgs += "--supervise" } - Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) } - Option(args.driverCores).foreach { c => childArgs += ("--cores", c) } + Option(args.driverMemory).foreach { m => childArgs += "--memory" += m } + Option(args.driverCores).foreach { c => childArgs += "--cores" += c } childArgs += "launch" - childArgs += (args.master, args.primaryResource, args.mainClass) + childArgs += args.master += args.primaryResource += args.mainClass } if (args.childArgs != null) { childArgs ++= args.childArgs @@ -777,20 +777,20 @@ private[spark] class SparkSubmit extends Logging { if (isYarnCluster) { childMainClass = YARN_CLUSTER_SUBMIT_CLASS if (args.isPython) { - childArgs += ("--primary-py-file", args.primaryResource) - childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") + childArgs += "--primary-py-file" += args.primaryResource + childArgs += "--class" += "org.apache.spark.deploy.PythonRunner" } else if (args.isR) { val mainFile = new Path(args.primaryResource).getName - childArgs += ("--primary-r-file", mainFile) - childArgs += ("--class", "org.apache.spark.deploy.RRunner") + childArgs += "--primary-r-file" += mainFile + childArgs += "--class" += "org.apache.spark.deploy.RRunner" } else { if (args.primaryResource != SparkLauncher.NO_RESOURCE) { - childArgs += ("--jar", args.primaryResource) + childArgs += "--jar" += args.primaryResource } - childArgs += ("--class", args.mainClass) + childArgs += "--class" += args.mainClass } if (args.childArgs != null) { - args.childArgs.foreach { arg => childArgs += ("--arg", arg) } + args.childArgs.foreach { arg => childArgs += "--arg" += arg } } } @@ -813,12 +813,12 @@ private[spark] class SparkSubmit extends Logging { } if (args.childArgs != null) { args.childArgs.foreach { arg => - childArgs += ("--arg", arg) + childArgs += "--arg" += arg } } // Pass the proxyUser to the k8s app so it is possible to add it to the driver args if (args.proxyUser != null) { - childArgs += ("--proxy-user", args.proxyUser) + childArgs += "--proxy-user" += args.proxyUser } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index c04214de4ddc6..d1190ca46c2a8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -81,14 +81,15 @@ object CommandUtils extends Logging { var newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName) - command.environment + ((libraryPathName, libraryPaths.mkString(File.pathSeparator))) + command.environment ++ Map(libraryPathName -> libraryPaths.mkString(File.pathSeparator)) } else { command.environment } // set auth secret to env variable if needed if (securityMgr.isAuthenticationEnabled()) { - newEnvironment += (SecurityManager.ENV_AUTH_SECRET -> securityMgr.getSecretKey()) + newEnvironment = newEnvironment ++ + Map(SecurityManager.ENV_AUTH_SECRET -> securityMgr.getSecretKey()) } // set SSL env variables if needed newEnvironment ++= securityMgr.getEnvironmentForSslRpcPasswords diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 3defd4b1a7d90..948bc8889bcd1 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -626,7 +626,7 @@ class JsonProtocolSuite extends SparkFunSuite { val expectedEvent: SparkListenerEnvironmentUpdate = { val e = JsonProtocol.environmentUpdateFromJson(environmentUpdateJsonString) e.copy(environmentDetails = - e.environmentDetails + ("Metrics Properties" -> Seq.empty[(String, String)])) + e.environmentDetails ++ Map("Metrics Properties" -> Seq.empty[(String, String)])) } val oldEnvironmentUpdateJson = environmentUpdateJsonString .removeField("Metrics Properties") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 717c620f5c341..af41d30c2cdb8 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -53,7 +53,7 @@ private[spark] class YarnClientSchedulerBackend( sc.ui.foreach { ui => conf.set(DRIVER_APP_UI_ADDRESS, ui.webUrl) } val argsArrayBuf = new ArrayBuffer[String]() - argsArrayBuf += ("--arg", hostport) + argsArrayBuf += "--arg" += hostport logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" ")) val args = new ClientArguments(argsArrayBuf.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index c6e76df1b31ad..10646130a9106 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -206,13 +206,12 @@ case class JoinEstimation(join: Join) extends Logging { case _ => computeByNdv(leftKey, rightKey, newMin, newMax) } - keyStatsAfterJoin += ( + keyStatsAfterJoin += // Histograms are propagated as unchanged. During future estimation, they should be // truncated by the updated max/min. In this way, only pointers of the histograms are // propagated and thus reduce memory consumption. - leftKey -> joinStat.copy(histogram = leftKeyStat.histogram), - rightKey -> joinStat.copy(histogram = rightKeyStat.histogram) - ) + (leftKey -> joinStat.copy(histogram = leftKeyStat.histogram)) += + (rightKey -> joinStat.copy(histogram = rightKeyStat.histogram)) // Return cardinality estimated from the most selective join keys. if (card < joinCard) joinCard = card } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala index 125952566d7e8..d97ffb6940600 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala @@ -46,12 +46,12 @@ case class DescribeNamespaceExec( } if (isExtended) { - val properties = metadata.asScala -- CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES + val properties = metadata.asScala.toMap -- CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES val propertiesStr = if (properties.isEmpty) { "" } else { - conf.redactOptions(properties.toMap).toSeq.sortBy(_._1).mkString("(", ", ", ")") + conf.redactOptions(properties).toSeq.sortBy(_._1).mkString("(", ", ", ")") } rows += toCatalystRow("Properties", propertiesStr) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index c43658eacabc2..f9da55ed6ba31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -827,7 +827,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { // remove location and comment that are automatically added by HMS unless they are expected val toRemove = CatalogV2Util.NAMESPACE_RESERVED_PROPERTIES.filter(expected.contains) - assert(expected -- toRemove === actual) + assert(expected.toMap -- toRemove === actual) } test("listNamespaces: basic behavior") { From da5aa706147c6a462a8470f8fdf2c3cd17fede01 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Sat, 4 Nov 2023 01:10:30 +0800 Subject: [PATCH 018/121] [SPARK-45779][INFRA] Add a check step for jira issue ticket to be resolved ### What changes were proposed in this pull request? This PR moves the step for printing the summary of the Jira issue and adds a confirmation step for committers before assigning it to a user. Also, we show the final status of the Jira issue after resolution. ### Why are the changes needed? The JIRA ID mentioned in the PR title can sometimes be incorrect due to carelessness. Double-checking the ID is necessary to ensure that it is accurate. It is now too late for the committers to modify the ID once they have seen the ticket summary. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Tested with SPARK-45776. Note that the message below may seem noisy because I declined the merge script at first. ``` Would you like to update an associated JIRA? (y/n): y Enter a JIRA id[SPARK-45776]: === JIRA SPARK-45776 === summary Remove the defensive null check added in SPARK-39553. assignee None status Open url https://issues.apache.org/jira/browse/SPARK-45776 Check if the JIRA information is as expected(y/n): n Enter the revised JIRA ID again or leave blank to skip[]: SPARK-45776 === JIRA SPARK-45776 === summary Remove the defensive null check added in SPARK-39553. assignee None status Open url https://issues.apache.org/jira/browse/SPARK-45776 Check if the JIRA information is as expected(y/n): y JIRA is unassigned, choose assignee [0] Yang Jie (Reporter) Enter number of user, or userid, to assign to (blank to leave unassigned):0 Enter comma-separated fix version(s) [4.0.0]: === JIRA SPARK-45776 === summary Remove the defensive null check added in SPARK-39553. assignee Yang Jie status RESOLVED url https://issues.apache.org/jira/browse/SPARK-45776 Successfully resolved SPARK-45776 with fixVersions=['4.0.0']! Enter a JIRA id[SPARK-39553]: ``` ### Was this patch authored or co-authored using generative AI tooling? no Closes #43648 from yaooqinn/SPARK-45779. Authored-by: Kent Yao Signed-off-by: Kent Yao --- dev/merge_spark_pr.py | 70 +++++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 643bc37ced193..59cb027389798 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -248,37 +248,49 @@ def cherry_pick(pr_num, merge_hash, default_branch): return pick_ref -def resolve_jira_issue(merge_branches, comment, default_jira_id=""): - jira_id = input("Enter a JIRA id [%s]: " % default_jira_id) +def print_jira_issue_summary(issue): + summary = issue.fields.summary + assignee = issue.fields.assignee + if assignee is not None: + assignee = assignee.displayName + status = issue.fields.status.name + print("=== JIRA %s ===" % issue.key) + print( + "summary\t\t%s\nassignee\t%s\nstatus\t\t%s\nurl\t\t%s/%s\n" + % (summary, assignee, status, JIRA_BASE, issue.key) + ) + + +def get_jira_issue(prompt, default_jira_id=""): + jira_id = input("%s[%s]: " % (prompt, default_jira_id)) if jira_id == "": jira_id = default_jira_id if jira_id == "": print("JIRA ID not found, skipping.") - return - + return None try: issue = asf_jira.issue(jira_id) + print_jira_issue_summary(issue) + status = issue.fields.status.name + if status == "Resolved" or status == "Closed": + print("JIRA issue %s already has status '%s'" % (jira_id, status)) + return None + if input("Check if the JIRA information is as expected(y/n): ").lower() != "n": + return issue + else: + return get_jira_issue("Enter the revised JIRA ID again or leave blank to skip") except Exception as e: - fail("ASF JIRA could not find %s\n%s" % (jira_id, e)) - - cur_status = issue.fields.status.name - cur_summary = issue.fields.summary - cur_assignee = issue.fields.assignee - if cur_assignee is None: - cur_assignee = choose_jira_assignee(issue) - # Check again, we might not have chosen an assignee - if cur_assignee is None: - cur_assignee = "NOT ASSIGNED!!!" - else: - cur_assignee = cur_assignee.displayName + print("ASF JIRA could not find %s: %s" % (jira_id, e)) + return get_jira_issue("Enter the revised JIRA ID again or leave blank to skip") - if cur_status == "Resolved" or cur_status == "Closed": - fail("JIRA issue %s already has status '%s'" % (jira_id, cur_status)) - print("=== JIRA %s ===" % jira_id) - print( - "summary\t\t%s\nassignee\t%s\nstatus\t\t%s\nurl\t\t%s/%s\n" - % (cur_summary, cur_assignee, cur_status, JIRA_BASE, jira_id) - ) + +def resolve_jira_issue(merge_branches, comment, default_jira_id=""): + issue = get_jira_issue("Enter a JIRA id", default_jira_id) + if issue is None: + return + + if issue.fields.assignee is None: + choose_jira_assignee(issue) versions = asf_jira.project_versions("SPARK") # Consider only x.y.z, unreleased, unarchived versions @@ -351,17 +363,23 @@ def get_version_json(version_str): jira_fix_versions = list(map(lambda v: get_version_json(v), fix_versions)) - resolve = list(filter(lambda a: a["name"] == "Resolve Issue", asf_jira.transitions(jira_id)))[0] + resolve = list(filter(lambda a: a["name"] == "Resolve Issue", asf_jira.transitions(issue.key)))[ + 0 + ] resolution = list(filter(lambda r: r.raw["name"] == "Fixed", asf_jira.resolutions()))[0] asf_jira.transition_issue( - jira_id, + issue.key, resolve["id"], fixVersions=jira_fix_versions, comment=comment, resolution={"id": resolution.raw["id"]}, ) - print("Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions)) + try: + print_jira_issue_summary(asf_jira.issue(issue.key)) + except Exception: + print("Unable to fetch JIRA issue %s after resolving" % issue.key) + print("Successfully resolved %s with fixVersions=%s!" % (issue.key, fix_versions)) def choose_jira_assignee(issue): From afdce266f0ffeb068d47eca2f2af1bcba66b0e95 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Sat, 4 Nov 2023 01:16:32 +0800 Subject: [PATCH 019/121] [SPARK-44843][TESTS] Double streamingTimeout for StateStoreMetricsTest to make RocksDBStateStore related streaming tests reliable ### What changes were proposed in this pull request? This PR increases streamingTimeout and the check interval for StateStoreMetricsTest to make RocksDBStateStore-related streaming tests reliable, hopefully. ### Why are the changes needed? ``` SPARK-35896: metrics in StateOperatorProgress are output correctly (RocksDBStateStore with changelog checkpointing) *** FAILED *** (1 minute) [info] Timed out waiting for stream: The code passed to failAfter did not complete within 60 seconds. [info] java.base/java.lang.Thread.getStackTrace(Thread.java:1619) ``` The probability of these tests failing is close to 100%, which seriously affects the UX of making PRs for the contributors. https://github.com/yaooqinn/spark/actions/runs/6744173341/job/18333952141 ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? this can be verified by `sql - slow test` job in CI ### Was this patch authored or co-authored using generative AI tooling? no Closes #43647 from yaooqinn/SPARK-44843. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../apache/spark/sql/streaming/StateStoreMetricsTest.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index 57ced748cd9f0..07837f5c06473 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql.execution.streaming.StreamExecution trait StateStoreMetricsTest extends StreamTest { @@ -24,6 +26,8 @@ trait StateStoreMetricsTest extends StreamTest { private var lastCheckedRecentProgressIndex = -1 private var lastQuery: StreamExecution = null + override val streamingTimeout = 120.seconds + override def beforeEach(): Unit = { super.beforeEach() lastCheckedRecentProgressIndex = -1 @@ -106,7 +110,7 @@ trait StateStoreMetricsTest extends StreamTest { AssertOnQuery(s"Check operator progress metrics: operatorName = $operatorName, " + s"numShufflePartitions = $numShufflePartitions, " + s"numStateStoreInstances = $numStateStoreInstances") { q => - eventually(timeout(streamingTimeout)) { + eventually(timeout(streamingTimeout), interval(200.milliseconds)) { val (progressesSinceLastCheck, lastCheckedProgressIndex, numStateOperators) = retrieveProgressesSinceLastCheck(q) assert(operatorIndex < numStateOperators, s"Invalid operator Index: $operatorIndex") From 31a81983fb2d19e05fadccdf49c37dd4f5c50465 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 3 Nov 2023 17:53:55 -0700 Subject: [PATCH 020/121] [SPARK-45780][CONNECT] Propagate all Spark Connect client threadlocals in InheritableThread ### What changes were proposed in this pull request? Currently pyspark InheritableThread propagates Spark Connect session.client.thread_local.tags to child threads. Generalize this to propagate all thread locals, and also make a deep copy, just like the scala equivalent does a clone. ### Why are the changes needed? Generalize the mechanism of SparkConnectClient.thread_local ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing test for propagating SparkSession tags should cover this. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43649 from juliuszsompolski/SPARK-45780. Authored-by: Juliusz Sompolski Signed-off-by: Hyukjin Kwon --- python/pyspark/util.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 9c70bac2a3d95..4a828d6bfc947 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,6 +16,7 @@ # limitations under the License. # +import copy import functools import itertools import os @@ -343,14 +344,19 @@ def inheritable_thread_target(f: Optional[Union[Callable, "SparkSession"]] = Non assert session is not None, "Spark Connect session must be provided." def outer(ff: Callable) -> Callable: - if not hasattr(session.client.thread_local, "tags"): # type: ignore[union-attr] - session.client.thread_local.tags = set() # type: ignore[union-attr] - tags = set(session.client.thread_local.tags) # type: ignore[union-attr] + session_client_thread_local_attrs = [ + (attr, copy.deepcopy(value)) + for ( + attr, + value, + ) in session.client.thread_local.__dict__.items() # type: ignore[union-attr] + ] @functools.wraps(ff) def inner(*args: Any, **kwargs: Any) -> Any: - # Set tags in child thread. - session.client.thread_local.tags = tags # type: ignore[union-attr] + # Set thread locals in child thread. + for attr, value in session_client_thread_local_attrs: + setattr(session.client.thread_local, attr, value) # type: ignore[union-attr] return ff(*args, **kwargs) return inner From 4f56e3852b9275a0097384305e3966eda49c045d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 3 Nov 2023 22:21:36 -0700 Subject: [PATCH 021/121] [SPARK-45785][CORE] Support `spark.deploy.appNumberModulo` to rotate app number MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR aims to support to rotate app number by introducing a new configuration, `spark.deploy.appNumberModulo`. ### Why are the changes needed? Historically, Apache Spark's App ID has a style, `app-yyyyMMddHHmmss-1234`. Since the 3rd part, `1234`, is a simple sequentially incremented number without any rotation, the generated IDs are like the following. ``` app-yyyyMMddHHmmss-0000 app-yyyyMMddHHmmss-0001 ... app-yyyyMMddHHmmss-9999 app-yyyyMMddHHmmss-10000 ``` If we support rotation by modulo 10000, it will keep 4 digits. ``` app-yyyyMMddHHmmss-0000 app-yyyyMMddHHmmss-0001 ... app-yyyyMMddHHmmss-9999 app-yyyyMMddHHmmss-0000 ``` Please note that the second part changes every seconds. In general, modulo by 10000 is enough to generate unique AppIDs. The following is an example to use modulo 1000. You can tune further by using `spark.deploy.appIdPattern` configuration. ``` $ SPARK_MASTER_OPTS="-Dspark.deploy.appNumberModulo=1000 -Dspark.master.rest.enabled=true" sbin/start-master.sh ``` Screenshot 2023-11-03 at 5 56 17 PM ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs with newly added test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43654 from dongjoon-hyun/SPARK-45785. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/deploy/master/Master.scala | 4 ++++ .../org/apache/spark/internal/config/Deploy.scala | 10 ++++++++++ .../org/apache/spark/deploy/master/MasterSuite.scala | 9 +++++++++ 3 files changed, 23 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index d5de1366ac053..63d981c5fde82 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -79,6 +79,7 @@ private[deploy] class Master( private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 + private val moduloAppNumber = conf.get(APP_NUMBER_MODULO).getOrElse(0) private val drivers = new HashSet[DriverInfo] private val completedDrivers = new ArrayBuffer[DriverInfo] @@ -1156,6 +1157,9 @@ private[deploy] class Master( private def newApplicationId(submitDate: Date): String = { val appId = appIdPattern.format(createDateFormat.format(submitDate), nextAppNumber) nextAppNumber += 1 + if (moduloAppNumber > 0) { + nextAppNumber %= moduloAppNumber + } appId } diff --git a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala index c6ccf9550bc91..906ec0fc99737 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala @@ -82,6 +82,16 @@ private[spark] object Deploy { .checkValue(_ > 0, "The maximum number of running drivers should be positive.") .createWithDefault(Int.MaxValue) + val APP_NUMBER_MODULO = ConfigBuilder("spark.deploy.appNumberModulo") + .doc("The modulo for app number. By default, the next of `app-yyyyMMddHHmmss-9999` is " + + "`app-yyyyMMddHHmmss-10000`. If we have 10000 as modulo, it will be " + + "`app-yyyyMMddHHmmss-0000`. In most cases, the prefix `app-yyyyMMddHHmmss` is increased " + + "already during creating 10000 applications.") + .version("4.0.0") + .intConf + .checkValue(_ >= 1000, "The modulo for app number should be greater than or equal to 1000.") + .createOptional + val DRIVER_ID_PATTERN = ConfigBuilder("spark.deploy.driverIdPattern") .doc("The pattern for driver ID generation based on Java `String.format` method. " + "The default value is `driver-%s-%04d` which represents the existing driver id string " + diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index e8615cdbdd559..4f8457f930e4a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -1266,6 +1266,15 @@ class MasterSuite extends SparkFunSuite }.getMessage assert(m.contains("Whitespace is not allowed")) } + + test("SPARK-45785: Rotate app num with modulo operation") { + val conf = new SparkConf().set(APP_ID_PATTERN, "%2$d").set(APP_NUMBER_MODULO, 1000) + val master = makeMaster(conf) + val submitDate = new Date() + (0 to 2000).foreach { i => + assert(master.invokePrivate(_newApplicationId(submitDate)) === s"${i % 1000}") + } + } } private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer) From 3363c2af3f6a59363135451d251f25e328a4fddf Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 4 Nov 2023 00:23:33 -0700 Subject: [PATCH 022/121] [MINOR][CORE] Validate spark.deploy.defaultCores value during config setting ### What changes were proposed in this pull request? This aims to move `spark.deploy.defaultCores` validation logic to config setting. ### Why are the changes needed? In order to ensure the value range by early checking. ### Does this PR introduce _any_ user-facing change? No. `Spark Master` will fail to start in both cases, *before* and *after*. ### How was this patch tested? Manual review. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43655 from dongjoon-hyun/spark.deploy.defaultCores. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/deploy/master/Master.scala | 5 +---- .../main/scala/org/apache/spark/internal/config/Deploy.scala | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 63d981c5fde82..e63d72ebb40d2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -24,7 +24,7 @@ import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random -import org.apache.spark.{SecurityManager, SparkConf, SparkException} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState.DriverState @@ -119,9 +119,6 @@ private[deploy] class Master( // Default maxCores for applications that don't specify it (i.e. pass Int.MaxValue) private val defaultCores = conf.get(DEFAULT_CORES) val reverseProxy = conf.get(UI_REVERSE_PROXY) - if (defaultCores < 1) { - throw new SparkException(s"${DEFAULT_CORES.key} must be positive") - } val historyServerUrl = conf.get(MASTER_UI_HISTORY_SERVER_URL) // Alternative application submission gateway that is stable across Spark versions diff --git a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala index 906ec0fc99737..7b35e92022ae0 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala @@ -73,6 +73,7 @@ private[spark] object Deploy { val DEFAULT_CORES = ConfigBuilder("spark.deploy.defaultCores") .version("0.9.0") .intConf + .checkValue(_ > 0, "spark.deploy.defaultCores must be positive.") .createWithDefault(Int.MaxValue) val MAX_DRIVERS = ConfigBuilder("spark.deploy.maxDrivers") From 749c79e6db1ca35eb47ba66aa4bc31c285260eae Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Sat, 4 Nov 2023 01:01:21 -0700 Subject: [PATCH 023/121] [SPARK-45781][BUILD] Upgrade Arrow to 14.0.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This pr upgrade Apache Arrow from 13.0.0 to 14.0.0. ### Why are the changes needed? The Apache Arrow 14.0.0 release brings a number of enhancements and bug fixes. ‎ In terms of bug fixes, the release addresses several critical issues that were causing failures in integration jobs with Spark([GH-36332](https://github.com/apache/arrow/issues/36332)) and problems with importing empty data arrays([GH-37056](https://github.com/apache/arrow/issues/37056)). It also optimizes the process of appending variable length vectors([GH-37829](https://github.com/apache/arrow/issues/37829)) and includes C++ libraries for MacOS AARCH 64 in Java-Jars([GH-38076](https://github.com/apache/arrow/issues/38076)). ‎ The new features and improvements focus on enhancing the handling and manipulation of data. This includes the introduction of DefaultVectorComparators for large types([GH-25659](https://github.com/apache/arrow/issues/25659)), support for extended expressions in ScannerBuilder([GH-34252](https://github.com/apache/arrow/issues/34252)), and the exposure of the VectorAppender class([GH-37246](https://github.com/apache/arrow/issues/37246)). ‎ The release also brings enhancements to the development and testing process, with the CI environment now using JDK 21([GH-36994](https://github.com/apache/arrow/issues/36994)). In addition, the release introduces vector validation consistent with C++, ensuring consistency across different languages([GH-37702](https://github.com/apache/arrow/issues/37702)). ‎ Furthermore, the usability of VarChar writers and binary writers has been improved with the addition of extra input methods([GH-37705](https://github.com/apache/arrow/issues/37705)), and VarCharWriter now supports writing from `Text` and `String`([GH-37706](https://github.com/apache/arrow/issues/37706)). The release also adds typed getters for StructVector, improving the ease of accessing data([GH-37863](https://github.com/apache/arrow/issues/37863)). The full release notes as follows: - https://arrow.apache.org/release/14.0.0.html ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43650 from LuciferYang/arrow-14. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 8 ++++---- pom.xml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 6364ec48fb664..b7d6bdbfd1299 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -16,10 +16,10 @@ antlr4-runtime/4.13.1//antlr4-runtime-4.13.1.jar aopalliance-repackaged/2.6.1//aopalliance-repackaged-2.6.1.jar arpack/3.0.3//arpack-3.0.3.jar arpack_combined_all/0.1//arpack_combined_all-0.1.jar -arrow-format/13.0.0//arrow-format-13.0.0.jar -arrow-memory-core/13.0.0//arrow-memory-core-13.0.0.jar -arrow-memory-netty/13.0.0//arrow-memory-netty-13.0.0.jar -arrow-vector/13.0.0//arrow-vector-13.0.0.jar +arrow-format/14.0.0//arrow-format-14.0.0.jar +arrow-memory-core/14.0.0//arrow-memory-core-14.0.0.jar +arrow-memory-netty/14.0.0//arrow-memory-netty-14.0.0.jar +arrow-vector/14.0.0//arrow-vector-14.0.0.jar audience-annotations/0.5.0//audience-annotations-0.5.0.jar avro-ipc/1.11.3//avro-ipc-1.11.3.jar avro-mapred/1.11.3//avro-mapred-1.11.3.jar diff --git a/pom.xml b/pom.xml index 2e0c95516c177..cae315f4d7182 100644 --- a/pom.xml +++ b/pom.xml @@ -228,7 +228,7 @@ If you are changing Arrow version specification, please check ./python/pyspark/sql/pandas/utils.py, and ./python/setup.py too. --> - 13.0.0 + 14.0.0 2.5.11 From 6d669fa957463851af463d0ba03d6e6ee76e2cda Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 4 Nov 2023 09:23:58 -0700 Subject: [PATCH 024/121] [SPARK-45791][CONNECT][TESTS] Rename `SparkConnectSessionHodlerSuite.scala` to `SparkConnectSessionHolderSuite.scala` ### What changes were proposed in this pull request? This PR aims to fix a typo `Hodler` in file name. - `SparkConnectSessionHodlerSuite.scala` (from) - `SparkConnectSessionHolderSuite.scala` (to) It's also unmatched with the class name in the file because class name itself is correct. https://github.com/apache/spark/blob/3363c2af3f6a59363135451d251f25e328a4fddf/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala#L37 ### Why are the changes needed? This is a typo from the original PR. - https://github.com/apache/spark/pull/41580 Since the original PR is shipped as Apache Spark 3.5.0, I created a JIRA instead of a follow-up. We need to backport this patch to `branch-3.5`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43657 from dongjoon-hyun/SPARK-45791. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- ...sionHodlerSuite.scala => SparkConnectSessionHolderSuite.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/{SparkConnectSessionHodlerSuite.scala => SparkConnectSessionHolderSuite.scala} (100%) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala similarity index 100% rename from connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala rename to connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala From 89cbe85b0b26c7994d2fb78733769d93cd1ec992 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 4 Nov 2023 09:25:49 -0700 Subject: [PATCH 025/121] [SPARK-45779][INFRA][FOLLOWUP] Add more spaces to the prompts ### What changes were proposed in this pull request? This PR aims to add more spaces to the prompts to be consistent with the existing behaviors. ### Why are the changes needed? Previously, we have these spaces. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual review. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43658 from dongjoon-hyun/SPARK-45779-2. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/merge_spark_pr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 59cb027389798..d837e5c298252 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -262,7 +262,7 @@ def print_jira_issue_summary(issue): def get_jira_issue(prompt, default_jira_id=""): - jira_id = input("%s[%s]: " % (prompt, default_jira_id)) + jira_id = input("%s [%s]: " % (prompt, default_jira_id)) if jira_id == "": jira_id = default_jira_id if jira_id == "": @@ -275,7 +275,7 @@ def get_jira_issue(prompt, default_jira_id=""): if status == "Resolved" or status == "Closed": print("JIRA issue %s already has status '%s'" % (jira_id, status)) return None - if input("Check if the JIRA information is as expected(y/n): ").lower() != "n": + if input("Check if the JIRA information is as expected (y/n): ").lower() != "n": return issue else: return get_jira_issue("Enter the revised JIRA ID again or leave blank to skip") From 8e1cf59ab68995863945aa69f53a04647b328952 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 4 Nov 2023 09:28:17 -0700 Subject: [PATCH 026/121] [SPARK-45790][INFRA] Move `graphx` to `mllib*` test pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR aims to move `graphx` module test from `core` pipeline to `mllib` pipeline. ### Why are the changes needed? To off-load `core` test pipeline and avoid the situation where `graphx` patch triggers `core` pipeline. - `core` test pipeline takes 1 and half hour in general. - `mllib*` test pipeline seems to be stable in these day and finished in a hour. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs and manually check. - https://github.com/dongjoon-hyun/spark/actions/runs/6753488512/job/18360020367 Screenshot 2023-11-04 at 12 25 00 AM ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43656 from dongjoon-hyun/SPARK-45790. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .github/workflows/build_and_test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index eded5da5c1ddd..1e28f5530513a 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -143,11 +143,11 @@ jobs: - >- core, unsafe, kvstore, avro, utils, network-common, network-shuffle, repl, launcher, - examples, sketch, graphx + examples, sketch - >- api, catalyst, hive-thriftserver - >- - mllib-local,mllib + mllib-local, mllib, graphx - >- streaming, sql-kafka-0-10, streaming-kafka-0-10, yarn, kubernetes, hadoop-cloud, spark-ganglia-lgpl, From 8553de36d8d22b3819b1f7da45cfc992d6216fc9 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sun, 5 Nov 2023 08:47:20 -0800 Subject: [PATCH 027/121] [SPARK-45799][SQL] Remove some unused method in QueryExecutionErrors & QueryCompilationErrors ### What changes were proposed in this pull request? The pr aims to clear some unused method in `QueryExecutionErrors` &`QueryCompilationErrors` and related error classes in `error-classes.json`. ### Why are the changes needed? Make code clear. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43665 from panbingkun/SPARK-45799. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- .../main/resources/error/error-classes.json | 20 ------------------- .../sql/errors/QueryCompilationErrors.scala | 6 ------ .../sql/errors/QueryExecutionErrors.scala | 19 ------------------ 3 files changed, 45 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index af32bcf129c08..8b0951a7b00b4 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -4777,11 +4777,6 @@ "Hive metastore does not support altering database location." ] }, - "_LEGACY_ERROR_TEMP_1221" : { - "message" : [ - "Hive 0.12 doesn't support creating permanent functions. Please use Hive 0.13 or higher." - ] - }, "_LEGACY_ERROR_TEMP_1222" : { "message" : [ "Unknown resource type: ." @@ -5901,16 +5896,6 @@ ", db: , table: ." ] }, - "_LEGACY_ERROR_TEMP_2190" : { - "message" : [ - "DROP TABLE ... PURGE." - ] - }, - "_LEGACY_ERROR_TEMP_2191" : { - "message" : [ - "ALTER TABLE ... DROP PARTITION ... PURGE." - ] - }, "_LEGACY_ERROR_TEMP_2192" : { "message" : [ "Partition filter cannot have both `\"` and `'` characters." @@ -5931,11 +5916,6 @@ " when creating Hive client using classpath: Please make sure that jars for your version of hive and hadoop are included in the paths passed to ." ] }, - "_LEGACY_ERROR_TEMP_2197" : { - "message" : [ - "LOCATION clause illegal for view partition." - ] - }, "_LEGACY_ERROR_TEMP_2198" : { "message" : [ "Failed to rename as already exists." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 4b28eadfec6fb..1925eddd2ce23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2447,12 +2447,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "tableType" -> tableType)) } - def hiveCreatePermanentFunctionsUnsupportedError(): Throwable = { - new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1221", - messageParameters = Map.empty) - } - def unknownHiveResourceTypeError(resourceType: String): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1222", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 30dfe8eebe6cf..3a60480cfc52a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1643,18 +1643,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map.empty) } - def dropTableWithPurgeUnsupportedError(): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2190", - messageParameters = Map.empty) - } - - def alterTableWithDropPartitionAndPurgeUnsupportedError(): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2191", - messageParameters = Map.empty) - } - def invalidPartitionFilterError(): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( errorClass = "_LEGACY_ERROR_TEMP_2192", @@ -1701,13 +1689,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = e) } - def illegalLocationClauseForViewPartitionError(): Throwable = { - new SparkException( - errorClass = "_LEGACY_ERROR_TEMP_2197", - messageParameters = Map.empty, - cause = null) - } - def renamePathAsExistsPathError(srcPath: Path, dstPath: Path): Throwable = { new SparkFileAlreadyExistsException( errorClass = "FAILED_RENAME_PATH", From 9cbc2d138a1de2d3ff399d224d817554ba0b9e18 Mon Sep 17 00:00:00 2001 From: dengziming Date: Sun, 5 Nov 2023 20:07:07 +0300 Subject: [PATCH 028/121] [SPARK-45710][SQL] Assign names to error _LEGACY_ERROR_TEMP_21[59,60,61,62] ### What changes were proposed in this pull request? This PR are removing `_LEGACY_ERROR_TEMP_21[59,60,61,62]` and `TOO_MANY_ARRAY_ELEMENTS`: 1. `_LEGACY_ERROR_TEMP_2159` is used in concat/array_insert; 2. `_LEGACY_ERROR_TEMP_2160` is only used in flatten; 3. `_LEGACY_ERROR_TEMP_2161` is used in array_repeat/array_insert/array_distinct/array_union/array_intersect/array_remove; 4. `_LEGACY_ERROR_TEMP_2162` is used in array_union/array_distinct; 5. There is another similar error class `TOO_MANY_ARRAY_ELEMENTS` which are used in `UnsafeArrayWriter.java`. I removed these 5 similar error classes and create a new error class `COLLECTION_SIZE_LIMIT_EXCEEDED` with 3 sub-classes: 1. `PARAMETER` is used when the parameter exceed size limit, such as `array_repeat` with count too large; 6. `FUNCTION` is used when trying to create an array exceeding size limit in a function, for example, flatten 2 arrays to a larger array; 7. `INITIALIZE` is used in `UnsafeArrayWriter.java` when trying to initialize an array exceeding size limit. ### Why are the changes needed? To assign proper name as a part of activity in SPARK-37935. ### Does this PR introduce _any_ user-facing change? Yes, the error message will include the error class name. ### How was this patch tested? 1. `COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER` can be tested from use code; 2. `COLLECTION_SIZE_LIMIT_EXCEEDED.FUNCTION` is tested using a `ColumnarArray` in `concat/flatten`, but can't be tested in `array_insert/array_distinct/array_union/array_intersect/array_remove` since we need to deduplicate the data and create an array which will cause OOM. 3. `INITIALIZE` is already tested in a existing case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43567 from dengziming/SPARK-45710. Authored-by: dengziming Signed-off-by: Max Gekk --- .../main/resources/error/error-classes.json | 49 +++++----- docs/_data/menu-sql.yaml | 2 + ...lection-size-limit-exceeded-error-class.md | 38 ++++++++ docs/sql-error-conditions.md | 14 +-- .../codegen/UnsafeArrayWriter.java | 3 +- .../expressions/collectionOperations.scala | 95 +++++++------------ .../sql/errors/QueryExecutionErrors.scala | 48 ++++------ .../codegen/UnsafeArrayWriterSuite.scala | 6 +- .../errors/QueryExecutionErrorsSuite.scala | 59 +++++++++++- 9 files changed, 184 insertions(+), 130 deletions(-) create mode 100644 docs/sql-error-conditions-collection-size-limit-exceeded-error-class.md diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 8b0951a7b00b4..3e0743d366ae3 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -405,6 +405,29 @@ ], "sqlState" : "42704" }, + "COLLECTION_SIZE_LIMIT_EXCEEDED" : { + "message" : [ + "Can't create array with elements which exceeding the array size limit ," + ], + "subClass" : { + "FUNCTION" : { + "message" : [ + "unsuccessful try to create arrays in the function ." + ] + }, + "INITIALIZE" : { + "message" : [ + "cannot initialize an array with specified parameters." + ] + }, + "PARAMETER" : { + "message" : [ + "the value of parameter(s) in the function is invalid." + ] + } + }, + "sqlState" : "54000" + }, "COLUMN_ALIASES_IS_NOT_ALLOWED" : { "message" : [ "Columns aliases are not allowed in ." @@ -3017,12 +3040,6 @@ ], "sqlState" : "428EK" }, - "TOO_MANY_ARRAY_ELEMENTS" : { - "message" : [ - "Cannot initialize array with elements of size ." - ], - "sqlState" : "54000" - }, "UDTF_ALIAS_NUMBER_MISMATCH" : { "message" : [ "The number of aliases supplied in the AS clause does not match the number of columns output by the UDTF.", @@ -5765,26 +5782,6 @@ " is not annotated with SQLUserDefinedType nor registered with UDTRegistration.}" ] }, - "_LEGACY_ERROR_TEMP_2159" : { - "message" : [ - "Unsuccessful try to concat arrays with elements due to exceeding the array size limit ." - ] - }, - "_LEGACY_ERROR_TEMP_2160" : { - "message" : [ - "Unsuccessful try to flatten an array of arrays with elements due to exceeding the array size limit ." - ] - }, - "_LEGACY_ERROR_TEMP_2161" : { - "message" : [ - "Unsuccessful try to create array with elements due to exceeding the array size limit ." - ] - }, - "_LEGACY_ERROR_TEMP_2162" : { - "message" : [ - "Unsuccessful try to union arrays with elements due to exceeding the array size limit ." - ] - }, "_LEGACY_ERROR_TEMP_2163" : { "message" : [ "Initial type must be a ." diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index 4125860642294..95833849cc595 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -109,6 +109,8 @@ subitems: - text: SQLSTATE Codes url: sql-error-conditions-sqlstates.html + - text: COLLECTION_SIZE_LIMIT_EXCEEDED error class + url: sql-error-conditions-collection-size-limit-exceeded-error-class.html - text: CONNECT error class url: sql-error-conditions-connect-error-class.html - text: DATATYPE_MISMATCH error class diff --git a/docs/sql-error-conditions-collection-size-limit-exceeded-error-class.md b/docs/sql-error-conditions-collection-size-limit-exceeded-error-class.md new file mode 100644 index 0000000000000..78b9e43826157 --- /dev/null +++ b/docs/sql-error-conditions-collection-size-limit-exceeded-error-class.md @@ -0,0 +1,38 @@ +--- +layout: global +title: COLLECTION_SIZE_LIMIT_EXCEEDED error class +displayTitle: COLLECTION_SIZE_LIMIT_EXCEEDED error class +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +[SQLSTATE: 54000](sql-error-conditions-sqlstates.html#class-54-program-limit-exceeded) + +Can't create array with `` elements which exceeding the array size limit ``, + +This error class has the following derived error classes: + +## FUNCTION + +unsuccessful try to create arrays in the function ``. + +## INITIALIZE + +cannot initialize an array with specified parameters. + +## PARAMETER + +the value of parameter(s) `` in the function `` is invalid. diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 1741e49f56152..a6f003647ddc8 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -338,6 +338,14 @@ The codec `` is not available. Consider to set the config ``. +### [COLLECTION_SIZE_LIMIT_EXCEEDED](sql-error-conditions-collection-size-limit-exceeded-error-class.html) + +[SQLSTATE: 54000](sql-error-conditions-sqlstates.html#class-54-program-limit-exceeded) + +Can't create array with `` elements which exceeding the array size limit ``, + +For more details see [COLLECTION_SIZE_LIMIT_EXCEEDED](sql-error-conditions-collection-size-limit-exceeded-error-class.html) + ### COLUMN_ALIASES_IS_NOT_ALLOWED [SQLSTATE: 42601](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -1921,12 +1929,6 @@ Choose a different name, drop or replace the existing view, or add the IF NOT E CREATE TEMPORARY VIEW or the corresponding Dataset APIs only accept single-part view names, but got: ``. -### TOO_MANY_ARRAY_ELEMENTS - -[SQLSTATE: 54000](sql-error-conditions-sqlstates.html#class-54-program-limit-exceeded) - -Cannot initialize array with `` elements of size ``. - ### UDTF_ALIAS_NUMBER_MISMATCH [SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 65d984bcd19fe..3070fa3e74b1f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -62,7 +62,8 @@ public void initialize(int numElements) { long totalInitialSize = headerInBytes + fixedPartInBytesLong; if (totalInitialSize > Integer.MAX_VALUE) { - throw QueryExecutionErrors.tooManyArrayElementsError(numElements, elementSize); + throw QueryExecutionErrors.tooManyArrayElementsError( + fixedPartInBytesLong, Integer.MAX_VALUE); } // it's now safe to cast fixedPartInBytesLong and totalInitialSize to int diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 25da787b8874f..196c4a6cdd69f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2658,7 +2658,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val arrayData = inputs.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(numberOfElements) + throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + prettyName, numberOfElements) } val finalData = new Array[AnyRef](numberOfElements.toInt) var position = 0 @@ -2839,7 +2840,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran val arrayData = elements.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.flattenArraysWithElementsExceedLimitError(numberOfElements) + throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + prettyName, numberOfElements) } val flattenedData = new Array(numberOfElements.toInt) var position = 0 @@ -3552,7 +3554,8 @@ case class ArrayRepeat(left: Expression, right: Expression) null } else { if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(count) + throw QueryExecutionErrors.createArrayWithElementsExceedLimitError( + prettyName, count) } val element = left.eval(input) new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) @@ -3842,10 +3845,12 @@ trait ArraySetLike { builder: String, value : String, size : String, - nullElementIndex : String): String = withResultArrayNullCheck( + nullElementIndex : String, + functionName: String): String = withResultArrayNullCheck( s""" |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($size); + | throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + | "$functionName", $size); |} | |if (!UnsafeArrayData.shouldUseGenericArrayData(${et.defaultSize}, $size)) { @@ -3903,7 +3908,8 @@ case class ArrayDistinct(child: Expression) (value: Any) => if (!hs.contains(value)) { if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) + throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + prettyName, arrayBuffer.size) } arrayBuffer += value hs.add(value) @@ -4013,7 +4019,7 @@ case class ArrayDistinct(child: Expression) |for (int $i = 0; $i < $array.numElements(); $i++) { | $processArray |} - |${buildResultArray(builder, ev.value, size, nullElementIndex)} + |${buildResultArray(builder, ev.value, size, nullElementIndex, prettyName)} """.stripMargin }) } else { @@ -4048,13 +4054,6 @@ trait ArrayBinaryLike } } -object ArrayBinaryLike { - def throwUnionLengthOverflowException(length: Int): Unit = { - throw QueryExecutionErrors.unionArrayWithElementsExceedLimitError(length) - } -} - - /** * Returns an array of the elements in the union of x and y, without duplicates */ @@ -4082,7 +4081,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi (value: Any) => if (!hs.contains(value)) { if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) + throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + prettyName, arrayBuffer.size) } arrayBuffer += value hs.add(value) @@ -4125,7 +4125,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi } if (!found) { if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.length) + throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + prettyName, arrayBuffer.length) } arrayBuffer += elem } @@ -4213,7 +4214,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi | $processArray | } |} - |${buildResultArray(builder, ev.value, size, nullElementIndex)} + |${buildResultArray(builder, ev.value, size, nullElementIndex, prettyName)} """.stripMargin }) } else { @@ -4230,44 +4231,6 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi newLeft: Expression, newRight: Expression): ArrayUnion = copy(left = newLeft, right = newRight) } -object ArrayUnion { - def unionOrdering( - array1: ArrayData, - array2: ArrayData, - elementType: DataType, - ordering: Ordering[Any]): ArrayData = { - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - var alreadyIncludeNull = false - Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { - var found = false - if (elem == null) { - if (alreadyIncludeNull) { - found = true - } else { - alreadyIncludeNull = true - } - } else { - // check elem is already stored in arrayBuffer or not? - var j = 0 - while (!found && j < arrayBuffer.size) { - val va = arrayBuffer(j) - if (va != null && ordering.equiv(va, elem)) { - found = true - } - j = j + 1 - } - } - if (!found) { - if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.length) - } - arrayBuffer += elem - } - })) - new GenericArrayData(arrayBuffer) - } -} - /** * Returns an array of the elements in the intersect of x and y, without duplicates */ @@ -4482,7 +4445,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina |for (int $i = 0; $i < $array1.numElements(); $i++) { | $processArray1 |} - |${buildResultArray(builder, ev.value, size, nullElementIndex)} + |${buildResultArray(builder, ev.value, size, nullElementIndex, prettyName)} """.stripMargin }) } else { @@ -4693,7 +4656,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL |for (int $i = 0; $i < $array1.numElements(); $i++) { | $processArray1 |} - |${buildResultArray(builder, ev.value, size, nullElementIndex)} + |${buildResultArray(builder, ev.value, size, nullElementIndex, prettyName)} """.stripMargin }) } else { @@ -4808,7 +4771,8 @@ case class ArrayInsert( val newArrayLength = math.max(baseArr.numElements() + 1, positivePos.get) if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength) + throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + prettyName, newArrayLength) } val newArray = new Array[Any](newArrayLength) @@ -4842,7 +4806,8 @@ case class ArrayInsert( val newArrayLength = -posInt + baseOffset if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength) + throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + prettyName, newArrayLength) } val newArray = new Array[Any](newArrayLength) @@ -4866,7 +4831,8 @@ case class ArrayInsert( val newArrayLength = math.max(baseArr.numElements() + 1, posInt + 1) if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength) + throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + prettyName, newArrayLength) } val newArray = new Array[Any](newArrayLength) @@ -4912,7 +4878,8 @@ case class ArrayInsert( | |final int $resLength = java.lang.Math.max($arr.numElements() + 1, ${positivePos.get}); |if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength); + | throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + | "$prettyName", $resLength); |} | |$allocation @@ -4949,7 +4916,8 @@ case class ArrayInsert( | | $resLength = java.lang.Math.abs($pos) + $baseOffset; | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength); + | throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + | "$prettyName", $resLength); | } | | $allocation @@ -4976,7 +4944,8 @@ case class ArrayInsert( | | $resLength = java.lang.Math.max($arr.numElements() + 1, $itemInsertionIndex + 1); | if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength); + | throw QueryExecutionErrors.createArrayWithElementsExceedLimitError( + | "$prettyName", $resLength); | } | | $allocation diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 3a60480cfc52a..0d315f526d9e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1400,36 +1400,27 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE summary = getSummary(context)) } - def concatArraysWithElementsExceedLimitError(numberOfElements: Long): SparkRuntimeException = { + def arrayFunctionWithElementsExceedLimitError( + prettyName: String, numberOfElements: Long): SparkRuntimeException = { new SparkRuntimeException( - errorClass = "_LEGACY_ERROR_TEMP_2159", + errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.FUNCTION", messageParameters = Map( "numberOfElements" -> numberOfElements.toString(), - "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) - } - - def flattenArraysWithElementsExceedLimitError(numberOfElements: Long): SparkRuntimeException = { - new SparkRuntimeException( - errorClass = "_LEGACY_ERROR_TEMP_2160", - messageParameters = Map( - "numberOfElements" -> numberOfElements.toString(), - "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) - } - - def createArrayWithElementsExceedLimitError(count: Any): SparkRuntimeException = { - new SparkRuntimeException( - errorClass = "_LEGACY_ERROR_TEMP_2161", - messageParameters = Map( - "count" -> count.toString(), - "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), + "functionName" -> toSQLId(prettyName) + )) } - def unionArrayWithElementsExceedLimitError(length: Int): SparkRuntimeException = { + def createArrayWithElementsExceedLimitError( + prettyName: String, count: Any): SparkRuntimeException = { new SparkRuntimeException( - errorClass = "_LEGACY_ERROR_TEMP_2162", + errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", messageParameters = Map( - "length" -> length.toString(), - "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) + "numberOfElements" -> count.toString, + "functionName" -> toSQLId(prettyName), + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), + "parameter" -> toSQLId("count") + )) } def initialTypeNotTargetDataTypeError( @@ -2567,13 +2558,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def tooManyArrayElementsError( - numElements: Int, - elementSize: Int): SparkIllegalArgumentException = { + numElements: Long, + maxRoundedArrayLength: Int): SparkIllegalArgumentException = { new SparkIllegalArgumentException( - errorClass = "TOO_MANY_ARRAY_ELEMENTS", + errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE", messageParameters = Map( - "numElements" -> numElements.toString, - "size" -> elementSize.toString)) + "numberOfElements" -> numElements.toString, + "maxRoundedArrayLength" -> maxRoundedArrayLength.toString) + ) } def invalidEmptyLocationError(location: String): SparkIllegalArgumentException = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriterSuite.scala index f10fb0754f574..a968b1fe53506 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriterSuite.scala @@ -30,10 +30,10 @@ class UnsafeArrayWriterSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { arrayWriter.initialize(numElements) }, - errorClass = "TOO_MANY_ARRAY_ELEMENTS", + errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE", parameters = Map( - "numElements" -> numElements.toString, - "size" -> elementSize.toString + "numberOfElements" -> (numElements * elementSize).toString, + "maxRoundedArrayLength" -> Int.MaxValue.toString ) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index dd3f3dc60048a..a49352cbe5080 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -33,29 +33,32 @@ import org.apache.spark._ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{NamedParameter, UnresolvedGenerator} -import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, RowNumber} +import org.apache.spark.sql.catalyst.expressions.{Concat, CreateArray, EmptyRow, Flatten, Grouping, Literal, RowNumber} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.InitializeJavaBean import org.apache.spark.sql.catalyst.rules.RuleIdCollection import org.apache.spark.sql.catalyst.util.BadRecordException -import org.apache.spark.sql.errors.DataTypeErrorsBase import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions} import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.execution.datasources.orc.OrcTest import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.execution.streaming.FileSystemBasedCheckpointFileManager +import org.apache.spark.sql.execution.vectorized.ConstantColumnVector import org.apache.spark.sql.functions.{lit, lower, struct, sum, udf} import org.apache.spark.sql.internal.LegacyBehaviorPolicy.EXCEPTION import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{DataType, DecimalType, LongType, MetadataBuilder, StructType} +import org.apache.spark.sql.types.{ArrayType, BooleanType, DataType, DecimalType, LongType, MetadataBuilder, StructType} +import org.apache.spark.sql.vectorized.ColumnarArray +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.util.ThreadUtils import org.apache.spark.util.Utils + class QueryExecutionErrorsSuite extends QueryTest with ParquetTest @@ -1095,6 +1098,56 @@ class QueryExecutionErrorsSuite ) ) } + + test("Elements exceed limit for concat()") { + val array = new ColumnarArray( + new ConstantColumnVector(Int.MaxValue, BooleanType), 0, Int.MaxValue) + + checkError( + exception = intercept[SparkRuntimeException] { + Concat(Seq(Literal.create(array, ArrayType(BooleanType)))).eval(EmptyRow) + }, + errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.FUNCTION", + parameters = Map( + "numberOfElements" -> Int.MaxValue.toString, + "maxRoundedArrayLength" -> MAX_ROUNDED_ARRAY_LENGTH.toString, + "functionName" -> toSQLId("concat") + ) + ) + } + + test("Elements exceed limit for flatten()") { + val array = new ColumnarArray( + new ConstantColumnVector(Int.MaxValue, BooleanType), 0, Int.MaxValue) + + checkError( + exception = intercept[SparkRuntimeException] { + Flatten(CreateArray(Seq(Literal.create(array, ArrayType(BooleanType))))).eval(EmptyRow) + }, + errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.FUNCTION", + parameters = Map( + "numberOfElements" -> Int.MaxValue.toString, + "maxRoundedArrayLength" -> MAX_ROUNDED_ARRAY_LENGTH.toString, + "functionName" -> toSQLId("flatten") + ) + ) + } + + test("Elements exceed limit for array_repeat()") { + val count = 2147483647 + checkError( + exception = intercept[SparkRuntimeException] { + sql(s"select array_repeat(1, $count)").collect() + }, + errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", + parameters = Map( + "parameter" -> toSQLId("count"), + "numberOfElements" -> count.toString, + "functionName" -> toSQLId("array_repeat"), + "maxRoundedArrayLength" -> MAX_ROUNDED_ARRAY_LENGTH.toString + ) + ) + } } class FakeFileSystemSetPermission extends LocalFileSystem { From 7c513807352b90464275667f8182dffd0019da77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A2=81=E7=84=8A=E5=BF=A0?= Date: Mon, 6 Nov 2023 16:46:31 +0800 Subject: [PATCH 029/121] [MINOR][DOCS] Fix some spelling typos MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Fixed typo. ### Why are the changes needed? To help make spark perfect. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Just comment typo, no need to test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43634 from YuanHanzhong/patch-3. Lead-authored-by: 袁焊忠 Co-authored-by: YuanHanzhong Signed-off-by: Kent Yao --- hadoop-cloud/README.md | 2 +- .../execution/command/AlterTableDropPartitionSuiteBase.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hadoop-cloud/README.md b/hadoop-cloud/README.md index 840ff1576f527..dc7647b75b1b1 100644 --- a/hadoop-cloud/README.md +++ b/hadoop-cloud/README.md @@ -16,5 +16,5 @@ Integration tests will have some extra configurations for example selecting the run the test against. Those configs are passed as environment variables and the existence of these variables must be checked by the test. Like for `AwsS3AbortableStreamBasedCheckpointFileManagerSuite` the S3 bucket used for testing -is passed in the `S3A_PATH` and the credetinals to access AWS S3 are AWS_ACCESS_KEY_ID and +is passed in the `S3A_PATH` and the credentials to access AWS S3 are AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY (in addition you can define optional AWS_SESSION_TOKEN and AWS_ENDPOINT_URL too). diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala index 1e786c8e57811..199d1b8b4b674 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala @@ -155,7 +155,7 @@ trait AlterTableDropPartitionSuiteBase extends QueryTest with DDLCommandTestUtil } } - test("SPARK-33990: don not return data from dropped partition") { + test("SPARK-33990: do not return data from dropped partition") { withNamespaceAndTable("ns", "tbl") { t => sql(s"CREATE TABLE $t (id int, part int) $defaultUsing PARTITIONED BY (part)") sql(s"INSERT INTO $t PARTITION (part=0) SELECT 0") From fc867266f0898866ab5ff7ed82b0c7c5fbaccefc Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 6 Nov 2023 18:01:11 +0800 Subject: [PATCH 030/121] [SPARK-45758][SQL] Introduce a mapper for hadoop compression codecs ### What changes were proposed in this pull request? Currently, Spark supported partial Hadoop compression codecs, but the Hadoop supported compression codecs and spark supported are not completely one-on-one due to Spark introduce two fake compression codecs none and uncompress. There are a lot of magic strings copy from Hadoop compression codecs. This issue lead to developers need to manually maintain its consistency. It is easy to make mistakes and reduce development efficiency. ### Why are the changes needed? Let developers easy to use Hadoop compression codecs. ### Does this PR introduce _any_ user-facing change? 'No'. Introduce a new class. ### How was this patch tested? Exists test cases. ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #43620 from beliefer/SPARK-45758. Authored-by: Jiaan Geng Signed-off-by: Jiaan Geng --- .../catalyst/util/HadoopCompressionCodec.java | 63 +++++++++++++++++++ .../sql/catalyst/util/CompressionCodecs.scala | 12 +--- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../benchmark/DataSourceReadBenchmark.scala | 8 ++- .../execution/datasources/csv/CSVSuite.scala | 4 +- .../datasources/json/JsonSuite.scala | 4 +- .../datasources/text/TextSuite.scala | 10 +-- .../datasources/text/WholeTextFileSuite.scala | 3 +- 8 files changed, 87 insertions(+), 21 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/HadoopCompressionCodec.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/HadoopCompressionCodec.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/HadoopCompressionCodec.java new file mode 100644 index 0000000000000..ee4cb4da322b8 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/HadoopCompressionCodec.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +import java.util.Arrays; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.hadoop.io.compress.BZip2Codec; +import org.apache.hadoop.io.compress.CompressionCodec; +import org.apache.hadoop.io.compress.DeflateCodec; +import org.apache.hadoop.io.compress.GzipCodec; +import org.apache.hadoop.io.compress.Lz4Codec; +import org.apache.hadoop.io.compress.SnappyCodec; + +/** + * A mapper class from Spark supported hadoop compression codecs to hadoop compression codecs. + */ +public enum HadoopCompressionCodec { + NONE(null), + UNCOMPRESSED(null), + BZIP2(new BZip2Codec()), + DEFLATE(new DeflateCodec()), + GZIP(new GzipCodec()), + LZ4(new Lz4Codec()), + SNAPPY(new SnappyCodec()); + + // TODO supports ZStandardCodec + + private final CompressionCodec compressionCodec; + + HadoopCompressionCodec(CompressionCodec compressionCodec) { + this.compressionCodec = compressionCodec; + } + + public CompressionCodec getCompressionCodec() { + return this.compressionCodec; + } + + private static final Map codecNameMap = + Arrays.stream(HadoopCompressionCodec.values()).collect( + Collectors.toMap(Enum::name, codec -> codec.name().toLowerCase(Locale.ROOT))); + + public String lowerCaseName() { + return codecNameMap.get(this.name()); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala index 1377a03d93b7e..a1d6446cc1053 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala @@ -21,19 +21,13 @@ import java.util.Locale import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.SequenceFile.CompressionType -import org.apache.hadoop.io.compress._ import org.apache.spark.util.Utils object CompressionCodecs { - private val shortCompressionCodecNames = Map( - "none" -> null, - "uncompressed" -> null, - "bzip2" -> classOf[BZip2Codec].getName, - "deflate" -> classOf[DeflateCodec].getName, - "gzip" -> classOf[GzipCodec].getName, - "lz4" -> classOf[Lz4Codec].getName, - "snappy" -> classOf[SnappyCodec].getName) + private val shortCompressionCodecNames = HadoopCompressionCodec.values().map { codec => + codec.lowerCaseName() -> Option(codec.getCompressionCodec).map(_.getClass.getName).orNull + }.toMap /** * Return the full version of the given codec class. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b0a0b189cb7f1..d3271283baa33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.HadoopCompressionCodec.GZIP import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, SortExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -2766,7 +2767,8 @@ class DataFrameSuite extends QueryTest // The data set has 2 partitions, so Spark will write at least 2 json files. // Use a non-splittable compression (gzip), to make sure the json scan RDD has at least 2 // partitions. - .write.partitionBy("p").option("compression", "gzip").json(path.getCanonicalPath) + .write.partitionBy("p") + .option("compression", GZIP.lowerCaseName()).json(path.getCanonicalPath) val numJobs = new AtomicLong(0) sparkContext.addSparkListener(new SparkListener { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index 74043bac49a3e..ea90cd9cd09b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, TestUtils} import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.HadoopCompressionCodec.GZIP import org.apache.spark.sql.execution.datasources.orc.OrcCompressionCodec import org.apache.spark.sql.execution.datasources.parquet.{ParquetCompressionCodec, VectorizedParquetRecordReader} import org.apache.spark.sql.internal.SQLConf @@ -91,12 +92,15 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { } private def saveAsCsvTable(df: DataFrameWriter[Row], dir: String): Unit = { - df.mode("overwrite").option("compression", "gzip").option("header", true).csv(dir) + df.mode("overwrite") + .option("compression", GZIP.lowerCaseName()) + .option("header", true) + .csv(dir) spark.read.option("header", true).csv(dir).createOrReplaceTempView("csvTable") } private def saveAsJsonTable(df: DataFrameWriter[Row], dir: String): Unit = { - df.mode("overwrite").option("compression", "gzip").json(dir) + df.mode("overwrite").option("compression", GZIP.lowerCaseName()).json(dir) spark.read.json(dir).createOrReplaceTempView("jsonTable") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index a84aea2786823..a2ce9b5db2a0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -39,7 +39,7 @@ import org.apache.logging.log4j.Level import org.apache.spark.{SparkConf, SparkException, SparkFileNotFoundException, SparkRuntimeException, SparkUpgradeException, TestUtils} import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Encoders, QueryTest, Row} import org.apache.spark.sql.catalyst.csv.CSVOptions -import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, HadoopCompressionCodec} import org.apache.spark.sql.execution.datasources.CommonFileDataSourceSuite import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.test.SharedSparkSession @@ -874,7 +874,7 @@ abstract class CSVSuite cars.coalesce(1).write .format("csv") .option("header", "true") - .option("compression", "none") + .option("compression", HadoopCompressionCodec.NONE.lowerCaseName()) .options(extraOptions) .save(csvDir) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 2f8b0a323dc8c..d906ae80a80ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkFileNotFoundException, import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, HadoopCompressionCodec} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLType import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, DataSource, InMemoryFileIndex, NoopCache} @@ -1689,7 +1689,7 @@ abstract class JsonSuite val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .format("json") - .option("compression", "none") + .option("compression", HadoopCompressionCodec.NONE.lowerCaseName()) .options(extraOptions) .save(jsonDir) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index ff6b9aadf7cfb..6e3210f8c1778 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkConf, TestUtils} import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.util.HadoopCompressionCodec.{BZIP2, DEFLATE, GZIP, NONE} import org.apache.spark.sql.execution.datasources.CommonFileDataSourceSuite import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -92,7 +93,8 @@ abstract class TextSuite extends QueryTest with SharedSparkSession with CommonFi test("SPARK-13503 Support to specify the option for compression codec for TEXT") { val testDf = spark.read.text(testFile) - val extensionNameMap = Map("bzip2" -> ".bz2", "deflate" -> ".deflate", "gzip" -> ".gz") + val extensionNameMap = Seq(BZIP2, DEFLATE, GZIP) + .map(codec => codec.lowerCaseName() -> codec.getCompressionCodec.getDefaultExtension) extensionNameMap.foreach { case (codecName, extension) => val tempDir = Utils.createTempDir() @@ -122,7 +124,7 @@ abstract class TextSuite extends QueryTest with SharedSparkSession with CommonFi withTempDir { dir => val testDf = spark.read.text(testFile) val tempDirPath = dir.getAbsolutePath - testDf.write.option("compression", "none") + testDf.write.option("compression", NONE.lowerCaseName()) .options(extraOptions).mode(SaveMode.Overwrite).text(tempDirPath) val compressedFiles = new File(tempDirPath).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".txt.gz"))) @@ -141,7 +143,7 @@ abstract class TextSuite extends QueryTest with SharedSparkSession with CommonFi withTempDir { dir => val testDf = spark.read.text(testFile) val tempDirPath = dir.getAbsolutePath - testDf.write.option("CoMpReSsIoN", "none") + testDf.write.option("CoMpReSsIoN", NONE.lowerCaseName()) .options(extraOptions).mode(SaveMode.Overwrite).text(tempDirPath) val compressedFiles = new File(tempDirPath).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".txt.gz"))) @@ -166,7 +168,7 @@ abstract class TextSuite extends QueryTest with SharedSparkSession with CommonFi withTempDir { dir => val path = dir.getCanonicalPath val df1 = spark.range(0, 1000).selectExpr("CAST(id AS STRING) AS s") - df1.write.option("compression", "gzip").mode("overwrite").text(path) + df1.write.option("compression", GZIP.lowerCaseName()).mode("overwrite").text(path) val expected = df1.collect() Seq(10, 100, 1000).foreach { bytes => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala index f4812844cbae3..57e08c5587479 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala @@ -21,6 +21,7 @@ import java.io.File import org.apache.spark.SparkConf import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.util.HadoopCompressionCodec.GZIP import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{StringType, StructType} @@ -90,7 +91,7 @@ abstract class WholeTextFileSuite extends QueryTest with SharedSparkSession { withTempDir { dir => val path = dir.getCanonicalPath val df1 = spark.range(0, 1000).selectExpr("CAST(id AS STRING) AS s").repartition(1) - df1.write.option("compression", "gzip").mode("overwrite").text(path) + df1.write.option("compression", GZIP.lowerCaseName()).mode("overwrite").text(path) // On reading through wholetext mode, one file will be read as a single row, i.e. not // delimited by "next line" character. val expected = Row(df1.collect().map(_.getString(0)).mkString("", "\n", "\n")) From f6038302dd615f4bf9bed9c4af3d04426f7e5c5e Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 6 Nov 2023 20:06:39 +0800 Subject: [PATCH 031/121] [SPARK-45793][CORE] Improve the built-in compression codecs ### What changes were proposed in this pull request? Currently, Spark supported many built-in compression codecs used for I/O and storage. There are a lot of magic strings copy from built-in compression codecs. This issue lead to developers need to manually maintain its consistency. It is easy to make mistakes and reduce development efficiency. ### Why are the changes needed? Improve some code for storage compression codecs ### Does this PR introduce _any_ user-facing change? 'No'. ### How was this patch tested? N/A ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #43659 from beliefer/improve_storage_code. Authored-by: Jiaan Geng Signed-off-by: Jiaan Geng --- .../history/HistoryServerMemoryManager.scala | 3 ++- .../spark/internal/config/package.scala | 7 ++++--- .../apache/spark/io/CompressionCodec.scala | 21 +++++++++++-------- .../history/EventLogFileWritersSuite.scala | 6 +++--- .../history/FsHistoryProviderSuite.scala | 5 +++-- .../spark/io/CompressionCodecSuite.scala | 8 +++---- .../spark/storage/FallbackStorageSuite.scala | 3 ++- .../ExternalAppendOnlyMapSuite.scala | 3 +-- .../k8s/integrationtest/BasicTestsSuite.scala | 3 ++- .../apache/spark/sql/internal/SQLConf.scala | 3 ++- .../sql/execution/streaming/OffsetSeq.scala | 3 ++- .../streaming/state/RocksDBFileManager.scala | 2 +- 12 files changed, 38 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerMemoryManager.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerMemoryManager.scala index 00e58cbdc57b9..b95f1ed24f376 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerMemoryManager.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config.History._ +import org.apache.spark.io.CompressionCodec import org.apache.spark.util.Utils /** @@ -75,7 +76,7 @@ private class HistoryServerMemoryManager( private def approximateMemoryUsage(eventLogSize: Long, codec: Option[String]): Long = { codec match { - case Some("zstd") => + case Some(CompressionCodec.ZSTD) => eventLogSize * 10 case Some(_) => eventLogSize * 4 diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 93a42eec8326b..bbadf91fc41cb 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -21,6 +21,7 @@ import java.util.Locale import java.util.concurrent.TimeUnit import org.apache.spark.SparkContext +import org.apache.spark.io.CompressionCodec import org.apache.spark.launcher.SparkLauncher import org.apache.spark.metrics.GarbageCollectionMetrics import org.apache.spark.network.shuffle.Constants @@ -1530,7 +1531,7 @@ package object config { "use fully qualified class names to specify the codec.") .version("3.0.0") .stringConf - .createWithDefault("zstd") + .createWithDefault(CompressionCodec.ZSTD) private[spark] val SHUFFLE_SPILL_INITIAL_MEM_THRESHOLD = ConfigBuilder("spark.shuffle.spill.initialMemoryThreshold") @@ -1871,7 +1872,7 @@ package object config { "the codec") .version("0.8.0") .stringConf - .createWithDefaultString("lz4") + .createWithDefaultString(CompressionCodec.LZ4) private[spark] val IO_COMPRESSION_ZSTD_BUFFERSIZE = ConfigBuilder("spark.io.compression.zstd.bufferSize") @@ -1914,7 +1915,7 @@ package object config { "the codec.") .version("3.0.0") .stringConf - .createWithDefault("zstd") + .createWithDefault(CompressionCodec.ZSTD) private[spark] val BUFFER_SIZE = ConfigBuilder("spark.buffer.size") diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0bb392deb3923..a6a5b1f67c6f9 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -58,18 +58,21 @@ trait CompressionCodec { private[spark] object CompressionCodec { - private val configKey = IO_COMPRESSION_CODEC.key - private[spark] def supportsConcatenationOfSerializedStreams(codec: CompressionCodec): Boolean = { (codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec] || codec.isInstanceOf[LZ4CompressionCodec] || codec.isInstanceOf[ZStdCompressionCodec]) } - private val shortCompressionCodecNames = Map( - "lz4" -> classOf[LZ4CompressionCodec].getName, - "lzf" -> classOf[LZFCompressionCodec].getName, - "snappy" -> classOf[SnappyCompressionCodec].getName, - "zstd" -> classOf[ZStdCompressionCodec].getName) + val LZ4 = "lz4" + val LZF = "lzf" + val SNAPPY = "snappy" + val ZSTD = "zstd" + + private[spark] val shortCompressionCodecNames = Map( + LZ4 -> classOf[LZ4CompressionCodec].getName, + LZF -> classOf[LZFCompressionCodec].getName, + SNAPPY -> classOf[SnappyCompressionCodec].getName, + ZSTD -> classOf[ZStdCompressionCodec].getName) def getCodecName(conf: SparkConf): String = { conf.get(IO_COMPRESSION_CODEC) @@ -93,7 +96,7 @@ private[spark] object CompressionCodec { errorClass = "CODEC_NOT_AVAILABLE", messageParameters = Map( "codecName" -> codecName, - "configKey" -> toConf(configKey), + "configKey" -> toConf(IO_COMPRESSION_CODEC.key), "configVal" -> toConfVal(FALLBACK_COMPRESSION_CODEC)))) } @@ -113,7 +116,7 @@ private[spark] object CompressionCodec { } } - val FALLBACK_COMPRESSION_CODEC = "snappy" + val FALLBACK_COMPRESSION_CODEC = SNAPPY val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala index b575cbc080c07..349985207e48c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala @@ -176,7 +176,7 @@ class SingleEventLogFileWriterSuite extends EventLogFileWritersSuite { baseDirUri, "app1", None, None)) // with compression assert(s"${baseDirUri.toString}/app1.lzf" === - SingleEventLogFileWriter.getLogPath(baseDirUri, "app1", None, Some("lzf"))) + SingleEventLogFileWriter.getLogPath(baseDirUri, "app1", None, Some(CompressionCodec.LZF))) // illegal characters in app ID assert(s"${baseDirUri.toString}/a-fine-mind_dollar_bills__1" === SingleEventLogFileWriter.getLogPath(baseDirUri, @@ -184,7 +184,7 @@ class SingleEventLogFileWriterSuite extends EventLogFileWritersSuite { // illegal characters in app ID with compression assert(s"${baseDirUri.toString}/a-fine-mind_dollar_bills__1.lz4" === SingleEventLogFileWriter.getLogPath(baseDirUri, - "a fine:mind$dollar{bills}.1", None, Some("lz4"))) + "a fine:mind$dollar{bills}.1", None, Some(CompressionCodec.LZ4))) } override protected def createWriter( @@ -239,7 +239,7 @@ class RollingEventLogFilesWriterSuite extends EventLogFileWritersSuite { // with compression assert(s"$logDir/${EVENT_LOG_FILE_NAME_PREFIX}1_${appId}.lzf" === RollingEventLogFilesWriter.getEventLogFilePath(logDir, appId, appAttemptId, - 1, Some("lzf")).toString) + 1, Some(CompressionCodec.LZF)).toString) // illegal characters in app ID assert(s"${baseDirUri.toString}/${EVENT_LOG_DIR_NAME_PREFIX}a-fine-mind_dollar_bills__1" === diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index ae8481a852bf2..d16e904bdcf13 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -126,8 +126,9 @@ abstract class FsHistoryProviderSuite extends SparkFunSuite with Matchers with P // Write a new-style application log. val newAppCompressedComplete = newLogFile("new1compressed", None, inProgress = false, - Some("lzf")) - writeFile(newAppCompressedComplete, Some(CompressionCodec.createCodec(conf, "lzf")), + Some(CompressionCodec.LZF)) + writeFile( + newAppCompressedComplete, Some(CompressionCodec.createCodec(conf, CompressionCodec.LZF)), SparkListenerApplicationStart(newAppCompressedComplete.getName(), Some("new-complete-lzf"), 1L, "test", None), SparkListenerApplicationEnd(4L)) diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 244c007f53925..9c9fac0d4833d 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -58,7 +58,7 @@ class CompressionCodecSuite extends SparkFunSuite { } test("lz4 compression codec short form") { - val codec = CompressionCodec.createCodec(conf, "lz4") + val codec = CompressionCodec.createCodec(conf, CompressionCodec.LZ4) assert(codec.getClass === classOf[LZ4CompressionCodec]) testCodec(codec) } @@ -76,7 +76,7 @@ class CompressionCodecSuite extends SparkFunSuite { } test("lzf compression codec short form") { - val codec = CompressionCodec.createCodec(conf, "lzf") + val codec = CompressionCodec.createCodec(conf, CompressionCodec.LZF) assert(codec.getClass === classOf[LZFCompressionCodec]) testCodec(codec) } @@ -94,7 +94,7 @@ class CompressionCodecSuite extends SparkFunSuite { } test("snappy compression codec short form") { - val codec = CompressionCodec.createCodec(conf, "snappy") + val codec = CompressionCodec.createCodec(conf, CompressionCodec.SNAPPY) assert(codec.getClass === classOf[SnappyCompressionCodec]) testCodec(codec) } @@ -115,7 +115,7 @@ class CompressionCodecSuite extends SparkFunSuite { } test("zstd compression codec short form") { - val codec = CompressionCodec.createCodec(conf, "zstd") + val codec = CompressionCodec.createCodec(conf, CompressionCodec.ZSTD) assert(codec.getClass === classOf[ZStdCompressionCodec]) testCodec(codec) } diff --git a/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala index 83c9707bfc273..6c51bd4ff2e2f 100644 --- a/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala @@ -31,6 +31,7 @@ import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TestUtils} import org.apache.spark.LocalSparkContext.withSpark import org.apache.spark.internal.config._ +import org.apache.spark.io.CompressionCodec import org.apache.spark.launcher.SparkLauncher.{EXECUTOR_MEMORY, SPARK_MASTER} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.buffer.ManagedBuffer @@ -292,7 +293,7 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext { } } - Seq("lz4", "lzf", "snappy", "zstd").foreach { codec => + CompressionCodec.shortCompressionCodecNames.keys.foreach { codec => test(s"$codec - Newly added executors should access old data from remote storage") { sc = new SparkContext(getSparkConf(2, 0).set(IO_COMPRESSION_CODEC, codec)) withSpark(sc) { sc => diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 59f6e3f2d359a..2a760c39b46be 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -37,7 +37,6 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with Matchers { import TestUtils.{assertNotSpilled, assertSpilled} - private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS private def createCombiner[T](i: T) = ArrayBuffer[T](i) private def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i private def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] = @@ -224,7 +223,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite // Keep track of which compression codec we're using to report in test failure messages var lastCompressionCodec: Option[String] = None try { - allCompressionCodecs.foreach { c => + CompressionCodec.ALL_COMPRESSION_CODECS.foreach { c => lastCompressionCodec = Some(c) testSimpleSpilling(Some(c), encrypt) } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala index 992fe7c97ff1a..d6911aadfa237 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.time.{Seconds, Span} import org.apache.spark.{SparkFunSuite, TestUtils} import org.apache.spark.deploy.k8s.integrationtest.KubernetesSuite.SPARK_PI_MAIN_CLASS +import org.apache.spark.io.CompressionCodec import org.apache.spark.launcher.SparkLauncher private[spark] trait BasicTestsSuite { k8sSuite: KubernetesSuite => @@ -93,7 +94,7 @@ private[spark] trait BasicTestsSuite { k8sSuite: KubernetesSuite => test("Run SparkPi with an argument.", k8sTestTag) { // This additional configuration with snappy is for SPARK-26995 sparkAppConf - .set("spark.io.compression.codec", "snappy") + .set("spark.io.compression.codec", CompressionCodec.SNAPPY) runSparkPiAndVerifyCompletion(appArgs = Array("5")) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1af0b41d0faac..ecc3e6e101fcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{ErrorMessageFormat, SparkConf, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.io.CompressionCodec import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{HintErrorLogger, Resolver} @@ -1997,7 +1998,7 @@ object SQLConf { "use fully qualified class names to specify the codec. Default codec is lz4.") .version("3.1.0") .stringConf - .createWithDefault("lz4") + .createWithDefault(CompressionCodec.LZ4) val CHECKPOINT_RENAMEDFILE_CHECK_ENABLED = buildConf("spark.sql.streaming.checkpoint.renamedFileCheck.enabled") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 913805d1a074d..dea75e3ec4783 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -21,6 +21,7 @@ import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging +import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager, SymmetricHashJoinStateManager} @@ -118,7 +119,7 @@ object OffsetSeqMetadata extends Logging { StreamingAggregationStateManager.legacyVersion.toString, STREAMING_JOIN_STATE_FORMAT_VERSION.key -> SymmetricHashJoinStateManager.legacyVersion.toString, - STATE_STORE_COMPRESSION_CODEC.key -> "lz4", + STATE_STORE_COMPRESSION_CODEC.key -> CompressionCodec.LZ4, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false" ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala index 6a62a6c52f519..046cf69f1fcaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala @@ -126,7 +126,7 @@ class RocksDBFileManager( dfsRootDir: String, localTempDir: File, hadoopConf: Configuration, - codecName: String = "zstd", + codecName: String = CompressionCodec.ZSTD, loggingId: String = "") extends Logging { From 82ffd253520297197fd7c27fe3c907d941124dce Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Mon, 6 Nov 2023 11:10:15 -0800 Subject: [PATCH 032/121] [SPARK-45773][PYTHON][DOCS] Refine docstring of SparkSession.builder.config ### What changes were proposed in this pull request? This PR refines the docstring of the method `SparkSession.builder.config`. ### Why are the changes needed? To improve PySpark documentation. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? doc test ### Was this patch authored or co-authored using generative AI tooling? No Closes #43639 from allisonwang-db/spark-45773-config. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/session.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 1ffb602ba86bc..4ab7281d7ac87 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -253,12 +253,17 @@ def config( ------- :class:`SparkSession.Builder` + See Also + -------- + :class:`SparkConf` + Examples -------- - For an existing class:`SparkConf`, use `conf` parameter. + For an existing :class:`SparkConf`, use `conf` parameter. >>> from pyspark.conf import SparkConf - >>> SparkSession.builder.config(conf=SparkConf()) + >>> conf = SparkConf().setAppName("example").setMaster("local") + >>> SparkSession.builder.config(conf=conf) >> SparkSession.builder.config("spark.some.config.option", "some-value") >> SparkSession.builder.config( + ... "spark.some.config.number", 123).config("spark.some.config.float", 0.123) + >> SparkSession.builder.config( ... map={"spark.some.config.number": 123, "spark.some.config.float": 0.123}) From 74f9ebeac8a38a57d683b1636857045329ec29f6 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Mon, 6 Nov 2023 11:21:28 -0800 Subject: [PATCH 033/121] [MINOR][INFRA] Correct Java version in RM Dockerfile description ### What changes were proposed in this pull request? As title, it's a super nit fix. ### Why are the changes needed? Correct the description to match the install JDK version. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Review ### Was this patch authored or co-authored using generative AI tooling? No Closes #43669 from pan3793/minor-rm. Authored-by: Cheng Pan Signed-off-by: Hyukjin Kwon --- dev/create-release/spark-rm/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index 50562e38fb562..b628798bbf593 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -18,7 +18,7 @@ # Image for building Spark releases. Based on Ubuntu 20.04. # # Includes: -# * Java 8 +# * Java 17 # * Ivy # * Python (3.8.5) # * R-base/R-base-dev (4.0.3) From 58c30ad4f6c5ee88bb31f148f106db613f7ff962 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Mon, 6 Nov 2023 11:26:36 -0800 Subject: [PATCH 034/121] [SPARK-44751][SQL] Move `XSDToSchema` from `catalyst` to `core` package ### What changes were proposed in this pull request? Move XSDToSchema from catalyst to core package ### Why are the changes needed? It facilitates the a follow-up refactoring PR. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43652 from shujingyang-db/mv-xsdToSchema. Authored-by: Shujing Yang Signed-off-by: Hyukjin Kwon --- .../spark/sql/execution/datasources}/xml/XSDToSchema.scala | 3 ++- .../sql/execution/datasources/xml/util/XSDToSchemaSuite.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) rename sql/{catalyst/src/main/scala/org/apache/spark/sql/catalyst => core/src/main/scala/org/apache/spark/sql/execution/datasources}/xml/XSDToSchema.scala (99%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XSDToSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XSDToSchema.scala similarity index 99% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XSDToSchema.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XSDToSchema.scala index 6c3958f5dd1a6..b0894ed34844e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XSDToSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XSDToSchema.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.catalyst.xml +package org.apache.spark.sql.execution.datasources.xml import java.io.{File, FileInputStream, InputStreamReader, StringReader} import java.nio.charset.StandardCharsets @@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._ import org.apache.ws.commons.schema._ import org.apache.ws.commons.schema.constants.Constants +import org.apache.spark.sql.catalyst.xml.XmlOptions import org.apache.spark.sql.types._ /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/util/XSDToSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/util/XSDToSchemaSuite.scala index e163c05846d02..9d8b1eec8f731 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/util/XSDToSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/util/XSDToSchemaSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.datasources.xml.util import java.nio.file.Paths -import org.apache.spark.sql.catalyst.xml.XSDToSchema import org.apache.spark.sql.execution.datasources.xml.TestUtils._ +import org.apache.spark.sql.execution.datasources.xml.XSDToSchema import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, LongType, StringType} From 2764ee0e9329a68ef10b6dc79fec20c722aaaf96 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Mon, 6 Nov 2023 21:18:43 +0100 Subject: [PATCH 035/121] [SPARK-45805][SQL] Make `withOrigin` more generic ### What changes were proposed in this pull request? In the PR, I propose to change the implementation of `sql.withOrigin`, and eliminate the magic number 3 from which the algorithm starts iterations. New implementation starts from the index 0, and finds the first block of Spark traces. It stops immediately after the block at the first non-Spark trace. For example: Screenshot 2023-11-01 at 21 29 18 new implementation finds the block [2, 4], and stops at the index 5 by catching and returning the block of traces [4, 6]. ### Why are the changes needed? The PR makes `withOrigin` more generic and improves code maintenance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By existing test suites, for instance: ``` $ build/sbt "test:testOnly *DatasetSuite" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43671 from MaxGekk/refactor-withOrigin. Authored-by: Max Gekk Signed-off-by: Peter Toth --- .../sql/catalyst/trees/QueryContexts.scala | 18 ++++++++++-------- .../scala/org/apache/spark/sql/package.scala | 8 +++++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index b8288b24535e8..8d885d07ca8b0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -160,14 +160,16 @@ case class DataFrameQueryContext( object DataFrameQueryContext { def apply(elements: Array[StackTraceElement]): DataFrameQueryContext = { - val methodName = elements(0).getMethodName - val code = if (methodName.length > 1 && methodName(0) == '$') { - methodName.substring(1) - } else { - methodName - } - val callSite = elements(1).toString + val fragment = elements.headOption.map { firstElem => + val methodName = firstElem.getMethodName + if (methodName.length > 1 && methodName(0) == '$') { + methodName.substring(1) + } else { + methodName + } + }.getOrElse("") + val callSite = elements.tail.headOption.map(_.toString).getOrElse("") - DataFrameQueryContext(code, callSite) + DataFrameQueryContext(fragment, callSite) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 7f00f6d6317c8..96bef83af0a86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -98,10 +98,12 @@ package object sql { f } else { val st = Thread.currentThread().getStackTrace - var i = 3 + var i = 0 + // Find the beginning of Spark code traces + while (i < st.length && !sparkCode(st(i))) i += 1 + // Stop at the end of the first Spark code traces while (i < st.length && sparkCode(st(i))) i += 1 - val origin = - Origin(stackTrace = Some(Thread.currentThread().getStackTrace.slice(i - 1, i + 1))) + val origin = Origin(stackTrace = Some(st.slice(i - 1, i + 1))) CurrentOrigin.withOrigin(origin)(f) } } From 370870b7a0303e4a2c4b3dea1b479b4fcbc93f8d Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Mon, 6 Nov 2023 12:48:52 -0800 Subject: [PATCH 036/121] [SPARK-45803][CORE] Remove the no longer used `RpcAbortException` ### What changes were proposed in this pull request? `RpcAbortException` introduced in SPARK-28483 | https://github.com/apache/spark/pull/25235 and become unused after SPARK-31472 | https://github.com/apache/spark/pull/28245. At the same time, `RpcAbortException` is a `private[spark]` definition, so this pr removes it. ### Why are the changes needed? Clean up useless code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43673 from LuciferYang/SPARK-45803. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/util/SparkThreadUtils.scala | 2 +- .../src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala index ec14688a00625..a5e4cef1ec1a9 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala @@ -49,7 +49,7 @@ private[spark] object SparkThreadUtils { } catch { case e: SparkFatalException => throw e.throwable - // TimeoutException and RpcAbortException is thrown in the current thread, so not need to warp + // TimeoutException is thrown in the current thread, so not need to warp // the exception. case NonFatal(t) if !t.isInstanceOf[TimeoutException] => diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 925dcdba07328..d907763639feb 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -103,11 +103,6 @@ private[spark] abstract class RpcEndpointRef(conf: SparkConf) } -/** - * An exception thrown if the RPC is aborted. - */ -private[spark] class RpcAbortException(message: String) extends Exception(message) - /** * A wrapper for [[Future]] but add abort method. * This is used in long run RPC and provide an approach to abort the RPC. From 296d3c50d2b6212040836b5fe0f86d9890564dc6 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 6 Nov 2023 20:06:01 -0800 Subject: [PATCH 037/121] [SPARK-45804][UI] Add spark.ui.threadDump.flamegraphEnabled config to switch flame graph on/off ### What changes were proposed in this pull request? Add spark.ui.threadDump.flamegraphEnabled config to switch flame graph on/off ### Why are the changes needed? UI stability ### Does this PR introduce _any_ user-facing change? yes, new spark.ui.threadDump.flamegraphEnabled config ### How was this patch tested? locally tested bin/spark-shell -c spark.ui.threadDump.flamegraphEnabled=true ![image](https://github.com/apache/spark/assets/8326978/538e1991-da15-4323-a28e-63565e042804) bin/spark-shell -c spark.ui.threadDump.flamegraphEnabled=false ![image](https://github.com/apache/spark/assets/8326978/4562e359-cec3-41e5-a47e-636d5edf9ec4) ### Was this patch authored or co-authored using generative AI tooling? no Closes #43674 from yaooqinn/SPARK-45804. Authored-by: Kent Yao Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/internal/config/UI.scala | 6 ++++++ .../apache/spark/ui/exec/ExecutorThreadDumpPage.scala | 10 +++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/UI.scala b/core/src/main/scala/org/apache/spark/internal/config/UI.scala index 841d2b494c05e..f983308667e3f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/UI.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/UI.scala @@ -97,6 +97,12 @@ private[spark] object UI { .booleanConf .createWithDefault(true) + val UI_FLAMEGRAPH_ENABLED = ConfigBuilder("spark.ui.threadDump.flamegraphEnabled") + .doc("Whether to render the Flamegraph for executor thread dumps") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val UI_HEAP_HISTOGRAM_ENABLED = ConfigBuilder("spark.ui.heapHistogramEnabled") .version("3.5.0") .booleanConf diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 4a00777c509fb..328abdb5c5f9b 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} import org.apache.spark.SparkContext +import org.apache.spark.internal.config.UI.UI_FLAMEGRAPH_ENABLED import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.ui.{SparkUITab, UIUtils, WebUIPage} import org.apache.spark.ui.UIUtils.prependBaseUri @@ -31,6 +32,8 @@ private[ui] class ExecutorThreadDumpPage( parent: SparkUITab, sc: Option[SparkContext]) extends WebUIPage("threadDump") { + private val flamegraphEnabled = sc.isDefined && sc.get.conf.get(UI_FLAMEGRAPH_ENABLED) + def render(request: HttpServletRequest): Seq[Node] = { val executorId = Option(request.getParameter("executorId")).map { executorId => UIUtils.decodeURLParameter(executorId) @@ -70,7 +73,12 @@ private[ui] class ExecutorThreadDumpPage(

    Updated at {UIUtils.formatDate(time)}

    - {drawExecutorFlamegraph(request, threadDump)} + { if (flamegraphEnabled) { + drawExecutorFlamegraph(request, threadDump) } + else { + Seq.empty + } + } { // scalastyle:off

    From ff4ca0b1bcab7a5f8f14f845a168b12664e16e51 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Mon, 6 Nov 2023 20:10:10 -0800 Subject: [PATCH 038/121] [SPARK-45812][BUILD][PYTHON][PS] Upgrade Pandas to 2.1.2 ### What changes were proposed in this pull request? This PR proposes to upgrade Pandas to 2.1.2. See https://pandas.pydata.org/docs/dev/whatsnew/v2.1.2.html for detail ### Why are the changes needed? Pandas 2.1.2 is released, and we should support the latest Pandas. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The existing CI should pass ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43689 from itholic/SPARK-45812. Authored-by: Haejoon Lee Signed-off-by: Dongjoon Hyun --- dev/infra/Dockerfile | 4 ++-- python/pyspark/pandas/supported_api_gen.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 001de613a9200..e6a58cc3fc753 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -84,8 +84,8 @@ RUN Rscript -e "devtools::install_version('roxygen2', version='7.2.0', repos='ht # See more in SPARK-39735 ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library" -RUN pypy3 -m pip install numpy 'pandas<=2.1.1' scipy coverage matplotlib -RUN python3.9 -m pip install numpy pyarrow 'pandas<=2.1.1' scipy unittest-xml-reporting plotly>=4.8 'mlflow>=2.3.1' coverage matplotlib openpyxl 'memory-profiler==0.60.0' 'scikit-learn==1.1.*' +RUN pypy3 -m pip install numpy 'pandas<=2.1.2' scipy coverage matplotlib +RUN python3.9 -m pip install numpy pyarrow 'pandas<=2.1.2' scipy unittest-xml-reporting plotly>=4.8 'mlflow>=2.3.1' coverage matplotlib openpyxl 'memory-profiler==0.60.0' 'scikit-learn==1.1.*' # Add Python deps for Spark Connect. RUN python3.9 -m pip install 'grpcio>=1.48,<1.57' 'grpcio-status>=1.48,<1.57' 'protobuf==3.20.3' 'googleapis-common-protos==1.56.4' diff --git a/python/pyspark/pandas/supported_api_gen.py b/python/pyspark/pandas/supported_api_gen.py index c4471a0af36d4..8d49fef279946 100644 --- a/python/pyspark/pandas/supported_api_gen.py +++ b/python/pyspark/pandas/supported_api_gen.py @@ -98,7 +98,7 @@ def generate_supported_api(output_rst_file_path: str) -> None: Write supported APIs documentation. """ - pandas_latest_version = "2.1.1" + pandas_latest_version = "2.1.2" if LooseVersion(pd.__version__) != LooseVersion(pandas_latest_version): msg = ( "Warning: Latest version of pandas (%s) is required to generate the documentation; " From 1370e9de0e49012f2886aedb30da5117e37f2ef1 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 6 Nov 2023 22:53:12 -0800 Subject: [PATCH 039/121] [SPARK-45013][TEST] Flaky Test with NPE: track allocated resources by taskId ### What changes were proposed in this pull request? This PR ensures the runningTasks to be updated before subsequent tasks causing NPE ### Why are the changes needed? fix flakey tests ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? It shall fix ``` - track allocated resources by taskId *** FAILED *** (36 milliseconds) [info] java.lang.NullPointerException: Cannot invoke "org.apache.spark.executor.Executor$TaskRunner.taskDescription()" because the return value of "java.util.concurrent.ConcurrentHashMap.get(Object)" is null [info] at org.apache.spark.executor.CoarseGrainedExecutorBackend.statusUpdate(CoarseGrainedExecutorBackend.scala:275) [info] at org.apache.spark.executor.CoarseGrainedExecutorBackendSuite.$anonfun$new$22(CoarseGrainedExecutorBackendSuite.scala:351) [info] ``` ### Was this patch authored or co-authored using generative AI tooling? no Closes #43693 from yaooqinn/SPARK-45013. Authored-by: Kent Yao Signed-off-by: Dongjoon Hyun --- .../spark/executor/CoarseGrainedExecutorBackendSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index 28af0656869b3..35fe0b0d1c908 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -343,6 +343,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite backend.self.send(LaunchTask(new SerializableBuffer(serializedTaskDescription))) eventually(timeout(10.seconds)) { assert(backend.taskResources.size == 1) + assert(runningTasks.size == 1) val resources = backend.taskResources.get(taskId) assert(resources(GPU).addresses sameElements Array("0", "1")) } From 8c41629033508197b80f6a50c78fcf4d5f9752ff Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 7 Nov 2023 16:35:48 +0800 Subject: [PATCH 040/121] [SPARK-45509][SQL] Fix df column reference behavior for Spark Connect ### What changes were proposed in this pull request? This PR fixes a few problems of column resolution for Spark Connect, to make the behavior closer to classic Spark SQL (unfortunately we still have some behavior differences in corner cases). 1. resolve df column references in both `resolveExpressionByPlanChildren` and `resolveExpressionByPlanOutput`. Previously it's only in `resolveExpressionByPlanChildren`. 2. when the plan id has multiple matches, fail with `AMBIGUOUS_COLUMN_REFERENCE` ### Why are the changes needed? fix behavior differences between spark connect and classic spark sql ### Does this PR introduce _any_ user-facing change? Yes, for spark connect scala client ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #43465 from cloud-fan/column. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../main/resources/error/error-classes.json | 9 ++ .../apache/spark/sql/ClientE2ETestSuite.scala | 62 +++++++++++- docs/sql-error-conditions.md | 9 ++ python/pyspark/pandas/indexes/multi.py | 2 +- python/pyspark/sql/connect/plan.py | 4 +- .../analysis/ColumnResolutionHelper.scala | 98 +++++++++++-------- .../analysis/ResolveSetVariable.scala | 1 - 7 files changed, 140 insertions(+), 45 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 3e0743d366ae3..db46ee8ca208c 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -31,6 +31,15 @@ ], "sqlState" : "42702" }, + "AMBIGUOUS_COLUMN_REFERENCE" : { + "message" : [ + "Column is ambiguous. It's because you joined several DataFrame together, and some of these DataFrames are the same.", + "This column points to one of the DataFrame but Spark is unable to figure out which one.", + "Please alias the DataFrames with different names via `DataFrame.alias` before joining them,", + "and specify the column using qualified name, e.g. `df.alias(\"a\").join(df.alias(\"b\"), col(\"a.id\") > col(\"b.id\"))`." + ], + "sqlState" : "42702" + }, "AMBIGUOUS_LATERAL_COLUMN_ALIAS" : { "message" : [ "Lateral column alias is ambiguous and has matches." diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index d9a77f2830b94..b9fa415034c3e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.types._ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { - test(s"throw SparkException with null filename in stack trace elements") { + test("throw SparkException with null filename in stack trace elements") { withSQLConf("spark.sql.connect.enrichError.enabled" -> "true") { val session = spark import session.implicits._ @@ -94,7 +94,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM } } - test(s"throw SparkException with large cause exception") { + test("throw SparkException with large cause exception") { withSQLConf("spark.sql.connect.enrichError.enabled" -> "true") { val session = spark import session.implicits._ @@ -872,6 +872,64 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assert(joined2.schema.catalogString === "struct") } + test("SPARK-45509: ambiguous column reference") { + val session = spark + import session.implicits._ + val df1 = Seq(1 -> "a").toDF("i", "j") + val df1_filter = df1.filter(df1("i") > 0) + val df2 = Seq(2 -> "b").toDF("i", "y") + + checkSameResult( + Seq(Row(1)), + // df1("i") is not ambiguous, and it's still valid in the filtered df. + df1_filter.select(df1("i"))) + + val e1 = intercept[AnalysisException] { + // df1("i") is not ambiguous, but it's not valid in the projected df. + df1.select((df1("i") + 1).as("plus")).select(df1("i")).collect() + } + assert(e1.getMessage.contains("MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_MISSING_FROM_INPUT")) + + checkSameResult( + Seq(Row(1, "a")), + // All these column references are not ambiguous and are still valid after join. + df1.join(df2, df1("i") + 1 === df2("i")).sort(df1("i").desc).select(df1("i"), df1("j"))) + + val e2 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides. + df1.join(df1, df1("i") === 1).collect() + } + assert(e2.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + val e3 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides. + df1.join(df1).select(df1("i")).collect() + } + assert(e3.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + val e4 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides (df1_filter contains df1). + df1.join(df1_filter, df1("i") === 1).collect() + } + assert(e4.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + checkSameResult( + Seq(Row("a")), + // df1_filter("i") is not ambiguous as df1_filter does not exist in the join left side. + df1.join(df1_filter, df1_filter("i") === 1).select(df1_filter("j"))) + + val e5 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both sides of the first join. + df1.join(df1_filter, df1_filter("i") === 1).join(df2, df1("i") === 1).collect() + } + assert(e5.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + checkSameResult( + Seq(Row("a")), + // df1_filter("i") is not ambiguous as df1_filter only appears once. + df1.join(df1_filter).join(df2, df1_filter("i") === 1).select(df1_filter("j"))) + } + test("broadcast join") { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") { val left = spark.range(100).select(col("id"), rand(10).as("a")) diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index a6f003647ddc8..7b0bc8ceb2b5a 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -55,6 +55,15 @@ See '``/sql-migration-guide.html#query-engine'. Column or field `` is ambiguous and has `` matches. +### AMBIGUOUS_COLUMN_REFERENCE + +[SQLSTATE: 42702](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Column `` is ambiguous. It's because you joined several DataFrame together, and some of these DataFrames are the same. +This column points to one of the DataFrame but Spark is unable to figure out which one. +Please alias the DataFrames with different names via `DataFrame.alias` before joining them, +and specify the column using qualified name, e.g. `df.alias("a").join(df.alias("b"), col("a.id") > col("b.id"))`. + ### AMBIGUOUS_LATERAL_COLUMN_ALIAS [SQLSTATE: 42702](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index 41c3b93ed51b6..9fbc608c12a4b 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -813,7 +813,7 @@ def symmetric_difference( # type: ignore[override] sdf_symdiff = sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other)) if sort: - sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_columns) + sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_column_names) internal = InternalFrame( spark_frame=sdf_symdiff, diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index d888422d29f71..607d1429a9efd 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -2251,7 +2251,9 @@ def __init__( self._input_grouping_cols = input_grouping_cols self._other_grouping_cols = other_grouping_cols self._other = cast(LogicalPlan, other) - self._func = function._build_common_inline_user_defined_function(*cols) + # The function takes entire DataFrame as inputs, no need to do + # column binding (no input columns). + self._func = function._build_common_inline_user_defined_function() def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 54a9c6ca01813..edfc60fc6eaa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -30,10 +30,10 @@ import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier} -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors} import org.apache.spark.sql.internal.SQLConf -trait ColumnResolutionHelper extends Logging { +trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { def conf: SQLConf @@ -426,7 +426,7 @@ trait ColumnResolutionHelper extends Logging { throws: Boolean = false, includeLastResort: Boolean = false): Expression = { resolveExpression( - expr, + tryResolveColumnByPlanId(expr, plan), resolveColumnByName = nameParts => { plan.resolve(nameParts, conf.resolver) }, @@ -447,21 +447,8 @@ trait ColumnResolutionHelper extends Logging { e: Expression, q: LogicalPlan, includeLastResort: Boolean = false): Expression = { - val newE = if (e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) { - // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and - // expression are from Spark Connect, and need to be resolved in this way: - // 1, extract the attached plan id from the expression (UnresolvedAttribute only for now); - // 2, top-down traverse the query plan to find the plan node that matches the plan id; - // 3, if can not find the matching node, fail the analysis due to illegal references; - // 4, resolve the expression with the matching node, if any error occurs here, apply the - // old code path; - resolveExpressionByPlanId(e, q) - } else { - e - } - resolveExpression( - newE, + tryResolveColumnByPlanId(e, q), resolveColumnByName = nameParts => { q.resolveChildren(nameParts, conf.resolver) }, @@ -490,39 +477,46 @@ trait ColumnResolutionHelper extends Logging { } } - private def resolveExpressionByPlanId( + // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and + // expression are from Spark Connect, and need to be resolved in this way: + // 1. extract the attached plan id from UnresolvedAttribute; + // 2. top-down traverse the query plan to find the plan node that matches the plan id; + // 3. if can not find the matching node, fail the analysis due to illegal references; + // 4. if more than one matching nodes are found, fail due to ambiguous column reference; + // 5. resolve the expression with the matching node, if any error occurs here, return the + // original expression as it is. + private def tryResolveColumnByPlanId( e: Expression, - q: LogicalPlan): Expression = { - if (!e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) { - return e - } - - e match { - case u: UnresolvedAttribute => - resolveUnresolvedAttributeByPlanId(u, q).getOrElse(u) - case _ => - e.mapChildren(c => resolveExpressionByPlanId(c, q)) - } + q: LogicalPlan, + idToPlan: mutable.HashMap[Long, LogicalPlan] = mutable.HashMap.empty): Expression = e match { + case u: UnresolvedAttribute => + resolveUnresolvedAttributeByPlanId( + u, q, idToPlan: mutable.HashMap[Long, LogicalPlan] + ).getOrElse(u) + case _ if e.containsPattern(UNRESOLVED_ATTRIBUTE) => + e.mapChildren(c => tryResolveColumnByPlanId(c, q, idToPlan)) + case _ => e } private def resolveUnresolvedAttributeByPlanId( u: UnresolvedAttribute, - q: LogicalPlan): Option[NamedExpression] = { + q: LogicalPlan, + idToPlan: mutable.HashMap[Long, LogicalPlan]): Option[NamedExpression] = { val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG) if (planIdOpt.isEmpty) return None val planId = planIdOpt.get logDebug(s"Extract plan_id $planId from $u") - val planOpt = q.find(_.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(planId)) - if (planOpt.isEmpty) { - // For example: - // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) - // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) - // df1.select(df2.a) <- illegal reference df2.a - throw new AnalysisException(s"When resolving $u, " + - s"fail to find subplan with plan_id=$planId in $q") - } - val plan = planOpt.get + val plan = idToPlan.getOrElseUpdate(planId, { + findPlanById(u, planId, q).getOrElse { + // For example: + // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) + // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) + // df1.select(df2.a) <- illegal reference df2.a + throw new AnalysisException(s"When resolving $u, " + + s"fail to find subplan with plan_id=$planId in $q") + } + }) val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).isDefined try { @@ -539,4 +533,28 @@ trait ColumnResolutionHelper extends Logging { None } } + + private def findPlanById( + u: UnresolvedAttribute, + id: Long, + plan: LogicalPlan): Option[LogicalPlan] = { + if (plan.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) { + Some(plan) + } else if (plan.children.length == 1) { + findPlanById(u, id, plan.children.head) + } else if (plan.children.length > 1) { + val matched = plan.children.flatMap(findPlanById(u, id, _)) + if (matched.length > 1) { + throw new AnalysisException( + errorClass = "AMBIGUOUS_COLUMN_REFERENCE", + messageParameters = Map("name" -> toSQLId(u.nameParts)), + origin = u.origin + ) + } else { + matched.headOption + } + } else { + None + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala index ebf56ef1cc4f3..bd0204ba06fd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{Limit, LogicalPlan, SetVaria import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.errors.QueryCompilationErrors.unresolvedVariableError import org.apache.spark.sql.types.IntegerType From 2aa172dba1176de76719021a45a017759379abe5 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 7 Nov 2023 23:32:47 +0800 Subject: [PATCH 041/121] [SPARK-45822][CONNECT] SparkConnectSessionManager may look up a stopped sparkcontext ### What changes were proposed in this pull request? This PR checks whether the sc is still functional before cloning a new isolated session from it. ### Why are the changes needed? SparkSession.active is a thread-local value and not be updated by other thread. This causes https://github.com/LuciferYang/spark/actions/runs/6767960232/job/18426049162 ```java - ReleaseSession: session with different session_id or user_id allowed after release *** FAILED *** (9 milliseconds) [info] org.apache.spark.SparkException: com.google.common.util.concurrent.UncheckedExecutionException: java.lang.IllegalStateException: Cannot call methods on a stopped SparkContext. [info] This stopped SparkContext was created at: [info] [info] org.apache.spark.sql.connect.service.SparkConnectSessionHolderSuite.beforeAll(SparkConnectSessionHolderSuite.scala:37) [info] org.scalatest.BeforeAndAfterAll.liftedTree1$1(BeforeAndAfterAll.scala:212) ``` For shared spark sessions in tests, these sessions are created, stopped, and retrieved in different threads. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? I ran `build/sbt "connect/testOnly *SparkConnect*"` locally and the test consistently failed w/o this patch. Otherwise, it passed. ### Was this patch authored or co-authored using generative AI tooling? no Closes #43701 from yaooqinn/SPARK-45822. Authored-by: Kent Yao Signed-off-by: yangjie01 --- .../sql/connect/service/SparkConnectSessionManager.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala index 5c8e3c611586c..ba402a90a71e5 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -139,7 +139,13 @@ class SparkConnectSessionManager extends Logging { } private def newIsolatedSession(): SparkSession = { - SparkSession.active.newSession() + val active = SparkSession.active + if (active.sparkContext.isStopped) { + assert(SparkSession.getDefaultSession.nonEmpty) + SparkSession.getDefaultSession.get.newSession() + } else { + active.newSession() + } } private def validateSessionCreate(key: SessionKey): Unit = { From 01c294b05f3a9b7bd87cda0ee8b0160f5f58bb24 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 8 Nov 2023 00:57:31 +0800 Subject: [PATCH 042/121] [SPARK-45760][SQL] Add With expression to avoid duplicating expressions ### What changes were proposed in this pull request? Sometimes we need to duplicate expressions when rewriting the plan. It's OK for small query, as codegen has common-subexpression-elimination (CSE) to avoid evaluating the same expression. However, when the query is big, duplicating expressions can lead to a very big expression tree and make catalyst rules very slow, or even OOM when updating a leaf node (need to copy all tree nodes). This PR introduces a new expression to do expression-level CTE: it adds a Project to pre-evaluate the common expressions, so that they appear only once on the query plan tree, and are evaluated only once. `NullIf` now uses this new expression to avoid duplicating the `left` child expression. ### Why are the changes needed? make catalyst more efficient. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? new test suite ### Was this patch authored or co-authored using generative AI tooling? No Closes #43623 from cloud-fan/with. Lead-authored-by: Wenchen Fan Co-authored-by: Peter Toth Signed-off-by: Wenchen Fan --- .../explain-results/function_count_if.explain | 5 +- .../function_regexp_substr.explain | 5 +- .../connect/ProtoToParsedPlanTestSuite.scala | 15 +- .../spark/sql/catalyst/expressions/With.scala | 63 +++++++ .../expressions/nullExpressions.scala | 6 +- .../sql/catalyst/optimizer/Optimizer.scala | 3 + .../optimizer/RewriteWithExpression.scala | 90 ++++++++++ .../sql/catalyst/trees/TreePatterns.scala | 2 + .../RewriteWithExpressionSuite.scala | 157 ++++++++++++++++++ 9 files changed, 338 insertions(+), 8 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain index 1c23bbf6bce55..f2ada15eccb7d 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain @@ -1,2 +1,3 @@ -Aggregate [count(if (((a#0 > 0) = false)) null else (a#0 > 0)) AS count_if((a > 0))#0L] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] +Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS count_if((a > 0))#0L] ++- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0] + +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain index 69fc760c82910..1811f770f8297 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain @@ -1,2 +1,3 @@ -Project [if ((regexp_extract(g#0, \d{2}(a|b|m), 0) = )) null else regexp_extract(g#0, \d{2}(a|b|m), 0) AS regexp_substr(g, \d{2}(a|b|m))#0] -+- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] +Project [if ((_common_expr_0#0 = )) null else _common_expr_0#0 AS regexp_substr(g, \d{2}(a|b|m))#0] ++- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, regexp_extract(g#0, \d{2}(a|b|m), 0) AS _common_expr_0#0] + +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala index 9fdaffcba670c..e0c4e21503e91 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala @@ -29,7 +29,9 @@ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.{catalog, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{caseSensitiveResolution, Analyzer, FunctionRegistry, Resolver, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog -import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions +import org.apache.spark.sql.catalyst.optimizer.{ReplaceExpressions, RewriteWithExpression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.SessionHolder @@ -181,8 +183,15 @@ class ProtoToParsedPlanTestSuite val planner = new SparkConnectPlanner(SessionHolder.forTesting(spark)) val catalystPlan = analyzer.executeAndCheck(planner.transformRelation(relation), new QueryPlanningTracker) - val actual = - removeMemoryAddress(normalizeExprIds(ReplaceExpressions(catalystPlan)).treeString) + val finalAnalyzedPlan = { + object Helper extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Finish Analysis", Once, ReplaceExpressions) :: + Batch("Rewrite With expression", Once, RewriteWithExpression) :: Nil + } + Helper.execute(catalystPlan) + } + val actual = removeMemoryAddress(normalizeExprIds(finalAnalyzedPlan).treeString) val goldenFile = goldenFilePath.resolve(relativePath).getParent.resolve(name + ".explain") Try(readGoldenFile(goldenFile)) match { case Success(expected) if expected == actual => // Test passes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala new file mode 100644 index 0000000000000..bfed63af17409 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION} +import org.apache.spark.sql.types.DataType + +/** + * An expression holder that keeps a list of common expressions and allow the actual expression to + * reference these common expressions. The common expressions are guaranteed to be evaluated only + * once even if it's referenced more than once. This is similar to CTE but is expression-level. + */ +case class With(child: Expression, defs: Seq[CommonExpressionDef]) + extends Expression with Unevaluable { + override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION) + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable + override def children: Seq[Expression] = child +: defs + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + copy(child = newChildren.head, defs = newChildren.tail.map(_.asInstanceOf[CommonExpressionDef])) + } +} + +/** + * A wrapper of common expression to carry the id. + */ +case class CommonExpressionDef(child: Expression, id: Long = CommonExpressionDef.newId) + extends UnaryExpression with Unevaluable { + override def dataType: DataType = child.dataType + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +/** + * A reference to the common expression by its id. Only resolved common expressions can be + * referenced, so that we can determine the data type and nullable of the reference node. + */ +case class CommonExpressionRef(id: Long, dataType: DataType, nullable: Boolean) + extends LeafExpression with Unevaluable { + def this(exprDef: CommonExpressionDef) = this(exprDef.id, exprDef.dataType, exprDef.nullable) + override val nodePatterns: Seq[TreePattern] = Seq(COMMON_EXPR_REF) +} + +object CommonExpressionDef { + private[sql] val curId = new java.util.concurrent.atomic.AtomicLong() + def newId: Long = curId.getAndIncrement() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 948cb6fbedd32..0e9e375b8acf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -154,7 +154,11 @@ case class NullIf(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { def this(left: Expression, right: Expression) = { - this(left, right, If(EqualTo(left, right), Literal.create(null, left.dataType), left)) + this(left, right, { + val commonExpr = CommonExpressionDef(left) + val ref = new CommonExpressionRef(commonExpr) + With(If(EqualTo(ref, right), Literal.create(null, left.dataType), ref), Seq(commonExpr)) + }) } override def parameters: Seq[Expression] = Seq(left, right) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 48ecb9aee2118..decef766ae97d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -147,6 +147,9 @@ abstract class Optimizer(catalogManager: CatalogManager) val batches = ( Batch("Finish Analysis", Once, FinishAnalysis) :: + // We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression + // may produce `With` expressions that need to be rewritten. + Batch("Rewrite With expression", Once, RewriteWithExpression) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala new file mode 100644 index 0000000000000..c5bd71b4a7d1f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Alias, CommonExpressionDef, CommonExpressionRef, Expression, With} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION} + +/** + * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the common expressions, or + * just inline them if they are cheap. + * + * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. If we expand its + * usage, we should support aggregate/window functions as well. + */ +object RewriteWithExpression extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) { + case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => + var newChildren = p.children + var newPlan: LogicalPlan = p.transformExpressionsUp { + case With(child, defs) => + val refToExpr = mutable.HashMap.empty[Long, Expression] + val childProjections = Array.fill(newChildren.size)(mutable.ArrayBuffer.empty[Alias]) + + defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) => + if (CollapseProject.isCheap(child)) { + refToExpr(id) = child + } else { + val childProjectionIndex = newChildren.indexWhere( + c => child.references.subsetOf(c.outputSet) + ) + if (childProjectionIndex == -1) { + // When we cannot rewrite the common expressions, force to inline them so that the + // query can still run. This can happen if the join condition contains `With` and + // the common expression references columns from both join sides. + // TODO: things can go wrong if the common expression is nondeterministic. We + // don't fix it for now to match the old buggy behavior when certain + // `RuntimeReplaceable` did not use the `With` expression. + // TODO: we should calculate the ref count and also inline the common expression + // if it's ref count is 1. + refToExpr(id) = child + } else { + val alias = Alias(child, s"_common_expr_$index")() + childProjections(childProjectionIndex) += alias + refToExpr(id) = alias.toAttribute + } + } + } + + newChildren = newChildren.zip(childProjections).map { case (child, projections) => + if (projections.nonEmpty) { + Project(child.output ++ projections, child) + } else { + child + } + } + + child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) { + case ref: CommonExpressionRef => refToExpr(ref.id) + } + } + + newPlan = newPlan.withNewChildren(newChildren) + if (p.output == newPlan.output) { + newPlan + } else { + Project(p.output, newPlan) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 8b714d5a5d280..9b3337d1a9406 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -36,6 +36,7 @@ object TreePattern extends Enumeration { val CASE_WHEN: Value = Value val CAST: Value = Value val COALESCE: Value = Value + val COMMON_EXPR_REF: Value = Value val CONCAT: Value = Value val COUNT: Value = Value val CREATE_NAMED_STRUCT: Value = Value @@ -132,6 +133,7 @@ object TreePattern extends Enumeration { val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WINDOW_GROUP_LIMIT: Value = Value + val WITH_EXPRESSION: Value = Value val WITH_WINDOW_DEFINITION: Value = Value // Unresolved expression patterns (Alphabetically ordered) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala new file mode 100644 index 0000000000000..c625379eb5ffd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CommonExpressionDef, CommonExpressionRef, With} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.IntegerType + +class RewriteWithExpressionSuite extends PlanTest { + + object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = Batch("Rewrite With expression", Once, RewriteWithExpression) :: Nil + } + + private val testRelation = LocalRelation($"a".int, $"b".int) + private val testRelation2 = LocalRelation($"x".int, $"y".int) + + test("simple common expression") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a) + val ref = new CommonExpressionRef(commonExprDef) + val plan = testRelation.select(With(ref + ref, Seq(commonExprDef)).as("col")) + comparePlans(Optimizer.execute(plan), testRelation.select((a + a).as("col"))) + } + + test("non-cheap common expression") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a + a) + val ref = new CommonExpressionRef(commonExprDef) + val plan = testRelation.select(With(ref * ref, Seq(commonExprDef)).as("col")) + val commonExprName = "_common_expr_0" + comparePlans( + Optimizer.execute(plan), + testRelation + .select((testRelation.output :+ (a + a).as(commonExprName)): _*) + .select(($"$commonExprName" * $"$commonExprName").as("col")) + .analyze + ) + } + + test("nested WITH expression") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a + a) + val ref = new CommonExpressionRef(commonExprDef) + val innerExpr = With(ref + ref, Seq(commonExprDef)) + val innerCommonExprName = "_common_expr_0" + + val b = testRelation.output.last + val outerCommonExprDef = CommonExpressionDef(innerExpr + b) + val outerRef = new CommonExpressionRef(outerCommonExprDef) + val outerExpr = With(outerRef * outerRef, Seq(outerCommonExprDef)) + val outerCommonExprName = "_common_expr_0" + + val plan = testRelation.select(outerExpr.as("col")) + val rewrittenOuterExpr = ($"$innerCommonExprName" + $"$innerCommonExprName" + b) + .as(outerCommonExprName) + val outerExprAttr = AttributeReference(outerCommonExprName, IntegerType)( + exprId = rewrittenOuterExpr.exprId) + comparePlans( + Optimizer.execute(plan), + testRelation + .select((testRelation.output :+ (a + a).as(innerCommonExprName)): _*) + .select((testRelation.output :+ $"$innerCommonExprName" :+ rewrittenOuterExpr): _*) + .select((outerExprAttr * outerExprAttr).as("col")) + .analyze + ) + } + + test("WITH expression in filter") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a + a) + val ref = new CommonExpressionRef(commonExprDef) + val plan = testRelation.where(With(ref < 10 && ref > 0, Seq(commonExprDef))) + val commonExprName = "_common_expr_0" + comparePlans( + Optimizer.execute(plan), + testRelation + .select((testRelation.output :+ (a + a).as(commonExprName)): _*) + .where($"$commonExprName" < 10 && $"$commonExprName" > 0) + .select(testRelation.output: _*) + .analyze + ) + } + + test("WITH expression in join condition: only reference left child") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a + a) + val ref = new CommonExpressionRef(commonExprDef) + val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val plan = testRelation.join(testRelation2, condition = Some(condition)) + val commonExprName = "_common_expr_0" + comparePlans( + Optimizer.execute(plan), + testRelation + .select((testRelation.output :+ (a + a).as(commonExprName)): _*) + .join(testRelation2, condition = Some($"$commonExprName" < 10 && $"$commonExprName" > 0)) + .select((testRelation.output ++ testRelation2.output): _*) + .analyze + ) + } + + test("WITH expression in join condition: only reference right child") { + val x = testRelation2.output.head + val commonExprDef = CommonExpressionDef(x + x) + val ref = new CommonExpressionRef(commonExprDef) + val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val plan = testRelation.join(testRelation2, condition = Some(condition)) + val commonExprName = "_common_expr_0" + comparePlans( + Optimizer.execute(plan), + testRelation + .join( + testRelation2.select((testRelation2.output :+ (x + x).as(commonExprName)): _*), + condition = Some($"$commonExprName" < 10 && $"$commonExprName" > 0) + ) + .select((testRelation.output ++ testRelation2.output): _*) + .analyze + ) + } + + test("WITH expression in join condition: reference both children") { + val a = testRelation.output.head + val x = testRelation2.output.head + val commonExprDef = CommonExpressionDef(a + x) + val ref = new CommonExpressionRef(commonExprDef) + val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val plan = testRelation.join(testRelation2, condition = Some(condition)) + comparePlans( + Optimizer.execute(plan), + testRelation + .join( + testRelation2, + // Can't pre-evaluate, have to inline + condition = Some((a + x) < 10 && (a + x) > 0) + ) + ) + } +} From 5ef3a846f52ab90cb7183953cff3080449d0b57b Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Tue, 7 Nov 2023 09:06:00 -0800 Subject: [PATCH 043/121] [SPARK-45786][SQL] Fix inaccurate Decimal multiplication and division results ### What changes were proposed in this pull request? This PR fixes inaccurate Decimal multiplication and division results. ### Why are the changes needed? Decimal multiplication and division results may be inaccurate due to rounding issues. #### Multiplication: ``` scala> sql("select -14120025096157587712113961295153.858047 * -0.4652").show(truncate=false) +----------------------------------------------------+ |(-14120025096157587712113961295153.858047 * -0.4652)| +----------------------------------------------------+ |6568635674732509803675414794505.574764 | +----------------------------------------------------+ ``` The correct answer is `6568635674732509803675414794505.574763` Please note that the last digit is `3` instead of `4` as ``` scala> java.math.BigDecimal("-14120025096157587712113961295153.858047").multiply(java.math.BigDecimal("-0.4652")) val res21: java.math.BigDecimal = 6568635674732509803675414794505.5747634644 ``` Since the factional part `.574763` is followed by `4644`, it should not be rounded up. #### Division: ``` scala> sql("select -0.172787979 / 533704665545018957788294905796.5").show(truncate=false) +-------------------------------------------------+ |(-0.172787979 / 533704665545018957788294905796.5)| +-------------------------------------------------+ |-3.237521E-31 | +-------------------------------------------------+ ``` The correct answer is `-3.237520E-31` Please note that the last digit is `0` instead of `1` as ``` scala> java.math.BigDecimal("-0.172787979").divide(java.math.BigDecimal("533704665545018957788294905796.5"), 100, java.math.RoundingMode.DOWN) val res22: java.math.BigDecimal = -3.237520489418037889998826491401059986665344697406144511563561222578738E-31 ``` Since the factional part `.237520` is followed by `4894...`, it should not be rounded up. ### Does this PR introduce _any_ user-facing change? Yes, users will see correct Decimal multiplication and division results. Directly multiplying and dividing with `org.apache.spark.sql.types.Decimal()` (not via SQL) will return 39 digit at maximum instead of 38 at maximum and round down instead of round half-up ### How was this patch tested? Test added ### Was this patch authored or co-authored using generative AI tooling? No Closes #43678 from kazuyukitanimura/SPARK-45786. Authored-by: Kazuyuki Tanimura Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/types/Decimal.scala | 8 +- .../ArithmeticExpressionSuite.scala | 107 ++++++++++++++++++ .../ansi/decimalArithmeticOperations.sql.out | 14 +-- 3 files changed, 120 insertions(+), 9 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 5652e5adda9d4..0bcbefaa54828 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -499,7 +499,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def / (that: Decimal): Decimal = if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, - DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode)) + DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode)) def % (that: Decimal): Decimal = if (that.isZero) null @@ -547,7 +547,11 @@ object Decimal { val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP) + // SPARK-45786 Using RoundingMode.HALF_UP with MathContext may cause inaccurate SQL results + // because TypeCoercion later rounds again. Instead, always round down and use 1 digit longer + // precision than DecimalType.MAX_PRECISION. Then, TypeCoercion will properly round up/down + // the last extra digit. + private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.DOWN) private[sql] val ZERO = Decimal(0) private[sql] val ONE = Decimal(1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e21793ab506c4..568dcd10d1166 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.RoundingMode import java.sql.{Date, Timestamp} import java.time.{Duration, Period} import java.time.temporal.ChronoUnit @@ -225,6 +226,112 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } + test("SPARK-45786: Decimal multiply, divide, remainder, quot") { + // Some known cases + checkEvaluation( + Multiply( + Literal(Decimal(BigDecimal("-14120025096157587712113961295153.858047"), 38, 6)), + Literal(Decimal(BigDecimal("-0.4652"), 4, 4)) + ), + Decimal(BigDecimal("6568635674732509803675414794505.574763")) + ) + checkEvaluation( + Multiply( + Literal(Decimal(BigDecimal("-240810500742726"), 15, 0)), + Literal(Decimal(BigDecimal("-5677.6988688550027099967697071"), 29, 25)) + ), + Decimal(BigDecimal("1367249507675382200.164877854336665327")) + ) + checkEvaluation( + Divide( + Literal(Decimal(BigDecimal("-0.172787979"), 9, 9)), + Literal(Decimal(BigDecimal("533704665545018957788294905796.5"), 31, 1)) + ), + Decimal(BigDecimal("-3.237520E-31")) + ) + checkEvaluation( + Divide( + Literal(Decimal(BigDecimal("-0.574302343618"), 12, 12)), + Literal(Decimal(BigDecimal("-795826820326278835912868.106"), 27, 3)) + ), + Decimal(BigDecimal("7.21642358550E-25")) + ) + + // Random tests + val rand = scala.util.Random + def makeNum(p: Int, s: Int): String = { + val int1 = rand.nextLong() + val int2 = rand.nextLong().abs + val frac1 = rand.nextLong().abs + val frac2 = rand.nextLong().abs + s"$int1$int2".take(p - s + (int1 >>> 63).toInt) + "." + s"$frac1$frac2".take(s) + } + + (0 until 100).foreach { _ => + val p1 = rand.nextInt(38) + 1 // 1 <= p1 <= 38 + val s1 = rand.nextInt(p1 + 1) // 0 <= s1 <= p1 + val p2 = rand.nextInt(38) + 1 + val s2 = rand.nextInt(p2 + 1) + + val n1 = makeNum(p1, s1) + val n2 = makeNum(p2, s2) + + val mulActual = Multiply( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2)) + + val divActual = Divide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val divExact = new java.math.BigDecimal(n1) + .divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN) + + val remActual = Remainder( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val remExact = new java.math.BigDecimal(n1).remainder(new java.math.BigDecimal(n2)) + + val quotActual = IntegralDivide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val quotExact = + new java.math.BigDecimal(n1).divideToIntegralValue(new java.math.BigDecimal(n2)) + + Seq(true, false).foreach { allowPrecLoss => + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) { + val mulType = Multiply(null, null).resultDecimalType(p1, s1, p2, s2) + val mulResult = Decimal(mulExact.setScale(mulType.scale, RoundingMode.HALF_UP)) + val mulExpected = + if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult + checkEvaluation(mulActual, mulExpected) + + val divType = Divide(null, null).resultDecimalType(p1, s1, p2, s2) + val divResult = Decimal(divExact.setScale(divType.scale, RoundingMode.HALF_UP)) + val divExpected = + if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult + checkEvaluation(divActual, divExpected) + + val remType = Remainder(null, null).resultDecimalType(p1, s1, p2, s2) + val remResult = Decimal(remExact.setScale(remType.scale, RoundingMode.HALF_UP)) + val remExpected = + if (remResult.precision > DecimalType.MAX_PRECISION) null else remResult + checkEvaluation(remActual, remExpected) + + val quotType = IntegralDivide(null, null).resultDecimalType(p1, s1, p2, s2) + val quotResult = Decimal(quotExact.setScale(quotType.scale, RoundingMode.HALF_UP)) + val quotExpected = + if (quotResult.precision > DecimalType.MAX_PRECISION) null else quotResult + checkEvaluation(quotActual, quotExpected.toLong) + } + } + } + } + private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = { testFunc(_.toDouble) testFunc(Decimal(_)) diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out index 699c916fd8fdb..9593291fae21d 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out @@ -155,7 +155,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "1000000000000000000000000000000000000.00000000000000000000000000000000000000" + "value" : "1000000000000000000000000000000000000.000000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -204,7 +204,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "10123456789012345678901234567890123456.00000000000000000000000000000000000000" + "value" : "10123456789012345678901234567890123456.000000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -229,7 +229,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901234.56000000000000000000000000000000000000" + "value" : "101234567890123456789012345678901234.560000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -254,7 +254,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "10123456789012345678901234567890123.45600000000000000000000000000000000000" + "value" : "10123456789012345678901234567890123.456000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -279,7 +279,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "1012345678901234567890123456789012.34560000000000000000000000000000000000" + "value" : "1012345678901234567890123456789012.345600000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -304,7 +304,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901.23456000000000000000000000000000000000" + "value" : "101234567890123456789012345678901.234560000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -337,7 +337,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901.23456000000000000000000000000000000000" + "value" : "101234567890123456789012345678901.234560000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", From 50d137d752e5be61bcd7c28754eda7b300eb6e27 Mon Sep 17 00:00:00 2001 From: Sandip Agarwala <131817656+sandip-db@users.noreply.github.com> Date: Tue, 7 Nov 2023 09:33:02 -0800 Subject: [PATCH 044/121] [SPARK-44751][SQL] XML: Refactor TypeCast and timestampFormatter ### What changes were proposed in this pull request? - Move initialization of TimestampFormatter from XmlOptions to StaxXmlParser, StaxXmlGenerator and XmlInferSchema. - Move functions from typecast.scala to StaxXmlParser or XmlInferSchema - Convert XmlInferSchema to a class ### Why are the changes needed? Some of the timestampformatter fields were not correctly initialized when accessed in StaxXmlParser in the executor. This was resulting in some timestamp parsing failures. Moving the initialization of timestampformatter to StaxXmlParser fixed the issue. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added new unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #43697 from sandip-db/xml-typecast. Authored-by: Sandip Agarwala <131817656+sandip-db@users.noreply.github.com> Signed-off-by: Hyukjin Kwon --- .../catalyst/expressions/xmlExpressions.scala | 11 +- .../sql/catalyst/xml/StaxXmlGenerator.scala | 29 ++- .../sql/catalyst/xml/StaxXmlParser.scala | 224 +++++++++++++--- .../spark/sql/catalyst/xml/TypeCast.scala | 244 ------------------ .../sql/catalyst/xml/XmlInferSchema.scala | 158 +++++++++--- .../spark/sql/catalyst/xml/XmlOptions.scala | 36 +-- .../apache/spark/sql/DataFrameReader.scala | 5 +- .../datasources/xml/XmlDataSource.scala | 6 +- .../sql/streaming/DataStreamReader.scala | 6 +- .../execution/datasources/xml/XmlSuite.scala | 182 ++++++++++++- .../datasources/xml/util/TypeCastSuite.scala | 234 ----------------- 11 files changed, 546 insertions(+), 589 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/TypeCast.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/util/TypeCastSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index 047b669fc8960..c581643460f65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -188,6 +188,9 @@ case class SchemaOfXml( @transient private lazy val xmlFactory = xmlOptions.buildXmlFactory() + @transient + private lazy val xmlInferSchema = new XmlInferSchema(xmlOptions) + @transient private lazy val xml = child.eval().asInstanceOf[UTF8String] @@ -209,16 +212,16 @@ case class SchemaOfXml( } override def eval(v: InternalRow): Any = { - val dataType = XmlInferSchema.infer(xml.toString, xmlOptions).get match { + val dataType = xmlInferSchema.infer(xml.toString).get match { case st: StructType => - XmlInferSchema.canonicalizeType(st).getOrElse(StructType(Nil)) + xmlInferSchema.canonicalizeType(st).getOrElse(StructType(Nil)) case at: ArrayType if at.elementType.isInstanceOf[StructType] => - XmlInferSchema + xmlInferSchema .canonicalizeType(at.elementType) .map(ArrayType(_, containsNull = at.containsNull)) .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) case other: DataType => - XmlInferSchema.canonicalizeType(other).getOrElse(StringType) + xmlInferSchema.canonicalizeType(other).getOrElse(StringType) } UTF8String.fromString(dataType.sql) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala index 4477cf50823cb..ae3a64d865cf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala @@ -26,7 +26,8 @@ import com.sun.xml.txw2.output.IndentingXMLStreamWriter import org.apache.hadoop.shaded.com.ctc.wstx.api.WstxOutputProperties import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, MapData, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -40,6 +41,26 @@ class StaxXmlGenerator( "'attributePrefix' option should not be empty string.") private val indentDisabled = options.indent == "" + private val timestampFormatter = TimestampFormatter( + options.timestampFormatInWrite, + options.zoneId, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = false) + + private val timestampNTZFormatter = TimestampFormatter( + options.timestampNTZFormatInWrite, + options.zoneId, + legacyFormat = FAST_DATE_FORMAT, + isParsing = false, + forTimestampNTZ = true) + + private val dateFormatter = DateFormatter( + options.dateFormatInWrite, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = false) + private val gen = { val factory = XMLOutputFactory.newInstance() // to_xml disables structure validation to allow multiple root tags @@ -149,11 +170,11 @@ class StaxXmlGenerator( case (StringType, v: UTF8String) => gen.writeCharacters(v.toString) case (StringType, v: String) => gen.writeCharacters(v) case (TimestampType, v: Timestamp) => - gen.writeCharacters(options.timestampFormatterInWrite.format(v.toInstant())) + gen.writeCharacters(timestampFormatter.format(v.toInstant())) case (TimestampType, v: Long) => - gen.writeCharacters(options.timestampFormatterInWrite.format(v)) + gen.writeCharacters(timestampFormatter.format(v)) case (DateType, v: Int) => - gen.writeCharacters(options.dateFormatterInWrite.format(v)) + gen.writeCharacters(dateFormatter.format(v)) case (IntegerType, v: Int) => gen.writeCharacters(v.toString) case (ShortType, v: Short) => gen.writeCharacters(v.toString) case (FloatType, v: Float) => gen.writeCharacters(v.toString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index dcb760aca9d2a..77a0bd1dff179 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.xml import java.io.{CharConversionException, InputStream, InputStreamReader, StringReader} import java.nio.charset.{Charset, MalformedInputException} +import java.text.NumberFormat +import java.util.Locale import javax.xml.stream.{XMLEventReader, XMLStreamException} import javax.xml.stream.events._ import javax.xml.transform.stream.StreamSource @@ -25,15 +27,17 @@ import javax.xml.validation.Schema import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import scala.util.Try import scala.util.control.NonFatal import scala.xml.SAXException import org.apache.spark.SparkUpgradeException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode} +import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream -import org.apache.spark.sql.catalyst.xml.TypeCast._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -46,6 +50,29 @@ class StaxXmlParser( private val factory = options.buildXmlFactory() + private lazy val timestampFormatter = TimestampFormatter( + options.timestampFormatInRead, + options.zoneId, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true) + + private lazy val timestampNTZFormatter = TimestampFormatter( + options.timestampNTZFormatInRead, + options.zoneId, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true, + forTimestampNTZ = true) + + private lazy val dateFormatter = DateFormatter( + options.dateFormatInRead, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true) + + private val decimalParser = ExprUtils.getDecimalParser(options.locale) + + /** * Parses a single XML string and turns it into either one resulting row or no row (if the * the record is malformed). @@ -108,7 +135,7 @@ class StaxXmlParser( val isRootAttributesOnly = schema.fields.forall { f => f.name == options.valueTag || f.name.startsWith(options.attributePrefix) } - Some(convertObject(parser, schema, options, rootAttributes, isRootAttributesOnly)) + Some(convertObject(parser, schema, rootAttributes, isRootAttributesOnly)) } catch { case e: SparkUpgradeException => throw e case e@(_: RuntimeException | _: XMLStreamException | _: MalformedInputException @@ -145,15 +172,14 @@ class StaxXmlParser( private[xml] def convertField( parser: XMLEventReader, dataType: DataType, - options: XmlOptions, attributes: Array[Attribute] = Array.empty): Any = { def convertComplicatedType(dt: DataType, attributes: Array[Attribute]): Any = dt match { - case st: StructType => convertObject(parser, st, options) - case MapType(StringType, vt, _) => convertMap(parser, vt, options, attributes) - case ArrayType(st, _) => convertField(parser, st, options) + case st: StructType => convertObject(parser, st) + case MapType(StringType, vt, _) => convertMap(parser, vt, attributes) + case ArrayType(st, _) => convertField(parser, st) case _: StringType => - convertTo(StaxXmlParserUtils.currentStructureAsString(parser), StringType, options) + convertTo(StaxXmlParserUtils.currentStructureAsString(parser), StringType) } (parser.peek, dataType) match { @@ -168,7 +194,7 @@ class StaxXmlParser( case (_: EndElement, _: DataType) => null case (c: Characters, ArrayType(st, _)) => // For `ArrayType`, it needs to return the type of element. The values are merged later. - convertTo(c.getData, st, options) + convertTo(c.getData, st) case (c: Characters, st: StructType) => // If a value tag is present, this can be an attribute-only element whose values is in that // value tag field. Or, it can be a mixed-type element with both some character elements @@ -180,18 +206,18 @@ class StaxXmlParser( // If everything else is an attribute column, there's no complex structure. // Just return the value of the character element, or null if we don't have a value tag st.find(_.name == options.valueTag).map( - valueTag => convertTo(c.getData, valueTag.dataType, options)).orNull + valueTag => convertTo(c.getData, valueTag.dataType)).orNull } else { // Otherwise, ignore this character element, and continue parsing the following complex // structure parser.next parser.peek match { case _: EndElement => null // no struct here at all; done - case _ => convertObject(parser, st, options) + case _ => convertObject(parser, st) } } case (_: Characters, _: StringType) => - convertTo(StaxXmlParserUtils.currentStructureAsString(parser), StringType, options) + convertTo(StaxXmlParserUtils.currentStructureAsString(parser), StringType) case (c: Characters, _: DataType) if c.isWhiteSpace => // When `Characters` is found, we need to look further to decide // if this is really data or space between other elements. @@ -201,11 +227,11 @@ class StaxXmlParser( case _: StartElement => convertComplicatedType(dataType, attributes) case _: EndElement if data.isEmpty => null case _: EndElement if options.treatEmptyValuesAsNulls => null - case _: EndElement => convertTo(data, dataType, options) - case _ => convertField(parser, dataType, options, attributes) + case _: EndElement => convertTo(data, dataType) + case _ => convertField(parser, dataType, attributes) } case (c: Characters, dt: DataType) => - convertTo(c.getData, dt, options) + convertTo(c.getData, dt) case (e: XMLEvent, dt: DataType) => throw new IllegalArgumentException( s"Failed to parse a value for data type $dt with event ${e.toString}") @@ -218,12 +244,11 @@ class StaxXmlParser( private def convertMap( parser: XMLEventReader, valueType: DataType, - options: XmlOptions, attributes: Array[Attribute]): MapData = { val kvPairs = ArrayBuffer.empty[(UTF8String, Any)] attributes.foreach { attr => kvPairs += (UTF8String.fromString(options.attributePrefix + attr.getName.getLocalPart) - -> convertTo(attr.getValue, valueType, options)) + -> convertTo(attr.getValue, valueType)) } var shouldStop = false while (!shouldStop) { @@ -231,7 +256,7 @@ class StaxXmlParser( case e: StartElement => kvPairs += (UTF8String.fromString(StaxXmlParserUtils.getName(e.asStartElement.getName, options)) -> - convertField(parser, valueType, options)) + convertField(parser, valueType)) case _: EndElement => shouldStop = StaxXmlParserUtils.checkEndElement(parser) case _ => // do nothing @@ -245,14 +270,13 @@ class StaxXmlParser( */ private def convertAttributes( attributes: Array[Attribute], - schema: StructType, - options: XmlOptions): Map[String, Any] = { + schema: StructType): Map[String, Any] = { val convertedValuesMap = collection.mutable.Map.empty[String, Any] val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options) valuesMap.foreach { case (f, v) => val nameToIndex = schema.map(_.name).zipWithIndex.toMap nameToIndex.get(f).foreach { i => - convertedValuesMap(f) = convertTo(v, schema(i).dataType, options) + convertedValuesMap(f) = convertTo(v, schema(i).dataType) } } convertedValuesMap.toMap @@ -266,16 +290,15 @@ class StaxXmlParser( private def convertObjectWithAttributes( parser: XMLEventReader, schema: StructType, - options: XmlOptions, attributes: Array[Attribute] = Array.empty): InternalRow = { // TODO: This method might have to be removed. Some logics duplicate `convertObject()` val row = new Array[Any](schema.length) // Read attributes first. - val attributesMap = convertAttributes(attributes, schema, options) + val attributesMap = convertAttributes(attributes, schema) // Then, we read elements here. - val fieldsMap = convertField(parser, schema, options) match { + val fieldsMap = convertField(parser, schema) match { case internalRow: InternalRow => Map(schema.map(_.name).zip(internalRow.toSeq(schema)): _*) case v if schema.fieldNames.contains(options.valueTag) => @@ -309,13 +332,12 @@ class StaxXmlParser( private def convertObject( parser: XMLEventReader, schema: StructType, - options: XmlOptions, rootAttributes: Array[Attribute] = Array.empty, isRootAttributesOnly: Boolean = false): InternalRow = { val row = new Array[Any](schema.length) val nameToIndex = schema.map(_.name).zipWithIndex.toMap // If there are attributes, then we process them first. - convertAttributes(rootAttributes, schema, options).toSeq.foreach { case (f, v) => + convertAttributes(rootAttributes, schema).toSeq.foreach { case (f, v) => nameToIndex.get(f).foreach { row(_) = v } } @@ -334,7 +356,7 @@ class StaxXmlParser( nameToIndex.get(field) match { case Some(index) => schema(index).dataType match { case st: StructType => - row(index) = convertObjectWithAttributes(parser, st, options, attributes) + row(index) = convertObjectWithAttributes(parser, st, attributes) case ArrayType(dt: DataType, _) => val values = Option(row(index)) @@ -342,21 +364,21 @@ class StaxXmlParser( .getOrElse(ArrayBuffer.empty[Any]) val newValue = dt match { case st: StructType => - convertObjectWithAttributes(parser, st, options, attributes) + convertObjectWithAttributes(parser, st, attributes) case dt: DataType => - convertField(parser, dt, options) + convertField(parser, dt) } row(index) = values :+ newValue case dt: DataType => - row(index) = convertField(parser, dt, options, attributes) + row(index) = convertField(parser, dt, attributes) } case None => if (hasWildcard) { // Special case: there's an 'any' wildcard element that matches anything else // as a string (or array of strings, to parse multiple ones) - val newValue = convertField(parser, StringType, options) + val newValue = convertField(parser, StringType) val anyIndex = schema.fieldIndex(wildcardColName) schema(wildcardColName).dataType match { case StringType => @@ -380,7 +402,7 @@ class StaxXmlParser( case c: Characters if !c.isWhiteSpace && isRootAttributesOnly => nameToIndex.get(options.valueTag) match { case Some(index) => - row(index) = convertTo(c.getData, schema(index).dataType, options) + row(index) = convertTo(c.getData, schema(index).dataType) case None => // do nothing } @@ -410,6 +432,144 @@ class StaxXmlParser( badRecordException.get) } } + + /** + * Casts given string datum to specified type. + * + * For string types, this is simply the datum. + * For other nullable types, returns null if it is null or equals to the value specified + * in `nullValue` option. + * + * @param datum string value + * @param castType SparkSQL type + */ + private def castTo( + datum: String, + castType: DataType): Any = { + if ((datum == options.nullValue) || + (options.treatEmptyValuesAsNulls && datum == "")) { + null + } else { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + case _: DoubleType => Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + case _: BooleanType => parseXmlBoolean(datum) + case dt: DecimalType => + Decimal(decimalParser(datum), dt.precision, dt.scale) + case _: TimestampType => parseXmlTimestamp(datum, options) + case _: DateType => parseXmlDate(datum, options) + case _: StringType => UTF8String.fromString(datum) + case _ => throw new IllegalArgumentException(s"Unsupported type: ${castType.typeName}") + } + } + } + + private def parseXmlBoolean(s: String): Boolean = { + s.toLowerCase(Locale.ROOT) match { + case "true" | "1" => true + case "false" | "0" => false + case _ => throw new IllegalArgumentException(s"For input string: $s") + } + } + + private def parseXmlDate(value: String, options: XmlOptions): Int = { + dateFormatter.parse(value) + } + + private def parseXmlTimestamp(value: String, options: XmlOptions): Long = { + timestampFormatter.parse(value) + } + + // TODO: This function unnecessarily does type dispatch. Should merge it with `castTo`. + private def convertTo( + datum: String, + dataType: DataType): Any = { + val value = if (datum != null && options.ignoreSurroundingSpaces) { + datum.trim() + } else { + datum + } + if ((value == options.nullValue) || + (options.treatEmptyValuesAsNulls && value == "")) { + null + } else { + dataType match { + case NullType => castTo(value, StringType) + case LongType => signSafeToLong(value) + case DoubleType => signSafeToDouble(value) + case BooleanType => castTo(value, BooleanType) + case StringType => castTo(value, StringType) + case DateType => castTo(value, DateType) + case TimestampType => castTo(value, TimestampType) + case FloatType => signSafeToFloat(value) + case ByteType => castTo(value, ByteType) + case ShortType => castTo(value, ShortType) + case IntegerType => signSafeToInt(value) + case dt: DecimalType => castTo(value, dt) + case _ => throw new IllegalArgumentException( + s"Failed to parse a value for data type $dataType.") + } + } + } + + + private def signSafeToLong(value: String): Long = { + if (value.startsWith("+")) { + val data = value.substring(1) + castTo(data, LongType).asInstanceOf[Long] + } else if (value.startsWith("-")) { + val data = value.substring(1) + -castTo(data, LongType).asInstanceOf[Long] + } else { + val data = value + castTo(data, LongType).asInstanceOf[Long] + } + } + + private def signSafeToDouble(value: String): Double = { + if (value.startsWith("+")) { + val data = value.substring(1) + castTo(data, DoubleType).asInstanceOf[Double] + } else if (value.startsWith("-")) { + val data = value.substring(1) + -castTo(data, DoubleType).asInstanceOf[Double] + } else { + val data = value + castTo(data, DoubleType).asInstanceOf[Double] + } + } + + private def signSafeToInt(value: String): Int = { + if (value.startsWith("+")) { + val data = value.substring(1) + castTo(data, IntegerType).asInstanceOf[Int] + } else if (value.startsWith("-")) { + val data = value.substring(1) + -castTo(data, IntegerType).asInstanceOf[Int] + } else { + val data = value + castTo(data, IntegerType).asInstanceOf[Int] + } + } + + private def signSafeToFloat(value: String): Float = { + if (value.startsWith("+")) { + val data = value.substring(1) + castTo(data, FloatType).asInstanceOf[Float] + } else if (value.startsWith("-")) { + val data = value.substring(1) + -castTo(data, FloatType).asInstanceOf[Float] + } else { + val data = value + castTo(data, FloatType).asInstanceOf[Float] + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/TypeCast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/TypeCast.scala deleted file mode 100644 index 3315196ffc765..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/TypeCast.scala +++ /dev/null @@ -1,244 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.xml - -import java.math.BigDecimal -import java.text.NumberFormat -import java.util.Locale - -import scala.util.Try -import scala.util.control.Exception._ - -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * Utility functions for type casting - */ -private[sql] object TypeCast { - - /** - * Casts given string datum to specified type. - * Currently we do not support complex types (ArrayType, MapType, StructType). - * - * For string types, this is simply the datum. For other types. - * For other nullable types, this is null if the string datum is empty. - * - * @param datum string value - * @param castType SparkSQL type - */ - private[sql] def castTo( - datum: String, - castType: DataType, - options: XmlOptions): Any = { - if ((datum == options.nullValue) || - (options.treatEmptyValuesAsNulls && datum == "")) { - null - } else { - castType match { - case _: ByteType => datum.toByte - case _: ShortType => datum.toShort - case _: IntegerType => datum.toInt - case _: LongType => datum.toLong - case _: FloatType => Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) - case _: DoubleType => Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) - case _: BooleanType => parseXmlBoolean(datum) - case dt: DecimalType => - Decimal(new BigDecimal(datum.replaceAll(",", "")), dt.precision, dt.scale) - case _: TimestampType => parseXmlTimestamp(datum, options) - case _: DateType => parseXmlDate(datum, options) - case _: StringType => UTF8String.fromString(datum) - case _ => throw new IllegalArgumentException(s"Unsupported type: ${castType.typeName}") - } - } - } - - private def parseXmlBoolean(s: String): Boolean = { - s.toLowerCase(Locale.ROOT) match { - case "true" | "1" => true - case "false" | "0" => false - case _ => throw new IllegalArgumentException(s"For input string: $s") - } - } - - private def parseXmlDate(value: String, options: XmlOptions): Int = { - options.dateFormatter.parse(value) - } - - private def parseXmlTimestamp(value: String, options: XmlOptions): Long = { - options.timestampFormatter.parse(value) - } - - // TODO: This function unnecessarily does type dispatch. Should merge it with `castTo`. - private[sql] def convertTo( - datum: String, - dataType: DataType, - options: XmlOptions): Any = { - val value = if (datum != null && options.ignoreSurroundingSpaces) { - datum.trim() - } else { - datum - } - if ((value == options.nullValue) || - (options.treatEmptyValuesAsNulls && value == "")) { - null - } else { - dataType match { - case NullType => castTo(value, StringType, options) - case LongType => signSafeToLong(value, options) - case DoubleType => signSafeToDouble(value, options) - case BooleanType => castTo(value, BooleanType, options) - case StringType => castTo(value, StringType, options) - case DateType => castTo(value, DateType, options) - case TimestampType => castTo(value, TimestampType, options) - case FloatType => signSafeToFloat(value, options) - case ByteType => castTo(value, ByteType, options) - case ShortType => castTo(value, ShortType, options) - case IntegerType => signSafeToInt(value, options) - case dt: DecimalType => castTo(value, dt, options) - case _ => throw new IllegalArgumentException( - s"Failed to parse a value for data type $dataType.") - } - } - } - - /** - * Helper method that checks and cast string representation of a numeric types. - */ - private[sql] def isBoolean(value: String): Boolean = { - value.toLowerCase(Locale.ROOT) match { - case "true" | "false" => true - case _ => false - } - } - - private[sql] def isDouble(value: String): Boolean = { - val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) { - value.substring(1) - } else { - value - } - // A little shortcut to avoid trying many formatters in the common case that - // the input isn't a double. All built-in formats will start with a digit or period. - if (signSafeValue.isEmpty || - !(Character.isDigit(signSafeValue.head) || signSafeValue.head == '.')) { - return false - } - // Rule out strings ending in D or F, as they will parse as double but should be disallowed - if (value.nonEmpty && (value.last match { - case 'd' | 'D' | 'f' | 'F' => true - case _ => false - })) { - return false - } - (allCatch opt signSafeValue.toDouble).isDefined - } - - private[sql] def isInteger(value: String): Boolean = { - val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) { - value.substring(1) - } else { - value - } - // A little shortcut to avoid trying many formatters in the common case that - // the input isn't a number. All built-in formats will start with a digit. - if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) { - return false - } - (allCatch opt signSafeValue.toInt).isDefined - } - - private[sql] def isLong(value: String): Boolean = { - val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) { - value.substring(1) - } else { - value - } - // A little shortcut to avoid trying many formatters in the common case that - // the input isn't a number. All built-in formats will start with a digit. - if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) { - return false - } - (allCatch opt signSafeValue.toLong).isDefined - } - - private[sql] def isTimestamp(value: String, options: XmlOptions): Boolean = { - try { - options.timestampFormatter.parseOptional(value).isDefined - } catch { - case _: IllegalArgumentException => false - } - } - - private[sql] def isDate(value: String, options: XmlOptions): Boolean = { - (allCatch opt options.dateFormatter.parse(value)).isDefined - } - - private[sql] def signSafeToLong(value: String, options: XmlOptions): Long = { - if (value.startsWith("+")) { - val data = value.substring(1) - TypeCast.castTo(data, LongType, options).asInstanceOf[Long] - } else if (value.startsWith("-")) { - val data = value.substring(1) - -TypeCast.castTo(data, LongType, options).asInstanceOf[Long] - } else { - val data = value - TypeCast.castTo(data, LongType, options).asInstanceOf[Long] - } - } - - private[sql] def signSafeToDouble(value: String, options: XmlOptions): Double = { - if (value.startsWith("+")) { - val data = value.substring(1) - TypeCast.castTo(data, DoubleType, options).asInstanceOf[Double] - } else if (value.startsWith("-")) { - val data = value.substring(1) - -TypeCast.castTo(data, DoubleType, options).asInstanceOf[Double] - } else { - val data = value - TypeCast.castTo(data, DoubleType, options).asInstanceOf[Double] - } - } - - private[sql] def signSafeToInt(value: String, options: XmlOptions): Int = { - if (value.startsWith("+")) { - val data = value.substring(1) - TypeCast.castTo(data, IntegerType, options).asInstanceOf[Int] - } else if (value.startsWith("-")) { - val data = value.substring(1) - -TypeCast.castTo(data, IntegerType, options).asInstanceOf[Int] - } else { - val data = value - TypeCast.castTo(data, IntegerType, options).asInstanceOf[Int] - } - } - - private[sql] def signSafeToFloat(value: String, options: XmlOptions): Float = { - if (value.startsWith("+")) { - val data = value.substring(1) - TypeCast.castTo(data, FloatType, options).asInstanceOf[Float] - } else if (value.startsWith("-")) { - val data = value.substring(1) - -TypeCast.castTo(data, FloatType, options).asInstanceOf[Float] - } else { - val data = value - TypeCast.castTo(data, FloatType, options).asInstanceOf[Float] - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index 8bddb8f5bd99b..777dd69fd7fa0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.xml import java.io.StringReader +import java.util.Locale import javax.xml.stream.XMLEventReader import javax.xml.stream.events._ import javax.xml.transform.stream.StreamSource @@ -25,14 +26,39 @@ import javax.xml.validation.Schema import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import scala.util.control.Exception._ import scala.util.control.NonFatal +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.util.PermissiveMode -import org.apache.spark.sql.catalyst.xml.TypeCast._ +import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.util.{DateFormatter, PermissiveMode, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.types._ -private[sql] object XmlInferSchema { +private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with Logging { + + private val decimalParser = ExprUtils.getDecimalParser(options.locale) + + private val timestampFormatter = TimestampFormatter( + options.timestampFormatInRead, + options.zoneId, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true) + + private val timestampNTZFormatter = TimestampFormatter( + options.timestampNTZFormatInRead, + options.zoneId, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true, + forTimestampNTZ = true) + + private lazy val dateFormatter = DateFormatter( + options.dateFormatInRead, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true) /** * Copied from internal Spark api @@ -66,7 +92,7 @@ private[sql] object XmlInferSchema { * 2. Merge types by choosing the lowest type necessary to cover equal keys * 3. Replace any remaining null fields with string, the top type */ - def infer(xml: RDD[String], options: XmlOptions): StructType = { + def infer(xml: RDD[String]): StructType = { val schemaData = if (options.samplingRatio < 1.0) { xml.sample(withReplacement = false, options.samplingRatio, 1) } else { @@ -77,9 +103,9 @@ private[sql] object XmlInferSchema { val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema) iter.flatMap { xml => - infer(xml, options, xsdSchema) + infer(xml, xsdSchema) } - }.fold(StructType(Seq()))(compatibleType(options)) + }.fold(StructType(Seq()))(compatibleType) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -90,7 +116,6 @@ private[sql] object XmlInferSchema { } def infer(xml: String, - options: XmlOptions, xsdSchema: Option[Schema] = None): Option[DataType] = { try { val xsd = xsdSchema.orElse(Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)) @@ -99,7 +124,7 @@ private[sql] object XmlInferSchema { } val parser = StaxXmlParserUtils.filteredReader(xml) val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser) - Some(inferObject(parser, options, rootAttributes)) + Some(inferObject(parser, rootAttributes)) } catch { case NonFatal(_) if options.parseMode == PermissiveMode => Some(StructType(Seq(StructField(options.columnNameOfCorruptRecord, StringType)))) @@ -108,7 +133,7 @@ private[sql] object XmlInferSchema { } } - private def inferFrom(datum: String, options: XmlOptions): DataType = { + private def inferFrom(datum: String): DataType = { val value = if (datum != null && options.ignoreSurroundingSpaces) { datum.trim() } else { @@ -123,8 +148,8 @@ private[sql] object XmlInferSchema { case v if isInteger(v) => IntegerType case v if isDouble(v) => DoubleType case v if isBoolean(v) => BooleanType - case v if isDate(v, options) => DateType - case v if isTimestamp(v, options) => TimestampType + case v if isDate(v) => DateType + case v if isTimestamp(v) => TimestampType case _ => StringType } } else { @@ -133,32 +158,32 @@ private[sql] object XmlInferSchema { } @tailrec - private def inferField(parser: XMLEventReader, options: XmlOptions): DataType = { + private def inferField(parser: XMLEventReader): DataType = { parser.peek match { case _: EndElement => NullType - case _: StartElement => inferObject(parser, options) + case _: StartElement => inferObject(parser) case c: Characters if c.isWhiteSpace => // When `Characters` is found, we need to look further to decide // if this is really data or space between other elements. val data = c.getData parser.nextEvent() parser.peek match { - case _: StartElement => inferObject(parser, options) + case _: StartElement => inferObject(parser) case _: EndElement if data.isEmpty => NullType case _: EndElement if options.treatEmptyValuesAsNulls => NullType case _: EndElement => StringType - case _ => inferField(parser, options) + case _ => inferField(parser) } case c: Characters if !c.isWhiteSpace => // This could be the characters of a character-only element, or could have mixed // characters and other complex structure - val characterType = inferFrom(c.getData, options) + val characterType = inferFrom(c.getData) parser.nextEvent() parser.peek match { case _: StartElement => // Some more elements follow; so ignore the characters. // Use the schema of the rest - inferObject(parser, options).asInstanceOf[StructType] + inferObject(parser).asInstanceOf[StructType] case _ => // That's all, just the character-only body; use that as the type characterType @@ -173,7 +198,6 @@ private[sql] object XmlInferSchema { */ private def inferObject( parser: XMLEventReader, - options: XmlOptions, rootAttributes: Array[Attribute] = Array.empty): DataType = { val builder = ArrayBuffer[StructField]() val nameToDataType = collection.mutable.Map.empty[String, ArrayBuffer[DataType]] @@ -182,7 +206,7 @@ private[sql] object XmlInferSchema { StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options) rootValuesMap.foreach { case (f, v) => - nameToDataType += (f -> ArrayBuffer(inferFrom(v, options))) + nameToDataType += (f -> ArrayBuffer(inferFrom(v))) } var shouldStop = false while (!shouldStop) { @@ -190,14 +214,14 @@ private[sql] object XmlInferSchema { case e: StartElement => val attributes = e.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options) - val inferredType = inferField(parser, options) match { + val inferredType = inferField(parser) match { case st: StructType if valuesMap.nonEmpty => // Merge attributes to the field val nestedBuilder = ArrayBuffer[StructField]() nestedBuilder ++= st.fields valuesMap.foreach { case (f, v) => - nestedBuilder += StructField(f, inferFrom(v, options), nullable = true) + nestedBuilder += StructField(f, inferFrom(v), nullable = true) } StructType(nestedBuilder.sortBy(_.name).toArray) @@ -207,7 +231,7 @@ private[sql] object XmlInferSchema { nestedBuilder += StructField(options.valueTag, dt, nullable = true) valuesMap.foreach { case (f, v) => - nestedBuilder += StructField(f, inferFrom(v, options), nullable = true) + nestedBuilder += StructField(f, inferFrom(v), nullable = true) } StructType(nestedBuilder.sortBy(_.name).toArray) @@ -221,7 +245,7 @@ private[sql] object XmlInferSchema { case c: Characters if !c.isWhiteSpace => // This can be an attribute-only object - val valueTagType = inferFrom(c.getData, options) + val valueTagType = inferFrom(c.getData) nameToDataType += options.valueTag -> ArrayBuffer(valueTagType) case _: EndElement => @@ -245,7 +269,7 @@ private[sql] object XmlInferSchema { // This can be inferred as ArrayType. nameToDataType.foreach { case (field, dataTypes) if dataTypes.length > 1 => - val elementType = dataTypes.reduceLeft(XmlInferSchema.compatibleType(options)) + val elementType = dataTypes.reduceLeft(compatibleType) builder += StructField(field, ArrayType(elementType), nullable = true) case (field, dataTypes) => builder += StructField(field, dataTypes.head, nullable = true) @@ -255,6 +279,78 @@ private[sql] object XmlInferSchema { StructType(builder.sortBy(_.name).toArray) } + /** + * Helper method that checks and cast string representation of a numeric types. + */ + private def isBoolean(value: String): Boolean = { + value.toLowerCase(Locale.ROOT) match { + case "true" | "false" => true + case _ => false + } + } + + private def isDouble(value: String): Boolean = { + val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) { + value.substring(1) + } else { + value + } + // A little shortcut to avoid trying many formatters in the common case that + // the input isn't a double. All built-in formats will start with a digit or period. + if (signSafeValue.isEmpty || + !(Character.isDigit(signSafeValue.head) || signSafeValue.head == '.')) { + return false + } + // Rule out strings ending in D or F, as they will parse as double but should be disallowed + if (value.nonEmpty && (value.last match { + case 'd' | 'D' | 'f' | 'F' => true + case _ => false + })) { + return false + } + (allCatch opt signSafeValue.toDouble).isDefined + } + + private def isInteger(value: String): Boolean = { + val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) { + value.substring(1) + } else { + value + } + // A little shortcut to avoid trying many formatters in the common case that + // the input isn't a number. All built-in formats will start with a digit. + if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) { + return false + } + (allCatch opt signSafeValue.toInt).isDefined + } + + private def isLong(value: String): Boolean = { + val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) { + value.substring(1) + } else { + value + } + // A little shortcut to avoid trying many formatters in the common case that + // the input isn't a number. All built-in formats will start with a digit. + if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) { + return false + } + (allCatch opt signSafeValue.toLong).isDefined + } + + private def isTimestamp(value: String): Boolean = { + try { + timestampFormatter.parseOptional(value).isDefined + } catch { + case _: IllegalArgumentException => false + } + } + + private def isDate(value: String): Boolean = { + (allCatch opt dateFormatter.parse(value)).isDefined + } + /** * Convert NullType to StringType and remove StructTypes with no fields */ @@ -288,7 +384,7 @@ private[sql] object XmlInferSchema { /** * Returns the most general data type for two given data types. */ - private[xml] def compatibleType(options: XmlOptions)(t1: DataType, t2: DataType): DataType = { + def compatibleType(t1: DataType, t2: DataType): DataType = { // TODO: Optimise this logic. findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. @@ -312,22 +408,22 @@ private[sql] object XmlInferSchema { case (StructType(fields1), StructType(fields2)) => val newFields = (fields1 ++ fields2).groupBy(_.name).map { case (name, fieldTypes) => - val dataType = fieldTypes.map(_.dataType).reduce(compatibleType(options)) + val dataType = fieldTypes.map(_.dataType).reduce(compatibleType) StructField(name, dataType, nullable = true) } StructType(newFields.toArray.sortBy(_.name)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => ArrayType( - compatibleType(options)(elementType1, elementType2), containsNull1 || containsNull2) + compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // In XML datasource, since StructType can be compared with ArrayType. // In this case, ArrayType wraps the StructType. case (ArrayType(ty1, _), ty2) => - ArrayType(compatibleType(options)(ty1, ty2)) + ArrayType(compatibleType(ty1, ty2)) case (ty1, ArrayType(ty2, _)) => - ArrayType(compatibleType(options)(ty1, ty2)) + ArrayType(compatibleType(ty1, ty2)) // As this library can infer an element with attributes as StructType whereas // some can be inferred as other non-structural data types, this case should be @@ -335,14 +431,14 @@ private[sql] object XmlInferSchema { case (st: StructType, dt: DataType) if st.fieldNames.contains(options.valueTag) => val valueIndex = st.fieldNames.indexOf(options.valueTag) val valueField = st.fields(valueIndex) - val valueDataType = compatibleType(options)(valueField.dataType, dt) + val valueDataType = compatibleType(valueField.dataType, dt) st.fields(valueIndex) = StructField(options.valueTag, valueDataType, nullable = true) st case (dt: DataType, st: StructType) if st.fieldNames.contains(options.valueTag) => val valueIndex = st.fieldNames.indexOf(options.valueTag) val valueField = st.fields(valueIndex) - val valueDataType = compatibleType(options)(dt, valueField.dataType) + val valueDataType = compatibleType(dt, valueField.dataType) st.fields(valueIndex) = StructField(options.valueTag, valueDataType, nullable = true) st diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala index 7d049fdd82b8a..aac6eec21c60a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala @@ -23,8 +23,7 @@ import javax.xml.stream.XMLInputFactory import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, DateFormatter, DateTimeUtils, ParseMode, PermissiveMode, TimestampFormatter} -import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, DateFormatter, DateTimeUtils, ParseMode, PermissiveMode} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} @@ -32,7 +31,7 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} * Options for the XML data source. */ private[sql] class XmlOptions( - @transient val parameters: CaseInsensitiveMap[String], + val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String, rowTagRequired: Boolean) @@ -147,6 +146,10 @@ private[sql] class XmlOptions( s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS][XXX]" }) + val timestampNTZFormatInRead: Option[String] = parameters.get(TIMESTAMP_NTZ_FORMAT) + val timestampNTZFormatInWrite: String = + parameters.getOrElse(TIMESTAMP_NTZ_FORMAT, s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]") + val timezone = parameters.get("timezone") val zoneId: ZoneId = DateTimeUtils.getZoneId( @@ -163,32 +166,6 @@ private[sql] class XmlOptions( def buildXmlFactory(): XMLInputFactory = { XMLInputFactory.newInstance() } - - val timestampFormatter = TimestampFormatter( - timestampFormatInRead, - zoneId, - locale, - legacyFormat = FAST_DATE_FORMAT, - isParsing = true) - - val timestampFormatterInWrite = TimestampFormatter( - timestampFormatInWrite, - zoneId, - locale, - legacyFormat = FAST_DATE_FORMAT, - isParsing = false) - - val dateFormatter = DateFormatter( - dateFormatInRead, - locale, - legacyFormat = FAST_DATE_FORMAT, - isParsing = true) - - val dateFormatterInWrite = DateFormatter( - dateFormatInWrite, - locale, - legacyFormat = FAST_DATE_FORMAT, - isParsing = false) } private[sql] object XmlOptions extends DataSourceOptions { @@ -225,6 +202,7 @@ private[sql] object XmlOptions extends DataSourceOptions { val COLUMN_NAME_OF_CORRUPT_RECORD = newOption("columnNameOfCorruptRecord") val DATE_FORMAT = newOption("dateFormat") val TIMESTAMP_FORMAT = newOption("timestampFormat") + val TIMESTAMP_NTZ_FORMAT = newOption("timestampNTZFormat") val TIME_ZONE = newOption("timeZone") val INDENT = newOption("indent") // Options with alternative diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index bc62003b251e1..9992d8cbba076 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -565,7 +565,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 4.0.0 */ @scala.annotation.varargs - def xml(paths: String*): DataFrame = format("xml").load(paths: _*) + def xml(paths: String*): DataFrame = { + userSpecifiedSchema.foreach(checkXmlSchema) + format("xml").load(paths: _*) + } /** * Loads an `Dataset[String]` storing XML object and returns the result as a `DataFrame`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala index d96d80d6ce51d..b09be84130abb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala @@ -121,7 +121,9 @@ object TextInputXmlDataSource extends XmlDataSource { def inferFromDataset( xml: Dataset[String], parsedOptions: XmlOptions): StructType = { - XmlInferSchema.infer(xml.rdd, parsedOptions) + SQLExecution.withSQLConfPropagated(xml.sparkSession) { + new XmlInferSchema(parsedOptions).infer(xml.rdd) + } } private def createBaseDataset( @@ -177,7 +179,7 @@ object MultiLineXmlDataSource extends XmlDataSource { parsedOptions) } SQLExecution.withSQLConfPropagated(sparkSession) { - val schema = XmlInferSchema.infer(tokenRDD, parsedOptions) + val schema = new XmlInferSchema(parsedOptions).infer(tokenRDD) schema } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index fc8f5a416ab14..36dd168992a14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.json.JsonUtils.checkJsonSchema import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2} +import org.apache.spark.sql.execution.datasources.xml.XmlUtils.checkXmlSchema import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.types.StructType @@ -278,7 +279,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * * @since 4.0.0 */ - def xml(path: String): DataFrame = format("xml").load(path) + def xml(path: String): DataFrame = { + userSpecifiedSchema.foreach(checkXmlSchema) + format("xml").load(path) + } /** * Loads a ORC file stream, returning the result as a `DataFrame`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 20600848019d6..2d4cd2f403c56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.xml import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} import java.nio.file.{Files, Path, Paths} import java.sql.{Date, Timestamp} +import java.time.Instant import java.util.TimeZone import scala.collection.mutable @@ -31,15 +32,15 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Encoders, QueryTest, Row, SaveMode} +import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.xml.XmlOptions import org.apache.spark.sql.catalyst.xml.XmlOptions._ import org.apache.spark.sql.execution.datasources.xml.TestUtils._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class XmlSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -48,9 +49,6 @@ class XmlSuite extends QueryTest with SharedSparkSession { private var tempDir: Path = _ - protected override def sparkConf = super.sparkConf - .set(SQLConf.SESSION_LOCAL_TIMEZONE.key, "UTC") - override protected def beforeAll(): Unit = { super.beforeAll() tempDir = Files.createTempDirectory("XmlSuite") @@ -1511,7 +1509,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { val expectedSchema = buildSchema(field("author"), field("date", TimestampType), field("date2", DateType)) assert(df.schema === expectedSchema) - assert(df.collect().head.getAs[Timestamp](1).toString === "2021-01-31 16:00:00.0") + assert(df.collect().head.getAs[Timestamp](1) === Timestamp.valueOf("2021-02-01 00:00:00")) assert(df.collect().head.getAs[Date](2).toString === "2021-02-01") } @@ -1556,7 +1554,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { val res = df.collect() assert(res.head.get(1) === "2011-12-03T10:15:30Z") assert(res.head.get(2) === "12-03-2011 10:15:30 PST") - assert(res.head.getAs[Timestamp](3).getTime === 1322892930000L) + assert(res.head.getAs[Timestamp](3) === Timestamp.valueOf("2011-12-03 06:15:30")) } test("Test custom timestampFormat with offset") { @@ -1803,4 +1801,174 @@ class XmlSuite extends QueryTest with SharedSparkSession { "declaration" -> s"<${XmlOptions.DEFAULT_DECLARATION}>"), "'declaration' should not include angle brackets") } + + def dataTypeTest(data: String, + dt: DataType): Unit = { + val xmlString = s"""$data""" + val schema = new StructType().add(XmlOptions.VALUE_TAG, dt) + val df = spark.read + .option("rowTag", "ROW") + .schema(schema) + .xml(spark.createDataset(Seq(xmlString))) + } + + test("Primitive field casting") { + val ts = Seq("2002-05-30 21:46:54", "2002-05-30T21:46:54", "2002-05-30T21:46:54.1234", + "2002-05-30T21:46:54Z", "2002-05-30T21:46:54.1234Z", "2002-05-30T21:46:54-06:00", + "2002-05-30T21:46:54+06:00", "2002-05-30T21:46:54.1234-06:00", + "2002-05-30T21:46:54.1234+06:00", "2002-05-30T21:46:54+00:00", "2002-05-30T21:46:54.0000Z") + + val tsXMLStr = ts.map(t => s"$t").mkString("\n") + val tsResult = ts.map(t => + Timestamp.from(Instant.ofEpochSecond(0, DateTimeUtils.stringToTimestamp( + UTF8String.fromString(t), DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)).get * 1000)) + ) + + val primitiveFieldAndType: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + s""" + 10.05 + 1,000.01 + 158,058,049.001 + + 10 + 10 + +10 + -10 + 10 + +10 + -10 + 10 + +10 + -10 + 1.00 + +1.00 + -1.00 + 1.00 + +1.00 + -1.00 + true + 1 + false + 0 + $tsXMLStr + 2002-09-24 + """.stripMargin :: Nil))(Encoders.STRING) + + val decimalType = DecimalType(20, 3) + + val schema = StructType( + StructField("decimal", ArrayType(decimalType), true) :: + StructField("emptyString", StringType, true) :: + StructField("ByteType", ByteType, true) :: + StructField("ShortType", ArrayType(ShortType), true) :: + StructField("IntegerType", ArrayType(IntegerType), true) :: + StructField("LongType", ArrayType(LongType), true) :: + StructField("FloatType", ArrayType(FloatType), true) :: + StructField("DoubleType", ArrayType(DoubleType), true) :: + StructField("BooleanType", ArrayType(BooleanType), true) :: + StructField("TimestampType", ArrayType(TimestampType), true) :: + StructField("DateType", DateType, true) :: Nil) + + val df = spark.read.schema(schema).xml(primitiveFieldAndType) + + checkAnswer( + df, + Seq(Row(Array( + Decimal(BigDecimal("10.05"), decimalType.precision, decimalType.scale).toJavaBigDecimal, + Decimal(BigDecimal("1000.01"), decimalType.precision, decimalType.scale).toJavaBigDecimal, + Decimal(BigDecimal("158058049.001"), decimalType.precision, decimalType.scale) + .toJavaBigDecimal), + "", + 10.toByte, + Array(10.toShort, 10.toShort, -10.toShort), + Array(10, 10, -10), + Array(10L, 10L, -10L), + Array(1.0.toFloat, 1.0.toFloat, -1.0.toFloat), + Array(1.0, 1.0, -1.0), + Array(true, true, false, false), + tsResult, + Date.valueOf("2002-09-24") + )) + ) + } + + test("Nullable types are handled") { + val dataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, + BooleanType, TimestampType, DateType, StringType) + + val dataXMLString = dataTypes.map { dt => + s"""<${dt.toString}>-""" + }.mkString("\n") + + val fields = dataTypes.map { dt => + StructField(dt.toString, dt, true) + } + val schema = StructType(fields) + + val res = dataTypes.map { dt => null } + + val nullDataset: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + s""" + $dataXMLString + """.stripMargin :: Nil))(Encoders.STRING) + + val df = spark.read.option("nullValue", "-").schema(schema).xml(nullDataset) + checkAnswer(df, Row.fromSeq(res)) + + val df2 = spark.read.xml(nullDataset) + checkAnswer(df2, Row.fromSeq(dataTypes.map { dt => "-" })) + } + + test("Custom timestamp format is used to parse correctly") { + val schema = StructType( + StructField("ts", TimestampType, true) :: Nil) + + Seq( + ("12-03-2011 10:15:30", "2011-12-03 10:15:30", "MM-dd-yyyy HH:mm:ss", "UTC"), + ("2011/12/03 10:15:30", "2011-12-03 10:15:30", "yyyy/MM/dd HH:mm:ss", "UTC"), + ("2011/12/03 10:15:30", "2011-12-03 10:15:30", "yyyy/MM/dd HH:mm:ss", "Asia/Shanghai") + ).foreach { case (ts, resTS, fmt, zone) => + val tsDataset: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + s""" + $ts + """.stripMargin :: Nil))(Encoders.STRING) + val timestampResult = Timestamp.from(Instant.ofEpochSecond(0, + DateTimeUtils.stringToTimestamp(UTF8String.fromString(resTS), + DateTimeUtils.getZoneId(zone)).get * 1000)) + + val df = spark.read.option("timestampFormat", fmt).option("timezone", zone) + .schema(schema).xml(tsDataset) + checkAnswer(df, Row(timestampResult)) + } + } + + test("Schema Inference for primitive types") { + val dataset: Dataset[String] = + spark.createDataset(spark.sparkContext.parallelize( + s""" + true + +10.1 + -10 + 10 + 8E9D + 8E9F + 2015-01-01 00:00:00 + """.stripMargin :: Nil))(Encoders.STRING) + + val expectedSchema = StructType(StructField("bool1", BooleanType, true) :: + StructField("double1", DoubleType, true) :: + StructField("long1", LongType, true) :: + StructField("long2", LongType, true) :: + StructField("string1", StringType, true) :: + StructField("string2", StringType, true) :: + StructField("ts1", TimestampType, true) :: Nil) + + val df = spark.read.xml(dataset) + assert(df.schema.toSet === expectedSchema.toSet) + checkAnswer(df, Row(true, 10.1, -10, 10, "8E9D", "8E9F", + Timestamp.valueOf("2015-01-01 00:00:00"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/util/TypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/util/TypeCastSuite.scala deleted file mode 100644 index 096fb3d83a54c..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/util/TypeCastSuite.scala +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.xml.util - -import java.math.BigDecimal -import java.util.Locale - -import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} -import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT -import org.apache.spark.sql.catalyst.xml.{TypeCast, XmlOptions} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -final class TypeCastSuite extends SharedSparkSession { - - test("Can parse decimal type values") { - val options = new XmlOptions() - val stringValues = Seq("10.05", "1,000.01", "158,058,049.001") - val decimalValues = Seq(10.05, 1000.01, 158058049.001) - val decimalType = DecimalType.SYSTEM_DEFAULT - - stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => - val dt = new BigDecimal(decimalVal.toString) - assert(TypeCast.castTo(strVal, decimalType, options) === - Decimal(dt, dt.precision(), dt.scale())) - } - } - - test("Nullable types are handled") { - val options = new XmlOptions(Map("nullValue" -> "-")) - for (t <- Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - BooleanType, TimestampType, DateType, StringType)) { - assert(TypeCast.castTo("-", t, options) === null) - } - } - - test("String type should always return the same as the input") { - val options = new XmlOptions() - assert(TypeCast.castTo("", StringType, options) === UTF8String.fromString("")) - } - - test("Types are cast correctly") { - val options = new XmlOptions() - assert(TypeCast.castTo("10", ByteType, options) === 10) - assert(TypeCast.castTo("10", ShortType, options) === 10) - assert(TypeCast.castTo("10", IntegerType, options) === 10) - assert(TypeCast.castTo("10", LongType, options) === 10) - assert(TypeCast.castTo("1.00", FloatType, options) === 1.0) - assert(TypeCast.castTo("1.00", DoubleType, options) === 1.0) - assert(TypeCast.castTo("true", BooleanType, options) === true) - assert(TypeCast.castTo("1", BooleanType, options) === true) - assert(TypeCast.castTo("false", BooleanType, options) === false) - assert(TypeCast.castTo("0", BooleanType, options) === false) - - { - val ts = TypeCast.castTo("2002-05-30 21:46:54", TimestampType, options) - assert(ts === 1022820414000000L) - assert(ts === - TimestampFormatter(None, DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30 21:46:54")) - } - { - val ts = TypeCast.castTo("2002-05-30T21:46:54", TimestampType, options) - assert(ts === 1022820414000000L) - assert(ts === - TimestampFormatter(None, DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54")) - } - { - val ts = TypeCast.castTo("2002-05-30T21:46:54.1234", TimestampType, options) - assert(ts === 1022820414123400L) - assert(ts === - TimestampFormatter(None, DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54.1234")) - } - { - val ts = TypeCast.castTo("2002-05-30T21:46:54Z", TimestampType, options) - assert(ts === 1022795214000000L) - assert(ts === - TimestampFormatter(None, DateTimeUtils.getZoneId("UTC"), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54Z")) - } - { - val ts = TypeCast.castTo("2002-05-30T21:46:54-06:00", TimestampType, options) - assert(ts === 1022816814000000L) - assert(ts === - TimestampFormatter(None, DateTimeUtils.getZoneId("-06:00"), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54-06:00")) - } - { - val ts = TypeCast.castTo("2002-05-30T21:46:54+06:00", TimestampType, options) - assert(ts === 1022773614000000L) - assert(ts === - TimestampFormatter(None, DateTimeUtils.getZoneId("+06:00"), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54+06:00")) - } - { - val ts = TypeCast.castTo("2002-05-30T21:46:54.1234Z", TimestampType, options) - assert(ts === 1022795214123400L) - assert(ts === - TimestampFormatter(None, DateTimeUtils.getZoneId("UTC"), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54.1234Z")) - } - { - val ts = TypeCast.castTo("2002-05-30T21:46:54.1234-06:00", TimestampType, options) - assert(ts === 1022816814123400L) - assert(ts === - TimestampFormatter(None, DateTimeUtils.getZoneId("-06:00"), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54.1234-06:00")) - } - { - val ts = TypeCast.castTo("2002-05-30T21:46:54.1234+06:00", TimestampType, options) - assert(ts === 1022773614123400L) - assert(ts === - TimestampFormatter(None, DateTimeUtils.getZoneId("+06:00"), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54.1234+06:00")) - } - { - val date = TypeCast.castTo("2002-09-24", DateType, options) - assert(date === 11954) - assert(date === DateFormatter(DateFormatter.defaultPattern, - Locale.US, FAST_DATE_FORMAT, true).parse("2002-09-24")) - } - } - - test("Types with sign are cast correctly") { - val options = new XmlOptions() - assert(TypeCast.signSafeToInt("+10", options) === 10) - assert(TypeCast.signSafeToLong("-10", options) === -10) - assert(TypeCast.signSafeToFloat("1.00", options) === 1.0) - assert(TypeCast.signSafeToDouble("-1.00", options) === -1.0) - } - - test("Types with sign are checked correctly") { - assert(TypeCast.isBoolean("true")) - assert(TypeCast.isInteger("10")) - assert(TypeCast.isLong("10")) - assert(TypeCast.isDouble("+10.1")) - assert(!TypeCast.isDouble("8E9D")) - assert(!TypeCast.isDouble("8E9F")) - val timestamp = "2015-01-01 00:00:00" - assert(TypeCast.isTimestamp(timestamp, new XmlOptions())) - } - - test("Float and Double Types are cast correctly with Locale") { - val options = new XmlOptions() - val defaultLocale = Locale.getDefault - try { - Locale.setDefault(Locale.FRANCE) - assert(TypeCast.castTo("1,00", FloatType, options) === 1.0) - assert(TypeCast.castTo("1,00", DoubleType, options) === 1.0) - } finally { - Locale.setDefault(defaultLocale) - } - } - - test("Parsing built-in timestamp formatters") { - val options = XmlOptions(Map()) - val expectedResult = - TimestampFormatter(None, DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30 21:46:54") - assert( - TypeCast.castTo("2002-05-30 21:46:54", TimestampType, options) === expectedResult - ) - assert( - TypeCast.castTo("2002-05-30T21:46:54", TimestampType, options) === expectedResult - ) - assert( - TypeCast.castTo("2002-05-30T21:46:54+00:00", TimestampType, options) === - TimestampFormatter(None, DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54+00:00") - ) - assert( - TypeCast.castTo("2002-05-30T21:46:54.0000Z", TimestampType, options) === - TimestampFormatter(None, DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone), - Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54.0000Z") - ) - } - - test("Custom timestamp format is used to parse correctly") { - var options = XmlOptions(Map("timestampFormat" -> "MM-dd-yyyy HH:mm:ss", "timezone" -> "UTC")) - assert( - TypeCast.castTo("12-03-2011 10:15:30", TimestampType, options) === - TimestampFormatter("MM-dd-yyyy HH:mm:ss", DateTimeUtils.getZoneId("UTC"), - Locale.US, FAST_DATE_FORMAT, true).parse("12-03-2011 10:15:30") - ) - - options = XmlOptions(Map("timestampFormat" -> "yyyy/MM/dd HH:mm:ss", "timezone" -> "UTC")) - assert( - TypeCast.castTo("2011/12/03 10:15:30", TimestampType, options) === - TimestampFormatter("yyyy/MM/dd HH:mm:ss", DateTimeUtils.getZoneId("UTC"), - Locale.US, FAST_DATE_FORMAT, true).parse("2011/12/03 10:15:30") - ) - - options = XmlOptions(Map("timestampFormat" -> "yyyy/MM/dd HH:mm:ss", - "timezone" -> "Asia/Shanghai")) - assert( - TypeCast.castTo("2011/12/03 10:15:30", TimestampType, options) === - TimestampFormatter("yyyy/MM/dd HH:mm:ss", DateTimeUtils.getZoneId("Asia/Shanghai"), - Locale.US, FAST_DATE_FORMAT, true).parse("2011/12/03 10:15:30") - ) - - options = XmlOptions(Map("timestampFormat" -> "yyyy/MM/dd HH:mm:ss", - "timezone" -> "Asia/Shanghai")) - assert( - TypeCast.castTo("2011/12/03 10:15:30", TimestampType, options) === - TimestampFormatter("yyyy/MM/dd HH:mm:ss", DateTimeUtils.getZoneId("Asia/Shanghai"), - Locale.US, FAST_DATE_FORMAT, true).parse("2011/12/03 10:15:30") - ) - - options = XmlOptions(Map("timestampFormat" -> "yyyy/MM/dd HH:mm:ss")) - assert(TypeCast.castTo("2011/12/03 10:15:30", TimestampType, options) === - TimestampFormatter("yyyy/MM/dd HH:mm:ss", - DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone), - Locale.US, FAST_DATE_FORMAT, true).parse("2011/12/03 10:15:30") - ) - } -} From a0dea7ca40bcd0340b72dff356b3e27ecefc45b7 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 7 Nov 2023 09:44:12 -0800 Subject: [PATCH 045/121] [SPARK-45802][CORE] Remove no longer needed Java `majorVersion` checks in `Platform` ### What changes were proposed in this pull request? This PR removes the version checks for Java 9 and 11 from `Platform` because Spark 4.0 minimum supports Java 17. ### Why are the changes needed? Remove no longer needed Java `majorVersion` checks. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43672 from LuciferYang/SPARK-45802. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- .../java/org/apache/spark/unsafe/Platform.java | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index dfa5734ccbce9..9d97d04a58137 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -59,13 +59,6 @@ public final class Platform { // reflection to invoke it, which is not necessarily possible by default in Java 9+. // Code below can test for null to see whether to use it. - // The implementation of Cleaner changed from JDK 8 to 9 - String cleanerClassName; - if (majorVersion < 9) { - cleanerClassName = "sun.misc.Cleaner"; - } else { - cleanerClassName = "jdk.internal.ref.Cleaner"; - } try { Class cls = Class.forName("java.nio.DirectByteBuffer"); Constructor constructor = (majorVersion < 21) ? @@ -84,7 +77,7 @@ public final class Platform { // no point continuing if the above failed: if (DBB_CONSTRUCTOR != null && DBB_CLEANER_FIELD != null) { - Class cleanerClass = Class.forName(cleanerClassName); + Class cleanerClass = Class.forName("jdk.internal.ref.Cleaner"); Method createMethod = cleanerClass.getMethod("create", Object.class, Runnable.class); // Accessing jdk.internal.ref.Cleaner should actually fail by default in JDK 9+, // unfortunately, unless the user has allowed access with something like @@ -314,7 +307,7 @@ public static void throwException(Throwable t) { } } - // This requires `majorVersion` and `_UNSAFE`. + // This requires `_UNSAFE`. static { boolean _unaligned; String arch = System.getProperty("os.arch", ""); @@ -326,10 +319,8 @@ public static void throwException(Throwable t) { try { Class bitsClass = Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); - if (_UNSAFE != null && majorVersion >= 9) { - // Java 9/10 and 11/12 have different field names. - Field unalignedField = - bitsClass.getDeclaredField(majorVersion >= 11 ? "UNALIGNED" : "unaligned"); + if (_UNSAFE != null) { + Field unalignedField = bitsClass.getDeclaredField("UNALIGNED"); _unaligned = _UNSAFE.getBoolean( _UNSAFE.staticFieldBase(unalignedField), _UNSAFE.staticFieldOffset(unalignedField)); } else { From b701d6e8951dd1a506e6a6bd0a5c3c7c23b8ddf0 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 7 Nov 2023 09:51:51 -0800 Subject: [PATCH 046/121] [SPARK-45258][PYTHON][DOCS] Refine docstring of `sum` ### What changes were proposed in this pull request? This PR proposes to improve the docstring of `sum`. ### Why are the changes needed? For end users, and better usability of PySpark. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the user facing documentation. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43684 from HyukjinKwon/SPARK-45258. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 869506a35586e..a32f04164f314 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1197,13 +1197,27 @@ def sum(col: "ColumnOrName") -> Column: Examples -------- + Example 1: Calculating the sum of values in a column + + >>> from pyspark.sql import functions as sf >>> df = spark.range(10) - >>> df.select(sum(df["id"])).show() + >>> df.select(sf.sum(df["id"])).show() +-------+ |sum(id)| +-------+ | 45| +-------+ + + Example 2: Using a plus expression together to calculate the sum + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(1, 2), (3, 4)], ["A", "B"]) + >>> df.select(sf.sum(sf.col("A") + sf.col("B"))).show() + +------------+ + |sum((A + B))| + +------------+ + | 10| + +------------+ """ return _invoke_function_over_columns("sum", col) From 563b3cab749f0104ef399730fe69fa4efd14be84 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 7 Nov 2023 09:52:51 -0800 Subject: [PATCH 047/121] [SPARK-45259][PYTHON][DOCS] Refine docstring of `count` ### What changes were proposed in this pull request? This PR proposes to improve the docstring of `count`. ### Why are the changes needed? For end users, and better usability of PySpark. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the user facing documentation. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43685 from HyukjinKwon/SPARK-45259. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 47 ++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a32f04164f314..81d120e2ff49d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1162,15 +1162,48 @@ def count(col: "ColumnOrName") -> Column: Examples -------- - Count by all columns (start), and by a column that does not count ``None``. + Example 1: Count all rows in a DataFrame + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([(None,), ("a",), ("b",), ("c",)], schema=["alphabets"]) - >>> df.select(count(expr("*")), count(df.alphabets)).show() - +--------+----------------+ - |count(1)|count(alphabets)| - +--------+----------------+ - | 4| 3| - +--------+----------------+ + >>> df.select(sf.count(sf.expr("*"))).show() + +--------+ + |count(1)| + +--------+ + | 4| + +--------+ + + Example 2: Count non-null values in a specific column + + >>> from pyspark.sql import functions as sf + >>> df.select(sf.count(df.alphabets)).show() + +----------------+ + |count(alphabets)| + +----------------+ + | 3| + +----------------+ + + Example 3: Count all rows in a DataFrame with multiple columns + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame( + ... [(1, "apple"), (2, "banana"), (3, None)], schema=["id", "fruit"]) + >>> df.select(sf.count(sf.expr("*"))).show() + +--------+ + |count(1)| + +--------+ + | 3| + +--------+ + + Example 4: Count non-null values in multiple columns + + >>> from pyspark.sql import functions as sf + >>> df.select(sf.count(df.id), sf.count(df.fruit)).show() + +---------+------------+ + |count(id)|count(fruit)| + +---------+------------+ + | 3| 2| + +---------+------------+ """ return _invoke_function_over_columns("count", col) From 5d851d9989056e7d3c4f7ddc27c3fd5043ab19d5 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 7 Nov 2023 09:53:58 -0800 Subject: [PATCH 048/121] [SPARK-45260][PYTHON][DOCS] Refine docstring of `count_distinct` ### What changes were proposed in this pull request? This PR proposes to improve the docstring of `count_distinct`. ### Why are the changes needed? For end users, and better usability of PySpark. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the user facing documentation. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43686 from HyukjinKwon/SPARK-45260. Lead-authored-by: Hyukjin Kwon Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 52 ++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 81d120e2ff49d..bf67f7ff51fda 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -4673,26 +4673,38 @@ def count_distinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column: Examples -------- - >>> from pyspark.sql import types - >>> df1 = spark.createDataFrame([1, 1, 3], types.IntegerType()) - >>> df2 = spark.createDataFrame([1, 2], types.IntegerType()) - >>> df1.join(df2).show() - +-----+-----+ - |value|value| - +-----+-----+ - | 1| 1| - | 1| 2| - | 1| 1| - | 1| 2| - | 3| 1| - | 3| 2| - +-----+-----+ - >>> df1.join(df2).select(count_distinct(df1.value, df2.value)).show() - +----------------------------+ - |count(DISTINCT value, value)| - +----------------------------+ - | 4| - +----------------------------+ + Example 1: Counting distinct values of a single column + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(1,), (1,), (3,)], ["value"]) + >>> df.select(sf.count_distinct(df.value)).show() + +---------------------+ + |count(DISTINCT value)| + +---------------------+ + | 2| + +---------------------+ + + Example 2: Counting distinct values of multiple columns + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(1, 1), (1, 2)], ["value1", "value2"]) + >>> df.select(sf.count_distinct(df.value1, df.value2)).show() + +------------------------------+ + |count(DISTINCT value1, value2)| + +------------------------------+ + | 2| + +------------------------------+ + + Example 3: Counting distinct values with column names as strings + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(1, 1), (1, 2)], ["value1", "value2"]) + >>> df.select(sf.count_distinct("value1", "value2")).show() + +------------------------------+ + |count(DISTINCT value1, value2)| + +------------------------------+ + | 2| + +------------------------------+ """ sc = _get_active_spark_context() return _invoke_function( From 90c6c2b36743e64ecdeaebb34fe37aa348701370 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 7 Nov 2023 09:58:50 -0800 Subject: [PATCH 049/121] [SPARK-45222][PYTHON][DOCS] Refine docstring of `DataFrameReader.json` ### What changes were proposed in this pull request? This PR proposes to improve the docstring of `DataFrameReader.json`. ### Why are the changes needed? For end users, and better usability of PySpark. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the user facing documentation. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43687 from HyukjinKwon/SPARK-45222. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/readwriter.py | 51 +++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 75faa13f02b38..b7e2c145f443e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -380,22 +380,59 @@ def json( Examples -------- - Write a DataFrame into a JSON file and read it back. + Example 1: Write a DataFrame into a JSON file and read it back. >>> import tempfile >>> with tempfile.TemporaryDirectory() as d: ... # Write a DataFrame into a JSON file ... spark.createDataFrame( - ... [{"age": 100, "name": "Hyukjin Kwon"}] + ... [{"age": 100, "name": "Hyukjin"}] ... ).write.mode("overwrite").format("json").save(d) ... ... # Read the JSON file as a DataFrame. ... spark.read.json(d).show() - +---+------------+ - |age| name| - +---+------------+ - |100|Hyukjin Kwon| - +---+------------+ + +---+-------+ + |age| name| + +---+-------+ + |100|Hyukjin| + +---+-------+ + + Example 2: Read JSON from multiple files in a directory + + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as d1, tempfile.TemporaryDirectory() as d2: + ... # Write a DataFrame into a JSON file + ... spark.createDataFrame( + ... [{"age": 30, "name": "Bob"}] + ... ).write.mode("overwrite").format("json").save(d1) + ... + ... # Read the JSON files as a DataFrame. + ... spark.createDataFrame( + ... [{"age": 25, "name": "Alice"}] + ... ).write.mode("overwrite").format("json").save(d2) + ... spark.read.json([d1, d2]).show() + +---+-----+ + |age| name| + +---+-----+ + | 25|Alice| + | 30| Bob| + +---+-----+ + + Example 3: Read JSON with a custom schema + + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as d: + ... # Write a DataFrame into a JSON file + ... spark.createDataFrame( + ... [{"age": 30, "name": "Bob"}] + ... ).write.mode("overwrite").format("json").save(d) + ... custom_schema = "name STRING, age INT" + ... spark.read.json(d, schema=custom_schema).show() + +----+---+ + |name|age| + +----+---+ + | Bob| 30| + +----+---+ """ self._set_opts( schema=schema, From 8a9b44ea211e32328b680b21d50beebd5d8b83d8 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 7 Nov 2023 09:59:58 -0800 Subject: [PATCH 050/121] [SPARK-45809][PYTHON][DOCS] Refine docstring of `lit` ### What changes were proposed in this pull request? This PR proposes to improve the docstring of `lit`. ### Why are the changes needed? For end users, and better usability of PySpark. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the user facing documentation. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43679 from HyukjinKwon/SPARK-45809. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bf67f7ff51fda..dd6be89ab853d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -163,22 +163,49 @@ def lit(col: Any) -> Column: Examples -------- + Example 1: Creating a literal column with an integer value. + + >>> import pyspark.sql.functions as sf >>> df = spark.range(1) - >>> df.select(lit(5).alias('height'), df.id).show() + >>> df.select(sf.lit(5).alias('height'), df.id).show() +------+---+ |height| id| +------+---+ | 5| 0| +------+---+ - Create a literal from a list. + Example 2: Creating a literal column from a list. - >>> spark.range(1).select(lit([1, 2, 3])).show() + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.lit([1, 2, 3])).show() +--------------+ |array(1, 2, 3)| +--------------+ | [1, 2, 3]| +--------------+ + + Example 3: Creating a literal column from a string. + + >>> import pyspark.sql.functions as sf + >>> df = spark.range(1) + >>> df.select(sf.lit("PySpark").alias('framework'), df.id).show() + +---------+---+ + |framework| id| + +---------+---+ + | PySpark| 0| + +---------+---+ + + Example 4: Creating a literal column from a boolean value. + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([(True, "Yes"), (False, "No")], ["flag", "response"]) + >>> df.select(sf.lit(False).alias('is_approved'), df.response).show() + +-----------+--------+ + |is_approved|response| + +-----------+--------+ + | false| Yes| + | false| No| + +-----------+--------+ """ if isinstance(col, Column): return col From 6a44b627f40f501a171794416b6a6a9cae8893b5 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 7 Nov 2023 10:01:50 -0800 Subject: [PATCH 051/121] [SPARK-45811][PYTHON][DOCS] Refine docstring of `from_xml` ### What changes were proposed in this pull request? This PR proposes to improve the docstring of `from_xml`. ### Why are the changes needed? For end users, and better usability of PySpark. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the user facing documentation. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43680 from HyukjinKwon/SPARK-45186. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 51 +++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dd6be89ab853d..ef5c0ea073ab7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -13635,6 +13635,8 @@ def json_object_keys(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("json_object_keys", col) +# TODO: Fix and add an example for StructType with Spark Connect +# e.g., StructType([StructField("a", IntegerType())]) @_try_remote_functions def from_xml( col: "ColumnOrName", @@ -13668,40 +13670,51 @@ def from_xml( Examples -------- - >>> from pyspark.sql.types import * - >>> from pyspark.sql.functions import from_xml, schema_of_xml, lit - - StructType input with simple IntegerType. + Example 1: Parsing XML with a :class:`StructType` schema + >>> import pyspark.sql.functions as sf + >>> from pyspark.sql.types import StructType, StructField, LongType + ... # Sample data with an XML column >>> data = [(1, '''

    1

    ''')] >>> df = spark.createDataFrame(data, ("key", "value")) + ... # Define the schema using a StructType + >>> schema = StructType([StructField("a", LongType())]) + ... # Parse the XML column using the specified schema + >>> df.select(sf.from_xml(df.value, schema).alias("xml")).collect() + [Row(xml=Row(a=1))] - TODO: Fix StructType for spark connect - schema = StructType([StructField("a", IntegerType())]) + Example 2: Parsing XML with a DDL-formatted string schema + >>> import pyspark.sql.functions as sf + >>> data = [(1, '''

    1

    ''')] + >>> df = spark.createDataFrame(data, ("key", "value")) + ... # Define the schema using a DDL-formatted string >>> schema = "STRUCT" - >>> df.select(from_xml(df.value, schema).alias("xml")).collect() + ... # Parse the XML column using the DDL-formatted schema + >>> df.select(sf.from_xml(df.value, schema).alias("xml")).collect() [Row(xml=Row(a=1))] - String input. - - >>> df.select(from_xml(df.value, "a INT").alias("xml")).collect() - [Row(xml=Row(a=1))] + Example 3: Parsing XML with :class:`ArrayType` in schema + >>> import pyspark.sql.functions as sf >>> data = [(1, '

    12

    ')] >>> df = spark.createDataFrame(data, ("key", "value")) - - TODO: Fix StructType for spark connect - schema = StructType([StructField("a", ArrayType(IntegerType()))]) - + ... # Define the schema with an Array type >>> schema = "STRUCT>" - >>> df.select(from_xml(df.value, schema).alias("xml")).collect() + ... # Parse the XML column using the schema with an Array + >>> df.select(sf.from_xml(df.value, schema).alias("xml")).collect() [Row(xml=Row(a=[1, 2]))] - Column input generated by schema_of_xml. + Example 4: Parsing XML using :meth:`pyspark.sql.functions.schema_of_xml` - >>> schema = schema_of_xml(lit(data[0][1])) - >>> df.select(from_xml(df.value, schema).alias("xml")).collect() + >>> import pyspark.sql.functions as sf + >>> # Sample data with an XML column + ... data = [(1, '

    12

    ')] + >>> df = spark.createDataFrame(data, ("key", "value")) + ... # Generate the schema from an example XML value + >>> schema = sf.schema_of_xml(sf.lit(data[0][1])) + ... # Parse the XML column using the generated schema + >>> df.select(sf.from_xml(df.value, schema).alias("xml")).collect() [Row(xml=Row(a=[1, 2]))] """ From 1157ffddde9713cccfd6f8572171fff36e5aa3be Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 7 Nov 2023 10:03:05 -0800 Subject: [PATCH 052/121] [SPARK-45186][PYTHON][DOCS] Refine docstring of `schema_of_xml` ### What changes were proposed in this pull request? This PR proposes to improve the docstring of `schema_of_xml`. ### Why are the changes needed? For end users, and better usability of PySpark. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the user facing documentation. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43681 from HyukjinKwon/SPARK-45186-1. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 37 ++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ef5c0ea073ab7..95821deeec173 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -13755,14 +13755,45 @@ def schema_of_xml(xml: "ColumnOrName", options: Optional[Dict[str, str]] = None) Examples -------- + Example 1: Parsing a simple XML with a single element + + >>> from pyspark.sql import functions as sf >>> df = spark.range(1) - >>> df.select(schema_of_xml(lit('

    1

    ')).alias("xml")).collect() + >>> df.select(sf.schema_of_xml(sf.lit('

    1

    ')).alias("xml")).collect() [Row(xml='STRUCT')] - >>> df.select(schema_of_xml(lit('

    12

    ')).alias("xml")).collect() + + Example 2: Parsing an XML with multiple elements in an array + + >>> from pyspark.sql import functions as sf + >>> df.select(sf.schema_of_xml(sf.lit('

    12

    ')).alias("xml")).collect() [Row(xml='STRUCT>')] - >>> schema = schema_of_xml('

    1

    ', {'excludeAttribute':'true'}) + + Example 3: Parsing XML with options to exclude attributes + + >>> from pyspark.sql import functions as sf + >>> schema = sf.schema_of_xml('

    1

    ', {'excludeAttribute':'true'}) >>> df.select(schema.alias("xml")).collect() [Row(xml='STRUCT')] + + Example 4: Parsing XML with complex structure + + >>> from pyspark.sql import functions as sf + >>> df.select( + ... sf.schema_of_xml( + ... sf.lit('Alice30') + ... ).alias("xml") + ... ).collect() + [Row(xml='STRUCT>')] + + Example 5: Parsing XML with nested arrays + + >>> from pyspark.sql import functions as sf + >>> df.select( + ... sf.schema_of_xml( + ... sf.lit('12') + ... ).alias("xml") + ... ).collect() + [Row(xml='STRUCT>>')] """ if isinstance(xml, str): col = _create_column_from_literal(xml) From beb0238273b937bb42e746f7b240dd63e48f0667 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 7 Nov 2023 10:11:19 -0800 Subject: [PATCH 053/121] [SPARK-45819][CORE] Support `clear` in REST Submission API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR aims to support `clear` in REST Submission API to clear `completed` drivers and apps. ### Why are the changes needed? This new feature is helpful for users to reset the completed drivers and apps in Spark Master. **"1 Completed"** Screenshot 2023-11-07 at 12 56 02 AM **After invoking `clear` API, "0 Completed"** ``` $ curl -X POST http://max.local:6066/v1/submissions/clear { "action" : "ClearResponse", "message" : "", "serverSparkVersion" : "4.0.0-SNAPSHOT", "success" : true } ``` Screenshot 2023-11-07 at 12 56 24 AM ### Does this PR introduce _any_ user-facing change? No, this is a new API. ### How was this patch tested? Pass the CIs with the newly added test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43698 from dongjoon-hyun/SPARK-45819. Lead-authored-by: Dongjoon Hyun Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../apache/spark/deploy/DeployMessage.scala | 2 ++ .../apache/spark/deploy/master/Master.scala | 8 +++++ .../deploy/rest/RestSubmissionClient.scala | 35 +++++++++++++++++++ .../deploy/rest/RestSubmissionServer.scala | 22 +++++++++++- .../deploy/rest/StandaloneRestServer.scala | 19 ++++++++++ .../rest/SubmitRestProtocolResponse.scala | 10 ++++++ .../rest/StandaloneRestSubmitSuite.scala | 33 +++++++++++++++++ 7 files changed, 128 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 4ec0edd5909ec..f49530461b4d0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -238,6 +238,8 @@ private[deploy] object DeployMessages { case class DriverStatusResponse(found: Boolean, state: Option[DriverState], workerId: Option[String], workerHostPort: Option[String], exception: Option[Exception]) + case object RequestClearCompletedDriversAndApps extends DeployMessage + // Internal message in AppClient case object StopAppClient diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e63d72ebb40d2..3ba50318610ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -460,6 +460,14 @@ private[deploy] class Master( } } + case RequestClearCompletedDriversAndApps => + val numDrivers = completedDrivers.length + val numApps = completedApps.length + logInfo(s"Asked to clear $numDrivers completed drivers and $numApps completed apps.") + completedDrivers.clear() + completedApps.clear() + context.reply(true) + case RequestDriverStatus(driverId) => if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 68f08dd951ef7..3010efc936f97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -135,6 +135,35 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { response } + /** Request that the server clears all submissions and applications. */ + def clear(): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to clear $master.") + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getClearUrl(m) + try { + response = post(url) + response match { + case k: ClearResponse => + if (!Utils.responseFromBackup(k.message)) { + handleRestResponse(k) + handled = true + } + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } + } + response + } + /** Request the status of a submission from the server. */ def requestSubmissionStatus( submissionId: String, @@ -300,6 +329,12 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { new URL(s"$baseUrl/kill/$submissionId") } + /** Return the REST URL for clear all existing submissions and applications. */ + private def getClearUrl(master: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/clear") + } + /** Return the REST URL for requesting the status of an existing submission. */ private def getStatusUrl(master: String, submissionId: String): URL = { val baseUrl = getBaseUrl(master) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index 41845dc31a988..3323d0f529ebf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -55,6 +55,7 @@ private[spark] abstract class RestSubmissionServer( protected val submitRequestServlet: SubmitRequestServlet protected val killRequestServlet: KillRequestServlet protected val statusRequestServlet: StatusRequestServlet + protected val clearRequestServlet: ClearRequestServlet private var _server: Option[Server] = None @@ -64,6 +65,7 @@ private[spark] abstract class RestSubmissionServer( s"$baseContext/create/*" -> submitRequestServlet, s"$baseContext/kill/*" -> killRequestServlet, s"$baseContext/status/*" -> statusRequestServlet, + s"$baseContext/clear/*" -> clearRequestServlet, "/*" -> new ErrorServlet // default handler ) @@ -227,6 +229,24 @@ private[rest] abstract class KillRequestServlet extends RestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse } +/** + * A servlet for handling clear requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class ClearRequestServlet extends RestServlet { + + /** + * Clear the completed drivers and apps. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val responseMessage = handleClear() + sendResponse(responseMessage, response) + } + + protected def handleClear(): ClearResponse +} + /** * A servlet for handling status requests passed to the [[RestSubmissionServer]]. */ @@ -311,7 +331,7 @@ private class ErrorServlet extends RestServlet { "Missing the /submissions prefix." case `serverVersion` :: "submissions" :: tail => // http://host:port/correct-version/submissions/* - "Missing an action: please specify one of /create, /kill, or /status." + "Missing an action: please specify one of /create, /kill, /clear or /status." case unknownVersion :: tail => // http://host:port/unknown-version/* versionMismatch = true diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index a298e4f6dbf03..8ed716428dc28 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -65,6 +65,8 @@ private[deploy] class StandaloneRestServer( new StandaloneKillRequestServlet(masterEndpoint, masterConf) protected override val statusRequestServlet = new StandaloneStatusRequestServlet(masterEndpoint, masterConf) + protected override val clearRequestServlet = + new StandaloneClearRequestServlet(masterEndpoint, masterConf) } /** @@ -107,6 +109,23 @@ private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRe } } +/** + * A servlet for handling clear requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class StandaloneClearRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) + extends ClearRequestServlet { + + protected def handleClear(): ClearResponse = { + val response = masterEndpoint.askSync[Boolean]( + DeployMessages.RequestClearCompletedDriversAndApps) + val c = new ClearResponse + c.serverSparkVersion = sparkVersion + c.message = "" + c.success = response + c + } +} + /** * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala index 0e226ee294cab..21614c22285f8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -55,6 +55,16 @@ private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { } } +/** + * A response to a clear request in the REST application submission protocol. + */ +private[spark] class ClearResponse extends SubmitRestProtocolResponse { + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(success, "success") + } +} + /** * A response to a status request in the REST application submission protocol. */ diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 3cd96670c8b5f..d775aa6542dcd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -227,6 +227,15 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { assert(statusResponse.submissionId === doesNotExist) } + test("SPARK-45819: clear") { + val masterUrl = startDummyServer() + val response = new RestSubmissionClient(masterUrl).clear() + val clearResponse = getClearResponse(response) + assert(clearResponse.action === Utils.getFormattedClassName(clearResponse)) + assert(clearResponse.serverSparkVersion === SPARK_VERSION) + assert(clearResponse.success) + } + /* ---------------------------------------- * | Aberrant client / server behavior | * ---------------------------------------- */ @@ -505,6 +514,15 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { } } + /** Return the response as a clear response, or fail with error otherwise. */ + private def getClearResponse(response: SubmitRestProtocolResponse): ClearResponse = { + response match { + case k: ClearResponse => k + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected clear response. Actual: ${r.toJson}") + } + } + /** Return the response as a status response, or fail with error otherwise. */ private def getStatusResponse(response: SubmitRestProtocolResponse): SubmissionStatusResponse = { response match { @@ -574,6 +592,8 @@ private class DummyMaster( context.reply(KillDriverResponse(self, driverId, success = true, killMessage)) case RequestDriverStatus(driverId) => context.reply(DriverStatusResponse(found = true, Some(state), None, None, exception)) + case RequestClearCompletedDriversAndApps => + context.reply(true) } } @@ -617,6 +637,7 @@ private class SmarterMaster(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEn * When handling a submit request, the server returns a malformed JSON. * When handling a kill request, the server returns an invalid JSON. * When handling a status request, the server throws an internal exception. + * When handling a clear request, the server throws an internal exception. * The purpose of this class is to test that client handles these cases gracefully. */ private class FaultyStandaloneRestServer( @@ -630,6 +651,7 @@ private class FaultyStandaloneRestServer( protected override val submitRequestServlet = new MalformedSubmitServlet protected override val killRequestServlet = new InvalidKillServlet protected override val statusRequestServlet = new ExplodingStatusServlet + protected override val clearRequestServlet = new ExplodingClearServlet /** A faulty servlet that produces malformed responses. */ class MalformedSubmitServlet @@ -660,4 +682,15 @@ private class FaultyStandaloneRestServer( s } } + + /** A faulty clear servlet that explodes. */ + class ExplodingClearServlet extends StandaloneClearRequestServlet(masterEndpoint, masterConf) { + private def explode: Int = 1 / 0 + + protected override def handleClear(): ClearResponse = { + val s = super.handleClear() + s.message = explode.toString + s + } + } } From 1ca543b5595ebfff4c46500df0ef7715c440c050 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 7 Nov 2023 10:12:16 -0800 Subject: [PATCH 054/121] [SPARK-45808][CONNECT][PYTHON] Better error handling for SQL Exceptions ### What changes were proposed in this pull request? This patch optimizes the handling of errors reported back to Python. First, it properly allows the extraction of the `ERROR_CLASS` and the `SQL_STATE` and gives simpler accces to the stack trace. It therefore makes sure that the display of the stack trace is no longer only server-side decided but becomes a local usability property. In addition the following methods on the `SparkConnectGrpcException` become actually useful: * `getSqlState()` * `getErrorClass()` * `getStackTrace()` ### Why are the changes needed? Compatibility ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Updated the existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43667 from grundprinzip/SPARK-XXXX-ex. Authored-by: Martin Grund Signed-off-by: Hyukjin Kwon --- .../apache/spark/sql/ClientE2ETestSuite.scala | 3 +- ...SparkConnectFetchErrorDetailsHandler.scala | 6 +- .../spark/sql/connect/utils/ErrorUtils.scala | 14 ++ .../FetchErrorDetailsHandlerSuite.scala | 14 +- .../SparkConnectSessionHolderSuite.scala | 102 +++++----- python/pyspark/errors/exceptions/base.py | 2 +- python/pyspark/errors/exceptions/captured.py | 2 +- python/pyspark/errors/exceptions/connect.py | 178 +++++++++++++++--- python/pyspark/sql/connect/client/core.py | 13 +- .../sql/tests/connect/test_connect_basic.py | 25 +-- 10 files changed, 258 insertions(+), 101 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index b9fa415034c3e..10c928f130416 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -136,8 +136,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assert( ex.getStackTrace .find(_.getClassName.contains("org.apache.spark.sql.catalyst.analysis.CheckAnalysis")) - .isDefined - == isServerStackTraceEnabled) + .isDefined) } } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala index 17a6e9e434f37..b5a3c986d169b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala @@ -20,9 +20,7 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.connect.proto.FetchErrorDetailsResponse -import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.utils.ErrorUtils -import org.apache.spark.sql.internal.SQLConf /** * Handles [[proto.FetchErrorDetailsRequest]]s for the [[SparkConnectService]]. The handler @@ -46,9 +44,7 @@ class SparkConnectFetchErrorDetailsHandler( ErrorUtils.throwableToFetchErrorDetailsResponse( st = error, - serverStackTraceEnabled = sessionHolder.session.conf.get( - Connect.CONNECT_SERVER_STACKTRACE_ENABLED) || sessionHolder.session.conf.get( - SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED)) + serverStackTraceEnabled = true) } .getOrElse(FetchErrorDetailsResponse.newBuilder().build()) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 744fa3c8aa1a4..7cb555ca47ec9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -164,6 +164,20 @@ private[connect] object ErrorUtils extends Logging { "classes", JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName)))) + // Add the SQL State and Error Class to the response metadata of the ErrorInfoObject. + st match { + case e: SparkThrowable => + val state = e.getSqlState + if (state != null && state.nonEmpty) { + errorInfo.putMetadata("sqlState", state) + } + val errorClass = e.getErrorClass + if (errorClass != null && errorClass.nonEmpty) { + errorInfo.putMetadata("errorClass", errorClass) + } + case _ => + } + if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) { // Generate a new unique key for this exception. val errorId = UUID.randomUUID().toString diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala index 40439a2172303..ebcd1de600573 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala @@ -103,15 +103,11 @@ class FetchErrorDetailsHandlerSuite extends SharedSparkSession with ResourceHelp assert(response.getErrors(1).getErrorTypeHierarchy(1) == classOf[Throwable].getName) assert(response.getErrors(1).getErrorTypeHierarchy(2) == classOf[Object].getName) assert(!response.getErrors(1).hasCauseIdx) - if (serverStacktraceEnabled) { - assert(response.getErrors(0).getStackTraceCount == testError.getStackTrace.length) - assert( - response.getErrors(1).getStackTraceCount == - testError.getCause.getStackTrace.length) - } else { - assert(response.getErrors(0).getStackTraceCount == 0) - assert(response.getErrors(1).getStackTraceCount == 0) - } + assert(response.getErrors(0).getStackTraceCount == testError.getStackTrace.length) + assert( + response.getErrors(1).getStackTraceCount == + testError.getCause.getStackTrace.length) + } finally { sessionHolder.session.conf.unset(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key) sessionHolder.session.conf.unset(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 910c2a2650c6b..9845cee31037c 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -169,6 +169,56 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { accumulator = null) } + test("python listener process: process terminates after listener is removed") { + // scalastyle:off assume + assume(IntegratedUDFTestUtils.shouldTestPandasUDFs) + // scalastyle:on assume + + val sessionHolder = SessionHolder.forTesting(spark) + try { + SparkConnectService.start(spark.sparkContext) + + val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction) + + val id1 = "listener_removeListener_test_1" + val id2 = "listener_removeListener_test_2" + val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder) + val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder) + + sessionHolder.cacheListenerById(id1, listener1) + spark.streams.addListener(listener1) + sessionHolder.cacheListenerById(id2, listener2) + spark.streams.addListener(listener2) + + val (runner1, runner2) = (listener1.runner, listener2.runner) + + // assert both python processes are running + assert(!runner1.isWorkerStopped().get) + assert(!runner2.isWorkerStopped().get) + + // remove listener1 + spark.streams.removeListener(listener1) + sessionHolder.removeCachedListener(id1) + // assert listener1's python process is not running + eventually(timeout(30.seconds)) { + assert(runner1.isWorkerStopped().get) + assert(!runner2.isWorkerStopped().get) + } + + // remove listener2 + spark.streams.removeListener(listener2) + sessionHolder.removeCachedListener(id2) + eventually(timeout(30.seconds)) { + // assert listener2's python process is not running + assert(runner2.isWorkerStopped().get) + // all listeners are removed + assert(spark.streams.listListeners().isEmpty) + } + } finally { + SparkConnectService.stop() + } + } + test("python foreachBatch process: process terminates after query is stopped") { // scalastyle:off assume assume(IntegratedUDFTestUtils.shouldTestPandasUDFs) @@ -232,58 +282,10 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { assert(spark.streams.listListeners().length == 1) // only process termination listener } finally { SparkConnectService.stop() + // Wait for things to calm down. + Thread.sleep(4.seconds.toMillis) // remove process termination listener spark.streams.listListeners().foreach(spark.streams.removeListener) } } - - test("python listener process: process terminates after listener is removed") { - // scalastyle:off assume - assume(IntegratedUDFTestUtils.shouldTestPandasUDFs) - // scalastyle:on assume - - val sessionHolder = SessionHolder.forTesting(spark) - try { - SparkConnectService.start(spark.sparkContext) - - val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction) - - val id1 = "listener_removeListener_test_1" - val id2 = "listener_removeListener_test_2" - val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder) - val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder) - - sessionHolder.cacheListenerById(id1, listener1) - spark.streams.addListener(listener1) - sessionHolder.cacheListenerById(id2, listener2) - spark.streams.addListener(listener2) - - val (runner1, runner2) = (listener1.runner, listener2.runner) - - // assert both python processes are running - assert(!runner1.isWorkerStopped().get) - assert(!runner2.isWorkerStopped().get) - - // remove listener1 - spark.streams.removeListener(listener1) - sessionHolder.removeCachedListener(id1) - // assert listener1's python process is not running - eventually(timeout(30.seconds)) { - assert(runner1.isWorkerStopped().get) - assert(!runner2.isWorkerStopped().get) - } - - // remove listener2 - spark.streams.removeListener(listener2) - sessionHolder.removeCachedListener(id2) - eventually(timeout(30.seconds)) { - // assert listener2's python process is not running - assert(runner2.isWorkerStopped().get) - // all listeners are removed - assert(spark.streams.listListeners().isEmpty) - } - } finally { - SparkConnectService.stop() - } - } } diff --git a/python/pyspark/errors/exceptions/base.py b/python/pyspark/errors/exceptions/base.py index 1d09a68dffbfe..518a2d99ce889 100644 --- a/python/pyspark/errors/exceptions/base.py +++ b/python/pyspark/errors/exceptions/base.py @@ -75,7 +75,7 @@ def getMessageParameters(self) -> Optional[Dict[str, str]]: """ return self.message_parameters - def getSqlState(self) -> None: + def getSqlState(self) -> Optional[str]: """ Returns an SQLSTATE as a string. diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py index d62b7d24347e7..55ed7ab3a6d52 100644 --- a/python/pyspark/errors/exceptions/captured.py +++ b/python/pyspark/errors/exceptions/captured.py @@ -107,7 +107,7 @@ def getMessageParameters(self) -> Optional[Dict[str, str]]: else: return None - def getSqlState(self) -> Optional[str]: # type: ignore[override] + def getSqlState(self) -> Optional[str]: assert SparkContext._gateway is not None gw = SparkContext._gateway if self._origin is not None and is_instance_of( diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index 423fb2c6f0acc..2558c425469a1 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -46,55 +46,155 @@ class SparkConnectException(PySparkException): def convert_exception( - info: "ErrorInfo", truncated_message: str, resp: Optional[pb2.FetchErrorDetailsResponse] + info: "ErrorInfo", + truncated_message: str, + resp: Optional[pb2.FetchErrorDetailsResponse], + display_server_stacktrace: bool = False, ) -> SparkConnectException: classes = [] + sql_state = None + error_class = None + + stacktrace: Optional[str] = None + if "classes" in info.metadata: classes = json.loads(info.metadata["classes"]) + if "sqlState" in info.metadata: + sql_state = info.metadata["sqlState"] + + if "errorClass" in info.metadata: + error_class = info.metadata["errorClass"] + if resp is not None and resp.HasField("root_error_idx"): message = resp.errors[resp.root_error_idx].message stacktrace = _extract_jvm_stacktrace(resp) else: message = truncated_message - stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else "" - - if len(stacktrace) > 0: - message += f"\n\nJVM stacktrace:\n{stacktrace}" + stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else None + display_server_stacktrace = display_server_stacktrace if stacktrace is not None else False if "org.apache.spark.sql.catalyst.parser.ParseException" in classes: - return ParseException(message) + return ParseException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) # Order matters. ParseException inherits AnalysisException. elif "org.apache.spark.sql.AnalysisException" in classes: - return AnalysisException(message) + return AnalysisException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes: - return StreamingQueryException(message) + return StreamingQueryException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) elif "org.apache.spark.sql.execution.QueryExecutionException" in classes: - return QueryExecutionException(message) + return QueryExecutionException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) # Order matters. NumberFormatException inherits IllegalArgumentException. elif "java.lang.NumberFormatException" in classes: - return NumberFormatException(message) + return NumberFormatException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) elif "java.lang.IllegalArgumentException" in classes: - return IllegalArgumentException(message) + return IllegalArgumentException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) elif "java.lang.ArithmeticException" in classes: - return ArithmeticException(message) + return ArithmeticException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) elif "java.lang.UnsupportedOperationException" in classes: - return UnsupportedOperationException(message) + return UnsupportedOperationException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) elif "java.lang.ArrayIndexOutOfBoundsException" in classes: - return ArrayIndexOutOfBoundsException(message) + return ArrayIndexOutOfBoundsException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) elif "java.time.DateTimeException" in classes: - return DateTimeException(message) + return DateTimeException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) elif "org.apache.spark.SparkRuntimeException" in classes: - return SparkRuntimeException(message) + return SparkRuntimeException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) elif "org.apache.spark.SparkUpgradeException" in classes: - return SparkUpgradeException(message) + return SparkUpgradeException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) elif "org.apache.spark.api.python.PythonException" in classes: return PythonException( "\n An exception was thrown from the Python worker. " "Please see the stack trace below.\n%s" % message ) + # Make sure that the generic SparkException is handled last. + elif "org.apache.spark.SparkException" in classes: + return SparkException( + message, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) else: - return SparkConnectGrpcException(message, reason=info.reason) + return SparkConnectGrpcException( + message, + reason=info.reason, + error_class=error_class, + sql_state=sql_state, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, + ) def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str: @@ -106,7 +206,7 @@ def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str: def format_stacktrace(error: pb2.FetchErrorDetailsResponse.Error) -> None: message = f"{error.error_type_hierarchy[0]}: {error.message}" if len(lines) == 0: - lines.append(message) + lines.append(error.error_type_hierarchy[0]) else: lines.append(f"Caused by: {message}") for elem in error.stack_trace: @@ -135,16 +235,48 @@ def __init__( error_class: Optional[str] = None, message_parameters: Optional[Dict[str, str]] = None, reason: Optional[str] = None, + sql_state: Optional[str] = None, + server_stacktrace: Optional[str] = None, + display_server_stacktrace: bool = False, ) -> None: self.message = message # type: ignore[assignment] if reason is not None: self.message = f"({reason}) {self.message}" + # PySparkException has the assumption that error_class and message_parameters are + # only occurring together. If only one is set, we assume the message to be fully + # parsed. + tmp_error_class = error_class + tmp_message_parameters = message_parameters + if error_class is not None and message_parameters is None: + tmp_error_class = None + elif error_class is None and message_parameters is not None: + tmp_message_parameters = None + super().__init__( message=self.message, - error_class=error_class, - message_parameters=message_parameters, + error_class=tmp_error_class, + message_parameters=tmp_message_parameters, ) + self.error_class = error_class + self._sql_state: Optional[str] = sql_state + self._stacktrace: Optional[str] = server_stacktrace + self._display_stacktrace: bool = display_server_stacktrace + + def getSqlState(self) -> Optional[str]: + if self._sql_state is not None: + return self._sql_state + else: + return super().getSqlState() + + def getStackTrace(self) -> Optional[str]: + return self._stacktrace + + def __str__(self) -> str: + desc = self.message + if self._display_stacktrace: + desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace + return desc class AnalysisException(SparkConnectGrpcException, BaseAnalysisException): @@ -223,3 +355,7 @@ class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException """ Exception thrown because of Spark upgrade from Spark Connect. """ + + +class SparkException(SparkConnectGrpcException): + """ """ diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 11a1112ad1fe7..cef0ea4f305df 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1564,6 +1564,14 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet except grpc.RpcError: return None + def _display_server_stack_trace(self) -> bool: + from pyspark.sql.connect.conf import RuntimeConf + + conf = RuntimeConf(self) + if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true": + return True + return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true" + def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: """ Error handling helper for dealing with GRPC Errors. On the server side, certain @@ -1594,7 +1602,10 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: d.Unpack(info) raise convert_exception( - info, status.message, self._fetch_enriched_error(info) + info, + status.message, + self._fetch_enriched_error(info), + self._display_server_stack_trace(), ) from None raise SparkConnectGrpcException(status.message) from None diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index f024a03c2686c..daf6772e52bf5 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3378,35 +3378,37 @@ def test_error_enrichment_jvm_stacktrace(self): """select from_json( '{"d": "02-29"}', 'd date', map('dateFormat', 'MM-dd'))""" ).collect() - self.assertTrue("JVM stacktrace" in e.exception.message) - self.assertTrue("org.apache.spark.SparkUpgradeException:" in e.exception.message) + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertTrue("org.apache.spark.SparkUpgradeException" in str(e.exception)) self.assertTrue( "at org.apache.spark.sql.errors.ExecutionErrors" - ".failToParseDateTimeInNewParserError" in e.exception.message + ".failToParseDateTimeInNewParserError" in str(e.exception) ) - self.assertTrue("Caused by: java.time.DateTimeException:" in e.exception.message) + self.assertTrue("Caused by: java.time.DateTimeException:" in str(e.exception)) def test_not_hitting_netty_header_limit(self): with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}): with self.assertRaises(AnalysisException): - self.spark.sql("select " + "test" * 10000).collect() + self.spark.sql("select " + "test" * 1).collect() def test_error_stack_trace(self): with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}): with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}): with self.assertRaises(AnalysisException) as e: self.spark.sql("select x").collect() - self.assertTrue("JVM stacktrace" in e.exception.message) + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertIsNotNone(e.exception.getStackTrace()) self.assertTrue( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) ) with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}): with self.assertRaises(AnalysisException) as e: self.spark.sql("select x").collect() - self.assertFalse("JVM stacktrace" in e.exception.message) + self.assertFalse("JVM stacktrace" in str(e.exception)) + self.assertIsNone(e.exception.getStackTrace()) self.assertFalse( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) ) # Create a new session with a different stack trace size. @@ -3421,9 +3423,10 @@ def test_error_stack_trace(self): spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", True) with self.assertRaises(AnalysisException) as e: spark.sql("select x").collect() - self.assertTrue("JVM stacktrace" in e.exception.message) + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertIsNotNone(e.exception.getStackTrace()) self.assertFalse( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) ) spark.stop() From 57fc1abc1ac85e49a3df71a7327085a6aa39ecb0 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 7 Nov 2023 11:25:58 -0800 Subject: [PATCH 055/121] [SPARK-45811][PYTHON][DOCS][FOLLOW-UP] Remove an example Spark Connect does not support ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/43680 that fixes the failing test with Spark Connect. ### Why are the changes needed? To fix the CI. ### Does this PR introduce _any_ user-facing change? No, the documentation change has not been released out. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43704 from HyukjinKwon/SPARK-45811. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 95821deeec173..ae0f1e70be675 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -13670,20 +13670,7 @@ def from_xml( Examples -------- - Example 1: Parsing XML with a :class:`StructType` schema - - >>> import pyspark.sql.functions as sf - >>> from pyspark.sql.types import StructType, StructField, LongType - ... # Sample data with an XML column - >>> data = [(1, '''

    1

    ''')] - >>> df = spark.createDataFrame(data, ("key", "value")) - ... # Define the schema using a StructType - >>> schema = StructType([StructField("a", LongType())]) - ... # Parse the XML column using the specified schema - >>> df.select(sf.from_xml(df.value, schema).alias("xml")).collect() - [Row(xml=Row(a=1))] - - Example 2: Parsing XML with a DDL-formatted string schema + Example 1: Parsing XML with a DDL-formatted string schema >>> import pyspark.sql.functions as sf >>> data = [(1, '''

    1

    ''')] @@ -13694,7 +13681,7 @@ def from_xml( >>> df.select(sf.from_xml(df.value, schema).alias("xml")).collect() [Row(xml=Row(a=1))] - Example 3: Parsing XML with :class:`ArrayType` in schema + Example 2: Parsing XML with :class:`ArrayType` in schema >>> import pyspark.sql.functions as sf >>> data = [(1, '

    12

    ')] @@ -13705,7 +13692,7 @@ def from_xml( >>> df.select(sf.from_xml(df.value, schema).alias("xml")).collect() [Row(xml=Row(a=[1, 2]))] - Example 4: Parsing XML using :meth:`pyspark.sql.functions.schema_of_xml` + Example 3: Parsing XML using :meth:`pyspark.sql.functions.schema_of_xml` >>> import pyspark.sql.functions as sf >>> # Sample data with an XML column From b0791b513da3f0671417b9fbcd3a0caddbb45318 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 7 Nov 2023 13:51:37 -0800 Subject: [PATCH 056/121] [SPARK-45223][PYTHON][DOCS] Refine docstring of `Column.when` ### What changes were proposed in this pull request? This PR proposes to improve the docstring of `Column.when`. ### Why are the changes needed? For end users, and better usability of PySpark. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the user facing documentation. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43688 from HyukjinKwon/SPARK-45223. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/column.py | 40 +++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 203e53474f74a..9357b4842bbdb 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -1388,10 +1388,12 @@ def when(self, condition: "Column", value: Any) -> "Column": Examples -------- + Example 1: Using :func:`when` with conditions and values to create a new Column + >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame( - ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) - >>> df.select(df.name, sf.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() + >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> result = df.select(df.name, sf.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)) + >>> result.show() +-----+------------------------------------------------------------+ | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END| +-----+------------------------------------------------------------+ @@ -1399,6 +1401,38 @@ def when(self, condition: "Column", value: Any) -> "Column": | Bob| 1| +-----+------------------------------------------------------------+ + Example 2: Chaining multiple :func:`when` conditions + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(1, "Alice"), (4, "Bob"), (6, "Charlie")], ["age", "name"]) + >>> result = df.select( + ... df.name, + ... sf.when(df.age < 3, "Young").when(df.age < 5, "Middle-aged").otherwise("Old") + ... ) + >>> result.show() + +-------+---------------------------------------------------------------------------+ + | name|CASE WHEN (age < 3) THEN Young WHEN (age < 5) THEN Middle-aged ELSE Old END| + +-------+---------------------------------------------------------------------------+ + | Alice| Young| + | Bob| Middle-aged| + |Charlie| Old| + +-------+---------------------------------------------------------------------------+ + + Example 3: Using literal values as conditions + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> result = df.select( + ... df.name, sf.when(sf.lit(True), 1).otherwise( + ... sf.raise_error("unreachable")).alias("when")) + >>> result.show() + +-----+----+ + | name|when| + +-----+----+ + |Alice| 1| + | Bob| 1| + +-----+----+ + See Also -------- pyspark.sql.functions.when From 5a49af205411feaf0f5aee07f5d6d122e10bfe1f Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Tue, 7 Nov 2023 16:30:49 -0800 Subject: [PATCH 057/121] [SPARK-45555][PYTHON] Includes a debuggable object for failed assertion ### What changes were proposed in this pull request? This PR proposes to enhanced the `assertDataFrameEqual` function to support an optional `includeDiffRows` parameter. This parameter, will return the rows from both DataFrames that are not equal when set to `True`. ### Why are the changes needed? This enhancement provides users with an easier debugging experience by directly pointing out the rows that do not match, eliminating the need for manual comparison in case of large DataFrames. ### Does this PR introduce _any_ user-facing change? Yes. An optional parameter `includeDiffRows` has been introduced in the `assertDataFrameEqual` function. When set to `True`, it will return unequal rows for further analysis. For example: ```python df1 = spark.createDataFrame( data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"]) df2 = spark.createDataFrame( data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], schema=["id", "amount"]) try: assertDataFrameEqual(df1, df2, includeDiffRows=True) except PySparkAssertionError as e: spark.createDataFrame(e.data).show() ``` The above code will produce the following DataFrame: ``` +-----------+-----------+ | _1| _2| +-----------+-----------+ |{1, 1000.0}|{1, 1001.0}| |{3, 2000.0}|{3, 2003.0}| +-----------+-----------+ ``` ### How was this patch tested? Added usage example into doctest. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43444 from itholic/SPARK-45555. Authored-by: Haejoon Lee Signed-off-by: Hyukjin Kwon --- python/pyspark/errors/exceptions/base.py | 15 +++++++++- python/pyspark/sql/tests/test_utils.py | 22 ++++++++++++++ python/pyspark/testing/utils.py | 37 ++++++++++++++++++++---- 3 files changed, 68 insertions(+), 6 deletions(-) diff --git a/python/pyspark/errors/exceptions/base.py b/python/pyspark/errors/exceptions/base.py index 518a2d99ce889..5ab73b63d362b 100644 --- a/python/pyspark/errors/exceptions/base.py +++ b/python/pyspark/errors/exceptions/base.py @@ -15,11 +15,14 @@ # limitations under the License. # -from typing import Dict, Optional, cast +from typing import Dict, Optional, cast, Iterable, TYPE_CHECKING from pyspark.errors.utils import ErrorClassesReader from pickle import PicklingError +if TYPE_CHECKING: + from pyspark.sql.types import Row + class PySparkException(Exception): """ @@ -222,6 +225,16 @@ class PySparkAssertionError(PySparkException, AssertionError): Wrapper class for AssertionError to support error classes. """ + def __init__( + self, + message: Optional[str] = None, + error_class: Optional[str] = None, + message_parameters: Optional[Dict[str, str]] = None, + data: Optional[Iterable["Row"]] = None, + ): + super().__init__(message, error_class, message_parameters) + self.data = data + class PySparkNotImplementedError(PySparkException, NotImplementedError): """ diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index 421043a41bb47..ebdab31ec2075 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -1614,6 +1614,28 @@ def test_list_row_unequal_schema(self): message_parameters={"error_msg": error_msg}, ) + def test_dataframe_include_diff_rows(self): + df1 = self.spark.createDataFrame( + [("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], ["id", "amount"] + ) + df2 = self.spark.createDataFrame( + [("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], ["id", "amount"] + ) + + with self.assertRaises(PySparkAssertionError) as context: + assertDataFrameEqual(df1, df2, includeDiffRows=True) + + # Extracting the differing rows data from the exception + error_data = context.exception.data + + # Expected differences + expected_diff = [ + (Row(id="1", amount=1000.0), Row(id="1", amount=1001.0)), + (Row(id="3", amount=2000.0), Row(id="3", amount=2003.0)), + ] + + self.assertEqual(error_data, expected_diff) + def test_dataframe_ignore_column_order(self): df1 = self.spark.createDataFrame([Row(A=1, B=2), Row(A=3, B=4)]) df2 = self.spark.createDataFrame([Row(B=2, A=1), Row(B=4, A=3)]) diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 8b2332208c190..5d284ffc7956b 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -484,6 +484,7 @@ def assertDataFrameEqual( ignoreColumnType: bool = False, maxErrors: Optional[int] = None, showOnlyDiff: bool = False, + includeDiffRows=False, ): r""" A util function to assert equality between `actual` and `expected` @@ -560,6 +561,11 @@ def assertDataFrameEqual( If set to `False` (default), the error message will include all rows (when there is at least one row that is different). + .. versionadded:: 4.0.0 + includeDiffRows: bool, False + If set to `True`, the unequal rows are included in PySparkAssertionError for further + debugging. If set to `False` (default), the unequal rows are not returned as a data set. + .. versionadded:: 4.0.0 Notes @@ -704,6 +710,24 @@ def assertDataFrameEqual( *** expected *** ! Row(_1=2, _2='X') ! Row(_1=3, _2='Y') + + The `includeDiffRows` parameter can be used to include the rows that did not match + in the PySparkAssertionError. This can be useful for debugging or further analysis. + + >>> df1 = spark.createDataFrame( + ... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"]) + >>> df2 = spark.createDataFrame( + ... data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], schema=["id", "amount"]) + >>> try: + ... assertDataFrameEqual(df1, df2, includeDiffRows=True) + ... except PySparkAssertionError as e: + ... spark.createDataFrame(e.data).show() # doctest: +NORMALIZE_WHITESPACE + +-----------+-----------+ + | _1| _2| + +-----------+-----------+ + |{1, 1000.0}|{1, 1001.0}| + |{3, 2000.0}|{3, 2003.0}| + +-----------+-----------+ """ if actual is None and expected is None: return True @@ -843,7 +867,8 @@ def assert_rows_equal( ): zipped = list(zip_longest(rows1, rows2)) diff_rows_cnt = 0 - diff_rows = False + diff_rows = [] + has_diff_rows = False rows_str1 = "" rows_str2 = "" @@ -852,7 +877,9 @@ def assert_rows_equal( for r1, r2 in zipped: if not compare_rows(r1, r2): diff_rows_cnt += 1 - diff_rows = True + has_diff_rows = True + if includeDiffRows: + diff_rows.append((r1, r2)) rows_str1 += str(r1) + "\n" rows_str2 += str(r2) + "\n" if maxErrors is not None and diff_rows_cnt >= maxErrors: @@ -865,14 +892,14 @@ def assert_rows_equal( actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=len(zipped) ) - if diff_rows: + if has_diff_rows: error_msg = "Results do not match: " percent_diff = (diff_rows_cnt / len(zipped)) * 100 error_msg += "( %.5f %% )" % percent_diff error_msg += "\n" + "\n".join(generated_diff) + data = diff_rows if includeDiffRows else None raise PySparkAssertionError( - error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": error_msg}, + error_class="DIFFERENT_ROWS", message_parameters={"error_msg": error_msg}, data=data ) # only compare schema if expected is not a List From bfcd2c47742d69632584f6bb3cf08a8bc3ca8c8c Mon Sep 17 00:00:00 2001 From: Hasnain Lakhani Date: Tue, 7 Nov 2023 21:16:33 -0600 Subject: [PATCH 058/121] [SPARK-45431][DOCS] Document new SSL RPC feature ### What changes were proposed in this pull request? Add some documentation to clarify the new flags being added and how this feature interacts with existing SSL and encryption support. It's a little confusing so feedback is welcome. Note I am not sure if it's best practice to merge this after the feature is fully merged - thought I would put this up now since it's unblocked. ### Why are the changes needed? New features require documentation so users can understand how to use them. ### Does this PR introduce _any_ user-facing change? Yes, this is adding documentation for a new feature ### How was this patch tested? Not applicable, this is documentation ### Was this patch authored or co-authored using generative AI tooling? No Closes #43240 from hasnain-db/spark-tls-docs. Authored-by: Hasnain Lakhani Signed-off-by: Mridul Muralidharan gmail.com> --- docs/security.md | 109 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 104 insertions(+), 5 deletions(-) diff --git a/docs/security.md b/docs/security.md index 0a6a3cf089185..2a1105fea33fe 100644 --- a/docs/security.md +++ b/docs/security.md @@ -147,7 +147,26 @@ Note that when using files, Spark will not mount these files into the containers you to ensure that the secret files are deployed securely into your containers and that the driver's secret file agrees with the executors' secret file. -## Encryption +# Network Encryption + +Spark supports two mutually exclusive forms of encryption for RPC connections. + +The first is an AES-based encryption which relies on a shared secret, and thus requires +RPC authentication to also be enabled. + +The second is an SSL based encryption mechanism utilizing Netty's support for SSL. This requires +keys and certificates to be properly configured. It can be used with or without the authentication +mechanism discussed earlier. + +One may prefer to use the SSL based encryption in scenarios where compliance mandates the usage +of specific protocols; or to leverage the security of a more standard encryption library. However, +the AES based encryption is simpler to configure and may be preferred if the only requirement +is that data be encrypted in transit. + +If both options are enabled in the configuration, the SSL based RPC encryption takes precedence +and the AES based encryption will not be used (and a warning message will be emitted). + +## AES based Encryption Spark supports AES-based encryption for RPC connections. For encryption to be enabled, RPC authentication must also be enabled and properly configured. AES encryption uses the @@ -209,6 +228,17 @@ The following table describes the different options available for configuring th +## SSL Encryption + +Spark supports SSL based encryption for RPC connections. Please refer to the SSL Configuration +section below to understand how to configure it. The SSL settings are mostly similar across the UI +and RPC, however there are a few additional settings which are specific to the RPC implementation. +The RPC implementation uses Netty under the hood (while the UI uses Jetty), which supports a +different set of options. + +Unlike the other SSL settings for the UI, the RPC SSL is *not* automatically enabled if +`spark.ssl.enabled` is set. It must be explicitly enabled, to ensure a safe migration path for users +upgrading Spark versions. # Local Storage Encryption @@ -437,8 +467,10 @@ application configurations will be ignored. Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the -protocols without disabling the ability to configure each one individually. The following table -describes the SSL configuration namespaces: +protocols without disabling the ability to configure each one individually. Note that all settings +are inherited this way, *except* for `spark.ssl.rpc.enabled` which must be explicitly set. + +The following table describes the SSL configuration namespaces: @@ -466,17 +498,22 @@ describes the SSL configuration namespaces: + + + +
    spark.ssl.historyServer History Server Web UI
    spark.ssl.rpcSpark RPC communication
    The full breakdown of available SSL options can be found below. The `${ns}` placeholder should be replaced with one of the above namespaces. - + + @@ -490,6 +527,7 @@ replaced with one of the above namespaces.
    When not set, the SSL port will be derived from the non-SSL port for the same service. A value of "0" will make the service bind to an ephemeral port. + @@ -504,6 +542,7 @@ replaced with one of the above namespaces.
    Note: If not set, the default cipher suite for the JRE will be used. + @@ -511,6 +550,7 @@ replaced with one of the above namespaces. + @@ -519,16 +559,19 @@ replaced with one of the above namespaces. Path to the key store file. The path can be absolute or relative to the directory in which the process is started. + + + @@ -541,11 +584,15 @@ replaced with one of the above namespaces. this page. + - + + @@ -554,16 +601,68 @@ replaced with one of the above namespaces. Path to the trust store file. The path can be absolute or relative to the directory in which the process is started. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    Property NameDefaultMeaningSupported Namespaces
    ${ns}.enabled false Enables SSL. When enabled, ${ns}.ssl.protocol is required.ui,standalone,historyServer,rpc
    ${ns}.portui,standalone,historyServer
    ${ns}.enabledAlgorithmsui,standalone,historyServer,rpc
    ${ns}.keyPassword The password to the private key in the key store. ui,standalone,historyServer,rpc
    ${ns}.keyStoreui,standalone,historyServer,rpc
    ${ns}.keyStorePassword None Password to the key store.ui,standalone,historyServer,rpc
    ${ns}.keyStoreType JKS The type of the key store.ui,standalone,historyServer
    ${ns}.protocolui,standalone,historyServer,rpc
    ${ns}.needClientAuth falseWhether to require client authentication. + Whether to require client authentication. + ui,standalone,historyServer
    ${ns}.trustStoreui,standalone,historyServer,rpc
    ${ns}.trustStorePassword None Password for the trust store.ui,standalone,historyServer,rpc
    ${ns}.trustStoreType JKS The type of the trust store.ui,standalone,historyServer
    ${ns}.openSSLEnabledfalse + Whether to use OpenSSL for cryptographic operations instead of the JDK SSL provider. + This setting requires the `certChain` and `privateKey` settings to be set. + This takes precedence over the `keyStore` and `trustStore` settings if both are specified. + If the OpenSSL library is not available at runtime, we will fall back to the JDK provider. + rpc
    ${ns}.privateKeyNone + Path to the private key file in PEM format. The path can be absolute or relative to the + directory in which the process is started. + This setting is required when using the OpenSSL implementation. + rpc
    ${ns}.certChainNone + Path to the certificate chain file in PEM format. The path can be absolute or relative to the + directory in which the process is started. + This setting is required when using the OpenSSL implementation. + rpc
    ${ns}.trustStoreReloadingEnabledfalse + Whether the trust store should be reloaded periodically. + This setting is mostly only useful in standalone deployments, not k8s or yarn deployments. + rpc
    ${ns}.trustStoreReloadIntervalMs10000 + The interval at which the trust store should be reloaded (in milliseconds). + This setting is mostly only useful in standalone deployments, not k8s or yarn deployments. + rpc
    From 44522756382dd628c2278d22df4cf7db4461079d Mon Sep 17 00:00:00 2001 From: Chaoqin Li Date: Wed, 8 Nov 2023 14:04:39 +0900 Subject: [PATCH 059/121] [SPARK-45794][SS] Introduce state metadata source to query the streaming state metadata information MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Introduce a new data source so that user can query the metadata of each state store of a streaming query, the schema of the result will be following: column | type --------- | ------- | operatorId | INT | | operatorName | STRING | | stateStoreName | STRING | | numPartitions | INT | | minBatchId | LONG | | minBatchId | LONG | | _numColsPrefixKey (metadata column) | INT | To use this source, specify the source format and checkpoint path and load the dataframe `df = spark.read.format(“state-metadata”).load(“/checkpointPath”)` ### Why are the changes needed? To improve debugability. Also facilitate the query of state store data source introduced in SPARK-45511 by displaying the operator id, batch id and state store name. ### Does this PR introduce _any_ user-facing change? Yes, this is a new source exposed to user. ### How was this patch tested? Add test to verify the output of state metadata ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43660 from chaoqin-li1123/state_metadata_source. Authored-by: Chaoqin Li Signed-off-by: Jungtaek Lim --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../state/metadata/StateMetadataSource.scala | 214 ++++++++++++++++++ .../state/OperatorStateMetadataSuite.scala | 37 ++- 3 files changed, 252 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 3169e75031fca..b4c18c38f04aa 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -27,4 +27,5 @@ org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider org.apache.spark.sql.execution.datasources.binaryfile.BinaryFileFormat -org.apache.spark.sql.execution.streaming.sources.RatePerMicroBatchProvider \ No newline at end of file +org.apache.spark.sql.execution.streaming.sources.RatePerMicroBatchProvider +org.apache.spark.sql.execution.datasources.v2.state.StateMetadataSource \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala new file mode 100644 index 0000000000000..8a74db8d19639 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.util + +import scala.jdk.CollectionConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration + +case class StateMetadataTableEntry( + operatorId: Long, + operatorName: String, + stateStoreName: String, + numPartitions: Int, + minBatchId: Long, + maxBatchId: Long, + numColsPrefixKey: Int) { + def toRow(): InternalRow = { + InternalRow.fromSeq( + Seq(operatorId, + UTF8String.fromString(operatorName), + UTF8String.fromString(stateStoreName), + numPartitions, + minBatchId, + maxBatchId, + numColsPrefixKey)) + } +} + +object StateMetadataTableEntry { + private[sql] val schema = { + new StructType() + .add("operatorId", LongType) + .add("operatorName", StringType) + .add("stateStoreName", StringType) + .add("numPartitions", IntegerType) + .add("minBatchId", LongType) + .add("maxBatchId", LongType) + } +} + +class StateMetadataSource extends TableProvider with DataSourceRegister { + override def shortName(): String = "state-metadata" + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + new StateMetadataTable + } + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + // The schema of state metadata table is static. + StateMetadataTableEntry.schema + } +} + + +class StateMetadataTable extends Table with SupportsRead with SupportsMetadataColumns { + override def name(): String = "state-metadata-table" + + override def schema(): StructType = StateMetadataTableEntry.schema + + override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + () => { + if (!options.containsKey("path")) { + throw new IllegalArgumentException("Checkpoint path is not specified for" + + " state metadata data source.") + } + new StateMetadataScan(options.get("path")) + } + } + + private object NumColsPrefixKeyColumn extends MetadataColumn { + override def name: String = "_numColsPrefixKey" + override def dataType: DataType = IntegerType + override def comment: String = "Number of columns in prefix key of the state store instance" + } + + override val metadataColumns: Array[MetadataColumn] = Array(NumColsPrefixKeyColumn) +} + +case class StateMetadataInputPartition(checkpointLocation: String) extends InputPartition + +class StateMetadataScan(checkpointLocation: String) extends Scan { + override def readSchema: StructType = StateMetadataTableEntry.schema + + override def toBatch: Batch = { + new Batch { + override def planInputPartitions(): Array[InputPartition] = { + Array(StateMetadataInputPartition(checkpointLocation)) + } + + override def createReaderFactory(): PartitionReaderFactory = { + // Don't need to broadcast the hadoop conf because this source only has one partition. + val conf = new SerializableConfiguration(SparkSession.active.sessionState.newHadoopConf()) + StateMetadataPartitionReaderFactory(conf) + } + } + } +} + +case class StateMetadataPartitionReaderFactory(hadoopConf: SerializableConfiguration) + extends PartitionReaderFactory { + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + new StateMetadataPartitionReader( + partition.asInstanceOf[StateMetadataInputPartition].checkpointLocation, hadoopConf) + } +} + +class StateMetadataPartitionReader( + checkpointLocation: String, + serializedHadoopConf: SerializableConfiguration) extends PartitionReader[InternalRow] { + + override def next(): Boolean = { + stateMetadata.hasNext + } + + override def get(): InternalRow = { + stateMetadata.next().toRow() + } + + override def close(): Unit = {} + + private def pathToLong(path: Path) = { + path.getName.toLong + } + + private def pathNameCanBeParsedAsLong(path: Path) = { + try { + pathToLong(path) + true + } catch { + case _: NumberFormatException => false + } + } + + // Return true when the filename can be parsed as long integer. + private val pathNameCanBeParsedAsLongFilter = new PathFilter { + override def accept(path: Path): Boolean = pathNameCanBeParsedAsLong(path) + } + + private lazy val hadoopConf: Configuration = serializedHadoopConf.value + + private lazy val fileManager = + CheckpointFileManager.create(new Path(checkpointLocation), hadoopConf) + + // List the commit log entries to find all the available batch ids. + private def batchIds: Array[Long] = { + val commitLog = new Path(checkpointLocation, "commits") + if (fileManager.exists(commitLog)) { + fileManager + .list(commitLog, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted + } else Array.empty + } + + private def allOperatorStateMetadata: Array[OperatorStateMetadata] = { + val stateDir = new Path(checkpointLocation, "state") + val opIds = fileManager + .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted + opIds.map { opId => + new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read() + } + } + + private lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { + allOperatorStateMetadata.flatMap { operatorStateMetadata => + require(operatorStateMetadata.version == 1) + val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1] + operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata => + StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId, + operatorStateMetadataV1.operatorInfo.operatorName, + stateStoreMetadata.storeName, + stateStoreMetadata.numPartitions, + if (batchIds.nonEmpty) batchIds.head else -1, + if (batchIds.nonEmpty) batchIds.last else -1, + stateStoreMetadata.numColsPrefixKey + ) + } + } + }.iterator +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala index 48cc17bbbabf2..340187fa49514 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.hadoop.fs.Path -import org.apache.spark.sql.Column +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{OutputMode, StreamTest} @@ -53,6 +53,15 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { val operatorMetadata = OperatorStateMetadataV1(operatorInfo, stateStoreInfo.toArray) new OperatorStateMetadataWriter(statePath, hadoopConf).write(operatorMetadata) checkOperatorStateMetadata(checkpointDir.toString, 0, operatorMetadata) + val df = spark.read.format("state-metadata").load(checkpointDir.toString) + // Commit log is empty, there is no available batch id. + checkAnswer(df, Seq(Row(1, "Join", "store1", 200, -1L, -1L), + Row(1, "Join", "store2", 200, -1L, -1L), + Row(1, "Join", "store3", 200, -1L, -1L), + Row(1, "Join", "store4", 200, -1L, -1L) + )) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), + Seq(Row(1), Row(1), Row(1), Row(1))) } } @@ -105,6 +114,16 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { val expectedMetadata = OperatorStateMetadataV1( OperatorInfoV1(0, "symmetricHashJoin"), expectedStateStoreInfo) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata) + + val df = spark.read.format("state-metadata") + .load(checkpointDir.toString) + checkAnswer(df, Seq(Row(0, "symmetricHashJoin", "left-keyToNumValues", 5, 0L, 1L), + Row(0, "symmetricHashJoin", "left-keyWithIndexToValue", 5, 0L, 1L), + Row(0, "symmetricHashJoin", "right-keyToNumValues", 5, 0L, 1L), + Row(0, "symmetricHashJoin", "right-keyWithIndexToValue", 5, 0L, 1L) + )) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), + Seq(Row(0), Row(0), Row(0), Row(0))) } } @@ -147,6 +166,10 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { Array(StateStoreMetadataV1("default", 1, spark.sessionState.conf.numShufflePartitions)) ) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata) + + val df = spark.read.format("state-metadata").load(checkpointDir.toString) + checkAnswer(df, Seq(Row(0, "sessionWindowStateStoreSaveExec", "default", 5, 0L, 0L))) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(1))) } } @@ -176,6 +199,18 @@ class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession { Array(StateStoreMetadataV1("default", 0, numShufflePartitions))) checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata0) checkOperatorStateMetadata(checkpointDir.toString, 1, expectedMetadata1) + + val df = spark.read.format("state-metadata").load(checkpointDir.toString) + checkAnswer(df, Seq(Row(0, "stateStoreSave", "default", 5, 0L, 1L), + Row(1, "stateStoreSave", "default", 5, 0L, 1L))) + checkAnswer(df.select(df.metadataColumn("_numColsPrefixKey")), Seq(Row(0), Row(0))) + } + } + + test("State metadata data source handle missing argument") { + val e = intercept[IllegalArgumentException] { + spark.read.format("state-metadata").load().collect() } + assert(e.getMessage == "Checkpoint path is not specified for state metadata data source.") } } From d582b74d97a7e44bdd2f0ae6c63121fc5e5466b7 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Wed, 8 Nov 2023 11:38:55 +0300 Subject: [PATCH 060/121] [SPARK-45824][SQL] Enforce the error class in `ParseException` ### What changes were proposed in this pull request? In the PR, I propose to enforce creation of `ParseException` with an error class always. In particular, it converts the constructor with a message to private one, so, callers have to create `ParseException` with an error class. ### Why are the changes needed? This simplifies migration on error classes. ### Does this PR introduce _any_ user-facing change? No since user code doesn't throw `ParseException` in regular cases. ### How was this patch tested? By existing test suites, for instance: ``` $ build/sbt "sql/testOnly *QueryParsingErrorsSuite" $ build/sbt "test:testOnly *SparkConnectClientSuite" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43702 from MaxGekk/ban-message-ParseException. Authored-by: Max Gekk Signed-off-by: Max Gekk --- .../client/SparkConnectClientSuite.scala | 15 ++---- .../client/GrpcExceptionConverter.scala | 3 +- .../spark/sql/catalyst/parser/parsers.scala | 53 ++++++++++++++----- .../spark/sql/errors/QueryParsingErrors.scala | 8 ++- 4 files changed, 51 insertions(+), 28 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index b3ff4eb0bb296..d0c85da5f212e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -31,7 +31,6 @@ import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.connect.common.config.ConnectCommon import org.apache.spark.sql.test.ConnectFunSuite @@ -210,19 +209,15 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { } for ((name, constructor) <- GrpcExceptionConverter.errorFactory) { - test(s"error framework parameters - ${name}") { + test(s"error framework parameters - $name") { val testParams = GrpcExceptionConverter.ErrorParams( - message = "test message", + message = "Found duplicate keys `abc`", cause = None, - errorClass = Some("test error class"), - messageParameters = Map("key" -> "value"), + errorClass = Some("DUPLICATE_KEY"), + messageParameters = Map("keyColumn" -> "`abc`"), queryContext = Array.empty) val error = constructor(testParams) - if (!error.isInstanceOf[ParseException]) { - assert(error.getMessage == testParams.message) - } else { - assert(error.getMessage == s"\n${testParams.message}") - } + assert(error.getMessage.contains(testParams.message)) assert(error.getCause == null) error match { case sparkThrowable: SparkThrowable => diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 652797bc2e40f..88cd2118ba755 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -191,10 +191,9 @@ private[client] object GrpcExceptionConverter { errorConstructor(params => new ParseException( None, - params.message, Origin(), Origin(), - errorClass = params.errorClass, + errorClass = params.errorClass.orNull, messageParameters = params.messageParameters, queryContext = params.queryContext)), errorConstructor(params => diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala index 22e6c67090b4d..2689e317128a8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala @@ -23,7 +23,7 @@ import org.antlr.v4.runtime.atn.PredictionMode import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} import org.antlr.v4.runtime.tree.TerminalNodeImpl -import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper} +import org.apache.spark.{QueryContext, SparkException, SparkThrowable, SparkThrowableHelper} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, SQLQueryContext, WithOrigin} @@ -99,10 +99,9 @@ abstract class AbstractParser extends DataTypeParserInterface with Logging { case e: SparkThrowable with WithOrigin => throw new ParseException( command = Option(command), - message = e.getMessage, start = e.origin, stop = e.origin, - errorClass = Option(e.getErrorClass), + errorClass = e.getErrorClass, messageParameters = e.getMessageParameters.asScala.toMap, queryContext = e.getQueryContext) } @@ -174,7 +173,12 @@ case object ParseErrorListener extends BaseErrorListener { case sre: SparkRecognitionException if sre.errorClass.isDefined => throw new ParseException(None, start, stop, sre.errorClass.get, sre.messageParameters) case _ => - throw new ParseException(None, msg, start, stop) + throw new ParseException( + command = None, + start = start, + stop = stop, + errorClass = "PARSE_SYNTAX_ERROR", + messageParameters = Map("error" -> msg, "hint" -> "")) } } } @@ -183,7 +187,7 @@ case object ParseErrorListener extends BaseErrorListener { * A [[ParseException]] is an [[SparkException]] that is thrown during the parse process. It * contains fields and an extended error message that make reporting and diagnosing errors easier. */ -class ParseException( +class ParseException private( val command: Option[String], message: String, val start: Origin, @@ -223,7 +227,24 @@ class ParseException( start, stop, Some(errorClass), - messageParameters) + messageParameters, + queryContext = ParseException.getQueryContext()) + + def this( + command: Option[String], + start: Origin, + stop: Origin, + errorClass: String, + messageParameters: Map[String, String], + queryContext: Array[QueryContext]) = + this( + command, + SparkThrowableHelper.getMessage(errorClass, messageParameters), + start, + stop, + Some(errorClass), + messageParameters, + queryContext) override def getMessage: String = { val builder = new StringBuilder @@ -247,17 +268,21 @@ class ParseException( } def withCommand(cmd: String): ParseException = { - val (cls, params) = - if (errorClass == Some("PARSE_SYNTAX_ERROR") && cmd.trim().isEmpty) { - // PARSE_EMPTY_STATEMENT error class overrides the PARSE_SYNTAX_ERROR when cmd is empty - (Some("PARSE_EMPTY_STATEMENT"), Map.empty[String, String]) - } else { - (errorClass, messageParameters) - } - new ParseException(Option(cmd), message, start, stop, cls, params, queryContext) + val cl = getErrorClass + val (newCl, params) = if (cl == "PARSE_SYNTAX_ERROR" && cmd.trim().isEmpty) { + // PARSE_EMPTY_STATEMENT error class overrides the PARSE_SYNTAX_ERROR when cmd is empty + ("PARSE_EMPTY_STATEMENT", Map.empty[String, String]) + } else { + (cl, messageParameters) + } + new ParseException(Option(cmd), start, stop, newCl, params, queryContext) } override def getQueryContext: Array[QueryContext] = queryContext + + override def getErrorClass: String = errorClass.getOrElse { + throw SparkException.internalError("ParseException shall have an error class.") + } } object ParseException { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index f63fc8c4785bc..2067bf7d0955d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -431,8 +431,12 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { } def sqlStatementUnsupportedError(sqlText: String, position: Origin): Throwable = { - new ParseException(Option(sqlText), "Unsupported SQL statement", position, position, - Some("_LEGACY_ERROR_TEMP_0039")) + new ParseException( + command = Option(sqlText), + start = position, + stop = position, + errorClass = "_LEGACY_ERROR_TEMP_0039", + messageParameters = Map.empty) } def invalidIdentifierError(ident: String, ctx: ErrorIdentContext): Throwable = { From a9127068194a48786df4f429ceb4f908c71f7138 Mon Sep 17 00:00:00 2001 From: chenyu <119398199+chenyu-opensource@users.noreply.github.com> Date: Wed, 8 Nov 2023 19:16:48 +0800 Subject: [PATCH 061/121] [SPARK-45829][DOCS] Update the default value for spark.executor.logs.rolling.maxSize **What changes were proposed in this pull request?** The PR updates the default value of 'spark.executor.logs.rolling.maxSize' in configuration.html on the website **Why are the changes needed?** The default value of 'spark.executor.logs.rolling.maxSize' is 1024 * 1024, but the website is wrong. **Does this PR introduce any user-facing change?** No **How was this patch tested?** It doesn't need to. **Was this patch authored or co-authored using generative AI tooling?** No Closes #43712 from chenyu-opensource/branch-SPARK-45829. Authored-by: chenyu <119398199+chenyu-opensource@users.noreply.github.com> Signed-off-by: Kent Yao --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 60cad24e71c44..3d54aaf6518be 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -684,7 +684,7 @@ Apart from these, the following properties are also available, and may be useful spark.executor.logs.rolling.maxSize - (none) + 1024 * 1024 Set the max size of the file in bytes by which the executor logs will be rolled over. Rolling is disabled by default. See spark.executor.logs.rolling.maxRetainedFiles From 1d8df4f6b99b836f4267b888e81d67c75b4dfdcd Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 8 Nov 2023 19:43:33 +0800 Subject: [PATCH 062/121] [SPARK-45606][SQL] Release restrictions on multi-layer runtime filter ### What changes were proposed in this pull request? Before https://github.com/apache/spark/pull/39170, Spark only supports insert runtime filter for application side of shuffle join on single-layer. Considered it's not worth to insert more runtime filter if the column already exists runtime filter, Spark restricts it at https://github.com/apache/spark/blob/7057952f6bc2c5cf97dd408effd1b18bee1cb8f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala#L346 For example `select * from bf1 join bf2 on bf1.c1 = bf2.c2 and bf1.c1 = bf2.b2 where bf2.a2 = 62` This SQL have two join conditions. There will insert two runtime filter on `bf1.c1` if haven't the restriction mentioned above. At that time, it was reasonable. After https://github.com/apache/spark/pull/39170, Spark supports insert runtime filter for one side of any shuffle join on multi-layer. But the restrictions on multi-layer runtime filter mentioned above looks outdated. For example `select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5` Assume bf2 as the build side and insert a runtime filter for bf1. We can't insert the same runtime filter for bf3 due to there are already a runtime filter on `bf1.c1`. The behavior is different from the origin and is unexpected. The change of the PR doesn't affect the restriction mentioned above. ### Why are the changes needed? Release restrictions on multi-layer runtime filter. Expand optimization surface. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? Test cases updated. Micro benchmark for q9 in TPC-H. **TPC-H 100** Query | Master(ms) | PR(ms) | Difference(ms) | Percent -- | -- | -- | -- | -- q9 | 26491 | 20725 | 5766| 27.82% ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #43449 from beliefer/SPARK-45606. Authored-by: Jiaan Geng Signed-off-by: Jiaan Geng --- .../optimizer/InjectRuntimeFilter.scala | 33 +++++++++---------- .../spark/sql/InjectRuntimeFilterSuite.scala | 8 ++--- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 5f5508d6b22c2..9c150f1f3308f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -247,15 +247,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J } } - private def hasBloomFilter( - left: LogicalPlan, - right: LogicalPlan, - leftKey: Expression, - rightKey: Expression): Boolean = { - findBloomFilterWithKey(left, leftKey) || findBloomFilterWithKey(right, rightKey) - } - - private def findBloomFilterWithKey(plan: LogicalPlan, key: Expression): Boolean = { + private def hasBloomFilter(plan: LogicalPlan, key: Expression): Boolean = { plan.exists { case Filter(condition, _) => splitConjunctivePredicates(condition).exists { @@ -277,28 +269,33 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J leftKeys.lazyZip(rightKeys).foreach((l, r) => { // Check if: // 1. There is already a DPP filter on the key - // 2. There is already a bloom filter on the key - // 3. The keys are simple cheap expressions + // 2. The keys are simple cheap expressions if (filterCounter < numFilterThreshold && !hasDynamicPruningSubquery(left, right, l, r) && - !hasBloomFilter(newLeft, newRight, l, r) && isSimpleExpression(l) && isSimpleExpression(r)) { val oldLeft = newLeft val oldRight = newRight - // Check if the current join is a shuffle join or a broadcast join that - // has a shuffle below it + // Check if: + // 1. The current join type supports prune the left side with runtime filter + // 2. The current join is a shuffle join or a broadcast join that + // has a shuffle below it + // 3. There is no bloom filter on the left key yet val hasShuffle = isProbablyShuffleJoin(left, right, hint) - if (canPruneLeft(joinType) && (hasShuffle || probablyHasShuffle(left))) { + if (canPruneLeft(joinType) && (hasShuffle || probablyHasShuffle(left)) && + !hasBloomFilter(newLeft, l)) { extractBeneficialFilterCreatePlan(left, right, l, r).foreach { case (filterCreationSideKey, filterCreationSidePlan) => newLeft = injectFilter(l, newLeft, filterCreationSideKey, filterCreationSidePlan) } } // Did we actually inject on the left? If not, try on the right - // Check if the current join is a shuffle join or a broadcast join that - // has a shuffle below it + // Check if: + // 1. The current join type supports prune the right side with runtime filter + // 2. The current join is a shuffle join or a broadcast join that + // has a shuffle below it + // 3. There is no bloom filter on the right key yet if (newLeft.fastEquals(oldLeft) && canPruneRight(joinType) && - (hasShuffle || probablyHasShuffle(right))) { + (hasShuffle || probablyHasShuffle(right)) && !hasBloomFilter(newRight, r)) { extractBeneficialFilterCreatePlan(right, left, r, l).foreach { case (filterCreationSideKey, filterCreationSidePlan) => newRight = injectFilter( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index 2e57975ee6d1d..fc1524be13179 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -335,14 +335,12 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp "bf1.c1 = bf2.c2 where bf2.a2 = 5) as a join bf3 on bf3.c3 = a.c1", 2) assertRewroteWithBloomFilter("select * from (select * from bf1 right join bf2 on " + "bf1.c1 = bf2.c2 where bf2.a2 = 5) as a join bf3 on bf3.c3 = a.c1", 2) - // Can't leverage the transitivity of join keys due to runtime filters already exists. - // bf2 as creation side and inject runtime filter for bf1. assertRewroteWithBloomFilter("select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " + - "and bf3.c3 = bf1.c1 where bf2.a2 = 5") + "and bf3.c3 = bf1.c1 where bf2.a2 = 5", 2) assertRewroteWithBloomFilter("select * from bf1 left outer join bf2 join bf3 on " + - "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5") + "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5", 2) assertRewroteWithBloomFilter("select * from bf1 right outer join bf2 join bf3 on " + - "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5") + "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5", 2) } withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", From f866549a5aa86f379cb71732b97fa547f2c4eb0a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 8 Nov 2023 20:21:55 +0800 Subject: [PATCH 063/121] [SPARK-45816][SQL] Return `NULL` when overflowing during casting from timestamp to integers ### What changes were proposed in this pull request? Spark cast works in two modes: ansi and non-ansi. When overflowing during casting, the common behavior under non-ansi mode is to return null. However, casting from Timestamp to Int/Short/Byte returns a wrapping value now. The behavior to silently overflow doesn't make sense. This patch changes it to the common behavior, i.e., returning null. ### Why are the changes needed? Returning a wrapping value, e.g., negative one, during casting Timestamp to Int/Short/Byte could implicitly cause misinterpret casted result without caution. We also should follow the common behavior of overflowing handling. ### Does this PR introduce _any_ user-facing change? Yes. Overflowing during casting from Timestamp to Int/Short/Byte under non-ansi mode, returns null instead of wrapping value. ### How was this patch tested? Will add test or update test if any existing ones fail ### Was this patch authored or co-authored using generative AI tooling? No Closes #43694 from viirya/fix_cast_integers. Authored-by: Liang-Chi Hsieh Signed-off-by: Jiaan Geng --- docs/sql-migration-guide.md | 1 + .../spark/sql/catalyst/expressions/Cast.scala | 51 +++++++++++-------- .../expressions/CastWithAnsiOffSuite.scala | 6 +-- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index b0dc49ed47683..5c00ce6558513 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -28,6 +28,7 @@ license: | - Since Spark 4.0, any read of SQL tables takes into consideration the SQL configs `spark.sql.files.ignoreCorruptFiles`/`spark.sql.files.ignoreMissingFiles` instead of the core config `spark.files.ignoreCorruptFiles`/`spark.files.ignoreMissingFiles`. - Since Spark 4.0, `spark.sql.hive.metastore` drops the support of Hive prior to 2.0.0 as they require JDK 8 that Spark does not support anymore. Users should migrate to higher versions. - Since Spark 4.0, `spark.sql.parquet.compression.codec` drops the support of codec name `lz4raw`, please use `lz4_raw` instead. +- Since Spark 4.0, when overflowing during casting timestamp to byte/short/int under non-ansi mode, Spark will return null instead a wrapping value. ## Upgrading from Spark SQL 3.4 to 3.5 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 62295fe260535..ee022c068b987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -344,6 +344,7 @@ object Cast extends QueryErrorsBase { case (StringType, _) => true case (_, StringType) => false + case (TimestampType, ByteType | ShortType | IntegerType) => true case (FloatType | DoubleType, TimestampType) => true case (TimestampType, DateType) => false case (_, DateType) => true @@ -777,6 +778,14 @@ case class Cast( buildCast[Int](_, i => yearMonthIntervalToInt(i, x.startField, x.endField).toLong) } + private def errorOrNull(t: Any, from: DataType, to: DataType) = { + if (ansiEnabled) { + throw QueryExecutionErrors.castingCauseOverflowError(t, from, to) + } else { + null + } + } + // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType if ansiEnabled => @@ -788,17 +797,15 @@ case class Cast( buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => buildCast[Int](_, d => null) - case TimestampType if ansiEnabled => + case TimestampType => buildCast[Long](_, t => { val longValue = timestampToLong(t) if (longValue == longValue.toInt) { longValue.toInt } else { - throw QueryExecutionErrors.castingCauseOverflowError(t, from, IntegerType) + errorOrNull(t, from, IntegerType) } }) - case TimestampType => - buildCast[Long](_, t => timestampToLong(t).toInt) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => exactNumeric.toInt(b) @@ -826,17 +833,15 @@ case class Cast( buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) case DateType => buildCast[Int](_, d => null) - case TimestampType if ansiEnabled => + case TimestampType => buildCast[Long](_, t => { val longValue = timestampToLong(t) if (longValue == longValue.toShort) { longValue.toShort } else { - throw QueryExecutionErrors.castingCauseOverflowError(t, from, ShortType) + errorOrNull(t, from, ShortType) } }) - case TimestampType => - buildCast[Long](_, t => timestampToLong(t).toShort) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => @@ -875,17 +880,15 @@ case class Cast( buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) case DateType => buildCast[Int](_, d => null) - case TimestampType if ansiEnabled => + case TimestampType => buildCast[Long](_, t => { val longValue = timestampToLong(t) if (longValue == longValue.toByte) { longValue.toByte } else { - throw QueryExecutionErrors.castingCauseOverflowError(t, from, ByteType) + errorOrNull(t, from, ByteType) } }) - case TimestampType => - buildCast[Long](_, t => timestampToLong(t).toByte) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => @@ -1661,22 +1664,26 @@ case class Cast( integralType: String, from: DataType, to: DataType): CastFunction = { - if (ansiEnabled) { - val longValue = ctx.freshName("longValue") - val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) - val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) - (c, evPrim, _) => - code""" + + val longValue = ctx.freshName("longValue") + val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) + val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) + + (c, evPrim, evNull) => + val overflow = if (ansiEnabled) { + code"""throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt);""" + } else { + code"$evNull = true;" + } + + code""" long $longValue = ${timestampToLongCode(c)}; if ($longValue == ($integralType) $longValue) { $evPrim = ($integralType) $longValue; } else { - throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt); + $overflow } """ - } else { - (c, evPrim, _) => code"$evPrim = ($integralType) ${timestampToLongCode(c)};" - } } private[this] def castDayTimeIntervalToIntegralTypeCode( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala index 1dbf03b1538a6..e260b6fdbdb52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala @@ -514,9 +514,9 @@ class CastWithAnsiOffSuite extends CastSuiteBase { val negativeTs = Timestamp.valueOf("1900-05-05 18:34:56.1") assert(negativeTs.getTime < 0) val expectedSecs = Math.floorDiv(negativeTs.getTime, MILLIS_PER_SECOND) - checkEvaluation(cast(negativeTs, ByteType), expectedSecs.toByte) - checkEvaluation(cast(negativeTs, ShortType), expectedSecs.toShort) - checkEvaluation(cast(negativeTs, IntegerType), expectedSecs.toInt) + checkEvaluation(cast(negativeTs, ByteType), null) + checkEvaluation(cast(negativeTs, ShortType), null) + checkEvaluation(cast(negativeTs, IntegerType), null) checkEvaluation(cast(negativeTs, LongType), expectedSecs) } } From 6abc4a1a58ef4e5d896717b10b2314dae2af78af Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Wed, 8 Nov 2023 15:51:50 +0300 Subject: [PATCH 064/121] [SPARK-45841][SQL] Expose stack trace by `DataFrameQueryContext` ### What changes were proposed in this pull request? In the PR, I propose to change the case class `DataFrameQueryContext`, and add stack traces as a field and override `callSite`, `fragment` using the new field `stackTrace`. ### Why are the changes needed? By exposing the stack trace, we give users opportunity to see all stack traces needed for debugging. ### Does this PR introduce _any_ user-facing change? No, `DataFrameQueryContext` hasn't been released yet. ### How was this patch tested? By running the modified test suite: ``` $ build/sbt "test:testOnly *DatasetSuite" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43703 from MaxGekk/stack-traces-in-DataFrameQueryContext. Authored-by: Max Gekk Signed-off-by: Max Gekk --- .../sql/catalyst/trees/QueryContexts.scala | 33 ++++++++----------- .../org/apache/spark/sql/DatasetSuite.scala | 13 +++++--- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index 8d885d07ca8b0..874c834b75585 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -134,9 +134,7 @@ case class SQLQueryContext( override def callSite: String = throw new UnsupportedOperationException } -case class DataFrameQueryContext( - override val fragment: String, - override val callSite: String) extends QueryContext { +case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement]) extends QueryContext { override val contextType = QueryContextType.DataFrame override def objectType: String = throw new UnsupportedOperationException @@ -144,6 +142,19 @@ case class DataFrameQueryContext( override def startIndex: Int = throw new UnsupportedOperationException override def stopIndex: Int = throw new UnsupportedOperationException + override val fragment: String = { + stackTrace.headOption.map { firstElem => + val methodName = firstElem.getMethodName + if (methodName.length > 1 && methodName(0) == '$') { + methodName.substring(1) + } else { + methodName + } + }.getOrElse("") + } + + override val callSite: String = stackTrace.tail.headOption.map(_.toString).getOrElse("") + override lazy val summary: String = { val builder = new StringBuilder builder ++= "== DataFrame ==\n" @@ -157,19 +168,3 @@ case class DataFrameQueryContext( builder.result() } } - -object DataFrameQueryContext { - def apply(elements: Array[StackTraceElement]): DataFrameQueryContext = { - val fragment = elements.headOption.map { firstElem => - val methodName = firstElem.getMethodName - if (methodName.length > 1 && methodName(0) == '$') { - methodName.substring(1) - } else { - methodName - } - }.getOrElse("") - val callSite = elements.tail.headOption.map(_.toString).getOrElse("") - - DataFrameQueryContext(fragment, callSite) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 66105d2ac429f..dcbd8948120ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoders, ExpressionEncod import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.BoxedIntEncoder import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, GenericRowWithSchema} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} +import org.apache.spark.sql.catalyst.trees.DataFrameQueryContext import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -2668,16 +2669,18 @@ class DatasetSuite extends QueryTest withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { val df = Seq(1).toDS() var callSitePattern: String = null + val exception = intercept[AnalysisException] { + callSitePattern = getNextLineCallSitePattern() + val c = col("a") + df.select(c) + } checkError( - exception = intercept[AnalysisException] { - callSitePattern = getNextLineCallSitePattern() - val c = col("a") - df.select(c) - }, + exception, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map("objectName" -> "`a`", "proposal" -> "`value`"), context = ExpectedContext(fragment = "col", callSitePattern = callSitePattern)) + assert(exception.context.head.asInstanceOf[DataFrameQueryContext].stackTrace.length == 2) } } } From b5408e1ce61ce2195de72dcf79d8355c16b4b92a Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 8 Nov 2023 08:24:52 -0800 Subject: [PATCH 065/121] [SPARK-45828][SQL] Remove deprecated method in dsl ### What changes were proposed in this pull request? The pr aims to remove `some deprecated method` in dsl. ### Why are the changes needed? After https://github.com/apache/spark/pull/36646 (Apache Spark 3.4.0), the method `def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)()` and `def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan)` has been marked as `deprecated` and we need to remove it in `Spark 4.0`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43708 from panbingkun/SPARK-45828. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/sql/catalyst/dsl/package.scala | 6 ------ .../sql/catalyst/optimizer/TransposeWindowSuite.scala | 8 ++++---- .../spark/sql/catalyst/plans/LogicalPlanSuite.scala | 6 +++--- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5f85716fa2833..30d4c2dbb409f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -152,9 +152,6 @@ package object dsl { def desc: SortOrder = SortOrder(expr, Descending) def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Seq.empty) def as(alias: String): NamedExpression = Alias(expr, alias)() - // TODO: Remove at Spark 4.0.0 - @deprecated("Use as(alias: String)", "3.4.0") - def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } trait ExpressionConversions { @@ -468,9 +465,6 @@ package object dsl { limit: Int): LogicalPlan = WindowGroupLimit(partitionSpec, orderSpec, rankLikeFunction, limit, logicalPlan) - // TODO: Remove at Spark 4.0.0 - @deprecated("Use subquery(alias: String)", "3.4.0") - def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) def subquery(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala index 8d4c2de10e34f..f4d520bbb4439 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala @@ -146,15 +146,15 @@ class TransposeWindowSuite extends PlanTest { test("SPARK-38034: transpose two adjacent windows with compatible partitions " + "which is not a prefix") { val query = testRelation - .window(Seq(sum(c).as(Symbol("sum_a_2"))), partitionSpec4, orderSpec2) - .window(Seq(sum(c).as(Symbol("sum_a_1"))), partitionSpec3, orderSpec1) + .window(Seq(sum(c).as("sum_a_2")), partitionSpec4, orderSpec2) + .window(Seq(sum(c).as("sum_a_1")), partitionSpec3, orderSpec1) val analyzed = query.analyze val optimized = Optimize.execute(analyzed) val correctAnswer = testRelation - .window(Seq(sum(c).as(Symbol("sum_a_1"))), partitionSpec3, orderSpec1) - .window(Seq(sum(c).as(Symbol("sum_a_2"))), partitionSpec4, orderSpec2) + .window(Seq(sum(c).as("sum_a_1")), partitionSpec3, orderSpec1) + .window(Seq(sum(c).as("sum_a_2")), partitionSpec4, orderSpec2) .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("sum_a_2"), Symbol("sum_a_1")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index ea0fcac881c7a..3eba9eebc3d5f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -126,12 +126,12 @@ class LogicalPlanSuite extends SparkFunSuite { assert(sort2.maxRows === Some(100)) assert(sort2.maxRowsPerPartition === Some(100)) - val c1 = Literal(1).as(Symbol("a")).toAttribute.newInstance().withNullability(true) - val c2 = Literal(2).as(Symbol("b")).toAttribute.newInstance().withNullability(true) + val c1 = Literal(1).as("a").toAttribute.newInstance().withNullability(true) + val c2 = Literal(2).as("b").toAttribute.newInstance().withNullability(true) val expand = Expand( Seq(Seq(Literal(null), Symbol("b")), Seq(Symbol("a"), Literal(null))), Seq(c1, c2), - sort.select(Symbol("id") as Symbol("a"), Symbol("id") + 1 as Symbol("b"))) + sort.select(Symbol("id") as "a", Symbol("id") + 1 as "b")) assert(expand.maxRows === Some(200)) assert(expand.maxRowsPerPartition === Some(68)) From e331de06dd0526761c804b32640e3471ce772d38 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Wed, 8 Nov 2023 08:26:23 -0800 Subject: [PATCH 066/121] [MINOR][CORE][SQL] Clean up expired comments: `Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation.` ### What changes were proposed in this pull request? This pr just clean up expired comments: `Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation.` ### Why are the changes needed? Apache Spark 4.0 only support Scala 2.13, so these comments are no longer needed ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? No testing required ### Was this patch authored or co-authored using generative AI tooling? No Closes #43718 from LuciferYang/minor-comments. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/util/BoundedPriorityQueue.scala | 2 -- .../org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala | 2 -- .../apache/spark/sql/catalyst/expressions/AttributeMap.scala | 2 -- .../apache/spark/sql/execution/streaming/StreamProgress.scala | 2 -- 4 files changed, 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala index ccb4d2063ff3b..9fed2373ea552 100644 --- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala +++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala @@ -31,8 +31,6 @@ import scala.jdk.CollectionConverters._ private[spark] class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) extends Iterable[A] with Growable[A] with Serializable { - // Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation. - private val underlying = new JPriorityQueue[A](maxSize, ord) override def iterator: Iterator[A] = underlying.iterator.asScala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala index e18a01810d2eb..640304efce4b4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -30,8 +30,6 @@ import java.util.Locale class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Map[String, T] with Serializable { - // Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation. - val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase(Locale.ROOT))) override def get(k: String): Option[T] = keyLowerCasedMap.get(k.toLowerCase(Locale.ROOT)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index ac6149f3acc4d..b317cacc061b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -41,8 +41,6 @@ object AttributeMap { class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) extends Map[Attribute, A] with Serializable { - // Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation. - override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) override def getOrElse[B1 >: A](k: Attribute, default: => B1): B1 = get(k).getOrElse(default) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 6aa1b46cbb94a..02f52bb30e1f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -29,8 +29,6 @@ class StreamProgress( new immutable.HashMap[SparkDataStream, OffsetV2]) extends scala.collection.immutable.Map[SparkDataStream, OffsetV2] { - // Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation. - def toOffsetSeq(source: Seq[SparkDataStream], metadata: OffsetSeqMetadata): OffsetSeq = { OffsetSeq(source.map(get), Some(metadata)) } From 9d93b7112a31965447a34301889f90d14578e628 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 8 Nov 2023 09:23:12 -0800 Subject: [PATCH 067/121] [SPARK-45639][SQL][PYTHON] Support loading Python data sources in DataFrameReader ### What changes were proposed in this pull request? This PR supports `spark.read.format(...).load()` for Python data sources. After this PR, users can use a Python data source directly like this: ```python from pyspark.sql.datasource import DataSource, DataSourceReader class MyReader(DataSourceReader): def read(self, partition): yield (0, 1) class MyDataSource(DataSource): classmethod def name(cls): return "my-source" def schema(self): return "id INT, value INT" def reader(self, schema): return MyReader() spark.dataSource.register(MyDataSource) df = spark.read.format("my-source").load() df.show() +---+-----+ | id|value| +---+-----+ | 0| 1| +---+-----+ ``` ### Why are the changes needed? To support Python data sources. ### Does this PR introduce _any_ user-facing change? Yes. After this PR, users can load a custom Python data source using `spark.read.format(...).load()`. ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43630 from allisonwang-db/spark-45639-ds-lookup. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- .../main/resources/error/error-classes.json | 12 +++ dev/sparktestsupport/modules.py | 1 + docs/sql-error-conditions.md | 12 +++ python/pyspark/sql/session.py | 4 + .../sql/tests/test_python_datasource.py | 97 +++++++++++++++++-- .../pyspark/sql/worker/create_data_source.py | 16 ++- .../sql/errors/QueryCompilationErrors.scala | 12 +++ .../apache/spark/sql/DataFrameReader.scala | 48 +++++++-- .../datasources/DataSourceManager.scala | 31 +++++- .../python/UserDefinedPythonDataSource.scala | 15 ++- .../python/PythonDataSourceSuite.scala | 35 +++++++ 11 files changed, 255 insertions(+), 28 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index db46ee8ca208c..c38171c3d9e63 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -850,6 +850,12 @@ ], "sqlState" : "42710" }, + "DATA_SOURCE_NOT_EXIST" : { + "message" : [ + "Data source '' not found. Please make sure the data source is registered." + ], + "sqlState" : "42704" + }, "DATA_SOURCE_NOT_FOUND" : { "message" : [ "Failed to find the data source: . Please find packages at `https://spark.apache.org/third-party-projects.html`." @@ -1095,6 +1101,12 @@ ], "sqlState" : "42809" }, + "FOUND_MULTIPLE_DATA_SOURCES" : { + "message" : [ + "Detected multiple data sources with the name ''. Please check the data source isn't simultaneously registered and located in the classpath." + ], + "sqlState" : "42710" + }, "GENERATED_COLUMN_WITH_DEFAULT_VALUE" : { "message" : [ "A column cannot have both a default value and a generation expression but column has default value: () and generation expression: ()." diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 95c9069a83131..01757ba28dd23 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -511,6 +511,7 @@ def __hash__(self): "pyspark.sql.tests.pandas.test_pandas_udf_window", "pyspark.sql.tests.pandas.test_converter", "pyspark.sql.tests.test_pandas_sqlmetrics", + "pyspark.sql.tests.test_python_datasource", "pyspark.sql.tests.test_readwriter", "pyspark.sql.tests.test_serde", "pyspark.sql.tests.test_session", diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 7b0bc8ceb2b5a..8a5faa15dc9cd 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -454,6 +454,12 @@ DataType `` requires a length parameter, for example ``(10). Please Data source '``' already exists in the registry. Please use a different name for the new data source. +### DATA_SOURCE_NOT_EXIST + +[SQLSTATE: 42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Data source '``' not found. Please make sure the data source is registered. + ### DATA_SOURCE_NOT_FOUND [SQLSTATE: 42K02](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -669,6 +675,12 @@ No such struct field `` in ``. The operation `` is not allowed on the ``: ``. +### FOUND_MULTIPLE_DATA_SOURCES + +[SQLSTATE: 42710](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Detected multiple data sources with the name '``'. Please check the data source isn't simultaneously registered and located in the classpath. + ### GENERATED_COLUMN_WITH_DEFAULT_VALUE [SQLSTATE: 42623](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 4ab7281d7ac87..85aff09aa3df1 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -884,6 +884,10 @@ def dataSource(self) -> "DataSourceRegistration": Returns ------- :class:`DataSourceRegistration` + + Notes + ----- + This feature is experimental and unstable. """ from pyspark.sql.datasource import DataSourceRegistration diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index b429d73fb7d77..fe6a841752746 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -14,10 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os import unittest from pyspark.sql.datasource import DataSource, DataSourceReader +from pyspark.sql.types import Row +from pyspark.testing import assertDataFrameEqual from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.utils import SPARK_HOME class BasePythonDataSourceTestsMixin: @@ -45,16 +49,93 @@ def read(self, partition): self.assertEqual(list(reader.partitions()), [None]) self.assertEqual(list(reader.read(None)), [(None,)]) - def test_register_data_source(self): - class MyDataSource(DataSource): - ... + def test_in_memory_data_source(self): + class InMemDataSourceReader(DataSourceReader): + DEFAULT_NUM_PARTITIONS: int = 3 + + def __init__(self, paths, options): + self.paths = paths + self.options = options + + def partitions(self): + if "num_partitions" in self.options: + num_partitions = int(self.options["num_partitions"]) + else: + num_partitions = self.DEFAULT_NUM_PARTITIONS + return range(num_partitions) + + def read(self, partition): + yield partition, str(partition) + + class InMemoryDataSource(DataSource): + @classmethod + def name(cls): + return "memory" + + def schema(self): + return "x INT, y STRING" + + def reader(self, schema) -> "DataSourceReader": + return InMemDataSourceReader(self.paths, self.options) + + self.spark.dataSource.register(InMemoryDataSource) + df = self.spark.read.format("memory").load() + self.assertEqual(df.rdd.getNumPartitions(), 3) + assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1"), Row(x=2, y="2")]) - self.spark.dataSource.register(MyDataSource) + df = self.spark.read.format("memory").option("num_partitions", 2).load() + assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")]) + self.assertEqual(df.rdd.getNumPartitions(), 2) + + def test_custom_json_data_source(self): + import json + + class JsonDataSourceReader(DataSourceReader): + def __init__(self, paths, options): + self.paths = paths + self.options = options + + def partitions(self): + return iter(self.paths) + + def read(self, path): + with open(path, "r") as file: + for line in file.readlines(): + if line.strip(): + data = json.loads(line) + yield data.get("name"), data.get("age") + + class JsonDataSource(DataSource): + @classmethod + def name(cls): + return "my-json" + + def schema(self): + return "name STRING, age INT" + + def reader(self, schema) -> "DataSourceReader": + return JsonDataSourceReader(self.paths, self.options) + + self.spark.dataSource.register(JsonDataSource) + path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") + path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json") + df1 = self.spark.read.format("my-json").load(path1) + self.assertEqual(df1.rdd.getNumPartitions(), 1) + assertDataFrameEqual( + df1, + [Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)], + ) - self.assertTrue( - self.spark._jsparkSession.sharedState() - .dataSourceRegistry() - .dataSourceExists("MyDataSource") + df2 = self.spark.read.format("my-json").load([path1, path2]) + self.assertEqual(df2.rdd.getNumPartitions(), 2) + assertDataFrameEqual( + df2, + [ + Row(name="Michael", age=None), + Row(name="Andy", age=30), + Row(name="Justin", age=19), + Row(name="Jonathan", age=None), + ], ) diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py index ea56d2cc75221..6a9ef79b7c18d 100644 --- a/python/pyspark/sql/worker/create_data_source.py +++ b/python/pyspark/sql/worker/create_data_source.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import inspect import os import sys from typing import IO, List from pyspark.accumulators import _accumulatorRegistry -from pyspark.errors import PySparkAssertionError, PySparkRuntimeError +from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( read_bool, @@ -84,8 +84,20 @@ def main(infile: IO, outfile: IO) -> None: }, ) + # Check the name method is a class method. + if not inspect.ismethod(data_source_cls.name): + raise PySparkTypeError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "'name()' method to be a classmethod", + "actual": f"'{type(data_source_cls.name).__name__}'", + }, + ) + # Receive the provider name. provider = utf8_deserializer.loads(infile) + + # Check if the provider name matches the data source's name. if provider.lower() != data_source_cls.name().lower(): raise PySparkAssertionError( error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 1925eddd2ce23..0c5dcb1ead01e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3805,4 +3805,16 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat errorClass = "DATA_SOURCE_ALREADY_EXISTS", messageParameters = Map("provider" -> name)) } + + def dataSourceDoesNotExist(name: String): Throwable = { + new AnalysisException( + errorClass = "DATA_SOURCE_NOT_EXIST", + messageParameters = Map("provider" -> name)) + } + + def foundMultipleDataSources(provider: String): Throwable = { + new AnalysisException( + errorClass = "FOUND_MULTIPLE_DATA_SOURCES", + messageParameters = Map("provider" -> provider)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9992d8cbba076..ef447e8a80102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql -import java.util.{Locale, Properties} +import java.util.{Locale, Properties, ServiceConfigurationError} import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Success, Try} -import org.apache.spark.Partition +import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable} import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -208,10 +209,45 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError() } - DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).flatMap { provider => - DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions, - source, paths: _*) - }.getOrElse(loadV1Source(paths: _*)) + val isUserDefinedDataSource = + sparkSession.sharedState.dataSourceManager.dataSourceExists(source) + + Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match { + case Success(providerOpt) => + // The source can be successfully loaded as either a V1 or a V2 data source. + // Check if it is also a user-defined data source. + if (isUserDefinedDataSource) { + throw QueryCompilationErrors.foundMultipleDataSources(source) + } + providerOpt.flatMap { provider => + DataSourceV2Utils.loadV2Source( + sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*) + }.getOrElse(loadV1Source(paths: _*)) + case Failure(exception) => + // Exceptions are thrown while trying to load the data source as a V1 or V2 data source. + // For the following not found exceptions, if the user-defined data source is defined, + // we can instead return the user-defined data source. + val isNotFoundError = exception match { + case _: NoClassDefFoundError | _: SparkClassNotFoundException => true + case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND" + case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError] + case _ => false + } + if (isNotFoundError && isUserDefinedDataSource) { + loadUserDefinedDataSource(paths) + } else { + // Throw the original exception. + throw exception + } + } + } + + private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = { + val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source) + // Unless the legacy path option behavior is enabled, the extraOptions here + // should not include "path" or "paths" as keys. + val plan = builder(sparkSession, source, paths, userSpecifiedSchema, extraOptions) + Dataset.ofRows(sparkSession, plan) } private def loadV1Source(paths: String*) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index 283ca2ac62edc..72a9e6497aca5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -22,10 +22,14 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap +/** + * A manager for user-defined data sources. It is used to register and lookup data sources by + * their short names or fully qualified names. + */ class DataSourceManager { private type DataSourceBuilder = ( @@ -33,22 +37,41 @@ class DataSourceManager { String, // provider name Seq[String], // paths Option[StructType], // user specified schema - CaseInsensitiveStringMap // options + CaseInsensitiveMap[String] // options ) => LogicalPlan private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]() private def normalize(name: String): String = name.toLowerCase(Locale.ROOT) + /** + * Register a data source builder for the given provider. + * Note that the provider name is case-insensitive. + */ def registerDataSource(name: String, builder: DataSourceBuilder): Unit = { val normalizedName = normalize(name) if (dataSourceBuilders.containsKey(normalizedName)) { throw QueryCompilationErrors.dataSourceAlreadyExists(name) } - // TODO(SPARK-45639): check if the data source is a DSv1 or DSv2 using loadDataSource. dataSourceBuilders.put(normalizedName, builder) } - def dataSourceExists(name: String): Boolean = + /** + * Returns a data source builder for the given provider and throw an exception if + * it does not exist. + */ + def lookupDataSource(name: String): DataSourceBuilder = { + if (dataSourceExists(name)) { + dataSourceBuilders.get(normalize(name)) + } else { + throw QueryCompilationErrors.dataSourceDoesNotExist(name) + } + } + + /** + * Checks if a data source with the specified name exists (case-insensitive). + */ + def dataSourceExists(name: String): Boolean = { dataSourceBuilders.containsKey(normalize(name)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index dbff8eefcd5fb..703c1e10ce265 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.python import java.io.{DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.jdk.CollectionConverters._ import net.razorvine.pickle.Pickler @@ -28,9 +27,9 @@ import org.apache.spark.api.python.{PythonFunction, PythonWorkerUtils, SimplePyt import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PythonDataSource} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A user-defined Python data source. This is used by the Python API. @@ -44,7 +43,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { provider: String, paths: Seq[String], userSpecifiedSchema: Option[StructType], - options: CaseInsensitiveStringMap): LogicalPlan = { + options: CaseInsensitiveMap[String]): LogicalPlan = { val runner = new UserDefinedPythonDataSourceRunner( dataSourceCls, provider, paths, userSpecifiedSchema, options) @@ -70,7 +69,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { provider: String, paths: Seq[String] = Seq.empty, userSpecifiedSchema: Option[StructType] = None, - options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty): DataFrame = { + options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): DataFrame = { val plan = builder(sparkSession, provider, paths, userSpecifiedSchema, options) Dataset.ofRows(sparkSession, plan) } @@ -91,7 +90,7 @@ class UserDefinedPythonDataSourceRunner( provider: String, paths: Seq[String], userSpecifiedSchema: Option[StructType], - options: CaseInsensitiveStringMap) + options: CaseInsensitiveMap[String]) extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) { override val workerModule = "pyspark.sql.worker.create_data_source" @@ -113,9 +112,9 @@ class UserDefinedPythonDataSourceRunner( // Send the options dataOut.writeInt(options.size) - options.entrySet.asScala.foreach { e => - PythonWorkerUtils.writeUTF(e.getKey, dataOut) - PythonWorkerUtils.writeUTF(e.getValue, dataOut) + options.iterator.foreach { case (key, value) => + PythonWorkerUtils.writeUTF(key, dataOut) + PythonWorkerUtils.writeUTF(value, dataOut) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 6c749c2c9b67a..22a1e5250cd95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -155,6 +155,41 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { parameters = Map("provider" -> dataSourceName)) } + test("load data source") { + assume(shouldTestPythonUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource, DataSourceReader + |class SimpleDataSourceReader(DataSourceReader): + | def __init__(self, paths, options): + | self.paths = paths + | self.options = options + | + | def partitions(self): + | return iter(self.paths) + | + | def read(self, path): + | yield (path, 1) + | + |class $dataSourceName(DataSource): + | @classmethod + | def name(cls) -> str: + | return "test" + | + | def schema(self) -> str: + | return "id STRING, value INT" + | + | def reader(self, schema): + | return SimpleDataSourceReader(self.paths, self.options) + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython("test", dataSource) + + checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1))) + checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1))) + checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1))) + } + test("reader not implemented") { assume(shouldTestPythonUDFs) val dataSourceScript = From eabea643c7424340397fc91dd89329baf31b48dd Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 8 Nov 2023 14:58:36 -0600 Subject: [PATCH 068/121] [SPARK-42821][SQL] Remove unused parameters in splitFiles methods ### What changes were proposed in this pull request? The pr aims to remove unused parameters in PartitionedFileUtil.splitFiles methods ### Why are the changes needed? Make the code more concise. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. Closes #40454 from panbingkun/minor_PartitionedFileUtil. Authored-by: panbingkun Signed-off-by: Sean Owen --- .../org/apache/spark/sql/execution/DataSourceScanExec.scala | 3 +-- .../org/apache/spark/sql/execution/PartitionedFileUtil.scala | 2 -- .../apache/spark/sql/execution/datasources/v2/FileScan.scala | 1 - 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index e5a38967dc3e1..c7bb3b6719157 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -687,7 +687,7 @@ case class FileSourceScanExec( * @param selectedPartitions Hive-style partition that are part of the read. */ private def createReadRDD( - readFile: (PartitionedFile) => Iterator[InternalRow], + readFile: PartitionedFile => Iterator[InternalRow], selectedPartitions: Array[PartitionDirectory]): RDD[InternalRow] = { val openCostInBytes = relation.sparkSession.sessionState.conf.filesOpenCostInBytes val maxSplitBytes = @@ -711,7 +711,6 @@ case class FileSourceScanExec( val isSplitable = relation.fileFormat.isSplitable( relation.sparkSession, relation.options, file.getPath) PartitionedFileUtil.splitFiles( - sparkSession = relation.sparkSession, file = file, isSplitable = isSplitable, maxSplitBytes = maxSplitBytes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/PartitionedFileUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/PartitionedFileUtil.scala index cc234565d1112..b31369b6768e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/PartitionedFileUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/PartitionedFileUtil.scala @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus} import org.apache.spark.paths.SparkPath -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources._ object PartitionedFileUtil { def splitFiles( - sparkSession: SparkSession, file: FileStatusWithMetadata, isSplitable: Boolean, maxSplitBytes: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 71e86beefdaff..61d61ee7af250 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -152,7 +152,6 @@ trait FileScan extends Scan } partition.files.flatMap { file => PartitionedFileUtil.splitFiles( - sparkSession = sparkSession, file = file, isSplitable = isSplitable(file.getPath), maxSplitBytes = maxSplitBytes, From 4df4fec622f3f6926b979f89daa177ec5e53d4ad Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 8 Nov 2023 13:07:04 -0800 Subject: [PATCH 069/121] [SPARK-45843][CORE] Support `killall` in REST Submission API ### What changes were proposed in this pull request? This PR aims to add `killall` action in REST Submission API. ### Why are the changes needed? To help users to kill all submissions easily. **BEFORE: Script** ```bash for id in $(curl http://master:8080/json/activedrivers | grep id | sed 's/"/ /g' | awk '{print $3}') do curl -XPOST http://master:6066/v1/submissions/kill/$id done ``` **AFTER** ```bash $ curl -XPOST http://master:6066/v1/submissions/killall { "action" : "KillAllSubmissionResponse", "message" : "Kill request for all drivers submitted", "serverSparkVersion" : "4.0.0-SNAPSHOT", "success" : true } ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs with the newly added test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43721 from dongjoon-hyun/SPARK-45843. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../apache/spark/deploy/DeployMessage.scala | 6 ++++ .../apache/spark/deploy/master/Master.scala | 25 +++++++++++++ .../deploy/rest/RestSubmissionClient.scala | 35 +++++++++++++++++++ .../deploy/rest/RestSubmissionServer.scala | 23 +++++++++++- .../deploy/rest/StandaloneRestServer.scala | 19 ++++++++++ .../rest/SubmitRestProtocolResponse.scala | 10 ++++++ .../rest/StandaloneRestSubmitSuite.scala | 31 ++++++++++++++++ 7 files changed, 148 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index f49530461b4d0..4ccc0bd7cdc26 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -233,6 +233,12 @@ private[deploy] object DeployMessages { master: RpcEndpointRef, driverId: String, success: Boolean, message: String) extends DeployMessage + case object RequestKillAllDrivers extends DeployMessage + + case class KillAllDriversResponse( + master: RpcEndpointRef, success: Boolean, message: String) + extends DeployMessage + case class RequestDriverStatus(driverId: String) extends DeployMessage case class DriverStatusResponse(found: Boolean, state: Option[DriverState], diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 3ba50318610ba..dbb647252c5f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -460,6 +460,31 @@ private[deploy] class Master( } } + case RequestKillAllDrivers => + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + s"Can only kill drivers in ALIVE state." + context.reply(KillAllDriversResponse(self, success = false, msg)) + } else { + logInfo("Asked to kill all drivers") + drivers.foreach { d => + val driverId = d.id + if (waitingDrivers.contains(d)) { + waitingDrivers -= d + self.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + } else { + // We just notify the worker to kill the driver here. The final bookkeeping occurs + // on the return path when the worker submits a state change back to the master + // to notify it that the driver was successfully killed. + d.worker.foreach { w => + w.endpoint.send(KillDriver(driverId)) + } + } + logInfo(s"Kill request for $driverId submitted") + } + context.reply(KillAllDriversResponse(self, true, "Kill request for all drivers submitted")) + } + case RequestClearCompletedDriversAndApps => val numDrivers = completedDrivers.length val numApps = completedApps.length diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 3010efc936f97..286305bb76b84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -135,6 +135,35 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { response } + /** Request that the server kill all submissions. */ + def killAllSubmissions(): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to kill all submissions in $master.") + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getKillAllUrl(m) + try { + response = post(url) + response match { + case k: KillAllSubmissionResponse => + if (!Utils.responseFromBackup(k.message)) { + handleRestResponse(k) + handled = true + } + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } + } + response + } + /** Request that the server clears all submissions and applications. */ def clear(): SubmitRestProtocolResponse = { logInfo(s"Submitting a request to clear $master.") @@ -329,6 +358,12 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { new URL(s"$baseUrl/kill/$submissionId") } + /** Return the REST URL for killing all submissions. */ + private def getKillAllUrl(master: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/killall") + } + /** Return the REST URL for clear all existing submissions and applications. */ private def getClearUrl(master: String): URL = { val baseUrl = getBaseUrl(master) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index 3323d0f529ebf..28197fd0a556d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -54,6 +54,7 @@ private[spark] abstract class RestSubmissionServer( protected val submitRequestServlet: SubmitRequestServlet protected val killRequestServlet: KillRequestServlet + protected val killAllRequestServlet: KillAllRequestServlet protected val statusRequestServlet: StatusRequestServlet protected val clearRequestServlet: ClearRequestServlet @@ -64,6 +65,7 @@ private[spark] abstract class RestSubmissionServer( protected lazy val contextToServlet = Map[String, RestServlet]( s"$baseContext/create/*" -> submitRequestServlet, s"$baseContext/kill/*" -> killRequestServlet, + s"$baseContext/killall/*" -> killAllRequestServlet, s"$baseContext/status/*" -> statusRequestServlet, s"$baseContext/clear/*" -> clearRequestServlet, "/*" -> new ErrorServlet // default handler @@ -229,6 +231,25 @@ private[rest] abstract class KillRequestServlet extends RestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse } +/** + * A servlet for handling killAll requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class KillAllRequestServlet extends RestServlet { + + /** + * Have the Master kill all drivers and return an appropriate response to the client. + * Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val responseMessage = handleKillAll() + sendResponse(responseMessage, response) + } + + protected def handleKillAll(): KillAllSubmissionResponse +} + /** * A servlet for handling clear requests passed to the [[RestSubmissionServer]]. */ @@ -331,7 +352,7 @@ private class ErrorServlet extends RestServlet { "Missing the /submissions prefix." case `serverVersion` :: "submissions" :: tail => // http://host:port/correct-version/submissions/* - "Missing an action: please specify one of /create, /kill, /clear or /status." + "Missing an action: please specify one of /create, /kill, /killall, /clear or /status." case unknownVersion :: tail => // http://host:port/unknown-version/* versionMismatch = true diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 8ed716428dc28..d382ec12847dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -63,6 +63,8 @@ private[deploy] class StandaloneRestServer( new StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) protected override val killRequestServlet = new StandaloneKillRequestServlet(masterEndpoint, masterConf) + protected override val killAllRequestServlet = + new StandaloneKillAllRequestServlet(masterEndpoint, masterConf) protected override val statusRequestServlet = new StandaloneStatusRequestServlet(masterEndpoint, masterConf) protected override val clearRequestServlet = @@ -87,6 +89,23 @@ private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, } } +/** + * A servlet for handling killAll requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class StandaloneKillAllRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) + extends KillAllRequestServlet { + + protected def handleKillAll() : KillAllSubmissionResponse = { + val response = masterEndpoint.askSync[DeployMessages.KillAllDriversResponse]( + DeployMessages.RequestKillAllDrivers) + val k = new KillAllSubmissionResponse + k.serverSparkVersion = sparkVersion + k.message = response.message + k.success = response.success + k + } +} + /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala index 21614c22285f8..b9e3b3028ac79 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -55,6 +55,16 @@ private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { } } +/** + * A response to a killAll request in the REST application submission protocol. + */ +private[spark] class KillAllSubmissionResponse extends SubmitRestProtocolResponse { + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(success, "success") + } +} + /** * A response to a clear request in the REST application submission protocol. */ diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index d775aa6542dcd..1cc2c873760df 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -236,6 +236,15 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { assert(clearResponse.success) } + test("SPARK-45843: killAll") { + val masterUrl = startDummyServer() + val response = new RestSubmissionClient(masterUrl).killAllSubmissions() + val killAllResponse = getKillAllResponse(response) + assert(killAllResponse.action === Utils.getFormattedClassName(killAllResponse)) + assert(killAllResponse.serverSparkVersion === SPARK_VERSION) + assert(killAllResponse.success) + } + /* ---------------------------------------- * | Aberrant client / server behavior | * ---------------------------------------- */ @@ -514,6 +523,16 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { } } + /** Return the response as a killAll response, or fail with error otherwise. */ + private def getKillAllResponse(response: SubmitRestProtocolResponse) + : KillAllSubmissionResponse = { + response match { + case k: KillAllSubmissionResponse => k + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected killAll response. Actual: ${r.toJson}") + } + } + /** Return the response as a clear response, or fail with error otherwise. */ private def getClearResponse(response: SubmitRestProtocolResponse): ClearResponse = { response match { @@ -590,6 +609,8 @@ private class DummyMaster( context.reply(SubmitDriverResponse(self, success = true, Some(submitId), submitMessage)) case RequestKillDriver(driverId) => context.reply(KillDriverResponse(self, driverId, success = true, killMessage)) + case RequestKillAllDrivers => + context.reply(KillAllDriversResponse(self, success = true, killMessage)) case RequestDriverStatus(driverId) => context.reply(DriverStatusResponse(found = true, Some(state), None, None, exception)) case RequestClearCompletedDriversAndApps => @@ -636,6 +657,7 @@ private class SmarterMaster(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEn * * When handling a submit request, the server returns a malformed JSON. * When handling a kill request, the server returns an invalid JSON. + * When handling a killAll request, the server returns an invalid JSON. * When handling a status request, the server throws an internal exception. * When handling a clear request, the server throws an internal exception. * The purpose of this class is to test that client handles these cases gracefully. @@ -650,6 +672,7 @@ private class FaultyStandaloneRestServer( protected override val submitRequestServlet = new MalformedSubmitServlet protected override val killRequestServlet = new InvalidKillServlet + protected override val killAllRequestServlet = new InvalidKillAllServlet protected override val statusRequestServlet = new ExplodingStatusServlet protected override val clearRequestServlet = new ExplodingClearServlet @@ -673,6 +696,14 @@ private class FaultyStandaloneRestServer( } } + /** A faulty servlet that produces invalid responses. */ + class InvalidKillAllServlet extends StandaloneKillAllRequestServlet(masterEndpoint, masterConf) { + protected override def handleKillAll(): KillAllSubmissionResponse = { + val k = super.handleKillAll() + k + } + } + /** A faulty status servlet that explodes. */ class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterEndpoint, masterConf) { private def explode: Int = 1 / 0 From dfd7cde91c5d6f034a11ea492be83afaf771ceb6 Mon Sep 17 00:00:00 2001 From: Yihong He Date: Wed, 8 Nov 2023 18:22:40 -0800 Subject: [PATCH 070/121] [SPARK-45842][SQL] Refactor Catalog Function APIs to use analyzer ### What changes were proposed in this pull request? - Refactor Catalog Function APIs to use analyzer ### Why are the changes needed? - Less duplicate logics. We should not directly invoke catalog APIs, but go through analyzer. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #43720 from heyihong/SPARK-45842. Authored-by: Yihong He Signed-off-by: Dongjoon Hyun --- .../spark/sql/internal/CatalogImpl.scala | 59 +++++++++++-------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 5650e9d2399cc..b1ad454fc041f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -22,14 +22,14 @@ import scala.util.control.NonFatal import org.apache.spark.sql._ import org.apache.spark.sql.catalog.{Catalog, CatalogMetadata, Column, Database, Function, Table} -import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, LocalRelation, LogicalPlan, OptionList, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, UnresolvedTableSpec, View} import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, FunctionCatalog, Identifier, SupportsNamespaces, Table => V2Table, TableCatalog, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, Identifier, SupportsNamespaces, Table => V2Table, TableCatalog, V1Table} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, MultipartIdentifierHelper, NamespaceHelper, TransformHelper} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.ShowTablesCommand @@ -284,6 +284,33 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { CatalogImpl.makeDataset(functions.result(), sparkSession) } + private def toFunctionIdent(functionName: String): Seq[String] = { + val parsed = parseIdent(functionName) + // For backward compatibility (Spark 3.3 and prior), we should check if the function exists in + // the Hive Metastore first. + if (parsed.length <= 2 && + !sessionCatalog.isTemporaryFunction(parsed.asFunctionIdentifier) && + sessionCatalog.isPersistentFunction(parsed.asFunctionIdentifier)) { + qualifyV1Ident(parsed) + } else { + parsed + } + } + + private def functionExists(ident: Seq[String]): Boolean = { + val plan = + UnresolvedFunctionName(ident, CatalogImpl.FUNCTION_EXISTS_COMMAND_NAME, false, None) + try { + sparkSession.sessionState.executePlan(plan).analyzed match { + case _: ResolvedPersistentFunc => true + case _: ResolvedNonPersistentFunc => true + case _ => false + } + } catch { + case e: AnalysisException if e.getErrorClass == "UNRESOLVED_ROUTINE" => false + } + } + private def makeFunction(ident: Seq[String]): Function = { val plan = UnresolvedFunctionName(ident, "Catalog.makeFunction", false, None) sparkSession.sessionState.executePlan(plan).analyzed match { @@ -465,17 +492,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * function. This throws an `AnalysisException` when no `Function` can be found. */ override def getFunction(functionName: String): Function = { - val parsed = parseIdent(functionName) - // For backward compatibility (Spark 3.3 and prior), we should check if the function exists in - // the Hive Metastore first. - val nameParts = if (parsed.length <= 2 && - !sessionCatalog.isTemporaryFunction(parsed.asFunctionIdentifier) && - sessionCatalog.isPersistentFunction(parsed.asFunctionIdentifier)) { - qualifyV1Ident(parsed) - } else { - parsed - } - makeFunction(nameParts) + makeFunction(toFunctionIdent(functionName)) } /** @@ -540,23 +557,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * or a function. */ override def functionExists(functionName: String): Boolean = { - val parsed = parseIdent(functionName) - // For backward compatibility (Spark 3.3 and prior), we should check if the function exists in - // the Hive Metastore first. This also checks if it's a built-in/temp function. - (parsed.length <= 2 && sessionCatalog.functionExists(parsed.asFunctionIdentifier)) || { - val plan = UnresolvedIdentifier(parsed) - sparkSession.sessionState.executePlan(plan).analyzed match { - case ResolvedIdentifier(catalog: FunctionCatalog, ident) => catalog.functionExists(ident) - case _ => false - } - } + functionExists(toFunctionIdent(functionName)) } /** * Checks if the function with the specified name exists in the specified database. */ override def functionExists(dbName: String, functionName: String): Boolean = { - sessionCatalog.functionExists(FunctionIdentifier(functionName, Option(dbName))) + // For backward compatibility (Spark 3.3 and prior), here we always look up the function from + // the Hive Metastore. + functionExists(Seq(CatalogManager.SESSION_CATALOG_NAME, dbName, functionName)) } /** @@ -942,4 +952,5 @@ private[sql] object CatalogImpl { new Dataset[T](queryExecution, enc) } + private val FUNCTION_EXISTS_COMMAND_NAME = "Catalog.functionExists" } From 974313994a0594fde7b424e569febed89cafd9ca Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 8 Nov 2023 19:22:13 -0800 Subject: [PATCH 071/121] [SPARK-45835][INFRA] Make gitHub labeler more accurate and remove outdated comments ### What changes were proposed in this pull request? The pr aims to make gitHub labeler more accurate and remove outdated comments. ### Why are the changes needed? The functions mentioned in the comments have been released in the latest version of Github Action labeler. https://github.com/actions/labeler/issues/111 https://github.com/actions/labeler/issues/111#issuecomment-1345989028 image According to the description of the original PR (https://github.com/apache/spark/pull/30244/files), after 'any/all' is released in the official version of `Github Action labeler`, we need to make subsequent updates to better identify the code `label`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Continuous manual observation is required. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43716 from panbingkun/SPARK-45835. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- .github/labeler.yml | 40 +++++------------------------------ .github/workflows/labeler.yml | 13 ------------ 2 files changed, 5 insertions(+), 48 deletions(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index f21b90d460fb6..fc69733f4b66a 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -17,23 +17,6 @@ # under the License. # -# -# Pull Request Labeler Github Action Configuration: https://github.com/marketplace/actions/labeler -# -# Note that we currently cannot use the negatioon operator (i.e. `!`) for miniglob matches as they -# would match any file that doesn't touch them. What's needed is the concept of `any `, which takes a -# list of constraints / globs and then matches all of the constraints for either `any` of the files or -# `all` of the files in the change set. -# -# However, `any`/`all` are not supported in a released version and testing off of the `main` branch -# resulted in some other errors when testing. -# -# An issue has been opened upstream requesting that a release be cut that has support for all/any: -# - https://github.com/actions/labeler/issues/111 -# -# While we wait for this issue to be handled upstream, we can remove -# the negated / `!` matches for now and at least have labels again. -# INFRA: - ".github/**/*" - "appveyor.yml" @@ -45,9 +28,7 @@ INFRA: - "dev/merge_spark_pr.py" - "dev/run-tests-jenkins*" BUILD: - # Can be supported when a stable release with correct all/any is released - #- any: ['dev/**/*', '!dev/merge_spark_pr.py', '!dev/.rat-excludes'] - - "dev/**/*" + - any: ['dev/**/*', '!dev/merge_spark_pr.py', '!dev/run-tests-jenkins*'] - "build/**/*" - "project/**/*" - "assembly/**/*" @@ -55,22 +36,16 @@ BUILD: - "bin/docker-image-tool.sh" - "bin/find-spark-home*" - "scalastyle-config.xml" - # These can be added in the above `any` clause (and the /dev/**/* glob removed) when - # `any`/`all` support is released - # - "!dev/merge_spark_pr.py" - # - "!dev/run-tests-jenkins*" - # - "!dev/.rat-excludes" DOCS: - "docs/**/*" - "**/README.md" - "**/CONTRIBUTING.md" + - "python/docs/**/*" EXAMPLES: - "examples/**/*" - "bin/run-example*" -# CORE needs to be updated when all/any are released upstream. CORE: - # - any: ["core/**/*", "!**/*UI.scala", "!**/ui/**/*"] # If any file matches all of the globs defined in the list started by `any`, label is applied. - - "core/**/*" + - any: ["core/**/*", "!**/*UI.scala", "!**/ui/**/*"] - "common/kvstore/**/*" - "common/network-common/**/*" - "common/network-shuffle/**/*" @@ -82,12 +57,8 @@ SPARK SHELL: - "repl/**/*" - "bin/spark-shell*" SQL: -#- any: ["**/sql/**/*", "!python/pyspark/sql/avro/**/*", "!python/pyspark/sql/streaming/**/*", "!python/pyspark/sql/tests/streaming/test_streaming.py"] - - "**/sql/**/*" + - any: ["**/sql/**/*", "!python/pyspark/sql/avro/**/*", "!python/pyspark/sql/streaming/**/*", "!python/pyspark/sql/tests/streaming/test_streaming*.py"] - "common/unsafe/**/*" - #- "!python/pyspark/sql/avro/**/*" - #- "!python/pyspark/sql/streaming/**/*" - #- "!python/pyspark/sql/tests/streaming/test_streaming.py" - "bin/spark-sql*" - "bin/beeline*" - "sbin/*thriftserver*.sh" @@ -123,7 +94,7 @@ STRUCTURED STREAMING: - "**/sql/**/streaming/**/*" - "connector/kafka-0-10-sql/**/*" - "python/pyspark/sql/streaming/**/*" - - "python/pyspark/sql/tests/streaming/test_streaming.py" + - "python/pyspark/sql/tests/streaming/test_streaming*.py" - "**/*streaming.R" PYTHON: - "bin/pyspark*" @@ -148,7 +119,6 @@ DEPLOY: - "sbin/**/*" CONNECT: - "connector/connect/**/*" - - "**/sql/sparkconnect/**/*" - "python/pyspark/sql/**/connect/**/*" - "python/pyspark/ml/**/connect/**/*" PROTOBUF: diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index c6b6e65bc9fec..b55d28e5a6406 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -34,19 +34,6 @@ jobs: contents: read pull-requests: write steps: - # In order to get back the negated matches like in the old config, - # we need the actinons/labeler concept of `all` and `any` which matches - # all of the given constraints / glob patterns for either `all` - # files or `any` file in the change set. - # - # Github issue which requests a timeline for a release with any/all support: - # - https://github.com/actions/labeler/issues/111 - # This issue also references the issue that mentioned that any/all are only - # supported on main branch (previously called master): - # - https://github.com/actions/labeler/issues/73#issuecomment-639034278 - # - # However, these are not in a published release and the current `main` branch - # has some issues upon testing. - uses: actions/labeler@v4 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" From 093fbf1aa8520193b8d929f9f855afe0aded20a1 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Wed, 8 Nov 2023 19:23:29 -0800 Subject: [PATCH 072/121] [SPARK-45831][CORE][SQL][DSTREAM] Use collection factory instead to create immutable Java collections ### What changes were proposed in this pull request? This pr change to use collection factory instread of `Collections.unmodifiable` to create an immutable Java collection(new collection API introduced after [JEP 269](https://openjdk.org/jeps/269)) ### Why are the changes needed? Make the relevant code look simple and clear. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43709 from LuciferYang/collection-factory. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../apache/spark/network/util/JavaUtils.java | 44 +++++++++---------- .../scala/org/apache/spark/FutureAction.scala | 5 +-- .../org/apache/spark/util/AccumulatorV2.scala | 3 +- .../SpecificParquetRecordReaderBase.java | 5 +-- .../apache/spark/streaming/JavaAPISuite.java | 4 +- 5 files changed, 26 insertions(+), 35 deletions(-) diff --git a/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java index bbe764b8366c8..fa0a2629f3502 100644 --- a/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -202,29 +202,27 @@ private static boolean isSymlink(File file) throws IOException { private static final Map byteSuffixes; static { - final Map timeSuffixesBuilder = new HashMap<>(); - timeSuffixesBuilder.put("us", TimeUnit.MICROSECONDS); - timeSuffixesBuilder.put("ms", TimeUnit.MILLISECONDS); - timeSuffixesBuilder.put("s", TimeUnit.SECONDS); - timeSuffixesBuilder.put("m", TimeUnit.MINUTES); - timeSuffixesBuilder.put("min", TimeUnit.MINUTES); - timeSuffixesBuilder.put("h", TimeUnit.HOURS); - timeSuffixesBuilder.put("d", TimeUnit.DAYS); - timeSuffixes = Collections.unmodifiableMap(timeSuffixesBuilder); - - final Map byteSuffixesBuilder = new HashMap<>(); - byteSuffixesBuilder.put("b", ByteUnit.BYTE); - byteSuffixesBuilder.put("k", ByteUnit.KiB); - byteSuffixesBuilder.put("kb", ByteUnit.KiB); - byteSuffixesBuilder.put("m", ByteUnit.MiB); - byteSuffixesBuilder.put("mb", ByteUnit.MiB); - byteSuffixesBuilder.put("g", ByteUnit.GiB); - byteSuffixesBuilder.put("gb", ByteUnit.GiB); - byteSuffixesBuilder.put("t", ByteUnit.TiB); - byteSuffixesBuilder.put("tb", ByteUnit.TiB); - byteSuffixesBuilder.put("p", ByteUnit.PiB); - byteSuffixesBuilder.put("pb", ByteUnit.PiB); - byteSuffixes = Collections.unmodifiableMap(byteSuffixesBuilder); + timeSuffixes = Map.of( + "us", TimeUnit.MICROSECONDS, + "ms", TimeUnit.MILLISECONDS, + "s", TimeUnit.SECONDS, + "m", TimeUnit.MINUTES, + "min", TimeUnit.MINUTES, + "h", TimeUnit.HOURS, + "d", TimeUnit.DAYS); + + byteSuffixes = Map.ofEntries( + Map.entry("b", ByteUnit.BYTE), + Map.entry("k", ByteUnit.KiB), + Map.entry("kb", ByteUnit.KiB), + Map.entry("m", ByteUnit.MiB), + Map.entry("mb", ByteUnit.MiB), + Map.entry("g", ByteUnit.GiB), + Map.entry("gb", ByteUnit.GiB), + Map.entry("t", ByteUnit.TiB), + Map.entry("tb", ByteUnit.TiB), + Map.entry("p", ByteUnit.PiB), + Map.entry("pb", ByteUnit.PiB)); } /** diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 9100d4ce041bf..a68700421b8df 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -17,7 +17,6 @@ package org.apache.spark -import java.util.Collections import java.util.concurrent.TimeUnit import scala.concurrent._ @@ -255,8 +254,6 @@ private[spark] class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T) extends JavaFutureAction[T] { - import scala.jdk.CollectionConverters._ - override def isCancelled: Boolean = futureAction.isCancelled override def isDone: Boolean = { @@ -266,7 +263,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S } override def jobIds(): java.util.List[java.lang.Integer] = { - Collections.unmodifiableList(futureAction.jobIds.map(Integer.valueOf).asJava) + java.util.List.of(futureAction.jobIds.map(Integer.valueOf): _*) } private def getImpl(timeout: Duration): T = { diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 181033c9d20c8..c6d8073a0c2fa 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -19,7 +19,6 @@ package org.apache.spark.util import java.{lang => jl} import java.io.ObjectInputStream -import java.util.ArrayList import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong @@ -505,7 +504,7 @@ class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { } override def value: java.util.List[T] = this.synchronized { - java.util.Collections.unmodifiableList(new ArrayList[T](getOrCreate)) + java.util.List.copyOf(getOrCreate) } private[spark] def setValue(newValue: java.util.List[T]): Unit = this.synchronized { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 4f2b65f36120a..6d00048154a56 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -23,7 +23,6 @@ import java.lang.reflect.InvocationTargetException; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -238,9 +237,7 @@ public void close() throws IOException { private static Map> toSetMultiMap(Map map) { Map> setMultiMap = new HashMap<>(); for (Map.Entry entry : map.entrySet()) { - Set set = new HashSet<>(); - set.add(entry.getValue()); - setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + setMultiMap.put(entry.getKey(), Set.of(entry.getValue())); } return Collections.unmodifiableMap(setMultiMap); } diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index b1f743b921969..f8d961fa8dd8e 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -704,11 +704,11 @@ public static void assertOrderInvariantEquals( List> expected, List> actual) { List> expectedSets = new ArrayList<>(); for (List list: expected) { - expectedSets.add(Collections.unmodifiableSet(new HashSet<>(list))); + expectedSets.add(Set.copyOf(list)); } List> actualSets = new ArrayList<>(); for (List list: actual) { - actualSets.add(Collections.unmodifiableSet(new HashSet<>(list))); + actualSets.add(Set.copyOf(list)); } Assertions.assertEquals(expectedSets, actualSets); } From 24edc0ef5bee578de8eec3b032f993812e4303ea Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Thu, 9 Nov 2023 15:25:52 +0800 Subject: [PATCH 073/121] [SPARK-45752][SQL] Unreferenced CTE should all be checked by CheckAnalysis0 ### What changes were proposed in this pull request? This PR fixes an issue that if a CTE is referenced by a non-referenced CTE, then this CTE should also have ref count as 0 and goes through CheckAnalysis0. This will guarantee analyzer throw proper error message for problematic CTE which is not referenced. ### Why are the changes needed? To improve error message for non-referenced CTE case. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? UT ### Was this patch authored or co-authored using generative AI tooling? NO Closes #43614 from amaliujia/cte_ref. Lead-authored-by: Rui Wang Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/CheckAnalysis.scala | 28 +++++++++++++++++-- .../org/apache/spark/sql/CTEInlineSuite.scala | 11 ++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index cebaee2cdec9c..29d60ae0f41e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -148,15 +148,39 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass, missingCol, orderedCandidates, a.origin) } + private def checkUnreferencedCTERelations( + cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], + visited: mutable.Map[Long, Boolean], + cteId: Long): Unit = { + if (visited(cteId)) { + return + } + val (cteDef, _, refMap) = cteMap(cteId) + refMap.foreach { case (id, _) => + checkUnreferencedCTERelations(cteMap, visited, id) + } + checkAnalysis0(cteDef.child) + visited(cteId) = true + } + def checkAnalysis(plan: LogicalPlan): Unit = { val inlineCTE = InlineCTE(alwaysInline = true) val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] inlineCTE.buildCTEMap(plan, cteMap) - cteMap.values.foreach { case (relation, refCount, _) => + cteMap.values.foreach { case (relation, _, _) => // If a CTE relation is never used, it will disappear after inline. Here we explicitly check // analysis for it, to make sure the entire query plan is valid. try { - if (refCount == 0) checkAnalysis0(relation.child) + // If a CTE relation ref count is 0, the other CTE relations that reference it + // should also be checked by checkAnalysis0. This code will also guarantee the leaf + // relations that do not reference any others are checked first. + val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false) + cteMap.foreach { case (cteId, _) => + val (_, refCount, _) = cteMap(cteId) + if (refCount == 0) { + checkUnreferencedCTERelations(cteMap, visited, cteId) + } + } } catch { case e: AnalysisException => throw new ExtendedAnalysisException(e, relation.child) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala index 5f6c44792658a..055c04992c009 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala @@ -678,6 +678,17 @@ abstract class CTEInlineSuiteBase }.isDefined, "CTE columns should not be pruned.") } } + + test("SPARK-45752: Unreferenced CTE should all be checked by CheckAnalysis0") { + val e = intercept[AnalysisException](sql( + s""" + |with + |a as (select * from non_exist), + |b as (select * from a) + |select 2 + |""".stripMargin)) + checkErrorTableNotFound(e, "`non_exist`", ExpectedContext("non_exist", 26, 34)) + } } class CTEInlineSuiteAEOff extends CTEInlineSuiteBase with DisableAdaptiveExecutionSuite From c128f811820e5a31ddd5bd1c95ed8dd49017eaea Mon Sep 17 00:00:00 2001 From: xieshuaihu Date: Thu, 9 Nov 2023 15:56:40 +0800 Subject: [PATCH 074/121] [SPARK-45814][CONNECT][SQL] Make ArrowConverters.createEmptyArrowBatch call close() to avoid memory leak ### What changes were proposed in this pull request? Make `ArrowBatchIterator` implement `AutoCloseable` and `ArrowConverters.createEmptyArrowBatch()` call close() to avoid memory leak. ### Why are the changes needed? `ArrowConverters.createEmptyArrowBatch` don't call `super.hasNext`, if `TaskContext.get` returns `None`, then memory allocated in `ArrowBatchIterator` is leaked. In spark connect, `createEmptyArrowBatch` is called in [SparkConnectPlanner](https://github.com/apache/spark/blob/master/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala#L2558) and [SparkConnectPlanExecution](https://github.com/apache/spark/blob/master/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala#L224), which cause a long running driver consume all off-heap memory specified by `-XX:MaxDirectMemorySize`. This is the exception stack: ``` org.apache.arrow.memory.OutOfMemoryException: Failure allocating buffer. at io.netty.buffer.PooledByteBufAllocatorL.allocate(PooledByteBufAllocatorL.java:67) at org.apache.arrow.memory.NettyAllocationManager.(NettyAllocationManager.java:77) at org.apache.arrow.memory.NettyAllocationManager.(NettyAllocationManager.java:84) at org.apache.arrow.memory.NettyAllocationManager$1.create(NettyAllocationManager.java:34) at org.apache.arrow.memory.BaseAllocator.newAllocationManager(BaseAllocator.java:354) at org.apache.arrow.memory.BaseAllocator.newAllocationManager(BaseAllocator.java:349) at org.apache.arrow.memory.BaseAllocator.bufferWithoutReservation(BaseAllocator.java:337) at org.apache.arrow.memory.BaseAllocator.buffer(BaseAllocator.java:315) at org.apache.arrow.memory.BaseAllocator.buffer(BaseAllocator.java:279) at org.apache.arrow.vector.BaseValueVector.allocFixedDataAndValidityBufs(BaseValueVector.java:192) at org.apache.arrow.vector.BaseFixedWidthVector.allocateBytes(BaseFixedWidthVector.java:338) at org.apache.arrow.vector.BaseFixedWidthVector.allocateNew(BaseFixedWidthVector.java:308) at org.apache.arrow.vector.BaseFixedWidthVector.allocateNew(BaseFixedWidthVector.java:273) at org.apache.spark.sql.execution.arrow.ArrowWriter$.$anonfun$create$1(ArrowWriter.scala:44) at scala.collection.StrictOptimizedIterableOps.map(StrictOptimizedIterableOps.scala:100) at scala.collection.StrictOptimizedIterableOps.map$(StrictOptimizedIterableOps.scala:87) at scala.collection.convert.JavaCollectionWrappers$JListWrapper.map(JavaCollectionWrappers.scala:103) at org.apache.spark.sql.execution.arrow.ArrowWriter$.create(ArrowWriter.scala:43) at org.apache.spark.sql.execution.arrow.ArrowConverters$ArrowBatchIterator.(ArrowConverters.scala:93) at org.apache.spark.sql.execution.arrow.ArrowConverters$ArrowBatchWithSchemaIterator.(ArrowConverters.scala:138) at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.(ArrowConverters.scala:231) at org.apache.spark.sql.execution.arrow.ArrowConverters$.createEmptyArrowBatch(ArrowConverters.scala:229) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleSqlCommand(SparkConnectPlanner.scala:2481) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.process(SparkConnectPlanner.scala:2426) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.handleCommand(ExecuteThreadRunner.scala:202) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1(ExecuteThreadRunner.scala:158) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1$adapted(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:189) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:189) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withContextClassLoader$1(SessionHolder.scala:176) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:178) at org.apache.spark.sql.connect.service.SessionHolder.withContextClassLoader(SessionHolder.scala:175) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:188) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.executeInternal(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.org$apache$spark$sql$connect$execution$ExecuteThreadRunner$$execute(ExecuteThreadRunner.scala:84) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner$ExecutionThread.run(ExecuteThreadRunner.scala:228) Caused by: io.netty.util.internal.OutOfDirectMemoryError: failed to allocate 4194304 byte(s) of direct memory (used: 1069547799, max: 1073741824) at io.netty.util.internal.PlatformDependent.incrementMemoryCounter(PlatformDependent.java:845) at io.netty.util.internal.PlatformDependent.allocateDirectNoCleaner(PlatformDependent.java:774) at io.netty.buffer.PoolArena$DirectArena.allocateDirect(PoolArena.java:721) at io.netty.buffer.PoolArena$DirectArena.newChunk(PoolArena.java:696) at io.netty.buffer.PoolArena.allocateNormal(PoolArena.java:215) at io.netty.buffer.PoolArena.tcacheAllocateSmall(PoolArena.java:180) at io.netty.buffer.PoolArena.allocate(PoolArena.java:137) at io.netty.buffer.PoolArena.allocate(PoolArena.java:129) at io.netty.buffer.PooledByteBufAllocatorL$InnerAllocator.newDirectBufferL(PooledByteBufAllocatorL.java:181) at io.netty.buffer.PooledByteBufAllocatorL$InnerAllocator.directBuffer(PooledByteBufAllocatorL.java:214) at io.netty.buffer.PooledByteBufAllocatorL.allocate(PooledByteBufAllocatorL.java:58) ... 37 more ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manually test ### Was this patch authored or co-authored using generative AI tooling? No Closes #43691 from xieshuaihu/spark-45814. Authored-by: xieshuaihu Signed-off-by: yangjie01 --- .../sql/execution/arrow/ArrowConverters.scala | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index d6bf1e29edddd..9ddec74374abd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -80,7 +80,7 @@ private[sql] object ArrowConverters extends Logging { maxRecordsPerBatch: Long, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean, - context: TaskContext) extends Iterator[Array[Byte]] { + context: TaskContext) extends Iterator[Array[Byte]] with AutoCloseable { protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) @@ -93,13 +93,11 @@ private[sql] object ArrowConverters extends Logging { protected val arrowWriter = ArrowWriter.create(root) Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => - root.close() - allocator.close() + close() }} override def hasNext: Boolean = rowIter.hasNext || { - root.close() - allocator.close() + close() false } @@ -124,6 +122,11 @@ private[sql] object ArrowConverters extends Logging { out.toByteArray } + + override def close(): Unit = { + root.close() + allocator.close() + } } private[sql] class ArrowBatchWithSchemaIterator( @@ -226,11 +229,19 @@ private[sql] object ArrowConverters extends Logging { schema: StructType, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean): Array[Byte] = { - new ArrowBatchWithSchemaIterator( + val batches = new ArrowBatchWithSchemaIterator( Iterator.empty, schema, 0L, 0L, timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get()) { override def hasNext: Boolean = true - }.next() + } + Utils.tryWithSafeFinally { + batches.next() + } { + // If taskContext is null, `batches.close()` should be called to avoid memory leak. + if (TaskContext.get() == null) { + batches.close() + } + } } /** From 06d8cbe073499ff16bca3165e2de1192daad3984 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 9 Nov 2023 16:23:38 +0800 Subject: [PATCH 075/121] [SPARK-45847][SQL][TESTS] CliSuite flakiness due to non-sequential guarantee for stdout&stderr ### What changes were proposed in this pull request? In CliSuite, This PR adds a retry for tests that write errors to STDERR. ### Why are the changes needed? To fix flakiness tests as below https://github.com/chenhao-db/apache-spark/actions/runs/6791437199/job/18463313766 https://github.com/dongjoon-hyun/spark/actions/runs/6753670527/job/18361206900 ```sql [info] Spark master: local, Application Id: local-1699402393189 [info] spark-sql> /* SELECT /*+ HINT() 4; */; [info] [info] [PARSE_SYNTAX_ERROR] Syntax error at or near ';'. SQLSTATE: 42601 (line 1, pos 26) [info] [info] == SQL == [info] /* SELECT /*+ HINT() 4; */; [info] --------------------------^^^ [info] [info] spark-sql> /* SELECT /*+ HINT() 4; */ SELECT 1; [info] 1 [info] Time taken: 1.499 seconds, Fetched 1 row(s) [info] [info] [UNCLOSED_BRACKETED_COMMENT] Found an unclosed bracketed comment. Please, append */ at the end of the comment. SQLSTATE: 42601 [info] == SQL == [info] /* Here is a unclosed bracketed comment SELECT 1; [info] spark-sql> /* Here is a unclosed bracketed comment SELECT 1; [info] spark-sql> /* SELECT /*+ HINT() */ 4; */; [info] spark-sql> ``` As you can see the fragment above, the query on the 3rd line from the bottom, came from STDOUT, was printed later than its error output, came from STDERR. In this scenario, the error output would not match anything and would simply go unnoticed. Finally, timed out and failed. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests and CI ### Was this patch authored or co-authored using generative AI tooling? no Closes #43725 from yaooqinn/SPARK-45847. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../sql/hive/thriftserver/CliSuite.scala | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 5391965ded2e9..4f0d4dff566c4 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -383,7 +383,7 @@ class CliSuite extends SparkFunSuite { ) } - test("SPARK-11188 Analysis error reporting") { + testRetry("SPARK-11188 Analysis error reporting") { runCliWithin(timeout = 2.minute, errorResponses = Seq("AnalysisException"))( "select * from nonexistent_table;" -> "nonexistent_table" @@ -551,7 +551,7 @@ class CliSuite extends SparkFunSuite { ) } - test("SparkException with root cause will be printStacktrace") { + testRetry("SparkException with root cause will be printStacktrace") { // If it is not in silent mode, will print the stacktrace runCliWithin( 1.minute, @@ -575,8 +575,8 @@ class CliSuite extends SparkFunSuite { runCliWithin(1.minute)("SELECT MAKE_DATE(-44, 3, 15);" -> "-0044-03-15") } - test("SPARK-33100: Ignore a semicolon inside a bracketed comment in spark-sql") { - runCliWithin(4.minute)( + testRetry("SPARK-33100: Ignore a semicolon inside a bracketed comment in spark-sql") { + runCliWithin(1.minute)( "/* SELECT 'test';*/ SELECT 'test';" -> "test", ";;/* SELECT 'test';*/ SELECT 'test';" -> "test", "/* SELECT 'test';*/;; SELECT 'test';" -> "test", @@ -623,8 +623,8 @@ class CliSuite extends SparkFunSuite { ) } - test("SPARK-37555: spark-sql should pass last unclosed comment to backend") { - runCliWithin(5.minute)( + testRetry("SPARK-37555: spark-sql should pass last unclosed comment to backend") { + runCliWithin(1.minute)( // Only unclosed comment. "/* SELECT /*+ HINT() 4; */;".stripMargin -> "Syntax error at or near ';'", // Unclosed nested bracketed comment. @@ -637,7 +637,7 @@ class CliSuite extends SparkFunSuite { ) } - test("SPARK-37694: delete [jar|file|archive] shall use spark sql processor") { + testRetry("SPARK-37694: delete [jar|file|archive] shall use spark sql processor") { runCliWithin(2.minute, errorResponses = Seq("ParseException"))( "delete jar dummy.jar;" -> "Syntax error at or near 'jar': missing 'FROM'. SQLSTATE: 42601 (line 1, pos 7)") @@ -679,7 +679,7 @@ class CliSuite extends SparkFunSuite { SparkSQLEnv.stop() } - test("SPARK-39068: support in-memory catalog and running concurrently") { + testRetry("SPARK-39068: support in-memory catalog and running concurrently") { val extraConf = Seq("-c", s"${StaticSQLConf.CATALOG_IMPLEMENTATION.key}=in-memory") val cd = new CountDownLatch(2) def t: Thread = new Thread { @@ -699,7 +699,7 @@ class CliSuite extends SparkFunSuite { } // scalastyle:off line.size.limit - test("formats of error messages") { + testRetry("formats of error messages") { def check(format: ErrorMessageFormat.Value, errorMessage: String, silent: Boolean): Unit = { val expected = errorMessage.split(System.lineSeparator()).map("" -> _) runCliWithin( @@ -811,7 +811,6 @@ class CliSuite extends SparkFunSuite { s"spark.sql.catalog.$catalogName.url=jdbc:derby:memory:$catalogName;create=true" val catalogDriver = s"spark.sql.catalog.$catalogName.driver=org.apache.derby.jdbc.AutoloadedDriver" - val database = s"-database $catalogName.SYS" val catalogConfigs = Seq(catalogImpl, catalogDriver, catalogUrl, "spark.sql.catalogImplementation=in-memory") .flatMap(Seq("--conf", _)) From 5ac88b12f86b306e7612591154c26aebabb957a8 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Thu, 9 Nov 2023 19:30:12 +0800 Subject: [PATCH 076/121] [SPARK-44886][SQL] Introduce CLUSTER BY clause for CREATE/REPLACE TABLE ### What changes were proposed in this pull request? This proposes to introduce `CLUSTER BY` SQL clause to CREATE/REPLACE SQL syntax: ``` CREATE TABLE tbl(a int, b string) CLUSTER BY (a, b) ``` This doesn't introduce a default implementation for clustering, but it's up to the catalog/datasource implementation to utilize the clustering information (e.g., Delta, Iceberg, etc.). ### Why are the changes needed? To introduce the concept of clustering to datasources. ### Does this PR introduce _any_ user-facing change? Yes, this introduces a new SQL keyword. ### How was this patch tested? Added extensive unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42577 from imback82/cluster_by. Lead-authored-by: Terry Kim Co-authored-by: Terry Kim Signed-off-by: Wenchen Fan --- .../main/resources/error/error-classes.json | 12 ++ docs/sql-error-conditions.md | 12 ++ .../sql/catalyst/parser/SqlBaseParser.g4 | 5 + .../spark/sql/errors/QueryParsingErrors.scala | 8 ++ .../sql/catalyst/catalog/interface.scala | 63 +++++++++- .../sql/catalyst/parser/AstBuilder.scala | 46 ++++++-- .../catalog/CatalogV2Implicits.scala | 26 ++++- .../connector/expressions/expressions.scala | 36 ++++++ .../sql/catalyst/parser/DDLParserSuite.scala | 110 +++++++++++++++++- .../connector/catalog/InMemoryBaseTable.scala | 1 + .../expressions/TransformExtractorSuite.scala | 43 ++++++- .../analysis/ResolveSessionCatalog.scala | 8 +- .../spark/sql/execution/SparkSqlParser.scala | 3 +- .../datasources/v2/V2SessionCatalog.scala | 8 +- .../spark/sql/internal/CatalogImpl.scala | 3 +- .../CreateTableClusterBySuiteBase.scala | 83 +++++++++++++ .../v1/CreateTableClusterBySuite.scala | 51 ++++++++ .../v2/CreateTableClusterBySuite.scala | 50 ++++++++ .../command/CreateTableClusterBySuite.scala | 39 +++++++ 19 files changed, 583 insertions(+), 24 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableClusterBySuiteBase.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableClusterBySuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableClusterBySuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CreateTableClusterBySuite.scala diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index c38171c3d9e63..26f6c0240afb3 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2963,6 +2963,18 @@ ], "sqlState" : "42601" }, + "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED" : { + "message" : [ + "Cannot specify both CLUSTER BY and CLUSTERED BY INTO BUCKETS." + ], + "sqlState" : "42908" + }, + "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED" : { + "message" : [ + "Cannot specify both CLUSTER BY and PARTITIONED BY." + ], + "sqlState" : "42908" + }, "SPECIFY_PARTITION_IS_NOT_ALLOWED" : { "message" : [ "A CREATE TABLE without explicit column list cannot specify PARTITIONED BY.", diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 8a5faa15dc9cd..2cb433b19fa56 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1852,6 +1852,18 @@ A CREATE TABLE without explicit column list cannot specify bucketing information Please use the form with explicit column list and specify bucketing information. Alternatively, allow bucketing information to be inferred by omitting the clause. +### SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED + +[SQLSTATE: 42908](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Cannot specify both CLUSTER BY and CLUSTERED BY INTO BUCKETS. + +### SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED + +[SQLSTATE: 42908](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Cannot specify both CLUSTER BY and PARTITIONED BY. + ### SPECIFY_PARTITION_IS_NOT_ALLOWED [SQLSTATE: 42601](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 84a31dafed985..bd449a4e194e8 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -298,6 +298,10 @@ replaceTableHeader : (CREATE OR)? REPLACE TABLE identifierReference ; +clusterBySpec + : CLUSTER BY LEFT_PAREN multipartIdentifierList RIGHT_PAREN + ; + bucketSpec : CLUSTERED BY identifierList (SORTED BY orderedIdentifierList)? @@ -383,6 +387,7 @@ createTableClauses :((OPTIONS options=expressionPropertyList) | (PARTITIONED BY partitioning=partitionFieldList) | skewSpec | + clusterBySpec | bucketSpec | rowFormat | createFileFormat | diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 2067bf7d0955d..841a678144f5f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -683,4 +683,12 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { ctx ) } + + def clusterByWithPartitionedBy(ctx: ParserRuleContext): Throwable = { + new ParseException(errorClass = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED", ctx) + } + + def clusterByWithBucketing(ctx: ParserRuleContext): Throwable = { + new ParseException(errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", ctx) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 634afb47ea5e2..066dbd9fad15b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -24,6 +24,9 @@ import java.util.Date import scala.collection.mutable import scala.util.control.NonFatal +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} +import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule} import org.apache.commons.lang3.StringUtils import org.json4s.JsonAST.{JArray, JString} import org.json4s.jackson.JsonMethods._ @@ -31,7 +34,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CurrentUserContext, FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedLeafNode} +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, Resolver, UnresolvedLeafNode} import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal} import org.apache.spark.sql.catalyst.plans.logical._ @@ -39,10 +42,11 @@ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUti import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} /** @@ -170,6 +174,55 @@ case class CatalogTablePartition( } } +/** + * A container for clustering information. + * + * @param columnNames the names of the columns used for clustering. + */ +case class ClusterBySpec(columnNames: Seq[NamedReference]) { + override def toString: String = toJson + + def toJson: String = ClusterBySpec.mapper.writeValueAsString(columnNames.map(_.fieldNames)) +} + +object ClusterBySpec { + private val mapper = { + val ret = new ObjectMapper() with ClassTagExtensions + ret.setSerializationInclusion(Include.NON_ABSENT) + ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + ret.registerModule(DefaultScalaModule) + ret + } + + def fromProperty(columns: String): ClusterBySpec = { + ClusterBySpec(mapper.readValue[Seq[Seq[String]]](columns).map(FieldReference(_))) + } + + def toProperty( + schema: StructType, + clusterBySpec: ClusterBySpec, + resolver: Resolver): (String, String) = { + CatalogTable.PROP_CLUSTERING_COLUMNS -> + normalizeClusterBySpec(schema, clusterBySpec, resolver).toJson + } + + private def normalizeClusterBySpec( + schema: StructType, + clusterBySpec: ClusterBySpec, + resolver: Resolver): ClusterBySpec = { + val normalizedColumns = clusterBySpec.columnNames.map { columnName => + val position = SchemaUtils.findColumnPosition( + columnName.fieldNames(), schema, resolver) + FieldReference(SchemaUtils.getColumnName(position, schema)) + } + + SchemaUtils.checkColumnNameDuplication( + normalizedColumns.map(_.toString), + resolver) + + ClusterBySpec(normalizedColumns) + } +} /** * A container for bucketing information. @@ -462,6 +515,10 @@ case class CatalogTable( if (value.isEmpty) key else s"$key: $value" }.mkString("", "\n", "") } + + lazy val clusterBySpec: Option[ClusterBySpec] = { + properties.get(PROP_CLUSTERING_COLUMNS).map(ClusterBySpec.fromProperty) + } } object CatalogTable { @@ -499,6 +556,8 @@ object CatalogTable { val VIEW_STORING_ANALYZED_PLAN = VIEW_PREFIX + "storingAnalyzedPlan" + val PROP_CLUSTERING_COLUMNS: String = "clusteringColumns" + def splitLargeTableProp( key: String, value: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 6d70ad29f8761..eb501f56d81ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -33,7 +33,7 @@ import org.apache.spark.{SparkArithmeticException, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, ClusterBySpec} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyValue, First, Last, PercentileCont, PercentileDisc} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ @@ -3241,6 +3241,15 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { }) } + /** + * Create a [[ClusterBySpec]]. + */ + override def visitClusterBySpec(ctx: ClusterBySpecContext): ClusterBySpec = withOrigin(ctx) { + val columnNames = ctx.multipartIdentifierList.multipartIdentifier.asScala + .map(typedVisit[Seq[String]]).map(FieldReference(_)).toSeq + ClusterBySpec(columnNames) + } + /** * Convert a property list into a key-value map. * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. @@ -3341,6 +3350,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { * - location * - comment * - serde + * - clusterBySpec * * Note: Partition transforms are based on existing table schema definition. It can be simple * column names, or functions like `year(date_col)`. Partition columns are column names with data @@ -3348,7 +3358,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { */ type TableClauses = ( Seq[Transform], Seq[StructField], Option[BucketSpec], Map[String, String], - OptionList, Option[String], Option[String], Option[SerdeInfo]) + OptionList, Option[String], Option[String], Option[SerdeInfo], Option[ClusterBySpec]) /** * Validate a create table statement and return the [[TableIdentifier]]. @@ -3809,6 +3819,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx) checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.clusterBySpec(), "CLUSTER BY", ctx) checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) if (ctx.skewSpec.size > 0) { @@ -3827,8 +3838,19 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { val comment = visitCommentSpecList(ctx.commentSpec()) val serdeInfo = getSerdeInfo(ctx.rowFormat.asScala.toSeq, ctx.createFileFormat.asScala.toSeq, ctx) + val clusterBySpec = ctx.clusterBySpec().asScala.headOption.map(visitClusterBySpec) + + if (clusterBySpec.isDefined) { + if (partCols.nonEmpty || partTransforms.nonEmpty) { + throw QueryParsingErrors.clusterByWithPartitionedBy(ctx) + } + if (bucketSpec.isDefined) { + throw QueryParsingErrors.clusterByWithBucketing(ctx) + } + } + (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment, - serdeInfo) + serdeInfo, clusterBySpec) } protected def getSerdeInfo( @@ -3881,6 +3903,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { * [OPTIONS table_property_list] * [ROW FORMAT row_format] * [STORED AS file_format] + * [CLUSTER BY (col_name, col_name, ...)] * [CLUSTERED BY (col_name, col_name, ...) * [SORTED BY (col_name [ASC|DESC], ...)] * INTO num_buckets BUCKETS @@ -3902,7 +3925,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { .map(visitCreateOrReplaceTableColTypeList).getOrElse(Nil) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) val (partTransforms, partCols, bucketSpec, properties, options, location, - comment, serdeInfo) = visitCreateTableClauses(ctx.createTableClauses()) + comment, serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses()) if (provider.isDefined && serdeInfo.isDefined) { operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) @@ -3915,7 +3938,10 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } val partitioning = - partitionExpressions(partTransforms, partCols, ctx) ++ bucketSpec.map(_.asTransform) + partitionExpressions(partTransforms, partCols, ctx) ++ + bucketSpec.map(_.asTransform) ++ + clusterBySpec.map(_.asTransform) + val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, serdeInfo, external) @@ -3958,6 +3984,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { * replace_table_clauses (order insensitive): * [OPTIONS table_property_list] * [PARTITIONED BY (partition_fields)] + * [CLUSTER BY (col_name, col_name, ...)] * [CLUSTERED BY (col_name, col_name, ...) * [SORTED BY (col_name [ASC|DESC], ...)] * INTO num_buckets BUCKETS @@ -3973,8 +4000,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { */ override def visitReplaceTable(ctx: ReplaceTableContext): LogicalPlan = withOrigin(ctx) { val orCreate = ctx.replaceTableHeader().CREATE() != null - val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) = - visitCreateTableClauses(ctx.createTableClauses()) + val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo, + clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses()) val columns = Option(ctx.createOrReplaceTableColTypeList()) .map(visitCreateOrReplaceTableColTypeList).getOrElse(Nil) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) @@ -3984,7 +4011,10 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } val partitioning = - partitionExpressions(partTransforms, partCols, ctx) ++ bucketSpec.map(_.asTransform) + partitionExpressions(partTransforms, partCols, ctx) ++ + bucketSpec.map(_.asTransform) ++ + clusterBySpec.map(_.asTransform) + val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment, serdeInfo, external = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 8843a9fa237f0..0c49f9e46730c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.connector.catalog import scala.collection.mutable +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ClusterBySpec} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, QuotingUtils} -import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LogicalExpressions, Transform} +import org.apache.spark.sql.connector.expressions.{BucketTransform, ClusterByTransform, FieldReference, IdentityTransform, LogicalExpressions, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.StructType @@ -53,10 +54,15 @@ private[sql] object CatalogV2Implicits { } } + implicit class ClusterByHelper(spec: ClusterBySpec) { + def asTransform: Transform = clusterBy(spec.columnNames.toArray) + } + implicit class TransformHelper(transforms: Seq[Transform]) { - def convertTransforms: (Seq[String], Option[BucketSpec]) = { + def convertTransforms: (Seq[String], Option[BucketSpec], Option[ClusterBySpec]) = { val identityCols = new mutable.ArrayBuffer[String] var bucketSpec = Option.empty[BucketSpec] + var clusterBySpec = Option.empty[ClusterBySpec] transforms.map { case IdentityTransform(FieldReference(Seq(col))) => @@ -73,11 +79,23 @@ private[sql] object CatalogV2Implicits { sortCol.map(_.fieldNames.mkString(".")))) } + case ClusterByTransform(columnNames) => + if (clusterBySpec.nonEmpty) { + // AstBuilder guarantees that it only passes down one ClusterByTransform. + throw SparkException.internalError("Cannot have multiple cluster by transforms.") + } + clusterBySpec = Some(ClusterBySpec(columnNames)) + case transform => throw QueryExecutionErrors.unsupportedPartitionTransformError(transform) } - (identityCols.toSeq, bucketSpec) + // Parser guarantees that partition and clustering cannot co-exist. + assert(!(identityCols.toSeq.nonEmpty && clusterBySpec.nonEmpty)) + // Parser guarantees that bucketing and clustering cannot co-exist. + assert(!(bucketSpec.nonEmpty && clusterBySpec.nonEmpty)) + + (identityCols.toSeq, bucketSpec, clusterBySpec) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index fbd2520e2a774..0037f52a21b73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -52,6 +52,9 @@ private[sql] object LogicalExpressions { sortedCols: Array[NamedReference]): SortedBucketTransform = SortedBucketTransform(literal(numBuckets, IntegerType), references, sortedCols) + def clusterBy(references: Array[NamedReference]): ClusterByTransform = + ClusterByTransform(references) + def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference) def years(reference: NamedReference): YearsTransform = YearsTransform(reference) @@ -150,6 +153,39 @@ private[sql] object BucketTransform { } } +/** + * This class represents a transform for [[ClusterBySpec]]. This is used to bundle + * ClusterBySpec in CreateTable's partitioning transforms to pass it down to analyzer. + */ +final case class ClusterByTransform( + columnNames: Seq[NamedReference]) extends RewritableTransform { + + override val name: String = "cluster_by" + + override def references: Array[NamedReference] = columnNames.toArray + + override def arguments: Array[Expression] = columnNames.toArray + + override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})" + + override def withReferences(newReferences: Seq[NamedReference]): Transform = { + this.copy(columnNames = newReferences) + } +} + +/** + * Convenience extractor for ClusterByTransform. + */ +object ClusterByTransform { + def unapply(transform: Transform): Option[Seq[NamedReference]] = + transform match { + case NamedTransform("cluster_by", arguments) => + Some(arguments.map(_.asInstanceOf[NamedReference])) + case _ => + None + } +} + private[sql] final case class SortedBucketTransform( numBuckets: Literal[Int], columns: Seq[NamedReference], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index b1f7340504490..2e896d563a73a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{EqualTo, Hex, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{GeneratedColumn, ResolveDefaultColumns} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.{after, first} -import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} +import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, ClusterByTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{Decimal, IntegerType, LongType, MetadataBuilder, StringType, StructType, TimestampType} @@ -190,6 +190,50 @@ class DDLParserSuite extends AnalysisTest { } } + test("create/replace table - with cluster by") { + // Testing cluster by single part and multipart name. + Seq( + ("a INT, b STRING, ts TIMESTAMP", + "a, b", + new StructType() + .add("a", IntegerType) + .add("b", StringType) + .add("ts", TimestampType), + ClusterByTransform(Seq(FieldReference("a"), FieldReference("b")))), + ("a STRUCT, ts TIMESTAMP", + "a.b, ts", + new StructType() + .add("a", + new StructType() + .add("b", IntegerType) + .add("c", StringType)) + .add("ts", TimestampType), + ClusterByTransform(Seq(FieldReference(Seq("a", "b")), FieldReference("ts")))) + ).foreach { case (columns, clusteringColumns, schema, clusterByTransform) => + val createSql = + s"""CREATE TABLE my_tab ($columns) USING parquet + |CLUSTER BY ($clusteringColumns) + |""".stripMargin + val replaceSql = + s"""REPLACE TABLE my_tab ($columns) USING parquet + |CLUSTER BY ($clusteringColumns) + |""".stripMargin + val expectedTableSpec = TableSpec( + Seq("my_tab"), + Some(schema), + Seq(clusterByTransform), + Map.empty[String, String], + Some("parquet"), + OptionList(Seq.empty), + None, + None, + None) + Seq(createSql, replaceSql).foreach { sql => + testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false) + } + } + } + test("create/replace table - with comment") { val createSql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'" val replaceSql = "REPLACE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'" @@ -859,6 +903,26 @@ class DDLParserSuite extends AnalysisTest { fragment = sql18, start = 0, stop = 86)) + + val sql19 = createTableHeader("CLUSTER BY (a)") + checkError( + exception = parseException(sql19), + errorClass = "DUPLICATE_CLAUSES", + parameters = Map("clauseName" -> "CLUSTER BY"), + context = ExpectedContext( + fragment = sql19, + start = 0, + stop = 65)) + + val sql20 = replaceTableHeader("CLUSTER BY (a)") + checkError( + exception = parseException(sql20), + errorClass = "DUPLICATE_CLAUSES", + parameters = Map("clauseName" -> "CLUSTER BY"), + context = ExpectedContext( + fragment = sql20, + start = 0, + stop = 66)) } test("support for other types in OPTIONS") { @@ -2896,6 +2960,50 @@ class DDLParserSuite extends AnalysisTest { ) } + test("create table cluster by with bucket") { + val sql1 = "CREATE TABLE my_tab(a INT, b STRING) " + + "USING parquet CLUSTERED BY (a) INTO 2 BUCKETS CLUSTER BY (a)" + checkError( + exception = parseException(sql1), + errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", + parameters = Map.empty, + context = ExpectedContext(fragment = sql1, start = 0, stop = 96) + ) + } + + test("replace table cluster by with bucket") { + val sql1 = "REPLACE TABLE my_tab(a INT, b STRING) " + + "USING parquet CLUSTERED BY (a) INTO 2 BUCKETS CLUSTER BY (a)" + checkError( + exception = parseException(sql1), + errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED", + parameters = Map.empty, + context = ExpectedContext(fragment = sql1, start = 0, stop = 97) + ) + } + + test("create table cluster by with partitioned by") { + val sql1 = "CREATE TABLE my_tab(a INT, b STRING) " + + "USING parquet CLUSTER BY (a) PARTITIONED BY (a)" + checkError( + exception = parseException(sql1), + errorClass = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED", + parameters = Map.empty, + context = ExpectedContext(fragment = sql1, start = 0, stop = 83) + ) + } + + test("replace table cluster by with partitioned by") { + val sql1 = "REPLACE TABLE my_tab(a INT, b STRING) " + + "USING parquet CLUSTER BY (a) PARTITIONED BY (a)" + checkError( + exception = parseException(sql1), + errorClass = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED", + parameters = Map.empty, + context = ExpectedContext(fragment = sql1, start = 0, stop = 84) + ) + } + test("AstBuilder don't support `INSERT OVERWRITE DIRECTORY`") { val insertDirSql = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index cd7f7295d5cb9..318cbf6962c19 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -90,6 +90,7 @@ abstract class InMemoryBaseTable( case _: HoursTransform => case _: BucketTransform => case _: SortedBucketTransform => + case _: ClusterByTransform => case NamedTransform("truncate", Seq(_: NamedReference, _: Literal[_])) => case t if !allowUnsupportedTransforms => throw new IllegalArgumentException(s"Transform $t is not a supported transform") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index 62cae3c861071..8ac268df80bcd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connector.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst -import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket +import org.apache.spark.sql.connector.expressions.LogicalExpressions.{bucket, clusterBy} import org.apache.spark.sql.types.DataType class TransformExtractorSuite extends SparkFunSuite { @@ -210,4 +210,45 @@ class TransformExtractorSuite extends SparkFunSuite { val copied2 = sortedBucketTransform.withReferences(reference2) assert(copied2.equals(sortedBucketTransform)) } + + test("ClusterBySpec extractor") { + val col = ref("a", "b") + val clusterByTransform = new Transform { + override def name: String = "cluster_by" + override def references: Array[NamedReference] = Array(col) + override def arguments: Array[Expression] = Array(col) + override def toString: String = s"$name(${col.describe})" + } + + clusterByTransform match { + case ClusterByTransform(columnNames) => + assert(columnNames.size === 1) + assert(columnNames(0).fieldNames === Seq("a", "b")) + case _ => + fail("Did not match ClusterByTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case ClusterByTransform(_) => + fail("Matched unknown transform") + case _ => + // expected + } + } + + test("test cluster by") { + val col = Array(ref("a a", "b"), ref("ts")) + + val clusterByTransform = clusterBy(col) + val reference = clusterByTransform.references + assert(reference.length == 2) + assert(reference(0).fieldNames() === Seq("a a", "b")) + assert(reference(1).fieldNames() === Seq("ts")) + val arguments = clusterByTransform.arguments + assert(arguments.length == 2) + assert(arguments(0).asInstanceOf[NamedReference].fieldNames() === Seq("a a", "b")) + assert(arguments(1).asInstanceOf[NamedReference].fieldNames() === Seq("ts")) + val copied = clusterByTransform.withReferences(reference) + assert(copied.equals(clusterByTransform)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 174881a465929..c557ec4a486db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -21,7 +21,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils} +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils, ClusterBySpec} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -532,7 +532,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } else { CatalogTableType.MANAGED } - val (partitionColumns, maybeBucketSpec) = partitioning.convertTransforms + val (partitionColumns, maybeBucketSpec, maybeClusterBySpec) = partitioning.convertTransforms CatalogTable( identifier = table, @@ -542,7 +542,9 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) provider = Some(provider), partitionColumnNames = partitionColumns, bucketSpec = maybeBucketSpec, - properties = properties, + properties = properties ++ + maybeClusterBySpec.map( + clusterBySpec => ClusterBySpec.toProperty(schema, clusterBySpec, conf.resolver)), comment = comment) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 1cc2147725355..d8e5d4f227015 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -323,7 +323,8 @@ class SparkSqlAstBuilder extends AstBuilder { operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) } - val (_, _, _, _, options, location, _, _) = visitCreateTableClauses(ctx.createTableClauses()) + val (_, _, _, _, options, location, _, _, _) = + visitCreateTableClauses(ctx.createTableClauses()) val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText).getOrElse( throw QueryParsingErrors.createTempTableNotSpecifyProviderError(ctx)) val schema = Option(ctx.createOrReplaceTableColTypeList()).map(createSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index bd2e659749311..6dd76973baa5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, ClusterBySpec, SessionCatalog} import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, Column, FunctionCatalog, Identifier, NamespaceChange, SupportsNamespaces, Table, TableCatalog, TableCatalogCapability, TableChange, V1Table} import org.apache.spark.sql.connector.catalog.NamespaceChange.RemoveProperty @@ -114,7 +114,7 @@ class V2SessionCatalog(catalog: SessionCatalog) partitions: Array[Transform], properties: util.Map[String, String]): Table = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TransformHelper - val (partitionColumns, maybeBucketSpec) = partitions.toSeq.convertTransforms + val (partitionColumns, maybeBucketSpec, maybeClusterBySpec) = partitions.toSeq.convertTransforms val provider = properties.getOrDefault(TableCatalog.PROP_PROVIDER, conf.defaultDataSourceName) val tableProperties = properties.asScala val location = Option(properties.get(TableCatalog.PROP_LOCATION)) @@ -135,7 +135,9 @@ class V2SessionCatalog(catalog: SessionCatalog) provider = Some(provider), partitionColumnNames = partitionColumns, bucketSpec = maybeBucketSpec, - properties = tableProperties.toMap, + properties = tableProperties.toMap ++ + maybeClusterBySpec.map( + clusterBySpec => ClusterBySpec.toProperty(schema, clusterBySpec, conf.resolver)), tracksPartitionsInCatalog = conf.manageFilesourcePartitions, comment = Option(properties.get(TableCatalog.PROP_COMMENT))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index b1ad454fc041f..d58cd001e9416 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -379,7 +379,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { val columns = sparkSession.sessionState.executePlan(plan).analyzed match { case ResolvedTable(_, _, table, _) => - val (partitionColumnNames, bucketSpecOpt) = table.partitioning.toSeq.convertTransforms + // TODO (SPARK-45787): Support clusterBySpec for listColumns(). + val (partitionColumnNames, bucketSpecOpt, _) = table.partitioning.toSeq.convertTransforms val bucketColumnNames = bucketSpecOpt.map(_.bucketColumnNames).getOrElse(Nil) schemaToColumns(table.schema(), partitionColumnNames.contains, bucketColumnNames.contains) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableClusterBySuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableClusterBySuiteBase.scala new file mode 100644 index 0000000000000..cb56d11b665db --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableClusterBySuiteBase.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.{AnalysisException, QueryTest} + +/** + * This base suite contains unified tests for the `CREATE/REPLACE TABLE ... CLUSTER BY` command + * that check V1 and V2 table catalogs. The tests that cannot run for all supported catalogs are + * located in more specific test suites: + * + * - V2 table catalog tests: `org.apache.spark.sql.execution.command.v2.CreateTableClusterBySuite` + * - V1 table catalog tests: + * `org.apache.spark.sql.execution.command.v1.CreateTableClusterBySuiteBase` + * - V1 In-Memory catalog: `org.apache.spark.sql.execution.command.v1.CreateTableClusterBySuite` + * - V1 Hive External catalog: + * `org.apache.spark.sql.hive.execution.command.CreateTableClusterBySuite` + */ +trait CreateTableClusterBySuiteBase extends QueryTest with DDLCommandTestUtils { + override val command = "CREATE/REPLACE TABLE CLUSTER BY" + + protected val nestedColumnSchema: String = + "col1 INT, col2 STRUCT, col3 STRUCT<`col4.1` INT>" + protected val nestedClusteringColumns: Seq[String] = + Seq("col2.col3", "col2.`col4 1`", "col3.`col4.1`") + + def validateClusterBy(tableName: String, clusteringColumns: Seq[String]): Unit + + test("test basic CREATE TABLE with clustering columns") { + withNamespaceAndTable("ns", "table") { tbl => + spark.sql(s"CREATE TABLE $tbl (id INT, data STRING) $defaultUsing CLUSTER BY (id, data)") + validateClusterBy(tbl, Seq("id", "data")) + } + } + + test("test clustering columns with comma") { + withNamespaceAndTable("ns", "table") { tbl => + spark.sql(s"CREATE TABLE $tbl (`i,d` INT, data STRING) $defaultUsing " + + "CLUSTER BY (`i,d`, data)") + validateClusterBy(tbl, Seq("`i,d`", "data")) + } + } + + test("test nested clustering columns") { + withNamespaceAndTable("ns", "table") { tbl => + spark.sql(s"CREATE TABLE $tbl " + + s"($nestedColumnSchema) " + + s"$defaultUsing CLUSTER BY (${nestedClusteringColumns.mkString(",")})") + validateClusterBy(tbl, nestedClusteringColumns) + } + } + + test("clustering columns not defined in schema") { + withNamespaceAndTable("ns", "table") { tbl => + val err = intercept[AnalysisException] { + sql(s"CREATE TABLE $tbl (id bigint, data string) $defaultUsing CLUSTER BY (unknown)") + } + assert(err.message.contains("Couldn't find column unknown in:")) + } + } + + // Converts three-part table name (catalog.namespace.table) to TableIdentifier. + protected def parseTableName(threePartTableName: String): (String, String, String) = { + val tablePath = threePartTableName.split('.') + assert(tablePath.length === 3) + (tablePath(0), tablePath(1), tablePath(2)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableClusterBySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableClusterBySuite.scala new file mode 100644 index 0000000000000..2444fe062e283 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableClusterBySuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v1 + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.ClusterBySpec +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.execution.command + +/** + * This base suite contains unified tests for the `CREATE TABLE ... CLUSTER BY` command that + * checks V1 table catalogs. The tests that cannot run for all V1 catalogs are located in more + * specific test suites: + * + * - V1 In-Memory catalog: `org.apache.spark.sql.execution.command.v1.CreateTableClusterBySuite` + * - V1 Hive External catalog: + * `org.apache.spark.sql.hive.execution.command.CreateTableClusterBySuite` + */ +trait CreateTableClusterBySuiteBase extends command.CreateTableClusterBySuiteBase + with command.TestsV1AndV2Commands { + override def validateClusterBy(tableName: String, clusteringColumns: Seq[String]): Unit = { + val catalog = spark.sessionState.catalog + val (_, db, t) = parseTableName(tableName) + val table = catalog.getTableMetadata(TableIdentifier.apply(t, Some(db))) + assert(table.clusterBySpec === Some(ClusterBySpec(clusteringColumns.map(FieldReference(_))))) + } +} + +/** + * The class contains tests for the `CREATE TABLE ... CLUSTER BY` command to check V1 In-Memory + * table catalog. + */ +class CreateTableClusterBySuite extends CreateTableClusterBySuiteBase + with CommandSuiteBase { + override def commandVersion: String = super[CreateTableClusterBySuiteBase].commandVersion +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableClusterBySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableClusterBySuite.scala new file mode 100644 index 0000000000000..86b14d6680388 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableClusterBySuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryPartitionTable} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper +import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference} +import org.apache.spark.sql.execution.command + +/** + * The class contains tests for the `CREATE TABLE ... CLUSTER BY` command to check V2 table + * catalogs. + */ +class CreateTableClusterBySuite extends command.CreateTableClusterBySuiteBase + with CommandSuiteBase { + override def validateClusterBy(tableName: String, clusteringColumns: Seq[String]): Unit = { + val (catalog, namespace, table) = parseTableName(tableName) + val catalogPlugin = spark.sessionState.catalogManager.catalog(catalog) + val partTable = catalogPlugin.asTableCatalog + .loadTable(Identifier.of(Array(namespace), table)) + .asInstanceOf[InMemoryPartitionTable] + assert(partTable.partitioning === + Array(ClusterByTransform(clusteringColumns.map(FieldReference(_))))) + } + + test("test REPLACE TABLE with clustering columns") { + withNamespaceAndTable("ns", "table") { tbl => + spark.sql(s"CREATE TABLE $tbl (id INT) $defaultUsing CLUSTER BY (id)") + validateClusterBy(tbl, Seq("id")) + + spark.sql(s"REPLACE TABLE $tbl (id2 INT) $defaultUsing CLUSTER BY (id2)") + validateClusterBy(tbl, Seq("id2")) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CreateTableClusterBySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CreateTableClusterBySuite.scala new file mode 100644 index 0000000000000..496cc13c49715 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CreateTableClusterBySuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution.command + +import org.apache.spark.sql.execution.command.v1 + +/** + * The class contains tests for the `CREATE TABLE ... CLUSTER BY` command to check V1 Hive external + * table catalog. + */ +class CreateTableClusterBySuite extends v1.CreateTableClusterBySuiteBase with CommandSuiteBase { + // Hive doesn't support nested column names with space and dot. + override protected val nestedColumnSchema: String = + "col1 INT, col2 STRUCT" + override protected val nestedClusteringColumns: Seq[String] = + Seq("col2.col3") + + // Hive catalog doesn't support column names with commas. + override def excluded: Seq[String] = Seq( + s"$command using Hive V1 catalog V1 command: test clustering columns with comma", + s"$command using Hive V1 catalog V2 command: test clustering columns with comma") + + override def commandVersion: String = super[CreateTableClusterBySuiteBase].commandVersion +} From 1e93c408e19f4ce8cec8220fd5eb6c1cb76468ff Mon Sep 17 00:00:00 2001 From: Yaohua Zhao Date: Thu, 9 Nov 2023 19:35:51 +0800 Subject: [PATCH 077/121] [SPARK-45815][SQL][STREAMING] Provide an interface for other Streaming sources to add `_metadata` columns ### What changes were proposed in this pull request? Currently, only the native V1 file-based streaming source can read the `_metadata` column: https://github.com/apache/spark/blob/370870b7a0303e4a2c4b3dea1b479b4fcbc93f8d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala#L63 Our goal is to create an interface that allows other streaming sources to add `_metadata` columns. For instance, we would like the Delta Streaming source, which you can find here: https://github.com/delta-io/delta/blob/master/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaDataSource.scala#L49, to extend this interface and provide the `_metadata` column for its underlying storage format, such as Parquet. ### Why are the changes needed? A generic interface to enable other streaming sources to expose and add `_metadata` columns. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? N/A ### Was this patch authored or co-authored using generative AI tooling? No Closes #43692 from Yaohua628/spark-45815. Authored-by: Yaohua Zhao Signed-off-by: Wenchen Fan --- .../streaming/StreamingRelation.scala | 11 ++++--- .../apache/spark/sql/sources/interfaces.scala | 31 +++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 135d46c5291e7..c5d5a79d34545 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat} +import org.apache.spark.sql.sources.SupportsStreamSourceMetadataColumns object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -60,11 +61,11 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance())) override lazy val metadataOutput: Seq[AttributeReference] = { - dataSource.providingClass match { - // If the dataSource provided class is a same or subclass of FileFormat class - case f if classOf[FileFormat].isAssignableFrom(f) => - metadataOutputWithOutConflicts( - Seq(dataSource.providingInstance().asInstanceOf[FileFormat].createFileMetadataCol())) + dataSource.providingInstance() match { + case f: FileFormat => metadataOutputWithOutConflicts(Seq(f.createFileMetadataCol())) + case s: SupportsStreamSourceMetadataColumns => + metadataOutputWithOutConflicts(s.getMetadataOutput( + dataSource.sparkSession, dataSource.options, dataSource.userSpecifiedSchema)) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 63e57c6804e16..d194ae77e968f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -309,3 +309,34 @@ trait InsertableRelation { trait CatalystScan { def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } + +/** + * Implemented by StreamSourceProvider objects that can generate file metadata columns. + * This trait extends the basic StreamSourceProvider by allowing the addition of metadata + * columns to the schema of the Stream Data Source. + */ +trait SupportsStreamSourceMetadataColumns extends StreamSourceProvider { + + /** + * Returns the metadata columns that should be added to the schema of the Stream Source. + * These metadata columns supplement the columns + * defined in the sourceSchema() of the StreamSourceProvider. + * + * The final schema for the Stream Source, therefore, consists of the source schema as + * defined by StreamSourceProvider.sourceSchema(), with the metadata columns added at the end. + * The caller is responsible for resolving any naming conflicts with the source schema. + * + * An example of using this streaming source metadata output interface is + * when a customized file-based streaming source needs to expose file metadata columns, + * leveraging the hidden file metadata columns from its underlying storage format. + * + * @param spark The SparkSession used for the operation. + * @param options A map of options of the Stream Data Source. + * @param userSpecifiedSchema An optional user-provided schema of the Stream Data Source. + * @return A Seq of AttributeReference representing the metadata output attributes. + */ + def getMetadataOutput( + spark: SparkSession, + options: Map[String, String], + userSpecifiedSchema: Option[StructType]): Seq[AttributeReference] +} From 7bc96e8e37672483a07088dbbdcf3610a497af1d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 9 Nov 2023 13:59:44 -0800 Subject: [PATCH 078/121] [SPARK-45867][CORE] Support `spark.worker.idPattern` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR aims to support `spark.worker.idPattern`. ### Why are the changes needed? To allow users to customize the worker IDs if they want. - From: `worker-20231109183042-[fe80::1%lo0]-39729` - To: `my-worker-20231109183042-[fe80::1%lo0]` For example, ``` $ cat conf/spark-defaults.conf spark.worker.idPattern worker-%2$s ``` Screenshot 2023-11-09 at 1 25 19 PM ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs with newly added test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43740 from dongjoon-hyun/SPARK-45867. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../apache/spark/deploy/worker/Worker.scala | 3 ++- .../apache/spark/internal/config/Worker.scala | 11 +++++++++++ .../spark/deploy/worker/WorkerSuite.scala | 18 ++++++++++++++++-- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 44082ae78794b..ddbba55e00b44 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -62,6 +62,7 @@ private[deploy] class Worker( private val host = rpcEnv.address.host private val port = rpcEnv.address.port + private val workerIdPattern = conf.get(config.Worker.WORKER_ID_PATTERN) Utils.checkHost(host) assert (port > 0) @@ -813,7 +814,7 @@ private[deploy] class Worker( } private def generateWorkerId(): String = { - "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) + workerIdPattern.format(createDateFormat.format(new Date), host, port) } override def onStop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/Worker.scala b/core/src/main/scala/org/apache/spark/internal/config/Worker.scala index fda3a57546b67..f160470edd8f0 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Worker.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Worker.scala @@ -89,4 +89,15 @@ private[spark] object Worker { .version("3.2.0") .stringConf .createWithDefaultString("PWR") + + val WORKER_ID_PATTERN = ConfigBuilder("spark.worker.idPattern") + .internal() + .doc("The pattern for worker ID generation based on Java `String.format` method. The " + + "default value is `worker-%s-%s-%d` which represents the existing worker id string, e.g.," + + " `worker-20231109183042-[fe80::1%lo0]-39729`. Please be careful to generate unique IDs") + .version("4.0.0") + .stringConf + .checkValue(!_.format("20231109000000", "host", 0).exists(_.isWhitespace), + "Whitespace is not allowed.") + .createWithDefaultString("worker-%s-%s-%d") } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index a07d4f76905a7..1b2d92af4b026 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -29,7 +29,7 @@ import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} import org.scalatest.matchers.must.Matchers import org.scalatest.matchers.should.Matchers._ @@ -49,7 +49,7 @@ import org.apache.spark.resource.TestResourceIDs.{WORKER_FPGA_ID, WORKER_GPU_ID} import org.apache.spark.rpc.{RpcAddress, RpcEnv} import org.apache.spark.util.Utils -class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { +class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter with PrivateMethodTester { import org.apache.spark.deploy.DeployTestUtils._ @@ -62,6 +62,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { implicit val formats = DefaultFormats + private val _generateWorkerId = PrivateMethod[String](Symbol("generateWorkerId")) + private var _worker: Worker = _ private def makeWorker( @@ -391,4 +393,16 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { assert(cleanupCalled.get() == dbCleanupEnabled) } } + + test("SPARK-45867: Support worker id pattern") { + val worker = makeWorker(new SparkConf().set(WORKER_ID_PATTERN, "my-worker-%2$s")) + assert(worker.invokePrivate(_generateWorkerId()) === "my-worker-localhost") + } + + test("SPARK-45867: Prevent invalid worker id patterns") { + val m = intercept[IllegalArgumentException] { + makeWorker(new SparkConf().set(WORKER_ID_PATTERN, "my worker")) + }.getMessage + assert(m.contains("Whitespace is not allowed")) + } } From ce818ba969537cf9eb16865a88148407a5992e98 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 9 Nov 2023 15:56:47 -0800 Subject: [PATCH 079/121] [SPARK-45731][SQL] Also update partition statistics with `ANALYZE TABLE` command ### What changes were proposed in this pull request? Also update partition statistics (e.g., total size in bytes, row count) with `ANALYZE TABLE` command. ### Why are the changes needed? Currently when a `ANALYZE TABLE ` command is triggered against a partition table, only table stats are updated, but not partition stats. For Spark users who want to update the latter, they have to use a different syntax: `ANALYZE TABLE PARTITION()` which is more verbose. Given `ANALYZE TABLE` internally already calculates total size for all the partitions, it makes sense to also update partition stats using the result. In this way, Spark users do not need to remember two different syntaxes. In addition, when using `ANALYZE TABLE` with the "scan node", i.e., `NOSCAN` is NOT specified, we can also calculate row count for all the partitions and update the stats accordingly. The above behavior is controlled via a new flag `spark.sql.statistics.updatePartitionStatsInAnalyzeTable.enabled`, which by default is turned off. ### Does this PR introduce _any_ user-facing change? Not by default. When `spark.sql.statistics.updatePartitionStatsInAnalyzeTable.enabled`, Spark will now update partition stats as well with `ANALYZE TABLE` command, on a partitioned table. ### How was this patch tested? Added a unit test for this feature. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43629 from sunchao/SPARK-45731. Authored-by: Chao Sun Signed-off-by: Chao Sun --- .../apache/spark/sql/internal/SQLConf.scala | 13 +++ .../command/AnalyzePartitionCommand.scala | 50 ++--------- .../sql/execution/command/CommandUtils.scala | 87 ++++++++++++++++--- .../spark/sql/hive/StatisticsSuite.scala | 78 +++++++++++++++++ 4 files changed, 170 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ecc3e6e101fcb..ff6ab7b541a35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2664,6 +2664,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val UPDATE_PART_STATS_IN_ANALYZE_TABLE_ENABLED = + buildConf("spark.sql.statistics.updatePartitionStatsInAnalyzeTable.enabled") + .doc("When this config is enabled, Spark will also update partition statistics in analyze " + + "table command (i.e., ANALYZE TABLE .. COMPUTE STATISTICS [NOSCAN]). Note the command " + + "will also become more expensive. When this config is disabled, Spark will only " + + "update table level statistics.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val CBO_ENABLED = buildConf("spark.sql.cbo.enabled") .doc("Enables CBO for estimation of plan statistics when set true.") @@ -5113,6 +5123,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def autoSizeUpdateEnabled: Boolean = getConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED) + def updatePartStatsInAnalyzeTableEnabled: Boolean = + getConf(SQLConf.UPDATE_PART_STATS_IN_ANALYZE_TABLE_ENABLED) + def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index c2b227d6cad77..7fe4c73abf903 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.{Column, Row, SparkSession} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.util.PartitioningUtils -import org.apache.spark.util.collection.Utils /** * Analyzes a given set of partitions to generate per-partition statistics, which will be used in @@ -101,20 +98,13 @@ case class AnalyzePartitionCommand( if (noscan) { Map.empty } else { - calculateRowCountsPerPartition(sparkSession, tableMeta, partitionValueSpec) + CommandUtils.calculateRowCountsPerPartition(sparkSession, tableMeta, partitionValueSpec) } // Update the metastore if newly computed statistics are different from those // recorded in the metastore. - - val sizes = CommandUtils.calculateMultipleLocationSizes(sparkSession, tableMeta.identifier, - partitions.map(_.storage.locationUri)) - val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) => - val newRowCount = rowCounts.get(p.spec) - val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), newRowCount) - newStats.map(_ => p.copy(stats = newStats)) - } - + val (_, newPartitions) = CommandUtils.calculatePartitionStats( + sparkSession, tableMeta, partitions, Some(rowCounts)) if (newPartitions.nonEmpty) { sessionState.catalog.alterPartitions(tableMeta.identifier, newPartitions) } @@ -122,35 +112,5 @@ case class AnalyzePartitionCommand( Seq.empty[Row] } - private def calculateRowCountsPerPartition( - sparkSession: SparkSession, - tableMeta: CatalogTable, - partitionValueSpec: Option[TablePartitionSpec]): Map[TablePartitionSpec, BigInt] = { - val filter = if (partitionValueSpec.isDefined) { - val filters = partitionValueSpec.get.map { - case (columnName, value) => EqualTo(UnresolvedAttribute(columnName), Literal(value)) - } - filters.reduce(And) - } else { - Literal.TrueLiteral - } - - val tableDf = sparkSession.table(tableMeta.identifier) - val partitionColumns = tableMeta.partitionColumnNames.map(Column(_)) - - val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count() - df.collect().map { r => - val partitionColumnValues = partitionColumns.indices.map { i => - if (r.isNullAt(i)) { - ExternalCatalogUtils.DEFAULT_PARTITION_NAME - } else { - r.get(i).toString - } - } - val spec = Utils.toMap(tableMeta.partitionColumnNames, partitionColumnValues) - val count = BigInt(r.getLong(partitionColumns.size)) - (spec, count) - }.toMap - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index c656bdbafa0c7..73478272a6841 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -25,9 +25,11 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Column, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, CatalogTableType} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, CatalogTableType, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -37,6 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.{DataSourceUtils, InMemoryFileIndex} import org.apache.spark.sql.internal.{SessionState, SQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.Utils /** * For the purpose of calculating total directory sizes, use this filter to @@ -76,29 +79,42 @@ object CommandUtils extends Logging { def calculateTotalSize( spark: SparkSession, - catalogTable: CatalogTable): (BigInt, Seq[CatalogTablePartition]) = { + catalogTable: CatalogTable, + partitionRowCount: Option[Map[TablePartitionSpec, BigInt]] = None): + (BigInt, Seq[CatalogTablePartition]) = { val sessionState = spark.sessionState val startTime = System.nanoTime() val (totalSize, newPartitions) = if (catalogTable.partitionColumnNames.isEmpty) { - (calculateSingleLocationSize(sessionState, catalogTable.identifier, - catalogTable.storage.locationUri), Seq()) + val size = calculateSingleLocationSize(sessionState, catalogTable.identifier, + catalogTable.storage.locationUri) + (BigInt(size), Seq()) } else { // Calculate table size as a sum of the visible partitions. See SPARK-21079 val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) logInfo(s"Starting to calculate sizes for ${partitions.length} partitions.") - val paths = partitions.map(_.storage.locationUri) - val sizes = calculateMultipleLocationSizes(spark, catalogTable.identifier, paths) - val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) => - val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), None) - newStats.map(_ => p.copy(stats = newStats)) - } - (sizes.sum, newPartitions) + calculatePartitionStats(spark, catalogTable, partitions, partitionRowCount) } logInfo(s"It took ${(System.nanoTime() - startTime) / (1000 * 1000)} ms to calculate" + s" the total size for table ${catalogTable.identifier}.") (totalSize, newPartitions) } + def calculatePartitionStats( + spark: SparkSession, + catalogTable: CatalogTable, + partitions: Seq[CatalogTablePartition], + partitionRowCount: Option[Map[TablePartitionSpec, BigInt]] = None): + (BigInt, Seq[CatalogTablePartition]) = { + val paths = partitions.map(_.storage.locationUri) + val sizes = calculateMultipleLocationSizes(spark, catalogTable.identifier, paths) + val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) => + val newRowCount = partitionRowCount.flatMap(_.get(p.spec)) + val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), newRowCount) + newStats.map(_ => p.copy(stats = newStats)) + } + (sizes.sum, newPartitions) + } + def calculateSingleLocationSize( sessionState: SessionState, identifier: TableIdentifier, @@ -214,6 +230,7 @@ object CommandUtils extends Logging { tableIdent: TableIdentifier, noScan: Boolean): Unit = { val sessionState = sparkSession.sessionState + val partitionStatsEnabled = sessionState.conf.updatePartStatsInAnalyzeTableEnabled val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB) @@ -231,7 +248,15 @@ object CommandUtils extends Logging { } } else { // Compute stats for the whole table - val (newTotalSize, _) = CommandUtils.calculateTotalSize(sparkSession, tableMeta) + val rowCounts: Map[TablePartitionSpec, BigInt] = + if (noScan || !partitionStatsEnabled) { + Map.empty + } else { + calculateRowCountsPerPartition(sparkSession, tableMeta, None) + } + val (newTotalSize, newPartitions) = CommandUtils.calculateTotalSize( + sparkSession, tableMeta, Some(rowCounts)) + val newRowCount = if (noScan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count())) @@ -241,6 +266,10 @@ object CommandUtils extends Logging { if (newStats.isDefined) { sessionState.catalog.alterTableStats(tableIdentWithDB, newStats) } + // Also update partition stats when the config is enabled + if (newPartitions.nonEmpty && partitionStatsEnabled) { + sessionState.catalog.alterPartitions(tableIdentWithDB, newPartitions) + } } } @@ -440,4 +469,36 @@ object CommandUtils extends Logging { case NonFatal(e) => logWarning(s"Exception when attempting to uncache $name", e) } } + + def calculateRowCountsPerPartition( + sparkSession: SparkSession, + tableMeta: CatalogTable, + partitionValueSpec: Option[TablePartitionSpec]): Map[TablePartitionSpec, BigInt] = { + val filter = if (partitionValueSpec.isDefined) { + val filters = partitionValueSpec.get.map { + case (columnName, value) => EqualTo(UnresolvedAttribute(columnName), Literal(value)) + } + filters.reduce(And) + } else { + Literal.TrueLiteral + } + + val tableDf = sparkSession.table(tableMeta.identifier) + val partitionColumns = tableMeta.partitionColumnNames.map(Column(_)) + + val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count() + + df.collect().map { r => + val partitionColumnValues = partitionColumns.indices.map { i => + if (r.isNullAt(i)) { + ExternalCatalogUtils.DEFAULT_PARTITION_NAME + } else { + r.get(i).toString + } + } + val spec = Utils.toMap(tableMeta.partitionColumnNames, partitionColumnValues) + val count = BigInt(r.getLong(partitionColumns.size)) + (spec, count) + }.toMap + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 11134a891960c..21a115486298d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -363,6 +363,84 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("SPARK-45731: update partition stats with ANALYZE TABLE") { + val tableName = "analyzeTable_part" + + def queryStats(ds: String): Option[CatalogStatistics] = { + val partition = + spark.sessionState.catalog.getPartition(TableIdentifier(tableName), Map("ds" -> ds)) + partition.stats + } + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + val expectedRowCount = 500 + + Seq(true, false).foreach { partitionStatsEnabled => + withSQLConf(SQLConf.UPDATE_PART_STATS_IN_ANALYZE_TABLE_ENABLED.key -> + partitionStatsEnabled.toString) { + withTable(tableName) { + withTempPath { path => + // Create a table with 3 partitions all located under a directory 'path' + sql( + s""" + |CREATE TABLE $tableName (key INT, value STRING) + |USING parquet + |PARTITIONED BY (ds STRING) + |LOCATION '${path.toURI}' + """.stripMargin) + + partitionDates.foreach { ds => + sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds') LOCATION '$path/ds=$ds'") + sql("SELECT * FROM src").write.mode(SaveMode.Overwrite) + .format("parquet").save(s"$path/ds=$ds") + } + + assert(getCatalogTable(tableName).stats.isEmpty) + partitionDates.foreach { ds => + assert(queryStats(ds).isEmpty) + } + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS NOSCAN") + + // Table size should also have been updated + assert(getTableStats(tableName).sizeInBytes > 0) + // Row count should NOT be updated with the `NOSCAN` option + assert(getTableStats(tableName).rowCount.isEmpty) + + partitionDates.foreach { ds => + val partStats = queryStats(ds) + if (partitionStatsEnabled) { + assert(partStats.nonEmpty) + assert(partStats.get.sizeInBytes > 0) + assert(partStats.get.rowCount.isEmpty) + } else { + assert(partStats.isEmpty) + } + } + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") + + assert(getTableStats(tableName).sizeInBytes > 0) + // Table row count should be updated + assert(getTableStats(tableName).rowCount.get == 3 * expectedRowCount) + + partitionDates.foreach { ds => + val partStats = queryStats(ds) + if (partitionStatsEnabled) { + assert(partStats.nonEmpty) + // The scan option should update partition row count + assert(partStats.get.sizeInBytes > 0) + assert(partStats.get.rowCount.get == expectedRowCount) + } else { + assert(partStats.isEmpty) + } + } + } + } + } + } + } + test("analyze single partition") { val tableName = "analyzeTable_part" From d9c5f9d6d42a51156c0aeb2aae2764a9c6f691e4 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 10 Nov 2023 10:32:58 +0900 Subject: [PATCH 080/121] [SPARK-45798][CONNECT] Assert server-side session ID ### What changes were proposed in this pull request? Without this patch, when the server would restart because of an abnormal condition, the client would not realize that this be the case. For example, when a driver OOM occurs and the driver is restarted, the client would not realize that the server is restarted and a new session is assigned. This patch fixes this behavior and asserts that the server side session ID does not change during the connection. If it does change it throws an exception like this: ``` >>> spark.range(10).collect() Traceback (most recent call last): File "", line 1, in File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/dataframe.py", line 1710, in collect table, schema = self._to_table() File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/dataframe.py", line 1722, in _to_table table, schema = self._session.client.to_table(query, self._plan.observations) File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 839, in to_table table, schema, _, _, _ = self._execute_and_fetch(req, observations) File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1295, in _execute_and_fetch for response in self._execute_and_fetch_as_iterator(req, observations): File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1273, in _execute_and_fetch_as_iterator self._handle_error(error) File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1521, in _handle_error raise error File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1266, in _execute_and_fetch_as_iterator yield from handle_response(b) File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1193, in handle_response self._verify_response_integrity(b) File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1622, in _verify_response_integrity raise SparkConnectException( pyspark.errors.exceptions.connect.SparkConnectException: Received incorrect server side session identifier for request. Please restart Spark Session. (9493c83d-cfa4-488f-9522-838ef3df90bf != c5302e8f-170d-477e-908d-299927b68fd8) ``` ### Why are the changes needed? Stability ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Existing tests cover the basic invariant. - Added new tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43664 from grundprinzip/SPARK-45798. Authored-by: Martin Grund Signed-off-by: Hyukjin Kwon --- .../sql/connect/client/ArtifactSuite.scala | 6 +- .../main/protobuf/spark/connect/base.proto | 69 ++++- .../sql/connect/client/ArtifactManager.scala | 10 +- .../CustomSparkConnectBlockingStub.scala | 34 ++- .../client/CustomSparkConnectStub.scala | 6 +- ...cutePlanResponseReattachableIterator.scala | 19 +- .../client/GrpcExceptionConverter.scala | 10 +- .../connect/client/ResponseValidator.scala | 111 ++++++++ .../connect/client/SparkConnectClient.scala | 7 +- .../client/SparkConnectStubState.scala | 43 +++ .../SparkConnectArtifactManager.scala | 2 +- .../execution/ExecuteResponseObserver.scala | 1 + .../execution/SparkConnectPlanExecution.scala | 11 +- .../connect/planner/SparkConnectPlanner.scala | 7 +- .../sql/connect/service/SessionHolder.scala | 7 + .../SparkConnectAddArtifactsHandler.scala | 2 + .../service/SparkConnectAnalyzeHandler.scala | 4 +- .../SparkConnectArtifactStatusesHandler.scala | 5 + .../service/SparkConnectConfigHandler.scala | 5 +- .../SparkConnectInterruptHandler.scala | 1 + .../SparkConnectReleaseExecuteHandler.scala | 1 + .../SparkConnectReleaseSessionHandler.scala | 3 + .../sql/connect/utils/MetricGenerator.scala | 8 +- .../sql/connect/SparkConnectServerTest.scala | 7 +- .../artifact/ArtifactManagerSuite.scala | 18 +- .../ArtifactStatusesHandlerSuite.scala | 7 +- .../service/SparkConnectServiceE2ESuite.scala | 2 + python/pyspark/sql/connect/client/core.py | 92 ++++--- python/pyspark/sql/connect/proto/base_pb2.py | 260 +++++++++--------- python/pyspark/sql/connect/proto/base_pb2.pyi | 165 +++++++++-- 30 files changed, 671 insertions(+), 252 deletions(-) create mode 100644 connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala create mode 100644 connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala index e11d2bf2e3ab7..79aba053ea04e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala @@ -45,6 +45,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach { private var retryPolicy: GrpcRetryHandler.RetryPolicy = _ private var bstub: CustomSparkConnectBlockingStub = _ private var stub: CustomSparkConnectStub = _ + private var state: SparkConnectStubState = _ private def startDummyServer(): Unit = { service = new DummySparkConnectService() @@ -58,8 +59,9 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach { private def createArtifactManager(): Unit = { channel = InProcessChannelBuilder.forName(getClass.getName).directExecutor().build() retryPolicy = GrpcRetryHandler.RetryPolicy() - bstub = new CustomSparkConnectBlockingStub(channel, retryPolicy) - stub = new CustomSparkConnectStub(channel, retryPolicy) + state = new SparkConnectStubState(channel, retryPolicy) + bstub = new CustomSparkConnectBlockingStub(channel, state) + stub = new CustomSparkConnectStub(channel, state) artifactManager = new ArtifactManager(Configuration(), "", bstub, stub) } diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 19a94a5a429f0..da089dcd75640 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -197,8 +197,12 @@ message AnalyzePlanRequest { // Response to performing analysis of the query. Contains relevant metadata to be able to // reason about the performance. +// Next ID: 16 message AnalyzePlanResponse { string session_id = 1; + // Server-side generated idempotency key that the client can use to assert that the server side + // session has not changed. + string server_side_session_id = 15; oneof result { Schema schema = 2; @@ -317,8 +321,12 @@ message ExecutePlanRequest { // The response of a query, can be one or more for each request. Responses belonging to the // same input query, carry the same `session_id`. +// Next ID: 16 message ExecutePlanResponse { string session_id = 1; + // Server-side generated idempotency key that the client can use to assert that the server side + // session has not changed. + string server_side_session_id = 15; // Identifies the ExecutePlan execution. // If set by the client in ExecutePlanRequest.operationId, that value is returned. @@ -492,8 +500,12 @@ message ConfigRequest { } // Response to the config request. +// Next ID: 5 message ConfigResponse { string session_id = 1; + // Server-side generated idempotency key that the client can use to assert that the server side + // session has not changed. + string server_side_session_id = 4; // (Optional) The result key-value pairs. // @@ -584,7 +596,17 @@ message AddArtifactsRequest { // Response to adding an artifact. Contains relevant metadata to verify successful transfer of // artifact(s). +// Next ID: 4 message AddArtifactsResponse { + // Session id in which the AddArtifact was running. + string session_id = 2; + // Server-side generated idempotency key that the client can use to assert that the server side + // session has not changed. + string server_side_session_id = 3; + + // The list of artifact(s) seen by the server. + repeated ArtifactSummary artifacts = 1; + // Metadata of an artifact. message ArtifactSummary { string name = 1; @@ -593,9 +615,6 @@ message AddArtifactsResponse { // If false, the client may choose to resend the artifact specified by `name`. bool is_crc_successful = 2; } - - // The list of artifact(s) seen by the server. - repeated ArtifactSummary artifacts = 1; } // Request to get current statuses of artifacts at the server side. @@ -626,14 +645,20 @@ message ArtifactStatusesRequest { } // Response to checking artifact statuses. +// Next ID: 4 message ArtifactStatusesResponse { + // Session id in which the ArtifactStatus was running. + string session_id = 2; + // Server-side generated idempotency key that the client can use to assert that the server side + // session has not changed. + string server_side_session_id = 3; + // A map of artifact names to their statuses. + map statuses = 1; + message ArtifactStatus { // Exists or not particular artifact at the server. bool exists = 1; } - - // A map of artifact names to their statuses. - map statuses = 1; } message InterruptRequest { @@ -678,12 +703,17 @@ message InterruptRequest { } } +// Next ID: 4 message InterruptResponse { // Session id in which the interrupt was running. string session_id = 1; + // Server-side generated idempotency key that the client can use to assert that the server side + // session has not changed. + string server_side_session_id = 3; // Operation ids of the executions which were interrupted. repeated string interrupted_ids = 2; + } message ReattachOptions { @@ -774,9 +804,13 @@ message ReleaseExecuteRequest { } } +// Next ID: 4 message ReleaseExecuteResponse { // Session id in which the release was running. string session_id = 1; + // Server-side generated idempotency key that the client can use to assert that the server side + // session has not changed. + string server_side_session_id = 3; // Operation id of the operation on which the release executed. // If the operation couldn't be found (because e.g. it was concurrently released), will be unset. @@ -803,9 +837,13 @@ message ReleaseSessionRequest { optional string client_type = 3; } +// Next ID: 3 message ReleaseSessionResponse { // Session id of the session on which the release executed. string session_id = 1; + // Server-side generated idempotency key that the client can use to assert that the server side + // session has not changed. + string server_side_session_id = 2; } message FetchErrorDetailsRequest { @@ -828,8 +866,21 @@ message FetchErrorDetailsRequest { optional string client_type = 4; } +// Next ID: 5 message FetchErrorDetailsResponse { + // Server-side generated idempotency key that the client can use to assert that the server side + // session has not changed. + string server_side_session_id = 3; + + string session_id = 4; + + // The index of the root error in errors. The field will not be set if the error is not found. + optional int32 root_error_idx = 1; + + // A list of errors. + repeated Error errors = 2; + message StackTraceElement { // The fully qualified name of the class containing the execution point. string declaring_class = 1; @@ -914,12 +965,6 @@ message FetchErrorDetailsResponse { // The structured data of a SparkThrowable exception. optional SparkThrowable spark_throwable = 5; } - - // The index of the root error in errors. The field will not be set if the error is not found. - optional int32 root_error_idx = 1; - - // A list of errors. - repeated Error errors = 2; } // Main interface for the SparkConnect service. diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala index 10c1414ba5b0e..00fba781813e9 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala @@ -123,7 +123,12 @@ class ArtifactManager( .setSessionId(sessionId) .addAllNames(Arrays.asList(artifactName)) .build() - val statuses = bstub.artifactStatus(request).getStatusesMap + val response = bstub.artifactStatus(request) + if (response.getSessionId != sessionId) { + throw new IllegalStateException( + s"Session ID mismatch: $sessionId != ${response.getSessionId}") + } + val statuses = response.getStatusesMap if (statuses.containsKey(artifactName)) { statuses.get(artifactName).getExists } else false @@ -179,6 +184,9 @@ class ArtifactManager( val responseHandler = new StreamObserver[proto.AddArtifactsResponse] { private val summaries = mutable.Buffer.empty[ArtifactSummary] override def onNext(v: AddArtifactsResponse): Unit = { + if (v.getSessionId != sessionId) { + throw new IllegalStateException(s"Session ID mismatch: $sessionId != ${v.getSessionId}") + } v.getArtifactsList.forEach { summary => summaries += summary } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index e963b4136160f..f8df2fa3f650c 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -24,14 +24,14 @@ import org.apache.spark.connect.proto._ private[connect] class CustomSparkConnectBlockingStub( channel: ManagedChannel, - retryPolicy: GrpcRetryHandler.RetryPolicy) { + stubState: SparkConnectStubState) { private val stub = SparkConnectServiceGrpc.newBlockingStub(channel) - private val retryHandler = new GrpcRetryHandler(retryPolicy) + private val retryHandler = stubState.retryHandler // GrpcExceptionConverter with a GRPC stub for fetching error details from server. - private val grpcExceptionConverter = new GrpcExceptionConverter(stub) + private val grpcExceptionConverter = stubState.exceptionConverter def executePlan(request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = { grpcExceptionConverter.convert( @@ -44,7 +44,10 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType, retryHandler.RetryIterator[ExecutePlanRequest, ExecutePlanResponse]( request, - r => CloseableIterator(stub.executePlan(r).asScala))) + r => { + stubState.responseValidator.wrapIterator( + CloseableIterator(stub.executePlan(r).asScala)) + })) } } @@ -59,7 +62,8 @@ private[connect] class CustomSparkConnectBlockingStub( request.getUserContext, request.getClientType, // Don't use retryHandler - own retry handling is inside. - new ExecutePlanResponseReattachableIterator(request, channel, retryPolicy)) + stubState.responseValidator.wrapIterator( + new ExecutePlanResponseReattachableIterator(request, channel, stubState.retryPolicy))) } } @@ -69,7 +73,9 @@ private[connect] class CustomSparkConnectBlockingStub( request.getUserContext, request.getClientType) { retryHandler.retry { - stub.analyzePlan(request) + stubState.responseValidator.verifyResponse { + stub.analyzePlan(request) + } } } } @@ -80,7 +86,9 @@ private[connect] class CustomSparkConnectBlockingStub( request.getUserContext, request.getClientType) { retryHandler.retry { - stub.config(request) + stubState.responseValidator.verifyResponse { + stub.config(request) + } } } } @@ -91,7 +99,9 @@ private[connect] class CustomSparkConnectBlockingStub( request.getUserContext, request.getClientType) { retryHandler.retry { - stub.interrupt(request) + stubState.responseValidator.verifyResponse { + stub.interrupt(request) + } } } } @@ -102,7 +112,9 @@ private[connect] class CustomSparkConnectBlockingStub( request.getUserContext, request.getClientType) { retryHandler.retry { - stub.releaseSession(request) + stubState.responseValidator.verifyResponse { + stub.releaseSession(request) + } } } } @@ -113,7 +125,9 @@ private[connect] class CustomSparkConnectBlockingStub( request.getUserContext, request.getClientType) { retryHandler.retry { - stub.artifactStatus(request) + stubState.responseValidator.verifyResponse { + stub.artifactStatus(request) + } } } } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala index a388260f1cfc6..382bc87069558 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala @@ -23,13 +23,13 @@ import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse private[client] class CustomSparkConnectStub( channel: ManagedChannel, - retryPolicy: GrpcRetryHandler.RetryPolicy) { + stubState: SparkConnectStubState) { private val stub = SparkConnectServiceGrpc.newStub(channel) - private val retryHandler = new GrpcRetryHandler(retryPolicy) def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse]) : StreamObserver[AddArtifactsRequest] = { - retryHandler.RetryStreamObserver(responseObserver, stub.addArtifacts) + stubState.responseValidator.wrapStreamObserver( + stubState.retryHandler.RetryStreamObserver(responseObserver, stub.addArtifacts)) } } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 9d134f5935446..9fd8b12fef96f 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -97,6 +97,10 @@ class ExecutePlanResponseReattachableIterator( private[connect] var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] = Some(rawBlockingStub.executePlan(initialRequest)) + // Server side session ID, used to detect if the server side session changed. This is set upon + // receiving the first response from the server. + private var serverSideSessionId: Option[String] = None + override def innerIterator: Iterator[proto.ExecutePlanResponse] = iter match { case Some(it) => it.asScala case None => @@ -114,10 +118,23 @@ class ExecutePlanResponseReattachableIterator( try { // Get next response, possibly triggering reattach in case of stream error. - val ret = retry { + val ret: proto.ExecutePlanResponse = retry { callIter(_.next()) } + // Check if the server-side session state has changed. If this is the case, immediately + // abandon execution. + serverSideSessionId match { + case Some(id) => + if (id != ret.getServerSideSessionId) { + throw new IllegalStateException( + s"Server side session ID changed. Create a new SparkSession to continue. " + + s"(Old: $id, New: ${ret.getServerSideSessionId})") + } + case None => + serverSideSessionId = Some(ret.getServerSideSessionId) + } + // Record last returned response, to know where to restart in case of reattach. lastReturnedResponseId = Some(ret.getResponseId) if (ret.hasResultComplete) { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 88cd2118ba755..52bd276b0c4b5 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -22,14 +22,13 @@ import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import com.google.rpc.ErrorInfo -import io.grpc.StatusRuntimeException +import io.grpc.{ManagedChannel, StatusRuntimeException} import io.grpc.protobuf.StatusProto import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods import org.apache.spark.{QueryContext, QueryContextType, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} -import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, FetchErrorDetailsResponse, UserContext} -import org.apache.spark.connect.proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub +import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, FetchErrorDetailsResponse, SparkConnectServiceGrpc, UserContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException, TempTableAlreadyExistsException} @@ -49,10 +48,11 @@ import org.apache.spark.util.ArrayImplicits._ * the ErrorInfo is missing, the exception will be constructed based on the StatusRuntimeException * itself. */ -private[client] class GrpcExceptionConverter(grpcStub: SparkConnectServiceBlockingStub) - extends Logging { +private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Logging { import GrpcExceptionConverter._ + val grpcStub = SparkConnectServiceGrpc.newBlockingStub(channel) + def convert[T](sessionId: String, userContext: UserContext, clientType: String)(f: => T): T = { try { f diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala new file mode 100644 index 0000000000000..2081196d46711 --- /dev/null +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import com.google.protobuf.GeneratedMessageV3 +import io.grpc.stub.StreamObserver + +import org.apache.spark.internal.Logging + +// This is common logic to be shared between different stub instances to validate responses as +// seen by the client. +class ResponseValidator extends Logging { + + // Server side session ID, used to detect if the server side session changed. This is set upon + // receiving the first response from the server. This value is used only for executions that + // do not use server-side streaming. + private var serverSideSessionId: Option[String] = None + + def verifyResponse[RespT <: GeneratedMessageV3](fn: => RespT): RespT = { + val response = fn + val field = response.getDescriptorForType.findFieldByName("server_side_session_id") + // If the field does not exist, we ignore it. New / Old message might not contain it and this + // behavior allows us to be compatible. + if (field != null) { + val value = response.getField(field).asInstanceOf[String] + // Ignore, if the value is unset. + if (response.hasField(field) && value != null && value.nonEmpty) { + serverSideSessionId match { + case Some(id) if value != id && value != "" => + throw new IllegalStateException(s"Server side session ID changed from $id to $value") + case _ if value != "" => + synchronized { + serverSideSessionId = Some(value.toString) + } + case _ => // No-op + } + } + } else { + logDebug("Server side session ID field not found in response - Ignoring.") + } + response + } + + /** + * Wraps an existing iterator with another closeable iterator that verifies the response. This + * is needed for server-side streaming calls that are converted to iterators. + */ + def wrapIterator[T <: GeneratedMessageV3, V <: CloseableIterator[T]]( + inner: V): WrappedCloseableIterator[T] = { + new WrappedCloseableIterator[T] { + + override def innerIterator: Iterator[T] = inner + + override def hasNext: Boolean = { + innerIterator.hasNext + } + + override def next(): T = { + verifyResponse { + innerIterator.next() + } + } + + override def close(): Unit = { + innerIterator match { + case it: CloseableIterator[T] => it.close() + case _ => // nothing + } + } + } + } + + /** + * Wraps an existing stream observer with another stream observer that verifies the response. + * This is necessary for client-side streaming calls. + */ + def wrapStreamObserver[T <: GeneratedMessageV3](inner: StreamObserver[T]): StreamObserver[T] = { + new StreamObserver[T] { + private val innerObserver = inner + override def onNext(value: T): Unit = { + try { + innerObserver.onNext(verifyResponse(value)) + } catch { + case e: Exception => + onError(e) + } + } + override def onError(t: Throwable): Unit = { + innerObserver.onError(t) + } + override def onCompleted(): Unit = { + innerObserver.onCompleted() + } + } + } + +} diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 6d3d9420e2263..9fc74f1af2c2f 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -43,8 +43,11 @@ private[sql] class SparkConnectClient( private val userContext: UserContext = configuration.userContext - private[this] val bstub = new CustomSparkConnectBlockingStub(channel, configuration.retryPolicy) - private[this] val stub = new CustomSparkConnectStub(channel, configuration.retryPolicy) + private[this] val stubState = new SparkConnectStubState(channel, configuration.retryPolicy) + private[this] val bstub = + new CustomSparkConnectBlockingStub(channel, stubState) + private[this] val stub = + new CustomSparkConnectStub(channel, stubState) private[client] def userAgent: String = configuration.userAgent diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala new file mode 100644 index 0000000000000..e6c7ebf9211ed --- /dev/null +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import io.grpc.ManagedChannel + +import org.apache.spark.internal.Logging + +// This is common state shared between the blocking and non-blocking stubs. +// +// The common logic is responsible to verify the integrity of the response. The invariant is +// that the same stub instance is used for all requests from the same client. In addition, +// this class provides access to the commonly configured retry policy and exception conversion +// logic. +class SparkConnectStubState( + channel: ManagedChannel, + val retryPolicy: GrpcRetryHandler.RetryPolicy) + extends Logging { + + // Responsible to convert the GRPC Status exceptions into Spark exceptions. + lazy val exceptionConverter: GrpcExceptionConverter = new GrpcExceptionConverter(channel) + + // Manages the retry handler logic used by the stubs. + lazy val retryHandler = new GrpcRetryHandler(retryPolicy) + + // Provides a helper for validating the responses processed by the stub. + lazy val responseValidator = new ResponseValidator() + +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index ad551c4b0f540..ba36b708e83a0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -64,7 +64,7 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging // The base directory/URI where all class file artifacts are stored for this `sessionUUID`. val (classDir, classURI): (Path, String) = getClassfileDirectoryAndUriForSession(sessionHolder) val state: JobArtifactState = - JobArtifactState(sessionHolder.session.sessionUUID, Option(classURI)) + JobArtifactState(sessionHolder.serverSessionId, Option(classURI)) private val jarsList = new CopyOnWriteArrayList[Path] private val pythonIncludeList = new CopyOnWriteArrayList[String] diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala index 3c416bb3ab64b..b5844486b73aa 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala @@ -295,6 +295,7 @@ private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder: executePlanResponse .toBuilder() .setSessionId(executeHolder.sessionHolder.sessionId) + .setServerSideSessionId(executeHolder.sessionHolder.serverSessionId) .setOperationId(executeHolder.operationId) .setResponseId(UUID.randomUUID.toString) .build() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 9dd54e5b2b5d3..002239aba96ee 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -65,8 +65,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) tracker) responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema)) processAsArrowBatches(dataframe, responseObserver, executeHolder) - responseObserver.onNext( - MetricGenerator.createMetricsResponse(request.getSessionId, dataframe)) + responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe)) if (dataframe.queryExecution.observedMetrics.nonEmpty) { responseObserver.onNext(createObservedMetricsResponse(request.getSessionId, dataframe)) } @@ -111,7 +110,11 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) var numSent = 0 def sendBatch(bytes: Array[Byte], count: Long, startOffset: Long): Unit = { - val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId) + val response = proto.ExecutePlanResponse + .newBuilder() + .setSessionId(sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) + val batch = proto.ExecutePlanResponse.ArrowBatch .newBuilder() .setRowCount(count) @@ -235,6 +238,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) ExecutePlanResponse .newBuilder() .setSessionId(sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) .setSchema(DataTypeProtoConverter.toConnectProtoType(schema)) .build() } @@ -257,6 +261,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) ExecutePlanResponse .newBuilder() .setSessionId(sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) .addAllObservedMetrics(observedMetrics.asJava) .build() } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 018e293795e9d..8b852babb544a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2602,11 +2602,12 @@ class SparkConnectPlanner( ExecutePlanResponse .newBuilder() .setSessionId(sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) .setSqlCommandResult(result) .build()) // Send Metrics - responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionId, df)) + responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, df)) } private def handleRegisterUserDefinedFunction( @@ -2994,6 +2995,7 @@ class SparkConnectPlanner( ExecutePlanResponse .newBuilder() .setSessionId(sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) .setWriteStreamOperationStartResult(result) .build()) } @@ -3114,6 +3116,7 @@ class SparkConnectPlanner( ExecutePlanResponse .newBuilder() .setSessionId(sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) .setStreamingQueryCommandResult(respBuilder.build()) .build()) } @@ -3251,6 +3254,7 @@ class SparkConnectPlanner( ExecutePlanResponse .newBuilder() .setSessionId(sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) .setStreamingQueryManagerCommandResult(respBuilder.build()) .build()) } @@ -3262,6 +3266,7 @@ class SparkConnectPlanner( proto.ExecutePlanResponse .newBuilder() .setSessionId(sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) .setGetResourcesCommandResult( proto.GetResourcesCommandResult .newBuilder() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 792012a682b28..0c55e30ba5010 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -82,6 +82,13 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio def key: SessionKey = SessionKey(userId, sessionId) + // Returns the server side session ID and asserts that it must be different from the client-side + // session ID. + def serverSessionId: String = { + assert(session.sessionUUID != sessionId) + session.sessionUUID + } + /** * Add ExecuteHolder to this session. * diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala index d9de2a8094d5e..636054198fbf0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala @@ -119,6 +119,8 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr val artifactSummaries = flushStagedArtifacts() // Add the artifacts to the session and return the summaries to the client. val builder = proto.AddArtifactsResponse.newBuilder() + builder.setSessionId(holder.sessionId) + builder.setServerSideSessionId(holder.serverSessionId) artifactSummaries.foreach(summary => builder.addArtifacts(summary)) // Delete temp dir cleanUpStagedArtifacts() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index 9d1cf9e36d094..f6fb42d9fcaa0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -201,7 +201,9 @@ private[connect] class SparkConnectAnalyzeHandler( case other => throw InvalidPlanInput(s"Unknown Analyze Method $other!") } - builder.setSessionId(request.getSessionId) + builder + .setSessionId(request.getSessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) builder.build() } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala index 5699dd11bde3f..325832ac07e67 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala @@ -37,7 +37,12 @@ class SparkConnectArtifactStatusesHandler( } def handle(request: proto.ArtifactStatusesRequest): Unit = { + val holder = SparkConnectService + .getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId) + val builder = proto.ArtifactStatusesResponse.newBuilder() + builder.setSessionId(holder.sessionId) + builder.setServerSideSessionId(holder.serverSessionId) request.getNamesList().iterator().asScala.foreach { name => val status = proto.ArtifactStatusesResponse.ArtifactStatus.newBuilder() val exists = if (name.startsWith("cache/")) { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectConfigHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectConfigHandler.scala index 9e514f4f65d8c..8a7ce6d7b48f8 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectConfigHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectConfigHandler.scala @@ -30,10 +30,10 @@ class SparkConnectConfigHandler(responseObserver: StreamObserver[proto.ConfigRes extends Logging { def handle(request: proto.ConfigRequest): Unit = { - val session = + val holder = SparkConnectService .getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId) - .session + val session = holder.session val builder = request.getOperation.getOpTypeCase match { case proto.ConfigRequest.Operation.OpTypeCase.SET => @@ -54,6 +54,7 @@ class SparkConnectConfigHandler(responseObserver: StreamObserver[proto.ConfigRes } builder.setSessionId(request.getSessionId) + builder.setServerSideSessionId(holder.serverSessionId) responseObserver.onNext(builder.build()) responseObserver.onCompleted() } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala index 97b57c4940b62..9e1ab16208f2e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala @@ -54,6 +54,7 @@ class SparkConnectInterruptHandler(responseObserver: StreamObserver[proto.Interr val response = proto.InterruptResponse .newBuilder() .setSessionId(v.getSessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) .addAllInterruptedIds(interruptedIds.asJava) .build() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala index 1ca886960d536..88c1456602d3b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala @@ -34,6 +34,7 @@ class SparkConnectReleaseExecuteHandler( val responseBuilder = proto.ReleaseExecuteResponse .newBuilder() .setSessionId(v.getSessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) // ExecuteHolder may be concurrently released by SparkConnectExecutionManager if the // ReleaseExecute arrived after it was abandoned and timed out. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala index a32852bac45ea..c8a3ceab674ff 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala @@ -32,6 +32,9 @@ class SparkConnectReleaseSessionHandler( // If the session doesn't exist, this will just be a noop. val key = SessionKey(v.getUserContext.getUserId, v.getSessionId) + // if the session is present, update the server-side session ID. + val maybeSession = SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key) + maybeSession.foreach(f => responseBuilder.setServerSideSessionId(f.serverSessionId)) SparkConnectService.sessionManager.closeSession(key) responseObserver.onNext(responseBuilder.build()) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala index 0ddaf4d0c1312..c9bba653e8a8f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -29,10 +30,13 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper */ private[connect] object MetricGenerator extends AdaptiveSparkPlanHelper { - def createMetricsResponse(sessionId: String, rows: DataFrame): ExecutePlanResponse = { + def createMetricsResponse( + sessionHolder: SessionHolder, + rows: DataFrame): ExecutePlanResponse = { ExecutePlanResponse .newBuilder() - .setSessionId(sessionId) + .setSessionId(sessionHolder.sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) .setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan)) .build() } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index 120126f20ec24..c4a5539ce0b7b 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -27,7 +27,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, GrpcRetryHandler, SparkConnectClient, WrappedCloseableIterator} +import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, GrpcRetryHandler, SparkConnectClient, SparkConnectStubState, WrappedCloseableIterator} import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.connect.common.config.ConnectCommon import org.apache.spark.sql.connect.config.Connect @@ -150,6 +150,8 @@ trait SparkConnectServerTest extends SharedSparkSession { // This depends on the wrapping in CustomSparkConnectBlockingStub.executePlanReattachable: // GrpcExceptionConverter.convertIterator stubIterator + .asInstanceOf[WrappedCloseableIterator[proto.ExecutePlanResponse]] + .innerIterator .asInstanceOf[WrappedCloseableIterator[proto.ExecutePlanResponse]] // ExecutePlanResponseReattachableIterator .innerIterator @@ -254,7 +256,8 @@ trait SparkConnectServerTest extends SharedSparkSession { f: CustomSparkConnectBlockingStub => Unit): Unit = { val conf = SparkConnectClient.Configuration(port = serverPort) val channel = conf.createChannel() - val bstub = new CustomSparkConnectBlockingStub(channel, retryPolicy) + val stubState = new SparkConnectStubState(channel, retryPolicy) + val bstub = new CustomSparkConnectBlockingStub(channel, stubState) try f(bstub) finally { channel.shutdownNow() diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index f34c4f770885e..0c095384de86b 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -41,10 +41,10 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { } private val artifactPath = commonResourcePath.resolve("artifact-tests") - private def sessionHolder(): SessionHolder = { - SessionHolder("test", spark.sessionUUID, spark) + private lazy val sessionHolder: SessionHolder = { + SessionHolder("test", UUID.randomUUID().toString, spark) } - private lazy val artifactManager = new SparkConnectArtifactManager(sessionHolder()) + private lazy val artifactManager = new SparkConnectArtifactManager(sessionHolder) private def sessionUUID: String = spark.sessionUUID @@ -125,7 +125,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = path.toPath Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) val remotePath = Paths.get("cache/abc") - val session = sessionHolder() + val session = sessionHolder val blockManager = spark.sparkContext.env.blockManager val blockId = CacheId(session.userId, session.sessionId, "abc") try { @@ -193,7 +193,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = path.toPath Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) val remotePath = Paths.get("cache/abc") - val session = sessionHolder() + val session = sessionHolder val blockManager = spark.sparkContext.env.blockManager val blockId = CacheId(session.userId, session.sessionId, "abc") // Setup artifact dir @@ -294,15 +294,15 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") - val sessionHolder = + val holder = SparkConnectService.getOrCreateIsolatedSession("c1", UUID.randomUUID.toString) - sessionHolder.addArtifact(remotePath, stagingPath, None) + holder.addArtifact(remotePath, stagingPath, None) val sessionDirectory = - SparkConnectArtifactManager.getArtifactDirectoryAndUriForSession(sessionHolder)._1.toFile + SparkConnectArtifactManager.getArtifactDirectoryAndUriForSession(holder)._1.toFile assert(sessionDirectory.exists()) - sessionHolder.artifactManager.cleanUpResources() + holder.artifactManager.cleanUpResources() assert(!sessionDirectory.exists()) assert(SparkConnectArtifactManager.artifactRootPath.toFile.exists()) } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala index 54f7396ec097b..8fabcf61cb6f3 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.connect.service +import java.util.UUID + import scala.concurrent.Promise import scala.concurrent.duration._ import scala.jdk.CollectionConverters._ @@ -37,6 +39,9 @@ private class DummyStreamObserver(p: Promise[ArtifactStatusesResponse]) } class ArtifactStatusesHandlerSuite extends SharedSparkSession with ResourceHelper { + + val sessionId = UUID.randomUUID().toString + def getStatuses(names: Seq[String], exist: Set[String]): ArtifactStatusesResponse = { val promise = Promise[ArtifactStatusesResponse]() val handler = new SparkConnectArtifactStatusesHandler(new DummyStreamObserver(promise)) { @@ -54,7 +59,7 @@ class ArtifactStatusesHandlerSuite extends SharedSparkSession with ResourceHelpe val request = proto.ArtifactStatusesRequest .newBuilder() .setUserContext(context) - .setSessionId("abc") + .setSessionId(sessionId) .addAllNames(names.asJava) .build() handler.handle(request) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala index cc0481dab0f4f..454dd0c74b3d5 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala @@ -46,6 +46,8 @@ class SparkConnectServiceE2ESuite extends SparkConnectServerTest { // Close session client.releaseSession() + // Calling release session again should be a no-op. + client.releaseSession() // Check that queries get cancelled Eventually.eventually(timeout(eventuallyTimeout)) { diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index cef0ea4f305df..965c4107cacee 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -95,8 +95,7 @@ from pyspark.sql.types import DataType, StructType, TimestampType, _has_type from pyspark.rdd import PythonEvalType from pyspark.storagelevel import StorageLevel -from pyspark.errors import PySparkValueError - +from pyspark.errors import PySparkValueError, PySparkAssertionError if TYPE_CHECKING: from google.rpc.error_details_pb2 import ErrorInfo @@ -676,6 +675,10 @@ def __init__( self._use_reattachable_execute = use_reattachable_execute # Configure logging for the SparkConnect client. + # Capture the server-side session ID and set it to None initially. It will + # be updated on the first response received. + self._server_session_id: Optional[str] = None + def _retrying(self) -> "Retrying": return Retrying( can_retry=SparkConnectClient.retry_exception, **self._retry_policy # type: ignore @@ -1121,11 +1124,7 @@ def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: for attempt in self._retrying(): with attempt: resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) - if resp.session_id != self._session_id: - raise SparkConnectException( - "Received incorrect session identifier for request:" - f"{resp.session_id} != {self._session_id}" - ) + self._verify_response_integrity(resp) return AnalyzeResult.fromProto(resp) raise SparkConnectException("Invalid state during retry exception handling.") except Exception as error: @@ -1144,11 +1143,7 @@ def _execute(self, req: pb2.ExecutePlanRequest) -> None: logger.info("Execute") def handle_response(b: pb2.ExecutePlanResponse) -> None: - if b.session_id != self._session_id: - raise SparkConnectException( - "Received incorrect session identifier for request: " - f"{b.session_id} != {self._session_id}" - ) + self._verify_response_integrity(b) try: if self._use_reattachable_execute: @@ -1193,11 +1188,8 @@ def handle_response( ] ]: nonlocal num_records - if b.session_id != self._session_id: - raise SparkConnectException( - "Received incorrect session identifier for request: " - f"{b.session_id} != {self._session_id}" - ) + # The session ID is the local session ID and should match what we expect. + self._verify_response_integrity(b) if b.HasField("metrics"): logger.debug("Received metric batch.") yield from self._build_metrics(b.metrics) @@ -1387,11 +1379,7 @@ def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult: for attempt in self._retrying(): with attempt: resp = self._stub.Config(req, metadata=self._builder.metadata()) - if resp.session_id != self._session_id: - raise SparkConnectException( - "Received incorrect session identifier for request:" - f"{resp.session_id} != {self._session_id}" - ) + self._verify_response_integrity(resp) return ConfigResult.fromProto(resp) raise SparkConnectException("Invalid state during retry exception handling.") except Exception as error: @@ -1430,11 +1418,7 @@ def interrupt_all(self) -> Optional[List[str]]: for attempt in self._retrying(): with attempt: resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) - if resp.session_id != self._session_id: - raise SparkConnectException( - "Received incorrect session identifier for request:" - f"{resp.session_id} != {self._session_id}" - ) + self._verify_response_integrity(resp) return list(resp.interrupted_ids) raise SparkConnectException("Invalid state during retry exception handling.") except Exception as error: @@ -1446,11 +1430,7 @@ def interrupt_tag(self, tag: str) -> Optional[List[str]]: for attempt in self._retrying(): with attempt: resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) - if resp.session_id != self._session_id: - raise SparkConnectException( - "Received incorrect session identifier for request:" - f"{resp.session_id} != {self._session_id}" - ) + self._verify_response_integrity(resp) return list(resp.interrupted_ids) raise SparkConnectException("Invalid state during retry exception handling.") except Exception as error: @@ -1462,11 +1442,7 @@ def interrupt_operation(self, op_id: str) -> Optional[List[str]]: for attempt in self._retrying(): with attempt: resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) - if resp.session_id != self._session_id: - raise SparkConnectException( - "Received incorrect session identifier for request:" - f"{resp.session_id} != {self._session_id}" - ) + self._verify_response_integrity(resp) return list(resp.interrupted_ids) raise SparkConnectException("Invalid state during retry exception handling.") except Exception as error: @@ -1482,11 +1458,7 @@ def release_session(self) -> None: for attempt in self._retrying(): with attempt: resp = self._stub.ReleaseSession(req, metadata=self._builder.metadata()) - if resp.session_id != self._session_id: - raise SparkConnectException( - "Received incorrect session identifier for request:" - f"{resp.session_id} != {self._session_id}" - ) + self._verify_response_integrity(resp) return raise SparkConnectException("Invalid state during retry exception handling.") except Exception as error: @@ -1631,6 +1603,42 @@ def cache_artifact(self, blob: bytes) -> str: return self._artifact_manager.cache_artifact(blob) raise SparkConnectException("Invalid state during retry exception handling.") + def _verify_response_integrity( + self, + response: Union[ + pb2.ConfigResponse, + pb2.ExecutePlanResponse, + pb2.InterruptResponse, + pb2.ReleaseExecuteResponse, + pb2.AddArtifactsResponse, + pb2.AnalyzePlanResponse, + pb2.FetchErrorDetailsResponse, + pb2.ReleaseSessionResponse, + ], + ) -> None: + """ + Verifies the integrity of the response. This method checks if the session ID and the + server-side session ID match. If not, it throws an exception. + Parameters + ---------- + response - One of the different response types handled by the Spark Connect service + """ + if self._session_id != response.session_id: + raise PySparkAssertionError( + "Received incorrect session identifier for request:" + f"{response.session_id} != {self._session_id}" + ) + if self._server_session_id is not None: + if response.server_side_session_id != self._server_session_id: + raise PySparkAssertionError( + "Received incorrect server side session identifier for request. " + "Please create a new Spark Session to reconnect. (" + f"{response.server_side_session_id} != {self._server_session_id})" + ) + else: + # Update the server side session ID. + self._server_session_id = response.server_side_session_id + class RetryState: """ diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 0e374e7aa2ccb..e23f3bdaaaa4f 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x0f\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"p\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xab\x01\n\x15ReleaseSessionRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"7\n\x16ReleaseSessionResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xbe\x0b\n\x19\x46\x65tchErrorDetailsResponse\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xef\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1a\n\x08\x63\x61llSite\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xb2\x07\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12_\n\x0eReleaseSession\x12$.spark.connect.ReleaseSessionRequest\x1a%.spark.connect.ReleaseSessionResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\xce\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x0f \x01(\tR\x13serverSideSessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\x9b\x10\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x0f \x01(\tR\x13serverSideSessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"\xaf\x01\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x04 \x01(\tR\x13serverSideSessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\x90\x02\n\x14\x41\x64\x64\x41rtifactsResponse\x12\x1d\n\nsession_id\x18\x02 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\xe0\x02\n\x18\x41rtifactStatusesResponse\x12\x1d\n\nsession_id\x18\x02 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"\x90\x01\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"\xa5\x01\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xab\x01\n\x15ReleaseSessionRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"l\n\x16ReleaseSessionResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x02 \x01(\tR\x13serverSideSessionId"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\x92\x0c\n\x19\x46\x65tchErrorDetailsResponse\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12\x1d\n\nsession_id\x18\x04 \x01(\tR\tsessionId\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xef\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1a\n\x08\x63\x61llSite\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xb2\x07\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12_\n\x0eReleaseSession\x12$.spark.connect.ReleaseSessionRequest\x1a%.spark.connect.ReleaseSessionResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -88,137 +88,137 @@ _ANALYZEPLANREQUEST_GETSTORAGELEVEL._serialized_start = 2786 _ANALYZEPLANREQUEST_GETSTORAGELEVEL._serialized_end = 2856 _ANALYZEPLANRESPONSE._serialized_start = 2886 - _ANALYZEPLANRESPONSE._serialized_end = 4575 - _ANALYZEPLANRESPONSE_SCHEMA._serialized_start = 3994 - _ANALYZEPLANRESPONSE_SCHEMA._serialized_end = 4051 - _ANALYZEPLANRESPONSE_EXPLAIN._serialized_start = 4053 - _ANALYZEPLANRESPONSE_EXPLAIN._serialized_end = 4101 - _ANALYZEPLANRESPONSE_TREESTRING._serialized_start = 4103 - _ANALYZEPLANRESPONSE_TREESTRING._serialized_end = 4148 - _ANALYZEPLANRESPONSE_ISLOCAL._serialized_start = 4150 - _ANALYZEPLANRESPONSE_ISLOCAL._serialized_end = 4186 - _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_start = 4188 - _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_end = 4236 - _ANALYZEPLANRESPONSE_INPUTFILES._serialized_start = 4238 - _ANALYZEPLANRESPONSE_INPUTFILES._serialized_end = 4272 - _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_start = 4274 - _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_end = 4314 - _ANALYZEPLANRESPONSE_DDLPARSE._serialized_start = 4316 - _ANALYZEPLANRESPONSE_DDLPARSE._serialized_end = 4375 - _ANALYZEPLANRESPONSE_SAMESEMANTICS._serialized_start = 4377 - _ANALYZEPLANRESPONSE_SAMESEMANTICS._serialized_end = 4416 - _ANALYZEPLANRESPONSE_SEMANTICHASH._serialized_start = 4418 - _ANALYZEPLANRESPONSE_SEMANTICHASH._serialized_end = 4456 + _ANALYZEPLANRESPONSE._serialized_end = 4628 + _ANALYZEPLANRESPONSE_SCHEMA._serialized_start = 4047 + _ANALYZEPLANRESPONSE_SCHEMA._serialized_end = 4104 + _ANALYZEPLANRESPONSE_EXPLAIN._serialized_start = 4106 + _ANALYZEPLANRESPONSE_EXPLAIN._serialized_end = 4154 + _ANALYZEPLANRESPONSE_TREESTRING._serialized_start = 4156 + _ANALYZEPLANRESPONSE_TREESTRING._serialized_end = 4201 + _ANALYZEPLANRESPONSE_ISLOCAL._serialized_start = 4203 + _ANALYZEPLANRESPONSE_ISLOCAL._serialized_end = 4239 + _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_start = 4241 + _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_end = 4289 + _ANALYZEPLANRESPONSE_INPUTFILES._serialized_start = 4291 + _ANALYZEPLANRESPONSE_INPUTFILES._serialized_end = 4325 + _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_start = 4327 + _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_end = 4367 + _ANALYZEPLANRESPONSE_DDLPARSE._serialized_start = 4369 + _ANALYZEPLANRESPONSE_DDLPARSE._serialized_end = 4428 + _ANALYZEPLANRESPONSE_SAMESEMANTICS._serialized_start = 4430 + _ANALYZEPLANRESPONSE_SAMESEMANTICS._serialized_end = 4469 + _ANALYZEPLANRESPONSE_SEMANTICHASH._serialized_start = 4471 + _ANALYZEPLANRESPONSE_SEMANTICHASH._serialized_end = 4509 _ANALYZEPLANRESPONSE_PERSIST._serialized_start = 2521 _ANALYZEPLANRESPONSE_PERSIST._serialized_end = 2530 _ANALYZEPLANRESPONSE_UNPERSIST._serialized_start = 2674 _ANALYZEPLANRESPONSE_UNPERSIST._serialized_end = 2685 - _ANALYZEPLANRESPONSE_GETSTORAGELEVEL._serialized_start = 4482 - _ANALYZEPLANRESPONSE_GETSTORAGELEVEL._serialized_end = 4565 - _EXECUTEPLANREQUEST._serialized_start = 4578 - _EXECUTEPLANREQUEST._serialized_end = 5122 - _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 4924 - _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 5089 - _EXECUTEPLANRESPONSE._serialized_start = 5125 - _EXECUTEPLANRESPONSE._serialized_end = 7147 - _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 6283 - _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 6354 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 6356 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 6474 - _EXECUTEPLANRESPONSE_METRICS._serialized_start = 6477 - _EXECUTEPLANRESPONSE_METRICS._serialized_end = 6994 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 6572 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 6904 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 6781 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 6904 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 6906 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 6994 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 6996 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7112 - _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7114 - _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7130 - _KEYVALUE._serialized_start = 7149 - _KEYVALUE._serialized_end = 7214 - _CONFIGREQUEST._serialized_start = 7217 - _CONFIGREQUEST._serialized_end = 8245 - _CONFIGREQUEST_OPERATION._serialized_start = 7437 - _CONFIGREQUEST_OPERATION._serialized_end = 7935 - _CONFIGREQUEST_SET._serialized_start = 7937 - _CONFIGREQUEST_SET._serialized_end = 7989 - _CONFIGREQUEST_GET._serialized_start = 7991 - _CONFIGREQUEST_GET._serialized_end = 8016 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 8018 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 8081 - _CONFIGREQUEST_GETOPTION._serialized_start = 8083 - _CONFIGREQUEST_GETOPTION._serialized_end = 8114 - _CONFIGREQUEST_GETALL._serialized_start = 8116 - _CONFIGREQUEST_GETALL._serialized_end = 8164 - _CONFIGREQUEST_UNSET._serialized_start = 8166 - _CONFIGREQUEST_UNSET._serialized_end = 8193 - _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 8195 - _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 8229 - _CONFIGRESPONSE._serialized_start = 8247 - _CONFIGRESPONSE._serialized_end = 8369 - _ADDARTIFACTSREQUEST._serialized_start = 8372 - _ADDARTIFACTSREQUEST._serialized_end = 9243 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 8759 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 8812 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 8814 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 8925 - _ADDARTIFACTSREQUEST_BATCH._serialized_start = 8927 - _ADDARTIFACTSREQUEST_BATCH._serialized_end = 9020 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 9023 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 9216 - _ADDARTIFACTSRESPONSE._serialized_start = 9246 - _ADDARTIFACTSRESPONSE._serialized_end = 9434 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 9353 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 9434 - _ARTIFACTSTATUSESREQUEST._serialized_start = 9437 - _ARTIFACTSTATUSESREQUEST._serialized_end = 9632 - _ARTIFACTSTATUSESRESPONSE._serialized_start = 9635 - _ARTIFACTSTATUSESRESPONSE._serialized_end = 9903 - _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 9746 - _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9786 - _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9788 - _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9903 - _INTERRUPTREQUEST._serialized_start = 9906 - _INTERRUPTREQUEST._serialized_end = 10378 - _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 10221 - _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 10349 - _INTERRUPTRESPONSE._serialized_start = 10380 - _INTERRUPTRESPONSE._serialized_end = 10471 - _REATTACHOPTIONS._serialized_start = 10473 - _REATTACHOPTIONS._serialized_end = 10526 - _REATTACHEXECUTEREQUEST._serialized_start = 10529 - _REATTACHEXECUTEREQUEST._serialized_end = 10804 - _RELEASEEXECUTEREQUEST._serialized_start = 10807 - _RELEASEEXECUTEREQUEST._serialized_end = 11261 - _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 11173 - _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 11185 - _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 11187 - _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11234 - _RELEASEEXECUTERESPONSE._serialized_start = 11263 - _RELEASEEXECUTERESPONSE._serialized_end = 11375 - _RELEASESESSIONREQUEST._serialized_start = 11378 - _RELEASESESSIONREQUEST._serialized_end = 11549 - _RELEASESESSIONRESPONSE._serialized_start = 11551 - _RELEASESESSIONRESPONSE._serialized_end = 11606 - _FETCHERRORDETAILSREQUEST._serialized_start = 11609 - _FETCHERRORDETAILSREQUEST._serialized_end = 11810 - _FETCHERRORDETAILSRESPONSE._serialized_start = 11813 - _FETCHERRORDETAILSRESPONSE._serialized_end = 13283 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11958 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 12132 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 12135 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12502 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 12465 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12502 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12505 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12914 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 12816 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 12884 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12917 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13264 - _SPARKCONNECTSERVICE._serialized_start = 13286 - _SPARKCONNECTSERVICE._serialized_end = 14232 + _ANALYZEPLANRESPONSE_GETSTORAGELEVEL._serialized_start = 4535 + _ANALYZEPLANRESPONSE_GETSTORAGELEVEL._serialized_end = 4618 + _EXECUTEPLANREQUEST._serialized_start = 4631 + _EXECUTEPLANREQUEST._serialized_end = 5175 + _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 4977 + _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 5142 + _EXECUTEPLANRESPONSE._serialized_start = 5178 + _EXECUTEPLANRESPONSE._serialized_end = 7253 + _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 6389 + _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 6460 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 6462 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 6580 + _EXECUTEPLANRESPONSE_METRICS._serialized_start = 6583 + _EXECUTEPLANRESPONSE_METRICS._serialized_end = 7100 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 6678 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 7010 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 6887 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 7010 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 7012 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 7100 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 7102 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7218 + _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7220 + _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7236 + _KEYVALUE._serialized_start = 7255 + _KEYVALUE._serialized_end = 7320 + _CONFIGREQUEST._serialized_start = 7323 + _CONFIGREQUEST._serialized_end = 8351 + _CONFIGREQUEST_OPERATION._serialized_start = 7543 + _CONFIGREQUEST_OPERATION._serialized_end = 8041 + _CONFIGREQUEST_SET._serialized_start = 8043 + _CONFIGREQUEST_SET._serialized_end = 8095 + _CONFIGREQUEST_GET._serialized_start = 8097 + _CONFIGREQUEST_GET._serialized_end = 8122 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 8124 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 8187 + _CONFIGREQUEST_GETOPTION._serialized_start = 8189 + _CONFIGREQUEST_GETOPTION._serialized_end = 8220 + _CONFIGREQUEST_GETALL._serialized_start = 8222 + _CONFIGREQUEST_GETALL._serialized_end = 8270 + _CONFIGREQUEST_UNSET._serialized_start = 8272 + _CONFIGREQUEST_UNSET._serialized_end = 8299 + _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 8301 + _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 8335 + _CONFIGRESPONSE._serialized_start = 8354 + _CONFIGRESPONSE._serialized_end = 8529 + _ADDARTIFACTSREQUEST._serialized_start = 8532 + _ADDARTIFACTSREQUEST._serialized_end = 9403 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 8919 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 8972 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 8974 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 9085 + _ADDARTIFACTSREQUEST_BATCH._serialized_start = 9087 + _ADDARTIFACTSREQUEST_BATCH._serialized_end = 9180 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 9183 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 9376 + _ADDARTIFACTSRESPONSE._serialized_start = 9406 + _ADDARTIFACTSRESPONSE._serialized_end = 9678 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 9597 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 9678 + _ARTIFACTSTATUSESREQUEST._serialized_start = 9681 + _ARTIFACTSTATUSESREQUEST._serialized_end = 9876 + _ARTIFACTSTATUSESRESPONSE._serialized_start = 9879 + _ARTIFACTSTATUSESRESPONSE._serialized_end = 10231 + _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 10074 + _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 10189 + _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 10191 + _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 10231 + _INTERRUPTREQUEST._serialized_start = 10234 + _INTERRUPTREQUEST._serialized_end = 10706 + _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 10549 + _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 10677 + _INTERRUPTRESPONSE._serialized_start = 10709 + _INTERRUPTRESPONSE._serialized_end = 10853 + _REATTACHOPTIONS._serialized_start = 10855 + _REATTACHOPTIONS._serialized_end = 10908 + _REATTACHEXECUTEREQUEST._serialized_start = 10911 + _REATTACHEXECUTEREQUEST._serialized_end = 11186 + _RELEASEEXECUTEREQUEST._serialized_start = 11189 + _RELEASEEXECUTEREQUEST._serialized_end = 11643 + _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 11555 + _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 11567 + _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 11569 + _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11616 + _RELEASEEXECUTERESPONSE._serialized_start = 11646 + _RELEASEEXECUTERESPONSE._serialized_end = 11811 + _RELEASESESSIONREQUEST._serialized_start = 11814 + _RELEASESESSIONREQUEST._serialized_end = 11985 + _RELEASESESSIONRESPONSE._serialized_start = 11987 + _RELEASESESSIONRESPONSE._serialized_end = 12095 + _FETCHERRORDETAILSREQUEST._serialized_start = 12098 + _FETCHERRORDETAILSREQUEST._serialized_end = 12299 + _FETCHERRORDETAILSRESPONSE._serialized_start = 12302 + _FETCHERRORDETAILSRESPONSE._serialized_end = 13856 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 12531 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 12705 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 12708 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 13075 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 13038 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 13075 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 13078 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 13487 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 13389 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 13457 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 13490 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13837 + _SPARKCONNECTSERVICE._serialized_start = 13859 + _SPARKCONNECTSERVICE._serialized_end = 14805 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 20abbcb348bdd..cdf7e2b0bce0b 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -666,6 +666,7 @@ global___AnalyzePlanRequest = AnalyzePlanRequest class AnalyzePlanResponse(google.protobuf.message.Message): """Response to performing analysis of the query. Contains relevant metadata to be able to reason about the performance. + Next ID: 16 """ DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -854,6 +855,7 @@ class AnalyzePlanResponse(google.protobuf.message.Message): ) -> None: ... SESSION_ID_FIELD_NUMBER: builtins.int + SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int SCHEMA_FIELD_NUMBER: builtins.int EXPLAIN_FIELD_NUMBER: builtins.int TREE_STRING_FIELD_NUMBER: builtins.int @@ -868,6 +870,10 @@ class AnalyzePlanResponse(google.protobuf.message.Message): UNPERSIST_FIELD_NUMBER: builtins.int GET_STORAGE_LEVEL_FIELD_NUMBER: builtins.int session_id: builtins.str + server_side_session_id: builtins.str + """Server-side generated idempotency key that the client can use to assert that the server side + session has not changed. + """ @property def schema(self) -> global___AnalyzePlanResponse.Schema: ... @property @@ -898,6 +904,7 @@ class AnalyzePlanResponse(google.protobuf.message.Message): self, *, session_id: builtins.str = ..., + server_side_session_id: builtins.str = ..., schema: global___AnalyzePlanResponse.Schema | None = ..., explain: global___AnalyzePlanResponse.Explain | None = ..., tree_string: global___AnalyzePlanResponse.TreeString | None = ..., @@ -970,6 +977,8 @@ class AnalyzePlanResponse(google.protobuf.message.Message): b"schema", "semantic_hash", b"semantic_hash", + "server_side_session_id", + b"server_side_session_id", "session_id", b"session_id", "spark_version", @@ -1169,6 +1178,7 @@ global___ExecutePlanRequest = ExecutePlanRequest class ExecutePlanResponse(google.protobuf.message.Message): """The response of a query, can be one or more for each request. Responses belonging to the same input query, carry the same `session_id`. + Next ID: 16 """ DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -1391,6 +1401,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): ) -> None: ... SESSION_ID_FIELD_NUMBER: builtins.int + SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int OPERATION_ID_FIELD_NUMBER: builtins.int RESPONSE_ID_FIELD_NUMBER: builtins.int ARROW_BATCH_FIELD_NUMBER: builtins.int @@ -1405,6 +1416,10 @@ class ExecutePlanResponse(google.protobuf.message.Message): OBSERVED_METRICS_FIELD_NUMBER: builtins.int SCHEMA_FIELD_NUMBER: builtins.int session_id: builtins.str + server_side_session_id: builtins.str + """Server-side generated idempotency key that the client can use to assert that the server side + session has not changed. + """ operation_id: builtins.str """Identifies the ExecutePlan execution. If set by the client in ExecutePlanRequest.operationId, that value is returned. @@ -1465,6 +1480,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): self, *, session_id: builtins.str = ..., + server_side_session_id: builtins.str = ..., operation_id: builtins.str = ..., response_id: builtins.str = ..., arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ..., @@ -1534,6 +1550,8 @@ class ExecutePlanResponse(google.protobuf.message.Message): b"result_complete", "schema", b"schema", + "server_side_session_id", + b"server_side_session_id", "session_id", b"session_id", "sql_command_result", @@ -1870,14 +1888,21 @@ class ConfigRequest(google.protobuf.message.Message): global___ConfigRequest = ConfigRequest class ConfigResponse(google.protobuf.message.Message): - """Response to the config request.""" + """Response to the config request. + Next ID: 5 + """ DESCRIPTOR: google.protobuf.descriptor.Descriptor SESSION_ID_FIELD_NUMBER: builtins.int + SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int PAIRS_FIELD_NUMBER: builtins.int WARNINGS_FIELD_NUMBER: builtins.int session_id: builtins.str + server_side_session_id: builtins.str + """Server-side generated idempotency key that the client can use to assert that the server side + session has not changed. + """ @property def pairs( self, @@ -1899,13 +1924,21 @@ class ConfigResponse(google.protobuf.message.Message): self, *, session_id: builtins.str = ..., + server_side_session_id: builtins.str = ..., pairs: collections.abc.Iterable[global___KeyValue] | None = ..., warnings: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal[ - "pairs", b"pairs", "session_id", b"session_id", "warnings", b"warnings" + "pairs", + b"pairs", + "server_side_session_id", + b"server_side_session_id", + "session_id", + b"session_id", + "warnings", + b"warnings", ], ) -> None: ... @@ -2141,6 +2174,7 @@ global___AddArtifactsRequest = AddArtifactsRequest class AddArtifactsResponse(google.protobuf.message.Message): """Response to adding an artifact. Contains relevant metadata to verify successful transfer of artifact(s). + Next ID: 4 """ DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -2171,7 +2205,15 @@ class AddArtifactsResponse(google.protobuf.message.Message): ], ) -> None: ... + SESSION_ID_FIELD_NUMBER: builtins.int + SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int ARTIFACTS_FIELD_NUMBER: builtins.int + session_id: builtins.str + """Session id in which the AddArtifact was running.""" + server_side_session_id: builtins.str + """Server-side generated idempotency key that the client can use to assert that the server side + session has not changed. + """ @property def artifacts( self, @@ -2182,11 +2224,21 @@ class AddArtifactsResponse(google.protobuf.message.Message): def __init__( self, *, + session_id: builtins.str = ..., + server_side_session_id: builtins.str = ..., artifacts: collections.abc.Iterable[global___AddArtifactsResponse.ArtifactSummary] | None = ..., ) -> None: ... def ClearField( - self, field_name: typing_extensions.Literal["artifacts", b"artifacts"] + self, + field_name: typing_extensions.Literal[ + "artifacts", + b"artifacts", + "server_side_session_id", + b"server_side_session_id", + "session_id", + b"session_id", + ], ) -> None: ... global___AddArtifactsResponse = AddArtifactsResponse @@ -2268,25 +2320,12 @@ class ArtifactStatusesRequest(google.protobuf.message.Message): global___ArtifactStatusesRequest = ArtifactStatusesRequest class ArtifactStatusesResponse(google.protobuf.message.Message): - """Response to checking artifact statuses.""" + """Response to checking artifact statuses. + Next ID: 4 + """ DESCRIPTOR: google.protobuf.descriptor.Descriptor - class ArtifactStatus(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - EXISTS_FIELD_NUMBER: builtins.int - exists: builtins.bool - """Exists or not particular artifact at the server.""" - def __init__( - self, - *, - exists: builtins.bool = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["exists", b"exists"] - ) -> None: ... - class StatusesEntry(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -2308,7 +2347,30 @@ class ArtifactStatusesResponse(google.protobuf.message.Message): self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] ) -> None: ... + class ArtifactStatus(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + EXISTS_FIELD_NUMBER: builtins.int + exists: builtins.bool + """Exists or not particular artifact at the server.""" + def __init__( + self, + *, + exists: builtins.bool = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["exists", b"exists"] + ) -> None: ... + + SESSION_ID_FIELD_NUMBER: builtins.int + SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int STATUSES_FIELD_NUMBER: builtins.int + session_id: builtins.str + """Session id in which the ArtifactStatus was running.""" + server_side_session_id: builtins.str + """Server-side generated idempotency key that the client can use to assert that the server side + session has not changed. + """ @property def statuses( self, @@ -2319,13 +2381,23 @@ class ArtifactStatusesResponse(google.protobuf.message.Message): def __init__( self, *, + session_id: builtins.str = ..., + server_side_session_id: builtins.str = ..., statuses: collections.abc.Mapping[ builtins.str, global___ArtifactStatusesResponse.ArtifactStatus ] | None = ..., ) -> None: ... def ClearField( - self, field_name: typing_extensions.Literal["statuses", b"statuses"] + self, + field_name: typing_extensions.Literal[ + "server_side_session_id", + b"server_side_session_id", + "session_id", + b"session_id", + "statuses", + b"statuses", + ], ) -> None: ... global___ArtifactStatusesResponse = ArtifactStatusesResponse @@ -2449,12 +2521,19 @@ class InterruptRequest(google.protobuf.message.Message): global___InterruptRequest = InterruptRequest class InterruptResponse(google.protobuf.message.Message): + """Next ID: 4""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor SESSION_ID_FIELD_NUMBER: builtins.int + SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int INTERRUPTED_IDS_FIELD_NUMBER: builtins.int session_id: builtins.str """Session id in which the interrupt was running.""" + server_side_session_id: builtins.str + """Server-side generated idempotency key that the client can use to assert that the server side + session has not changed. + """ @property def interrupted_ids( self, @@ -2464,12 +2543,18 @@ class InterruptResponse(google.protobuf.message.Message): self, *, session_id: builtins.str = ..., + server_side_session_id: builtins.str = ..., interrupted_ids: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal[ - "interrupted_ids", b"interrupted_ids", "session_id", b"session_id" + "interrupted_ids", + b"interrupted_ids", + "server_side_session_id", + b"server_side_session_id", + "session_id", + b"session_id", ], ) -> None: ... @@ -2723,12 +2808,19 @@ class ReleaseExecuteRequest(google.protobuf.message.Message): global___ReleaseExecuteRequest = ReleaseExecuteRequest class ReleaseExecuteResponse(google.protobuf.message.Message): + """Next ID: 4""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor SESSION_ID_FIELD_NUMBER: builtins.int + SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int OPERATION_ID_FIELD_NUMBER: builtins.int session_id: builtins.str """Session id in which the release was running.""" + server_side_session_id: builtins.str + """Server-side generated idempotency key that the client can use to assert that the server side + session has not changed. + """ operation_id: builtins.str """Operation id of the operation on which the release executed. If the operation couldn't be found (because e.g. it was concurrently released), will be unset. @@ -2738,6 +2830,7 @@ class ReleaseExecuteResponse(google.protobuf.message.Message): self, *, session_id: builtins.str = ..., + server_side_session_id: builtins.str = ..., operation_id: builtins.str | None = ..., ) -> None: ... def HasField( @@ -2753,6 +2846,8 @@ class ReleaseExecuteResponse(google.protobuf.message.Message): b"_operation_id", "operation_id", b"operation_id", + "server_side_session_id", + b"server_side_session_id", "session_id", b"session_id", ], @@ -2825,18 +2920,29 @@ class ReleaseSessionRequest(google.protobuf.message.Message): global___ReleaseSessionRequest = ReleaseSessionRequest class ReleaseSessionResponse(google.protobuf.message.Message): + """Next ID: 3""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor SESSION_ID_FIELD_NUMBER: builtins.int + SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int session_id: builtins.str """Session id of the session on which the release executed.""" + server_side_session_id: builtins.str + """Server-side generated idempotency key that the client can use to assert that the server side + session has not changed. + """ def __init__( self, *, session_id: builtins.str = ..., + server_side_session_id: builtins.str = ..., ) -> None: ... def ClearField( - self, field_name: typing_extensions.Literal["session_id", b"session_id"] + self, + field_name: typing_extensions.Literal[ + "server_side_session_id", b"server_side_session_id", "session_id", b"session_id" + ], ) -> None: ... global___ReleaseSessionResponse = ReleaseSessionResponse @@ -2906,6 +3012,8 @@ class FetchErrorDetailsRequest(google.protobuf.message.Message): global___FetchErrorDetailsRequest = FetchErrorDetailsRequest class FetchErrorDetailsResponse(google.protobuf.message.Message): + """Next ID: 5""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor class StackTraceElement(google.protobuf.message.Message): @@ -3224,8 +3332,15 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): self, oneof_group: typing_extensions.Literal["_spark_throwable", b"_spark_throwable"] ) -> typing_extensions.Literal["spark_throwable"] | None: ... + SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int + SESSION_ID_FIELD_NUMBER: builtins.int ROOT_ERROR_IDX_FIELD_NUMBER: builtins.int ERRORS_FIELD_NUMBER: builtins.int + server_side_session_id: builtins.str + """Server-side generated idempotency key that the client can use to assert that the server side + session has not changed. + """ + session_id: builtins.str root_error_idx: builtins.int """The index of the root error in errors. The field will not be set if the error is not found.""" @property @@ -3238,6 +3353,8 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): def __init__( self, *, + server_side_session_id: builtins.str = ..., + session_id: builtins.str = ..., root_error_idx: builtins.int | None = ..., errors: collections.abc.Iterable[global___FetchErrorDetailsResponse.Error] | None = ..., ) -> None: ... @@ -3256,6 +3373,10 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): b"errors", "root_error_idx", b"root_error_idx", + "server_side_session_id", + b"server_side_session_id", + "session_id", + b"session_id", ], ) -> None: ... def WhichOneof( From a566099133ff38cd1b2cd2fe64879bf0ba75fa9b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 9 Nov 2023 18:34:43 -0800 Subject: [PATCH 081/121] [SPARK-45756][CORE] Support `spark.master.useAppNameAsAppId.enabled` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR aims to support `spark.master.useAppNameAsAppId.enabled` as an experimental feature in Spark Standalone cluster. ### Why are the changes needed? This allows the users to control the appID completely. Screenshot 2023-11-09 at 5 33 45 PM ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual tests with the following procedure. ``` $ SPARK_MASTER_OPTS="-Dspark.master.useAppNameAsAppId.enabled=true" sbin/start-master.sh $ bin/spark-shell --master spark://max.local:7077 ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43743 from dongjoon-hyun/SPARK-45756. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/deploy/master/Master.scala | 7 ++++++- .../apache/spark/internal/config/package.scala | 8 ++++++++ .../apache/spark/deploy/master/MasterSuite.scala | 16 ++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index dbb647252c5f7..b3fbec1830e45 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -120,6 +120,7 @@ private[deploy] class Master( private val defaultCores = conf.get(DEFAULT_CORES) val reverseProxy = conf.get(UI_REVERSE_PROXY) val historyServerUrl = conf.get(MASTER_UI_HISTORY_SERVER_URL) + val useAppNameAsAppId = conf.get(MASTER_USE_APP_NAME_AS_APP_ID) // Alternative application submission gateway that is stable across Spark versions private val restServerEnabled = conf.get(MASTER_REST_SERVER_ENABLED) @@ -1041,7 +1042,11 @@ private[deploy] class Master( ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) - val appId = newApplicationId(date) + val appId = if (useAppNameAsAppId) { + desc.name.toLowerCase().replaceAll("\\s+", "") + } else { + newApplicationId(date) + } new ApplicationInfo(now, appId, desc, date, driver, defaultCores) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bbadf91fc41cb..b2bf30863a91e 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1846,6 +1846,14 @@ package object config { .stringConf .createOptional + private[spark] val MASTER_USE_APP_NAME_AS_APP_ID = + ConfigBuilder("spark.master.useAppNameAsAppId.enabled") + .internal() + .doc("(Experimental) If true, Spark master uses the user-provided appName for appId.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + private[spark] val IO_COMPRESSION_SNAPPY_BLOCKSIZE = ConfigBuilder("spark.io.compression.snappy.blockSize") .doc("Block size in bytes used in Snappy compression, in the case when " + diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 4f8457f930e4a..2e54673649c74 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -804,6 +804,7 @@ class MasterSuite extends SparkFunSuite private val _state = PrivateMethod[RecoveryState.Value](Symbol("state")) private val _newDriverId = PrivateMethod[String](Symbol("newDriverId")) private val _newApplicationId = PrivateMethod[String](Symbol("newApplicationId")) + private val _createApplication = PrivateMethod[ApplicationInfo](Symbol("createApplication")) private val workerInfo = makeWorkerInfo(4096, 10) private val workerInfos = Array(workerInfo, workerInfo, workerInfo) @@ -1275,6 +1276,21 @@ class MasterSuite extends SparkFunSuite assert(master.invokePrivate(_newApplicationId(submitDate)) === s"${i % 1000}") } } + + test("SPARK-45756: Use appName for appId") { + val conf = new SparkConf() + .set(MASTER_USE_APP_NAME_AS_APP_ID, true) + val master = makeMaster(conf) + val desc = new ApplicationDescription( + name = " spark - 45756 ", + maxCores = None, + command = null, + appUiUrl = "", + defaultProfile = DeployTestUtils.defaultResourceProfile, + eventLogDir = None, + eventLogCodec = None) + assert(master.invokePrivate(_createApplication(desc, null)).id === "spark-45756") + } } private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer) From 2085ac5217577d20ab095c7f4eb37e72eb8e926b Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 9 Nov 2023 18:48:14 -0800 Subject: [PATCH 082/121] [SPARK-45850][BUILD] Upgrade oracle jdbc driver to 23.3.0.23.09 ### What changes were proposed in this pull request? The pr aims to upgrade oracle jdbc driver from `com.oracle.database.jdbc:ojdbc8:23.2.0.0` to `com.oracle.database.jdbc:ojdbc11:23.3.0.23.09`. ### Why are the changes needed? It is often found that `Run Docker integration tests` is interrupted due to `OracleIntegrationSuite` failure. Several possible issues have been identified: - The `oracle jdbc driver` currently used in Spark is `com.oracle.database.jdbc:ojdbc8`, it not for use with `JDK17`. https://www.oracle.com/database/technologies/maven-central-guide.html image - The `oracle server` version used in our GA environment is also `23.3` image After upgrade to `com.oracle.database.jdbc:ojdbc11`, I attempted to trigger `Run Docker integration tests` 6 times (trigger date distribution: day, night or weekend, etc.), and have all been successful for the `first time`. https://github.com/panbingkun/spark/actions/runs/6760572538/job/18383757692 https://github.com/panbingkun/spark/actions/runs/6765163954/job/18399468066 https://github.com/panbingkun/spark/actions/runs/6778851101/job/18465823046 https://github.com/panbingkun/spark/actions/runs/6797719942/job/18480445904 https://github.com/panbingkun/spark/actions/runs/6804543746/job/18502353511 https://github.com/panbingkun/spark/actions/runs/6809925232/job/18521970051 So, I proposed upgrading this driver first and then continuing to observe the effect. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43662 from panbingkun/oracle_integration_suite_fix. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- connector/docker-integration-tests/pom.xml | 2 +- pom.xml | 4 ++-- sql/core/pom.xml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/connector/docker-integration-tests/pom.xml b/connector/docker-integration-tests/pom.xml index 4308abbbe2016..ac8d2990c0e64 100644 --- a/connector/docker-integration-tests/pom.xml +++ b/connector/docker-integration-tests/pom.xml @@ -121,7 +121,7 @@ com.oracle.database.jdbc - ojdbc8 + ojdbc11 test diff --git a/pom.xml b/pom.xml index cae315f4d7182..71d8ffcc1c909 100644 --- a/pom.xml +++ b/pom.xml @@ -1279,8 +1279,8 @@ com.oracle.database.jdbc - ojdbc8 - 23.2.0.0 + ojdbc11 + 23.3.0.23.09 test diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 8fabfd4699db6..185d50c018c6a 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -186,7 +186,7 @@ com.oracle.database.jdbc - ojdbc8 + ojdbc11 test From 9bac48d4bd68d4f0d54c53c29a27b1f6e02c5f61 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 10 Nov 2023 17:12:25 +0900 Subject: [PATCH 083/121] [SPARK-45852][CONNECT][PYTHON] Gracefully deal with recursion error during logging ### What changes were proposed in this pull request? The Python client for Spark connect logs the text representation of the proto message. However, for deeply nested objects this can lead to a Python recursion error even before the maximum nested recursion limit of the GRPC message is reached. This patch fixes this issue by explicitly catching the recursion error during text conversion. ### Why are the changes needed? Stability ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #43732 from grundprinzip/SPARK-45852. Authored-by: Martin Grund Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/client/core.py | 5 ++++- .../pyspark/sql/tests/connect/test_connect_basic.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 965c4107cacee..7eafcc501f5f9 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -935,7 +935,10 @@ def _proto_to_string(self, p: google.protobuf.message.Message) -> str: ------- Single line string of the serialized proto message. """ - return text_format.MessageToString(p, as_one_line=True) + try: + return text_format.MessageToString(p, as_one_line=True) + except RecursionError: + return "" def schema(self, plan: pb2.Plan) -> StructType: """ diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index daf6772e52bf5..7a224d68219b0 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -159,6 +159,19 @@ def spark_connect_clean_up_test_data(cls): class SparkConnectBasicTests(SparkConnectSQLTestCase): + def test_recursion_handling_for_plan_logging(self): + """SPARK-45852 - Test that we can handle recursion in plan logging.""" + cdf = self.connect.range(1) + for x in range(400): + cdf = cdf.withColumn(f"col_{x}", CF.lit(x)) + + # Calling schema will trigger logging the message that will in turn trigger the message + # conversion into protobuf that will then trigger the recursion error. + self.assertIsNotNone(cdf.schema) + + result = self.connect._client._proto_to_string(cdf._plan.to_proto(self.connect._client)) + self.assertIn("recursion", result) + def test_df_getattr_behavior(self): cdf = self.connect.range(10) sdf = self.spark.range(10) From f5b1b8306cf13218f5ff79944aaa9c0b4e74fda4 Mon Sep 17 00:00:00 2001 From: Sandip Agarwala <131817656+sandip-db@users.noreply.github.com> Date: Fri, 10 Nov 2023 17:44:39 +0900 Subject: [PATCH 084/121] [SPARK-45562][SQL] XML: Add SQL error class for missing rowTag option ### What changes were proposed in this pull request? rowTag option is required for reading XML files. This PR adds a SQL error class for missing rowTag option. ### Why are the changes needed? rowTag option is required for reading XML files. This PR adds a SQL error class for missing rowTag option. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Updated the unit test to check for error message. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43710 from sandip-db/xml-rowTagRequiredError. Authored-by: Sandip Agarwala <131817656+sandip-db@users.noreply.github.com> Signed-off-by: Hyukjin Kwon --- .../utils/src/main/resources/error/error-classes.json | 6 ++++++ docs/sql-error-conditions.md | 6 ++++++ .../apache/spark/sql/catalyst/xml/XmlOptions.scala | 8 ++++++-- .../spark/sql/errors/QueryCompilationErrors.scala | 7 +++++++ .../sql/execution/datasources/xml/XmlSuite.scala | 11 ++++++++--- 5 files changed, 33 insertions(+), 5 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 26f6c0240afb3..3b7a3a6006ef3 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3911,6 +3911,12 @@ }, "sqlState" : "42605" }, + "XML_ROW_TAG_MISSING" : { + "message" : [ + " option is required for reading files in XML format." + ], + "sqlState" : "42000" + }, "_LEGACY_ERROR_TEMP_0001" : { "message" : [ "Invalid InsertIntoContext." diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 2cb433b19fa56..a811019e0a57b 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2369,3 +2369,9 @@ The operation `` requires a ``. But `` is a The `` requires `` parameters but the actual number is ``. For more details see [WRONG_NUM_ARGS](sql-error-conditions-wrong-num-args-error-class.html) + +### XML_ROW_TAG_MISSING + +[SQLSTATE: 42000](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +`` option is required for reading files in XML format. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala index aac6eec21c60a..8f6cdbf360ef0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala @@ -24,7 +24,7 @@ import javax.xml.stream.XMLInputFactory import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, DateFormatter, DateTimeUtils, ParseMode, PermissiveMode} -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} /** @@ -66,7 +66,11 @@ private[sql] class XmlOptions( val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName) val rowTagOpt = parameters.get(XmlOptions.ROW_TAG).map(_.trim) - require(!rowTagRequired || rowTagOpt.isDefined, s"'${XmlOptions.ROW_TAG}' option is required.") + + if (rowTagRequired && rowTagOpt.isEmpty) { + throw QueryCompilationErrors.xmlRowTagRequiredError(XmlOptions.ROW_TAG) + } + val rowTag = rowTagOpt.getOrElse(XmlOptions.DEFAULT_ROW_TAG) require(rowTag.nonEmpty, s"'$ROW_TAG' option should not be an empty string.") require(!rowTag.startsWith("<") && !rowTag.endsWith(">"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 0c5dcb1ead01e..e772b3497ac34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3817,4 +3817,11 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat errorClass = "FOUND_MULTIPLE_DATA_SOURCES", messageParameters = Map("provider" -> provider)) } + + def xmlRowTagRequiredError(optionName: String): Throwable = { + new AnalysisException( + errorClass = "XML_ROW_TAG_MISSING", + messageParameters = Map("rowTag" -> toSQLId(optionName)) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 2d4cd2f403c56..21122676c46be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, QueryTest, Ro import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.xml.XmlOptions import org.apache.spark.sql.catalyst.xml.XmlOptions._ +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.xml.TestUtils._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession @@ -1782,17 +1783,21 @@ class XmlSuite extends QueryTest with SharedSparkSession { test("Test XML Options Error Messages") { def checkXmlOptionErrorMessage( parameters: Map[String, String] = Map.empty, - msg: String): Unit = { - val e = intercept[IllegalArgumentException] { + msg: String, + exception: Throwable = new IllegalArgumentException().getCause): Unit = { + val e = intercept[Exception] { spark.read .options(parameters) .xml(getTestResourcePath(resDir + "ages.xml")) .collect() } + assert(e.getCause === exception) assert(e.getMessage.contains(msg)) } - checkXmlOptionErrorMessage(Map.empty, "'rowTag' option is required.") + checkXmlOptionErrorMessage(Map.empty, + "[XML_ROW_TAG_MISSING] `rowTag` option is required for reading files in XML format.", + QueryCompilationErrors.xmlRowTagRequiredError(XmlOptions.ROW_TAG).getCause) checkXmlOptionErrorMessage(Map("rowTag" -> ""), "'rowTag' option should not be an empty string.") checkXmlOptionErrorMessage(Map("rowTag" -> " "), From bd526986a5f0087ebdae0fc0ac0f07c273b8fc38 Mon Sep 17 00:00:00 2001 From: Alice Sayutina Date: Fri, 10 Nov 2023 17:47:27 +0900 Subject: [PATCH 085/121] [SPARK-45837][CONNECT] Improve logging information in handling retries ### What changes were proposed in this pull request? Add suppressed exception when handling retries ### Why are the changes needed? Improves user and debugging experience by showing underlying error. ### Does this PR introduce _any_ user-facing change? Better exceptions. ### How was this patch tested? Hand testing ### Was this patch authored or co-authored using generative AI tooling? NA Closes #43719 from cdkrot/SPARK-45837. Authored-by: Alice Sayutina Signed-off-by: Hyukjin Kwon --- .../client/ExecutePlanResponseReattachableIterator.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 9fd8b12fef96f..2b61463c343fb 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -236,7 +236,9 @@ class ExecutePlanResponseReattachableIterator( } // Try a new ExecutePlan, and throw upstream for retry. iter = Some(rawBlockingStub.executePlan(initialRequest)) - throw new GrpcRetryHandler.RetryException + val error = new GrpcRetryHandler.RetryException() + error.addSuppressed(ex) + throw error case NonFatal(e) => // Remove the iterator, so that a new one will be created after retry. iter = None From 49ca6aa6cb75b931d1c38dcffb4cd3dd63b0a2f3 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Fri, 10 Nov 2023 12:17:09 +0300 Subject: [PATCH 086/121] [MINOR][SQL] Pass `cause` in `CannotReplaceMissingTableException` costructor ### What changes were proposed in this pull request? In the PR, I propose to use the `cause` argument in the `CannotReplaceMissingTableException` constructor. ### Why are the changes needed? To improve user experience with Spark SQL while troubleshooting issues. Currently, users don't see where the exception come from. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43738 from MaxGekk/fix-missed-cause. Authored-by: Max Gekk Signed-off-by: Max Gekk --- .../catalyst/analysis/CannotReplaceMissingTableException.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala index 910bb9d374971..032cdca12c050 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala @@ -28,4 +28,5 @@ class CannotReplaceMissingTableException( extends AnalysisException( errorClass = "TABLE_OR_VIEW_NOT_FOUND", messageParameters = Map("relationName" - -> quoteNameParts(tableIdentifier.namespace :+ tableIdentifier.name))) + -> quoteNameParts(tableIdentifier.namespace :+ tableIdentifier.name)), + cause = cause) From b347237735094e9092f4100583ed1d6f3eacf1f6 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 10 Nov 2023 21:09:43 +0800 Subject: [PATCH 087/121] [SPARK-45878][SQL][TESTS] Fix ConcurrentModificationException in CliSuite ### What changes were proposed in this pull request? This PR changes the ArrayBuffer for logs to immutable for reading to prevent ConcurrentModificationException which hides the actual cause of failure ### Why are the changes needed? ```scala [info] - SPARK-29022 Commands using SerDe provided in ADD JAR sql *** FAILED *** (11 seconds, 105 milliseconds) [info] java.util.ConcurrentModificationException: mutation occurred during iteration [info] at scala.collection.mutable.MutationTracker$.checkMutations(MutationTracker.scala:43) [info] at scala.collection.mutable.CheckedIndexedSeqView$CheckedIterator.hasNext(CheckedIndexedSeqView.scala:47) [info] at scala.collection.IterableOnceOps.addString(IterableOnce.scala:1247) [info] at scala.collection.IterableOnceOps.addString$(IterableOnce.scala:1241) [info] at scala.collection.AbstractIterable.addString(Iterable.scala:933) [info] at org.apache.spark.sql.hive.thriftserver.CliSuite.runCliWithin(CliSuite.scala:205) [info] at org.apache.spark.sql.hive.thriftserver.CliSuite.$anonfun$new$20(CliSuite.scala:501) ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #43749 from yaooqinn/SPARK-45878. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../org/apache/spark/sql/hive/thriftserver/CliSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 4f0d4dff566c4..110ef7b0affab 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -193,7 +193,7 @@ class CliSuite extends SparkFunSuite { ThreadUtils.awaitResult(foundAllExpectedAnswers.future, timeoutForQuery) log.info("Found all expected output.") } catch { case cause: Throwable => - val message = + val message = lock.synchronized { s""" |======================= |CliSuite failure output @@ -207,6 +207,7 @@ class CliSuite extends SparkFunSuite { |End CliSuite failure output |=========================== """.stripMargin + } logError(message, cause) fail(message, cause) } finally { From 6851cb96ec651b25a8103f7681e8528ff7d625ff Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 10 Nov 2023 22:00:51 +0800 Subject: [PATCH 088/121] [SPARK-45752][SQL] Simplify the code for check unreferenced CTE relations ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/43614 let unreferenced `CTE` checked by `CheckAnalysis0`. This PR follows up https://github.com/apache/spark/pull/43614 to simplify the code for check unreferenced CTE relations. ### Why are the changes needed? Simplify the code for check unreferenced CTE relations ### Does this PR introduce _any_ user-facing change? 'No'. ### How was this patch tested? Exists test cases. ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #43727 from beliefer/SPARK-45752_followup. Authored-by: Jiaan Geng Signed-off-by: Jiaan Geng --- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++++-------- .../org/apache/spark/sql/CTEInlineSuite.scala | 18 ++++++++++++++++-- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 29d60ae0f41e1..f9010d47508c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -167,25 +167,21 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB val inlineCTE = InlineCTE(alwaysInline = true) val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] inlineCTE.buildCTEMap(plan, cteMap) - cteMap.values.foreach { case (relation, _, _) => + val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false) + cteMap.foreach { case (cteId, (relation, refCount, _)) => // If a CTE relation is never used, it will disappear after inline. Here we explicitly check // analysis for it, to make sure the entire query plan is valid. try { // If a CTE relation ref count is 0, the other CTE relations that reference it // should also be checked by checkAnalysis0. This code will also guarantee the leaf // relations that do not reference any others are checked first. - val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false) - cteMap.foreach { case (cteId, _) => - val (_, refCount, _) = cteMap(cteId) - if (refCount == 0) { - checkUnreferencedCTERelations(cteMap, visited, cteId) - } + if (refCount == 0) { + checkUnreferencedCTERelations(cteMap, visited, cteId) } } catch { case e: AnalysisException => throw new ExtendedAnalysisException(e, relation.child) } - } // Inline all CTEs in the plan to help check query plan structures in subqueries. var inlinedPlan: Option[LogicalPlan] = None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala index 055c04992c009..a06b50d175f90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala @@ -683,11 +683,25 @@ abstract class CTEInlineSuiteBase val e = intercept[AnalysisException](sql( s""" |with - |a as (select * from non_exist), + |a as (select * from tab_non_exists), |b as (select * from a) |select 2 |""".stripMargin)) - checkErrorTableNotFound(e, "`non_exist`", ExpectedContext("non_exist", 26, 34)) + checkErrorTableNotFound(e, "`tab_non_exists`", ExpectedContext("tab_non_exists", 26, 39)) + + withTable("tab_exists") { + spark.sql("CREATE TABLE tab_exists(id INT) using parquet") + val e = intercept[AnalysisException](sql( + s""" + |with + |a as (select * from tab_exists), + |b as (select * from a), + |c as (select * from tab_non_exists), + |d as (select * from c) + |select 2 + |""".stripMargin)) + checkErrorTableNotFound(e, "`tab_non_exists`", ExpectedContext("tab_non_exists", 83, 96)) + } } } From 605aa0c299c1d88f8a31ba888ac8e6b6203be6c5 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Fri, 10 Nov 2023 08:10:20 -0600 Subject: [PATCH 089/121] [SPARK-45687][CORE][SQL][ML][MLLIB][KUBERNETES][EXAMPLES][CONNECT][STRUCTURED STREAMING] Fix `Passing an explicit array value to a Scala varargs method is deprecated` ### What changes were proposed in this pull request? Fix the deprecated behavior below: `Passing an explicit array value to a Scala varargs method is deprecated (since 2.13.0) and will result in a defensive copy; Use the more efficient non-copying ArraySeq.unsafeWrapArray or an explicit toIndexedSeq call` For all the use cases, we don't need to make a copy of the array. Explicitly use `ArraySeq.unsafeWrapArray` to do the conversion. ### Why are the changes needed? Eliminate compile warnings and no longer use deprecated scala APIs. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GA. Fixed all the warning with build: `mvn clean package -DskipTests -Pspark-ganglia-lgpl -Pkinesis-asl -Pdocker-integration-tests -Pyarn -Pkubernetes -Pkubernetes-integration-tests -Phive-thriftserver -Phadoop-cloud` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43642 from ivoson/SPARK-45687. Authored-by: Tengfei Huang Signed-off-by: Sean Owen --- .../apache/spark/sql/KeyValueGroupedDataset.scala | 9 ++++++--- .../scala/org/apache/spark/sql/ColumnTestSuite.scala | 3 ++- .../spark/sql/UserDefinedFunctionE2ETestSuite.scala | 5 ++++- .../sql/connect/planner/SparkConnectPlanner.scala | 3 ++- .../org/apache/spark/api/python/PythonRDD.scala | 3 ++- .../scala/org/apache/spark/executor/Executor.scala | 3 ++- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 3 ++- .../org/apache/spark/examples/graphx/Analytics.scala | 4 ++-- .../apache/spark/ml/classification/OneVsRest.scala | 3 ++- .../org/apache/spark/ml/feature/FeatureHasher.scala | 4 +++- .../scala/org/apache/spark/ml/feature/Imputer.scala | 8 +++++--- .../org/apache/spark/ml/feature/Interaction.scala | 4 +++- .../scala/org/apache/spark/ml/feature/RFormula.scala | 6 ++++-- .../apache/spark/ml/feature/VectorAssembler.scala | 5 +++-- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 3 ++- .../scala/org/apache/spark/ml/fpm/PrefixSpan.scala | 3 ++- .../scala/org/apache/spark/ml/r/KSTestWrapper.scala | 3 ++- .../spark/ml/regression/DecisionTreeRegressor.scala | 3 ++- .../scala/org/apache/spark/ml/tree/treeModels.scala | 3 ++- .../scala/org/apache/spark/mllib/util/MLUtils.scala | 12 ++++++++---- .../org/apache/spark/ml/feature/ImputerSuite.scala | 12 ++++++++---- .../spark/ml/source/image/ImageFileFormatSuite.scala | 3 ++- .../spark/ml/stat/KolmogorovSmirnovTestSuite.scala | 3 ++- .../test/scala/org/apache/spark/ml/util/MLTest.scala | 6 ++++-- .../k8s/features/DriverCommandFeatureStepSuite.scala | 2 +- .../spark/sql/catalyst/expressions/generators.scala | 8 ++++++-- .../expressions/UnsafeRowConverterSuite.scala | 4 +++- .../apache/spark/sql/DataFrameStatFunctions.scala | 3 ++- .../apache/spark/sql/KeyValueGroupedDataset.scala | 8 ++++++-- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../spark/sql/execution/stat/StatFunctions.scala | 3 ++- .../spark/sql/execution/streaming/OffsetSeqLog.scala | 3 ++- .../continuous/ContinuousRateStreamSource.scala | 3 ++- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 3 ++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 6 ++++-- .../scala/org/apache/spark/sql/GenTPCDSData.scala | 3 ++- .../scala/org/apache/spark/sql/ParametersSuite.scala | 9 +++++---- .../sql/connector/SimpleWritableDataSource.scala | 4 +++- .../datasources/FileMetadataStructSuite.scala | 3 ++- .../sql/execution/datasources/csv/CSVBenchmark.scala | 7 ++++--- .../org/apache/spark/sql/streaming/StreamSuite.scala | 2 +- .../spark/sql/streaming/StreamingQuerySuite.scala | 3 ++- .../spark/sql/hive/thriftserver/CliSuite.scala | 3 ++- .../sql/hive/execution/AggregationQuerySuite.scala | 4 +++- .../hive/execution/ObjectHashAggregateSuite.scala | 8 +++++--- 45 files changed, 136 insertions(+), 69 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index dac89bf3eb5ac..2e6117abbf32c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -194,7 +194,9 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { SortExprs: Array[Column], f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - flatMapSortedGroups(SortExprs: _*)(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder) + import org.apache.spark.util.ArrayImplicits._ + flatMapSortedGroups(SortExprs.toImmutableArraySeq: _*)( + UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder) } /** @@ -458,8 +460,9 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { otherSortExprs: Array[Column], f: CoGroupFunction[K, V, U, R], encoder: Encoder[R]): Dataset[R] = { - cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)( - UdfUtils.coGroupFunctionToScalaFunc(f))(encoder) + import org.apache.spark.util.ArrayImplicits._ + cogroupSorted(other)(thisSortExprs.toImmutableArraySeq: _*)( + otherSortExprs.toImmutableArraySeq: _*)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder) } protected[sql] def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder]( diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala index c1e4399ccb054..0fb6894e457ae 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala @@ -207,5 +207,6 @@ class ColumnTestSuite extends ConnectFunSuite { private val structType1 = new StructType().add("a", "int").add("b", "string") private val structType2 = structType1.add("c", "binary") testColName(structType1, _.struct(structType1)) - testColName(structType2, _.struct(structType2.fields: _*)) + import org.apache.spark.util.ArrayImplicits._ + testColName(structType2, _.struct(structType2.fields.toImmutableArraySeq: _*)) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index baf65e7bb330a..f7ffe7aa12719 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -286,7 +286,10 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest { import session.implicits._ val df = Seq((1, 2, 3)).toDF("a", "b", "c") val f = udf((row: Row) => row.schema.fieldNames) - checkDataset(df.select(f(struct(df.columns map col: _*))), Row(Seq("a", "b", "c"))) + import org.apache.spark.util.ArrayImplicits._ + checkDataset( + df.select(f(struct((df.columns map col).toImmutableArraySeq: _*))), + Row(Seq("a", "b", "c"))) } test("Filter with row input encoder") { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 8b852babb544a..4925bc0a5dc16 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1178,11 +1178,12 @@ class SparkConnectPlanner( val normalized = normalize(schema).asInstanceOf[StructType] + import org.apache.spark.util.ArrayImplicits._ val project = Dataset .ofRows( session, logicalPlan = logical.LocalRelation(normalize(structType).asInstanceOf[StructType])) - .toDF(normalized.names: _*) + .toDF(normalized.names.toImmutableArraySeq: _*) .to(normalized) .logicalPlan .asInstanceOf[Project] diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 3eea0ebcdb2a6..e98259562c92f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -43,6 +43,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer} import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util._ +import org.apache.spark.util.ArrayImplicits._ private[spark] class PythonRDD( @@ -179,7 +180,7 @@ private[spark] object PythonRDD extends Logging { type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions.asScala.toSeq) - val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) + val flattenedPartition: UnrolledPartition = Array.concat(allPartitions.toImmutableArraySeq: _*) serveIterator(flattenedPartition.iterator, s"serve RDD ${rdd.id} with partitions ${partitions.asScala.mkString(",")}") } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index b12b5e2131214..e340667173b0b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -324,7 +324,8 @@ private[spark] class Executor( private val Seq(initialUserJars, initialUserFiles, initialUserArchives) = Seq("jar", "file", "archive").map { key => conf.getOption(s"spark.app.initial.$key.urls").map { urls => - immutable.Map(urls.split(",").map(url => (url, appStartTime)): _*) + import org.apache.spark.util.ArrayImplicits._ + immutable.Map(urls.split(",").map(url => (url, appStartTime)).toImmutableArraySeq: _*) }.getOrElse(immutable.Map.empty) } updateDependencies(initialUserFiles, initialUserJars, initialUserArchives, defaultSessionState) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 5dc666c62d1ad..610b48ea2ba50 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1044,7 +1044,8 @@ abstract class RDD[T: ClassTag]( */ def collect(): Array[T] = withScope { val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) - Array.concat(results: _*) + import org.apache.spark.util.ArrayImplicits._ + Array.concat(results.toImmutableArraySeq: _*) } /** diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index a8f9b32b0f3bf..5529da74970d6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.graphx -import scala.collection.mutable +import scala.collection.{immutable, mutable} import org.apache.spark._ import org.apache.spark.graphx._ @@ -51,7 +51,7 @@ object Analytics { case _ => throw new IllegalArgumentException(s"Invalid argument: $arg") } } - val options = mutable.Map(optionsList: _*) + val options = mutable.Map(immutable.ArraySeq.unsafeWrapArray(optionsList): _*) val conf = new SparkConf() GraphXUtils.registerKryoClasses(conf) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 52106f4010f93..b70f3ddd4c14d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -213,9 +213,10 @@ final class OneVsRestModel private[ml] ( tmpModel.asInstanceOf[ProbabilisticClassificationModel[_, _]].setProbabilityCol("") } + import org.apache.spark.util.ArrayImplicits._ tmpModel.transform(df) .withColumn(accColName, updateUDF(col(accColName), col(tmpRawPredName))) - .select(columns: _*) + .select(columns.toImmutableArraySeq: _*) } if (handlePersistence) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index f1268bdf6bd89..866bf9e5bf3fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -187,7 +187,9 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme } val metadata = outputSchema($(outputCol)).metadata - dataset.withColumn($(outputCol), hashFeatures(struct($(inputCols).map(col): _*)), metadata) + import org.apache.spark.util.ArrayImplicits._ + dataset.withColumn($(outputCol), + hashFeatures(struct($(inputCols).map(col).toImmutableArraySeq: _*)), metadata) } @Since("2.3.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 5998887923f8b..4d38c127d412d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -163,24 +163,26 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) } val numCols = cols.length + import org.apache.spark.util.ArrayImplicits._ val results = $(strategy) match { case Imputer.mean => // Function avg will ignore null automatically. // For a column only containing null, avg will return null. - val row = dataset.select(cols.map(avg): _*).head() + val row = dataset.select(cols.map(avg).toImmutableArraySeq: _*).head() Array.tabulate(numCols)(i => if (row.isNullAt(i)) Double.NaN else row.getDouble(i)) case Imputer.median => // Function approxQuantile will ignore null automatically. // For a column only containing null, approxQuantile will return an empty array. - dataset.select(cols: _*).stat.approxQuantile(inputColumns, Array(0.5), $(relativeError)) + dataset.select(cols.toImmutableArraySeq: _*) + .stat.approxQuantile(inputColumns, Array(0.5), $(relativeError)) .map(_.headOption.getOrElse(Double.NaN)) case Imputer.mode => import spark.implicits._ // If there is more than one mode, choose the smallest one to keep in line // with sklearn.impute.SimpleImputer (using scipy.stats.mode). - val modes = dataset.select(cols: _*).flatMap { row => + val modes = dataset.select(cols.toImmutableArraySeq: _*).flatMap { row => // Ignore null. Iterator.range(0, numCols) .flatMap(i => if (row.isNullAt(i)) None else Some((i, row.getDouble(i)))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 9a4f1d97c907a..a81c55a171571 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -108,9 +108,11 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType) } } + import org.apache.spark.util.ArrayImplicits._ dataset.select( col("*"), - interactFunc(struct(featureCols: _*)).as($(outputCol), featureAttrs.toMetadata())) + interactFunc(struct(featureCols.toImmutableArraySeq: _*)) + .as($(outputCol), featureAttrs.toMetadata())) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 9387ab3daeb8a..f3f85b4098672 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -476,7 +476,8 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str override def transform(dataset: Dataset[_]): DataFrame = { val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) - dataset.select(columnsToKeep.map(dataset.col): _*) + import org.apache.spark.util.ArrayImplicits._ + dataset.select(columnsToKeep.map(dataset.col).toImmutableArraySeq: _*) } override def transformSchema(schema: StructType): StructType = { @@ -564,7 +565,8 @@ private class VectorAttributeRewriter( } val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col) val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata) - dataset.select(otherCols :+ rewrittenCol : _*) + import org.apache.spark.util.ArrayImplicits._ + dataset.select((otherCols :+ rewrittenCol).toImmutableArraySeq : _*) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index cf5b5ecb20148..47c0ca22f9672 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -149,8 +149,9 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - - filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) + import org.apache.spark.util.ArrayImplicits._ + filteredDataset.select(col("*"), + assembleFunc(struct(args.toImmutableArraySeq: _*)).as($(outputCol), metadata)) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 7fe9aa414f2d5..081a40bfbe801 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -164,7 +164,8 @@ class FPGrowth @Since("2.2.0") ( instr.logPipelineStage(this) instr.logDataset(dataset) - instr.logParams(this, params: _*) + import org.apache.spark.util.ArrayImplicits._ + instr.logParams(this, params.toImmutableArraySeq: _*) val data = dataset.select($(itemsCol)) val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 5c98ffa394fe5..3ea76658d1a92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -135,7 +135,8 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params @Since("2.4.0") def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = instrumented { instr => instr.logDataset(dataset) - instr.logParams(this, params: _*) + import org.apache.spark.util.ArrayImplicits._ + instr.logParams(this, params.toImmutableArraySeq: _*) val sequenceColParam = $(sequenceCol) val inputType = dataset.schema(sequenceColParam).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KSTestWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KSTestWrapper.scala index 21531eb057ad3..234b8bbf6f0ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KSTestWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KSTestWrapper.scala @@ -49,7 +49,8 @@ private[r] object KSTestWrapper { case Row(feature: Double) => feature } - val ksTestResult = kolmogorovSmirnovTest(rddData, distName, distParams : _*) + import org.apache.spark.util.ArrayImplicits._ + val ksTestResult = kolmogorovSmirnovTest(rddData, distName, distParams.toImmutableArraySeq : _*) new KSTestWrapper(ksTestResult, distName, distParams) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index d9942f1c4f350..6c0089b689499 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -129,7 +129,8 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S instr.logPipelineStage(this) instr.logDataset(instances) - instr.logParams(this, params: _*) + import org.apache.spark.util.ArrayImplicits._ + instr.logParams(this, params.toImmutableArraySeq: _*) val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index cc917db98b328..47fb8bc92298a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -535,7 +535,8 @@ private[ml] object EnsembleModelReadWrite { val newNodeDataCol = df.schema("nodeData").dataType match { case StructType(fields) => val cols = fields.map(f => col(s"nodeData.${f.name}")) :+ lit(-1L).as("rawCount") - struct(cols: _*) + import org.apache.spark.util.ArrayImplicits._ + struct(cols.toImmutableArraySeq: _*) } df = df.withColumn("nodeData", newNodeDataCol) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 7bce38cc38a08..378f1381e4cf8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -359,7 +359,8 @@ object MLUtils extends Logging { col(c) } } - dataset.select(exprs: _*) + import org.apache.spark.util.ArrayImplicits._ + dataset.select(exprs.toImmutableArraySeq: _*) } /** @@ -411,7 +412,8 @@ object MLUtils extends Logging { col(c) } } - dataset.select(exprs: _*) + import org.apache.spark.util.ArrayImplicits._ + dataset.select(exprs.toImmutableArraySeq: _*) } /** @@ -461,7 +463,8 @@ object MLUtils extends Logging { col(c) } } - dataset.select(exprs: _*) + import org.apache.spark.util.ArrayImplicits._ + dataset.select(exprs.toImmutableArraySeq: _*) } /** @@ -511,7 +514,8 @@ object MLUtils extends Logging { col(c) } } - dataset.select(exprs: _*) + import org.apache.spark.util.ArrayImplicits._ + dataset.select(exprs.toImmutableArraySeq: _*) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index 5ef22a282c3a5..4873dacfc0f1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -339,9 +339,10 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { .setOutputCols(Array("out1")) val types = Seq(IntegerType, LongType) + import org.apache.spark.util.ArrayImplicits._ for (mType <- types) { // cast all columns to desired data type for testing - val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + val df2 = df.select(df.columns.map(c => col(c).cast(mType)).toImmutableArraySeq: _*) ImputerSuite.iterateStrategyTest(true, imputer, df2) } } @@ -360,9 +361,10 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { .setOutputCol("out") val types = Seq(IntegerType, LongType) + import org.apache.spark.util.ArrayImplicits._ for (mType <- types) { // cast all columns to desired data type for testing - val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + val df2 = df.select(df.columns.map(c => col(c).cast(mType)).toImmutableArraySeq: _*) ImputerSuite.iterateStrategyTest(false, imputer, df2) } } @@ -382,9 +384,10 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { .setMissingValue(-1.0) val types = Seq(IntegerType, LongType) + import org.apache.spark.util.ArrayImplicits._ for (mType <- types) { // cast all columns to desired data type for testing - val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + val df2 = df.select(df.columns.map(c => col(c).cast(mType)).toImmutableArraySeq: _*) ImputerSuite.iterateStrategyTest(true, imputer, df2) } } @@ -404,9 +407,10 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { .setMissingValue(-1.0) val types = Seq(IntegerType, LongType) + import org.apache.spark.util.ArrayImplicits._ for (mType <- types) { // cast all columns to desired data type for testing - val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + val df2 = df.select(df.columns.map(c => col(c).cast(mType)).toImmutableArraySeq: _*) ImputerSuite.iterateStrategyTest(false, imputer, df2) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala index 411e056bffb4c..32c5062544728 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala @@ -95,7 +95,8 @@ class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext { .select(substring_index(col("image.origin"), "/", -1).as("origin"), col("cls"), col("date")) .collect() - assert(Set(result: _*) === Set( + import org.apache.spark.util.ArrayImplicits._ + assert(Set(result.toImmutableArraySeq: _*) === Set( Row("29.5.a_b_EGDP022204.jpg", "kittens", "2018-01"), Row("54893.jpg", "kittens", "2018-02"), Row("DP153539.jpg", "kittens", "2018-02"), diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala index 1312de3a1b522..2ae21401538e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala @@ -60,9 +60,10 @@ class KolmogorovSmirnovTestSuite val cdf = (x: Double) => theoreticalDist.cumulativeProbability(x) KolmogorovSmirnovTest.test(sampledDF, "sample", cdf).head() } else { + import org.apache.spark.util.ArrayImplicits._ KolmogorovSmirnovTest.test(sampledDF, "sample", theoreticalDistByName._1, - theoreticalDistByName._2: _* + theoreticalDistByName._2.toImmutableArraySeq: _* ).head() } val theoreticalDistMath3 = if (theoreticalDist == null) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index b847c905e5f00..def04b5011873 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -112,13 +112,15 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => val columnsWithMetadata = dataframe.schema.map { structField => col(structField.name).as(structField.name, structField.metadata) } - val streamDF = stream.toDS().toDF(columnNames: _*).select(columnsWithMetadata: _*) + import org.apache.spark.util.ArrayImplicits._ + val streamDF = stream.toDS() + .toDF(columnNames.toImmutableArraySeq: _*).select(columnsWithMetadata: _*) val data = dataframe.as[A].collect() val streamOutput = transformer.transform(streamDF) .select(firstResultCol, otherResultCols: _*) testStream(streamOutput) ( - AddData(stream, data: _*), + AddData(stream, data.toImmutableArraySeq: _*), CheckAnswer(globalCheckFunction) ) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala index 4c38989955b80..b12508573b7cc 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStepSuite.scala @@ -79,7 +79,7 @@ class DriverCommandFeatureStepSuite extends SparkFunSuite { ( envPy.map(v => ENV_PYSPARK_PYTHON -> v :: Nil) ++ envDriverPy.map(v => ENV_PYSPARK_DRIVER_PYTHON -> v :: Nil) - ).flatten.toArray: _*) + ).flatten.toSeq: _*) val spec = applyFeatureStep( PythonMainAppResource(mainResource), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 49cf01d472ebb..b4be09f333d5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -234,13 +234,15 @@ case class Stack(children: Seq[Expression]) extends Generator { override def eval(input: InternalRow): IterableOnce[InternalRow] = { val values = children.tail.map(_.eval(input)).toArray + + import org.apache.spark.util.ArrayImplicits._ for (row <- 0 until numRows) yield { val fields = new Array[Any](numFields) for (col <- 0 until numFields) { val index = row * numFields + col fields.update(col, if (index < values.length) values(index) else null) } - InternalRow(fields: _*) + InternalRow(fields.toImmutableArraySeq: _*) } } @@ -293,12 +295,14 @@ case class ReplicateRows(children: Seq[Expression]) extends Generator with Codeg override def eval(input: InternalRow): IterableOnce[InternalRow] = { val numRows = children.head.eval(input).asInstanceOf[Long] val values = children.tail.map(_.eval(input)).toArray + + import org.apache.spark.util.ArrayImplicits._ Range.Long(0, numRows, 1).map { _ => val fields = new Array[Any](numColumns) for (col <- 0 until numColumns) { fields.update(col, values(col)) } - InternalRow(fields: _*) + InternalRow(fields.toImmutableArraySeq: _*) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index cbab8894cb553..44264a846630e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -314,7 +314,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB val row = new SpecificInternalRow(fieldTypes) val values = Array(new CalendarInterval(0, 7, 0L), null) - row.update(0, createArray(values: _*)) + + import org.apache.spark.util.ArrayImplicits._ + row.update(0, createArray(values.toImmutableArraySeq: _*)) val unsafeRow: UnsafeRow = converter.apply(row) testArrayInterval(unsafeRow.getArray(0), values) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index a8c4d4f8d2ba7..2f2857760526d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -98,8 +98,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { cols: Array[String], probabilities: Array[Double], relativeError: Double): Array[Array[Double]] = withOrigin { + import org.apache.spark.util.ArrayImplicits._ StatFunctions.multipleApproxQuantiles( - df.select(cols.map(col): _*), + df.select(cols.map(col).toImmutableArraySeq: _*), cols, probabilities, relativeError).map(_.toArray).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index ef0a3e0266c4a..22dfed3ea4c5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -241,7 +241,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( SortExprs: Array[Column], f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - flatMapSortedGroups(SortExprs: _*)((key, data) => f.call(key, data.asJava).asScala)(encoder) + import org.apache.spark.util.ArrayImplicits._ + flatMapSortedGroups( + SortExprs.toImmutableArraySeq: _*)((key, data) => f.call(key, data.asJava).asScala)(encoder) } /** @@ -901,7 +903,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( otherSortExprs: Array[Column], f: CoGroupFunction[K, V, U, R], encoder: Encoder[R]): Dataset[R] = { - cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)( + import org.apache.spark.util.ArrayImplicits._ + cogroupSorted(other)( + thisSortExprs.toImmutableArraySeq: _*)(otherSortExprs.toImmutableArraySeq: _*)( (key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 53b09179cc3a1..934ed9ac2a1bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -89,7 +89,7 @@ object JDBCRDD extends Logging { * @return A Catalyst schema corresponding to columns in the given order. */ private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { - val fieldMap = Map(schema.fields.map(x => x.name -> x): _*) + val fieldMap = schema.fields.map(x => x.name -> x).toMap new StructType(columns.map(name => fieldMap(name))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index e7f1affbde44d..db26f8c7758e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -253,9 +253,10 @@ object StatFunctions extends Logging { val valueColumns = columnNames.map { columnName => new Column(ElementAt(col(columnName).expr, col("summary").expr)).as(columnName) } + import org.apache.spark.util.ArrayImplicits._ ds.select(mapColumns: _*) .withColumn("summary", explode(lit(selectedStatistics))) - .select(Array(col("summary")) ++ valueColumns: _*) + .select((Array(col("summary")) ++ valueColumns).toImmutableArraySeq: _*) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index 5646f61440e77..7e490ef4cd53d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -64,7 +64,8 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) case "" => None case md => Some(md) } - OffsetSeq.fill(metadata, lines.map(parseOffset).toArray: _*) + import org.apache.spark.util.ArrayImplicits._ + OffsetSeq.fill(metadata, lines.map(parseOffset).toArray.toImmutableArraySeq: _*) } override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 08840496b052b..132d9a9d61e57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -42,7 +42,8 @@ class RateStreamContinuousStream(rowsPerSecond: Long, numPartitions: Int) extend case RateStreamPartitionOffset(i, currVal, nextRead) => (i, ValueRunTimeMsPair(currVal, nextRead)) } - RateStreamOffset(Map(tuples: _*)) + import org.apache.spark.util.ArrayImplicits._ + RateStreamOffset(Map(tuples.toImmutableArraySeq: _*)) } override def deserializeOffset(json: String): Offset = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d3271283baa33..b28a23f13f86d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1182,8 +1182,9 @@ class DataFrameSuite extends QueryTest } test("summary advanced") { + import org.apache.spark.util.ArrayImplicits._ val stats = Array("count", "50.01%", "max", "mean", "min", "25%") - val orderMatters = person2.summary(stats: _*) + val orderMatters = person2.summary(stats.toImmutableArraySeq: _*) assert(orderMatters.collect().map(_.getString(0)) === stats) val onlyPercentiles = person2.summary("0.1%", "99.9%") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dcbd8948120ce..9285c31d70253 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -2058,13 +2058,14 @@ class DatasetSuite extends QueryTest } test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") { + import org.apache.spark.util.ArrayImplicits._ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { val a = Seq(Some(1)).toDS() val b = Seq(Some(1.2)).toDS() val expected = Seq((Some(1), Some(1.2))).toDS() val joined = a.joinWith(b, lit(true)) assert(joined.schema == expected.schema) - checkDataset(joined, expected.collect(): _*) + checkDataset(joined, expected.collect().toImmutableArraySeq: _*) } } @@ -2078,7 +2079,8 @@ class DatasetSuite extends QueryTest val ds1 = spark.createDataset(rdd) val ds2 = spark.createDataset(rdd)(encoder) assert(ds1.schema == ds2.schema) - checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) + import org.apache.spark.util.ArrayImplicits._ + checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect().toImmutableArraySeq: _*) } test("SPARK-23862: Spark ExpressionEncoder should support Java Enum type from Scala") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala b/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala index 6768c5fd07b3e..5f95ba4f38547 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala @@ -169,7 +169,8 @@ class TPCDSTables(spark: SparkSession, dsdgenDir: String, scaleFactor: Int) } c.as(f.name) } - stringData.select(columns: _*) + import org.apache.spark.util.ArrayImplicits._ + stringData.select(columns.toImmutableArraySeq: _*) } convertedData diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index afbe9cdac6366..974def7f3b85e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -532,6 +532,7 @@ class ParametersSuite extends QueryTest with SharedSparkSession { } test("SPARK-45033: maps as parameters") { + import org.apache.spark.util.ArrayImplicits._ def fromArr(keys: Array[_], values: Array[_]): Column = { map_from_arrays(Column(Literal(keys)), Column(Literal(values))) } @@ -540,21 +541,21 @@ class ParametersSuite extends QueryTest with SharedSparkSession { } def createMap(keys: Array[_], values: Array[_]): Column = { val zipped = keys.map(k => Column(Literal(k))).zip(values.map(v => Column(Literal(v)))) - map(zipped.map { case (k, v) => Seq(k, v) }.flatten: _*) + map(zipped.flatMap { case (k, v) => Seq(k, v) }.toImmutableArraySeq: _*) } def callMap(keys: Array[_], values: Array[_]): Column = { val zipped = keys.map(k => Column(Literal(k))).zip(values.map(v => Column(Literal(v)))) - call_function("map", zipped.map { case (k, v) => Seq(k, v) }.flatten: _*) + call_function("map", zipped.flatMap { case (k, v) => Seq(k, v) }.toImmutableArraySeq: _*) } def fromEntries(keys: Array[_], values: Array[_]): Column = { val structures = keys.zip(values) .map { case (k, v) => struct(Column(Literal(k)), Column(Literal(v)))} - map_from_entries(array(structures: _*)) + map_from_entries(array(structures.toImmutableArraySeq: _*)) } def callFromEntries(keys: Array[_], values: Array[_]): Column = { val structures = keys.zip(values) .map { case (k, v) => struct(Column(Literal(k)), Column(Literal(v)))} - call_function("map_from_entries", call_function("array", structures: _*)) + call_function("map_from_entries", call_function("array", structures.toImmutableArraySeq: _*)) } Seq(fromArr(_, _), createMap(_, _), callFromArr(_, _), callMap(_, _)).foreach { f => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 235a8ff3869bd..de8cf7a7b2d7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -157,6 +157,7 @@ class CSVReaderFactory(conf: SerializableConfiguration) val fs = filePath.getFileSystem(conf.value) new PartitionReader[InternalRow] { + import org.apache.spark.util.ArrayImplicits._ private val inputStream = fs.open(filePath) private val lines = new BufferedReader(new InputStreamReader(inputStream)) .lines().iterator().asScala @@ -172,7 +173,8 @@ class CSVReaderFactory(conf: SerializableConfiguration) } } - override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toInt): _*) + override def get(): InternalRow = + InternalRow(currentLine.split(",").map(_.trim.toInt).toImmutableArraySeq: _*) override def close(): Unit = { inputStream.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala index 0e4985bac9941..6bf72b82564ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala @@ -1083,7 +1083,8 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { // Transform the result into a literal that can be used in an expression. val metadataColumnFields = metadataColumnRow.schema.fields .map(field => lit(metadataColumnRow.getAs[Any](field.name)).as(field.name)) - val metadataColumnStruct = struct(metadataColumnFields: _*) + import org.apache.spark.util.ArrayImplicits._ + val metadataColumnStruct = struct(metadataColumnFields.toImmutableArraySeq: _*) val selectSingleRowDf = spark.read.load(dir.getAbsolutePath) .where(col("_metadata").equalTo(lit(metadataColumnStruct))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmark.scala index eb561e13fc6de..e9cf35d9fab9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmark.scala @@ -70,6 +70,7 @@ object CSVBenchmark extends SqlBasedBenchmark { val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum, output = output) withTempPath { path => + import org.apache.spark.util.ArrayImplicits._ val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) val schema = StructType(fields) val values = (0 until colsNum).map(i => i.toString).mkString(",") @@ -87,7 +88,7 @@ object CSVBenchmark extends SqlBasedBenchmark { } val cols100 = columnNames.take(100).map(Column(_)) benchmark.addCase(s"Select 100 columns", numIters) { _ => - ds.select(cols100: _*).noop() + ds.select(cols100.toImmutableArraySeq: _*).noop() } benchmark.addCase(s"Select one column", numIters) { _ => ds.select($"col1").noop() @@ -100,7 +101,7 @@ object CSVBenchmark extends SqlBasedBenchmark { (1 until colsNum).map(i => StructField(s"col$i", IntegerType))) val dsErr1 = spark.read.schema(schemaErr1).csv(path.getAbsolutePath) benchmark.addCase(s"Select 100 columns, one bad input field", numIters) { _ => - dsErr1.select(cols100: _*).noop() + dsErr1.select(cols100.toImmutableArraySeq: _*).noop() } val badRecColName = "badRecord" @@ -109,7 +110,7 @@ object CSVBenchmark extends SqlBasedBenchmark { .option("columnNameOfCorruptRecord", badRecColName) .csv(path.getAbsolutePath) benchmark.addCase(s"Select 100 columns, corrupt record field", numIters) { _ => - dsErr2.select((Column(badRecColName) +: cols100): _*).noop() + dsErr2.select((Column(badRecColName) +: cols100).toImmutableArraySeq: _*).noop() } benchmark.run() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 66d37e996a6cf..953bbddf6abbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -224,7 +224,7 @@ class StreamSuite extends StreamTest { // Parquet write page-level CRC checksums will change the file size and // affect the data order when reading these files. Please see PARQUET-1746 for details. val outputDf = spark.read.parquet(outputDir.getAbsolutePath).sort($"a").as[Long] - checkDataset[Long](outputDf, (0L to 10L).toArray: _*) + checkDataset[Long](outputDf, 0L to 10L: _*) } finally { query.stop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1e0fa5b6bc9de..e388de214056a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -891,7 +891,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } override def getOffset: Option[Offset] = Some(LongOffset(1)) override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - spark.range(2).toDF(MockSourceProvider.fakeSchema.fieldNames: _*) + import org.apache.spark.util.ArrayImplicits._ + spark.range(2).toDF(MockSourceProvider.fakeSchema.fieldNames.toImmutableArraySeq: _*) } override def schema: StructType = MockSourceProvider.fakeSchema } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 110ef7b0affab..649c985cade98 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -703,6 +703,7 @@ class CliSuite extends SparkFunSuite { testRetry("formats of error messages") { def check(format: ErrorMessageFormat.Value, errorMessage: String, silent: Boolean): Unit = { val expected = errorMessage.split(System.lineSeparator()).map("" -> _) + import org.apache.spark.util.ArrayImplicits._ runCliWithin( 1.minute, extraArgs = Seq( @@ -710,7 +711,7 @@ class CliSuite extends SparkFunSuite { "--conf", s"${SQLConf.ERROR_MESSAGE_FORMAT.key}=$format", "--conf", s"${SQLConf.ANSI_ENABLED.key}=true", "-e", "select 1 / 0"), - errorResponses = Seq("DIVIDE_BY_ZERO"))(expected: _*) + errorResponses = Seq("DIVIDE_BY_ZERO"))(expected.toImmutableArraySeq: _*) } check( format = ErrorMessageFormat.PRETTY, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index dc8b184fcee32..4b000fff0eb92 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -941,8 +941,10 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te data .find(r => r.getInt(0) == 50) .getOrElse(fail("A row with id 50 should be the expected answer.")) + + import org.apache.spark.util.ArrayImplicits._ checkAnswer( - df.agg(udaf(allColumns: _*)), + df.agg(udaf(allColumns.toImmutableArraySeq: _*)), // udaf returns a Row as the output value. Row(expectedAnswer) ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index cc95de793ee4d..4e2db21403599 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -152,13 +152,15 @@ class ObjectHashAggregateSuite val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) val aggFunctions = schema.fieldNames.map(f => typed_count(col(f))) + import org.apache.spark.util.ArrayImplicits._ checkAnswer( - df.agg(aggFunctions.head, aggFunctions.tail: _*), + df.agg(aggFunctions.head, aggFunctions.tail.toImmutableArraySeq: _*), Row.fromSeq(data.map(_.toSeq).transpose.map(_.count(_ != null): Long)) ) checkAnswer( - df.groupBy($"id" % 4 as "mod").agg(aggFunctions.head, aggFunctions.tail: _*), + df.groupBy($"id" % 4 as "mod") + .agg(aggFunctions.head, aggFunctions.tail.toImmutableArraySeq: _*), data.groupBy(_.getInt(0) % 4).map { case (key, value) => key -> Row.fromSeq(value.map(_.toSeq).transpose.map(_.count(_ != null): Long)) }.toSeq.map { @@ -168,7 +170,7 @@ class ObjectHashAggregateSuite withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "5") { checkAnswer( - df.agg(aggFunctions.head, aggFunctions.tail: _*), + df.agg(aggFunctions.head, aggFunctions.tail.toImmutableArraySeq: _*), Row.fromSeq(data.map(_.toSeq).transpose.map(_.count(_ != null): Long)) ) } From 917947e62e1e67f49a83c1ffb0833b61f0c48eb6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 10 Nov 2023 07:50:17 -0800 Subject: [PATCH 090/121] [SPARK-45883][BUILD] Upgrade ORC to 1.9.2 ### What changes were proposed in this pull request? This PR aims to upgrade ORC to 1.9.2 for Apache Spark 4.0.0 and 3.5.1. ### Why are the changes needed? To bring the latest bug fixes. - https://github.com/apache/orc/releases/tag/v1.9.2 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43754 from dongjoon-hyun/SPARK-45883. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 6 +++--- pom.xml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index b7d6bdbfd1299..0a952aa6ee881 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -218,9 +218,9 @@ opencsv/2.3//opencsv-2.3.jar opentracing-api/0.33.0//opentracing-api-0.33.0.jar opentracing-noop/0.33.0//opentracing-noop-0.33.0.jar opentracing-util/0.33.0//opentracing-util-0.33.0.jar -orc-core/1.9.1/shaded-protobuf/orc-core-1.9.1-shaded-protobuf.jar -orc-mapreduce/1.9.1/shaded-protobuf/orc-mapreduce-1.9.1-shaded-protobuf.jar -orc-shims/1.9.1//orc-shims-1.9.1.jar +orc-core/1.9.2/shaded-protobuf/orc-core-1.9.2-shaded-protobuf.jar +orc-mapreduce/1.9.2/shaded-protobuf/orc-mapreduce-1.9.2-shaded-protobuf.jar +orc-shims/1.9.2//orc-shims-1.9.2.jar oro/2.0.8//oro-2.0.8.jar osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar paranamer/2.8//paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index 71d8ffcc1c909..14754c0bcaa4f 100644 --- a/pom.xml +++ b/pom.xml @@ -141,7 +141,7 @@ 10.14.2.0 1.13.1 - 1.9.1 + 1.9.2 shaded-protobuf 9.4.53.v20231009 4.0.3 From 0a791993be7b6f4b843887403460ef9aebe3daf9 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 10 Nov 2023 18:46:46 -0600 Subject: [PATCH 091/121] [SPARK-45686][INFRA][CORE][SQL][SS][CONNECT][MLLIB][DSTREAM][AVRO][ML][K8S][YARN][PYTHON][R][UI][GRAPHX][PROTOBUF][TESTS][EXAMPLES] Explicitly convert `Array` to `Seq` when function input is defined as `Seq` to avoid compilation warnings related to `class LowPriorityImplicits2 is deprecated` ### What changes were proposed in this pull request? This is pr change to explicitly convert `Array` to `Seq` when function input is defined as `Seq `to avoid compilation warnings as like follwos: ``` [error] /Users/yangjie01/SourceCode/git/spark-mine-sbt/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala:57:31: method copyArrayToImmutableIndexedSeq in class LowPriorityImplicits2 is deprecated (since 2.13.0): implicit conversions from Array to immutable.IndexedSeq are implemented by copying; use `toIndexedSeq` explicitly if you want to copy, or use the more efficient non-copying ArraySeq.unsafeWrapArray [error] Applicable -Wconf / nowarn filters for this fatal warning: msg=, cat=deprecation, site=org.apache.spark.ml.linalg.Vector.equals, origin=scala.LowPriorityImplicits2.copyArrayToImmutableIndexedSeq, version=2.13.0 [error] Vectors.equals(s1.indices, s1.values, s2.indices, s2.values) [error] ^ ``` There are mainly three ways to fix it: - `tools` and `mllib-local` module: Since the `tools` and `mllib-local` module does not import the `common-utils` module, `scala.collection.immutable.ArraySeq.unsafeWrapArray` is used directly. - `examples` module: Since `ArrayImplicits` is an internal tool class in Spark, `scala.collection.immutable.ArraySeq.unsafeWrapArray` is used directly. - Other modules: By importing `ArrayImplicits` and calling `toImmutableArraySeq`, the `Array` is wrapped into `immutable.ArraySeq`. ### Why are the changes needed? Clean up deprecated Scala Api usage and using `Array.toImmutableArraySeq` equivalent to `immutable.ArraySeq.unsafeWrapArray` to avoid collection copy. Why use `ArraySeq.unsafeWrapArray` instead of `toIndexedSeq`: 1. `ArraySeq.unsafeWrapArray` saves the overhead of collection copying compared to `toIndexedSeq`, it has less memory overhead and certain performance advantages. Moreover, `ArraySeq.unsafeWrapArray` is faster in scenarios such as - `Array.fill.toImmutableArraySeq` versus `IndexedSeq.fill` - `Array.apply(data).toImmutableArraySeq` versus `IndexedSeq.apply(data)` - `Array.emptyXXArray.toImmutableArraySeq` versus `IndexedSeq.empty`. 2. In Scala 2.12, when the function is defined as ``` def func(input: Seq[T]): R = { ... } ``` if an `Array` type data array is used as the function input, it will be implicitly converted by default through the `scala.Predef#genericArrayOps` function, the specific implementation is as follows: ```scala implicit def genericArrayOps[T](xs: Array[T]): ArrayOps[T] = (xs match { case x: Array[AnyRef] => refArrayOps[AnyRef](x) case x: Array[Boolean] => booleanArrayOps(x) case x: Array[Byte] => byteArrayOps(x) case x: Array[Char] => charArrayOps(x) case x: Array[Double] => doubleArrayOps(x) case x: Array[Float] => floatArrayOps(x) case x: Array[Int] => int(x case x: Array[Long] => longArrayOps(x) case x: Array[Short] => shortArrayOps(x) case x: Array[Unit] => unitArrayOps(x) case null => null }).asInstanceOf[ArrayOps[T]] implicit def booleanArrayOps(xs: Array[Boolean]): ArrayOps.ofBoolean = new ArrayOps.ofBoolean(xs) implicit def byteArrayOps(xs: Array[Byte]): ArrayOps.ofByte = new ArrayOps.ofByte(xs) implicit def charArrayOps(xs: Array[Char]): ArrayOps.ofChar = new ArrayOps.ofChar(xs) implicit def doubleArrayOps(xs: Array[Double]): ArrayOps.ofDouble = new ArrayOps.ofDouble(xs) implicit def floatArrayOps(xs: Array[Float]): ArrayOps.ofFloat = new ArrayOps.ofFloat(xs) implicit def intArrayOps(xs: Array[Int]): ArrayOps.ofInt = new ArrayOps.ofInt(xs implicit def longArrayOps(xs: Array[Long]): ArrayOps.ofLong = new ArrayOps.ofLong(xs) implicit def refArrayOps[T <: AnyRef](xs: Array[T]): ArrayOps.ofRef[T] = new ArrayOps.ofRef[T](xs) implicit def shortArrayOps(xs: Array[Short]): ArrayOps.ofShort = new ArrayOps.ofShort(xs) implicit def unitArrayOps(xs: Array[Unit]): ArrayOps.ofUnit = new ArrayOps.ofUnit(xs) ``` This implicit conversion will wrap the input data into a `mutable.WrappedArray`, for example for Array[Int] type data, it will be wrapped into `mutable.WrappedArray.ofInt`: ```scala final class ofInt(override val repr: Array[Int]) extends AnyVal with ArrayOps[Int] with ArrayLike[Int, Array[Int]] { override protected[this] def thisCollection: WrappedArray[Int] = new WrappedArray.ofInt(repr) override protected[this] def toCollection(repr: Array[Int]): WrappedArray[Int] = new WrappedArray.ofInt(repr) override protected[this] def newBuilder = new ArrayBuilder.ofInt def length: Int = repr.length def apply(index: Int): Int = repr(index) def update(index: Int, elem: Int) { repr(index) = elem } } final class ofInt(val array: Array[Int]) extends WrappedArray[Int] with Serializable { def elemTag = ClassTag.Int def length: Int = array.length def apply(index: Int): Int = array(index) def update(index: Int, elem: Int) { array(index) = elem } override def hashCode = MurmurHash3.wrappedArrayHash(array) override def equals(that: Any) = that match { case that: ofInt => Arrays.equals(array, that.array) case _ => super.equals(that) } } ``` As we can see, in Scala 2.12, Array type input will be implicitly converted into a `mutable.WrappedArray`, and no collection copying is performed. In Scala 2.13, although the default implicit type conversion will perform a defensive collection copy, but based on the facts that existed when Spark using Scala 2.12, we can assume that it is still safe to explicitly wrap Array type input into an `immutable.ArraySeq` without collection copying in Scala 2.13. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43670 from LuciferYang/SPARK-45686-2. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Sean Owen --- .../org/apache/spark/util/MavenUtils.scala | 5 +- .../apache/spark/sql/v2/avro/AvroScan.scala | 5 +- .../AvroCatalystDataConversionSuite.scala | 4 +- .../org/apache/spark/sql/test/QueryTest.scala | 7 ++- .../spark/sql/test/RemoteSparkSession.scala | 3 +- .../sql/connect/client/ArtifactManager.scala | 3 +- .../connect/planner/SparkConnectPlanner.scala | 10 ++-- .../SparkConnectInterceptorRegistry.scala | 4 +- .../SparkConnectSessionHolderSuite.scala | 3 +- .../sql/kafka010/KafkaMicroBatchStream.scala | 3 +- .../sql/kafka010/KafkaOffsetReaderAdmin.scala | 3 +- .../kafka010/KafkaOffsetReaderConsumer.scala | 3 +- .../spark/sql/kafka010/KafkaSource.scala | 3 +- .../sql/kafka010/KafkaSourceProvider.scala | 3 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../spark/sql/kafka010/KafkaSinkSuite.scala | 5 +- .../spark/sql/kafka010/KafkaTestUtils.scala | 5 ++ .../spark/streaming/kafka010/KafkaRDD.scala | 3 +- .../streaming/kinesis/KinesisReceiver.scala | 4 +- .../kinesis/KinesisBackedBlockRDDSuite.scala | 9 ++-- .../ProtobufCatalystDataConversionSuite.scala | 4 +- .../spark/ExecutorAllocationClient.scala | 3 +- .../scala/org/apache/spark/SparkConf.scala | 3 +- .../scala/org/apache/spark/SparkContext.scala | 3 +- .../scala/org/apache/spark/SparkEnv.scala | 3 +- .../apache/spark/api/java/JavaRDDLike.scala | 3 +- .../scala/org/apache/spark/api/r/RRDD.scala | 3 +- .../org/apache/spark/deploy/Client.scala | 4 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 3 +- .../org/apache/spark/deploy/SparkSubmit.scala | 5 +- .../deploy/StandaloneResourceUtils.scala | 3 +- .../deploy/history/FsHistoryProvider.scala | 3 +- .../master/FileSystemPersistenceEngine.scala | 3 +- .../apache/spark/deploy/master/Master.scala | 4 +- .../spark/deploy/master/WorkerInfo.scala | 5 +- .../spark/executor/ProcfsMetricsGetter.scala | 5 +- .../spark/internal/io/SparkHadoopWriter.scala | 4 +- .../org/apache/spark/rdd/CoalescedRDD.scala | 7 +-- .../org/apache/spark/rdd/HadoopRDD.scala | 5 +- .../spark/rdd/LocalRDDCheckpointData.scala | 3 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 3 +- .../apache/spark/rdd/PairRDDFunctions.scala | 5 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 4 +- .../spark/resource/ResourceInformation.scala | 3 +- .../apache/spark/resource/ResourceUtils.scala | 6 ++- .../apache/spark/scheduler/DAGScheduler.scala | 9 ++-- .../apache/spark/scheduler/MergeStatus.scala | 3 +- .../CoarseGrainedSchedulerBackend.scala | 19 ++++--- .../apache/spark/status/AppStatusStore.scala | 7 +-- .../apache/spark/storage/BlockManager.scala | 3 +- .../storage/BlockManagerMasterEndpoint.scala | 3 +- .../spark/storage/DiskBlockManager.scala | 3 +- .../apache/spark/util/DependencyUtils.scala | 3 +- .../org/apache/spark/util/HadoopFSUtils.scala | 13 ++--- .../org/apache/spark/util/JsonProtocol.scala | 5 +- .../scala/org/apache/spark/util/Utils.scala | 12 +++-- .../util/logging/RollingFileAppender.scala | 4 +- .../org/apache/spark/CheckpointSuite.scala | 7 +-- .../scala/org/apache/spark/FileSuite.scala | 4 +- .../org/apache/spark/PartitioningSuite.scala | 8 +-- .../scala/org/apache/spark/ShuffleSuite.scala | 5 +- .../apache/spark/SparkContextInfoSuite.scala | 11 ++-- .../org/apache/spark/SparkContextSuite.scala | 9 ++-- .../org/apache/spark/UnpersistSuite.scala | 4 +- .../deploy/DecommissionWorkerSuite.scala | 2 +- .../StandaloneDynamicAllocationSuite.scala | 2 +- .../spark/deploy/client/AppClientSuite.scala | 2 +- .../history/EventLogFileWritersSuite.scala | 3 +- .../deploy/history/HistoryServerSuite.scala | 3 +- .../rest/StandaloneRestSubmitSuite.scala | 3 +- .../org/apache/spark/rdd/DoubleRDDSuite.scala | 3 +- .../spark/rdd/PairRDDFunctionsSuite.scala | 7 +-- .../rdd/ParallelCollectionSplitSuite.scala | 15 +++--- .../rdd/PartitionwiseSampledRDDSuite.scala | 3 +- .../org/apache/spark/rdd/PipedRDDSuite.scala | 19 +++---- .../scala/org/apache/spark/rdd/RDDSuite.scala | 31 +++++------ .../org/apache/spark/rdd/SortingSuite.scala | 27 +++++----- .../spark/scheduler/DAGSchedulerSuite.scala | 5 +- .../scheduler/EventLoggingListenerSuite.scala | 3 +- .../scheduler/ExecutorResourceInfoSuite.scala | 15 +++--- .../spark/scheduler/TaskContextSuite.scala | 3 +- .../spark/scheduler/TaskSetManagerSuite.scala | 5 +- .../serializer/KryoSerializerSuite.scala | 6 ++- .../status/ListenerEventsTestHelper.scala | 4 +- ...kManagerDecommissionIntegrationSuite.scala | 3 +- .../spark/storage/BlockManagerSuite.scala | 3 +- .../spark/util/ClosureCleanerSuite.scala | 17 ++++--- .../util/MutableURLClassLoaderSuite.scala | 3 +- .../org/apache/spark/examples/SparkLR.scala | 5 +- .../spark/examples/ml/BinarizerExample.scala | 4 +- .../spark/examples/ml/BucketizerExample.scala | 8 ++- .../apache/spark/examples/ml/PCAExample.scala | 5 +- .../ml/PolynomialExpansionExample.scala | 5 +- .../ml/QuantileDiscretizerExample.scala | 4 +- .../examples/mllib/CorrelationsExample.scala | 8 ++- .../mllib/PCAOnRowMatrixExample.scala | 4 +- .../spark/examples/mllib/SVDExample.scala | 4 +- .../examples/streaming/RawNetworkGrep.scala | 4 +- .../apache/spark/graphx/GraphOpsSuite.scala | 3 +- .../StronglyConnectedComponentsSuite.scala | 3 +- .../org/apache/spark/ml/linalg/Vectors.scala | 12 +++-- .../apache/spark/ml/linalg/VectorsSuite.scala | 8 +-- .../spark/ml/clustering/GaussianMixture.scala | 4 +- .../apache/spark/ml/clustering/KMeans.scala | 5 +- .../org/apache/spark/ml/clustering/LDA.scala | 4 +- .../feature/BucketedRandomProjectionLSH.scala | 3 +- .../spark/ml/feature/CountVectorizer.scala | 3 +- .../org/apache/spark/ml/feature/Imputer.scala | 5 +- .../apache/spark/ml/feature/Instance.scala | 3 +- .../apache/spark/ml/feature/Interaction.scala | 7 +-- .../spark/ml/feature/OneHotEncoder.scala | 6 ++- .../spark/ml/feature/RFormulaParser.scala | 3 +- .../spark/ml/feature/StopWordsRemover.scala | 6 ++- .../spark/ml/feature/StringIndexer.scala | 12 +++-- .../apache/spark/ml/feature/Tokenizer.scala | 3 +- .../spark/ml/feature/VectorAssembler.scala | 4 +- .../apache/spark/ml/feature/Word2Vec.scala | 7 ++- .../org/apache/spark/ml/tree/treeModels.scala | 6 ++- .../clustering/BisectingKMeansModel.scala | 13 ++--- .../clustering/GaussianMixtureModel.scala | 4 +- .../spark/mllib/clustering/KMeansModel.scala | 15 +++--- .../spark/mllib/feature/ChiSqSelector.scala | 4 +- .../apache/spark/mllib/linalg/Vectors.scala | 8 +-- .../mllib/tree/model/treeEnsembleModels.scala | 8 +-- .../spark/mllib/util/MFDataGenerator.scala | 5 +- .../org/apache/spark/ml/ann/ANNSuite.scala | 5 +- .../DecisionTreeClassifierSuite.scala | 35 ++++++++----- .../ml/classification/FMClassifierSuite.scala | 3 +- .../classification/GBTClassifierSuite.scala | 13 +++-- .../ml/classification/LinearSVCSuite.scala | 3 +- .../ml/classification/NaiveBayesSuite.scala | 5 +- .../RandomForestClassifierSuite.scala | 13 +++-- .../spark/ml/feature/BinarizerSuite.scala | 3 +- .../spark/ml/feature/ChiSqSelectorSuite.scala | 10 ++-- .../ml/feature/CountVectorizerSuite.scala | 3 +- .../apache/spark/ml/feature/IDFSuite.scala | 3 +- .../apache/spark/ml/feature/PCASuite.scala | 11 ++-- .../ml/feature/QuantileDiscretizerSuite.scala | 12 +++-- .../UnivariateFeatureSelectorSuite.scala | 10 ++-- .../VarianceThresholdSelectorSuite.scala | 7 +-- .../spark/ml/feature/VectorSlicerSuite.scala | 4 +- .../BinaryLogisticBlockAggregatorSuite.scala | 22 ++++---- .../HingeBlockAggregatorSuite.scala | 18 ++++--- .../HuberBlockAggregatorSuite.scala | 20 +++++--- .../LeastSquaresBlockAggregatorSuite.scala | 9 ++-- ...tinomialLogisticBlockAggregatorSuite.scala | 22 ++++---- .../spark/ml/recommendation/ALSSuite.scala | 3 +- .../DecisionTreeRegressorSuite.scala | 6 ++- .../ml/regression/GBTRegressorSuite.scala | 12 +++-- .../RandomForestRegressorSuite.scala | 5 +- .../spark/ml/stat/ChiSquareTestSuite.scala | 3 +- .../spark/ml/stat/FValueTestSuite.scala | 7 +-- .../ml/stat/KolmogorovSmirnovTestSuite.scala | 5 +- .../spark/ml/tree/impl/BaggedPointSuite.scala | 11 ++-- .../tree/impl/GradientBoostedTreesSuite.scala | 7 ++- .../ml/tree/impl/RandomForestSuite.scala | 33 ++++++------ .../apache/spark/ml/tree/impl/TreeTests.scala | 3 +- .../org/apache/spark/ml/util/MLTest.scala | 3 +- .../LogisticRegressionSuite.scala | 19 ++++--- .../classification/NaiveBayesSuite.scala | 7 ++- .../spark/mllib/classification/SVMSuite.scala | 9 ++-- .../clustering/GaussianMixtureSuite.scala | 11 ++-- .../spark/mllib/clustering/LDASuite.scala | 27 +++++----- .../BinaryClassificationMetricsSuite.scala | 15 +++--- .../mllib/feature/ChiSqSelectorSuite.scala | 10 ++-- .../feature/ElementwiseProductSuite.scala | 5 +- .../spark/mllib/feature/HashingTFSuite.scala | 7 +-- .../spark/mllib/feature/NormalizerSuite.scala | 3 +- .../apache/spark/mllib/feature/PCASuite.scala | 3 +- .../mllib/feature/StandardScalerSuite.scala | 13 ++--- .../spark/mllib/fpm/PrefixSpanSuite.scala | 5 +- .../spark/mllib/linalg/VectorsSuite.scala | 8 +-- .../linalg/distributed/RowMatrixSuite.scala | 3 +- .../spark/mllib/recommendation/ALSSuite.scala | 5 +- .../regression/IsotonicRegressionSuite.scala | 3 +- .../spark/mllib/regression/LassoSuite.scala | 7 ++- .../regression/LinearRegressionSuite.scala | 10 ++-- .../regression/RidgeRegressionSuite.scala | 7 ++- .../spark/mllib/stat/CorrelationSuite.scala | 11 ++-- .../mllib/stat/HypothesisTestSuite.scala | 11 ++-- .../spark/mllib/stat/KernelDensitySuite.scala | 5 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 51 ++++++++++--------- .../tree/GradientBoostedTreesSuite.scala | 18 ++++--- .../spark/mllib/tree/ImpuritySuite.scala | 5 +- .../spark/mllib/tree/RandomForestSuite.scala | 17 ++++--- .../mllib/util/MLlibTestSparkContext.scala | 3 +- .../k8s/features/LocalDirsFeatureStep.scala | 3 +- .../KubernetesClusterSchedulerBackend.scala | 3 +- ...yPreferredContainerPlacementStrategy.scala | 3 +- .../YarnCoarseGrainedExecutorBackend.scala | 3 +- .../ContainerPlacementStrategySuite.scala | 15 +++--- .../deploy/yarn/YarnAllocatorSuite.scala | 32 ++++++------ .../sql/catalyst/JavaTypeInference.scala | 3 +- .../analysis/alreadyExistException.scala | 3 +- .../analysis/noSuchItemsExceptions.scala | 3 +- .../sql/catalyst/encoders/RowEncoder.scala | 3 +- .../spark/sql/catalyst/expressions/rows.scala | 3 +- .../spark/sql/catalyst/trees/origin.scala | 3 +- .../spark/sql/catalyst/util/StringUtils.scala | 3 +- .../apache/spark/sql/util/ArrowUtils.scala | 3 +- .../sql/catalyst/CatalystTypeConverters.scala | 5 +- .../spark/sql/catalyst/InternalRow.scala | 3 +- .../sql/catalyst/analysis/Analyzer.scala | 8 +-- .../catalyst/analysis/AssignmentUtils.scala | 3 +- .../CannotReplaceMissingTableException.scala | 3 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 6 ++- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../catalyst/analysis/ResolveCatalogs.scala | 16 ++++-- .../analysis/ResolvePartitionSpec.scala | 4 +- .../analysis/RewriteRowLevelCommand.scala | 5 +- .../catalyst/analysis/v2ResolutionPlans.scala | 3 +- .../sql/catalyst/catalog/interface.scala | 3 +- .../sql/catalyst/csv/UnivocityGenerator.scala | 3 +- .../expressions/ApplyFunctionExpression.scala | 3 +- .../expressions/CallMethodViaReflection.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 3 +- .../InterpretedUnsafeProjection.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 7 +-- .../catalyst/expressions/ToStringBase.scala | 4 +- .../expressions/V2ExpressionUtils.scala | 5 +- .../aggregate/ApproximatePercentile.scala | 5 +- .../aggregate/CentralMomentAgg.scala | 4 +- .../catalyst/expressions/aggregate/Corr.scala | 3 +- .../expressions/aggregate/Covariance.scala | 3 +- .../expressions/aggregate/percentiles.scala | 3 +- .../codegen/GenerateOrdering.scala | 6 ++- .../expressions/complexTypeCreator.scala | 3 +- .../spark/sql/catalyst/expressions/hash.scala | 5 +- .../sql/catalyst/expressions/literals.scala | 4 +- .../expressions/objects/objects.scala | 5 +- .../sql/catalyst/expressions/predicates.scala | 3 +- .../spark/sql/catalyst/expressions/rows.scala | 3 +- .../expressions/stringExpressions.scala | 5 +- .../sql/catalyst/json/JacksonGenerator.scala | 3 +- .../spark/sql/catalyst/json/JsonFilters.scala | 3 +- .../sql/catalyst/optimizer/expressions.scala | 28 +++++----- .../sql/catalyst/parser/AstBuilder.scala | 3 +- .../plans/logical/basicLogicalOperators.scala | 8 +-- .../plans/logical/v2AlterTableCommands.scala | 3 +- .../catalyst/plans/logical/v2Commands.scala | 8 +-- .../spark/sql/catalyst/trees/TreeNode.scala | 4 +- .../sql/catalyst/types/PhysicalDataType.scala | 3 +- .../sql/catalyst/util/CharVarcharUtils.scala | 3 +- .../sql/catalyst/util/QuantileSummaries.scala | 7 +-- .../util/ResolveDefaultColumnsUtil.scala | 3 +- .../catalog/CatalogV2Implicits.scala | 3 +- .../sql/connector/catalog/CatalogV2Util.scala | 17 ++++--- .../distributions/distributions.scala | 5 +- .../connector/expressions/expressions.scala | 12 +++-- .../sql/errors/QueryCompilationErrors.scala | 3 +- .../apache/spark/sql/util/SchemaUtils.scala | 6 ++- .../spark/sql/RandomDataGenerator.scala | 3 +- .../scala/org/apache/spark/sql/RowTest.scala | 3 +- ...eateTablePartitioningValidationSuite.scala | 15 +++--- .../expressions/HashExpressionsSuite.scala | 5 +- .../expressions/MutableProjectionSuite.scala | 4 +- .../catalyst/expressions/OrderingSuite.scala | 5 +- .../expressions/UnsafeRowConverterSuite.scala | 20 ++++---- ...ApproxCountDistinctForIntervalsSuite.scala | 11 ++-- .../codegen/GeneratedProjectionSuite.scala | 7 +-- .../BasicStatsEstimationSuite.scala | 4 +- .../util/GenericArrayDataBenchmark.scala | 3 +- .../InMemoryAtomicPartitionTable.scala | 3 +- .../connector/catalog/InMemoryBaseTable.scala | 8 +-- .../sql/connector/catalog/InMemoryTable.scala | 6 ++- .../catalog/InMemoryTableWithV2Filter.scala | 9 ++-- .../expressions/TransformExtractorSuite.scala | 7 +-- .../scala/org/apache/spark/sql/Column.scala | 5 +- .../spark/sql/DataFrameStatFunctions.scala | 10 ++-- .../scala/org/apache/spark/sql/Dataset.scala | 13 ++--- .../org/apache/spark/sql/Observation.scala | 3 +- .../org/apache/spark/sql/SparkSession.scala | 5 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 5 +- .../analysis/ResolveSessionCatalog.scala | 4 +- .../spark/sql/execution/CacheManager.scala | 3 +- .../execution/ColumnarEvaluatorFactory.scala | 5 +- .../sql/execution/CommandResultExec.scala | 5 +- .../sql/execution/DataSourceScanExec.scala | 15 +++--- .../spark/sql/execution/HiveResult.scala | 9 ++-- .../sql/execution/LocalTableScanExec.scala | 5 +- .../spark/sql/execution/QueryExecution.scala | 3 +- .../aggregate/AggregationIterator.scala | 9 ++-- .../aggregate/HashAggregateExec.scala | 4 +- .../aggregate/ObjectAggregationIterator.scala | 8 +-- .../TungstenAggregationIterator.scala | 9 ++-- .../sql/execution/arrow/ArrowConverters.scala | 7 ++- .../execution/columnar/InMemoryRelation.scala | 3 +- .../sql/execution/command/commands.scala | 9 ++-- .../spark/sql/execution/command/ddl.scala | 5 +- .../spark/sql/execution/command/tables.scala | 3 +- .../spark/sql/execution/command/views.scala | 8 +-- .../execution/datasources/DataSource.scala | 3 +- .../datasources/FileFormatDataWriter.scala | 9 ++-- .../datasources/FileFormatWriter.scala | 9 ++-- .../sql/execution/datasources/FileIndex.scala | 3 +- .../execution/datasources/FileScanRDD.scala | 3 +- .../datasources/FileSourceStrategy.scala | 3 +- .../PartitioningAwareFileIndex.scala | 3 +- .../datasources/PartitioningUtils.scala | 3 +- .../execution/datasources/WriteFiles.scala | 3 +- .../datasources/jdbc/JdbcUtils.scala | 4 +- .../datasources/orc/OrcFiltersBase.scala | 5 +- .../execution/datasources/orc/OrcUtils.scala | 3 +- .../datasources/parquet/ParquetColumn.scala | 5 +- .../datasources/parquet/ParquetUtils.scala | 9 ++-- .../sql/execution/datasources/rules.scala | 6 ++- .../datasources/v2/BatchScanExec.scala | 11 ++-- .../datasources/v2/ContinuousScanExec.scala | 4 +- .../datasources/v2/DataSourceRDD.scala | 6 ++- .../v2/DataSourceV2ScanExecBase.scala | 3 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../datasources/v2/DescribeTableExec.scala | 3 +- .../datasources/v2/DropTableExec.scala | 3 +- .../datasources/v2/FileBatchWrite.scala | 7 ++- .../execution/datasources/v2/FileWrite.scala | 4 +- .../datasources/v2/MicroBatchScanExec.scala | 4 +- .../OptimizeMetadataOnlyDeleteFromTable.scala | 3 +- .../datasources/v2/PushDownUtils.scala | 9 ++-- .../datasources/v2/ShowFunctionsExec.scala | 8 +-- .../datasources/v2/ShowPartitionsExec.scala | 3 +- .../v2/V2ScanPartitioningAndOrdering.scala | 4 +- .../v2/V2ScanRelationPushDown.scala | 3 +- .../datasources/v2/V2SessionCatalog.scala | 7 +-- .../v2/WriteToDataSourceV2Exec.scala | 9 ++-- .../datasources/v2/csv/CSVScan.scala | 6 ++- .../datasources/v2/jdbc/JDBCScan.scala | 7 +-- .../datasources/v2/json/JsonScan.scala | 4 +- .../v2/orc/OrcPartitionReaderFactory.scala | 3 +- .../datasources/v2/orc/OrcScan.scala | 7 +-- .../datasources/v2/orc/OrcScanBuilder.scala | 3 +- .../datasources/v2/parquet/ParquetScan.scala | 7 +-- .../v2/parquet/ParquetScanBuilder.scala | 3 +- .../dynamicpruning/PartitionPruning.scala | 4 +- ...wLevelOperationRuntimeGroupFiltering.scala | 3 +- .../joins/BroadcastNestedLoopJoinExec.scala | 3 +- .../apache/spark/sql/execution/limit.scala | 3 +- .../PythonDataSourcePartitionsExec.scala | 3 +- .../python/UserDefinedPythonDataSource.scala | 3 +- .../WindowInPandasEvaluatorFactory.scala | 6 ++- .../sql/execution/stat/StatFunctions.scala | 5 +- .../streaming/FileStreamSource.scala | 3 +- .../streaming/FileStreamSourceLog.scala | 3 +- .../execution/streaming/HDFSMetadataLog.scala | 3 +- .../continuous/ContinuousDataSourceRDD.scala | 6 ++- .../continuous/ContinuousWriteRDD.scala | 7 ++- .../sources/ConsoleStreamingWrite.scala | 3 +- .../state/HDFSBackedStateStoreProvider.scala | 3 +- .../streaming/state/RocksDBFileManager.scala | 4 +- .../spark/sql/internal/CatalogImpl.scala | 9 ++-- .../ui/StreamingQueryStatisticsPage.scala | 24 +++++---- .../spark/sql/ColumnExpressionSuite.scala | 5 +- .../sql/DataFrameSetOperationsSuite.scala | 3 +- .../org/apache/spark/sql/DataFrameSuite.scala | 5 +- .../spark/sql/DataFrameWriterV2Suite.scala | 20 ++++---- .../org/apache/spark/sql/DatasetSuite.scala | 3 +- .../spark/sql/DatasetUnpivotSuite.scala | 8 +-- .../org/apache/spark/sql/GenTPCDSData.scala | 3 +- .../spark/sql/IntegratedUDFTestUtils.scala | 13 ++--- .../org/apache/spark/sql/QueryTest.scala | 13 ++++- .../apache/spark/sql/SQLQueryTestSuite.scala | 9 ++-- .../sql/connector/DataSourceV2Suite.scala | 3 +- .../RowLevelOperationSuiteBase.scala | 3 +- .../sql/connector/V1WriteFallbackSuite.scala | 5 +- .../V2CommandsCaseSensitivitySuite.scala | 27 ++++++---- .../BaseScriptTransformationSuite.scala | 41 +++++++-------- .../CoalesceShufflePartitionsSuite.scala | 11 ++-- .../execution/RowToColumnConverterSuite.scala | 3 +- .../sql/execution/SQLViewTestSuite.scala | 3 +- .../ShufflePartitionsUtilSuite.scala | 37 +++++++------- .../sql/execution/SparkSqlParserSuite.scala | 15 +++--- .../adaptive/AdaptiveQueryExecSuite.scala | 3 +- .../BuiltInDataSourceWriteBenchmark.scala | 3 +- .../command/ShowFunctionsSuiteBase.scala | 3 +- .../command/v2/CommandSuiteBase.scala | 4 +- .../command/v2/ShowFunctionsSuite.scala | 3 +- .../datasources/json/JsonSuite.scala | 5 +- .../datasources/orc/OrcFilterSuite.scala | 19 +++---- .../execution/datasources/orc/OrcTest.scala | 4 +- ...rquetFileMetadataStructRowIndexSuite.scala | 3 +- .../parquet/ParquetFilterSuite.scala | 34 ++++++++----- .../parquet/ParquetRowIndexSuite.scala | 3 +- .../parquet/ParquetVectorizedSuite.scala | 3 +- .../execution/joins/HashedRelationSuite.scala | 3 +- .../streaming/FileStreamSinkLogSuite.scala | 3 +- .../sources/ForeachBatchSinkSuite.scala | 4 +- .../sources/ForeachWriterSuite.scala | 3 +- .../streaming/state/RocksDBSuite.scala | 4 ++ .../vectorized/ColumnVectorSuite.scala | 15 +++--- .../vectorized/ColumnarBatchSuite.scala | 7 +-- .../spark/sql/jdbc/JDBCWriteSuite.scala | 9 ++-- .../spark/sql/sources/FilteredScanSuite.scala | 3 +- .../sql/streaming/FileStreamSinkSuite.scala | 3 +- .../sql/streaming/FileStreamSourceSuite.scala | 3 +- .../sql/streaming/StateStoreMetricsTest.scala | 11 ++-- .../streaming/StreamingAggregationSuite.scala | 3 +- ...StreamingQueryStatusAndProgressSuite.scala | 3 +- .../apache/spark/sql/test/SQLTestUtils.scala | 3 +- .../execution/HiveCompatibilitySuite.scala | 3 +- .../apache/spark/sql/hive/TableReader.scala | 3 +- .../hive/client/IsolatedClientLoader.scala | 3 +- .../hive/execution/V1WritesHiveUtils.scala | 3 +- .../spark/sql/hive/CachedTableSuite.scala | 11 ++-- .../HiveExternalCatalogVersionsSuite.scala | 3 +- .../sql/hive/HiveUDFDynamicLoadSuite.scala | 13 ++--- .../HiveScriptTransformationSuite.scala | 30 ++++++----- .../sql/sources/HadoopFsRelationTest.scala | 3 +- .../sql/sources/SimpleTextRelation.scala | 3 +- .../apache/spark/streaming/Checkpoint.scala | 3 +- .../streaming/dstream/FileInputDStream.scala | 5 +- .../rdd/WriteAheadLogBackedBlockRDD.scala | 2 + .../streaming/scheduler/ReceiverTracker.scala | 5 +- .../util/FileBasedWriteAheadLog.scala | 3 +- .../spark/streaming/ReceiverSuite.scala | 2 + .../spark/streaming/TestSuiteBase.scala | 5 +- .../WriteAheadLogBackedBlockRDDSuite.scala | 5 +- .../streaming/util/WriteAheadLogSuite.scala | 5 +- .../spark/tools/GenerateMIMAIgnore.scala | 6 ++- 417 files changed, 1709 insertions(+), 1069 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala index f71ea873ab2c9..854c5581ee48f 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala @@ -37,6 +37,7 @@ import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBibl import org.apache.spark.SparkException import org.apache.spark.internal.Logging +import org.apache.spark.util.ArrayImplicits._ /** Provides utility functions to be used inside SparkSubmit. */ private[spark] object MavenUtils extends Logging { @@ -113,7 +114,7 @@ private[spark] object MavenUtils extends Logging { s"The version cannot be null or " + s"be whitespace. The version provided is: ${splits(2)}") new MavenCoordinate(splits(0), splits(1), splits(2)) - } + }.toImmutableArraySeq } /** Path of the local Maven cache. */ @@ -222,7 +223,7 @@ private[spark] object MavenUtils extends Logging { } cacheDirectory.getAbsolutePath + File.separator + s"${artifact.getOrganisation}_${artifact.getName}-${artifact.getRevision}$classifier.jar" - } + }.toImmutableArraySeq } /** Adds the given maven coordinates to Ivy's module descriptor. */ diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index a9f3bb78658ca..cf2cc1f0aefcd 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SerializableConfiguration case class AvroScan( @@ -59,7 +60,7 @@ case class AvroScan( readDataSchema, readPartitionSchema, parsedOptions, - pushedFilters) + pushedFilters.toImmutableArraySeq) } override def equals(obj: Any): Boolean = obj match { @@ -71,6 +72,6 @@ case class AvroScan( override def hashCode(): Int = super.hashCode() override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters.toImmutableArraySeq)) } } diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 250b5e0615ad8..633bbce8df801 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.sources.{EqualTo, Not} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ class AvroCatalystDataConversionSuite extends SparkFunSuite with SharedSparkSession @@ -90,7 +91,8 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite // Spark byte and short both map to avro int case b: Byte => b.toInt case s: Short => s.toInt - case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult)) + case row: GenericInternalRow => + InternalRow.fromSeq(row.values.map(prepareExpectedResult).toImmutableArraySeq) case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult)) case map: MapData => val keys = new GenericArrayData( diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala index adbd8286090d9..54fc97c50b3ec 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala @@ -23,6 +23,7 @@ import org.scalatest.Assertions import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.util.SparkStringUtils.sideBySide +import org.apache.spark.util.ArrayImplicits._ abstract class QueryTest extends RemoteSparkSession { @@ -43,7 +44,11 @@ abstract class QueryTest extends RemoteSparkSession { } protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { - checkAnswer(df, expectedAnswer.collect()) + checkAnswer(df, expectedAnswer.collect().toImmutableArraySeq) + } + + protected def checkAnswer(df: => DataFrame, expectedAnswer: Array[Row]): Unit = { + checkAnswer(df, expectedAnswer.toImmutableArraySeq) } /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala index 8a8f739a7c502..172efb7db7c50 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryPolicy import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.connect.common.config.ConnectCommon import org.apache.spark.sql.test.IntegrationTestUtils._ +import org.apache.spark.util.ArrayImplicits._ /** * An util class to start a local spark connect server in a different process for local E2E tests. @@ -175,7 +176,7 @@ object SparkConnectServerUtils { (fileName.startsWith("scalatest") || fileName.startsWith("scalactic")) } .map(e => Paths.get(e).toUri) - spark.client.artifactManager.addArtifacts(jars) + spark.client.artifactManager.addArtifacts(jars.toImmutableArraySeq) } def createSparkSession(): SparkSession = { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala index 00fba781813e9..7a6eb963cb33b 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala @@ -39,6 +39,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.AddArtifactsResponse import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary import org.apache.spark.util.{MavenUtils, SparkFileUtils, SparkThreadUtils} +import org.apache.spark.util.ArrayImplicits._ /** * The Artifact Manager is responsible for handling and transferring artifacts from the local @@ -392,7 +393,7 @@ object Artifact { val exclusionsList: Seq[String] = if (!StringUtils.isBlank(exclusions)) { - exclusions.split(",") + exclusions.split(",").toImmutableArraySeq } else { Nil } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 4925bc0a5dc16..6545138578249 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -938,7 +938,7 @@ class SparkConnectPlanner( private def transformPythonTableFunction(fun: proto.PythonUDTF): SimplePythonFunction = { SimplePythonFunction( - command = fun.getCommand.toByteArray, + command = fun.getCommand.toByteArray.toImmutableArraySeq, // Empty environment variables envVars = Maps.newHashMap(), pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava, @@ -1030,7 +1030,7 @@ class SparkConnectPlanner( if (!rel.hasValues) { Unpivot( - Some(ids.map(_.named)), + Some(ids.map(_.named).toImmutableArraySeq), None, None, rel.getVariableColumnName, @@ -1042,8 +1042,8 @@ class SparkConnectPlanner( } Unpivot( - Some(ids.map(_.named)), - Some(values.map(v => Seq(v.named))), + Some(ids.map(_.named).toImmutableArraySeq), + Some(values.map(v => Seq(v.named)).toImmutableArraySeq), None, rel.getVariableColumnName, Seq(rel.getValueColumnName), @@ -1624,7 +1624,7 @@ class SparkConnectPlanner( private def transformPythonFunction(fun: proto.PythonUDF): SimplePythonFunction = { SimplePythonFunction( - command = fun.getCommand.toByteArray, + command = fun.getCommand.toByteArray.toImmutableArraySeq, // Empty environment variables envVars = Maps.newHashMap(), pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava, diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala index b8df0ed21d107..c1b1bacba3b6d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala @@ -24,6 +24,7 @@ import io.grpc.netty.NettyServerBuilder import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -64,7 +65,8 @@ object SparkConnectInterceptorRegistry { .map(_.trim) .filter(_.nonEmpty) .map(Utils.classForName[ServerInterceptor](_)) - .map(createInstance(_)) + .map(createInstance) + .toImmutableArraySeq } else { Seq.empty } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 9845cee31037c..2eaa8c8383e3d 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, StreamingForeachBatchHelper} import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper.RunnerCleaner import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.ArrayImplicits._ class SparkConnectSessionHolderSuite extends SharedSparkSession { @@ -160,7 +161,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { s"${IntegratedUDFTestUtils.pysparkPythonPath}:${IntegratedUDFTestUtils.pythonPath}" SimplePythonFunction( - command = fcn(sparkPythonPath), + command = fcn(sparkPythonPath).toImmutableArraySeq, envVars = mutable.Map("PYTHONPATH" -> sparkPythonPath).asJava, pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava, pythonExec = IntegratedUDFTestUtils.pythonExec, diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 3287761b1f5d4..e92ebecfce08c 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.kafka010.MockedSystemClock.currentMockSystemTime import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * A [[MicroBatchStream]] that reads data from Kafka. @@ -131,7 +132,7 @@ private[kafka010] class KafkaMicroBatchStream( } val limits: Seq[ReadLimit] = readLimit match { - case rows: CompositeReadLimit => rows.getReadLimits + case rows: CompositeReadLimit => rows.getReadLimits.toImmutableArraySeq case rows => Seq(rows) } diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala index 7c4c35998e4f2..e13c79625cc55 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala @@ -34,6 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.kafka010.KafkaSourceProvider.StrategyOnNoMatchStartingOffset +import org.apache.spark.util.ArrayImplicits._ /** * This class uses Kafka's own [[Admin]] API to read data offsets from Kafka. @@ -488,7 +489,7 @@ private[kafka010] class KafkaOffsetReaderAdmin( } KafkaOffsetRange(tp, fromOffset, untilOffset, preferredLoc = None) } - rangeCalculator.getRanges(ranges, getSortedExecutorList) + rangeCalculator.getRanges(ranges, getSortedExecutorList.toImmutableArraySeq) } private def partitionsAssignedToAdmin( diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala index e83a2c9e8f5d7..c4ac4c7d57db2 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala @@ -32,6 +32,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.kafka010.KafkaSourceProvider.StrategyOnNoMatchStartingOffset import org.apache.spark.util.{UninterruptibleThread, UninterruptibleThreadRunner} +import org.apache.spark.util.ArrayImplicits._ /** * This class uses Kafka's own [[org.apache.kafka.clients.consumer.KafkaConsumer]] API to @@ -535,7 +536,7 @@ private[kafka010] class KafkaOffsetReaderConsumer( } KafkaOffsetRange(tp, fromOffset, untilOffset, preferredLoc = None) } - rangeCalculator.getRanges(ranges, getSortedExecutorList()) + rangeCalculator.getRanges(ranges, getSortedExecutorList().toImmutableArraySeq) } private def partitionsAssignedToConsumer( diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index f5d4abb569a31..6f6d6319cd6f9 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.types._ import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * A [[Source]] that reads data from Kafka using the following design. @@ -180,7 +181,7 @@ private[kafka010] class KafkaSource( latestPartitionOffsets = if (latest.isEmpty) None else Some(latest) val limits: Seq[ReadLimit] = limit match { - case rows: CompositeReadLimit => rows.getReadLimits + case rows: CompositeReadLimit => rows.getReadLimits.toImmutableArraySeq case rows => Seq(rows) } diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 9da1fe0280e57..f63f5e541e078 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ /** * The provider class for all Kafka readers and writers. It is designed such that it throws @@ -204,7 +205,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister case (ASSIGN, value) => AssignStrategy(JsonUtils.partitions(value)) case (SUBSCRIBE, value) => - SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty)) + SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty).toImmutableArraySeq) case (SUBSCRIBE_PATTERN, value) => SubscribePatternStrategy(value.trim()) case _ => diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 54e05ff95c938..921fcde9ebaaf 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -2410,7 +2410,7 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest { private def sendMessagesWithTimestamp( topic: String, - msgs: Seq[String], + msgs: Array[String], part: Int, ts: Long): Unit = { val records = msgs.map { msg => diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 3a400c657bab3..6753f8be54bf2 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession} import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.util.ArrayImplicits._ abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with KafkaTest { protected var testUtils: KafkaTestUtils = _ @@ -364,7 +365,7 @@ class KafkaContinuousSinkSuite extends KafkaSinkStreamingSuiteBase { try { val fieldTypes: Array[DataType] = Array(BinaryType) val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes.toImmutableArraySeq) row.update(0, data) val iter = Seq.fill(1000)(converter.apply(row)).iterator iter.foreach(writeTask.write(_)) @@ -580,7 +581,7 @@ class KafkaSinkBatchSuiteV2 extends KafkaSinkBatchSuiteBase { try { val fieldTypes: Array[DataType] = Array(BinaryType) val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes.toImmutableArraySeq) row.update(0, data) val iter = Seq.fill(1000)(converter.apply(row)).iterator writeTask.execute(iter) diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 1624f7320bb9b..d7e049629254b 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -54,6 +54,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.kafka010.KafkaTokenUtil import org.apache.spark.util.{SecurityUtils, ShutdownHookManager, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * This is a helper class for Kafka test suites. This has the functionality to set up @@ -440,6 +441,10 @@ class KafkaTestUtils( offsets } + def sendMessages(msgs: Array[ProducerRecord[String, String]]): Seq[(String, RecordMetadata)] = { + sendMessages(msgs.toImmutableArraySeq) + } + def cleanupLogs(): Unit = { server.logManager.cleanupLogs() } diff --git a/connector/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/connector/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index b8dbfe2fcf048..286b073125ff0 100644 --- a/connector/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/connector/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -29,6 +29,7 @@ import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ /** * A batch-oriented interface for consuming from Kafka. @@ -135,7 +136,7 @@ private[spark] class KafkaRDD[K, V]( context.runJob( this, (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => - it.take(parts(tc.partitionId())).toArray, parts.keys.toArray + it.take(parts(tc.partitionId())).toArray, parts.keys.toArray.toImmutableArraySeq ).flatten } } diff --git a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 7824daea8319a..4ddf1d9993ed6 100644 --- a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -33,6 +33,7 @@ import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.Duration import org.apache.spark.streaming.kinesis.KinesisInitialPositions.AtTimestamp import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -287,7 +288,8 @@ private[kinesis] class KinesisReceiver[T]( * for next block. Internally, this is synchronized with `rememberAddedRange()`. */ private def finalizeRangesForCurrentBlock(blockId: StreamBlockId): Unit = { - blockIdToSeqNumRanges.put(blockId, SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray)) + blockIdToSeqNumRanges.put(blockId, + SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray.toImmutableArraySeq)) seqNumRangesInCurrentBlock.clear() logDebug(s"Generated block $blockId has $blockIdToSeqNumRanges") } diff --git a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index c9d1498a5a46c..7c6cf7c3a227b 100644 --- a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} +import org.apache.spark.util.ArrayImplicits._ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) extends KinesisFunSuite with BeforeAndAfterEach with LocalSparkContext { @@ -79,21 +80,21 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) // Verify all data using multiple ranges in a single RDD partition val receivedData1 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, testUtils.endpointUrl, fakeBlockIds(1), - Array(SequenceNumberRanges(allRanges.toArray)) + Array(SequenceNumberRanges(allRanges.toArray.toImmutableArraySeq)) ).map { bytes => new String(bytes).toInt }.collect() assert(receivedData1.toSet === testData.toSet) // Verify all data using one range in each of the multiple RDD partitions val receivedData2 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, testUtils.endpointUrl, fakeBlockIds(allRanges.size), - allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + allRanges.map { range => SequenceNumberRanges(Array(range).toImmutableArraySeq) }.toArray ).map { bytes => new String(bytes).toInt }.collect() assert(receivedData2.toSet === testData.toSet) // Verify ordering within each partition val receivedData3 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, testUtils.endpointUrl, fakeBlockIds(allRanges.size), - allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + allRanges.map { range => SequenceNumberRanges(Array(range).toImmutableArraySeq) }.toArray ).map { bytes => new String(bytes).toInt }.collectPartitions() assert(receivedData3.length === allRanges.size) for (i <- allRanges.indices) { @@ -184,7 +185,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy", 1))) val realRanges = Array.tabulate(numPartitionsInKinesis) { i => val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis))) - SequenceNumberRanges(Array(range)) + SequenceNumberRanges(Array(range).toImmutableArraySeq) } val ranges = (fakeRanges ++ realRanges) diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala index b7f17fece5fa6..02526adb33707 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.sources.{EqualTo, Not} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ class ProtobufCatalystDataConversionSuite extends SparkFunSuite @@ -99,7 +100,8 @@ class ProtobufCatalystDataConversionSuite // Spark byte and short both map to Protobuf int case b: Byte => b.toInt case s: Short => s.toInt - case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult)) + case row: GenericInternalRow => + InternalRow.fromSeq(row.values.map(prepareExpectedResult).toImmutableArraySeq) case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult)) case map: MapData => val keys = new GenericArrayData( diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 5b587d7fbadbb..957408ac24f75 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.scheduler.ExecutorDecommissionInfo +import org.apache.spark.util.ArrayImplicits._ /** * A client that communicates with the cluster manager to request or kill executors. @@ -98,7 +99,7 @@ private[spark] trait ExecutorAllocationClient { executorsAndDecomInfo: Array[(String, ExecutorDecommissionInfo)], adjustTargetNumExecutors: Boolean, triggeredByExecutor: Boolean): Seq[String] = { - killExecutors(executorsAndDecomInfo.map(_._1), + killExecutors(executorsAndDecomInfo.map(_._1).toImmutableArraySeq, adjustTargetNumExecutors, countFailures = false) } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index b8fd2700771c3..af33ef2415783 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -31,6 +31,7 @@ import org.apache.spark.internal.config.History._ import org.apache.spark.internal.config.Kryo._ import org.apache.spark.internal.config.Network._ import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -440,7 +441,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** Get all executor environment variables set on this SparkConf */ def getExecutorEnv: Seq[(String, String)] = { - getAllWithPrefix("spark.executorEnv.") + getAllWithPrefix("spark.executorEnv.").toImmutableArraySeq } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c86f755bbd16a..73dcaffa6ceff 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -71,6 +71,7 @@ import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.{TriggerHeapHistogram, TriggerThreadDump} import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} import org.apache.spark.util._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.logging.DriverLogger /** @@ -2817,7 +2818,7 @@ class SparkContext(config: SparkConf) extends Logging { val driverUpdates = new HashMap[(Int, Int), ExecutorMetrics] // In the driver, we do not track per-stage metrics, so use a dummy stage for the key driverUpdates.put(EventLoggingListener.DRIVER_STAGE_KEY, new ExecutorMetrics(currentMetrics)) - val accumUpdates = new Array[(Long, Int, Int, Seq[AccumulableInfo])](0) + val accumUpdates = new Array[(Long, Int, Int, Seq[AccumulableInfo])](0).toImmutableArraySeq listenerBus.post(SparkListenerExecutorMetricsUpdate("driver", accumUpdates, driverUpdates)) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index c2bae41d34eec..3277f86e36710 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -46,6 +46,7 @@ import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManage import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * :: DeveloperApi :: @@ -533,7 +534,7 @@ object SparkEnv extends Logging { .map(entry => (entry.getKey, entry.getValue)).toSeq.sorted Map[String, Seq[(String, String)]]( "JVM Information" -> jvmInformation, - "Spark Properties" -> sparkProperties, + "Spark Properties" -> sparkProperties.toImmutableArraySeq, "Hadoop Properties" -> hadoopProperties, "System Properties" -> otherProperties, "Classpath Entries" -> classPaths, diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index af1cc127bcad2..9bad4d9e163df 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -35,6 +35,7 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -375,7 +376,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = { // This is useful for implementing `take` from other language frontends // like Python where the data is serialized. - val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds) + val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds.toImmutableArraySeq) res.map(_.toSeq.asJava) } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index b60d90275cb70..ff6ed9f86b554 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -30,6 +30,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.security.SocketAuthServer +import org.apache.spark.util.ArrayImplicits._ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( parent: RDD[T], @@ -149,7 +150,7 @@ private[spark] object RRDD { * called from R. */ def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { - JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) + JavaRDD.fromRDD(jsc.sc.parallelize(arr.toImmutableArraySeq, arr.length)) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index e9482d1581400..13a67a794c83f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -36,6 +36,7 @@ import org.apache.spark.internal.config.Network.RPC_ASK_TIMEOUT import org.apache.spark.resource.ResourceUtils import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.util.{SparkExitCode, ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * Proxy that relays messages to the driver. @@ -290,7 +291,8 @@ private[spark] class ClientApp extends SparkApplication { val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). map(rpcEnv.setupEndpointRef(_, Master.ENDPOINT_NAME)) - rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) + rpcEnv.setupEndpoint("client", + new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints.toImmutableArraySeq, conf)) rpcEnv.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 0bb3dab62ed45..dab880b37dc22 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -39,6 +39,7 @@ import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdenti import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.BUFFER_SIZE +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -219,7 +220,7 @@ private[spark] class SparkHadoopUtil extends Logging { def listLeafStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = { def recurse(status: FileStatus): Seq[FileStatus] = { val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDirectory) - leaves ++ directories.flatMap(f => listLeafStatuses(fs, f)) + (leaves ++ directories.flatMap(f => listLeafStatuses(fs, f))).toImmutableArraySeq } if (baseStatus.isDirectory) recurse(baseStatus) else Seq(baseStatus) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 30b542eefb60b..29bb8f84ee6cb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -45,6 +45,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.internal.config.UI._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util._ +import org.apache.spark.util.ArrayImplicits._ /** * Whether to submit, kill, or request the status of an application. @@ -84,7 +85,7 @@ private[spark] class SparkSubmit extends Logging { } protected def parseArguments(args: Array[String]): SparkSubmitArguments = { - new SparkSubmitArguments(args) + new SparkSubmitArguments(args.toImmutableArraySeq) } /** @@ -1057,7 +1058,7 @@ object SparkSubmit extends CommandLineUtils with Logging { self => override protected def parseArguments(args: Array[String]): SparkSubmitArguments = { - new SparkSubmitArguments(args) { + new SparkSubmitArguments(args.toImmutableArraySeq) { override protected def logInfo(msg: => String): Unit = self.logInfo(msg) override protected def logWarning(msg: => String): Unit = self.logWarning(msg) diff --git a/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala b/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala index 641c5416cbb33..db9f5c516cf65 100644 --- a/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/StandaloneResourceUtils.scala @@ -29,6 +29,7 @@ import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.resource.{ResourceAllocation, ResourceID, ResourceInformation, ResourceRequirement} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils private[spark] object StandaloneResourceUtils extends Logging { @@ -96,7 +97,7 @@ private[spark] object StandaloneResourceUtils extends Logging { val compShortName = componentName.substring(componentName.lastIndexOf(".") + 1) val tmpFile = Utils.tempFileWith(dir) val allocations = resources.map { case (rName, rInfo) => - ResourceAllocation(new ResourceID(componentName, rName), rInfo.addresses) + ResourceAllocation(new ResourceID(componentName, rName), rInfo.addresses.toImmutableArraySeq) }.toSeq try { writeResourceAllocationJson(allocations, tmpFile) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index d8bd16e970f24..e76d72194c39f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -50,6 +50,7 @@ import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.kvstore._ /** @@ -316,7 +317,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Split a comma separated String, filter out any empty items, and return a Sequence of strings */ private def stringToSeq(list: String): Seq[String] = { - list.split(',').map(_.trim).filter(!_.isEmpty) + list.split(',').map(_.trim).filter(_.nonEmpty).toImmutableArraySeq } override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index ba949e2630e43..ea776c6dd2ae2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -53,7 +54,7 @@ private[master] class FileSystemPersistenceEngine( override def read[T: ClassTag](prefix: String): Seq[T] = { val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix)) - files.map(deserializeFromFile[T]) + files.map(deserializeFromFile[T]).toImmutableArraySeq } private def serializeIntoFile(file: File, value: AnyRef): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index b3fbec1830e45..976655f029a9d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -41,6 +41,7 @@ import org.apache.spark.resource.{ResourceProfile, ResourceRequirement, Resource import org.apache.spark.rpc._ import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.util.{SparkUncaughtExceptionHandler, ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ private[deploy] class Master( override val rpcEnv: RpcEnv, @@ -277,7 +278,8 @@ private[deploy] class Master( } else if (idToWorker.contains(id)) { workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress, true)) } else { - val workerResources = resources.map(r => r._1 -> WorkerResourceInfo(r._1, r._2.addresses)) + val workerResources = + resources.map(r => r._1 -> WorkerResourceInfo(r._1, r._2.addresses.toImmutableArraySeq)) val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, workerRef, workerWebUiUrl, workerResources) if (registerWorker(worker)) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 0137e2be74720..4d475fa8a6e8f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.resource.{ResourceAllocator, ResourceInformation, ResourceRequirement} import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils private[spark] case class WorkerResourceInfo(name: String, addresses: Seq[String]) @@ -162,7 +163,7 @@ private[spark] class WorkerInfo( */ def recoverResources(expected: Map[String, ResourceInformation]): Unit = { expected.foreach { case (rName, rInfo) => - resources(rName).acquire(rInfo.addresses) + resources(rName).acquire(rInfo.addresses.toImmutableArraySeq) } } @@ -172,7 +173,7 @@ private[spark] class WorkerInfo( */ private def releaseResources(allocated: Map[String, ResourceInformation]): Unit = { allocated.foreach { case (rName, rInfo) => - resources(rName).release(rInfo.addresses) + resources(rName).release(rInfo.addresses.toImmutableArraySeq) } } } diff --git a/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala b/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala index 463552fadb4b7..3120f69822830 100644 --- a/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala +++ b/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala @@ -28,6 +28,7 @@ import scala.util.Try import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.{config, Logging} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -72,7 +73,7 @@ private[spark] class ProcfsMetricsGetter(procfsDir: String = "/proc/") extends L // This can be simplified in java9: // https://docs.oracle.com/javase/9/docs/api/java/lang/ProcessHandle.html val cmd = Array("bash", "-c", "echo $PPID") - val out = Utils.executeAndGetOutput(cmd) + val out = Utils.executeAndGetOutput(cmd.toImmutableArraySeq) Integer.parseInt(out.split("\n")(0)) } catch { @@ -90,7 +91,7 @@ private[spark] class ProcfsMetricsGetter(procfsDir: String = "/proc/") extends L } try { val cmd = Array("getconf", "PAGESIZE") - val out = Utils.executeAndGetOutput(cmd) + val out = Utils.executeAndGetOutput(cmd.toImmutableArraySeq) Integer.parseInt(out.split("\n")(0)) } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index 0643ab3ea2e1a..a84fadcf965b2 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -36,6 +36,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * A helper object that saves an RDD using a Hadoop OutputFormat. @@ -97,7 +98,8 @@ object SparkHadoopWriter extends Logging { }) logInfo(s"Start to commit write Job ${jobContext.getJobID}.") - val (_, duration) = Utils.timeTakenMs { committer.commitJob(jobContext, ret) } + val (_, duration) = Utils + .timeTakenMs { committer.commitJob(jobContext, ret.toImmutableArraySeq) } logInfo(s"Write Job ${jobContext.getJobID} committed. Elapsed time: $duration ms.") } catch { case cause: Throwable => diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 4dc01674473e1..895e314b7c2f3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -38,12 +39,12 @@ private[spark] case class CoalescedRDDPartition( @transient rdd: RDD[_], parentsIndices: Array[Int], @transient preferredLocation: Option[String] = None) extends Partition { - var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_)) + var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_)).toImmutableArraySeq @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { // Update the reference to parent partition at the time of task serialization - parents = parentsIndices.map(rdd.partitions(_)) + parents = parentsIndices.map(rdd.partitions(_)).toImmutableArraySeq oos.defaultWriteObject() } @@ -103,7 +104,7 @@ private[spark] class CoalescedRDD[T: ClassTag]( override def getDependencies: Seq[Dependency[_]] = { Seq(new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = - partitions(id).asInstanceOf[CoalescedRDDPartition].parentsIndices + partitions(id).asInstanceOf[CoalescedRDDPartition].parentsIndices.toImmutableArraySeq }) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 9430a6973f13e..bea1f3093fd7b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -43,6 +43,7 @@ import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{NextIterator, SerializableConfiguration, ShutdownHookManager, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * A Spark split class that wraps around a Hadoop InputSplit. @@ -399,7 +400,7 @@ class HadoopRDD[K, V]( HadoopRDD.convertSplitLocationInfo(lsplit.getLocationInfo) case _ => None } - locs.getOrElse(hsplit.getLocations.filter(_ != "localhost")) + locs.getOrElse(hsplit.getLocations.filter(_ != "localhost").toImmutableArraySeq) } override def checkpoint(): Unit = { @@ -482,6 +483,6 @@ private[spark] object HadoopRDD extends Logging { } else { None } - }) + }.toImmutableArraySeq) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala index 56f53714cbe3a..f9b2ffc068b06 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -51,7 +52,7 @@ private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient private val !SparkEnv.get.blockManager.master.contains(RDDBlockId(rdd.id, i)) } if (missingPartitionIndices.nonEmpty) { - rdd.sparkContext.runJob(rdd, action, missingPartitionIndices) + rdd.sparkContext.runJob(rdd, action, missingPartitionIndices.toImmutableArraySeq) } new LocalCheckpointRDD[T](rdd) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f03bcaad80c1e..a12d4a40b1464 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -41,6 +41,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} +import org.apache.spark.util.ArrayImplicits._ private[spark] class NewHadoopPartition( rddId: Int, @@ -346,7 +347,7 @@ class NewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: Partition): Seq[String] = { val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value val locs = HadoopRDD.convertSplitLocationInfo(split.getLocationInfo) - locs.getOrElse(split.getLocations.filter(_ != "localhost")) + locs.getOrElse(split.getLocations.filter(_ != "localhost").toImmutableArraySeq) } override def persist(storageLevel: StorageLevel): this.type = { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index a40d6636ff061..f5a731d134eaf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -41,6 +41,7 @@ import org.apache.spark.internal.io._ import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -937,10 +938,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } buf.toSeq } : Seq[V] - val res = self.context.runJob(self, process, Array(index)) + val res = self.context.runJob(self, process, Array(index).toImmutableArraySeq) res(0) case None => - self.filter(_._1 == key).map(_._2).collect() + self.filter(_._1 == key).map(_._2).collect().toImmutableArraySeq } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 610b48ea2ba50..fe10e140f82de 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -45,6 +45,7 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.resource.ResourceProfile import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.{ExternalAppendOnlyMap, OpenHashMap, Utils => collectionUtils} @@ -1963,8 +1964,7 @@ abstract class RDD[T: ClassTag]( val storageInfo = rdd.context.getRDDStorageInfo(_.id == rdd.id).map(info => " CachedPartitions: %d; MemorySize: %s; DiskSize: %s".format( info.numCachedPartitions, bytesToString(info.memSize), bytesToString(info.diskSize))) - - s"$rdd [$persistence]" +: storageInfo + (s"$rdd [$persistence]" +: storageInfo).toImmutableArraySeq } // Apply a different rule to the last child diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala b/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala index 7f7bb36512d14..c9e5ba1ad8e04 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceInformation.scala @@ -24,6 +24,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkException import org.apache.spark.annotation.Evolving +import org.apache.spark.util.ArrayImplicits._ /** * Class to hold information about a type of Resource. A resource could be a GPU, FPGA, etc. @@ -57,7 +58,7 @@ class ResourceInformation( // TODO(SPARK-39658): reconsider whether we want to expose a third-party library's // symbols as part of a public API: - final def toJson(): JValue = ResourceInformationJson(name, addresses).toJValue + final def toJson(): JValue = ResourceInformationJson(name, addresses.toImmutableArraySeq).toJValue } private[spark] object ResourceInformation { diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala b/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala index d19f413598b58..9080be01a9e66 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala @@ -31,6 +31,7 @@ import org.apache.spark.api.resource.ResourceDiscoveryPlugin import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_CORES, RESOURCES_DISCOVERY_PLUGIN, SPARK_TASK_PREFIX} import org.apache.spark.internal.config.Tests.RESOURCES_WARNING_TESTING +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -155,7 +156,7 @@ private[spark] object ResourceUtils extends Logging { s"config: $componentName.$RESOURCE_PREFIX.$key") } key.substring(0, index) - }.distinct.map(name => new ResourceID(componentName, name)) + }.distinct.map(name => new ResourceID(componentName, name)).toImmutableArraySeq } def parseAllResourceRequests( @@ -273,7 +274,8 @@ private[spark] object ResourceUtils extends Logging { val otherResources = otherResourceIds.flatMap { id => val request = parseResourceRequest(sparkConf, id) if (request.amount > 0) { - Some(ResourceAllocation(id, discoverResource(sparkConf, request).addresses)) + Some(ResourceAllocation(id, + discoverResource(sparkConf, request).addresses.toImmutableArraySeq)) } else { None } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7241376dcdc12..045f67ffd80f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -51,6 +51,7 @@ import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ +import org.apache.spark.util.ArrayImplicits._ /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -355,7 +356,7 @@ private[spark] class DAGScheduler( blockManagerId: BlockManagerId, // (stageId, stageAttemptId) -> metrics executorUpdates: mutable.Map[(Int, Int), ExecutorMetrics]): Boolean = { - listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates, + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates.toImmutableArraySeq, executorUpdates)) blockManagerMaster.driverHeartbeatEndPoint.askSync[Boolean]( BlockManagerHeartbeat(blockManagerId), @@ -1324,7 +1325,8 @@ private[spark] class DAGScheduler( activeJobs += job finalStage.setActiveJob(job) val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + val stageInfos = + stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)).toImmutableArraySeq listenerBus.post( SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, Utils.cloneProperties(properties))) @@ -1364,7 +1366,8 @@ private[spark] class DAGScheduler( activeJobs += job finalStage.addActiveJob(job) val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + val stageInfos = + stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)).toImmutableArraySeq listenerBus.post( SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, Utils.cloneProperties(properties))) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala index 850756b50a3cb..44655f3d91ed7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala @@ -23,6 +23,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.network.shuffle.protocol.MergeStatuses import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -102,7 +103,7 @@ private[spark] object MergeStatus { val mergeStatus = new MergeStatus(mergerLoc, shuffleMergeId, bitmap, mergeStatuses.sizes(index)) (mergeStatuses.reduceIds(index), mergeStatus) - } + }.toImmutableArraySeq } def apply( diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index e02dd27937062..f23902eb68a91 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -42,6 +42,7 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * A scheduler backend that waits for coarse-grained executors to connect. @@ -173,7 +174,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorInfo.freeCores += taskCpus resources.foreach { case (k, v) => executorInfo.resourcesInfo.get(k).foreach { r => - r.release(v.addresses) + r.release(v.addresses.toImmutableArraySeq) } } makeOffers(executorId) @@ -274,7 +275,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // as configured by the user, or set to 1 as that is the default (1 task/resource) val numParts = scheduler.sc.resourceProfileManager .resourceProfileFromId(resourceProfileId).getNumSlotsPerAddress(rName, conf) - (info.name, new ExecutorResourceInfo(info.name, info.addresses, numParts)) + (info.name, + new ExecutorResourceInfo(info.name, info.addresses.toImmutableArraySeq, numParts)) } // If we've requested the executor figure out when we did. val reqTs: Option[Long] = CoarseGrainedSchedulerBackend.this.synchronized { @@ -446,7 +448,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorData.freeCores -= task.cpus task.resources.foreach { case (rName, rInfo) => assert(executorData.resourcesInfo.contains(rName)) - executorData.resourcesInfo(rName).acquire(rInfo.addresses) + executorData.resourcesInfo(rName).acquire(rInfo.addresses.toImmutableArraySeq) } logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + @@ -569,14 +571,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } if (executorsToDecommission.isEmpty) { - return executorsToDecommission + return executorsToDecommission.toImmutableArraySeq } logInfo(s"Decommission executors: ${executorsToDecommission.mkString(", ")}") // If we don't want to replace the executors we are decommissioning if (adjustTargetNumExecutors) { - adjustExecutors(executorsToDecommission) + adjustExecutors(executorsToDecommission.toImmutableArraySeq) } // Mark those corresponding BlockManagers as decommissioned first before we sending @@ -586,7 +588,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Note that marking BlockManager as decommissioned doesn't need depend on // `spark.storage.decommission.enabled`. Because it's meaningless to save more blocks // for the BlockManager since the executor will be shutdown soon. - scheduler.sc.env.blockManager.master.decommissionBlockManagers(executorsToDecommission) + scheduler.sc.env.blockManager.master + .decommissionBlockManagers(executorsToDecommission.toImmutableArraySeq) if (!triggeredByExecutor) { executorsToDecommission.foreach { executorId => @@ -603,14 +606,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } if (stragglers.nonEmpty) { logInfo(s"${stragglers.toList} failed to decommission in ${cleanupInterval}, killing.") - killExecutors(stragglers, false, false, true) + killExecutors(stragglers.toImmutableArraySeq, false, false, true) } } } cleanupService.map(_.schedule(cleanupTask, cleanupInterval, TimeUnit.SECONDS)) } - executorsToDecommission + executorsToDecommission.toImmutableArraySeq } override def start(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 03501d9d7407c..c1c36d7a9f046 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -31,6 +31,7 @@ import org.apache.spark.status.AppStatusUtils.getQuantilesValue import org.apache.spark.status.api.v1 import org.apache.spark.storage.FallbackStorage.FALLBACK_BLOCK_MANAGER_ID import org.apache.spark.ui.scope._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils import org.apache.spark.util.kvstore.KVStore @@ -255,7 +256,7 @@ private[spark] class AppStatusStore( stageAttemptId: Int, unsortedQuantiles: Array[Double]): Option[v1.TaskMetricDistributions] = { val stageKey = Array(stageId, stageAttemptId) - val quantiles = unsortedQuantiles.sorted + val quantiles = unsortedQuantiles.sorted.toImmutableArraySeq // We don't know how many tasks remain in the store that actually have metrics. So scan one // metric and count how many valid tasks there are. Use skip() instead of next() since it's @@ -745,7 +746,7 @@ private[spark] class AppStatusStore( } else { val values = summary.values.toIndexedSeq Some(new v1.ExecutorMetricsDistributions( - quantiles = quantiles, + quantiles = quantiles.toImmutableArraySeq, taskTime = getQuantilesValue(values.map(_.taskTime.toDouble).sorted, quantiles), failedTasks = getQuantilesValue(values.map(_.failedTasks.toDouble).sorted, quantiles), succeededTasks = getQuantilesValue(values.map(_.succeededTasks.toDouble).sorted, quantiles), @@ -765,7 +766,7 @@ private[spark] class AppStatusStore( diskBytesSpilled = getQuantilesValue(values.map(_.diskBytesSpilled.toDouble).sorted, quantiles), peakMemoryMetrics = - new v1.ExecutorPeakMetricsDistributions(quantiles, + new v1.ExecutorPeakMetricsDistributions(quantiles.toImmutableArraySeq, values.flatMap(_.peakMemoryMetrics)) )) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index f77fda4614936..da48af90a9c2d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -62,6 +62,7 @@ import org.apache.spark.storage.BlockManagerMessages.{DecommissionBlockManager, import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.io.ChunkedByteBuffer /* Class for returning a fetched block and associated metrics. */ @@ -2151,7 +2152,7 @@ private[spark] object BlockManager { // blockManagerMaster != null is used in tests assert(env != null || blockManagerMaster != null) val blockLocations: Seq[Seq[BlockManagerId]] = if (blockManagerMaster == null) { - env.blockManager.getLocationBlockIds(blockIds) + env.blockManager.getLocationBlockIds(blockIds).toImmutableArraySeq } else { blockManagerMaster.getLocations(blockIds) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 81a6bb5d45c3e..0d52b66a400f6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -40,6 +40,7 @@ import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseG import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * BlockManagerMasterEndpoint is an [[IsolatedThreadSafeRpcEndpoint]] on the master node to @@ -877,7 +878,7 @@ class BlockManagerMasterEndpoint( private def getLocationsMultipleBlockIds( blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { - blockIds.map(blockId => getLocations(blockId)) + blockIds.map(blockId => getLocations(blockId)).toImmutableArraySeq } /** Get the list of the peers of the given block manager */ diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 6118631e549b7..512ee3cc806fb 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -36,6 +36,7 @@ import org.apache.spark.storage.DiskBlockManager.ATTEMPT_ID_KEY import org.apache.spark.storage.DiskBlockManager.MERGE_DIR_KEY import org.apache.spark.storage.DiskBlockManager.MERGE_DIRECTORY import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -168,7 +169,7 @@ private[spark] class DiskBlockManager( }.filter(_ != null).flatMap { dir => val files = dir.listFiles() if (files != null) files.toSeq else Seq.empty - } + }.toImmutableArraySeq } /** List all the blocks currently stored on disk by the disk manager. */ diff --git a/core/src/main/scala/org/apache/spark/util/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/util/DependencyUtils.scala index 1d158ad50dc58..14851d8772895 100644 --- a/core/src/main/scala/org/apache/spark/util/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/DependencyUtils.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.SparkSubmit import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.util.ArrayImplicits._ private[spark] case class IvyProperties( packagesExclusions: String, @@ -101,7 +102,7 @@ private[spark] object DependencyUtils extends Logging { ivySettingsPath: Option[String]): Seq[String] = { val exclusions: Seq[String] = if (!StringUtils.isBlank(packagesExclusions)) { - packagesExclusions.split(",") + packagesExclusions.split(",").toImmutableArraySeq } else { Nil } diff --git a/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala b/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala index 5cd93dfae3580..3245a528b74cf 100644 --- a/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.util.ArrayImplicits._ /** * Utility functions to simplify and speed-up file listing. @@ -95,11 +96,11 @@ private[spark] object HadoopFSUtils extends Logging { }.filterNot(status => shouldFilterOutPath(status.getPath.toString.substring(prefixLength))) .filter(f => filter.accept(f.getPath)) .toArray - Seq((path, statues)) + Seq((path, statues.toImmutableArraySeq)) } catch { case _: FileNotFoundException => logWarning(s"The root directory $path was not found. Was it deleted very recently?") - Seq((path, Array.empty[FileStatus])) + Seq((path, Seq.empty[FileStatus])) } } @@ -168,7 +169,7 @@ private[spark] object HadoopFSUtils extends Logging { parallelismMax = 0) (path, leafFiles) } - }.collect() + }.collect().toImmutableArraySeq } finally { sc.setJobDescription(previousJobDescription) } @@ -247,7 +248,7 @@ private[spark] object HadoopFSUtils extends Logging { case Some(context) if dirs.size > parallelismThreshold => parallelListLeafFilesInternal( context, - dirs.map(_.getPath), + dirs.map(_.getPath).toImmutableArraySeq, hadoopConf = hadoopConf, filter = filter, isRootLevel = false, @@ -268,7 +269,7 @@ private[spark] object HadoopFSUtils extends Logging { isRootPath = false, parallelismThreshold = parallelismThreshold, parallelismMax = parallelismMax) - } + }.toImmutableArraySeq } val filteredTopLevelFiles = if (filter != null) { topLevelFiles.filter(f => filter.accept(f.getPath)) @@ -326,7 +327,7 @@ private[spark] object HadoopFSUtils extends Logging { s"the following files were missing during file scan:\n ${missingFiles.mkString("\n ")}") } - resolvedLeafStatuses + resolvedLeafStatuses.toImmutableArraySeq } // scalastyle:on argcount diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index b3c208bb518e4..02f4912326365 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -34,6 +34,7 @@ import org.apache.spark.resource.{ExecutorResourceRequest, ResourceInformation, import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils.weakIntern /** @@ -1149,8 +1150,8 @@ private[spark] object JsonProtocol extends JsonUtils { val rpId = jsonOption(json.get("Resource Profile Id")).map(_.extractInt) val stageProf = rpId.getOrElse(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) - val stageInfo = new StageInfo(stageId, attemptId, stageName, numTasks, rddInfos, - parentIds, details, resourceProfileId = stageProf, + val stageInfo = new StageInfo(stageId, attemptId, stageName, numTasks, + rddInfos.toImmutableArraySeq, parentIds, details, resourceProfileId = stageProf, isShufflePushEnabled = isShufflePushEnabled, shuffleMergerCount = shufflePushMergersCount) stageInfo.submissionTime = submissionTime diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 99ba13479898e..049999281f5bb 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -77,6 +77,7 @@ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.{Utils => CUtils} import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -839,7 +840,7 @@ private[spark] object Utils * uses a local random number generator, avoiding inter-thread contention. */ def randomize[T: ClassTag](seq: IterableOnce[T]): Seq[T] = { - randomizeInPlace(seq.iterator.toArray) + randomizeInPlace(seq.iterator.toArray).toImmutableArraySeq } /** @@ -2092,7 +2093,7 @@ private[spark] object Utils } else "" val locking = monitors.get(idx).map(mi => s"\t- locked $mi\n").getOrElse("") s"${frame.toString}\n$locked$locking" - }) + }.toImmutableArraySeq) val synchronizers = threadInfo.getLockedSynchronizers.map(_.toString) val monitorStrs = monitors.values.toSeq @@ -2103,8 +2104,8 @@ private[spark] object Utils stackTrace, if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), - synchronizers ++ monitorStrs, - synchronizers, + (synchronizers ++ monitorStrs).toImmutableArraySeq, + synchronizers.toImmutableArraySeq, monitorStrs, Option(threadInfo.getLockName), Option(threadInfo.getLockOwnerName), @@ -2121,6 +2122,7 @@ private[spark] object Utils conf.getAll .filter { case (k, _) => filterKey(k) } .map { case (k, v) => s"-D$k=$v" } + .toImmutableArraySeq } /** @@ -2691,7 +2693,7 @@ private[spark] object Utils } def stringToSeq(str: String): Seq[String] = { - str.split(",").map(_.trim()).filter(_.nonEmpty) + str.split(",").map(_.trim()).filter(_.nonEmpty).toImmutableArraySeq } /** diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 68a59232c7a96..e374c41b91405 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -25,6 +25,7 @@ import org.apache.commons.io.IOUtils import org.apache.spark.SparkConf import org.apache.spark.internal.config +import org.apache.spark.util.ArrayImplicits._ /** * Continuously appends data from input stream into the given file, and rolls @@ -184,6 +185,7 @@ private[spark] object RollingFileAppender { val file = new File(directory, activeFileName).getAbsoluteFile if (file.exists) Some(file) else None } - rolledOverFiles.sortBy(_.getName.stripSuffix(GZIP_LOG_SUFFIX)) ++ activeFile + (rolledOverFiles.sortBy(_.getName.stripSuffix(GZIP_LOG_SUFFIX)) ++ activeFile) + .toImmutableArraySeq } } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index e42df0821589b..c425596eb0433 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils trait RDDCheckpointTester { self: SparkFunSuite => @@ -475,14 +476,14 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS new PartitionerAwareUnionRDD[(Int, Int)](sc, Array( generateFatPairRDD(), rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) - )) + ).toImmutableArraySeq) }, reliableCheckpoint) testRDDPartitions(rdd => { new PartitionerAwareUnionRDD[(Int, Int)](sc, Array( generateFatPairRDD(), rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) - )) + ).toImmutableArraySeq) }, reliableCheckpoint) // Test that the PartitionerAwareUnionRDD updates parent partitions @@ -491,7 +492,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS // implementation of PartitionerAwareUnionRDD. val pairRDD = generateFatPairRDD() checkpoint(pairRDD, reliableCheckpoint) - val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD)) + val unionRDD = new PartitionerAwareUnionRDD(sc, Seq(pairRDD)) val partitionBeforeCheckpoint = serializeDeserialize( unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) pairRDD.count() diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index f8c0e823cbe4e..4a2b2339159cb 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -596,7 +596,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { // Ensure that if all of the splits are empty, we remove the splits correctly testIgnoreEmptySplits( - data = Array.empty[Tuple2[String, String]], + data = Seq.empty[Tuple2[String, String]], actualPartitionNum = 1, expectedPartitionNum = 0) @@ -639,7 +639,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { // Ensure that if all of the splits are empty, we remove the splits correctly testIgnoreEmptySplits( - data = Array.empty[Tuple2[String, String]], + data = Seq.empty[Tuple2[String, String]], actualPartitionNum = 1, expectedPartitionNum = 0) diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index eea7753ee2555..28fa9f5e23e79 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -23,6 +23,7 @@ import scala.math.abs import org.scalatest.PrivateMethodTester import org.apache.spark.rdd.RDD +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.StatCounter class PartitioningSuite extends SparkFunSuite with SharedSparkContext with PrivateMethodTester { @@ -208,9 +209,10 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva } test("partitioning Java arrays should fail") { - val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x)) + val arrs: RDD[Array[Int]] = + sc.parallelize(Array(1, 2, 3, 4).toImmutableArraySeq, 2).map(x => Array(x)) val arrPairs: RDD[(Array[Int], Int)] = - sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) + sc.parallelize(Array(1, 2, 3, 4).toImmutableArraySeq, 2).map(x => (Array(x), x)) def verify(testFun: => Unit): Unit = { intercept[SparkException](testFun).getMessage.contains("array") @@ -235,7 +237,7 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva test("zero-length partitions should be correctly handled") { // Create RDD with some consecutive empty partitions (including the "first" one) val rdd: RDD[Double] = sc - .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) + .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0).toImmutableArraySeq, 8) .filter(_ >= 0.0) // Run the partitions, including the consecutive empty ones, through StatCounter diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index d403f940cf4da..a92d532907adf 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.scheduler.{MapStatus, MergeStatus, MyRDD, SparkListener, import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.MutablePair abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDirsTest { @@ -154,7 +155,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) - val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) + val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)).toImmutableArraySeq val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new ShuffledRDD[Int, Int, Int](pairs, new HashPartitioner(2)).collect() @@ -167,7 +168,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) - val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) + val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)).toImmutableArraySeq val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs) .sortByKey().collect() diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index 6271ce507fddb..4871e14cc80c8 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -23,13 +23,14 @@ import org.scalatest.Assertions import org.scalatest.concurrent.Eventually._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty) - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) + val rdd = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) assert(sc.getPersistentRDDs.isEmpty) rdd.cache() @@ -39,14 +40,14 @@ class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getPersistentRDDs returns an immutable map") { sc = new SparkContext("local", "test") - val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2).cache() val myRdds = sc.getPersistentRDDs assert(myRdds.size === 1) assert(myRdds(0) === rdd1) assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY) // myRdds2 should have 2 RDDs, but myRdds should not change - val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() + val rdd2 = sc.makeRDD(Array(5, 6, 7, 8).toImmutableArraySeq, 1).cache() val myRdds2 = sc.getPersistentRDDs assert(myRdds2.size === 2) assert(myRdds2(0) === rdd1) @@ -60,7 +61,7 @@ class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getRDDStorageInfo only reports on RDDs that actually persist data") { sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + val rdd = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2).cache() assert(sc.getRDDStorageInfo.length === 0) rdd.collect() sc.listenerBus.waitUntilEmpty() @@ -83,7 +84,7 @@ package object testPackage extends Assertions { private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r def runCallSiteTest(sc: SparkContext): Unit = { - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) + val rdd = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) val rddCreationSite = rdd.getCreationSite val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index e1401805c17e7..12f9d2f83c777 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -47,6 +47,7 @@ import org.apache.spark.resource.TestResourceIDs._ import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorMetricsUpdate, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -126,7 +127,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) sc.addFile(file1.getAbsolutePath) sc.addFile(relativePath) - sc.parallelize(Array(1), 1).map(x => { + sc.parallelize(Array(1).toImmutableArraySeq, 1).map(x => { val gotten1 = new File(SparkFiles.get(file1.getName)) val gotten2 = new File(SparkFiles.get(file2.getName)) if (!gotten1.exists()) { @@ -196,7 +197,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc.addArchive(s"${zipFile.getAbsolutePath}#bar") sc.addArchive(relativePath2) - sc.parallelize(Array(1), 1).map { x => + sc.parallelize(Array(1).toImmutableArraySeq, 1).map { x => val gotten1 = new File(SparkFiles.get(jarFile.getName)) val gotten2 = new File(SparkFiles.get(zipFile.getName)) val gotten3 = new File(SparkFiles.get("foo")) @@ -294,7 +295,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu try { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) sc.addFile(neptune.getAbsolutePath, true) - sc.parallelize(Array(1), 1).map(x => { + sc.parallelize(Array(1).toImmutableArraySeq, 1).map(x => { val sep = File.separator if (!new File(SparkFiles.get(neptune.getName + sep + alien1.getName)).exists()) { throw new SparkException("can't access file under root added directory") @@ -1241,7 +1242,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc.addFile(fileUrl1) sc.addJar(jar2.toString) sc.addFile(file2.toString) - sc.parallelize(Array(1), 1).map { x => + sc.parallelize(Array(1).toImmutableArraySeq, 1).map { x => val gottenJar1 = new File(SparkFiles.get(jar1.getName)) if (!gottenJar1.exists()) { throw new SparkException("file doesn't exist : " + jar1) diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index fc70e13246638..f4e1a6457a154 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.{Millis, Span} +import org.apache.spark.util.ArrayImplicits._ + class UnpersistSuite extends SparkFunSuite with LocalSparkContext with TimeLimits { // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x @@ -27,7 +29,7 @@ class UnpersistSuite extends SparkFunSuite with LocalSparkContext with TimeLimit test("unpersist RDD") { sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + val rdd = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2).cache() rdd.count() assert(sc.persistentRdds.nonEmpty) rdd.unpersist(blocking = true) diff --git a/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala index b22c07dce46ba..3b3bcff0c5a3f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala @@ -414,7 +414,7 @@ class DecommissionWorkerSuite master.self.askSync[MasterStateResponse](RequestMasterState) } - private def getApplications(): Seq[ApplicationInfo] = { + private def getApplications(): Array[ApplicationInfo] = { getMasterState.activeApps } diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 1a21fc38095a7..01995ca3632d2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -571,7 +571,7 @@ class StandaloneDynamicAllocationSuite } /** Get the applications that are active from Master */ - private def getApplications(): Seq[ApplicationInfo] = { + private def getApplications(): Array[ApplicationInfo] = { getMasterState.activeApps } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 7a67c4c8f75cc..d109ed8442d44 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -265,7 +265,7 @@ class AppClientSuite } /** Get the applications that are active from Master */ - private def getApplications(): Seq[ApplicationInfo] = { + private def getApplications(): Array[ApplicationInfo] = { getMasterState.activeApps } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala index 349985207e48c..3d35a612b5bbb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.EventLogTestHelper._ import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -367,6 +368,6 @@ class RollingEventLogFilesWriterSuite extends EventLogFileWritersSuite { private def listEventLogFiles(logDirPath: Path): Seq[FileStatus] = { fileSystem.listStatus(logDirPath).filter(isEventLogFile) - .sortBy { fs => getEventLogFileIndex(fs.getPath.getName) } + .sortBy { fs => getEventLogFileIndex(fs.getPath.getName) }.toImmutableArraySeq } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 9aa366de2ea51..1ca1e8fefd06f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -52,6 +52,7 @@ import org.apache.spark.status.api.v1.JobData import org.apache.spark.tags.ExtendedLevelDBTest import org.apache.spark.ui.SparkUI import org.apache.spark.util.{ResetSystemProperties, ShutdownHookManager, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * A collection of tests against the historyserver, including comparing responses from the json @@ -408,7 +409,7 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with def listDir(dir: Path): Seq[FileStatus] = { val statuses = fs.listStatus(dir) statuses.flatMap( - stat => if (stat.isDirectory) listDir(stat.getPath) else Seq(stat)) + stat => if (stat.isDirectory) listDir(stat.getPath) else Seq(stat)).toImmutableArraySeq } def dumpLogDir(msg: String = ""): Unit = { diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 1cc2c873760df..2f645e69079a2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState._ import org.apache.spark.rpc._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -499,7 +500,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { "--name", mainClass, "--class", mainClass, mainJar) ++ appArgs - val args = new SparkSubmitArguments(commandLineArgs) + val args = new SparkSubmitArguments(commandLineArgs.toImmutableArraySeq) val (_, _, sparkConf, _) = new SparkSubmit().prepareSubmitEnvironment(args) new RestSubmissionClient("spark://host:port").constructSubmitRequest( mainJar, mainClass, appArgs, sparkConf.getAll.toMap, Map.empty) diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 864adddad3426..38f65f8104554 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import org.apache.spark._ +import org.apache.spark.util.ArrayImplicits._ class DoubleRDDSuite extends SparkFunSuite with SharedSparkContext { test("sum") { @@ -282,7 +283,7 @@ class DoubleRDDSuite extends SparkFunSuite with SharedSparkContext { } test("WorksWithHugeRange") { - val rdd = sc.parallelize(Array(0, 1.0e24, 1.0e30)) + val rdd = sc.parallelize(Array(0, 1.0e24, 1.0e30).toImmutableArraySeq) val histogramResults = rdd.histogram(1000000)._2 assert(histogramResults(0) === 1) assert(histogramResults(1) === 1) diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index c30b4ca4dae11..9b60d2eeeed1b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -34,6 +34,7 @@ import org.scalatest.Assertions import org.apache.spark._ import org.apache.spark.Partitioner +import org.apache.spark.util.ArrayImplicits._ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { test("aggregateByKey") { @@ -486,7 +487,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { test("default partitioner uses partition size") { // specify 2000 partitions - val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) + val a = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2000) // do a map, which loses the partitioner val b = a.map(a => (a, (a * 2).toString)) // then a group by, and see we didn't revert to 2 partitions @@ -502,8 +503,8 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } test("subtract") { - val a = sc.parallelize(Array(1, 2, 3), 2) - val b = sc.parallelize(Array(2, 3, 4), 4) + val a = sc.parallelize(Array(1, 2, 3).toImmutableArraySeq, 2) + val b = sc.parallelize(Array(2, 3, 4).toImmutableArraySeq, 4) val c = a.subtract(b) assert(c.collect().toSet === Set(1)) assert(c.partitions.size === a.partitions.size) diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 2b57f8c8f6f23..d51e87c979ce4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -25,11 +25,12 @@ import org.scalacheck.Prop._ import org.scalatestplus.scalacheck.Checkers import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ArrayImplicits._ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { test("one element per slice") { val data = Array(1, 2, 3) - val slices = ParallelCollectionRDD.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data.toImmutableArraySeq, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === "1") assert(slices(1).mkString(",") === "2") @@ -37,7 +38,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { } test("one slice") { - val data = Array(1, 2, 3) + val data = Seq(1, 2, 3) val slices = ParallelCollectionRDD.slice(data, 1) assert(slices.size === 1) assert(slices(0).mkString(",") === "1,2,3") @@ -45,7 +46,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { test("equal slices") { val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9) - val slices = ParallelCollectionRDD.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data.toImmutableArraySeq, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === "1,2,3") assert(slices(1).mkString(",") === "4,5,6") @@ -54,7 +55,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { test("non-equal slices") { val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - val slices = ParallelCollectionRDD.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data.toImmutableArraySeq, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === "1,2,3") assert(slices(1).mkString(",") === "4,5,6") @@ -81,19 +82,19 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { } test("empty data") { - val data = new Array[Int](0) + val data = Seq.empty[Int] val slices = ParallelCollectionRDD.slice(data, 5) assert(slices.size === 5) for (slice <- slices) assert(slice.size === 0) } test("zero slices") { - val data = Array(1, 2, 3) + val data = Array(1, 2, 3).toImmutableArraySeq intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) } } test("negative number of slices") { - val data = Array(1, 2, 3) + val data = Array(1, 2, 3).toImmutableArraySeq intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index d55acec1e4eec..111f782809551 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler} /** a sampler that outputs its seed */ @@ -41,7 +42,7 @@ class MockSampler extends RandomSampler[Long, Long] { class PartitionwiseSampledRDDSuite extends SparkFunSuite with SharedSparkContext { test("seed distribution") { - val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2) + val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L).toImmutableArraySeq, 2) val sampler = new MockSampler val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, false, 0L) assert(sample.distinct().count() == 2, "Seeds must be different.") diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 2acd648b98d9e..3a097e5335a2a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} import org.scalatest.concurrent.Eventually import org.apache.spark._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class PipedRDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { @@ -41,7 +42,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext with Eventuall test("basic pipe") { assume(TestUtils.testCommandAvailable("cat")) - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) val piped = nums.pipe(Seq("cat")) @@ -55,7 +56,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext with Eventuall test("basic pipe with tokenization") { assume(TestUtils.testCommandAvailable("wc")) - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) // verify that both RDD.pipe(command: String) and RDD.pipe(command: String, env) work good for (piped <- Seq(nums.pipe("wc -l"), nums.pipe("wc -l", Map[String, String]()))) { @@ -69,7 +70,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext with Eventuall test("failure in iterating over pipe input") { assume(TestUtils.testCommandAvailable("cat")) val nums = - sc.makeRDD(Array(1, 2, 3, 4), 2) + sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) .mapPartitionsWithIndex((index, iterator) => { new Iterator[Int] { def hasNext = true @@ -88,7 +89,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext with Eventuall test("stdin writer thread should be exited when task is finished") { assume(TestUtils.testCommandAvailable("cat")) - val nums = sc.makeRDD(Array(1, 2, 3, 4), 1).map { x => + val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 1).map { x => val obj = new Object() obj.synchronized { obj.wait() // make the thread waits here. @@ -116,7 +117,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext with Eventuall test("advanced pipe") { assume(TestUtils.testCommandAvailable("cat")) - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) val bl = sc.broadcast(List("0")) val piped = nums.pipe(Seq("cat"), @@ -178,7 +179,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext with Eventuall test("pipe with env variable") { val executable = envCommand.split("\\s+", 2)(0) assume(TestUtils.testCommandAvailable(executable)) - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) val piped = nums.pipe(s"$envCommand MY_TEST_ENV", Map("MY_TEST_ENV" -> "LALALA")) val c = piped.collect() assert(c.length === 2) @@ -190,7 +191,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext with Eventuall test("pipe with process which cannot be launched due to bad command") { assume(!TestUtils.testCommandAvailable("some_nonexistent_command")) - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) val command = Seq("some_nonexistent_command") val piped = nums.pipe(command) val exception = intercept[SparkException] { @@ -201,7 +202,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext with Eventuall test("pipe with process which is launched but fails with non-zero exit status") { assume(TestUtils.testCommandAvailable("cat")) - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) val command = Seq("cat", "nonexistent_file") val piped = nums.pipe(command) val exception = intercept[SparkException] { @@ -212,7 +213,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext with Eventuall test("basic pipe with separate working directory") { assume(TestUtils.testCommandAvailable("cat")) - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) val piped = nums.pipe(Seq("cat"), separateWorkingDir = true) val c = piped.collect() assert(c.size === 4) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index f925a8b8b71d4..32ba2053258eb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.internal.config.{RDD_LIMIT_INITIAL_NUM_PARTITIONS, RDD_P import org.apache.spark.rdd.RDDSuiteUtils._ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { var tempDir: File = _ @@ -55,11 +56,11 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } test("basic operations") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) assert(nums.getNumPartitions === 2) assert(nums.collect().toList === List(1, 2, 3, 4)) assert(nums.toLocalIterator.toList === List(1, 2, 3, 4)) - val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) + val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4).toImmutableArraySeq, 2) assert(dups.distinct().count() === 4) assert(dups.distinct().count() === 4) // Can distinct and count be called without parentheses? assert(dups.distinct().collect() === dups.distinct().collect()) @@ -127,7 +128,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } test("SparkContext.union") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) assert(sc.union(nums).collect().toList === List(1, 2, 3, 4)) assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(sc.union(Seq(nums)).collect().toList === List(1, 2, 3, 4)) @@ -135,8 +136,8 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } test("SparkContext.union parallel partition listing") { - val nums1 = sc.makeRDD(Array(1, 2, 3, 4), 2) - val nums2 = sc.makeRDD(Array(5, 6, 7, 8), 2) + val nums1 = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) + val nums2 = sc.makeRDD(Array(5, 6, 7, 8).toImmutableArraySeq, 2) val serialUnion = sc.union(nums1, nums2) val expected = serialUnion.collect().toList @@ -296,7 +297,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } test("basic caching") { - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + val rdd = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4)) assert(rdd.collect().toList === List(1, 2, 3, 4)) assert(rdd.collect().toList === List(1, 2, 3, 4)) @@ -366,7 +367,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { // Coalesce partitions val input = Array.fill(1000)(1) val initialPartitions = 10 - val data = sc.parallelize(input, initialPartitions) + val data = sc.parallelize(input.toImmutableArraySeq, initialPartitions) val repartitioned1 = data.repartition(2) assert(repartitioned1.partitions.size == 2) @@ -392,9 +393,9 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } } - testSplitPartitions(Array.fill(100)(1), 10, 20) - testSplitPartitions(Array.fill(10000)(1) ++ Array.fill(10000)(2), 20, 100) - testSplitPartitions(Array.fill(1000)(1), 250, 128) + testSplitPartitions(Array.fill(100)(1).toImmutableArraySeq, 10, 20) + testSplitPartitions((Array.fill(10000)(1) ++ Array.fill(10000)(2)).toImmutableArraySeq, 20, 100) + testSplitPartitions(Array.fill(1000)(1).toImmutableArraySeq, 250, 128) } test("coalesced RDDs") { @@ -594,7 +595,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { } test("zipped RDDs") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val nums = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2) val zipped = nums.zip(nums.map(_ + 1.0)) assert(zipped.glom().map(_.toList).collect().toList === List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) @@ -684,7 +685,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { test("takeOrdered with predefined ordering") { val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - val rdd = sc.makeRDD(nums, 2) + val rdd = sc.makeRDD(nums.toImmutableArraySeq, 2) val sortedLowerK = rdd.takeOrdered(5) assert(sortedLowerK.size === 5) assert(sortedLowerK === Array(1, 2, 3, 4, 5)) @@ -692,7 +693,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { test("takeOrdered with limit 0") { val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - val rdd = sc.makeRDD(nums, 2) + val rdd = sc.makeRDD(nums.toImmutableArraySeq, 2) val sortedLowerK = rdd.takeOrdered(0) assert(sortedLowerK.size === 0) } @@ -705,7 +706,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { test("takeOrdered with custom ordering") { val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) implicit val ord = implicitly[Ordering[Int]].reverse - val rdd = sc.makeRDD(nums, 2) + val rdd = sc.makeRDD(nums.toImmutableArraySeq, 2) val sortedTopK = rdd.takeOrdered(5) assert(sortedTopK.size === 5) assert(sortedTopK === Array(10, 9, 8, 7, 6)) @@ -1238,7 +1239,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { val numCoalescedPartitions = 50 val locations = Array("locA", "locB") - val inputRDD = sc.makeRDD(Range(0, numInputPartitions).toArray[Int], numInputPartitions) + val inputRDD = sc.makeRDD(Range(0, numInputPartitions), numInputPartitions) assert(inputRDD.getNumPartitions == numInputPartitions) val locationPrefRDD = new LocationPrefRDD(inputRDD, { (p: Partition) => diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index 1ae2d51b125db..802889b047796 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.matchers.must.Matchers import org.scalatest.matchers.should.Matchers._ import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.util.ArrayImplicits._ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers { @@ -32,7 +33,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers { test("large array") { val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) + val pairs = sc.parallelize(pairArr.toImmutableArraySeq, 2) val sorted = pairs.sortByKey() assert(sorted.partitions.size === 2) assert(sorted.collect() === pairArr.sortBy(_._1)) @@ -41,7 +42,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers { test("large array with one split") { val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) + val pairs = sc.parallelize(pairArr.toImmutableArraySeq, 2) val sorted = pairs.sortByKey(true, 1) assert(sorted.partitions.size === 1) assert(sorted.collect() === pairArr.sortBy(_._1)) @@ -50,7 +51,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers { test("large array with many partitions") { val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) + val pairs = sc.parallelize(pairArr.toImmutableArraySeq, 2) val sorted = pairs.sortByKey(true, 20) assert(sorted.partitions.size === 20) assert(sorted.collect() === pairArr.sortBy(_._1)) @@ -59,40 +60,40 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers { test("sort descending") { val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) + val pairs = sc.parallelize(pairArr.toImmutableArraySeq, 2) assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) } test("sort descending with one split") { val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 1) + val pairs = sc.parallelize(pairArr.toImmutableArraySeq, 1) assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) } test("sort descending with many partitions") { val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) + val pairs = sc.parallelize(pairArr.toImmutableArraySeq, 2) assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) } test("more partitions than elements") { val rand = new scala.util.Random() val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 30) + val pairs = sc.parallelize(pairArr.toImmutableArraySeq, 30) assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) } test("empty RDD") { val pairArr = new Array[(Int, Int)](0) - val pairs = sc.parallelize(pairArr, 2) + val pairs = sc.parallelize(pairArr.toImmutableArraySeq, 2) assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) } test("partition balancing") { val pairArr = (1 to 1000).map(x => (x, x)).toArray - val sorted = sc.parallelize(pairArr, 4).sortByKey() + val sorted = sc.parallelize(pairArr.toImmutableArraySeq, 4).sortByKey() assert(sorted.collect() === pairArr.sortBy(_._1)) val partitions = sorted.collectPartitions() logInfo("Partition lengths: " + partitions.map(_.length).mkString(", ")) @@ -107,7 +108,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers { test("partition balancing for descending sort") { val pairArr = (1 to 1000).map(x => (x, x)).toArray - val sorted = sc.parallelize(pairArr, 4).sortByKey(false) + val sorted = sc.parallelize(pairArr.toImmutableArraySeq, 4).sortByKey(false) assert(sorted.collect() === pairArr.sortBy(_._1).reverse) val partitions = sorted.collectPartitions() logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) @@ -122,14 +123,14 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers { test("get a range of elements in a sorted RDD that is on one partition") { val pairArr = (1 to 1000).map(x => (x, x)).toArray - val sorted = sc.parallelize(pairArr, 10).sortByKey() + val sorted = sc.parallelize(pairArr.toImmutableArraySeq, 10).sortByKey() val range = sorted.filterByRange(20, 40).collect() assert((20 to 40).toArray === range.map(_._1)) } test("get a range of elements over multiple partitions in a descendingly sorted RDD") { val pairArr = (1000 to 1 by -1).map(x => (x, x)).toArray - val sorted = sc.parallelize(pairArr, 10).sortByKey(false) + val sorted = sc.parallelize(pairArr.toImmutableArraySeq, 10).sortByKey(false) val range = sorted.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray === range.map(_._1)) } @@ -143,7 +144,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers { test("get a range of elements over multiple partitions but not taking up full partitions") { val pairArr = (1000 to 1 by -1).map(x => (x, x)).toArray - val sorted = sc.parallelize(pairArr, 10).sortByKey(false) + val sorted = sc.parallelize(pairArr.toImmutableArraySeq, 10).sortByKey(false) val range = sorted.filterByRange(250, 850).collect() assert((850 to 250 by -1).toArray === range.map(_._1)) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5040532df4dd6..0f596d7d5b7bc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -49,6 +49,7 @@ import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, Clock, LongAccumulator, SystemClock, Utils} +import org.apache.spark.util.ArrayImplicits._ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) extends DAGSchedulerEventProcessLoop(dagScheduler) { @@ -2788,7 +2789,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti "still behave correctly on fetch failures") { // Runs a job that always encounters a fetch failure, so should eventually be aborted def runJobWithPersistentFetchFailure: Unit = { - val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2).map(x => (x, 1)).groupByKey() val shuffleHandle = rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle rdd1.map { @@ -2801,7 +2802,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // Runs a job that encounters a single fetch failure but succeeds on the second attempt def runJobWithTemporaryFetchFailure: Unit = { - val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4).toImmutableArraySeq, 2).map(x => (x, 1)).groupByKey() val shuffleHandle = rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle rdd1.map { diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 939923e12b8e7..cbb91ff9dca98 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.metrics.{ExecutorMetricType, MetricsSystem} import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.{JsonProtocol, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * Test whether EventLoggingListener logs events properly. @@ -587,7 +588,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit } else { stageIds.map(id => (id, 0) -> executorMetrics).toMap } - SparkListenerExecutorMetricsUpdate(executorId, accum, executorUpdates) + SparkListenerExecutorMetricsUpdate(executorId, accum.toImmutableArraySeq, executorUpdates) } private def createTaskEndEvent( diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExecutorResourceInfoSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExecutorResourceInfoSuite.scala index 3f99e2b4598f0..8676efe3140b3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExecutorResourceInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExecutorResourceInfoSuite.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.resource.ResourceUtils.GPU +import org.apache.spark.util.ArrayImplicits._ class ExecutorResourceInfoSuite extends SparkFunSuite { @@ -31,12 +32,12 @@ class ExecutorResourceInfoSuite extends SparkFunSuite { assert(info.assignedAddrs.isEmpty) // Acquire addresses - info.acquire(Seq("0", "1")) + info.acquire(Array("0", "1").toImmutableArraySeq) assert(info.availableAddrs.sorted sameElements Seq("2", "3")) assert(info.assignedAddrs.sorted sameElements Seq("0", "1")) // release addresses - info.release(Array("0", "1")) + info.release(Array("0", "1").toImmutableArraySeq) assert(info.availableAddrs.sorted sameElements Seq("0", "1", "2", "3")) assert(info.assignedAddrs.isEmpty) } @@ -49,7 +50,7 @@ class ExecutorResourceInfoSuite extends SparkFunSuite { assert(!info.availableAddrs.contains("1")) // Acquire an address that is not available val e = intercept[SparkException] { - info.acquire(Array("1")) + info.acquire(Array("1").toImmutableArraySeq) } assert(e.getMessage.contains("Try to acquire an address that is not available.")) } @@ -60,7 +61,7 @@ class ExecutorResourceInfoSuite extends SparkFunSuite { assert(!info.availableAddrs.contains("4")) // Acquire an address that doesn't exist val e = intercept[SparkException] { - info.acquire(Array("4")) + info.acquire(Array("4").toImmutableArraySeq) } assert(e.getMessage.contains("Try to acquire an address that doesn't exist.")) } @@ -69,11 +70,11 @@ class ExecutorResourceInfoSuite extends SparkFunSuite { // Init Executor Resource. val info = new ExecutorResourceInfo(GPU, Seq("0", "1", "2", "3"), 1) // Acquire addresses - info.acquire(Array("0", "1")) + info.acquire(Array("0", "1").toImmutableArraySeq) assert(!info.assignedAddrs.contains("2")) // Release an address that is not assigned val e = intercept[SparkException] { - info.release(Array("2")) + info.release(Array("2").toImmutableArraySeq) } assert(e.getMessage.contains("Try to release an address that is not assigned.")) } @@ -84,7 +85,7 @@ class ExecutorResourceInfoSuite extends SparkFunSuite { assert(!info.assignedAddrs.contains("4")) // Release an address that doesn't exist val e = intercept[SparkException] { - info.release(Array("4")) + info.release(Array("4").toImmutableArraySeq) } assert(e.getMessage.contains("Try to release an address that doesn't exist.")) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index ff5d8213ced98..c56fd3fd1f570 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ +import org.apache.spark.util.ArrayImplicits._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -476,7 +477,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark test("localProperties are propagated to executors correctly") { sc = new SparkContext("local", "test") sc.setLocalProperty("testPropKey", "testPropValue") - val res = sc.parallelize(Array(1), 1).map(i => i).map(i => { + val res = sc.parallelize(Array(1).toImmutableArraySeq, 1).map(i => i).map(i => { val inTask = TaskContext.get().getLocalProperty("testPropKey") val inDeser = Executor.taskDeserializationProps.get().getProperty("testPropKey") s"$inTask,$inDeser" diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 299ef51605991..2fe50a486dbd6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -45,6 +45,7 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.serializer.SerializerInstance import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{AccumulatorV2, Clock, ManualClock, SystemClock} +import org.apache.spark.util.ArrayImplicits._ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -836,10 +837,10 @@ class TaskSetManagerSuite val conf = new SparkConf().set(config.MAX_RESULT_SIZE.key, "2m") sc = new SparkContext("local", "test", conf) - def genBytes(size: Int): (Int) => Array[Byte] = { (x: Int) => + def genBytes(size: Int): (Int) => Seq[Byte] = { (x: Int) => val bytes = Array.ofDim[Byte](size) scala.util.Random.nextBytes(bytes) - bytes + bytes.toImmutableArraySeq } // multiple 1k result diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 86dc6c89883bf..2368900233648 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.OpenHashMap class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { @@ -294,7 +295,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("kryo with parallelize for primitive arrays") { - assert(sc.parallelize(Array(1, 2, 3)).count() === 3) + assert(sc.parallelize(Array(1, 2, 3).toImmutableArraySeq).count() === 3) } test("kryo with collect for specialized tuples") { @@ -302,7 +303,8 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("kryo with SerializableHyperLogLog") { - assert(sc.parallelize(Array(1, 2, 3, 2, 3, 3, 2, 3, 1)).countApproxDistinct(0.01) === 3) + assert(sc.parallelize(Array(1, 2, 3, 2, 3, 3, 2, 3, 1).toImmutableArraySeq) + .countApproxDistinct(0.01) === 3) } test("kryo with reduce") { diff --git a/core/src/test/scala/org/apache/spark/status/ListenerEventsTestHelper.scala b/core/src/test/scala/org/apache/spark/status/ListenerEventsTestHelper.scala index e7d78cbe3047f..8468624511ec5 100644 --- a/core/src/test/scala/org/apache/spark/status/ListenerEventsTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/status/ListenerEventsTestHelper.scala @@ -27,6 +27,7 @@ import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded, SparkListenerExecutorMetricsUpdate, SparkListenerExecutorRemoved, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerStageSubmitted, SparkListenerTaskEnd, SparkListenerTaskStart, StageInfo, TaskInfo, TaskLocality} import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{RDDInfo, StorageLevel} +import org.apache.spark.util.ArrayImplicits._ object ListenerEventsTestHelper { @@ -139,7 +140,8 @@ object ListenerEventsTestHelper { taskMetrics.incMemoryBytesSpilled(222) val accum = Array((333L, 1, 1, taskMetrics.accumulators().map(AccumulatorSuite.makeInfo))) val executorUpdates = Map((stageId, 0) -> new ExecutorMetrics(executorMetrics)) - SparkListenerExecutorMetricsUpdate(executorId.toString, accum, executorUpdates) + SparkListenerExecutorMetricsUpdate( + executorId.toString, accum.toImmutableArraySeq, executorUpdates) } case class JobInfo( diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala index 59ebc750af97e..4609925454526 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.config import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend import org.apache.spark.util.{ResetSystemProperties, SystemClock, ThreadUtils} +import org.apache.spark.util.ArrayImplicits._ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties with Eventually { @@ -75,7 +76,7 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS val blockManagerDecommissionStatus = if (SparkEnv.get.blockManager.decommissioner.isEmpty) false else true Iterator.single(blockManagerDecommissionStatus) - }.collect() + }.collect().toImmutableArraySeq assert(decommissionStatus.forall(_ == isEnabled)) sc.removeSparkListener(decommissionListener) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index a5e9a58d6456e..17dff20dd993b 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -65,6 +65,7 @@ import org.apache.spark.shuffle.{MigratableResolver, ShuffleBlockInfo, ShuffleBl import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTester @@ -1972,7 +1973,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) val blockLocations = Seq(BlockManagerId("1", "host1", 100), BlockManagerId("2", "host2", 200)) when(mockBlockManagerMaster.getLocations(mc.any[Array[BlockId]])) - .thenReturn(Array(blockLocations)) + .thenReturn(Array(blockLocations).toImmutableArraySeq) val env = mock(classOf[SparkEnv]) val blockIds: Array[BlockId] = Array(StreamBlockId(1, 2)) diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 2f084b2037efa..938465ab53265 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TaskContex import org.apache.spark.LocalSparkContext._ import org.apache.spark.partial.CountEvaluator import org.apache.spark.rdd.RDD +import org.apache.spark.util.ArrayImplicits._ class ClosureCleanerSuite extends SparkFunSuite { test("closures inside an object") { @@ -140,7 +141,7 @@ object TestObject { var nonSer = new NonSerializable val x = 5 withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) + val nums = sc.parallelize(Array(1, 2, 3, 4).toImmutableArraySeq) nums.map(_ + x).reduce(_ + _) } } @@ -154,7 +155,7 @@ class TestClass extends Serializable { def run(): Int = { var nonSer = new NonSerializable withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) + val nums = sc.parallelize(Array(1, 2, 3, 4).toImmutableArraySeq) nums.map(_ + getX).reduce(_ + _) } } @@ -166,7 +167,7 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { def run(): Int = { var nonSer = new NonSerializable withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) + val nums = sc.parallelize(Array(1, 2, 3, 4).toImmutableArraySeq) nums.map(_ + getX).reduce(_ + _) } } @@ -181,7 +182,7 @@ class TestClassWithoutFieldAccess { var nonSer2 = new NonSerializable val x = 5 withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) + val nums = sc.parallelize(Array(1, 2, 3, 4).toImmutableArraySeq) nums.map(_ + x).reduce(_ + _) } } @@ -190,7 +191,7 @@ class TestClassWithoutFieldAccess { object TestObjectWithBogusReturns { def run(): Int = { withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) + val nums = sc.parallelize(Array(1, 2, 3, 4).toImmutableArraySeq) // this return is invalid since it will transfer control outside the closure nums.map {x => return 1 ; x * 2} 1 @@ -201,7 +202,7 @@ object TestObjectWithBogusReturns { object TestObjectWithNestedReturns { def run(): Int = { withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) + val nums = sc.parallelize(Array(1, 2, 3, 4).toImmutableArraySeq) nums.map {x => // this return is fine since it will not transfer control outside the closure def foo(): Int = { return 5; 1 } @@ -217,7 +218,7 @@ object TestObjectWithNesting { var nonSer = new NonSerializable var answer = 0 withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) + val nums = sc.parallelize(Array(1, 2, 3, 4).toImmutableArraySeq) val y = 1 for (i <- 1 to 4) { var nonSer2 = new NonSerializable @@ -236,7 +237,7 @@ class TestClassWithNesting(val y: Int) extends Serializable { var nonSer = new NonSerializable var answer = 0 withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) + val nums = sc.parallelize(Array(1, 2, 3, 4).toImmutableArraySeq) for (i <- 1 to 4) { var nonSer2 = new NonSerializable val x = i diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index 82c11a3c7330a..b8b7756f36ee3 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.matchers.must.Matchers import org.scalatest.matchers.should.Matchers._ import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TestUtils} +import org.apache.spark.util.ArrayImplicits._ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { @@ -35,7 +36,7 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { classNames = Seq("FakeClass1"), classNamesWithBase = Seq(("FakeClass2", "FakeClass3")), // FakeClass3 is in parent toStringValue = "1", - classpathUrls = urls2)).toArray + classpathUrls = urls2.toImmutableArraySeq)).toArray val fileUrlsChild = List(TestUtils.createJarWithFiles(Map( "resource1" -> "resource1Contents-child", diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 469e429867fc3..7016ecd4f2ec6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -20,6 +20,7 @@ package org.apache.spark.examples import java.util.Random +import scala.collection.immutable import scala.math.exp import breeze.linalg.{DenseVector, Vector} @@ -42,13 +43,13 @@ object SparkLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData: Array[DataPoint] = { + def generateData: Seq[DataPoint] = { def generatePoint(i: Int): DataPoint = { val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D) {rand.nextGaussian + y * R} DataPoint(x, y) } - Array.tabulate(N)(generatePoint) + immutable.ArraySeq.unsafeWrapArray(Array.tabulate(N)(generatePoint)) } def showWarning(): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala index 38e7eca6e04a2..87e2f8082d8a3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import scala.collection.immutable + // $example on$ import org.apache.spark.ml.feature.Binarizer // $example off$ @@ -31,7 +33,7 @@ object BinarizerExample { .getOrCreate() // $example on$ - val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) + val data = immutable.ArraySeq.unsafeWrapArray(Array((0, 0.1), (1, 0.8), (2, 0.2))) val dataFrame = spark.createDataFrame(data).toDF("id", "feature") val binarizer: Binarizer = new Binarizer() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala index 185a9f3ed9cf0..560ac03c96866 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import scala.collection.immutable + // $example on$ import org.apache.spark.ml.feature.Bucketizer // $example off$ @@ -40,7 +42,8 @@ object BucketizerExample { val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) val data = Array(-999.9, -0.5, -0.3, 0.0, 0.2, 999.9) - val dataFrame = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val dataFrame = spark.createDataFrame( + immutable.ArraySeq.unsafeWrapArray(data.map(Tuple1.apply))).toDF("features") val bucketizer = new Bucketizer() .setInputCol("features") @@ -66,7 +69,8 @@ object BucketizerExample { (0.0, 0.0), (0.2, 0.4), (999.9, 999.9)) - val dataFrame2 = spark.createDataFrame(data2).toDF("features1", "features2") + val dataFrame2 = spark.createDataFrame(immutable.ArraySeq.unsafeWrapArray(data2)) + .toDF("features1", "features2") val bucketizer2 = new Bucketizer() .setInputCols(Array("features1", "features2")) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala index 4a0fd2b22faad..74ed2fd23cfc7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import scala.collection.immutable + // $example on$ import org.apache.spark.ml.feature.PCA import org.apache.spark.ml.linalg.Vectors @@ -37,7 +39,8 @@ object PCAExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) ) - val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val df = spark.createDataFrame(immutable.ArraySeq.unsafeWrapArray(data.map(Tuple1.apply))) + .toDF("features") val pca = new PCA() .setInputCol("features") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala index 2dbfe6cbef431..f8913bd8aa698 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import scala.collection.immutable + // $example on$ import org.apache.spark.ml.feature.PolynomialExpansion import org.apache.spark.ml.linalg.Vectors @@ -37,7 +39,8 @@ object PolynomialExpansionExample { Vectors.dense(0.0, 0.0), Vectors.dense(3.0, -1.0) ) - val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val df = spark.createDataFrame(immutable.ArraySeq.unsafeWrapArray(data.map(Tuple1.apply))) + .toDF("features") val polyExpansion = new PolynomialExpansion() .setInputCol("features") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index 3e022d23e6440..e08ceee927d6a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -17,6 +17,8 @@ package org.apache.spark.examples.ml +import scala.collection.immutable + // $example on$ import org.apache.spark.ml.feature.QuantileDiscretizer // $example off$ @@ -31,7 +33,7 @@ object QuantileDiscretizerExample { // $example on$ val data = Array((0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2)) - val df = spark.createDataFrame(data).toDF("id", "hour") + val df = spark.createDataFrame(immutable.ArraySeq.unsafeWrapArray(data)).toDF("id", "hour") // $example off$ // Output of QuantileDiscretizer for such small datasets can depend on the number of // partitions. Here we force a single partition to ensure consistent results. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CorrelationsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CorrelationsExample.scala index 1202caf534e95..cc54596ed9029 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CorrelationsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CorrelationsExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import scala.collection.immutable + import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.linalg._ @@ -33,9 +35,11 @@ object CorrelationsExample { val sc = new SparkContext(conf) // $example on$ - val seriesX: RDD[Double] = sc.parallelize(Array(1, 2, 3, 3, 5)) // a series + val seriesX: RDD[Double] = sc.parallelize( + immutable.ArraySeq.unsafeWrapArray(Array(1.0, 2.0, 3.0, 3.0, 5.0))) // a series // must have the same number of partitions and cardinality as seriesX - val seriesY: RDD[Double] = sc.parallelize(Array(11, 22, 33, 33, 555)) + val seriesY: RDD[Double] = sc.parallelize( + immutable.ArraySeq.unsafeWrapArray(Array(11.0, 22.0, 33.0, 33.0, 555.0))) // compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a // method is not specified, Pearson's method will be used by default. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala index da43a8d9c7e80..cb7c00682239d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import scala.collection.immutable + import org.apache.spark.SparkConf import org.apache.spark.SparkContext // $example on$ @@ -39,7 +41,7 @@ object PCAOnRowMatrixExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - val rows = sc.parallelize(data) + val rows = sc.parallelize(immutable.ArraySeq.unsafeWrapArray(data)) val mat: RowMatrix = new RowMatrix(rows) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala index 769ae2a3a88b1..66d6a5c4ee32f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import scala.collection.immutable + import org.apache.spark.SparkConf import org.apache.spark.SparkContext // $example on$ @@ -44,7 +46,7 @@ object SVDExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - val rows = sc.parallelize(data) + val rows = sc.parallelize(immutable.ArraySeq.unsafeWrapArray(data)) val mat: RowMatrix = new RowMatrix(rows) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index a20abd6e9d12e..cc699fcda3ed4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.streaming +import scala.collection.immutable + import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ @@ -52,7 +54,7 @@ object RawNetworkGrep { val rawStreams = (1 to numStreams).map(_ => ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray - val union = ssc.union(rawStreams) + val union = ssc.union(immutable.ArraySeq.unsafeWrapArray(rawStreams)) union.filter(_.contains("the")).count().foreachRDD(r => println(s"Grep count: ${r.collect().mkString}")) ssc.start() diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 3d4dde98318ee..e973fd4e38d81 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.graphx import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ +import org.apache.spark.util.ArrayImplicits._ class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { @@ -60,7 +61,7 @@ class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { case (a, b) => (a.toLong, b.toLong) } val correctEdges = edgeArray.filter { case (a, b) => a != b }.toSet - val graph = Graph.fromEdgeTuples(sc.parallelize(edgeArray), 1) + val graph = Graph.fromEdgeTuples(sc.parallelize(edgeArray.toImmutableArraySeq), 1) val canonicalizedEdges = graph.removeSelfEdges().edges.map(e => (e.srcId, e.dstId)) .collect() assert(canonicalizedEdges.toSet.size === canonicalizedEdges.size) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index 2c57e8927e4d6..a0a10c440efe9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.graphx.lib import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ +import org.apache.spark.util.ArrayImplicits._ class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { @@ -52,7 +53,7 @@ class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkCont Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ Array(3L -> 4L, 4L -> 5L, 5L -> 3L) ++ Array(6L -> 0L, 5L -> 7L) - val rawEdges = sc.parallelize(edges) + val rawEdges = sc.parallelize(edges.toImmutableArraySeq) val graph = Graph.fromEdgeTuples(rawEdges, -1) val sccGraph = graph.stronglyConnectedComponents(20) for ((id, scc) <- sccGraph.vertices.collect()) { diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index d8e17ddd24db7..5a0ee9307ab8a 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -21,7 +21,7 @@ import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => Java import java.util import scala.annotation.varargs -import scala.collection.mutable +import scala.collection.{immutable, mutable} import scala.jdk.CollectionConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} @@ -54,11 +54,15 @@ sealed trait Vector extends Serializable { if (this.size != v2.size) return false (this, v2) match { case (s1: SparseVector, s2: SparseVector) => - Vectors.equals(s1.indices, s1.values, s2.indices, s2.values) + Vectors.equals( + immutable.ArraySeq.unsafeWrapArray(s1.indices), s1.values, + immutable.ArraySeq.unsafeWrapArray(s2.indices), s2.values) case (s1: SparseVector, d1: DenseVector) => - Vectors.equals(s1.indices, s1.values, 0 until d1.size, d1.values) + Vectors.equals( + immutable.ArraySeq.unsafeWrapArray(s1.indices), s1.values, 0 until d1.size, d1.values) case (d1: DenseVector, s1: SparseVector) => - Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) + Vectors.equals( + 0 until d1.size, d1.values, immutable.ArraySeq.unsafeWrapArray(s1.indices), s1.values) case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) } case _ => false diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index f991f35b9d828..a84c5b62c430a 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.linalg +import scala.collection.immutable import scala.collection.mutable.ArrayBuilder import scala.util.Random @@ -52,7 +53,8 @@ class VectorsSuite extends SparkMLFunSuite { } test("sparse vector construction with unordered elements") { - val vec = Vectors.sparse(n, indices.zip(values).reverse).asInstanceOf[SparseVector] + val vec = Vectors.sparse(n, immutable.ArraySeq.unsafeWrapArray(indices.zip(values).reverse)) + .asInstanceOf[SparseVector] assert(vec.size === n) assert(vec.indices === indices) assert(vec.values === values) @@ -392,7 +394,7 @@ class VectorsSuite extends SparkMLFunSuite { test("sparse vector only support non-negative length") { val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) - val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + val v2 = Vectors.sparse(0, immutable.ArraySeq.unsafeWrapArray(Array.empty[(Int, Double)])) assert(v1.size === 0) assert(v2.size === 0) @@ -400,7 +402,7 @@ class VectorsSuite extends SparkMLFunSuite { Vectors.sparse(-1, Array(1), Array(2.0)) } intercept[IllegalArgumentException] { - Vectors.sparse(-1, Array((1, 2.0))) + Vectors.sparse(-1, immutable.ArraySeq.unsafeWrapArray(Array((1, 2.0)))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 98ab5faeb5e9d..cfb76c780cf14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel - +import org.apache.spark.util.ArrayImplicits._ /** * Common params for GaussianMixture and GaussianMixtureModel @@ -189,7 +189,7 @@ class GaussianMixtureModel private[ml] ( def gaussiansDF: DataFrame = { val modelGaussians = gaussians.map { gaussian => (OldVectors.fromML(gaussian.mean), OldMatrices.fromML(gaussian.cov)) - } + }.toImmutableArraySeq SparkSession.builder().getOrCreate().createDataFrame(modelGaussians).toDF("mean", "cov") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index e73d236d7d0e3..519978a0733b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.VersionUtils.majorVersion /** @@ -225,7 +226,7 @@ private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegi ClusterData(idx, center) } val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(data.toImmutableArraySeq).repartition(1).write.parquet(dataPath) } } @@ -499,7 +500,7 @@ class KMeans @Since("1.5.0") ( // Execute iterations of Lloyd's algorithm until converged while (iteration < $(maxIter) && !converged) { // Find the new centers - val bcCenters = sc.broadcast(DenseMatrix.fromVectors(centers)) + val bcCenters = sc.broadcast(DenseMatrix.fromVectors(centers.toImmutableArraySeq)) val countSumAccum = if (iteration == 0) sc.longAccumulator else null val weightSumAccum = if (iteration == 0) sc.doubleAccumulator else null val costSumAccum = sc.doubleAccumulator diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index c160dec13ff18..23402599e543d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -49,6 +49,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{monotonically_increasing_id, udf} import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils @@ -593,7 +594,8 @@ abstract class LDAModel private[ml] ( case ((termIndices, termWeights), topic) => (topic, termIndices.toSeq, termWeights.toSeq) } - sparkSession.createDataFrame(topics).toDF("topic", "termIndices", "termWeights") + sparkSession.createDataFrame(topics.toImmutableArraySeq) + .toDF("topic", "termIndices", "termWeights") } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index e073e41af4a78..16f72e18b9776 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -29,6 +29,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.Row import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ /** * Params for [[BucketedRandomProjectionLSH]]. @@ -68,7 +69,7 @@ class BucketedRandomProjectionLSHModel private[ml]( extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams { private[ml] def this(uid: String, randUnitVectors: Array[Vector]) = { - this(uid, Matrices.fromVectors(randUnitVectors)) + this(uid, Matrices.fromVectors(randUnitVectors.toImmutableArraySeq)) } private[ml] def randUnitVectors: Array[Vector] = randMatrix.rowIter.toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 74be13a074103..c02b43ceed5e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.{OpenHashMap, Utils} /** @@ -369,7 +370,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.vocabulary) + val data = Data(instance.vocabulary.toImmutableArraySeq) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 4d38c127d412d..ae65b17d7a810 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -28,6 +28,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * Params for [[Imputer]] and [[ImputerModel]]. @@ -201,7 +202,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) s"missingValue(${$(missingValue)})") } - val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results))) + val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results.toImmutableArraySeq))) val schema = StructType(inputColumns.map(col => StructField(col, DoubleType, nullable = false))) val surrogateDF = spark.createDataFrame(rows, schema) copyValues(new ImputerModel(uid, surrogateDF).setParent(this)) @@ -277,7 +278,7 @@ class ImputerModel private[ml] ( .otherwise(ic) .cast(inputType) } - dataset.withColumns(outputColumns, newCols).toDF() + dataset.withColumns(outputColumns.toImmutableArraySeq, newCols.toImmutableArraySeq).toDF() } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala index c237366ec5c3d..2c7a9d2a91e6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.ml.linalg._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.ArrayImplicits._ /** * Class that represents an instance of weighted data point with label and features. @@ -174,7 +175,7 @@ private[spark] object InstanceBlock { // the block memory usage may slightly exceed threshold, not a big issue. // and this ensure even if one row exceed block limit, each block has one row. - InstanceBlock.fromInstances(buff.result()) + InstanceBlock.fromInstances(buff.result().toImmutableArraySeq) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index a81c55a171571..c791edb9de153 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * Implements the feature interaction transform. This transformer takes in Double and Vector type @@ -70,8 +71,8 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val inputFeatures = $(inputCols).map(c => dataset.schema(c)) - val featureEncoders = getFeatureEncoders(inputFeatures) - val featureAttrs = getFeatureAttrs(inputFeatures) + val featureEncoders = getFeatureEncoders(inputFeatures.toImmutableArraySeq) + val featureAttrs = getFeatureAttrs(inputFeatures.toImmutableArraySeq) def interactFunc = udf { row: Row => var indices = ArrayBuilder.make[Int] @@ -168,7 +169,7 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext } case _: VectorUDT => val group = AttributeGroup.fromStructField(f) - encodedFeatureAttrs(group.attributes.get, Some(group.name)) + encodedFeatureAttrs(group.attributes.get.toImmutableArraySeq, Some(group.name)) } if (featureAttrs.isEmpty) { featureAttrs = encodedAttrs diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 5a500fefb57ec..e32addc7ee195 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.util.ArrayImplicits._ /** Private trait for params and common methods for OneHotEncoder and OneHotEncoderModel */ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid @@ -193,7 +194,8 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String) // When fitting data, we want the plain number of categories without `handleInvalid` and // `dropLast` taken into account. val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData( - dataset, inputColNames, outputColNames, dropLast = false) + dataset, inputColNames.toImmutableArraySeq, outputColNames.toImmutableArraySeq, + dropLast = false) attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) => categorySizes(idx) = attrGroup.size } @@ -372,7 +374,7 @@ class OneHotEncoderModel private[ml] ( encoder(col(inputColName).cast(DoubleType), lit(idx)) .as(outputColName, metadata) } - dataset.withColumns(outputColNames, encodedColumns) + dataset.withColumns(outputColNames.toImmutableArraySeq, encodedColumns) } @Since("3.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index c5b28c95eb7c9..88e63b766ca6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -22,6 +22,7 @@ import scala.util.parsing.combinator.RegexParsers import org.apache.spark.ml.linalg.VectorUDT import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * Represents a parsed R formula. @@ -125,7 +126,7 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { schema.fields.filter(_.dataType match { case _: NumericType | StringType | BooleanType | _: VectorUDT => true case _ => false - }).map(_.name).filter(_ != label.value) + }).map(_.name).filter(_ != label.value).toImmutableArraySeq } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 056baa1b6bf55..5862a60a407d4 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} +import org.apache.spark.util.ArrayImplicits._ /** * A feature transformer that filters out stop words from input. @@ -171,7 +172,10 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String t(col(inputColName)) } val outputMetadata = outputColNames.map(outputSchema(_).metadata) - dataset.withColumns(outputColNames, outputCols, outputMetadata) + dataset.withColumns( + outputColNames.toImmutableArraySeq, + outputCols.toImmutableArraySeq, + outputMetadata.toImmutableArraySeq) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 7f8850bebae35..4250b50673e81 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.ThreadUtils import org.apache.spark.util.VersionUtils.majorMinorVersion import org.apache.spark.util.collection.OpenHashMap @@ -197,7 +198,7 @@ class StringIndexer @Since("1.4.0") ( val aggregator = new StringIndexerAggregator(inputCols.length) implicit val encoder = Encoders.kryo[Array[OpenHashMap[String, Long]]] - val selectedCols = getSelectedCols(dataset, inputCols) + val selectedCols = getSelectedCols(dataset, inputCols.toImmutableArraySeq) dataset.select(selectedCols: _*) .toDF() .agg(aggregator.toColumn) @@ -218,7 +219,7 @@ class StringIndexer @Since("1.4.0") ( private def sortByAlphabet(dataset: Dataset[_], ascending: Boolean): Array[Array[String]] = { val (inputCols, _) = getInOutCols() - val selectedCols = getSelectedCols(dataset, inputCols).map(collect_set(_)) + val selectedCols = getSelectedCols(dataset, inputCols.toImmutableArraySeq).map(collect_set) val allLabels = dataset.select(selectedCols: _*) .collect().toSeq.flatMap(_.toSeq) .asInstanceOf[scala.collection.Seq[scala.collection.Seq[String]]].toSeq @@ -418,7 +419,7 @@ class StringIndexerModel ( // Skips invalid rows if `handleInvalid` is set to `StringIndexer.SKIP_INVALID`. val filteredDataset = if (getHandleInvalid == StringIndexer.SKIP_INVALID) { - filterInvalidData(dataset, inputColNames) + filterInvalidData(dataset, inputColNames.toImmutableArraySeq) } else { dataset } @@ -443,7 +444,7 @@ class StringIndexerModel ( .withValues(filteredLabels) .toMetadata() - val indexer = getIndexer(labels, labelToIndex) + val indexer = getIndexer(labels.toImmutableArraySeq, labelToIndex) outputColumns(i) = indexer(dataset(inputColName).cast(StringType)) .as(outputColName, metadata) @@ -455,7 +456,8 @@ class StringIndexerModel ( require(filteredOutputColNames.length == filteredOutputColumns.length) if (filteredOutputColNames.length > 0) { - filteredDataset.withColumns(filteredOutputColNames, filteredOutputColumns) + filteredDataset.withColumns( + filteredOutputColNames.toImmutableArraySeq, filteredOutputColumns.toImmutableArraySeq) } else { filteredDataset.toDF() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 04cd02fa0e277..8337e305e2b31 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} +import org.apache.spark.util.ArrayImplicits._ /** * A tokenizer that converts the input string to lowercase and then splits it by white spaces. @@ -37,7 +38,7 @@ class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) override protected def createTransformFunc: String => Seq[String] = { // scalastyle:off caselocale - _.toLowerCase.split("\\s") + _.toLowerCase.split("\\s").toImmutableArraySeq // scalastyle:on caselocale } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 47c0ca22f9672..fe54347b818a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -32,6 +32,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * A feature transformer that merges multiple columns into a vector column. @@ -92,7 +93,8 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) case _ => false } } - val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid)) + val vectorColsLengths = VectorAssembler.getLengths( + dataset, vectorCols.toImmutableArraySeq, $(handleInvalid)) val featureAttributesMap = $(inputCols).map { c => val field = schema(c) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 053aaac742b0b..638d8463b9d27 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.{Utils, VersionUtils} +import org.apache.spark.util.ArrayImplicits._ /** * Params for [[Word2Vec]] and [[Word2VecModel]]. @@ -230,7 +231,8 @@ class Word2VecModel private[ml] ( @Since("1.5.0") def findSynonyms(word: String, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(findSynonymsArray(word, num)).toDF("word", "similarity") + spark.createDataFrame(findSynonymsArray(word, num).toImmutableArraySeq) + .toDF("word", "similarity") } /** @@ -243,7 +245,8 @@ class Word2VecModel private[ml] ( @Since("2.0.0") def findSynonyms(vec: Vector, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(findSynonymsArray(vec, num)).toDF("word", "similarity") + spark.createDataFrame(findSynonymsArray(vec, num).toImmutableArraySeq) + .toDF("word", "similarity") } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 47fb8bc92298a..6a7615fb149b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.{col, lit, struct} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.VersionUtils import org.apache.spark.util.collection.OpenHashMap @@ -477,14 +478,15 @@ private[ml] object EnsembleModelReadWrite { instance.treeWeights(treeID)) } val treesMetadataPath = new Path(path, "treesMetadata").toString - sparkSession.createDataFrame(treesMetadataWeights) + sparkSession.createDataFrame(treesMetadataWeights.toImmutableArraySeq) .toDF("treeID", "metadata", "weights") .repartition(1) .write.parquet(treesMetadataPath) val dataPath = new Path(path, "data").toString val numDataParts = NodeData.inferNumPartitions(instance.trees.map(_.numNodes.toLong).sum) - val nodeDataRDD = sparkSession.sparkContext.parallelize(instance.trees.zipWithIndex) + val nodeDataRDD = sparkSession.sparkContext + .parallelize(instance.trees.zipWithIndex.toImmutableArraySeq) .flatMap { case (tree, treeID) => EnsembleNodeData.build(tree, treeID) } sparkSession.createDataFrame(nodeDataRDD) .repartition(numDataParts) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index c3979118de403..ceaae27d83429 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.util.ArrayImplicits._ /** * Clustering model produced by [[BisectingKMeans]]. @@ -182,8 +183,8 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { val data = getNodes(model.root).map(node => Data(node.index, node.size, node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, - node.children.map(_.index))) - spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) + node.children.map(_.index).toImmutableArraySeq)) + spark.createDataFrame(data.toImmutableArraySeq).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): BisectingKMeansModel = { @@ -218,8 +219,8 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { val data = getNodes(model.root).map(node => Data(node.index, node.size, node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, - node.children.map(_.index))) - spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) + node.children.map(_.index).toImmutableArraySeq)) + spark.createDataFrame(data.toImmutableArraySeq).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): BisectingKMeansModel = { @@ -256,8 +257,8 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { val data = getNodes(model.root).map(node => Data(node.index, node.size, node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, - node.children.map(_.index))) - spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) + node.children.map(_.index).toImmutableArraySeq)) + spark.createDataFrame(data.toImmutableArraySeq).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): BisectingKMeansModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 0c9c6ab826e62..42ec37f438f4e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.util.ArrayImplicits._ /** * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points @@ -152,7 +153,8 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { val dataArray = Array.tabulate(weights.length) { i => Data(weights(i), gaussians(i).mu, gaussians(i).sigma) } - spark.createDataFrame(sc.makeRDD(dataArray, 1)).write.parquet(Loader.dataPath(path)) + spark.createDataFrame(sc.makeRDD(dataArray.toImmutableArraySeq, 1)) + .write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): GaussianMixtureModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 1e8a271592bb9..c4ea263ee116a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.util.ArrayImplicits._ /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. @@ -172,9 +173,10 @@ object KMeansModel extends Loader[KMeansModel] { val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) - val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) => - Cluster(id, p.vector) - } + val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex.toImmutableArraySeq) + .map { case (p, id) => + Cluster(id, p.vector) + } spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) } @@ -206,9 +208,10 @@ object KMeansModel extends Loader[KMeansModel] { ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure) ~ ("trainingCost" -> model.trainingCost))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) - val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) => - Cluster(id, p.vector) - } + val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex.toImmutableArraySeq) + .map { case (p, id) => + Cluster(id, p.vector) + } spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 3202f08e220b0..a1b00bccbc34e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.stat.test.ChiSqTestResult import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.util.ArrayImplicits._ /** * Chi Squared selector model. @@ -141,7 +142,8 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val dataArray = Array.tabulate(model.selectedFeatures.length) { i => Data(model.selectedFeatures(i)) } - spark.createDataFrame(sc.makeRDD(dataArray, 1)).write.parquet(Loader.dataPath(path)) + spark.createDataFrame(sc.makeRDD(dataArray.toImmutableArraySeq, 1)) + .write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): ChiSqSelectorModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 25899fd3ebbc8..fbdb5843eb99d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * Represents a numeric vector, whose index type is Int and value type is Double. @@ -64,11 +65,12 @@ sealed trait Vector extends Serializable { if (this.size != v2.size) return false (this, v2) match { case (s1: SparseVector, s2: SparseVector) => - Vectors.equals(s1.indices, s1.values, s2.indices, s2.values) + Vectors.equals( + s1.indices.toImmutableArraySeq, s1.values, s2.indices.toImmutableArraySeq, s2.values) case (s1: SparseVector, d1: DenseVector) => - Vectors.equals(s1.indices, s1.values, 0 until d1.size, d1.values) + Vectors.equals(s1.indices.toImmutableArraySeq, s1.values, 0 until d1.size, d1.values) case (d1: DenseVector, s1: SparseVector) => - Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) + Vectors.equals(0 until d1.size, d1.values, s1.indices.toImmutableArraySeq, s1.values) case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) } case _ => false diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 1f879a4d9dfbb..b403770f5616a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -37,6 +37,7 @@ import org.apache.spark.mllib.tree.loss.Loss import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -428,9 +429,10 @@ private[tree] object TreeEnsembleModel extends Logging { sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) // Create Parquet data. - val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => - tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node)) - } + val dataRDD = sc.parallelize(model.trees.zipWithIndex.toImmutableArraySeq) + .flatMap { case (tree, treeId) => + tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node)) + } spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 31417ddeb2d44..0146dc67386f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix} import org.apache.spark.rdd.RDD +import org.apache.spark.util.ArrayImplicits._ /** * Generate RDD(s) containing data for Matrix Factorization. @@ -90,7 +91,7 @@ object MFDataGenerator { val omega = shuffled.slice(0, sampSize) val ordered = omega.sortWith(_ < _).toArray - val trainData: RDD[(Int, Int, Double)] = sc.parallelize(ordered) + val trainData: RDD[(Int, Int, Double)] = sc.parallelize(ordered.toImmutableArraySeq) .map(x => (x % m, x / m, fullData.values(x))) // optionally add gaussian noise @@ -105,7 +106,7 @@ object MFDataGenerator { val testSampSize = math.min(math.round(sampSize * testSampFact).toInt, mn - sampSize) val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) val testOrdered = testOmega.sortWith(_ < _).toArray - val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) + val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered.toImmutableArraySeq) .map(x => (x % m, x / m, fullData.values(x))) testData.map(x => s"${x._1},${x._2},${x._3}").saveAsTextFile(outputPath) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala index fdd6e352fa639..e2d00c98f1ca8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -36,7 +37,7 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { val data = inputs.zip(outputs).map { case (features, label) => (Vectors.dense(features), Vectors.dense(label)) } - val rddData = sc.parallelize(data, 1) + val rddData = sc.parallelize(data.toImmutableArraySeq, 1) val hiddenLayersTopology = Array(5) val dataSample = rddData.first() val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size @@ -70,7 +71,7 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { val data = inputs.zip(outputs).map { case (features, label) => (Vectors.dense(features), Vectors.dense(label)) } - val rddData = sc.parallelize(data, 1) + val rddData = sc.parallelize(data.toImmutableArraySeq, 1) val hiddenLayersTopology = Array(5) val dataSample = rddData.first() val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index b01ee79af151b..765fccf6c6207 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.util.ArrayImplicits._ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { @@ -47,17 +48,27 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { override def beforeAll(): Unit = { super.beforeAll() categoricalDataPointsRDD = - sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()).map(_.asML) + sc.parallelize( + OldDecisionTreeSuite.generateCategoricalDataPoints().toImmutableArraySeq).map(_.asML) orderedLabeledPointsWithLabel0RDD = - sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()).map(_.asML) + sc.parallelize( + OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0().toImmutableArraySeq) + .map(_.asML) orderedLabeledPointsWithLabel1RDD = - sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()).map(_.asML) + sc.parallelize( + OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1().toImmutableArraySeq) + .map(_.asML) categoricalDataPointsForMulticlassRDD = - sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass()).map(_.asML) + sc.parallelize( + OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass().toImmutableArraySeq) + .map(_.asML) continuousDataPointsForMulticlassRDD = - sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass()).map(_.asML) + sc.parallelize( + OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass().toImmutableArraySeq) + .map(_.asML) categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize( - OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) + OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + .toImmutableArraySeq) .map(_.asML) } @@ -117,7 +128,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { LabeledPoint(1.0, Vectors.dense(1.0)), LabeledPoint(1.0, Vectors.dense(2.0)), LabeledPoint(1.0, Vectors.dense(3.0))) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val dt = new DecisionTreeClassifier() .setImpurity("Gini") .setMaxDepth(4) @@ -131,7 +142,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val dt = new DecisionTreeClassifier() .setImpurity("Gini") .setMaxDepth(4) @@ -201,7 +212,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val dt = new DecisionTreeClassifier() .setImpurity("Gini") .setMaxDepth(2) @@ -218,7 +229,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val dt = new DecisionTreeClassifier() .setImpurity("Gini") .setMaxBins(2) @@ -234,7 +245,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val dt = new DecisionTreeClassifier() .setImpurity("Gini") @@ -340,7 +351,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { val data = TreeTests.getSingleTreeLeafData data.foreach { case (leafId, vec) => assert(leafId === model.predictLeaf(vec)) } - val df = sc.parallelize(data, 1).toDF("leafId", "features") + val df = sc.parallelize(data.toImmutableArraySeq, 1).toDF("leafId", "features") model.transform(df).select("leafId", "predictedLeafId") .collect() .foreach { case Row(leafId: Double, predictedLeafId: Double) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala index 3efb5eb196897..68e83fccf3d11 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/FMClassifierSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.ml.regression.FMRegressorSuite._ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.util.ArrayImplicits._ class FMClassifierSuite extends MLTest with DefaultReadWriteTest { @@ -122,7 +123,7 @@ class FMClassifierSuite extends MLTest with DefaultReadWriteTest { (0.0, Vectors.sparse(3, Array(0, 2), Array(-1.0, 2.0))), (0.0, Vectors.sparse(3, Array.emptyIntArray, Array.emptyDoubleArray)), (1.0, Vectors.sparse(3, Array(0, 1), Array(2.0, 3.0))) - )).toDF("label", "features") + ).toImmutableArraySeq).toDF("label", "features") val fm = new FMClassifier().setMaxIter(10) fm.fit(dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 0994465c92afb..6ce2108b1f7c8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.lit +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -58,13 +59,17 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { override def beforeAll(): Unit = { super.beforeAll() - data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + data = sc + .parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) + .toImmutableArraySeq, 2) .map(_.asML) trainData = - sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120) + .toImmutableArraySeq, 2) .map(_.asML) validationData = - sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80) + .toImmutableArraySeq, 2) .map(_.asML) binaryDataset = generateSVMInput(0.01, Array[Double](-1.5, 1.0), 1000, seed).toDF() } @@ -270,7 +275,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { val data = TreeTests.getTwoTreesLeafData data.foreach { case (leafId, vec) => assert(leafId === model.predictLeaf(vec)) } - val df = sc.parallelize(data, 1).toDF("leafId", "features") + val df = sc.parallelize(data.toImmutableArraySeq, 1).toDF("leafId", "features") model.transform(df).select("leafId", "predictedLeafId") .collect() .foreach { case Row(leafId: Vector, predictedLeafId: Vector) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 7376dd686b263..6fe482801c470 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.util.ArrayImplicits._ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { @@ -365,7 +366,7 @@ object LinearSVCSuite { val yD = new BDV(xi).dot(weightsMat) + intercept + 0.01 * rnd.nextGaussian() if (yD > 0) 1.0 else 0.0 } - y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) + y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))).toImmutableArraySeq } def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index bed45fc68f478..ec8ffd1073047 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.util.ArrayImplicits._ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { @@ -418,11 +419,11 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { generateGaussianNaiveBayesInput(piArray, thetaArray, sigmaArray, nPoints, 17).toDF() val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") - validatePrediction(predictionAndLabels.collect()) + validatePrediction(predictionAndLabels.collect().toImmutableArraySeq) val featureAndProbabilities = model.transform(validationDataset) .select("features", "probability") - validateProbabilities(featureAndProbabilities.collect(), model, "gaussian") + validateProbabilities(featureAndProbabilities.collect().toImmutableArraySeq, model, "gaussian") } test("Naive Bayes Gaussian - Model Coefficients") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index d0e55fa84241a..562cccedeef4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.util.ArrayImplicits._ /** * Test suite for [[RandomForestClassifier]]. @@ -49,10 +50,12 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { override def beforeAll(): Unit = { super.beforeAll() orderedLabeledPoints50_1000 = - sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + .toImmutableArraySeq) .map(_.asML) orderedLabeledPoints5_20 = - sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20)) + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20) + .toImmutableArraySeq) .map(_.asML) binaryDataset = generateSVMInput(0.01, Array[Double](-1.5, 1.0), 1000, seed).toDF() } @@ -108,7 +111,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)), LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) ) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4) val numClasses = 3 @@ -229,7 +232,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { val data = TreeTests.getTwoTreesLeafData data.foreach { case (leafId, vec) => assert(leafId === model.predictLeaf(vec)) } - val df = sc.parallelize(data, 1).toDF("leafId", "features") + val df = sc.parallelize(data.toImmutableArraySeq, 1).toDF("leafId", "features") model.transform(df).select("leafId", "predictedLeafId") .collect() .foreach { case Row(leafId: Vector, predictedLeafId: Vector) => @@ -314,7 +317,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { arr(i) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0)) } } - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val multinomialDataset = spark.createDataFrame(rdd) val rf = new RandomForestClassifier() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 9baad52db00b3..66b9b8f2ab31d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.util.ArrayImplicits._ class BinarizerSuite extends MLTest with DefaultReadWriteTest { @@ -203,7 +204,7 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { .setInputCols(Array("input")) .setOutputCols(Array("result1", "result2")) .setThreshold(1.0) - val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0).toImmutableArraySeq) .map(Tuple1.apply).toDF("input") intercept[IllegalArgumentException] { binarizer.transform(df).count() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 89a8e81a2a031..b48ef0fad1587 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.util.ArrayImplicits._ class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest { @@ -99,9 +100,12 @@ class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest { */ dataset = spark.createDataFrame(Seq( - (0.0, Vectors.sparse(6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0))), Vectors.dense(6.0)), - (1.0, Vectors.sparse(6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0))), Vectors.dense(0.0)), - (1.0, Vectors.sparse(6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0))), Vectors.dense(0.0)), + (0.0, Vectors.sparse(6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0)).toImmutableArraySeq), + Vectors.dense(6.0)), + (1.0, Vectors.sparse(6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0)).toImmutableArraySeq), + Vectors.dense(0.0)), + (1.0, Vectors.sparse(6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0)).toImmutableArraySeq), + Vectors.dense(0.0)), (1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)), Vectors.dense(0.0)), (2.0, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)), Vectors.dense(8.0)), (2.0, Vectors.dense(Array(8.0, 9.0, 6.0, 4.0, 0.0, 0.0)), Vectors.dense(8.0)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 5e32a654c130b..431772006c820 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.Row +import org.apache.spark.util.ArrayImplicits._ class CountVectorizerSuite extends MLTest with DefaultReadWriteTest { @@ -31,7 +32,7 @@ class CountVectorizerSuite extends MLTest with DefaultReadWriteTest { ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) } - private def split(s: String): Seq[String] = s.split("\\s+") + private def split(s: String): Seq[String] = s.split("\\s+").toImmutableArraySeq test("CountVectorizerModel common cases") { val df = Seq( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index b4e144ea5ba5e..582a11df793ed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.sql.Row +import org.apache.spark.util.ArrayImplicits._ class IDFSuite extends MLTest with DefaultReadWriteTest { @@ -38,7 +39,7 @@ class IDFSuite extends MLTest with DefaultReadWriteTest { val res = data.indices.zip(data.values).map { case (id, value) => (id, value * model(id)) } - Vectors.sparse(data.size, res) + Vectors.sparse(data.size, res.toImmutableArraySeq) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 73ef647f0e7a1..576065ce3be84 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.sql.Row +import org.apache.spark.util.ArrayImplicits._ class PCASuite extends MLTest with DefaultReadWriteTest { @@ -44,7 +45,7 @@ class PCASuite extends MLTest with DefaultReadWriteTest { Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) ) - val dataRDD = sc.parallelize(data, 2) + val dataRDD = sc.parallelize(data.toImmutableArraySeq, 2) val mat = new RowMatrix(dataRDD.map(OldVectors.fromML)) val pc = mat.computePrincipalComponents(3) @@ -83,10 +84,10 @@ class PCASuite extends MLTest with DefaultReadWriteTest { val data3 = data1.map(_.toSparse) val data4 = data1.map(_.toDense) - val df1 = spark.createDataFrame(data1.map(Tuple1.apply)).toDF("features") - val df2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("features") - val df3 = spark.createDataFrame(data3.map(Tuple1.apply)).toDF("features") - val df4 = spark.createDataFrame(data4.map(Tuple1.apply)).toDF("features") + val df1 = spark.createDataFrame(data1.map(Tuple1.apply).toImmutableArraySeq).toDF("features") + val df2 = spark.createDataFrame(data2.map(Tuple1.apply).toImmutableArraySeq).toDF("features") + val df3 = spark.createDataFrame(data3.map(Tuple1.apply).toImmutableArraySeq).toDF("features") + val df4 = spark.createDataFrame(data4.map(Tuple1.apply).toImmutableArraySeq).toDF("features") val pca = new PCA() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 542eb1759c442..5c654764a68f4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql._ +import org.apache.spark.util.ArrayImplicits._ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { @@ -63,7 +64,8 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { val numBuckets = 5 val expectedNumBuckets = 3 - val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0)) + val df = sc.parallelize( + Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0).toImmutableArraySeq) .map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() .setInputCol("input") @@ -431,7 +433,7 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { .setInputCols(Array("input")) .setOutputCols(Array("result1", "result2")) .setNumBuckets(3) - val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0).toImmutableArraySeq) .map(Tuple1.apply).toDF("input") intercept[IllegalArgumentException] { discretizer.fit(df) @@ -476,7 +478,7 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { .setInputCol("input") .setOutputCol("result") .setNumBucketsArray(Array(2, 5)) - val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0).toImmutableArraySeq) .map(Tuple1.apply).toDF("input") intercept[IllegalArgumentException] { discretizer.fit(df) @@ -499,7 +501,7 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { val spark = this.spark import spark.implicits._ - val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + val df = sc.parallelize(Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0).toImmutableArraySeq) .map(Tuple1.apply).toDF("input") val numBuckets = 2 val discretizer = new QuantileDiscretizer() @@ -520,7 +522,7 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { val a1 = Array.tabulate(200)(_ => rng.nextDouble() * 2.0 - 1.0) ++ Array.fill(20)(0.0) ++ Array.fill(20)(-0.0) - val df1 = sc.parallelize(a1, 2).toDF("id") + val df1 = sc.parallelize(a1.toImmutableArraySeq, 2).toDF("id") val qd = new QuantileDiscretizer() .setInputCol("id") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/UnivariateFeatureSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/UnivariateFeatureSelectorSuite.scala index e83d0f6b72ad1..2f9aa69e413d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/UnivariateFeatureSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/UnivariateFeatureSelectorSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.ml.stat.{ANOVATest, FValueTest} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.util.ArrayImplicits._ class UnivariateFeatureSelectorSuite extends MLTest with DefaultReadWriteTest { @@ -105,9 +106,12 @@ class UnivariateFeatureSelectorSuite extends MLTest with DefaultReadWriteTest { */ datasetChi2 = spark.createDataFrame(Seq( - (0.0, Vectors.sparse(6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0))), Vectors.dense(6.0)), - (1.0, Vectors.sparse(6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0))), Vectors.dense(0.0)), - (1.0, Vectors.sparse(6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0))), Vectors.dense(0.0)), + (0.0, Vectors.sparse(6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0)).toImmutableArraySeq), + Vectors.dense(6.0)), + (1.0, Vectors.sparse(6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0)).toImmutableArraySeq), + Vectors.dense(0.0)), + (1.0, Vectors.sparse(6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0)).toImmutableArraySeq), + Vectors.dense(0.0)), (1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)), Vectors.dense(0.0)), (2.0, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)), Vectors.dense(8.0)), (2.0, Vectors.dense(Array(8.0, 9.0, 6.0, 4.0, 0.0, 0.0)), Vectors.dense(8.0)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VarianceThresholdSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VarianceThresholdSelectorSuite.scala index 142abf2ccdfb9..77b819d1c9991 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VarianceThresholdSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VarianceThresholdSelectorSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.util.ArrayImplicits._ class VarianceThresholdSelectorSuite extends MLTest with DefaultReadWriteTest { @@ -83,11 +84,11 @@ class VarianceThresholdSelectorSuite extends MLTest with DefaultReadWriteTest { test("Test VarianceThresholdSelector: sparse vector") { val df = spark.createDataFrame(Seq( - (1, Vectors.sparse(6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0))), + (1, Vectors.sparse(6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0)).toImmutableArraySeq), Vectors.dense(Array(6.0, 0.0, 7.0, 0.0))), - (2, Vectors.sparse(6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0))), + (2, Vectors.sparse(6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0)).toImmutableArraySeq), Vectors.dense(Array(0.0, 6.0, 0.0, 9.0))), - (3, Vectors.sparse(6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0))), + (3, Vectors.sparse(6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0)).toImmutableArraySeq), Vectors.dense(Array(0.0, 3.0, 0.0, 5.0))), (4, Vectors.dense(Array(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)), Vectors.dense(Array(0.0, 8.0, 5.0, 4.0))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 3d90f9d9ac764..0991c0a9e1fdf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.Row import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.ArrayImplicits._ class VectorSlicerSuite extends MLTest with DefaultReadWriteTest { @@ -78,7 +79,8 @@ class VectorSlicerSuite extends MLTest with DefaultReadWriteTest { val resultAttrs = Array("f1", "f4").map(defaultAttr.withName) val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]]) - val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) } + val rdd = sc.parallelize(data.zip(expected).toImmutableArraySeq) + .map { case (a, b) => Row(a, b) } val df = spark.createDataFrame(rdd, StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField()))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregatorSuite.scala index 9ece1da7be067..20221e8c5ddbe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregatorSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ class BinaryLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -57,7 +58,7 @@ class BinaryLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpa fitIntercept: Boolean, fitWithMean: Boolean): BinaryLogisticBlockAggregator = { val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std.toArray val featuresMean = featuresSummarizer.mean.toArray val inverseStd = featuresStd.map(std => if (std != 0) 1.0 / std else 0.0) @@ -115,7 +116,7 @@ class BinaryLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpa val numFeatures = instances.head.features.size val coefWithIntercept = Vectors.dense(Array.fill(numFeatures + 1)(rng.nextDouble())) val coefWithoutIntercept = Vectors.dense(Array.fill(numFeatures)(rng.nextDouble())) - val block = InstanceBlock.fromInstances(instances) + val block = InstanceBlock.fromInstances(instances.toImmutableArraySeq) val aggIntercept = getNewAggregator(instances, coefWithIntercept, fitIntercept = true, fitWithMean = false) @@ -132,7 +133,7 @@ class BinaryLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpa val coefVec = Vectors.dense(1.0, 2.0) val numFeatures = instances.head.features.size val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std val stdCoefVec = Vectors.dense(Array.tabulate(coefVec.size)(i => coefVec(i) / featuresStd(i))) val weightSum = instances.map(_.weight).sum @@ -154,7 +155,7 @@ class BinaryLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpa Seq(1, 2, 4).foreach { blockSize => val blocks1 = scaledInstances .grouped(blockSize) - .map(seq => InstanceBlock.fromInstances(seq)) + .map(seq => InstanceBlock.fromInstances(seq.toImmutableArraySeq)) .toArray val blocks2 = blocks1.map { block => new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) @@ -175,7 +176,7 @@ class BinaryLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpa val interceptValue = 1.0 val numFeatures = instances.head.features.size val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std val stdCoefVec = Vectors.dense(Array.tabulate(coefVec.size)(i => coefVec(i) / featuresStd(i))) val weightSum = instances.map(_.weight).sum @@ -199,7 +200,7 @@ class BinaryLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpa Seq(1, 2, 4).foreach { blockSize => val blocks1 = scaledInstances .grouped(blockSize) - .map(seq => InstanceBlock.fromInstances(seq)) + .map(seq => InstanceBlock.fromInstances(seq.toImmutableArraySeq)) .toArray val blocks2 = blocks1.map { block => new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) @@ -220,7 +221,7 @@ class BinaryLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpa val interceptValue = 1.0 val numFeatures = instances.head.features.size val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std val featuresMean = featuresSummarizer.mean val stdCoefVec = Vectors.dense(Array.tabulate(coefVec.size)(i => coefVec(i) / featuresStd(i))) @@ -247,7 +248,7 @@ class BinaryLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpa Seq(1, 2, 4).foreach { blockSize => val blocks1 = scaledInstances .grouped(blockSize) - .map(seq => InstanceBlock.fromInstances(seq)) + .map(seq => InstanceBlock.fromInstances(seq.toImmutableArraySeq)) .toArray val blocks2 = blocks1.map { block => new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) @@ -277,7 +278,7 @@ class BinaryLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpa val aggConstantFeature = getNewAggregator(instancesConstantFeature, coefVec, fitIntercept = fitIntercept, fitWithMean = fitWithMean) aggConstantFeature - .add(InstanceBlock.fromInstances(standardize(instancesConstantFeature))) + .add(InstanceBlock.fromInstances(standardize(instancesConstantFeature).toImmutableArraySeq)) val grad = aggConstantFeature.gradient val coefVecFiltered = if (fitIntercept) { @@ -288,7 +289,8 @@ class BinaryLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpa val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered, coefVecFiltered, fitIntercept = fitIntercept, fitWithMean = fitWithMean) aggConstantFeatureFiltered - .add(InstanceBlock.fromInstances(standardize(instancesConstantFeatureFiltered))) + .add(InstanceBlock.fromInstances(standardize(instancesConstantFeatureFiltered) + .toImmutableArraySeq)) val gradFiltered = aggConstantFeatureFiltered.gradient // constant features should not affect gradient diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregatorSuite.scala index f0167a698578e..e729a5cb3b672 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregatorSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ class HingeBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -57,7 +58,7 @@ class HingeBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext coefficients: Vector, fitIntercept: Boolean): HingeBlockAggregator = { val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std.toArray val featuresMean = featuresSummarizer.mean.toArray val inverseStd = featuresStd.map(std => if (std != 0) 1.0 / std else 0.0) @@ -114,7 +115,7 @@ class HingeBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext val numFeatures = instances.head.features.size val coefWithIntercept = Vectors.dense(Array.fill(numFeatures + 1)(rng.nextDouble())) val coefWithoutIntercept = Vectors.dense(Array.fill(numFeatures)(rng.nextDouble())) - val block = InstanceBlock.fromInstances(instances) + val block = InstanceBlock.fromInstances(instances.toImmutableArraySeq) val aggIntercept = getNewAggregator(instances, coefWithIntercept, fitIntercept = true) aggIntercept.add(block) @@ -129,7 +130,7 @@ class HingeBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext val coefVec = Vectors.dense(1.0, 2.0) val numFeatures = instances.head.features.size val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std val stdCoefVec = Vectors.dense(Array.tabulate(coefVec.size)(i => coefVec(i) / featuresStd(i))) val weightSum = instances.map(_.weight).sum @@ -153,7 +154,7 @@ class HingeBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext Seq(1, 2, 4).foreach { blockSize => val blocks1 = scaledInstances .grouped(blockSize) - .map(seq => InstanceBlock.fromInstances(seq)) + .map(seq => InstanceBlock.fromInstances(seq.toImmutableArraySeq)) .toArray val blocks2 = blocks1.map { block => new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) @@ -173,7 +174,7 @@ class HingeBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext val interceptValue = 1.0 val numFeatures = instances.head.features.size val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std val featuresMean = featuresSummarizer.mean val stdCoefVec = Vectors.dense(Array.tabulate(coefVec.size)(i => coefVec(i) / featuresStd(i))) @@ -202,7 +203,7 @@ class HingeBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext Seq(1, 2, 4).foreach { blockSize => val blocks1 = scaledInstances .grouped(blockSize) - .map(seq => InstanceBlock.fromInstances(seq)) + .map(seq => InstanceBlock.fromInstances(seq.toImmutableArraySeq)) .toArray val blocks2 = blocks1.map { block => new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) @@ -232,7 +233,7 @@ class HingeBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext val aggConstantFeature = getNewAggregator(instancesConstantFeature, coefVec, fitIntercept = fitIntercept) aggConstantFeature - .add(InstanceBlock.fromInstances(standardize(instancesConstantFeature))) + .add(InstanceBlock.fromInstances(standardize(instancesConstantFeature).toImmutableArraySeq)) val grad = aggConstantFeature.gradient val coefVecFiltered = if (fitIntercept) { @@ -243,7 +244,8 @@ class HingeBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered, coefVecFiltered, fitIntercept = fitIntercept) aggConstantFeatureFiltered - .add(InstanceBlock.fromInstances(standardize(instancesConstantFeatureFiltered))) + .add(InstanceBlock.fromInstances(standardize(instancesConstantFeatureFiltered) + .toImmutableArraySeq)) val gradFiltered = aggConstantFeatureFiltered.gradient // constant features should not affect gradient diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberBlockAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberBlockAggregatorSuite.scala index 4bad50f2d4113..5c214d17fee8a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberBlockAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberBlockAggregatorSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ class HuberBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -59,7 +60,7 @@ class HuberBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext fitIntercept: Boolean, epsilon: Double): HuberBlockAggregator = { val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std.toArray val featuresMean = featuresSummarizer.mean.toArray val inverseStd = featuresStd.map(std => if (std != 0) 1.0 / std else 0.0) @@ -116,7 +117,7 @@ class HuberBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext val numFeatures = instances.head.features.size val coefWithIntercept = Vectors.dense(Array.fill(numFeatures + 1)(rng.nextDouble())) val coefWithoutIntercept = Vectors.dense(Array.fill(numFeatures)(rng.nextDouble())) - val block = InstanceBlock.fromInstances(instances) + val block = InstanceBlock.fromInstances(instances.toImmutableArraySeq) val aggIntercept = getNewAggregator(instances, coefWithIntercept, fitIntercept = true, epsilon = epsilon) @@ -133,7 +134,8 @@ class HuberBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext val coefVec = Vectors.dense(1.0, 2.0) val sigmaValue = 4.0 val numFeatures = instances.head.features.size - val (featuresSummarizer, _) = Summarizer.getRegressionSummarizers(sc.parallelize(instances)) + val (featuresSummarizer, _) = + Summarizer.getRegressionSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std val stdCoefVec = Vectors.dense(Array.tabulate(numFeatures)(i => coefVec(i) / featuresStd(i))) val weightSum = instances.map(_.weight).sum @@ -169,7 +171,7 @@ class HuberBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext Seq(1, 2, 4).foreach { blockSize => val blocks1 = scaledInstances .grouped(blockSize) - .map(seq => InstanceBlock.fromInstances(seq)) + .map(seq => InstanceBlock.fromInstances(seq.toImmutableArraySeq)) .toArray val blocks2 = blocks1.map { block => new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) @@ -190,7 +192,8 @@ class HuberBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext val interceptValue = 3.0 val sigmaValue = 4.0 val numFeatures = instances.head.features.size - val (featuresSummarizer, _) = Summarizer.getRegressionSummarizers(sc.parallelize(instances)) + val (featuresSummarizer, _) = + Summarizer.getRegressionSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std val featuresMean = featuresSummarizer.mean val stdCoefVec = Vectors.dense(Array.tabulate(numFeatures)(i => coefVec(i) / featuresStd(i))) @@ -232,7 +235,7 @@ class HuberBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext Seq(1, 2, 4).foreach { blockSize => val blocks1 = scaledInstances .grouped(blockSize) - .map(seq => InstanceBlock.fromInstances(seq)) + .map(seq => InstanceBlock.fromInstances(seq.toImmutableArraySeq)) .toArray val blocks2 = blocks1.map { block => new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) @@ -264,7 +267,7 @@ class HuberBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext val aggConstantFeature = getNewAggregator(instancesConstantFeature, coefVec, fitIntercept = fitIntercept, epsilon = epsilon) aggConstantFeature - .add(InstanceBlock.fromInstances(standardize(instancesConstantFeature))) + .add(InstanceBlock.fromInstances(standardize(instancesConstantFeature).toImmutableArraySeq)) val grad = aggConstantFeature.gradient val coefVecFiltered = if (fitIntercept) { @@ -275,7 +278,8 @@ class HuberBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered, coefVecFiltered, fitIntercept = fitIntercept, epsilon = epsilon) aggConstantFeatureFiltered - .add(InstanceBlock.fromInstances(standardize(instancesConstantFeatureFiltered))) + .add(InstanceBlock.fromInstances(standardize(instancesConstantFeatureFiltered) + .toImmutableArraySeq)) val gradFiltered = aggConstantFeatureFiltered.gradient // constant features should not affect gradient diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresBlockAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresBlockAggregatorSuite.scala index 11020edabdd08..e8ce703747634 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresBlockAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresBlockAggregatorSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ class LeastSquaresBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -56,7 +57,7 @@ class LeastSquaresBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpark coefficients: Vector, fitIntercept: Boolean): LeastSquaresBlockAggregator = { val (featuresSummarizer, ySummarizer) = - Summarizer.getRegressionSummarizers(sc.parallelize(instances)) + Summarizer.getRegressionSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val yStd = ySummarizer.std(0) val yMean = ySummarizer.mean(0) val featuresStd = featuresSummarizer.std.toArray @@ -111,7 +112,7 @@ class LeastSquaresBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpark val rng = new scala.util.Random val numFeatures = instances.head.features.size val coefVec = Vectors.dense(Array.fill(numFeatures)(rng.nextDouble())) - val block = InstanceBlock.fromInstances(instances) + val block = InstanceBlock.fromInstances(instances.toImmutableArraySeq) val agg = getNewAggregator(instances, coefVec, fitIntercept = true) @@ -127,7 +128,7 @@ class LeastSquaresBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpark val coefficients = Vectors.dense(1.0, 2.0) val numFeatures = coefficients.size val (featuresSummarizer, ySummarizer) = - Summarizer.getRegressionSummarizers(sc.parallelize(instances)) + Summarizer.getRegressionSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std.toArray val featuresMean = featuresSummarizer.mean.toArray val yStd = ySummarizer.std(0) @@ -162,7 +163,7 @@ class LeastSquaresBlockAggregatorSuite extends SparkFunSuite with MLlibTestSpark Seq(1, 2, 4).foreach { blockSize => val blocks1 = scaledInstances .grouped(blockSize) - .map(seq => InstanceBlock.fromInstances(seq)) + .map(seq => InstanceBlock.fromInstances(seq.toImmutableArraySeq)) .toArray val blocks2 = blocks1.map { block => new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregatorSuite.scala index 8200085f100ac..52848fa6a3c73 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregatorSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ class MultinomialLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -57,7 +58,7 @@ class MultinomialLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTe fitIntercept: Boolean, fitWithMean: Boolean): MultinomialLogisticBlockAggregator = { val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std.toArray val featuresMean = featuresSummarizer.mean.toArray val inverseStd = featuresStd.map(std => if (std != 0) 1.0 / std else 0.0) @@ -118,7 +119,7 @@ class MultinomialLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTe Array.fill(numClasses * (numFeatures + 1))(rng.nextDouble())) val coefWithoutIntercept = Vectors.dense( Array.fill(numClasses * numFeatures)(rng.nextDouble())) - val block = InstanceBlock.fromInstances(instances) + val block = InstanceBlock.fromInstances(instances.toImmutableArraySeq) val aggIntercept = getNewAggregator(instances, coefWithIntercept, fitIntercept = true, fitWithMean = false) @@ -136,7 +137,7 @@ class MultinomialLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTe val numFeatures = instances.head.features.size val numClasses = instances.map(_.label).toSet.size val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std val stdCoefMat = Matrices.dense(numClasses, numFeatures, Array.tabulate(coefArray.size)(i => coefArray(i) / featuresStd(i / numClasses))) @@ -179,7 +180,7 @@ class MultinomialLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTe Seq(1, 2, 4).foreach { blockSize => val blocks1 = scaledInstances .grouped(blockSize) - .map(seq => InstanceBlock.fromInstances(seq)) + .map(seq => InstanceBlock.fromInstances(seq.toImmutableArraySeq)) .toArray val blocks2 = blocks1.map { block => new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) @@ -202,7 +203,7 @@ class MultinomialLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTe val numClasses = instances.map(_.label).toSet.size val intercepts = Vectors.dense(interceptArray) val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std val stdCoefMat = Matrices.dense(numClasses, numFeatures, Array.tabulate(coefArray.size)(i => coefArray(i) / featuresStd(i / numClasses))) @@ -252,7 +253,7 @@ class MultinomialLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTe Seq(1, 2, 4).foreach { blockSize => val blocks1 = scaledInstances .grouped(blockSize) - .map(seq => InstanceBlock.fromInstances(seq)) + .map(seq => InstanceBlock.fromInstances(seq.toImmutableArraySeq)) .toArray val blocks2 = blocks1.map { block => new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) @@ -275,7 +276,7 @@ class MultinomialLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTe val numClasses = instances.map(_.label).toSet.size val intercepts = Vectors.dense(interceptArray) val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val featuresStd = featuresSummarizer.std val featuresMean = featuresSummarizer.mean val stdCoefMat = Matrices.dense(numClasses, numFeatures, @@ -330,7 +331,7 @@ class MultinomialLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTe Seq(1, 2, 4).foreach { blockSize => val blocks1 = scaledInstances .grouped(blockSize) - .map(seq => InstanceBlock.fromInstances(seq)) + .map(seq => InstanceBlock.fromInstances(seq.toImmutableArraySeq)) .toArray val blocks2 = blocks1.map { block => new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) @@ -360,7 +361,7 @@ class MultinomialLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTe val aggConstantFeature = getNewAggregator(instancesConstantFeature, coefVec, fitIntercept = fitIntercept, fitWithMean = fitWithMean) aggConstantFeature - .add(InstanceBlock.fromInstances(standardize(instancesConstantFeature))) + .add(InstanceBlock.fromInstances(standardize(instancesConstantFeature).toImmutableArraySeq)) val grad = aggConstantFeature.gradient val coefVecFiltered = if (fitIntercept) { @@ -371,7 +372,8 @@ class MultinomialLogisticBlockAggregatorSuite extends SparkFunSuite with MLlibTe val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered, coefVecFiltered, fitIntercept = fitIntercept, fitWithMean = fitWithMean) aggConstantFeatureFiltered - .add(InstanceBlock.fromInstances(standardize(instancesConstantFeatureFiltered))) + .add(InstanceBlock.fromInstances(standardize(instancesConstantFeatureFiltered) + .toImmutableArraySeq)) val gradFiltered = aggConstantFeatureFiltered.gradient // constant features should not affect gradient diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 08202c7f1f3c4..dafdd06a3a233 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.types._ import org.apache.spark.storage.{StorageLevel, StorageLevelMapper} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { @@ -766,7 +767,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { } test("SPARK-18268: ALS with empty RDD should fail with better message") { - val ratings = sc.parallelize(Array.empty[Rating[Int]]) + val ratings = sc.parallelize(Array.empty[Rating[Int]].toImmutableArraySeq) intercept[IllegalArgumentException] { ALS.train(ratings) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index a7b696e248c12..679ef83ec7624 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.util.ArrayImplicits._ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { @@ -43,7 +44,8 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { override def beforeAll(): Unit = { super.beforeAll() categoricalDataPointsRDD = - sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML)) + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML) + .toImmutableArraySeq) linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF() @@ -172,7 +174,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { val data = TreeTests.getSingleTreeLeafData data.foreach { case (leafId, vec) => assert(leafId === model.predictLeaf(vec)) } - val df = sc.parallelize(data, 1).toDF("leafId", "features") + val df = sc.parallelize(data.toImmutableArraySeq, 1).toDF("leafId", "features") model.transform(df).select("leafId", "predictedLeafId") .collect() .foreach { case Row(leafId: Double, predictedLeafId: Double) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 6a745b6252ccc..d7f15dc2cfe9e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.lit +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -52,13 +53,16 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { override def beforeAll(): Unit = { super.beforeAll() - data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) + .toImmutableArraySeq, 2) .map(_.asML) trainData = - sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120) + .toImmutableArraySeq, 2) .map(_.asML) validationData = - sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80) + .toImmutableArraySeq, 2) .map(_.asML) linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), @@ -149,7 +153,7 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { val data = TreeTests.getTwoTreesLeafData data.foreach { case (leafId, vec) => assert(leafId === model.predictLeaf(vec)) } - val df = sc.parallelize(data, 1).toDF("leafId", "features") + val df = sc.parallelize(data.toImmutableArraySeq, 1).toDF("leafId", "features") model.transform(df).select("leafId", "predictedLeafId") .collect() .foreach { case Row(leafId: Vector, predictedLeafId: Vector) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index ff17be1fc5364..15db8c5c22531 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.util.ArrayImplicits._ /** * Test suite for [[RandomForestRegressor]]. @@ -46,7 +47,7 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest { super.beforeAll() orderedLabeledPoints50_1000 = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) - .map(_.asML)) + .map(_.asML).toImmutableArraySeq) linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), @@ -139,7 +140,7 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest { val data = TreeTests.getTwoTreesLeafData data.foreach { case (leafId, vec) => assert(leafId === model.predictLeaf(vec)) } - val df = sc.parallelize(data, 1).toDF("leafId", "features") + val df = sc.parallelize(data.toImmutableArraySeq, 1).toDF("leafId", "features") model.transform(df).select("leafId", "predictedLeafId") .collect() .foreach { case Row(leafId: Vector, predictedLeafId: Vector) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala index bb34f3da23296..24c75803b6900 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.stat.test.ChiSqTest import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ class ChiSquareTestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -95,7 +96,7 @@ class ChiSquareTestSuite val sparseData = Array( LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) - val df = spark.createDataFrame(sparseData) + val df = spark.createDataFrame(sparseData.toImmutableArraySeq) val chi = ChiSquareTest.test(df, "features", "label") val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = chi.select("pValues", "degreesOfFreedom", "statistics") diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/FValueTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/FValueTestSuite.scala index 37195d2b503bf..7505a862988eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/FValueTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/FValueTestSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ class FValueTestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -150,9 +151,9 @@ class FValueTestSuite test("test DataFrame with sparse vector") { val df = spark.createDataFrame(Seq( - (4.6, Vectors.sparse(6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0)))), - (6.6, Vectors.sparse(6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0)))), - (5.1, Vectors.sparse(6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0)))), + (4.6, Vectors.sparse(6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0)).toImmutableArraySeq)), + (6.6, Vectors.sparse(6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0)).toImmutableArraySeq)), + (5.1, Vectors.sparse(6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0)).toImmutableArraySeq )), (7.6, Vectors.dense(Array(0.0, 9.0, 8.0, 5.0, 6.0, 4.0))), (9.0, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 4.0, 4.0))), (9.0, Vectors.dense(Array(8.0, 9.0, 6.0, 4.0, 0.0, 0.0))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala index 2ae21401538e7..27e90f1a44d06 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row +import org.apache.spark.util.ArrayImplicits._ class KolmogorovSmirnovTestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -48,7 +49,7 @@ class KolmogorovSmirnovTestSuite // Sample data from the distributions and parallelize it val n = 100000 val sampledArray = sampleDist.sample(n) - val sampledDF = sc.parallelize(sampledArray, 10).toDF("sample") + val sampledDF = sc.parallelize(sampledArray.toImmutableArraySeq, 10).toDF("sample") // Use a apache math commons local KS test to verify calculations val ksTest = new Math3KSTest() @@ -131,7 +132,7 @@ class KolmogorovSmirnovTestSuite -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 - ) + ).toImmutableArraySeq ).toDF("sample") val Row(pValue: Double, statistic: Double) = KolmogorovSmirnovTest .test(rData, "sample", "norm", 0, 1).head() diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala index 2a95faef98b63..0c66f35b8b3d2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.mllib.tree.EnsembleTestHelper import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ /** * Test suite for [[BaggedPoint]]. @@ -31,7 +32,7 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map { lp => Instance(lp.label, 0.5, lp.features.asML) } - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, (instance: Instance) => instance.weight * 4.0, seed = 42) baggedRDD.collect().foreach { baggedPoint => @@ -46,7 +47,7 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val seeds = Array(123, 5354, 230, 349867, 23987) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) seeds.foreach { seed => val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, (_: LabeledPoint) => 2.0, seed) @@ -65,7 +66,7 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val seeds = Array(123, 5354, 230, 349867, 23987) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) seeds.foreach { seed => val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed = seed) @@ -82,7 +83,7 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val seeds = Array(123, 5354, 230, 349867, 23987) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) seeds.foreach { seed => val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, (_: LabeledPoint) => 2.0, seed) @@ -101,7 +102,7 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val seeds = Array(123, 5354, 230, 349867, 23987) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) seeds.foreach { seed => val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed = seed) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala index 3a8fcd9cc4eff..6609e2c3c10e5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.loss.{AbsoluteError, LogLoss, SquaredError} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ /** * Test suite for [[GradientBoostedTrees]]. @@ -34,8 +35,10 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext test("runWithValidation stops early and performs better on a validation dataset") { // Set numIterations large enough so that it stops early. val numIterations = 20 - val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2).map(_.asML.toInstance) - val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2).map(_.asML.toInstance) + val trainRdd = sc.parallelize(OldGBTSuite.trainData.toImmutableArraySeq, 2) + .map(_.asML.toInstance) + val validateRdd = sc.parallelize(OldGBTSuite.validateData.toImmutableArraySeq, 2) + .map(_.asML.toInstance) val seed = 42 val algos = Array(Regression, Regression, Classification) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 274e5b0cb556f..c765960d5d1ee 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTes import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator, Variance} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.OpenHashMap /** @@ -46,7 +47,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification with continuous features: split calculation") { val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML.toInstance) assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -58,7 +59,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification with binary (ordered) categorical features: split calculation") { val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance) assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) @@ -75,7 +76,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { " with no samples for one category: split calculation") { val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance) assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) @@ -214,7 +215,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("train with empty arrays") { val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double])).toInstance val data = Array.fill(5)(lp) - val rdd = sc.parallelize(data) + val rdd = sc.parallelize(data.toImmutableArraySeq) val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2, maxBins = 5) @@ -229,7 +230,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("train with constant features") { val instance = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)).toInstance val data = Array.fill(5)(instance) - val rdd = sc.parallelize(data) + val rdd = sc.parallelize(data.toImmutableArraySeq) val strategy = new OldStrategy( OldAlgo.Classification, Gini, @@ -257,7 +258,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass classification with unordered categorical features: split calculations") { val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance) assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new OldStrategy( OldAlgo.Classification, Gini, @@ -299,7 +300,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() .map(_.asML.toInstance) assert(arr.length === 3000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) // 2^(10-1) - 1 > 100, so categorical features will be ordered @@ -329,7 +330,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) - val input = sc.parallelize(arr.map(_.toInstance)) + val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq) val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) @@ -373,7 +374,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) - val input = sc.parallelize(arr.map(_.toInstance)) + val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq) val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5, numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) @@ -427,7 +428,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(0.0, Vectors.dense(2.0)), LabeledPoint(0.0, Vectors.dense(2.0)), LabeledPoint(1.0, Vectors.dense(2.0))) - val input = sc.parallelize(arr.map(_.toInstance)) + val input = sc.parallelize(arr.map(_.toInstance).toImmutableArraySeq) // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, @@ -449,7 +450,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("Second level node building with vs. without groups") { val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML.toInstance) assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) // For tree with 1 group val strategy1 = new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 1000) @@ -492,7 +493,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { strategy: OldStrategy): Unit = { val numFeatures = 50 val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) - val rdd = sc.parallelize(arr).map(_.asML.toInstance) + val rdd = sc.parallelize(arr.toImmutableArraySeq).map(_.asML.toInstance) // Select feature subset for top nodes. Return true if OK. def checkFeatureSubsetStrategy( @@ -678,7 +679,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)), Instance(1.0, 1.0, Vectors.dense(1.0, 1.0)) ) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val numClasses = 2 val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, @@ -709,7 +710,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Instance(0.0, 1.0, Vectors.dense(1.0, 1.0)), Instance(0.5, 1.0, Vectors.dense(1.0, 1.0)) ) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, numClasses = 0, maxBins = 32) @@ -727,7 +728,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("weights at arbitrary scale") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10) - val rddWithUnitWeights = sc.parallelize(arr.map(_.asML.toInstance)) + val rddWithUnitWeights = sc.parallelize(arr.map(_.asML.toInstance).toImmutableArraySeq) val rddWithSmallWeights = rddWithUnitWeights.map { inst => Instance(inst.label, 0.001, inst.features) } @@ -756,7 +757,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Instance(0.0, 1.0, Vectors.dense(0.0)), Instance(1.0, 0.1, Vectors.dense(1.0)) ) - val rdd = sc.parallelize(data) + val rdd = sc.parallelize(data.toImmutableArraySeq) val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, minWeightFractionPerNode = 0.5) val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index a8a85391b1cb2..afca7e993b462 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -28,6 +28,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.util.ArrayImplicits._ private[ml] object TreeTests extends SparkFunSuite { @@ -233,7 +234,7 @@ private[ml] object TreeTests extends SparkFunSuite { LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), LabeledPoint(1.0, Vectors.dense(1.0, 2.0))) - sc.parallelize(arr) + sc.parallelize(arr.toImmutableArraySeq) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index def04b5011873..18cce169b4ce9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.test.TestSparkSession +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils trait MLTest extends StreamTest with TempDirectory { self: Suite => @@ -133,7 +134,7 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => (globalCheckFunction: Seq[Row] => Unit): Unit = { val dfOutput = transformer.transform(dataframe) val outputs = dfOutput.select(firstResultCol, otherResultCols: _*).collect() - globalCheckFunction(outputs) + globalCheckFunction(outputs.toImmutableArraySeq) } def testTransformer[A : Encoder]( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index ee90c8253c47a..4b3519879d4ec 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -237,7 +238,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) @@ -277,7 +279,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) @@ -308,7 +311,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) @@ -340,7 +344,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.8) + validatePrediction(model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, + validationData, 0.8) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData, 0.8) @@ -371,7 +376,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) @@ -525,7 +531,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w // The validation accuracy is not good since this model (even the original weights) doesn't have // very steep curve in logistic function so that when we draw samples from distribution, it's // very easy to assign to another labels. However, this prediction result is consistent to R. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.47) + validatePrediction(model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, + validationData, 0.47) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 287ef127e64f0..a617b2ed94c8b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils object NaiveBayesSuite { @@ -160,7 +161,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) @@ -211,7 +213,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index eacfeb7621d51..2d5292090eb36 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils object SVMSuite { @@ -52,7 +53,7 @@ object SVMSuite { val yD = new BDV(xi).dot(weightsMat) + intercept + 0.01 * rnd.nextGaussian() if (yD < 0) 0.0 else 1.0 } - y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) + y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))).toImmutableArraySeq } /** Binary labels, 3 features */ @@ -129,7 +130,8 @@ class SVMSuite extends SparkFunSuite with MLlibTestSparkContext { val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) @@ -161,7 +163,8 @@ class SVMSuite extends SparkFunSuite with MLlibTestSparkContext { val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index 54ed30799e7b9..2ba987b96ef79 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -62,7 +63,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { } test("two clusters") { - val data = sc.parallelize(GaussianTestData.data) + val data = sc.parallelize(GaussianTestData.data.toImmutableArraySeq) // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( @@ -91,7 +92,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { } test("two clusters with distributed decompositions") { - val data = sc.parallelize(GaussianTestData.data2, 2) + val data = sc.parallelize(GaussianTestData.data2.toImmutableArraySeq, 2) val k = 5 val d = data.first().size @@ -127,7 +128,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { } test("two clusters with sparse data") { - val data = sc.parallelize(GaussianTestData.data) + val data = sc.parallelize(GaussianTestData.data.toImmutableArraySeq) val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray)) // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( @@ -155,7 +156,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { } test("model save / load") { - val data = sc.parallelize(GaussianTestData.data) + val data = sc.parallelize(GaussianTestData.data.toImmutableArraySeq) val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) val tempDir = Utils.createTempDir() @@ -177,7 +178,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { } test("model prediction, parallel and local") { - val data = sc.parallelize(GaussianTestData.data) + val data = sc.parallelize(GaussianTestData.data.toImmutableArraySeq) val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) val batchPredictions = gmm.predict(data) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 8f311bbf9f840..6c0f096dc14a6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.graphx.Edge import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class LDASuite extends SparkFunSuite with MLlibTestSparkContext { @@ -73,7 +74,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { .setTopicConcentration(termSmoothing) .setMaxIterations(5) .setSeed(12345) - val corpus = sc.parallelize(tinyCorpus, 2) + val corpus = sc.parallelize(tinyCorpus.toImmutableArraySeq, 2) val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] @@ -189,7 +190,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("initializing with alpha length != k or 1 fails") { intercept[IllegalArgumentException] { val lda = new LDA().setK(2).setAlpha(Vectors.dense(1, 2, 3, 4)) - val corpus = sc.parallelize(tinyCorpus, 2) + val corpus = sc.parallelize(tinyCorpus.toImmutableArraySeq, 2) lda.run(corpus) } } @@ -197,14 +198,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("initializing with elements in alpha < 0 fails") { intercept[IllegalArgumentException] { val lda = new LDA().setK(4).setAlpha(Vectors.dense(-1, 2, 3, 4)) - val corpus = sc.parallelize(tinyCorpus, 2) + val corpus = sc.parallelize(tinyCorpus.toImmutableArraySeq, 2) lda.run(corpus) } } test("OnlineLDAOptimizer initialization") { val lda = new LDA().setK(2) - val corpus = sc.parallelize(tinyCorpus, 2) + val corpus = sc.parallelize(tinyCorpus.toImmutableArraySeq, 2) val op = new OnlineLDAOptimizer().initialize(corpus, lda) op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau0(567) assert(op.getAlpha.toArray.forall(_ === 0.5)) // default 1.0 / k @@ -224,7 +225,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { Vectors.sparse(vocabSize, Array(0, 1, 2), Array(1, 1, 1)), // apple, orange, banana Vectors.sparse(vocabSize, Array(3, 4, 5), Array(1, 1, 1)) // tiger, cat, dog ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - val corpus = sc.parallelize(docs, 2) + val corpus = sc.parallelize(docs.toImmutableArraySeq, 2) // Set GammaShape large to avoid the stochastic impact. val op = new OnlineLDAOptimizer().setTau0(1024).setKappa(0.51).setGammaShape(1e40) @@ -252,7 +253,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("OnlineLDAOptimizer with toy data") { - val docs = sc.parallelize(toyData) + val docs = sc.parallelize(toyData.toImmutableArraySeq) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) .setGammaShape(1e10) val lda = new LDA().setK(2) @@ -312,7 +313,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("LocalLDAModel logPerplexity") { - val docs = sc.parallelize(toyData) + val docs = sc.parallelize(toyData.toImmutableArraySeq) val ldaModel: LocalLDAModel = toyModel /* Verify results using gensim: @@ -338,7 +339,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("LocalLDAModel predict") { - val docs = sc.parallelize(toyData) + val docs = sc.parallelize(toyData.toImmutableArraySeq) val ldaModel: LocalLDAModel = toyModel /* Verify results using gensim: @@ -389,7 +390,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("OnlineLDAOptimizer with asymmetric prior") { - val docs = sc.parallelize(toyData) + val docs = sc.parallelize(toyData.toImmutableArraySeq) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) .setGammaShape(1e10) val lda = new LDA().setK(2) @@ -431,7 +432,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("OnlineLDAOptimizer alpha hyperparameter optimization") { val k = 2 - val docs = sc.parallelize(toyData) + val docs = sc.parallelize(toyData.toImmutableArraySeq) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) .setGammaShape(100).setOptimizeDocConcentration(true).setSampleWithReplacement(false) val lda = new LDA().setK(k) @@ -480,7 +481,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { .setTopicConcentration(topicConcentration) .setMaxIterations(5) .setSeed(12345) - val corpus = sc.parallelize(tinyCorpus, 2) + val corpus = sc.parallelize(tinyCorpus.toImmutableArraySeq, 2) val distributedModel: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] val tempDir2 = Utils.createTempDir() val path2 = tempDir2.toURI.toString @@ -531,7 +532,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { .zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + val distributedEmptyDocs = sc.parallelize(emptyDocs.toImmutableArraySeq, 2) val op = new EMLDAOptimizer() val lda = new LDA() @@ -551,7 +552,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { .zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + val distributedEmptyDocs = sc.parallelize(emptyDocs.toImmutableArraySeq, 2) val op = new OnlineLDAOptimizer() val lda = new LDA() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 96c1f220f6731..091220eff5452 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -44,18 +45,18 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark expectedPrecisions: Seq[Double], expectedRecalls: Seq[Double]): Unit = { - assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds) - assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve) + assertSequencesMatch(metrics.thresholds().collect().toImmutableArraySeq, expectedThresholds) + assertTupleSequencesMatch(metrics.roc().collect().toImmutableArraySeq, expectedROCCurve) assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(expectedROCCurve) absTol 1E-5) - assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve) + assertTupleSequencesMatch(metrics.pr().collect().toImmutableArraySeq, expectedPRCurve) assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(expectedPRCurve) absTol 1E-5) - assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), + assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect().toImmutableArraySeq, expectedThresholds.zip(expectedFMeasures1)) - assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), + assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect().toImmutableArraySeq, expectedThresholds.zip(expectedFmeasures2)) - assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), + assertTupleSequencesMatch(metrics.precisionByThreshold().collect().toImmutableArraySeq, expectedThresholds.zip(expectedPrecisions)) - assertTupleSequencesMatch(metrics.recallByThreshold().collect(), + assertTupleSequencesMatch(metrics.recallByThreshold().collect().toImmutableArraySeq, expectedThresholds.zip(expectedRecalls)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index 4bc84b83ab1db..2ff1c9194ab6c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -76,9 +77,12 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { */ lazy val labeledDiscreteData = sc.parallelize( - Seq(LabeledPoint(0.0, Vectors.sparse(6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0)))), - LabeledPoint(1.0, Vectors.sparse(6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0)))), - LabeledPoint(1.0, Vectors.sparse(6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0)))), + Seq(LabeledPoint(0.0, Vectors.sparse( + 6, Array((0, 6.0), (1, 7.0), (3, 7.0), (4, 6.0)).toImmutableArraySeq)), + LabeledPoint(1.0, Vectors.sparse( + 6, Array((1, 9.0), (2, 6.0), (4, 5.0), (5, 9.0)).toImmutableArraySeq)), + LabeledPoint(1.0, Vectors.sparse( + 6, Array((1, 9.0), (2, 3.0), (4, 5.0), (5, 5.0)).toImmutableArraySeq)), LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 5.0, 6.0, 4.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 6.0, 5.0, 4.0, 4.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 6.0, 4.0, 0.0, 0.0)))), 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala index a764ac8551f2f..dfa44a6099559 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -30,7 +31,7 @@ class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext { ) val scalingVec = Vectors.dense(2.0, 0.5, 0.0, 0.25) val transformer = new ElementwiseProduct(scalingVec) - val transformedData = transformer.transform(sc.makeRDD(denseData)) + val transformedData = transformer.transform(sc.makeRDD(denseData.toImmutableArraySeq)) val transformedVecs = transformedData.collect() val transformedVec = transformedVecs(0) val expectedVec = Vectors.dense(2.0, 2.0, 0.0, -2.25) @@ -42,7 +43,7 @@ class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext { val sparseData = Array( Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))) ) - val dataRDD = sc.parallelize(sparseData, 3) + val dataRDD = sc.parallelize(sparseData.toImmutableArraySeq, 3) val scalingVec = Vectors.dense(1.0, 0.0, 0.5) val transformer = new ElementwiseProduct(scalingVec) val data2 = sparseData.map(transformer.transform) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index 6c07e3a5cef2e..d3c6bbb57fa06 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -43,9 +44,9 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("hashing tf on an RDD") { val hashingTF = new HashingTF val localDocs: Seq[Seq[String]] = Seq( - "a a b b b c d".split(" "), - "a b c d a b c".split(" "), - "c b a c b a a".split(" ")) + "a a b b b c d".split(" ").toImmutableArraySeq, + "a b c d a b c".split(" ").toImmutableArraySeq, + "c b a c b a a".split(" ").toImmutableArraySeq) val docs = sc.parallelize(localDocs, 2) assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index 77eade64db904..81b42da8036e4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -35,7 +36,7 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { Vectors.sparse(3, Seq()) ) - lazy val dataRDD = sc.parallelize(data, 3) + lazy val dataRDD = sc.parallelize(data.toImmutableArraySeq, 3) test("Normalization using L1 distance") { val l1Normalizer = new Normalizer(1) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala index d0c8de0e75d53..ea02f58ca906c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { @@ -31,7 +32,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) ) - private lazy val dataRDD = sc.parallelize(data, 2) + private lazy val dataRDD = sc.parallelize(data.toImmutableArraySeq, 2) test("Correct computing use a PCA wrapper") { val k = dataRDD.count().toInt diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 2fad944344b31..18741f11fc07d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateSt import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.ArrayImplicits._ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -60,7 +61,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { test("Standardization with dense input when means and stds are provided") { - val dataRDD = sc.parallelize(denseData, 3) + val dataRDD = sc.parallelize(denseData.toImmutableArraySeq, 3) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler() @@ -128,7 +129,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { test("Standardization with dense input") { - val dataRDD = sc.parallelize(denseData, 3) + val dataRDD = sc.parallelize(denseData.toImmutableArraySeq, 3) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler() @@ -193,7 +194,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { test("Standardization with sparse input when means and stds are provided") { - val dataRDD = sc.parallelize(sparseData, 3) + val dataRDD = sc.parallelize(sparseData.toImmutableArraySeq, 3) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler() @@ -246,7 +247,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { test("Standardization with sparse input") { - val dataRDD = sc.parallelize(sparseData, 3) + val dataRDD = sc.parallelize(sparseData.toImmutableArraySeq, 3) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler() @@ -295,7 +296,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { test("Standardization with constant input when means and stds are provided") { - val dataRDD = sc.parallelize(constantData, 2) + val dataRDD = sc.parallelize(constantData.toImmutableArraySeq, 2) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler(withMean = true, withStd = false) @@ -323,7 +324,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { test("Standardization with constant input") { - val dataRDD = sc.parallelize(constantData, 2) + val dataRDD = sc.parallelize(constantData.toImmutableArraySeq, 2) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler(withMean = true, withStd = false) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index a700706cdc87d..9bcaa9e515c22 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -43,7 +44,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { Array(0, 3, 0, 4, 0, 4, 0, 3, 0), Array(0, 6, 0, 5, 0, 3, 0)) - val rdd = sc.parallelize(sequences, 2).cache() + val rdd = sc.parallelize(sequences.toImmutableArraySeq, 2).cache() val result1 = PrefixSpan.genFreqPatterns( rdd, minCount = 2L, maxPatternLength = 50, maxLocalProjDBSize = 16L) @@ -108,7 +109,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { Array(0, 1, 4, 0, 3, 0, 2, 3, 0, 1, 5, 0), Array(0, 5, 6, 0, 1, 2, 0, 4, 6, 0, 3, 0, 2, 0), Array(0, 5, 0, 7, 0, 1, 6, 0, 3, 0, 2, 0, 3, 0)) - val rdd = sc.parallelize(sequences, 2).cache() + val rdd = sc.parallelize(sequences.toImmutableArraySeq, 2).cache() val result = PrefixSpan.genFreqPatterns( rdd, minCount = 2, maxPatternLength = 5, maxLocalProjDBSize = 128L) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 19d424cb7ae30..a25a19e2d354b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.config.Kryo._ import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.util.ArrayImplicits._ class VectorsSuite extends SparkFunSuite { @@ -72,7 +73,8 @@ class VectorsSuite extends SparkFunSuite { } test("sparse vector construction with unordered elements") { - val vec = Vectors.sparse(n, indices.zip(values).reverse).asInstanceOf[SparseVector] + val vec = Vectors.sparse(n, indices.zip(values).reverse.toImmutableArraySeq) + .asInstanceOf[SparseVector] assert(vec.size === n) assert(vec.indices === indices) assert(vec.values === values) @@ -516,7 +518,7 @@ class VectorsSuite extends SparkFunSuite { test("sparse vector only support non-negative length") { val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) - val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + val v2 = Vectors.sparse(0, Array.empty[(Int, Double)].toImmutableArraySeq) assert(v1.size === 0) assert(v2.size === 0) @@ -524,7 +526,7 @@ class VectorsSuite extends SparkFunSuite { Vectors.sparse(-1, Array(1), Array(2.0)) } intercept[IllegalArgumentException] { - Vectors.sparse(-1, Array((1, 2.0))) + Vectors.sparse(-1, Array((1, 2.0)).toImmutableArraySeq) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 9bd65296cfe5b..2fde176a82c3e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} import org.apache.spark.mllib.random.RandomRDDs import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -200,7 +201,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { } test("svd of a low-rank matrix") { - val rows = sc.parallelize(Array.fill(4)(Vectors.dense(1.0, 1.0, 1.0)), 2) + val rows = sc.parallelize(Array.fill(4)(Vectors.dense(1.0, 1.0, 1.0)).toImmutableArraySeq, 2) val mat = new RowMatrix(rows, 4, 3) for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { val svd = mat.computeSVD(2, computeU = true, 1e-6, 300, 1e-10, mode) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index e5d5fbd33c159..cf11f848f9715 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -26,6 +26,7 @@ import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ object ALSSuite { @@ -174,7 +175,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { val model = ALS.train(ratings, 5, 15) val pairs = Array.tabulate(50, 50)((u, p) => (u - 25, p - 25)).flatten - val ans = model.predict(sc.parallelize(pairs)).collect() + val ans = model.predict(sc.parallelize(pairs.toImmutableArraySeq)).collect() ans.foreach { r => val u = r.user + 25 val p = r.product + 25 @@ -189,7 +190,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { } test("SPARK-18268: ALS with empty RDD should fail with better message") { - val ratings = sc.parallelize(Array.empty[Rating]) + val ratings = sc.parallelize(Array.empty[Rating].toImmutableArraySeq) intercept[IllegalArgumentException] { new ALS().run(ratings) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala index a206e922e5fc4..aa06b70307f53 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.matchers.must.Matchers import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -68,7 +69,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w private def runIsotonicRegression( labels: Seq[Double], isotonic: Boolean): IsotonicRegressionModel = { - runIsotonicRegression(labels, Array.fill(labels.size)(1d), isotonic) + runIsotonicRegression(labels, Array.fill(labels.size)(1d).toImmutableArraySeq, isotonic) } private def runIsotonicRegression( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 8eb142541c9c5..8934eb85685db 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils private object LassoSuite { @@ -73,7 +74,8 @@ class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) @@ -117,7 +119,8 @@ class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index be0834d0fd7df..d38d399f828cd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils private object LinearRegressionSuite { @@ -62,7 +63,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val validationRDD = sc.parallelize(validationData, 2).cache() // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) @@ -89,7 +91,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val validationRDD = sc.parallelize(validationData, 2).cache() // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + validatePrediction( + model.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, validationData) // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) @@ -124,7 +127,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { // Test prediction on RDD. validatePrediction( - model.predict(sparseValidationRDD.map(_.features)).collect(), sparseValidationData) + model.predict( + sparseValidationRDD.map(_.features)).collect().toImmutableArraySeq, sparseValidationData) // Test prediction on Array. validatePrediction( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 2d6aec184ad9d..74fe57a64e5b6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils private object RidgeRegressionSuite { @@ -64,12 +65,14 @@ class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val linearModel = linearReg.run(testRDD) val linearErr = predictionError( - linearModel.predict(validationRDD.map(_.features)).collect(), validationData) + linearModel.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, + validationData) val ridgeReg = new RidgeRegressionWithSGD(1.0, 200, 0.1, 1.0) val ridgeModel = ridgeReg.run(testRDD) val ridgeErr = predictionError( - ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData) + ridgeModel.predict(validationRDD.map(_.features)).collect().toImmutableArraySeq, + validationData) // Ridge validation error should be lower than linear regression. assert(ridgeErr < linearErr, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index eec54bec3277e..9a7f6d5fcad74 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -26,13 +26,14 @@ import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation SpearmanCorrelation} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext { // test input data - val xData = Array(1.0, 0.0, -2.0) - val yData = Array(4.0, 5.0, 3.0) - val zeros = new Array[Double](3) + val xData = Array(1.0, 0.0, -2.0).toImmutableArraySeq + val yData = Array(4.0, 5.0, 3.0).toImmutableArraySeq + val zeros = new Array[Double](3).toImmutableArraySeq val data = Seq( Vectors.dense(1.0, 0.0, 0.0, -2.0), Vectors.dense(4.0, 5.0, 0.0, 3.0), @@ -41,8 +42,8 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext { ) test("corr(x, y) pearson, 1 value in data") { - val x = sc.parallelize(Array(1.0)) - val y = sc.parallelize(Array(4.0)) + val x = sc.parallelize(Array(1.0).toImmutableArraySeq) + val y = sc.parallelize(Array(4.0).toImmutableArraySeq) intercept[IllegalArgumentException] { Statistics.corr(x, y, "pearson") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 992b876561896..e24c7a746c997 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -140,7 +141,7 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { val sparseData = Array( new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), new LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) - val chi = Statistics.chiSqTest(sc.parallelize(sparseData)) + val chi = Statistics.chiSqTest(sc.parallelize(sparseData.toImmutableArraySeq)) assert(chi.size === numCols) assert(chi(1000) != null) // SPARK-3087 @@ -175,9 +176,9 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { // Sample data from the distributions and parallelize it val n = 100000 - val sampledNorm = sc.parallelize(stdNormalDist.sample(n), 10) - val sampledExp = sc.parallelize(expDist.sample(n), 10) - val sampledUnif = sc.parallelize(unifDist.sample(n), 10) + val sampledNorm = sc.parallelize(stdNormalDist.sample(n).toImmutableArraySeq, 10) + val sampledExp = sc.parallelize(expDist.sample(n).toImmutableArraySeq, 10) + val sampledUnif = sc.parallelize(unifDist.sample(n).toImmutableArraySeq, 10) // Use a apache math commons local KS test to verify calculations val ksTest = new KolmogorovSmirnovTest() @@ -251,7 +252,7 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 - ) + ).toImmutableArraySeq ) val rCompResult = Statistics.kolmogorovSmirnovTest(rData, "norm", 0, 1) assert(rCompResult.statistic ~== rKSStat relTol 1e-4) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala index 9cbb3d0024daa..d60d2350319ee 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala @@ -22,10 +22,11 @@ import org.apache.commons.math3.distribution.NormalDistribution import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.ArrayImplicits._ class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext { test("kernel density single sample") { - val rdd = sc.parallelize(Array(5.0)) + val rdd = sc.parallelize(Array(5.0).toImmutableArraySeq) val evaluationPoints = Array(5.0, 6.0) val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal = new NormalDistribution(5.0, 3.0) @@ -35,7 +36,7 @@ class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext { } test("kernel density multiple samples") { - val rdd = sc.parallelize(Array(5.0, 10.0)) + val rdd = sc.parallelize(Array(5.0, 10.0).toImmutableArraySeq) val evaluationPoints = Array(5.0, 6.0) val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal1 = new NormalDistribution(5.0, 3.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 9a2356f980008..2f5579acea46c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -41,7 +42,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification stump with ordered categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy( Classification, Gini, @@ -65,7 +66,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Regression stump with 3-ary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy( Regression, Variance, @@ -93,7 +94,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Regression stump with binary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy( Regression, Variance, @@ -105,7 +106,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(!metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) - DecisionTreeSuite.validateRegressor(model, arr, 0.0) + DecisionTreeSuite.validateRegressor(model, arr.toImmutableArraySeq, 0.0) assert(model.numNodes === 3) assert(model.depth === 1) } @@ -113,7 +114,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification stump with fixed label 0 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClasses = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) @@ -130,7 +131,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification stump with fixed label 1 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClasses = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) @@ -147,7 +148,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification stump with fixed label 0 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClasses = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) @@ -164,7 +165,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification stump with fixed label 1 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClasses = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) @@ -180,7 +181,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass classification stump with 3-ary (unordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) @@ -203,12 +204,12 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(1.0)), LabeledPoint(1.0, Vectors.dense(2.0)), LabeledPoint(1.0, Vectors.dense(3.0))) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 2) val model = DecisionTree.train(rdd, strategy) - DecisionTreeSuite.validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr.toImmutableArraySeq, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) } @@ -220,12 +221,12 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 2) val model = DecisionTree.train(rdd, strategy) - DecisionTreeSuite.validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr.toImmutableArraySeq, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) assert(model.topNode.split.get.feature === 1) @@ -235,7 +236,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { " with just enough bins") { val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = maxBins, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) @@ -245,7 +246,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) - DecisionTreeSuite.validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr.toImmutableArraySeq, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -264,13 +265,13 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass classification stump with continuous features") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 100) assert(strategy.isMulticlassClassification) val model = DecisionTree.train(rdd, strategy) - DecisionTreeSuite.validateClassifier(model, arr, 0.9) + DecisionTreeSuite.validateClassifier(model, arr.toImmutableArraySeq, 0.9) val rootNode = model.topNode @@ -284,7 +285,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass classification stump with continuous + unordered categorical features") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) @@ -292,7 +293,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metadata.isUnordered(featureIndex = 0)) val model = DecisionTree.train(rdd, strategy) - DecisionTreeSuite.validateClassifier(model, arr, 0.9) + DecisionTreeSuite.validateClassifier(model, arr.toImmutableArraySeq, 0.9) val rootNode = model.topNode @@ -305,7 +306,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass classification stump with 10-ary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) @@ -326,14 +327,14 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass classification tree with 10-ary (ordered) categorical features," + " with just enough bins") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 10, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val model = DecisionTree.train(rdd, strategy) - DecisionTreeSuite.validateClassifier(model, arr, 0.6) + DecisionTreeSuite.validateClassifier(model, arr.toImmutableArraySeq, 0.6) } test("split must satisfy min instances per node requirements") { @@ -341,7 +342,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, numClasses = 2, minInstancesPerNode = 2) @@ -368,7 +369,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0, 0.0)), LabeledPoint(1.0, Vectors.dense(0.0, 0.0))) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2), numClasses = 2, minInstancesPerNode = 2) @@ -388,7 +389,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) - val input = sc.parallelize(arr) + val input = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, numClasses = 2, minInfoGain = 1.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 206aef78884d0..de8e8c9969d44 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.loss.{AbsoluteError, LogLoss, SquaredError} import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -35,7 +36,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { case (numIterations, learningRate, subsamplingRate) => - val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + val rdd = sc.parallelize(GradientBoostedTreesSuite.data.toImmutableArraySeq, 2) val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) @@ -46,7 +47,8 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext assert(gbt.trees.size === numIterations) try { - EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) + EnsembleTestHelper.validateRegressor( + gbt, GradientBoostedTreesSuite.data.toImmutableArraySeq, 0.06) } catch { case e: java.lang.AssertionError => logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + @@ -65,7 +67,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext test("Regression with continuous features: Absolute Error") { GradientBoostedTreesSuite.testCombinations.foreach { case (numIterations, learningRate, subsamplingRate) => - val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + val rdd = sc.parallelize(GradientBoostedTreesSuite.data.toImmutableArraySeq, 2) val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) @@ -76,7 +78,8 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext assert(gbt.trees.size === numIterations) try { - EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae") + EnsembleTestHelper.validateRegressor( + gbt, GradientBoostedTreesSuite.data.toImmutableArraySeq, 0.85, "mae") } catch { case e: java.lang.AssertionError => logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + @@ -95,7 +98,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext test("Binary classification with continuous features: Log Loss") { GradientBoostedTreesSuite.testCombinations.foreach { case (numIterations, learningRate, subsamplingRate) => - val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + val rdd = sc.parallelize(GradientBoostedTreesSuite.data.toImmutableArraySeq, 2) val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2, numClasses = 2, categoricalFeaturesInfo = Map.empty, @@ -107,7 +110,8 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext assert(gbt.trees.size === numIterations) try { - EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9) + EnsembleTestHelper.validateClassifier( + gbt, GradientBoostedTreesSuite.data.toImmutableArraySeq, 0.9) } catch { case e: java.lang.AssertionError => logError(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + @@ -162,7 +166,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val path = tempDir.toURI.toString sc.setCheckpointDir(path) - val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + val rdd = sc.parallelize(GradientBoostedTreesSuite.data.toImmutableArraySeq, 2) val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, categoricalFeaturesInfo = Map.empty, checkpointInterval = 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 078c6e6fff9fc..9b613795e7a89 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.tree.impurity._ +import org.apache.spark.util.ArrayImplicits._ /** * Test suites for `GiniAggregator` and `EntropyAggregator`. @@ -63,8 +64,8 @@ class ImpuritySuite extends SparkFunSuite { } } val rng = new scala.util.Random(seed) - val samples = Array.fill(10)(rng.nextDouble()) - val _weights = Array.fill(10)(rng.nextDouble()) + val samples = Array.fill(10)(rng.nextDouble()).toImmutableArraySeq + val _weights = Array.fill(10)(rng.nextDouble()).toImmutableArraySeq val smallWeights = _weights.map(_ * 0.0001) val largeWeights = _weights.map(_ * 10000) val (count, sum, sumSquared) = computeStats(samples, _weights) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index b1a385a576cea..24470a5baad09 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impurity.{Gini, Variance} import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -34,7 +35,7 @@ import org.apache.spark.util.Utils class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { def binaryClassificationTestWithContinuousFeatures(strategy: Strategy): Unit = { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val numTrees = 1 val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, @@ -44,8 +45,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val dt = DecisionTree.train(rdd, strategy) - EnsembleTestHelper.validateClassifier(rf, arr, 0.9) - DecisionTreeSuite.validateClassifier(dt, arr, 0.9) + EnsembleTestHelper.validateClassifier(rf, arr.toImmutableArraySeq, 0.9) + DecisionTreeSuite.validateClassifier(dt, arr.toImmutableArraySeq, 0.9) // Make sure trees are the same. assert(rfTree.toString == dt.toString) @@ -70,7 +71,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { def regressionTestWithContinuousFeatures(strategy: Strategy): Unit = { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val numTrees = 1 val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, @@ -80,8 +81,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val dt = DecisionTree.train(rdd, strategy) - EnsembleTestHelper.validateRegressor(rf, arr, 0.01) - DecisionTreeSuite.validateRegressor(dt, arr, 0.01) + EnsembleTestHelper.validateRegressor(rf, arr.toImmutableArraySeq, 0.01) + DecisionTreeSuite.validateRegressor(dt, arr.toImmutableArraySeq, 0.01) // Make sure trees are the same. assert(rfTree.toString == dt.toString) @@ -112,7 +113,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)) arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) val categoricalFeaturesInfo = Map(0 -> 3, 2 -> 2, 4 -> 4) - val input = sc.parallelize(arr) + val input = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) @@ -122,7 +123,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("subsampling rate in RandomForest") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr.toImmutableArraySeq) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int], useNodeIdCache = true) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index d61514bf5fd9b..82361f684ec61 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -26,6 +26,7 @@ import org.apache.spark.ml.feature._ import org.apache.spark.ml.stat.Summarizer import org.apache.spark.ml.util.TempDirectory import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils trait MLlibTestSparkContext extends TempDirectory { self: Suite => @@ -73,7 +74,7 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => private[spark] def standardize(instances: Array[Instance]): Array[Instance] = { val (featuresSummarizer, _) = - Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + Summarizer.getClassificationSummarizers(sc.parallelize(instances.toImmutableArraySeq)) val inverseStd = featuresSummarizer.std.toArray .map { std => if (std != 0) 1.0 / std else 0.0 } val func = StandardScalerModel.getTransformFunc(Array.empty, inverseStd, false, true) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala index 648d2fa42eb33..679ee2991eb9e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala @@ -24,6 +24,7 @@ import io.fabric8.kubernetes.api.model._ import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils.randomize private[spark] class LocalDirsFeatureStep( @@ -61,7 +62,7 @@ private[spark] class LocalDirsFeatureStep( .withMedium(if (useLocalDirTmpFs) "Memory" else null) .endEmptyDir() .build() - } + }.toImmutableArraySeq localDirVolumeMounts = localDirVolumes .zip(resolvedLocalDirs) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 93b6ca8adc369..daf8d5e3f58a2 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -40,6 +40,7 @@ import org.apache.spark.scheduler.{ExecutorDecommission, ExecutorDecommissionInf import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutor import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, @@ -225,7 +226,7 @@ private[spark] class KubernetesClusterSchedulerBackend( // If decommissioning is triggered by the executor the K8s cluster manager has already // picked the pod to evict so we don't need to update the labels. if (!triggeredByExecutor) { - labelDecommissioningExecs(executorsAndDecomInfo.map(_._1)) + labelDecommissioningExecs(executorsAndDecomInfo.map(_._1).toImmutableArraySeq) } super.decommissionExecutors(executorsAndDecomInfo, adjustTargetNumExecutors, triggeredByExecutor) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala index d200a378a38e5..f6dc7af64510b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.spark.SparkConf import org.apache.spark.resource.ResourceProfile +import org.apache.spark.util.ArrayImplicits._ private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], racks: Array[String]) @@ -139,7 +140,7 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( // Only filter out the ratio which is larger than 0, which means the current host can // still be allocated with new container request. val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray - val racks = resolver.resolve(hosts).map(_.getNetworkLocation) + val racks = resolver.resolve(hosts.toImmutableArraySeq).map(_.getNetworkLocation) .filter(_ != null).toSet containerLocalityPreferences += ContainerLocalityPreferences(hosts, racks.toArray) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/executor/YarnCoarseGrainedExecutorBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/executor/YarnCoarseGrainedExecutorBackend.scala index 3dd51f174b01f..9a339544d5d9c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/executor/YarnCoarseGrainedExecutorBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/executor/YarnCoarseGrainedExecutorBackend.scala @@ -25,6 +25,7 @@ import org.apache.spark.deploy.yarn.Client import org.apache.spark.internal.Logging import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.YarnContainerInfoHelper /** @@ -56,7 +57,7 @@ private[spark] class YarnCoarseGrainedExecutorBackend( private lazy val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(env.conf) override def getUserClassPath: Seq[URL] = - Client.getUserClasspathUrls(env.conf, useClusterPath = true) + Client.getUserClasspathUrls(env.conf, useClusterPath = true).toImmutableArraySeq override def extractLogUrls: Map[String, String] = { YarnContainerInfoHelper.getLogUrls(hadoopConfiguration, container = None) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala index b9f60d6c24a85..bc1e7269aa0f5 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala @@ -23,6 +23,7 @@ import org.scalatest.matchers.must.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.resource.ResourceProfile +import org.apache.spark.util.ArrayImplicits._ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers { @@ -49,7 +50,8 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers { val (handler, allocatorConf) = createAllocator(2) handler.updateResourceRequests() - handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) + handler.handleAllocatedContainers( + Array(createContainer("host1"), createContainer("host2")).toImmutableArraySeq) ResourceProfile.clearDefaultProfile() val rp = ResourceProfile.getOrCreateDefaultProfile(allocatorConf) @@ -74,7 +76,7 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers { createContainer("host1"), createContainer("host1"), createContainer("host2") - )) + ).toImmutableArraySeq) ResourceProfile.clearDefaultProfile() val rp = ResourceProfile.getOrCreateDefaultProfile(allocatorConf) @@ -98,7 +100,7 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers { createContainer("host1"), createContainer("host1"), createContainer("host2") - )) + ).toImmutableArraySeq) ResourceProfile.clearDefaultProfile() val rp = ResourceProfile.getOrCreateDefaultProfile(allocatorConf) @@ -120,7 +122,7 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers { createContainer("host2"), createContainer("host2"), createContainer("host3") - )) + ).toImmutableArraySeq) ResourceProfile.clearDefaultProfile() val rp = ResourceProfile.getOrCreateDefaultProfile(allocatorConf) @@ -136,7 +138,8 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers { val (handler, allocatorConf) = createAllocator(2) handler.updateResourceRequests() - handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) + handler.handleAllocatedContainers( + Array(createContainer("host1"), createContainer("host2")).toImmutableArraySeq) ResourceProfile.clearDefaultProfile() val rp = ResourceProfile.getOrCreateDefaultProfile(allocatorConf) @@ -154,7 +157,7 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers { createContainer("host1"), createContainer("host1"), createContainer("host2") - )) + ).toImmutableArraySeq) val pendingAllocationRequests = Seq( createContainerRequest(Array("host2", "host3")), diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 4aada0f3a8dac..efd66a912174b 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -49,6 +49,7 @@ import org.apache.spark.resource.TestResourceIDs._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.DecommissionExecutorsOnHost +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.ManualClock class MockResolver extends SparkRackResolver(SparkHadoopUtil.get.conf) { @@ -181,7 +182,7 @@ class YarnAllocatorSuite extends SparkFunSuite handler.getNumContainersPendingAllocate should be (1) val container = createContainer("host1") - handler.handleAllocatedContainers(Array(container)) + handler.handleAllocatedContainers(Array(container).toImmutableArraySeq) handler.getNumExecutorsRunning should be (1) handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") @@ -213,7 +214,7 @@ class YarnAllocatorSuite extends SparkFunSuite handler.getNumContainersPendingAllocate should be (1) val container = createContainer("host1", priority = Priority.newInstance(rprof.id)) - handler.handleAllocatedContainers(Array(container)) + handler.handleAllocatedContainers(Array(container).toImmutableArraySeq) handler.getNumExecutorsRunning should be (1) handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") @@ -260,7 +261,8 @@ class YarnAllocatorSuite extends SparkFunSuite priority = Priority.newInstance(rprof2.id)) val container3 = createContainer("host3", resource = containerResourcerp2, priority = Priority.newInstance(rprof2.id)) - handler.handleAllocatedContainers(Array(container, container2, container3)) + handler.handleAllocatedContainers( + Array(container, container2, container3).toImmutableArraySeq) handler.getNumExecutorsRunning should be (3) handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") @@ -290,7 +292,7 @@ class YarnAllocatorSuite extends SparkFunSuite handler.updateResourceRequests() val defaultResource = handler.rpIdToYarnResource.get(defaultRPId) val container = createContainer("host1", resource = defaultResource) - handler.handleAllocatedContainers(Array(container)) + handler.handleAllocatedContainers(Array(container).toImmutableArraySeq) // get amount of memory and vcores from resource, so effectively skipping their validation val expectedResources = Resource.newInstance(defaultResource.getMemorySize(), @@ -370,7 +372,7 @@ class YarnAllocatorSuite extends SparkFunSuite handler.getNumContainersPendingAllocate should be (1) val container = createContainer("host1") - handler.handleAllocatedContainers(Array(container)) + handler.handleAllocatedContainers(Array(container).toImmutableArraySeq) handler.getNumExecutorsRunning should be (1) handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") @@ -378,7 +380,7 @@ class YarnAllocatorSuite extends SparkFunSuite hostTocontainer.get("host1").get should contain(container.getId) val container2 = createContainer("host2") - handler.handleAllocatedContainers(Array(container2)) + handler.handleAllocatedContainers(Array(container2).toImmutableArraySeq) handler.getNumExecutorsRunning should be (1) } @@ -392,7 +394,7 @@ class YarnAllocatorSuite extends SparkFunSuite val container1 = createContainer("host1") val container2 = createContainer("host1") val container3 = createContainer("host2") - handler.handleAllocatedContainers(Array(container1, container2, container3)) + handler.handleAllocatedContainers(Array(container1, container2, container3).toImmutableArraySeq) handler.getNumExecutorsRunning should be (3) handler.allocatedContainerToHostMap.get(container1.getId).get should be ("host1") @@ -413,7 +415,7 @@ class YarnAllocatorSuite extends SparkFunSuite val container1 = createContainer("host1") val container2 = createContainer("host2") val container3 = createContainer("host4") - handler.handleAllocatedContainers(Array(container1, container2, container3)) + handler.handleAllocatedContainers(Array(container1, container2, container3).toImmutableArraySeq) handler.getNumExecutorsRunning should be (2) handler.allocatedContainerToHostMap.get(container1.getId).get should be ("host1") @@ -439,7 +441,7 @@ class YarnAllocatorSuite extends SparkFunSuite handler.getNumContainersPendingAllocate should be (3) val container = createContainer("host1") - handler.handleAllocatedContainers(Array(container)) + handler.handleAllocatedContainers(Array(container).toImmutableArraySeq) handler.getNumExecutorsRunning should be (1) handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") @@ -468,7 +470,7 @@ class YarnAllocatorSuite extends SparkFunSuite val container1 = createContainer("host1") val container2 = createContainer("host2") - handler.handleAllocatedContainers(Array(container1, container2)) + handler.handleAllocatedContainers(Array(container1, container2).toImmutableArraySeq) handler.getNumExecutorsRunning should be (2) @@ -488,7 +490,7 @@ class YarnAllocatorSuite extends SparkFunSuite val container1 = createContainer("host1") val container2 = createContainer("host2") - handler.handleAllocatedContainers(Array(container1, container2)) + handler.handleAllocatedContainers(Array(container1, container2).toImmutableArraySeq) val resourceProfileToTotalExecs = mutable.HashMap(defaultRP -> 1) val numLocalityAwareTasksPerResourceProfileId = mutable.HashMap(defaultRPId -> 0) @@ -513,7 +515,7 @@ class YarnAllocatorSuite extends SparkFunSuite val container1 = createContainer("host1") val container2 = createContainer("host2") - handler.handleAllocatedContainers(Array(container1, container2)) + handler.handleAllocatedContainers(Array(container1, container2).toImmutableArraySeq) handler.getNumExecutorsRunning should be (2) handler.getNumContainersPendingAllocate should be (0) @@ -540,7 +542,7 @@ class YarnAllocatorSuite extends SparkFunSuite val container1 = createContainer("host1") val container2 = createContainer("host2") - handler.handleAllocatedContainers(Array(container1, container2)) + handler.handleAllocatedContainers(Array(container1, container2).toImmutableArraySeq) handler.getNumExecutorsRunning should be (2) handler.getNumContainersPendingAllocate should be (0) @@ -560,7 +562,7 @@ class YarnAllocatorSuite extends SparkFunSuite val container1 = createContainer("host1") val container2 = createContainer("host2") - handler.handleAllocatedContainers(Array(container1, container2)) + handler.handleAllocatedContainers(Array(container1, container2).toImmutableArraySeq) val resourceProfileToTotalExecs = mutable.HashMap(defaultRP -> 2) val numLocalityAwareTasksPerResourceProfileId = mutable.HashMap(defaultRPId -> 0) @@ -853,7 +855,7 @@ class YarnAllocatorSuite extends SparkFunSuite handler.getNumContainersPendingAllocate should be(1) val container = createContainer("host1") - handler.handleAllocatedContainers(Array(container)) + handler.handleAllocatedContainers(Array(container).toImmutableArraySeq) handler.getNumExecutorsRunning should be(1) handler.getNumContainersPendingAllocate should be(0) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 874303d849d7c..a945cb720b014 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * Type-inference utilities for POJOs and Java collections. @@ -147,7 +148,7 @@ object JavaTypeInference { Option(readMethod.getName), Option(property.getWriteMethod).map(_.getName)) } - JavaBeanEncoder(ClassTag(c), fields) + JavaBeanEncoder(ClassTag(c), fields.toImmutableArraySeq) case _ => throw ExecutionErrors.cannotFindEncoderForTypeError(t.toString) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala index 2e1c33809602f..85eba2b246143 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.AttributeNameParser import org.apache.spark.sql.catalyst.util.QuotingUtils.{quoted, quoteIdentifier, quoteNameParts} import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.util.ArrayImplicits._ /** * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception @@ -49,7 +50,7 @@ class NamespaceAlreadyExistsException private[sql]( def this(namespace: Array[String]) = { this(errorClass = "SCHEMA_ALREADY_EXISTS", - Map("schemaName" -> quoteNameParts(namespace))) + Map("schemaName" -> quoteNameParts(namespace.toImmutableArraySeq))) } def this(message: String) = { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala index 48c61381b1d44..b7c8473c08c04 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkThrowableHelper import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.QuotingUtils.{quoted, quoteIdentifier, quoteNameParts} import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.util.ArrayImplicits._ /** * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception @@ -87,7 +88,7 @@ class NoSuchNamespaceException private( def this(namespace: Array[String]) = { this(errorClass = "SCHEMA_NOT_FOUND", - Map("schemaName" -> quoteNameParts(namespace))) + Map("schemaName" -> quoteNameParts(namespace.toImmutableArraySeq))) } def this(message: String, cause: Option[Throwable] = None) = { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 9cecc1837d9a8..69661c343c5b1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, B import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * A factory for constructing encoders that convert external row to/from the Spark SQL @@ -121,6 +122,6 @@ object RowEncoder { encoderForDataType(field.dataType, lenient), field.nullable, field.metadata) - }) + }.toImmutableArraySeq) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index fbd4c1d98379b..7f21ab25ad4e5 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ /** * A row implementation that uses an array of objects as the underlying storage. Note that, while @@ -34,7 +35,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { override def get(i: Int): Any = values(i) - override def toSeq: Seq[Any] = values.clone() + override def toSeq: Seq[Any] = values.clone().toImmutableArraySeq override def copy(): GenericRow = this } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala index dd24dae16ba8c..d8469d3056d57 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.QueryContext +import org.apache.spark.util.ArrayImplicits._ /** * Contexts of TreeNodes, including location, SQL text, object type and object name. @@ -34,7 +35,7 @@ case class Origin( stackTrace: Option[Array[StackTraceElement]] = None) { lazy val context: QueryContext = if (stackTrace.isDefined) { - DataFrameQueryContext(stackTrace.get) + DataFrameQueryContext(stackTrace.get.toImmutableArraySeq) } else { SQLQueryContext( line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index c629302214c97..aa8826dd48b66 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -20,6 +20,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.ByteArrayUtils +import org.apache.spark.util.ArrayImplicits._ /** * Concatenation of sequence of strings to final string with cheap append method @@ -107,7 +108,7 @@ object SparkStringUtils extends Logging { def getHexString(bytes: Array[Byte]): String = bytes.map("%02X".format(_)).mkString("[", " ", "]") def sideBySide(left: String, right: String): Seq[String] = { - sideBySide(left.split("\n"), right.split("\n")) + sideBySide(left.split("\n").toImmutableArraySeq, right.split("\n").toImmutableArraySeq) } def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 4a0b215c81c17..eb2ad180e3670 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -28,6 +28,7 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ private[sql] object ArrowUtils { @@ -182,7 +183,7 @@ private[sql] object ArrowUtils { st.names } else { if (errorOnDuplicatedFieldNames) { - throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names) + throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names.toImmutableArraySeq) } val genNawName = st.names.groupBy(identity).map { case (name, names) if names.length > 1 => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index fd063120c268e..7603a002d640e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType._ import org.apache.spark.sql.types.YearMonthIntervalType._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.Utils /** @@ -183,13 +184,13 @@ object CatalystTypeConverters { if (catalystValue == null) { null } else if (isPrimitive(elementType)) { - catalystValue.toArray[Any](elementType) + catalystValue.toArray[Any](elementType).toImmutableArraySeq } else { val result = new Array[Any](catalystValue.numElements()) catalystValue.foreach(elementType, (i, e) => { result(i) = elementConverter.toScala(e) }) - result + result.toImmutableArraySeq } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 02b42f11cc0bc..9233cca2ace73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.ArrayImplicits._ /** * An abstract class for row used internally in Spark SQL, which only contains the columns as @@ -92,7 +93,7 @@ abstract class InternalRow extends SpecializedGetters with Serializable { values(i) = get(i, fieldTypes(i)) i += 1 } - values + values.toImmutableArraySeq } def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3a5f60eb376ca..0b0bf86cdd039 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -57,6 +57,7 @@ import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.DAY import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ /** * A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and @@ -1347,7 +1348,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor table.partitioning.flatMap { case IdentityTransform(FieldReference(Seq(name))) => Some(name) case _ => None - } + }.toImmutableArraySeq } private def validatePartitionSpec( @@ -2033,7 +2034,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor f } else { val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts) - val fullName = normalizeFuncName(catalog.name +: ident.namespace :+ ident.name) + val fullName = + normalizeFuncName((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq) if (externalFunctionNameSet.contains(fullName)) { f } else if (catalog.asFunctionCatalog.functionExists(ident)) { @@ -2086,7 +2088,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val fullName = catalog.name +: ident.namespace :+ ident.name CatalogV2Util.loadFunction(catalog, ident).map { func => ResolvedPersistentFunc(catalog.asFunctionCatalog, ident, func) - }.getOrElse(u.copy(possibleQualifiedName = Some(fullName))) + }.getOrElse(u.copy(possibleQualifiedName = Some(fullName.toImmutableArraySeq))) } // Resolve table-valued function references. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala index ebc30ae4a6ef8..43631e1afc403 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueE import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.ArrayImplicits._ object AssignmentUtils extends SQLConfHelper with CastSupport { @@ -199,7 +200,7 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { private def toNamedStruct(structType: StructType, fieldExprs: Seq[Expression]): Expression = { val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case (field, expr) => Seq(Literal(field.name), expr) - } + }.toImmutableArraySeq CreateNamedStruct(namedStructExprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala index 032cdca12c050..f3e0c0aca29ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.quoteNameParts import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.util.ArrayImplicits._ class CannotReplaceMissingTableException( tableIdentifier: Identifier, @@ -28,5 +29,5 @@ class CannotReplaceMissingTableException( extends AnalysisException( errorClass = "TABLE_OR_VIEW_NOT_FOUND", messageParameters = Map("relationName" - -> quoteNameParts(tableIdentifier.namespace :+ tableIdentifier.name)), + -> quoteNameParts((tableIdentifier.namespace :+ tableIdentifier.name).toImmutableArraySeq)), cause = cause) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f9010d47508c2..352b3124a864b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -600,7 +601,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case create: V2CreateTablePlan => val references = create.partitioning.flatMap(_.references).toSet val badReferences = references.map(_.fieldNames).flatMap { column => - create.tableSchema.findNestedField(column) match { + create.tableSchema.findNestedField(column.toImmutableArraySeq) match { case Some(_) => None case _ => @@ -627,7 +628,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB if c.resolved && c.defaultExpr.child.containsPattern(PLAN_EXPRESSION) => val ident = c.name.asInstanceOf[ResolvedIdentifier] val varName = toSQLId( - ident.catalog.name +: ident.identifier.namespace :+ ident.identifier.name) + (ident.catalog.name +: ident.identifier.namespace :+ ident.identifier.name) + .toImmutableArraySeq) throw QueryCompilationErrors.defaultValuesMayNotContainSubQueryExpressions( "DECLARE VARIABLE", varName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1449764cdd595..23d63011db53f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generat import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ - +import org.apache.spark.util.ArrayImplicits._ /** * A catalog for looking up user defined functions, used by an [[Analyzer]]. @@ -142,7 +142,7 @@ object FunctionRegistryBase { .filter(_.getParameterTypes.forall(_ == classOf[Expression])) .map(_.getParameterCount).distinct.sorted throw QueryCompilationErrors.wrongNumArgsError( - name, validParametersCount, params.length) + name, validParametersCount.toImmutableArraySeq, params.length) } try { f.newInstance(expressions : _*).asInstanceOf[T] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 788f79cde99e1..253c8eb190f9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, LookupCatalog} import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.util.ArrayImplicits._ /** * Resolves the catalog of the name parts for table/view/function/namespace. @@ -47,15 +48,20 @@ class ResolveCatalogs(val catalogManager: CatalogManager) ResolvedIdentifier(catalog, identifier) } case s @ ShowTables(UnresolvedNamespace(Seq()), _, _) => - s.copy(namespace = ResolvedNamespace(currentCatalog, catalogManager.currentNamespace)) + s.copy(namespace = ResolvedNamespace(currentCatalog, + catalogManager.currentNamespace.toImmutableArraySeq)) case s @ ShowTableExtended(UnresolvedNamespace(Seq()), _, _, _) => - s.copy(namespace = ResolvedNamespace(currentCatalog, catalogManager.currentNamespace)) + s.copy(namespace = ResolvedNamespace(currentCatalog, + catalogManager.currentNamespace.toImmutableArraySeq)) case s @ ShowViews(UnresolvedNamespace(Seq()), _, _) => - s.copy(namespace = ResolvedNamespace(currentCatalog, catalogManager.currentNamespace)) + s.copy(namespace = ResolvedNamespace(currentCatalog, + catalogManager.currentNamespace.toImmutableArraySeq)) case s @ ShowFunctions(UnresolvedNamespace(Seq()), _, _, _, _) => - s.copy(namespace = ResolvedNamespace(currentCatalog, catalogManager.currentNamespace)) + s.copy(namespace = ResolvedNamespace(currentCatalog, + catalogManager.currentNamespace.toImmutableArraySeq)) case a @ AnalyzeTables(UnresolvedNamespace(Seq()), _) => - a.copy(namespace = ResolvedNamespace(currentCatalog, catalogManager.currentNamespace)) + a.copy(namespace = ResolvedNamespace(currentCatalog, + catalogManager.currentNamespace.toImmutableArraySeq)) case UnresolvedNamespace(Seq()) => ResolvedNamespace(currentCatalog, Seq.empty[String]) case UnresolvedNamespace(CatalogAndNamespace(catalog, ns)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala index 90a502653d043..8eccd4b104c60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.PartitioningUtils.{castPartitionSpec, normalizePartitionSpec, requireExactMatchedPartitionSpec} +import org.apache.spark.util.ArrayImplicits._ /** * Resolve [[UnresolvedPartitionSpec]] to [[ResolvedPartitionSpec]] in partition related commands. @@ -63,7 +64,8 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] { tableName, conf.resolver) if (!allowPartitionSpec) { - requireExactMatchedPartitionSpec(tableName, normalizedSpec, partSchema.fieldNames) + requireExactMatchedPartitionSpec(tableName, normalizedSpec, + partSchema.fieldNames.toImmutableArraySeq) } val partitionNames = normalizedSpec.keySet val requestedFields = partSchema.filter(field => partitionNames.contains(field.name)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index 916c88c376610..cec44470e3a35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.connector.write.RowLevelOperation.Command import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { @@ -72,7 +73,7 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { operation: RowLevelOperation): Seq[AttributeReference] = { V2ExpressionUtils.resolveRefs[AttributeReference]( - operation.requiredMetadataAttributes, + operation.requiredMetadataAttributes.toImmutableArraySeq, relation) } @@ -81,7 +82,7 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { operation: SupportsDelta): Seq[AttributeReference] = { val rowIdAttrs = V2ExpressionUtils.resolveRefs[AttributeReference]( - operation.rowId, + operation.rowId.toImmutableArraySeq, relation) val nullableRowIdAttrs = rowIdAttrs.filter(_.nullable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index 6634ce72d7bd3..8a9cc5706e5aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ /** * Holds the name of a namespace that has yet to be looked up in a catalog. It will be resolved to @@ -152,7 +153,7 @@ case class ResolvedTable( extends LeafNodeWithoutStats { override def output: Seq[Attribute] = { val qualifier = catalog.name +: identifier.namespace :+ identifier.name - outputAttributes.map(_.withQualifier(qualifier)) + outputAttributes.map(_.withQualifier(qualifier.toImmutableArraySeq)) } def name: String = (catalog.name +: identifier.namespace() :+ identifier.name()).quoted } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 066dbd9fad15b..20352129a06c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} +import org.apache.spark.util.ArrayImplicits._ /** @@ -212,7 +213,7 @@ object ClusterBySpec { resolver: Resolver): ClusterBySpec = { val normalizedColumns = clusterBySpec.columnNames.map { columnName => val position = SchemaUtils.findColumnPosition( - columnName.fieldNames(), schema, resolver) + columnName.fieldNames().toImmutableArraySeq, schema, resolver) FieldReference(SchemaUtils.getColumnName(position, schema)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index 0ed89f8cba2d3..b61652f4b5234 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, Interva import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ class UnivocityGenerator( schema: StructType, @@ -103,7 +104,7 @@ class UnivocityGenerator( } i += 1 } - values + values.toImmutableArraySeq } def writeHeaders(): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala index a1815cf3b3d40..b34ae6a80f576 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.connector.catalog.functions.ScalarFunction import org.apache.spark.sql.types.{AbstractDataType, DataType} +import org.apache.spark.util.ArrayImplicits._ case class ApplyFunctionExpression( function: ScalarFunction[_], @@ -35,7 +36,7 @@ case class ApplyFunctionExpression( children.forall(_.deterministic) override def foldable: Boolean = deterministic && children.forall(_.foldable) - private lazy val reusedRow = new SpecificInternalRow(function.inputTypes()) + private lazy val reusedRow = new SpecificInternalRow(function.inputTypes().toImmutableArraySeq) /** Returns the result of evaluating this expression on a given input Row */ override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 64b170d7e386b..a0e4039479c12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -169,8 +170,8 @@ case class CallMethodViaReflection( @transient private lazy val methodName = children(1).eval(null).asInstanceOf[UTF8String].toString /** The reflection method. */ - @transient lazy val method: Method = - CallMethodViaReflection.findMethod(className, methodName, argExprs.map(_.dataType)).orNull + @transient lazy val method: Method = CallMethodViaReflection + .findMethod(className, methodName, argExprs.map(_.dataType).toImmutableArraySeq).orNull /** A temporary buffer used to hold intermediate results returned by children. */ @transient private lazy val buffer = new Array[Object](argExprs.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ee022c068b987..f82f22772369a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper} +import org.apache.spark.util.ArrayImplicits._ object Cast extends QueryErrorsBase { /** @@ -2086,7 +2087,7 @@ case class Cast( """ } val fieldsEvalCodes = ctx.splitExpressions( - expressions = fieldsEvalCode.map(_.code), + expressions = fieldsEvalCode.map(_.code).toImmutableArraySeq, funcName = "castStruct", arguments = ("InternalRow", tmpInput.code) :: (rowClass.code, tmpResult.code) :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index f2509367f934a..6aa5fefc73902 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{UserDefinedType, _} import org.apache.spark.unsafe.Platform +import org.apache.spark.util.ArrayImplicits._ /** * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] it produces, a consumer @@ -35,7 +36,8 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe import InterpretedUnsafeProjection._ private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled - private[this] val exprs = prepareExpressions(expressions, subExprEliminationEnabled) + private[this] val exprs = + prepareExpressions(expressions.toImmutableArraySeq, subExprEliminationEnabled) /** Number of (top level) fields in the resulting row. */ private[this] val numFields = expressions.length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 7d993d776d159..6733ad2981c43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.ArrayImplicits._ /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -41,7 +42,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { } override def initialize(partitionIndex: Int): Unit = { - initializeExprs(exprArray, partitionIndex) + initializeExprs(exprArray.toImmutableArraySeq, partitionIndex) } def apply(input: InternalRow): InternalRow = { @@ -141,7 +142,7 @@ object UnsafeProjection * CAUTION: the returned projection object is *not* thread-safe. */ def create(fields: Array[DataType]): UnsafeProjection = { - create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true))) + create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true)).toImmutableArraySeq) } /** @@ -184,7 +185,7 @@ object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expressio * Returns a SafeProjection for given Array of DataTypes. */ def create(fields: Array[DataType]): Projection = { - createObject(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) + createObject(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true)).toImmutableArraySeq) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala index 43bf8459be59e..66a017578c351 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.ArrayImplicits._ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => @@ -221,7 +222,8 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => val row = ctx.freshVariable("row", classOf[InternalRow]) val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder]) val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder]) - val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) + val writeStructCode = + writeStructToStringBuilder(fields.map(_.dataType).toImmutableArraySeq, row, buffer, ctx) code""" |InternalRow $row = $c; |$bufferClass $buffer = new $bufferClass(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index d23ba3867dfe8..6fb6333df3b1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_MET import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * A utility class that converts public connector expressions into Catalyst expressions. @@ -40,7 +41,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper def resolveRef[T <: NamedExpression](ref: NamedReference, plan: LogicalPlan): T = { - plan.resolve(ref.fieldNames, conf.resolver) match { + plan.resolve(ref.fieldNames.toImmutableArraySeq, conf.resolver) match { case Some(namedExpr) => namedExpr.asInstanceOf[T] case None => @@ -61,7 +62,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { ordering: Array[V2SortOrder], query: LogicalPlan, funCatalogOpt: Option[FunctionCatalog] = None): Seq[SortOrder] = { - ordering.map(toCatalyst(_, query, funCatalogOpt).asInstanceOf[SortOrder]) + ordering.map(toCatalyst(_, query, funCatalogOpt).asInstanceOf[SortOrder]).toImmutableArraySeq } def toCatalyst( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 2476c1307187c..4987e31b49911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * The ApproximatePercentile function returns the approximate percentile(s) of a column at the given @@ -310,9 +311,9 @@ object ApproximatePercentile { def getPercentiles(percentages: Array[Double]): Seq[Double] = { if (!isCompressed) compress() if (summaries.count == 0 || percentages.length == 0) { - Array.emptyDoubleArray + Array.emptyDoubleArray.toImmutableArraySeq } else { - summaries.query(percentages).get + summaries.query(percentages.toImmutableArraySeq).get } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 133a39d987459..320f1d97a4504 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * A central moment is the expected value of a specified power of the deviation of a random @@ -74,7 +75,8 @@ abstract class CentralMomentAgg(child: Expression, nullOnDivideByZero: Boolean) override val aggBufferAttributes = trimHigherOrder(Seq(n, avg, m2, m3, m4)) - override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) + override val initialValues: Seq[Expression] = + Array.fill(momentOrder + 1)(Literal(0.0)).toImmutableArraySeq override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index c798004fe7843..ea154b9bce88c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * Base class for computing Pearson correlation between two expressions. @@ -55,7 +56,7 @@ abstract class PearsonCorrelation(x: Expression, y: Expression, nullOnDivideByZe override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck, xMk, yMk) - override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) + override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)).toImmutableArraySeq override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index ff31fb1128b9b..d261e78579738 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * Compute the covariance between two expressions. @@ -48,7 +49,7 @@ abstract class Covariance(val left: Expression, val right: Expression, nullOnDiv override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck) - override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) + override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)).toImmutableArraySeq override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala index a85ac5dbd30d0..da27ba4b128db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.types.TypeCollection.NumericAndAnsiInterval +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.OpenHashMap abstract class PercentileBase @@ -166,7 +167,7 @@ abstract class PercentileBase case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail - percentages.map(getPercentile(accumulatedCounts, _)) + percentages.map(getPercentile(accumulatedCounts, _)).toImmutableArraySeq } private def generateOutput(percentiles: Seq[Double]): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 5d00519d27c53..4bf6eafdb2d63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -57,7 +58,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], BaseOrdering] with def genComparisons(ctx: CodegenContext, schema: StructType): String = { val ordering = schema.fields.map(_.dataType).zipWithIndex.map { case(dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending) - } + }.toImmutableArraySeq genComparisons(ctx, ordering) } @@ -202,7 +203,8 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) } override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { - generatedOrdering = GenerateOrdering.generate(kryo.readObject(in, classOf[Array[SortOrder]])) + generatedOrdering = GenerateOrdering + .generate(kryo.readObject(in, classOf[Array[SortOrder]]).toImmutableArraySeq) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 64a31ed44b2b4..b35a7b412e485 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ /** * Trait to indicate the expression does not throw an exception by itself when they are evaluated. @@ -729,7 +730,7 @@ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperat } val fieldsWithIndex = structExpr.dataType.asInstanceOf[StructType].fields.zipWithIndex val existingFieldExprs: Seq[(StructField, Expression)] = - fieldsWithIndex.map { case (field, i) => (field, getFieldExpr(i)) } + fieldsWithIndex.map { case (field, i) => (field, getFieldExpr(i)) }.toImmutableArraySeq fieldOps.foldLeft(existingFieldExprs)((exprs, op) => op(exprs)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 74f2e56e64d03..0f5d4707a164b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.ArrayImplicits._ //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the expressions for hashing. @@ -456,7 +457,7 @@ abstract class HashExpression[E] extends Expression { } val hashResultType = CodeGenerator.javaType(dataType) val code = ctx.splitExpressions( - expressions = fieldsHash, + expressions = fieldsHash.toImmutableArraySeq, funcName = "computeHashForStruct", arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result), returnType = hashResultType, @@ -857,7 +858,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { } val code = ctx.splitExpressions( - expressions = fieldsHash, + expressions = fieldsHash.toImmutableArraySeq, funcName = "computeHashForStruct", arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result), returnType = CodeGenerator.JAVA_INT, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index ec76c70002da7..217ed562db779 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -53,6 +53,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.BitSet import org.apache.spark.util.collection.ImmutableBitSet @@ -198,7 +199,8 @@ object Literal { case arr: ArrayType => create(Array(), arr) case map: MapType => create(Map(), map) case struct: StructType => - create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct) + create(InternalRow.fromSeq( + struct.fields.map(f => default(f.dataType).value).toImmutableArraySeq), struct) case udt: UserDefinedType[_] => Literal(default(udt.sqlType).value, udt) case other => throw QueryExecutionErrors.noDefaultForDataTypeError(dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 91bd6a6ea2a99..643e108f4da4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic import org.apache.spark.sql.connector.catalog.functions.ScalarFunction import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -609,7 +610,7 @@ case class NewInstance( null } else { try { - constructor(evaluatedArgs) + constructor(evaluatedArgs.toImmutableArraySeq) } catch { // Re-throw the original exception. case e: java.lang.reflect.InvocationTargetException if e.getCause != null => @@ -1441,7 +1442,7 @@ case class ExternalMapToCatalyst private( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = { - val rowBuffer = InternalRow.fromSeq(Array[Any](1)) + val rowBuffer = InternalRow.fromSeq(Array[Any](1).toImmutableArraySeq) def rowWrapper(data: Any): InternalRow = { rowBuffer.update(0, data) rowBuffer diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 761bd3f33586e..312493c949911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * A base class for generated/interpreted predicate @@ -173,7 +174,7 @@ trait PredicateHelper extends AliasHelper with Logging { } i += 2 } - currentResult = nextResult + currentResult = nextResult.toImmutableArraySeq } currentResult.head } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 09d78f79edd17..296d093a13de6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.ArrayImplicits._ /** * An extended version of [[InternalRow]] that implements all special getters, toString @@ -170,7 +171,7 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values.clone() + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values.clone().toImmutableArraySeq override def numFields: Int = values.length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index cf6c7780cc82f..811d6e013ab76 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.ArrayImplicits._ //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines expressions for string operations. @@ -313,7 +314,7 @@ case class Elt( ) ) } - TypeUtils.checkForSameTypeInputExpr(inputTypes, prettyName) + TypeUtils.checkForSameTypeInputExpr(inputTypes.toImmutableArraySeq, prettyName) } } @@ -356,7 +357,7 @@ case class Elt( } val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = assignInputValue, + expressions = assignInputValue.toImmutableArraySeq, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, returnType = CodeGenerator.JAVA_BOOLEAN, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index a1e25eb4c948a..0a243c63685cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * `JackGenerator` can only be initialized with a `StructType`, a `MapType` or an `ArrayType`. @@ -283,7 +284,7 @@ private[sql] class JacksonGenerator( */ def write(row: InternalRow): Unit = { writeObject(writeFields( - fieldWriters = rootFieldWriters, + fieldWriters = rootFieldWriters.toImmutableArraySeq, row = row, schema = dataType.asInstanceOf[StructType])) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonFilters.scala index 01de1e3f54127..c4e696736ee0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonFilters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonFilters.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, StructFilters} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ /** * The class provides API for applying pushed down source filters to rows with @@ -98,7 +99,7 @@ class JsonFilters(pushedFilters: Seq[sources.Filter], schema: StructType) // Combine all filters from the same group by `And` because all filters should // return `true` to do not skip a row. The result is compiled to a predicate. .map { case (refSet, refsFilters) => - (refSet, JsonPredicate(toPredicate(refsFilters), refSet.size)) + (refSet, JsonPredicate(toPredicate(refsFilters.toImmutableArraySeq), refSet.size)) } // Apply predicates w/o references like `AlwaysTrue` and `AlwaysFalse` to all fields. // We cannot set such predicates to a particular position because skipRow() can diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 91d5e180c5967..2c73e25cb5ed1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ /* * Optimization rules defined in this file should not affect the structure of the logical plan. @@ -655,38 +656,41 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] { case u @ UnaryExpression(i @ If(_, trueValue, falseValue)) if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( - trueValue = u.withNewChildren(Array(trueValue)), - falseValue = u.withNewChildren(Array(falseValue))) + trueValue = u.withNewChildren(Array(trueValue).toImmutableArraySeq), + falseValue = u.withNewChildren(Array(falseValue).toImmutableArraySeq)) case u @ UnaryExpression(c @ CaseWhen(branches, elseValue)) if supportedUnaryExpression(u) && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( - branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))), - Some(u.withNewChildren(Array(elseValue.getOrElse(Literal(null, c.dataType)))))) + branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2).toImmutableArraySeq))), + Some(u.withNewChildren( + Array(elseValue.getOrElse(Literal(null, c.dataType))).toImmutableArraySeq))) case SupportedBinaryExpr(b, i @ If(_, trueValue, falseValue), right) if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( - trueValue = b.withNewChildren(Array(trueValue, right)), - falseValue = b.withNewChildren(Array(falseValue, right))) + trueValue = b.withNewChildren(Array(trueValue, right).toImmutableArraySeq), + falseValue = b.withNewChildren(Array(falseValue, right).toImmutableArraySeq)) case SupportedBinaryExpr(b, left, i @ If(_, trueValue, falseValue)) if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( - trueValue = b.withNewChildren(Array(left, trueValue)), - falseValue = b.withNewChildren(Array(left, falseValue))) + trueValue = b.withNewChildren(Array(left, trueValue).toImmutableArraySeq), + falseValue = b.withNewChildren(Array(left, falseValue).toImmutableArraySeq)) case SupportedBinaryExpr(b, c @ CaseWhen(branches, elseValue), right) if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( - branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))), - Some(b.withNewChildren(Array(elseValue.getOrElse(Literal(null, c.dataType)), right)))) + branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right).toImmutableArraySeq))), + Some(b.withNewChildren( + Array(elseValue.getOrElse(Literal(null, c.dataType)), right).toImmutableArraySeq))) case SupportedBinaryExpr(b, left, c @ CaseWhen(branches, elseValue)) if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( - branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))), - Some(b.withNewChildren(Array(left, elseValue.getOrElse(Literal(null, c.dataType)))))) + branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2).toImmutableArraySeq))), + Some(b.withNewChildren( + Array(left, elseValue.getOrElse(Literal(null, c.dataType))).toImmutableArraySeq))) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index eb501f56d81ce..ddc91db5967ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -51,6 +51,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.random.RandomSampler /** @@ -2622,7 +2623,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { .split("\\s") .map(_.toLowerCase(Locale.ROOT).stripSuffix("s")) .filter(s => s != "interval" && s.matches("[a-z]+")) - constructMultiUnitsIntervalLiteral(ctx, interval, units) + constructMultiUnitsIntervalLiteral(ctx, interval, units.toImmutableArraySeq) } else { Literal(interval, CalendarIntervalType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 837315b96fa66..4353f535aaa4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.Utils import org.apache.spark.util.random.RandomSampler @@ -103,7 +104,8 @@ object Project { def matchSchema(plan: LogicalPlan, schema: StructType, conf: SQLConf): Project = { assert(plan.resolved) - val projectList = reorderFields(plan.output.map(a => (a.name, a)), schema.fields, Nil, conf) + val projectList = + reorderFields(plan.output.map(a => (a.name, a)), schema.fields.toImmutableArraySeq, Nil, conf) Project(projectList, plan) } @@ -125,8 +127,8 @@ object Project { } else { (f.name, GetStructField(col, index)) } - }, - expected.fields, + }.toImmutableArraySeq, + expected.fields.toImmutableArraySeq, columnPath, conf) if (col.nullable) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index b02c4fac12dec..9c66e68d686d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, TypeUtils} import org.apache.spark.sql.connector.catalog.{TableCatalog, TableChange} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.DataType +import org.apache.spark.util.ArrayImplicits._ /** * The base trait for commands that need to alter a v2 table with [[TableChange]]s. @@ -158,7 +159,7 @@ case class ReplaceColumns( null, col.getV2Default) } - deleteChanges ++ addChanges + (deleteChanges ++ addChanges).toImmutableArraySeq } override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index ac4098d4e4101..55d71ff6c24e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils // For v2 DML commands, it may end up with the v1 fallback code path and need to build a DataFrame @@ -344,7 +345,7 @@ case class WriteDelta( // validates row ID projection output is compatible with row ID attributes private def rowIdAttrsResolved: Boolean = { val rowIdAttrs = V2ExpressionUtils.resolveRefs[AttributeReference]( - operation.rowId, + operation.rowId.toImmutableArraySeq, originalTable) val projectionSchema = projections.rowIdProjection.schema @@ -358,7 +359,7 @@ case class WriteDelta( projections.metadataProjection match { case Some(projection) => val metadataAttrs = V2ExpressionUtils.resolveRefs[AttributeReference]( - operation.requiredMetadataAttributes, + operation.requiredMetadataAttributes.toImmutableArraySeq, originalTable) val projectionSchema = projection.schema @@ -406,7 +407,8 @@ trait V2CreateTableAsSelectPlan // the table schema is created from the query schema, so the only resolution needed is to check // that the columns referenced by the table's partitioning exist in the query schema val references = partitioning.flatMap(_.references).toSet - references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined) + references.map(_.fieldNames.toImmutableArraySeq) + .forall(query.schema.findNestedField(_).isDefined) } override def childrenToAnalyze: Seq[LogicalPlan] = Seq(name, query) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 017b20077cf66..cb11ec4b5f1bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.BitSet @@ -829,7 +830,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] // Sort elements for deterministic behaviours truncatedString(set.toSeq.map(formatArg(_, maxFields)).sorted, "{", ", ", "}", maxFields) case array: Array[_] => - truncatedString(array.map(formatArg(_, maxFields)), "[", ", ", "]", maxFields) + truncatedString( + array.map(formatArg(_, maxFields)).toImmutableArraySeq, "[", ", ", "]", maxFields) case other => other.toString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala index d392557e650e3..29d7a39ace3c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, SQLOrderingUtil} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, YearMonthIntervalType} import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.ArrayImplicits._ sealed abstract class PhysicalDataType { private[sql] type InternalType @@ -316,7 +317,7 @@ case class PhysicalArrayType( case class PhysicalStructType(fields: Array[StructField]) extends PhysicalDataType { override private[sql] type InternalType = Any override private[sql] def ordering = - forSchema(this.fields.map(_.dataType)).asInstanceOf[Ordering[InternalType]] + forSchema(this.fields.map(_.dataType).toImmutableArraySeq).asInstanceOf[Ordering[InternalType]] @transient private[sql] lazy val tag = typeTag[InternalType] private[sql] def forSchema(dataTypes: Seq[DataType]): InterpretedOrdering = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index b9d83d444909d..192d812cc7aaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { @@ -178,7 +179,7 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => Seq(Literal(f.name), processStringForCharVarchar( GetStructField(expr, i, Some(f.name)), f.dataType, charFuncName, varcharFuncName)) - }) + }.toImmutableArraySeq) if (struct.valExprs.forall(_.isInstanceOf[GetStructField])) { // No field needs char/varchar processing, just return the original expression. expr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index e0cd6139a945f..62b6ebde4e09a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.util import scala.collection.mutable.{ArrayBuffer, ListBuffer} import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats +import org.apache.spark.util.ArrayImplicits._ /** * Helper class to compute approximate quantile summary. @@ -136,8 +137,8 @@ class QuantileSummaries( val inserted = this.withHeadBufferInserted assert(inserted.headSampled.isEmpty) assert(inserted.count == count + headSampled.size) - val compressed = - compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count) + val compressed = compressImmut( + inserted.sampled.toImmutableArraySeq, mergeThreshold = 2 * relativeError * inserted.count) new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count, true) } @@ -304,7 +305,7 @@ class QuantileSummaries( result(pos) = approxQuantile } } - Some(result) + Some(result.toImmutableArraySeq) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 50ff3eeab0c16..65767b51c7786 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ /** * This object contains fields to help process DEFAULT columns. @@ -85,7 +86,7 @@ object ResolveDefaultColumns extends QueryErrorsBase with ResolveDefaultColumnsU } else { field } - } + }.toImmutableArraySeq StructType(newFields) } else { tableSchema diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 0c49f9e46730c..133011ad9fac9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, QuotingUtils} import org.apache.spark.sql.connector.expressions.{BucketTransform, ClusterByTransform, FieldReference, IdentityTransform, LogicalExpressions, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ /** * Conversion helpers for working with v2 [[CatalogPlugin]]. @@ -149,7 +150,7 @@ private[sql] object CatalogV2Implicits { def original: String = ident.namespace() :+ ident.name() mkString "." - def asMultipartIdentifier: Seq[String] = ident.namespace :+ ident.name + def asMultipartIdentifier: Seq[String] = (ident.namespace :+ ident.name).toImmutableArraySeq def asTableIdentifier: TableIdentifier = ident.namespace match { case ns if ns.isEmpty => TableIdentifier(ident.name) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index e51d65070bfa7..06887b0b95038 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.connector.expressions.LiteralValue import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, MapType, Metadata, MetadataBuilder, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ private[sql] object CatalogV2Util { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -151,7 +152,7 @@ private[sql] object CatalogV2Util { Option(add.comment).map(fieldWithDefault.withComment).getOrElse(fieldWithDefault) addField(schema, fieldWithComment, add.position(), tableProvider, statementType, true) case names => - replace(schema, names.init, parent => parent.dataType match { + replace(schema, names.init.toImmutableArraySeq, parent => parent.dataType match { case parentType: StructType => val field = StructField(names.last, add.dataType, nullable = add.isNullable) val fieldWithDefault: StructField = encodeDefaultValue(add.defaultValue(), field) @@ -167,21 +168,21 @@ private[sql] object CatalogV2Util { } case rename: RenameColumn => - replace(schema, rename.fieldNames, field => + replace(schema, rename.fieldNames.toImmutableArraySeq, field => Some(StructField(rename.newName, field.dataType, field.nullable, field.metadata))) case update: UpdateColumnType => - replace(schema, update.fieldNames, field => { + replace(schema, update.fieldNames.toImmutableArraySeq, field => { Some(field.copy(dataType = update.newDataType)) }) case update: UpdateColumnNullability => - replace(schema, update.fieldNames, field => { + replace(schema, update.fieldNames.toImmutableArraySeq, field => { Some(field.copy(nullable = update.nullable)) }) case update: UpdateColumnComment => - replace(schema, update.fieldNames, field => + replace(schema, update.fieldNames.toImmutableArraySeq, field => Some(field.withComment(update.newComment))) case update: UpdateColumnPosition => @@ -198,7 +199,7 @@ private[sql] object CatalogV2Util { case Array(name) => updateFieldPos(schema, name) case names => - replace(schema, names.init, parent => parent.dataType match { + replace(schema, names.init.toImmutableArraySeq, parent => parent.dataType match { case parentType: StructType => Some(parent.copy(dataType = updateFieldPos(parentType, names.last))) case _ => @@ -207,7 +208,7 @@ private[sql] object CatalogV2Util { } case update: UpdateColumnDefaultValue => - replace(schema, update.fieldNames, field => + replace(schema, update.fieldNames.toImmutableArraySeq, field => // The new DEFAULT value string will be non-empty for any DDL commands that set the // default value, such as "ALTER TABLE t ALTER COLUMN c SET DEFAULT ..." (this is // enforced by the parser). On the other hand, commands that drop the default value such @@ -219,7 +220,7 @@ private[sql] object CatalogV2Util { }) case delete: DeleteColumn => - replace(schema, delete.fieldNames, _ => None, delete.ifExists) + replace(schema, delete.fieldNames.toImmutableArraySeq, _ => None, delete.ifExists) case _ => // ignore non-schema changes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/distributions/distributions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/distributions/distributions.scala index 599f82b4dc528..884202960d8f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/distributions/distributions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/distributions/distributions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.distributions import org.apache.spark.sql.connector.expressions.{Expression, SortOrder} +import org.apache.spark.util.ArrayImplicits._ private[sql] object LogicalDistributions { @@ -26,11 +27,11 @@ private[sql] object LogicalDistributions { } def clustered(clustering: Array[Expression]): ClusteredDistribution = { - ClusteredDistributionImpl(clustering) + ClusteredDistributionImpl(clustering.toImmutableArraySeq) } def ordered(ordering: Array[SortOrder]): OrderedDistribution = { - OrderedDistributionImpl(ordering) + OrderedDistributionImpl(ordering.toImmutableArraySeq) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 0037f52a21b73..6fabb43a895dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.util.ArrayImplicits._ /** * Helper methods for working with the logical expressions API. @@ -44,16 +45,17 @@ private[sql] object LogicalExpressions { def apply(name: String, arguments: Expression*): Transform = ApplyTransform(name, arguments) def bucket(numBuckets: Int, references: Array[NamedReference]): BucketTransform = - BucketTransform(literal(numBuckets, IntegerType), references) + BucketTransform(literal(numBuckets, IntegerType), references.toImmutableArraySeq) def bucket( numBuckets: Int, references: Array[NamedReference], sortedCols: Array[NamedReference]): SortedBucketTransform = - SortedBucketTransform(literal(numBuckets, IntegerType), references, sortedCols) + SortedBucketTransform(literal(numBuckets, IntegerType), + references.toImmutableArraySeq, sortedCols.toImmutableArraySeq) def clusterBy(references: Array[NamedReference]): ClusterByTransform = - ClusterByTransform(references) + ClusterByTransform(references.toImmutableArraySeq) def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference) @@ -234,7 +236,7 @@ private object Lit { */ private object Ref { def unapply(named: NamedReference): Some[Seq[String]] = { - Some(named.fieldNames) + Some(named.fieldNames.toImmutableArraySeq) } } @@ -243,7 +245,7 @@ private object Ref { */ private[sql] object NamedTransform { def unapply(transform: Transform): Some[(String, Seq[Expression])] = { - Some((transform.name, transform.arguments)) + Some((transform.name, transform.arguments.toImmutableArraySeq)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index e772b3497ac34..7399f6c621cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.internal.SQLConf.LEGACY_CTE_PRECEDENCE_POLICY import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * Object for grouping error messages from exceptions thrown during query compilation. @@ -845,7 +846,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ident: Identifier, operation: String): Throwable = { unsupportedTableOperationError( - catalog.name +: ident.namespace :+ ident.name, operation) + (catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq, operation) } def unsupportedTableOperationError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index 4b0e5308e651e..76d0a516a1322 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpress import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkSchemaUtils @@ -53,7 +54,7 @@ private[spark] object SchemaUtils { checkSchemaColumnNameDuplication(valueType, caseSensitiveAnalysis) case structType: StructType => val fields = structType.fields - checkColumnNameDuplication(fields.map(_.name), caseSensitiveAnalysis) + checkColumnNameDuplication(fields.map(_.name).toImmutableArraySeq, caseSensitiveAnalysis) fields.foreach { field => checkSchemaColumnNameDuplication(field.dataType, caseSensitiveAnalysis) } @@ -167,7 +168,8 @@ private[spark] object SchemaUtils { isCaseSensitive: Boolean): Unit = { val extractedTransforms = transforms.map { case b: BucketTransform => - val colNames = b.columns.map(c => UnresolvedAttribute(c.fieldNames()).name) + val colNames = + b.columns.map(c => UnresolvedAttribute(c.fieldNames().toImmutableArraySeq).name) // We need to check that we're not duplicating columns within our bucketing transform checkColumnNameDuplication(colNames, isCaseSensitive) b.name -> colNames diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index bff96019e97fb..22f24d8266177 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType._ import org.apache.spark.sql.types.YearMonthIntervalType.YEAR import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.Utils /** * Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random @@ -347,7 +348,7 @@ object RandomDataGenerator { case StructType(fields) => val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => forType(field.dataType, nullable = field.nullable, rand) - } + }.toImmutableArraySeq if (maybeFieldGenerators.forall(_.isDefined)) { val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get) Some(() => Row.fromSeq(fieldGenerators.map(_.apply()))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index ec40989e6b78e..840d80ffed13f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ class RowTest extends AnyFunSpec with Matchers { @@ -128,7 +129,7 @@ class RowTest extends AnyFunSpec with Matchers { def modifyValues(values: Seq[Any]): Seq[Any] = { val array = values.toArray array(2) = "42" - array + array.toImmutableArraySeq } it("copy should return same ref for external rows") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index 882bed48f8ffb..c9e37e255ab44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -26,13 +26,14 @@ import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, Table, Tabl import org.apache.spark.sql.connector.expressions.Expressions import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ class CreateTablePartitioningValidationSuite extends AnalysisTest { val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, false) test("CreateTableAsSelect: fail missing top-level column") { val plan = CreateTableAsSelect( - UnresolvedIdentifier(Array("table_name")), + UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, "does_not_exist") :: Nil, TestRelation2, tableSpec, @@ -47,7 +48,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { test("CreateTableAsSelect: fail missing top-level column nested reference") { val plan = CreateTableAsSelect( - UnresolvedIdentifier(Array("table_name")), + UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, "does_not_exist.z") :: Nil, TestRelation2, tableSpec, @@ -62,7 +63,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { test("CreateTableAsSelect: fail missing nested column") { val plan = CreateTableAsSelect( - UnresolvedIdentifier(Array("table_name")), + UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, "point.z") :: Nil, TestRelation2, tableSpec, @@ -77,7 +78,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { test("CreateTableAsSelect: fail with multiple errors") { val plan = CreateTableAsSelect( - UnresolvedIdentifier(Array("table_name")), + UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, "does_not_exist", "point.z") :: Nil, TestRelation2, tableSpec, @@ -92,7 +93,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { test("CreateTableAsSelect: success with top-level column") { val plan = CreateTableAsSelect( - UnresolvedIdentifier(Array("table_name")), + UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, "id") :: Nil, TestRelation2, tableSpec, @@ -104,7 +105,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { test("CreateTableAsSelect: success using nested column") { val plan = CreateTableAsSelect( - UnresolvedIdentifier(Array("table_name")), + UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, "point.x") :: Nil, TestRelation2, tableSpec, @@ -116,7 +117,7 @@ class CreateTablePartitioningValidationSuite extends AnalysisTest { test("CreateTableAsSelect: success using complex column") { val plan = CreateTableAsSelect( - UnresolvedIdentifier(Array("table_name")), + UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, "point") :: Nil, TestRelation2, tableSpec, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 6f5f22a84bad8..474c27c0de9a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjecti import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.types.{ArrayType, StructType, _} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val random = new scala.util.Random @@ -623,7 +624,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { def checkResult(schema: StructType, input: InternalRow): Unit = { val exprs = schema.fields.zipWithIndex.map { case (f, i) => BoundReference(i, f.dataType, true) - } + }.toImmutableArraySeq val murmur3HashExpr = Murmur3Hash(exprs, 42) val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr)) val murmursHashEval = Murmur3Hash(exprs, 42).eval(input) @@ -672,7 +673,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { (0 until O).map(_ => StructField("structOfStructOfStructOfString", outer)).toArray) val exprs = schema.fields.zipWithIndex.map { case (f, i) => BoundReference(i, f.dataType, true) - } + }.toImmutableArraySeq val murmur3HashExpr = Murmur3Hash(exprs, 42) val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index b79df0e40e99e..5143039281b48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.ArrayImplicits._ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -39,7 +40,8 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { StructType.fromDDL("a INT, b STRING"), ObjectType(classOf[java.lang.Integer])) def createMutableProjection(dataTypes: Array[DataType]): MutableProjection = { - MutableProjection.create(dataTypes.zipWithIndex.map(x => BoundReference(x._2, x._1, true))) + MutableProjection.create( + dataTypes.zipWithIndex.map(x => BoundReference(x._2, x._1, true)).toImmutableArraySeq) } testBothCodegenAndInterpreted("fixed-length types") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala index 448890223c879..52ab9ed46c6da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, GenerateOrdering, LazilyGeneratedOrdering} import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -132,10 +133,10 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { val sortOrder = Literal("abc").asc // this is passing prior to SPARK-16845, and it should also be passing after SPARK-16845 - GenerateOrdering.generate(Array.fill(40)(sortOrder)) + GenerateOrdering.generate(Array.fill(40)(sortOrder).toImmutableArraySeq) // verify that we can support up to 5000 ordering comparisons, which should be sufficient - GenerateOrdering.generate(Array.fill(5000)(sortOrder)) + GenerateOrdering.generate(Array.fill(5000)(sortOrder).toImmutableArraySeq) } test("SPARK-21344: BinaryType comparison does signed byte array comparison") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 44264a846630e..790aa94b5840f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.ArrayImplicits._ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase with ExpressionEvalHelper { @@ -40,7 +41,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) val converter = factory.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes.toImmutableArraySeq) row.setLong(0, 0) row.setLong(1, 1) row.setInt(2, 2) @@ -79,7 +80,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = factory.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes.toImmutableArraySeq) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) row.update(2, "World".getBytes(StandardCharsets.UTF_8)) @@ -100,7 +101,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) val converter = factory.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes.toImmutableArraySeq) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))) @@ -131,7 +132,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB val fieldTypes: Array[DataType] = Array(LongType, StringType, CalendarIntervalType) val converter = factory.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes.toImmutableArraySeq) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) val interval1 = new CalendarInterval(3, 1, 1000L) @@ -175,7 +176,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB val converter = factory.create(fieldTypes) val rowWithAllNullColumns: InternalRow = { - val r = new SpecificInternalRow(fieldTypes) + val r = new SpecificInternalRow(fieldTypes.toImmutableArraySeq) for (i <- fieldTypes.indices) { r.setNullAt(i) } @@ -204,7 +205,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB // columns, then the serialized row representation should be identical to what we would get by // creating an entirely null row via the converter val rowWithNoNullColumns: InternalRow = { - val r = new SpecificInternalRow(fieldTypes) + val r = new SpecificInternalRow(fieldTypes.toImmutableArraySeq) r.setNullAt(0) r.setBoolean(1, false) r.setByte(2, 20) @@ -282,7 +283,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB val fieldTypes: Array[DataType] = Array(CalendarIntervalType, CalendarIntervalType) val converter = factory.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes.toImmutableArraySeq) for (i <- 0 until row.numFields) { row.setInterval(i, null) } @@ -312,13 +313,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB val fieldTypes: Array[DataType] = Array(ArrayType(CalendarIntervalType)) val converter = factory.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) + val row = new SpecificInternalRow(fieldTypes.toImmutableArraySeq) val values = Array(new CalendarInterval(0, 7, 0L), null) - import org.apache.spark.util.ArrayImplicits._ row.update(0, createArray(values.toImmutableArraySeq: _*)) val unsafeRow: UnsafeRow = converter.apply(row) - testArrayInterval(unsafeRow.getArray(0), values) + testArrayInterval(unsafeRow.getArray(0), values.toImmutableArraySeq) } testBothCodegenAndInterpreted("basic conversion with struct type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala index 57f1e8e251b31..efc90716e31ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundRefer import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { @@ -78,13 +79,13 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), - endpointsExpression = CreateArray(Array(10L).map(Literal(_)))) + endpointsExpression = CreateArray(Array(10L).map(Literal(_)).toImmutableArraySeq)) assert(wrongEndpoints.checkInputDataTypes() == DataTypeMismatch("WRONG_NUM_ENDPOINTS", Map("actualNumber" -> "1"))) wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), - endpointsExpression = CreateArray(Array("foobar").map(Literal(_)))) + endpointsExpression = CreateArray(Array("foobar").map(Literal(_)).toImmutableArraySeq)) // scalastyle:off line.size.limit assert(wrongEndpoints.checkInputDataTypes() == DataTypeMismatch( @@ -104,7 +105,8 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, Array[Long]) = { val input = new SpecificInternalRow(Seq(dt)) val aggFunc = ApproxCountDistinctForIntervals( - BoundReference(0, dt, nullable = true), CreateArray(endpoints.map(Literal(_))), rsd) + BoundReference(0, dt, nullable = true), + CreateArray(endpoints.map(Literal(_)).toImmutableArraySeq), rsd) (aggFunc, input, aggFunc.createAggregationBuffer()) } @@ -151,7 +153,8 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { value: Double, expectedIntervalIndex: Int): Unit = { val aggFunc = ApproxCountDistinctForIntervals( - BoundReference(0, DoubleType, nullable = true), CreateArray(endpoints.map(Literal(_)))) + BoundReference(0, DoubleType, nullable = true), + CreateArray(endpoints.map(Literal(_)).toImmutableArraySeq)) assert(aggFunc.findHllppIndex(value) == expectedIntervalIndex) } val endpoints = Array[Double](0, 3, 6, 10) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 180665e653727..3a322b8de02be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ /** * A test suite for generated projections @@ -75,7 +76,7 @@ class GeneratedProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { // test generated MutableProjection val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => BoundReference(i, f.dataType, true) - } + }.toImmutableArraySeq val mutableProj = GenerateMutableProjection.generate(exprs) val row1 = mutableProj(result) assert(result === row1) @@ -126,7 +127,7 @@ class GeneratedProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { // test generated MutableProjection val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => BoundReference(i, f.dataType, true) - } + }.toImmutableArraySeq val mutableProj = GenerateMutableProjection.generate(exprs) val row1 = mutableProj(result) assert(result === row1) @@ -240,7 +241,7 @@ class GeneratedProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { // test generated MutableProjection val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => BoundReference(i, f.dataType, true) - } + }.toImmutableArraySeq val mutableProj = GenerateMutableProjection.generate(exprs) val row1 = mutableProj(result) assert(result === row1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index db2d8c24fd0ba..e421d5f392926 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.SupportsNamespaces import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, ByteType, IntegerType, LongType} +import org.apache.spark.util.ArrayImplicits._ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { val attribute = attr("key") @@ -268,7 +269,8 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { test("command should report a dummy stats") { val plan = CommentOnNamespace( - ResolvedNamespace(mock(classOf[SupportsNamespaces]), Array("ns")), "comment") + ResolvedNamespace(mock(classOf[SupportsNamespaces]), + Array("ns").toImmutableArraySeq), "comment") checkStats( plan, expectedStatsCboOn = Statistics.DUMMY, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/GenericArrayDataBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/GenericArrayDataBenchmark.scala index a2800b3faaeed..7a72258b43c9c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/GenericArrayDataBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/GenericArrayDataBenchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.util.ArrayImplicits._ /** * Benchmark for [[GenericArrayData]]. @@ -58,7 +59,7 @@ object GenericArrayDataBenchmark extends BenchmarkBase { } benchmark.addCase("arrayOfAnyAsSeq") { _ => - val arr: Seq[Any] = new Array[Any](arraySize) + val arr: Seq[Any] = new Array[Any](arraySize).toImmutableArraySeq var n = 0 while (n < valuesPerIteration) { new GenericArrayData(arr) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala index dd3d77f26cdd3..8c9039c723eb1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionsAlreadyExistException} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ /** * This class is used to test SupportsAtomicPartitionManagement API. @@ -61,7 +62,7 @@ class InMemoryAtomicPartitionTable ( properties: Array[util.Map[String, String]]): Unit = { if (idents.exists(partitionExists)) { throw new PartitionsAlreadyExistException( - name, idents.filter(partitionExists), partitionSchema) + name, idents.filter(partitionExists).toImmutableArraySeq, partitionSchema) } idents.zip(properties).foreach { case (ident, property) => createPartition(ident, property) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 318cbf6962c19..c1967f558c171 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ /** * A simple in-memory table. Rows are stored as a buffered group produced by each output task. @@ -104,7 +105,7 @@ abstract class InMemoryBaseTable( def rows: Seq[InternalRow] = dataMap.values.flatten.flatMap(_.rows).toSeq val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map { ref => - schema.findNestedField(ref.fieldNames(), includeCollections = false) match { + schema.findNestedField(ref.fieldNames().toImmutableArraySeq, includeCollections = false) match { case Some(_) => ref.fieldNames() case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.") } @@ -193,7 +194,7 @@ abstract class InMemoryBaseTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } - } + }.toImmutableArraySeq } protected def addPartitionKey(key: Seq[Any]): Unit = {} @@ -312,7 +313,8 @@ abstract class InMemoryBaseTable( private var _pushedFilters: Array[Filter] = Array.empty override def build: Scan = { - val scan = InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, tableSchema) + val scan = InMemoryBatchScan( + data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema) if (evaluableFilters.nonEmpty) { scan.filter(evaluableFilters) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 04d68ba3afa52..af04816e6b6f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ /** * A simple in-memory table. Rows are stored as a buffered group produced by each output task. @@ -50,7 +51,8 @@ class InMemoryTable( override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper - dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) + dataMap --= InMemoryTable + .filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted).toImmutableArraySeq, filters) } override def withData(data: Array[BufferedRows]): InMemoryTable = { @@ -116,7 +118,7 @@ class InMemoryTable( import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val deleteKeys = InMemoryTable.filtersToKeys( - dataMap.keys, partCols.map(_.toSeq.quoted), filters) + dataMap.keys, partCols.map(_.toSeq.quoted).toImmutableArraySeq, filters) dataMap --= deleteKeys withData(messages.map(_.asInstanceOf[BufferedRows])) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala index f4e3d9e5e1b1b..20ada0d622bca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, S import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwriteV2, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ class InMemoryTableWithV2Filter( name: String, @@ -42,7 +43,7 @@ class InMemoryTableWithV2Filter( override def deleteWhere(filters: Array[Predicate]): Unit = dataMap.synchronized { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper dataMap --= InMemoryTableWithV2Filter - .filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) + .filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted).toImmutableArraySeq, filters) } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { @@ -51,8 +52,8 @@ class InMemoryTableWithV2Filter( class InMemoryV2FilterScanBuilder(tableSchema: StructType) extends InMemoryScanBuilder(tableSchema) { - override def build: Scan = - InMemoryV2FilterBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, tableSchema) + override def build: Scan = InMemoryV2FilterBatchScan( + data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema) } case class InMemoryV2FilterBatchScan( @@ -122,7 +123,7 @@ class InMemoryTableWithV2Filter( import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val deleteKeys = InMemoryTableWithV2Filter.filtersToKeys( - dataMap.keys, partCols.map(_.toSeq.quoted), predicates) + dataMap.keys, partCols.map(_.toSeq.quoted).toImmutableArraySeq, predicates) dataMap --= deleteKeys withData(messages.map(_.asInstanceOf[BufferedRows])) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index 8ac268df80bcd..0dabae43adeb5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst import org.apache.spark.sql.connector.expressions.LogicalExpressions.{bucket, clusterBy} import org.apache.spark.sql.types.DataType +import org.apache.spark.util.ArrayImplicits._ class TransformExtractorSuite extends SparkFunSuite { /** @@ -190,7 +191,7 @@ class TransformExtractorSuite extends SparkFunSuite { assert(arguments1(0).asInstanceOf[LiteralValue[Integer]].value === 16) assert(arguments1(1).asInstanceOf[NamedReference].fieldNames() === Seq("a")) assert(arguments1(2).asInstanceOf[NamedReference].fieldNames() === Seq("b")) - val copied1 = bucketTransform.withReferences(reference1) + val copied1 = bucketTransform.withReferences(reference1.toImmutableArraySeq) assert(copied1.equals(bucketTransform)) val sortedBucketTransform = bucket(16, col, sortedCol) @@ -207,7 +208,7 @@ class TransformExtractorSuite extends SparkFunSuite { assert(arguments2(2).asInstanceOf[LiteralValue[Integer]].value === 16) assert(arguments2(3).asInstanceOf[NamedReference].fieldNames() === Seq("c")) assert(arguments2(4).asInstanceOf[NamedReference].fieldNames() === Seq("d")) - val copied2 = sortedBucketTransform.withReferences(reference2) + val copied2 = sortedBucketTransform.withReferences(reference2.toImmutableArraySeq) assert(copied2.equals(sortedBucketTransform)) } @@ -248,7 +249,7 @@ class TransformExtractorSuite extends SparkFunSuite { assert(arguments.length == 2) assert(arguments(0).asInstanceOf[NamedReference].fieldNames() === Seq("a a", "b")) assert(arguments(1).asInstanceOf[NamedReference].fieldNames() === Seq("ts")) - val copied = clusterByTransform.withReferences(reference) + val copied = clusterByTransform.withReferences(reference.toImmutableArraySeq) assert(copied.equals(clusterByTransform)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 0ca55ef67fd38..bfbc3287c63c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.TypedAggUtils import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ private[sql] object Column { @@ -1131,7 +1132,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def as(aliases: Array[String]): Column = withExpr { MultiAlias(expr, aliases) } + def as(aliases: Array[String]): Column = withExpr { + MultiAlias(expr, aliases.toImmutableArraySeq) + } /** * Gives the column an alias. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 2f2857760526d..790d15267a574 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** @@ -98,11 +99,10 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { cols: Array[String], probabilities: Array[Double], relativeError: Double): Array[Array[Double]] = withOrigin { - import org.apache.spark.util.ArrayImplicits._ StatFunctions.multipleApproxQuantiles( df.select(cols.map(col).toImmutableArraySeq: _*), - cols, - probabilities, + cols.toImmutableArraySeq, + probabilities.toImmutableArraySeq, relativeError).map(_.toArray).toArray } @@ -259,7 +259,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @since 1.4.0 */ def freqItems(cols: Array[String], support: Double): DataFrame = withOrigin { - FrequentItems.singlePassFreqItems(df, cols, support) + FrequentItems.singlePassFreqItems(df, cols.toImmutableArraySeq, support) } /** @@ -278,7 +278,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @since 1.4.0 */ def freqItems(cols: Array[String]): DataFrame = withOrigin { - FrequentItems.singlePassFreqItems(df, cols, 0.01) + FrequentItems.singlePassFreqItems(df, cols.toImmutableArraySeq, 0.01) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a567a915daf66..d18b12964c6d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -64,6 +64,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils private[sql] object Dataset { @@ -254,7 +255,7 @@ class Dataset[T] private[sql]( private[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get - } + }.toImmutableArraySeq } /** @@ -282,7 +283,7 @@ class Dataset[T] private[sql]( // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." - schema.fieldNames.map(SchemaUtils.escapeMetaCharacters).toSeq +: data.map { row => + (schema.fieldNames.map(SchemaUtils.escapeMetaCharacters).toSeq +: data.map { row => row.toSeq.map { cell => assert(cell != null, "ToPrettyString is not nullable and should not return null value") // Escapes meta-characters not to break the `showString` format @@ -295,7 +296,7 @@ class Dataset[T] private[sql]( str } }: Seq[String] - } + }).toImmutableArraySeq } /** @@ -2075,8 +2076,8 @@ class Dataset[T] private[sql]( valueColumnName: String): DataFrame = withOrigin { withPlan { Unpivot( - Some(ids.map(_.named)), - Some(values.map(v => Seq(v.named))), + Some(ids.map(_.named).toImmutableArraySeq), + Some(values.map(v => Seq(v.named)).toImmutableArraySeq), None, variableColumnName, Seq(valueColumnName), @@ -2108,7 +2109,7 @@ class Dataset[T] private[sql]( valueColumnName: String): DataFrame = withOrigin { withPlan { Unpivot( - Some(ids.map(_.named)), + Some(ids.map(_.named).toImmutableArraySeq), None, None, variableColumnName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala index cb6fbfbb2ae3e..f4b518c1e9fbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala @@ -24,6 +24,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.spark.util.ArrayImplicits._ /** @@ -142,7 +143,7 @@ class Observation(val name: String) { case _ => false }) { val row = qe.observedMetrics.get(name) - this.metrics = row.map(r => r.getValuesMap[Any](r.schema.fieldNames)) + this.metrics = row.map(r => r.getValuesMap[Any](r.schema.fieldNames.toImmutableArraySeq)) if (metrics.isDefined) { notifyAll() unregister() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index aec40c845fa00..779015ee13ebb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -54,6 +54,7 @@ import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.{CallSite, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * The entry point to programming Spark with the Dataset and DataFrame API. @@ -635,7 +636,7 @@ class SparkSession private( val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { val parsedPlan = sessionState.sqlParser.parsePlan(sqlText) if (args.nonEmpty) { - PosParameterizedQuery(parsedPlan, args.map(lit(_).expr)) + PosParameterizedQuery(parsedPlan, args.map(lit(_).expr).toImmutableArraySeq) } else { parsedPlan } @@ -890,7 +891,7 @@ class SparkSession private( val (dataType, _) = JavaTypeInference.inferDataType(beanClass) dataType.asInstanceOf[StructType].fields.map { f => AttributeReference(f.name, f.dataType, f.nullable)() - } + }.toImmutableArraySeq } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 5f72dc59105eb..623f8136abbe3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ private[sql] object SQLUtils extends Logging { SerDe.setSQLReadObject(readSqlObject).setSQLWriteObject(writeSqlObject) @@ -192,7 +193,7 @@ private[sql] object SQLUtils extends Logging { case 's' => // Read StructType for DataFrame val fields = SerDe.readList(dis, jvmObjectTracker = null) - Row.fromSeq(fields) + Row.fromSeq(fields.toImmutableArraySeq) case _ => null } } @@ -232,7 +233,7 @@ private[sql] object SQLUtils extends Logging { sparkSession: SparkSession, filename: String): JavaRDD[Array[Byte]] = { // Parallelize the record batches to create an RDD - val batches = ArrowConverters.readArrowStreamFromFile(filename) + val batches = ArrowConverters.readArrowStreamFromFile(filename).toImmutableArraySeq JavaRDD.fromRDD(sparkSession.sparkContext.parallelize(batches, batches.length)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index c557ec4a486db..078b988eaf0bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} +import org.apache.spark.util.ArrayImplicits._ /** * Converts resolved v2 commands to v1 if the catalog is the session catalog. Since the v2 commands @@ -581,7 +582,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) def unapply(resolved: LogicalPlan): Option[TableIdentifier] = resolved match { case ResolvedIdentifier(catalog, ident) if isSessionCatalog(catalog) => if (ident.namespace().length != 1) { - throw QueryCompilationErrors.requiresSinglePartNamespaceError(ident.namespace()) + throw QueryCompilationErrors + .requiresSinglePartNamespaceError(ident.namespace().toImmutableArraySeq) } Some(TableIdentifier(ident.name, Some(ident.namespace.head), Some(catalog.name))) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index e906c74f8a5ee..6678ae2dc1b98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, File import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK +import org.apache.spark.util.ArrayImplicits._ /** Holds a cached logical plan and its data */ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) @@ -186,7 +187,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { case SubqueryAlias(ident, DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _)) => isSameName(ident.qualifier :+ ident.name) && - isSameName(catalog.name() +: v2Ident.namespace() :+ v2Ident.name()) + isSameName((catalog.name() +: v2Ident.namespace() :+ v2Ident.name()).toImmutableArraySeq) case SubqueryAlias(ident, View(catalogTable, _, _)) => val v1Ident = catalogTable.identifier diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarEvaluatorFactory.scala index 95e7e509d7ba0..2ed94c36436ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarEvaluatorFactory.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.ArrayImplicits._ class ColumnarToRowEvaluatorFactory( childOutput: Seq[Attribute], @@ -73,9 +74,9 @@ class RowToColumnarEvaluatorFactory( new Iterator[ColumnarBatch] { private lazy val converters = new RowToColumnConverter(schema) private lazy val vectors: Seq[WritableColumnVector] = if (enableOffHeapColumnVector) { - OffHeapColumnVector.allocateColumns(numRows, schema) + OffHeapColumnVector.allocateColumns(numRows, schema).toImmutableArraySeq } else { - OnHeapColumnVector.allocateColumns(numRows, schema) + OnHeapColumnVector.allocateColumns(numRows, schema).toImmutableArraySeq } private lazy val cb: ColumnarBatch = new ColumnarBatch(vectors.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CommandResultExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CommandResultExec.scala index 45e3e41ab053d..27b8dae71711b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CommandResultExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CommandResultExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.ArrayImplicits._ /** * Physical plan node for holding data from a command. @@ -55,7 +56,7 @@ case class CommandResultExec( } else { val numSlices = math.min( unsafeRows.length, session.leafNodeDefaultParallelism) - sparkContext.parallelize(unsafeRows, numSlices) + sparkContext.parallelize(unsafeRows.toImmutableArraySeq, numSlices) } } @@ -91,7 +92,7 @@ case class CommandResultExec( } override def executeTail(limit: Int): Array[InternalRow] = { - val taken: Seq[InternalRow] = unsafeRows.takeRight(limit) + val taken: Seq[InternalRow] = unsafeRows.takeRight(limit).toImmutableArraySeq longMetric("numOutputRows").add(taken.size) sendDriverMetrics() taken.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index c7bb3b6719157..b3b2b0eab0555 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.BitSet @@ -164,8 +165,10 @@ case class RowDataSourceScanExec( Map("ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> pushedFilters) ++ pushedDownOperators.aggregation.fold(Map[String, String]()) { v => - Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())), - "PushedGroupByExpressions" -> seqToString(v.groupByExpressions.map(_.describe())))} ++ + Map("PushedAggregates" -> seqToString( + v.aggregateExpressions.map(_.describe()).toImmutableArraySeq), + "PushedGroupByExpressions" -> + seqToString(v.groupByExpressions.map(_.describe()).toImmutableArraySeq))} ++ topNOrLimitInfo ++ offsetInfo ++ pushedDownOperators.sample.map(v => "PushedSample" -> @@ -290,7 +293,7 @@ trait FileSourceScanLike extends DataSourceScanExec { val filePruningRunner = new FilePruningRunner(dynamicDataFilters) ret = ret.map(filePruningRunner.prune) } - setFilesNumAndSizeMetric(ret, false) + setFilesNumAndSizeMetric(ret.toImmutableArraySeq, false) val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 driverMetrics("pruningTime").set(timeTakenMs) ret @@ -691,7 +694,7 @@ case class FileSourceScanExec( selectedPartitions: Array[PartitionDirectory]): RDD[InternalRow] = { val openCostInBytes = relation.sparkSession.sessionState.conf.filesOpenCostInBytes val maxSplitBytes = - FilePartition.maxSplitBytes(relation.sparkSession, selectedPartitions) + FilePartition.maxSplitBytes(relation.sparkSession, selectedPartitions.toImmutableArraySeq) logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + s"open cost is considered as scanning $openCostInBytes bytes.") @@ -722,8 +725,8 @@ case class FileSourceScanExec( } }.sortBy(_.length)(implicitly[Ordering[Long]].reverse) - val partitions = - FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) + val partitions = FilePartition + .getFilePartitions(relation.sparkSession, splitFiles.toImmutableArraySeq, maxSplitBytes) new FileScanRDD(relation.sparkSession, readFile, partitions, new StructType(requiredSchema.fields ++ relation.partitionSchema.fields), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 602cca3327fd0..9811a1d3f33e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTab import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.ArrayImplicits._ /** * Runs a query returning the result in Hive compatible form. @@ -62,15 +63,15 @@ object HiveResult { // SHOW TABLES in Hive only output table names while our v1 command outputs // database, table name, isTemp. case ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => - executedPlan.executeCollect().map(_.getString(1)) + executedPlan.executeCollect().map(_.getString(1)).toImmutableArraySeq // SHOW TABLES in Hive only output table names while our v2 command outputs // namespace and table name. case _ : ShowTablesExec => - executedPlan.executeCollect().map(_.getString(1)) + executedPlan.executeCollect().map(_.getString(1)).toImmutableArraySeq // SHOW VIEWS in Hive only outputs view names while our v1 command outputs // namespace, viewName, and isTemporary. case ExecutedCommandExec(_: ShowViewsCommand) => - executedPlan.executeCollect().map(_.getString(1)) + executedPlan.executeCollect().map(_.getString(1)).toImmutableArraySeq case other => val timeFormatters = getTimeFormatters val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq @@ -87,7 +88,7 @@ object HiveResult { Seq(name, dataType, Option(comment.asInstanceOf[String]).getOrElse("")) .map(s => String.format("%-20s", s)) .mkString("\t") - } + }.toImmutableArraySeq } /** Formats a datum (based on the given data type) and returns the string representation. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index f178cd63dfeb3..cf974e19ef1b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.ArrayImplicits._ /** @@ -51,7 +52,7 @@ case class LocalTableScanExec( } else { val numSlices = math.min( unsafeRows.length, session.leafNodeDefaultParallelism) - sparkContext.parallelize(unsafeRows, numSlices) + sparkContext.parallelize(unsafeRows.toImmutableArraySeq, numSlices) } } @@ -85,7 +86,7 @@ case class LocalTableScanExec( } override def executeTail(limit: Int): Array[InternalRow] = { - val taken: Seq[InternalRow] = unsafeRows.takeRight(limit) + val taken: Seq[InternalRow] = unsafeRows.takeRight(limit).toImmutableArraySeq longMetric("numOutputRows").add(taken.size) sendDriverMetrics() taken.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 3d35300773bee..eb5b38d428819 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, WatermarkPropagator} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -120,7 +121,7 @@ class QueryExecution( qe.analyzed.output, qe.commandExecuted, qe.executedPlan, - result) + result.toImmutableArraySeq) case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 52bee7ecf3bae..a56488e780f8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.util.ArrayImplicits._ /** * The base class of [[SortBasedAggregationIterator]], [[TungstenAggregationIterator]] and @@ -142,7 +143,7 @@ abstract class AggregationIterator( // no-op expressions which are ignored during projection code-generation. case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) } - newMutableProjection(initExpressions, Nil) + newMutableProjection(initExpressions.toImmutableArraySeq, Nil) } // All imperative AggregateFunctions. @@ -222,7 +223,8 @@ abstract class AggregationIterator( } protected val processRow: (InternalRow, InternalRow) => Unit = - generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes) + generateProcessRow(aggregateExpressions, + aggregateFunctions.toImmutableArraySeq, inputAttributes) protected val groupingProjection: UnsafeProjection = UnsafeProjection.create(groupingExpressions, inputAttributes) @@ -239,7 +241,8 @@ abstract class AggregationIterator( case agg: AggregateFunction => NoOp } val aggregateResult = new SpecificInternalRow(aggregateAttributes.map(_.dataType)) - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes) + val expressionAggEvalProjection = newMutableProjection( + evalExpressions.toImmutableArraySeq, bufferAttributes.toImmutableArraySeq) expressionAggEvalProjection.target(aggregateResult) val resultProjection = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b942907b6752f..72d7086e42cef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType, StringType} import org.apache.spark.unsafe.KVIterator +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -708,7 +709,8 @@ case class HashAggregateExec( // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when // generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while // generating input columns, we use `currentVars`. - ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input + ctx.currentVars = (new Array[ExprCode](aggregateBufferAttributes.length) ++ input) + .toImmutableArraySeq val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName) // Computes start offsets for each aggregation function code diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 6d05be72b36a6..4cc251a99db76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator +import org.apache.spark.util.ArrayImplicits._ class ObjectAggregationIterator( partIndex: Int, @@ -75,7 +76,8 @@ class ObjectAggregationIterator( } val newFunctions = initializeAggregateFunctions(newExpressions, 0) val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) - generateProcessRow(newExpressions, newFunctions, newInputAttributes) + generateProcessRow( + newExpressions, newFunctions.toImmutableArraySeq, newInputAttributes.toImmutableArraySeq) } /** @@ -119,7 +121,7 @@ class ObjectAggregationIterator( // - when creating the re-used buffer for sort-based aggregation private def createNewAggregationBuffer(): SpecificInternalRow = { val bufferFieldTypes = aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType)) - val buffer = new SpecificInternalRow(bufferFieldTypes) + val buffer = new SpecificInternalRow(bufferFieldTypes.toImmutableArraySeq) initAggregationBuffer(buffer) buffer } @@ -186,7 +188,7 @@ class ObjectAggregationIterator( if (sortBased) { val sortIteratorFromHashMap = hashMap - .dumpToExternalSorter(groupingAttributes, aggregateFunctions) + .dumpToExternalSorter(groupingAttributes, aggregateFunctions.toImmutableArraySeq) .sortedIterator() sortBasedAggregationStore = new SortBasedAggregator( sortIteratorFromHashMap, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 7de2215037b2a..db567dcd15be8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.unsafe.KVIterator +import org.apache.spark.util.ArrayImplicits._ /** * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. @@ -141,7 +142,7 @@ class TungstenAggregationIterator( val groupingAttributes = groupingExpressions.map(_.toAttribute) val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) val groupingKeySchema = DataTypeUtils.fromAttributes(groupingAttributes) - val bufferSchema = DataTypeUtils.fromAttributes(bufferAttributes) + val bufferSchema = DataTypeUtils.fromAttributes(bufferAttributes.toImmutableArraySeq) val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { @@ -165,7 +166,8 @@ class TungstenAggregationIterator( // all groups and their corresponding aggregation buffers for hash-based aggregation. private[this] val hashMap = new UnsafeFixedWidthAggregationMap( initialAggregationBuffer, - DataTypeUtils.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), + DataTypeUtils.fromAttributes( + aggregateFunctions.flatMap(_.aggBufferAttributes).toImmutableArraySeq), DataTypeUtils.fromAttributes(groupingExpressions.map(_.toAttribute)), TaskContext.get(), 1024 * 16, // initial capacity @@ -258,7 +260,8 @@ class TungstenAggregationIterator( } val newFunctions = initializeAggregateFunctions(newExpressions, 0) val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) - sortBasedProcessRow = generateProcessRow(newExpressions, newFunctions, newInputAttributes) + sortBasedProcessRow = generateProcessRow( + newExpressions, newFunctions.toImmutableArraySeq, newInputAttributes.toImmutableArraySeq) // Step 5: Get the sorted iterator from the externalSorter. sortedKVIterator = externalSorter.sortedIterator() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 9ddec74374abd..9e6a99ef9fb28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils} +import org.apache.spark.util.ArrayImplicits._ /** @@ -398,7 +399,8 @@ private[sql] object ArrowConverters extends Logging { if (shouldUseRDD) { logDebug("Using RDD-based createDataFrame with Arrow optimization.") val timezone = session.sessionState.conf.sessionLocalTimeZone - val rdd = session.sparkContext.parallelize(batchesInDriver, batchesInDriver.length) + val rdd = session.sparkContext + .parallelize(batchesInDriver.toImmutableArraySeq, batchesInDriver.length) .mapPartitions { batchesInExecutors => ArrowConverters.fromBatchIterator( batchesInExecutors, @@ -419,7 +421,8 @@ private[sql] object ArrowConverters extends Logging { // Project/copy it. Otherwise, the Arrow column vectors will be closed and released out. val proj = UnsafeProjection.create(attrs, attrs) - Dataset.ofRows(session, LocalRelation(attrs, data.map(r => proj(r).copy()).toArray)) + Dataset.ofRows(session, + LocalRelation(attrs, data.map(r => proj(r).copy()).toArray.toImmutableArraySeq)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 4758d739d8c25..32db94b3cdfc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, FloatType, import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{LongAccumulator, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * The default implementation of CachedBatch. @@ -192,7 +193,7 @@ class DefaultCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { }.toArray input.mapPartitionsInternal { cachedBatchIterator => - val columnarIterator = GenerateColumnAccessor.generate(columnTypes) + val columnarIterator = GenerateColumnAccessor.generate(columnTypes.toImmutableArraySeq) columnarIterator.initialize(cachedBatchIterator.asInstanceOf[Iterator[DefaultCachedBatch]], columnTypes, requestedColumnIndices.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 1a6eab27d08cc..a1e9c4229b194 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.IncrementalExecution import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -167,7 +168,8 @@ case class ExplainCommand( .explainString(mode) Seq(Row(outputString)) } catch { case NonFatal(cause) => - ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) + ("Error occurred during query planning: \n" + cause.getMessage).split("\n") + .map(Row(_)).toImmutableArraySeq } } @@ -189,7 +191,8 @@ case class StreamingExplainCommand( } Seq(Row(outputString)) } catch { case NonFatal(cause) => - ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) + ("Error occurred during query planning: \n" + cause.getMessage).split("\n") + .map(Row(_)).toImmutableArraySeq } } @@ -207,6 +210,6 @@ case class ExternalCommandExecutor( override def run(sparkSession: SparkSession): Seq[Row] = { val output = runner.executeCommand(command, new CaseInsensitiveStringMap(options.asJava)) - output.map(Row(_)) + output.map(Row(_)).toImmutableArraySeq } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 9dc6d7aab8ba3..130872b10bcd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -51,6 +51,7 @@ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.PartitioningUtils import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} +import org.apache.spark.util.ArrayImplicits._ // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -267,7 +268,7 @@ case class DropTableCommand( case class DropTempViewCommand(ident: Identifier) extends LeafRunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { assert(ident.namespace().isEmpty || ident.namespace().length == 1) - val nameParts = ident.namespace() :+ ident.name() + val nameParts = (ident.namespace() :+ ident.name()).toImmutableArraySeq val catalog = sparkSession.sessionState.catalog catalog.getRawLocalOrGlobalTempView(nameParts).foreach { view => val hasViewText = view.tableMeta.viewText.isDefined @@ -765,7 +766,7 @@ case class RepairTableCommand( // scalastyle:on parvector parArray.seq } else { - statuses + statuses.toImmutableArraySeq } statusPar.flatMap { st => val name = st.getPath.getName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 5c748bcb084d8..7ed82b16cc5e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -51,6 +51,7 @@ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.PartitioningUtils import org.apache.spark.sql.util.SchemaUtils +import org.apache.spark.util.ArrayImplicits._ /** * A command to create a table with the same definition of the given existing table. @@ -857,7 +858,7 @@ case class DescribeColumnCommand( Row(s"bin_$index", s"lower_bound: ${bin.lo}, upper_bound: ${bin.hi}, distinct_count: ${bin.ndv}") } - header +: bins + (header +: bins).toImmutableArraySeq } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 88b7826bd91b3..6b0a1c34ed2b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types.{MetadataBuilder, StructType} import org.apache.spark.sql.util.SchemaUtils +import org.apache.spark.util.ArrayImplicits._ /** * Create or replace a view with given query plan. This command will generate some view-specific @@ -496,14 +497,15 @@ object ViewHelper extends SQLConfHelper with Logging { // Generate the query column names, throw an AnalysisException if there exists duplicate column // names. - SchemaUtils.checkColumnNameDuplication(fieldNames, conf.resolver) + SchemaUtils.checkColumnNameDuplication(fieldNames.toImmutableArraySeq, conf.resolver) // Generate the view default catalog and namespace, as well as captured SQL configs. val manager = session.sessionState.catalogManager removeReferredTempNames(removeSQLConfigs(removeQueryColumnNames(properties))) ++ - catalogAndNamespaceToProps(manager.currentCatalog.name, manager.currentNamespace) ++ + catalogAndNamespaceToProps( + manager.currentCatalog.name, manager.currentNamespace.toImmutableArraySeq) ++ sqlConfigsToProps(conf) ++ - generateQueryColumnNames(queryOutput) ++ + generateQueryColumnNames(queryOutput.toImmutableArraySeq) ++ referredTempNamesToProps(tempViewNames, tempFunctionNames, tempVariableNames) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 8bb6947488ac7..b3784dbf81373 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -53,6 +53,7 @@ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.{HadoopFSUtils, ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * The main class responsible for representing a pluggable Data Source in Spark SQL. In addition to @@ -270,7 +271,7 @@ case class DataSource( SourceInfo( s"FileSource[$path]", StructType(sourceDataSchema ++ partitionSchema), - partitionSchema.fieldNames) + partitionSchema.fieldNames.toImmutableArraySeq) case _ => throw QueryExecutionErrors.streamedOperatorUnsupportedByDataSourceError( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index e9aa6d8d356be..448f1182a1be9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * Abstract class for writing out data in a single Spark task. @@ -80,7 +81,7 @@ abstract class FileFormatDataWriter( def writeWithMetrics(record: InternalRow, count: Long): Unit = { if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { - CustomMetrics.updateMetrics(currentMetricsValues, customMetrics) + CustomMetrics.updateMetrics(currentMetricsValues.toImmutableArraySeq, customMetrics) } write(record) } @@ -92,7 +93,7 @@ abstract class FileFormatDataWriter( writeWithMetrics(iterator.next(), count) count += 1 } - CustomMetrics.updateMetrics(currentMetricsValues, customMetrics) + CustomMetrics.updateMetrics(currentMetricsValues.toImmutableArraySeq, customMetrics) } /** @@ -486,7 +487,7 @@ class DynamicPartitionDataConcurrentWriter( writeWithMetrics(iterator.next(), count) count += 1 } - CustomMetrics.updateMetrics(currentMetricsValues, customMetrics) + CustomMetrics.updateMetrics(currentMetricsValues.toImmutableArraySeq, customMetrics) if (iterator.hasNext) { count = 0L @@ -497,7 +498,7 @@ class DynamicPartitionDataConcurrentWriter( writeWithMetrics(sortIterator.next(), count) count += 1 } - CustomMetrics.updateMetrics(currentMetricsValues, customMetrics) + CustomMetrics.updateMetrics(currentMetricsValues.toImmutableArraySeq, customMetrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 8321b1fac71ee..9dbadbd97ec79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.ArrayImplicits._ /** A helper object for writing FileFormat data out to a location. */ @@ -231,7 +232,7 @@ object FileFormatWriter extends Logging { // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single // partition rdd to make sure we at least set up one write task to write the metadata. val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) { - sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1) + sparkSession.sparkContext.parallelize(Array.empty[InternalRow].toImmutableArraySeq, 1) } else { rdd } @@ -272,10 +273,12 @@ object FileFormatWriter extends Logging { val commitMsgs = ret.map(_.commitMsg) logInfo(s"Start to commit write Job ${description.uuid}.") - val (_, duration) = Utils.timeTakenMs { committer.commitJob(job, commitMsgs) } + val (_, duration) = Utils + .timeTakenMs { committer.commitJob(job, commitMsgs.toImmutableArraySeq) } logInfo(s"Write Job ${description.uuid} committed. Elapsed time: $duration ms.") - processStats(description.statsTrackers, ret.map(_.summary.stats), duration) + processStats( + description.statsTrackers, ret.map(_.summary.stats).toImmutableArraySeq, duration) logInfo(s"Finished processing stats for write job ${description.uuid}.") // return a set of all the partition paths that were updated during this job diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala index 2535440add19a..0346c9c570107 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala @@ -25,6 +25,7 @@ import org.apache.spark.paths.SparkPath import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ /** * A file status augmented with optional metadata, which tasks and file readers can use however they @@ -101,7 +102,7 @@ class FilePruningRunner(filters: Seq[Expression]) { object PartitionDirectory { // For backward compat with code that does not know about extra file metadata def apply(values: InternalRow, files: Array[FileStatus]): PartitionDirectory = - PartitionDirectory(values, files.map(FileStatusWithMetadata(_))) + PartitionDirectory(values, files.map(FileStatusWithMetadata(_)).toImmutableArraySeq) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 0cca51cf4e393..af53e71501089 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.FileFormat._ import org.apache.spark.sql.execution.vectorized.{ColumnVectorUtils, ConstantColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.NextIterator /** @@ -317,6 +318,6 @@ class FileScanRDD( override protected def getPartitions: Array[RDDPartition] = filePartitions.toArray override protected def getPreferredLocations(split: RDDPartition): Seq[String] = { - split.asInstanceOf[FilePartition].preferredLocations() + split.asInstanceOf[FilePartition].preferredLocations().toImmutableArraySeq } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 551fe253657c4..1b1eddecdb932 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{PLAN_EXPRESSION, SCALAR_ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.types.{DoubleType, FloatType, StructType} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.BitSet /** @@ -335,7 +336,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { // Here, we *explicitly* enforce the not null to `CreateStruct(structColumns)` // to avoid any risk of inconsistent schema nullability val metadataAlias = - Alias(KnownNotNull(CreateStruct(structColumns)), + Alias(KnownNotNull(CreateStruct(structColumns.toImmutableArraySeq)), FileFormat.METADATA_NAME)(exprId = metadataStruct.exprId) execution.ProjectExec( readDataColumns ++ partitionColumns :+ metadataAlias, scan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index ef4fff2360097..37de04a59e4b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ /** * An abstract class that represents [[FileIndex]]s that are aware of partitioned tables. @@ -86,7 +87,7 @@ abstract class PartitioningAwareFileIndex( val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { case Some(existingDir) => // Directory has children files in it, return them - existingDir.filter(f => matchPathPattern(f) && isNonEmptyFile(f)) + existingDir.filter(f => matchPathPattern(f) && isNonEmptyFile(f)).toImmutableArraySeq case None => // Directory does not exist, or has no children files diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 9b38155851e09..959a99c95cc4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. @@ -343,7 +344,7 @@ object PartitioningUtils extends SQLConfHelper { pathFragment.split("/").map { kv => val pair = kv.split("=", 2) (unescapePathName(pair(0)), unescapePathName(pair(1))) - } + }.toImmutableArraySeq } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteFiles.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteFiles.scala index d0ed6b02fef81..a4fd57e7dffad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteFiles.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteFiles.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec +import org.apache.spark.util.ArrayImplicits._ /** * The write files spec holds all information of [[V1WriteCommand]] if its provider is @@ -75,7 +76,7 @@ case class WriteFilesExec( // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single // partition rdd to make sure we at least set up one write task to write the metadata. val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) { - session.sparkContext.parallelize(Array.empty[InternalRow], 1) + session.sparkContext.parallelize(Array.empty[InternalRow].toImmutableArraySeq, 1) } else { rdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index f2b84810175e9..9be764e8b07d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDiale import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.NextIterator /** @@ -338,7 +339,8 @@ object JdbcUtils extends Logging with SQLConfHelper { new NextIterator[InternalRow] { private[this] val rs = resultSet private[this] val getters: Array[JDBCValueGetter] = makeGetters(dialect, schema) - private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) + private[this] val mutableRow = + new SpecificInternalRow(schema.fields.map(x => x.dataType).toImmutableArraySeq) override protected def close(): Unit = { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala index e544409467360..4a608a47b3887 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala @@ -22,6 +22,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.sources.{And, Filter} import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType, StructField, StructType} +import org.apache.spark.util.ArrayImplicits._ /** * Methods that can be shared when upgrading the built-in Hive. @@ -60,7 +61,7 @@ trait OrcFiltersBase { fields.flatMap { f => f.dataType match { case st: StructType => - getPrimitiveFields(st.fields, parentFieldNames :+ f.name) + getPrimitiveFields(st.fields.toImmutableArraySeq, parentFieldNames :+ f.name) case BinaryType => None case _: AtomicType => val fieldName = (parentFieldNames :+ f.name).quoted @@ -71,7 +72,7 @@ trait OrcFiltersBase { } } - val primitiveFields = getPrimitiveFields(schema.fields) + val primitiveFields = getPrimitiveFields(schema.fields.toImmutableArraySeq) if (caseSensitive) { primitiveFields.toMap } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index c4490b95e3b2d..b08f5546bdd8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, Schem import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils import org.apache.spark.sql.types._ import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ object OrcUtils extends Logging { @@ -502,7 +503,7 @@ object OrcUtils extends Logging { case (x, _) => throw new IllegalArgumentException( s"createAggInternalRowFromFooter should not take $x as the aggregate expression") - } + }.toImmutableArraySeq val orcValuesDeserializer = new OrcDeserializer(schemaWithoutGroupBy, (0 until schemaWithoutGroupBy.length).toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumn.scala index cbe6eb99a9879..6ac96300ccd65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumn.scala @@ -23,6 +23,7 @@ import org.apache.parquet.io.PrimitiveColumnIO import org.apache.parquet.schema.Type.Repetition import org.apache.spark.sql.types.DataType +import org.apache.spark.util.ArrayImplicits._ /** * Rich information for a Parquet column together with its SparkSQL type. @@ -43,12 +44,12 @@ object ParquetColumn { def apply(sparkType: DataType, io: PrimitiveColumnIO): ParquetColumn = { this(sparkType, Some(io.getColumnDescriptor), io.getRepetitionLevel, io.getDefinitionLevel, io.getType.isRepetition(Repetition.REQUIRED), - io.getFieldPath, Seq.empty) + io.getFieldPath.toImmutableArraySeq, Seq.empty) } def apply(sparkType: DataType, io: GroupColumnIO, children: Seq[ParquetColumn]): ParquetColumn = { this(sparkType, None, io.getRepetitionLevel, io.getDefinitionLevel, io.getType.isRepetition(Repetition.REQUIRED), - io.getFieldPath, children) + io.getFieldPath.toImmutableArraySeq, children) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index b325a2f54c182..bd6b5bfeb4da8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.internal.SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, MapType, StructField, StructType, UserDefinedType} +import org.apache.spark.util.ArrayImplicits._ object ParquetUtils extends Logging { @@ -141,11 +142,13 @@ object ParquetUtils extends Logging { val leaves = allFiles.toArray.sortBy(_.getPath.toString) FileTypes( - data = leaves.filterNot(f => isSummaryFile(f.getPath)), + data = leaves.filterNot(f => isSummaryFile(f.getPath)).toImmutableArraySeq, metadata = - leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE), + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) + .toImmutableArraySeq, commonMetadata = - leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE)) + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) + .toImmutableArraySeq) } private def isSummaryFile(file: Path): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 65ebbb57fd325..12c183de19d06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec import org.apache.spark.sql.util.SchemaUtils +import org.apache.spark.util.ArrayImplicits._ /** * Replaces [[UnresolvedRelation]]s if the plan is for direct query on files. @@ -273,10 +274,11 @@ case class PreprocessTableCreation(catalog: SessionCatalog) extends Rule[Logical case transform: RewritableTransform => val rewritten = transform.references().map { ref => // Throws an exception if the reference cannot be resolved - val position = SchemaUtils.findColumnPosition(ref.fieldNames(), schema, resolver) + val position = SchemaUtils + .findColumnPosition(ref.fieldNames().toImmutableArraySeq, schema, resolver) FieldReference(SchemaUtils.getColumnName(position, schema)) } - transform.withReferences(rewritten) + transform.withReferences(rewritten.toImmutableArraySeq) case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index afcc762e636a3..7cce599040189 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Par import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read._ +import org.apache.spark.util.ArrayImplicits._ /** * Physical plan node for scanning a batch of data from a data source v2. @@ -55,7 +56,8 @@ case class BatchScanExec( override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters) - @transient override lazy val inputPartitions: Seq[InputPartition] = batch.planInputPartitions() + @transient override lazy val inputPartitions: Seq[InputPartition] = + batch.planInputPartitions().toImmutableArraySeq @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { val dataSourceFilters = runtimeFilters.flatMap { @@ -100,11 +102,12 @@ case class BatchScanExec( "partition values that are not present in the original partitioning.") } - groupPartitions(newPartitions).map(_.groupedParts.map(_.parts)).getOrElse(Seq.empty) + groupPartitions(newPartitions.toImmutableArraySeq) + .map(_.groupedParts.map(_.parts)).getOrElse(Seq.empty) case _ => // no validation is needed as the data source did not report any specific partitioning - newPartitions.map(Seq(_)) + newPartitions.map(Seq(_)).toImmutableArraySeq } } else { @@ -135,7 +138,7 @@ case class BatchScanExec( override lazy val inputRDD: RDD[InternalRow] = { val rdd = if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) { // return an empty RDD with 1 partition if dynamic filtering removed the only split - sparkContext.parallelize(Array.empty[InternalRow], 1) + sparkContext.parallelize(Array.empty[InternalRow].toImmutableArraySeq, 1) } else { val finalPartitions = outputPartitioning match { case p: KeyGroupedPartitioning => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala index bcb7149fd0b17..288233e691453 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.connector.read.{InputPartition, Scan} import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReaderFactory, ContinuousStream, Offset} import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.util.ArrayImplicits._ /** * Physical plan node for scanning data from a streaming data source with continuous mode. @@ -43,7 +44,8 @@ case class ContinuousScanExec( override def hashCode(): Int = stream.hashCode() - override lazy val inputPartitions: Seq[InputPartition] = stream.planInputPartitions(start) + override lazy val inputPartitions: Seq[InputPartition] = + stream.planInputPartitions(start).toImmutableArraySeq override lazy val readerFactory: ContinuousPartitionReaderFactory = { stream.createContinuousReaderFactory() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 8288e7e4a845e..e46c0806ba2d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, Par import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.ArrayImplicits._ class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition]) extends Partition with Serializable @@ -89,7 +90,8 @@ class DataSourceRDD( context.addTaskCompletionListener[Unit] { _ => // In case of early stopping before consuming the entire iterator, // we need to do one more metric update at the end of the task. - CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics) + CustomMetrics + .updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, customMetrics) iter.forceUpdateMetrics() reader.close() } @@ -128,7 +130,7 @@ private class PartitionIterator[T]( throw QueryExecutionErrors.endOfStreamError() } if (numRow % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { - CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics) + CustomMetrics.updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, customMetrics) } numRow += 1 valuePrepared = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index b2f94cae2dfa7..45fc2a0765c0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsMetadata import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils trait DataSourceV2ScanExecBase extends LeafExecNode { @@ -198,7 +199,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, - driveSQLMetrics) + driveSQLMetrics.toImmutableArraySeq) } override def doExecuteColumnar(): RDD[ColumnarBatch] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 0106a9c5aea0e..e6bce7a0990c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -48,6 +48,7 @@ import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ class DataSourceV2Strategy(session: SparkSession) extends Strategy with PredicateHelper { @@ -349,7 +350,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case DropTable(r: ResolvedIdentifier, ifExists, purge) => val invalidateFunc = () => session.sharedState.cacheManager.uncacheTableOrView( - session, r.catalog.name() +: r.identifier.namespace() :+ r.identifier.name(), + session, + (r.catalog.name() +: r.identifier.namespace() :+ r.identifier.name()).toImmutableArraySeq, cascade = true) DropTableExec(r.catalog.asTableCatalog, r.identifier, ifExists, purge, invalidateFunc) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala index 9cade86829e6f..3d79a7113e0d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeTableExec.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, ResolveDefaultColumns} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsMetadataColumns, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.IdentityTransform +import org.apache.spark.util.ArrayImplicits._ case class DescribeTableExec( output: Seq[Attribute], @@ -104,7 +105,7 @@ case class DescribeTableExec( rows ++= table.partitioning .map(_.asInstanceOf[IdentityTransform].ref.fieldNames()) .map { fieldNames => - val nestedField = table.schema.findNestedField(fieldNames) + val nestedField = table.schema.findNestedField(fieldNames.toImmutableArraySeq) assert(nestedField.isDefined, s"Not found the partition column ${fieldNames.map(quoteIfNeeded).mkString(".")} " + s"in the table schema ${table.schema().catalogString}.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala index 2125b58813f85..f080a964c09ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.util.ArrayImplicits._ /** * Physical plan node for dropping a table. @@ -38,7 +39,7 @@ case class DropTableExec( if (purge) catalog.purgeTable(ident) else catalog.dropTable(ident) } else if (!ifExists) { throw QueryCompilationErrors.noSuchTableError( - catalog.name() +: ident.namespace() :+ ident.name()) + (catalog.name() +: ident.namespace() :+ ident.name()).toImmutableArraySeq) } Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala index ead5114d38680..2f443a0bb1fad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage} import org.apache.spark.sql.execution.datasources.{WriteJobDescription, WriteTaskResult} import org.apache.spark.sql.execution.datasources.FileFormatWriter.processStats +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class FileBatchWrite( @@ -33,10 +34,12 @@ class FileBatchWrite( override def commit(messages: Array[WriterCommitMessage]): Unit = { val results = messages.map(_.asInstanceOf[WriteTaskResult]) logInfo(s"Start to commit write Job ${description.uuid}.") - val (_, duration) = Utils.timeTakenMs { committer.commitJob(job, results.map(_.commitMsg)) } + val (_, duration) = Utils + .timeTakenMs { committer.commitJob(job, results.map(_.commitMsg).toImmutableArraySeq) } logInfo(s"Write Job ${description.uuid} committed. Elapsed time: $duration ms.") - processStats(description.statsTrackers, results.map(_.summary.stats), duration) + processStats( + description.statsTrackers, results.map(_.summary.stats).toImmutableArraySeq, duration) logInfo(s"Finished processing stats for write job ${description.uuid}.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala index 068870511ea55..52f44e33ea11f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.SchemaUtils +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SerializableConfiguration trait FileWrite extends Write { @@ -91,7 +92,8 @@ trait FileWrite extends Write { s"got: ${paths.mkString(", ")}") } val pathName = paths.head - SchemaUtils.checkColumnNameDuplication(schema.fields.map(_.name), caseSensitiveAnalysis) + SchemaUtils.checkColumnNameDuplication( + schema.fields.map(_.name).toImmutableArraySeq, caseSensitiveAnalysis) DataSource.validateSchema(schema, sqlConf) // TODO: [SPARK-36340] Unify check schema filed of DataSource V2 Insert. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala index c545b3dd50b59..07958987fa081 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} +import org.apache.spark.util.ArrayImplicits._ /** * Physical plan node for scanning a micro-batch of data from a data source. @@ -43,7 +44,8 @@ case class MicroBatchScanExec( override def hashCode(): Int = stream.hashCode() - override lazy val inputPartitions: Seq[InputPartition] = stream.planInputPartitions(start, end) + override lazy val inputPartitions: Seq[InputPartition] = + stream.planInputPartitions(start, end).toImmutableArraySeq override lazy val readerFactory: PartitionReaderFactory = stream.createReaderFactory() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala index 2a89ed8f80c99..6eede88c55bd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/OptimizeMetadataOnlyDeleteFromTable.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.write.RowLevelOperation import org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.util.ArrayImplicits._ /** * A rule that replaces a rewritten DELETE operation with a delete using filters if the data source @@ -46,7 +47,7 @@ object OptimizeMetadataOnlyDeleteFromTable extends Rule[LogicalPlan] with Predic val allPredicatesTranslated = normalizedPredicates.size == filters.length if (allPredicatesTranslated && table.canDeleteWhere(filters)) { logDebug(s"Switching to delete with filters: ${filters.mkString("[", ", ", "]")}") - DeleteFromTableWithFilters(relation, filters) + DeleteFromTableWithFilters(relation, filters.toImmutableArraySeq) } else { rowLevelPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index e78b1359021b7..3de4692c83b09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.Utils object PushDownUtils { @@ -70,7 +71,8 @@ object PushDownUtils { // Normally translated filters (postScanFilters) are simple filters that can be evaluated // faster, while the untranslated filters are complicated filters that take more time to // evaluate, so we want to evaluate the postScanFilters filters first. - (Left(r.pushedFilters()), (postScanFilters ++ untranslatableExprs).toSeq) + (Left(r.pushedFilters().toImmutableArraySeq), + (postScanFilters ++ untranslatableExprs).toSeq) case r: SupportsPushDownV2Filters => // A map from translated data source leaf node filters to original catalyst filter @@ -102,11 +104,12 @@ object PushDownUtils { // Normally translated filters (postScanFilters) are simple filters that can be evaluated // faster, while the untranslated filters are complicated filters that take more time to // evaluate, so we want to evaluate the postScanFilters filters first. - (Right(r.pushedPredicates), (postScanFilters ++ untranslatableExprs).toSeq) + (Right(r.pushedPredicates.toImmutableArraySeq), + (postScanFilters ++ untranslatableExprs).toSeq) case f: FileScanBuilder => val postScanFilters = f.pushFilters(filters) - (Right(f.pushedFilters), postScanFilters) + (Right(f.pushedFilters.toImmutableArraySeq), postScanFilters) case _ => (Left(Nil), filters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowFunctionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowFunctionsExec.scala index b8a9003b559ac..b80f4ee2357d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowFunctionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowFunctionsExec.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.FunctionCatalog import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.util.ArrayImplicits._ /** * Physical plan node for showing functions. @@ -54,9 +55,10 @@ case class ShowFunctionsExec( // List all temporary functions in the session catalog applyPattern(session.sessionState.catalog.listTemporaryFunctions().map(_.unquotedString)) ++ // List all functions registered in the given namespace of the catalog - applyPattern(catalog.listFunctions(namespace.toArray).map(_.name())).map { funcName => - (catalog.name() +: namespace :+ funcName).quoted - } + applyPattern(catalog.listFunctions(namespace.toArray).map(_.name()).toImmutableArraySeq) + .map { funcName => + (catalog.name() +: namespace :+ funcName).quoted + } } else Seq.empty (userFunctions ++ systemFunctions).distinct.sorted.foreach { fn => rows += toCatalystRow(fn) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowPartitionsExec.scala index f298d042b1c31..65c329e3ad37e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowPartitionsExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsPartitionManagement, Tabl import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ /** * Physical plan node for showing partitions. @@ -60,6 +61,6 @@ case class ShowPartitionsExec( } partitions.mkString("/") } - output.sorted.map(p => InternalRow(UTF8String.fromString(p))) + output.sorted.map(p => InternalRow(UTF8String.fromString(p))).toImmutableArraySeq } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala index b7470ab5059cc..cb7c3efdbe482 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.read.{SupportsReportOrdering, SupportsReportPartitioning} import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, UnknownPartitioning} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.Utils.sequenceToOption /** @@ -44,7 +45,8 @@ object V2ScanPartitioningAndOrdering extends Rule[LogicalPlan] with SQLConfHelpe val catalystPartitioning = scan.outputPartitioning() match { case kgp: KeyGroupedPartitioning => val partitioning = sequenceToOption( - kgp.keys().map(V2ExpressionUtils.toCatalystOpt(_, relation, relation.funCatalog))) + kgp.keys().map(V2ExpressionUtils.toCatalystOpt(_, relation, relation.funCatalog)) + .toImmutableArraySeq) if (partitioning.isEmpty) { None } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index d4e40530058f1..8c262cf56e8b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructType} import org.apache.spark.sql.util.SchemaUtils._ +import org.apache.spark.util.ArrayImplicits._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ @@ -541,7 +542,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } val pushedDownOperators = PushedDownOperators(sHolder.pushedAggregate, sHolder.pushedSample, sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders, sHolder.pushedPredicates) - V1ScanWrapper(v1, pushedFilters, pushedDownOperators) + V1ScanWrapper(v1, pushedFilters.toImmutableArraySeq, pushedDownOperators) case _ => scan } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 6dd76973baa5a..933e82a259dbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ /** * A [[TableCatalog]] that translates calls to the v1 SessionCatalog. @@ -244,7 +245,7 @@ class V2SessionCatalog(catalog: SessionCatalog) case Array(db) => TableIdentifier(ident.name, Some(db)) case other => - throw QueryCompilationErrors.requiresSinglePartNamespaceError(other) + throw QueryCompilationErrors.requiresSinglePartNamespaceError(other.toImmutableArraySeq) } } @@ -253,7 +254,7 @@ class V2SessionCatalog(catalog: SessionCatalog) case Array(db) => FunctionIdentifier(ident.name, Some(db)) case other => - throw QueryCompilationErrors.requiresSinglePartNamespaceError(other) + throw QueryCompilationErrors.requiresSinglePartNamespaceError(other.toImmutableArraySeq) } } } @@ -346,7 +347,7 @@ class V2SessionCatalog(catalog: SessionCatalog) } def isTempView(ident: Identifier): Boolean = { - catalog.isTempView(ident.namespace() :+ ident.name()) + catalog.isTempView((ident.namespace() :+ ident.name()).toImmutableArraySeq) } override def loadFunction(ident: Identifier): UnboundFunction = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index a9e34ff2e1cf0..2527f201f3a81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{LongAccumulator, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * Deprecated logical plan for writing data into data source v2. This is being replaced by more @@ -362,7 +363,7 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single // partition rdd to make sure we at least set up one write task to write the metadata. if (tempRdd.partitions.length == 0) { - sparkContext.parallelize(Array.empty[InternalRow], 1) + sparkContext.parallelize(Array.empty[InternalRow].toImmutableArraySeq, 1) } else { tempRdd } @@ -440,7 +441,8 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial Utils.tryWithSafeFinallyAndFailureCallbacks(block = { while (iter.hasNext) { if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { - CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics) + CustomMetrics.updateMetrics( + dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics) } // Count is here. @@ -448,7 +450,8 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial write(dataWriter, iter.next()) } - CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics) + CustomMetrics.updateMetrics( + dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics) val msg = if (useCommitCoordinator) { val coordinator = SparkEnv.get.outputCommitCoordinator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index b193e5199690d..05c6844c7c3b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SerializableConfiguration case class CSVScan( @@ -80,7 +81,8 @@ case class CSVScan( // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, parsedOptions, actualFilters) + dataSchema, readDataSchema, readPartitionSchema, parsedOptions, + actualFilters.toImmutableArraySeq) } override def equals(obj: Any): Boolean = obj match { @@ -92,6 +94,6 @@ case class CSVScan( override def hashCode(): Int = super.hashCode() override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters.toImmutableArraySeq)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index ea642a3a5e510..d9ef94cee0438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ case class JDBCScan( relation: JDBCRelation, @@ -58,13 +59,13 @@ case class JDBCScan( override def description(): String = { val (aggString, groupByString) = if (groupByColumns.nonEmpty) { val groupByColumnsLength = groupByColumns.get.length - (seqToString(pushedAggregateColumn.drop(groupByColumnsLength)), - seqToString(pushedAggregateColumn.take(groupByColumnsLength))) + (seqToString(pushedAggregateColumn.drop(groupByColumnsLength).toImmutableArraySeq), + seqToString(pushedAggregateColumn.take(groupByColumnsLength).toImmutableArraySeq)) } else { ("[]", "[]") } super.description() + ", prunedSchema: " + seqToString(prunedSchema) + - ", PushedPredicates: " + seqToString(pushedPredicates) + + ", PushedPredicates: " + seqToString(pushedPredicates.toImmutableArraySeq) + ", PushedAggregates: " + aggString + ", PushedGroupBy: " + groupByString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index ff7273f2870b2..db92f645eedb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SerializableConfiguration case class JsonScan( @@ -80,7 +81,8 @@ case class JsonScan( // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. JsonPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) + dataSchema, readDataSchema, readPartitionSchema, parsedOptions, + pushedFilters.toImmutableArraySeq) } override def equals(obj: Any): Boolean = obj match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 2b7bdae6b31b4..be57252ecb293 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * A factory used to create Orc readers. @@ -73,7 +74,7 @@ case class OrcPartitionReaderFactory( private def pushDownPredicates(orcSchema: TypeDescription, conf: Configuration): Unit = { if (orcFilterPushDown && filters.nonEmpty) { val fileSchema = OrcUtils.toCatalystSchema(orcSchema) - OrcFilters.createFilter(fileSchema, filters).foreach { f => + OrcFilters.createFilter(fileSchema, filters.toImmutableArraySeq).foreach { f => OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 4f410f6193275..894f7e765a4f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SerializableConfiguration case class OrcScan( @@ -86,14 +87,14 @@ case class OrcScan( override def hashCode(): Int = getClass.hashCode() lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { - (seqToString(pushedAggregate.get.aggregateExpressions), - seqToString(pushedAggregate.get.groupByExpressions)) + (seqToString(pushedAggregate.get.aggregateExpressions.toImmutableArraySeq), + seqToString(pushedAggregate.get.groupByExpressions.toImmutableArraySeq)) } else { ("[]", "[]") } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters.toImmutableArraySeq)) ++ Map("PushedAggregation" -> pushedAggregationsStr) ++ Map("PushedGroupBy" -> pushedGroupByStr) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 2ab1a2a1e210c..b4a857db4846b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ case class OrcScanBuilder( sparkSession: SparkSession, @@ -67,7 +68,7 @@ case class OrcScanBuilder( if (sparkSession.sessionState.conf.orcFilterPushDown) { val dataTypeMap = OrcFilters.getSearchableTypeMap( readDataSchema(), SQLConf.get.caseSensitiveAnalysis) - OrcFilters.convertibleFilters(dataTypeMap, dataFilters).toArray + OrcFilters.convertibleFilters(dataTypeMap, dataFilters.toImmutableArraySeq).toArray } else { Array.empty[Filter] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index f7fa4b7cb82a6..6b0552551547d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SerializableConfiguration case class ParquetScan( @@ -121,14 +122,14 @@ case class ParquetScan( override def hashCode(): Int = getClass.hashCode() lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { - (seqToString(pushedAggregate.get.aggregateExpressions), - seqToString(pushedAggregate.get.groupByExpressions)) + (seqToString(pushedAggregate.get.aggregateExpressions.toImmutableArraySeq), + seqToString(pushedAggregate.get.groupByExpressions.toImmutableArraySeq)) } else { ("[]", "[]") } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters.toImmutableArraySeq)) ++ Map("PushedAggregation" -> pushedAggregationsStr) ++ Map("PushedGroupBy" -> pushedGroupByStr) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index ae98f3b10301f..01367675e65b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.internal.LegacyBehaviorPolicy import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ case class ParquetScanBuilder( sparkSession: SparkSession, @@ -73,7 +74,7 @@ case class ParquetScanBuilder( // The rebase mode doesn't matter here because the filters are used to determine // whether they is convertible. RebaseSpec(LegacyBehaviorPolicy.CORRECTED)) - parquetFilters.convertibleFilters(dataFilters).toArray + parquetFilters.convertibleFilters(dataFilters.toImmutableArraySeq).toArray } else { Array.empty[Filter] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 60ecc4b635e57..4e52137b74271 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.util.ArrayImplicits._ /** * Dynamic partition pruning optimization is performed based on the type and @@ -79,7 +80,8 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join None } case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) => - val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r) + val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute]( + scan.filterAttributes.toImmutableArraySeq, r) if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { Some(r) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala index 7360349284ec1..b8288c636c386 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering import org.apache.spark.sql.connector.write.RowLevelOperation.Command import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, MERGE, UPDATE} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.util.ArrayImplicits._ /** * A rule that assigns a subquery to filter groups in row-level operations at runtime. @@ -63,7 +64,7 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla val command = replaceData.operation.command val matchingRowsPlan = buildMatchingRowsPlan(relation, cond, tableAttrs, command) - val filterAttrs = scan.filterAttributes + val filterAttrs = scan.filterAttributes.toImmutableArraySeq val buildKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan) val pruningKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, r) val dynamicPruningCond = buildDynamicPruningCond(matchingRowsPlan, buildKeys, pruningKeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 32d23136225d3..6dd41aca3a5e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class BroadcastNestedLoopJoinExec( @@ -220,7 +221,7 @@ case class BroadcastNestedLoopJoinExec( // Only need to know whether streamed side is empty or not. val streamExists = !streamed.executeTake(1).isEmpty if (streamExists == exists) { - sparkContext.makeRDD(relation.value) + sparkContext.makeRDD(relation.value.toImmutableArraySeq) } else { sparkContext.emptyRDD } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 77135d21a26ab..0dc4a69c07588 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.metric.{SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.Utils /** @@ -133,7 +134,7 @@ case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec { // If we use this execution plan separately like `Dataset.limit` without an actual // job launch, we might just have to mimic the implementation of `CollectLimitExec`. - sparkContext.parallelize(executeCollect(), numSlices = 1) + sparkContext.parallelize(executeCollect().toImmutableArraySeq, numSlices = 1) } override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala index 762b64dbe31fe..8f1595cfdd714 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{InputRDDCodegen, LeafExecNode, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.ArrayImplicits._ /** * A physical plan node for scanning data from a list of data source partition values. @@ -50,7 +51,7 @@ case class PythonDataSourcePartitionsExec( if (numPartitions == 0) { sparkContext.emptyRDD } else { - sparkContext.parallelize(unsafeRows, numPartitions) + sparkContext.parallelize(unsafeRows.toImmutableArraySeq, numPartitions) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 703c1e10ce265..0e7eb056f434c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.ArrayImplicits._ /** * A user-defined Python data source. This is used by the Python API. @@ -52,7 +53,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { val pickledDataSourceInstance = result.dataSource val dataSource = SimplePythonFunction( - command = pickledDataSourceInstance, + command = pickledDataSourceInstance.toImmutableArraySeq, envVars = dataSourceCls.envVars, pythonIncludes = dataSourceCls.pythonIncludes, pythonExec = dataSourceCls.pythonExec, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala index d5577c9e22f3b..12d484b12dacf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame, UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame, UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class WindowInPandasEvaluatorFactory( @@ -140,7 +141,7 @@ class WindowInPandasEvaluatorFactory( private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray private val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) = - computeWindowBoundHelpers(factories) + computeWindowBoundHelpers(factories.toImmutableArraySeq) private val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 } private val numFrames = factories.length @@ -287,7 +288,8 @@ class WindowInPandasEvaluatorFactory( new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) var bufferIterator: Iterator[UnsafeRow] = _ - val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType)) + val indexRow = + new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType).toImmutableArraySeq) val frames = factories.map(_ (indexRow)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index db26f8c7758e7..6b3b374ae9ad9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ object StatFunctions extends Logging { @@ -104,7 +105,7 @@ object StatFunctions extends Logging { case Some(q) => q case None => Seq() } - } + }.toImmutableArraySeq } /** Calculate the Pearson Correlation Coefficient for the given columns */ @@ -247,7 +248,7 @@ object StatFunctions extends Logging { } if (mapColumns.isEmpty) { - ds.sparkSession.createDataFrame(selectedStatistics.map(Tuple1.apply)) + ds.sparkSession.createDataFrame(selectedStatistics.map(Tuple1.apply).toImmutableArraySeq) .withColumnRenamed("_1", "summary") } else { val valueColumns = columnNames.map { columnName => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 51f310dcc04e3..4dad4c0adeacb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.{DataSource, InMemoryFileIndex, LogicalRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.ThreadUtils /** @@ -242,7 +243,7 @@ class FileStreamSource( val newDataSource = DataSource( sparkSession, - paths = files.map(_.sparkPath.toPath.toString), + paths = files.map(_.sparkPath.toPath.toString).toImmutableArraySeq, userSpecifiedSchema = Some(schema), partitionColumns = partitionColumns, className = fileFormatClassName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index 5fe9a39c91e0b..ceadbca2b1226 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -28,6 +28,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.streaming.FileStreamSource.FileEntry import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ArrayImplicits._ class FileStreamSourceLog( metadataLogVersion: Int, @@ -115,7 +116,7 @@ class FileStreamSourceLog( val batches = (existedBatches ++ retrievedBatches).map(i => i._1 -> i._2.get).toArray.sortBy(_._1) if (startBatchId <= endBatchId) { - HDFSMetadataLog.verifyBatchIds(batches.map(_._1), startId, endId) + HDFSMetadataLog.verifyBatchIds(batches.map(_._1).toImmutableArraySeq, startId, endId) } batches } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index a5d114b240a9d..a7b0483ea08a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -33,6 +33,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ArrayImplicits._ /** @@ -254,7 +255,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) }.sorted - HDFSMetadataLog.verifyBatchIds(batchIds, startId, endId) + HDFSMetadataLog.verifyBatchIds(batchIds.toImmutableArraySeq, startId, endId) batchIds.map(batchId => (batchId, getExistingBatch(batchId))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 6d27961fa0bcb..a7e16c83ffebe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.streaming.ContinuousPartitionReaderFactory import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.NextIterator class ContinuousDataSourceRDDPartition( @@ -96,7 +97,8 @@ class ContinuousDataSourceRDD( override def getNext(): InternalRow = { if (numRow % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { - CustomMetrics.updateMetrics(partitionReader.currentMetricsValues, customMetrics) + CustomMetrics.updateMetrics( + partitionReader.currentMetricsValues.toImmutableArraySeq, customMetrics) } numRow += 1 readerForPartition.next() match { @@ -112,6 +114,6 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - castPartition(split).inputPartition.preferredLocations() + castPartition(split).inputPartition.preferredLocations().toImmutableArraySeq } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 688b66716ea9b..1d6ba87145d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.DataWriter import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -59,12 +60,14 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDat var count = 0L while (dataIterator.hasNext) { if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { - CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics) + CustomMetrics.updateMetrics( + dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics) } count += 1 dataWriter.write(dataIterator.next()) } - CustomMetrics.updateMetrics(dataWriter.currentMetricsValues, customMetrics) + CustomMetrics.updateMetrics( + dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics) logInfo(s"Writer for partition ${context.partitionId()} " + s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.") val msg = dataWriter.commit() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala index 471fd7feedcf6..f1839ccceee1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.connector.write.{PhysicalWriteInfo, WriterCommitMess import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ /** Common methods used to create writes for the console sink */ class ConsoleWrite(schema: StructType, options: CaseInsensitiveStringMap) @@ -65,7 +66,7 @@ class ConsoleWrite(schema: StructType, options: CaseInsensitiveStringMap) println(printMessage) println("-------------------------------------------") // scalastyle:off println - Dataset.ofRows(spark, LocalRelation(toAttributes(schema), rows)) + Dataset.ofRows(spark, LocalRelation(toAttributes(schema), rows.toImmutableArraySeq)) .show(numRowsToShow, isTruncated) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index d2cd5d3de36a1..504c56ae70826 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.types.StructType import org.apache.spark.util.{SizeEstimator, Utils} +import org.apache.spark.util.ArrayImplicits._ /** @@ -706,7 +707,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with /** Fetch all the files that back the store */ private def fetchFiles(): Seq[StoreFile] = { val files: Seq[FileStatus] = try { - fm.list(baseDir) + fm.list(baseDir).toImmutableArraySeq } catch { case _: java.io.FileNotFoundException => Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala index 046cf69f1fcaa..ae342813338c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala @@ -42,6 +42,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -657,7 +658,8 @@ class RocksDBFileManager( // To ignore .log.crc files .filter(file => isLogFile(file.getName)) val (topLevelSstFiles, topLevelOtherFiles) = topLevelFiles.partition(f => isSstFile(f.getName)) - (topLevelSstFiles ++ archivedLogFiles, topLevelOtherFiles) + ((topLevelSstFiles ++ archivedLogFiles).toImmutableArraySeq, + topLevelOtherFiles.toImmutableArraySeq) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index d58cd001e9416..218774a21df8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ /** @@ -89,7 +90,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { val databases = qe.toRdd.collect().map { row => getNamespace(catalog, parseIdent(row.getString(0))) } - CatalogImpl.makeDataset(databases, sparkSession) + CatalogImpl.makeDataset(databases.toImmutableArraySeq, sparkSession) } /** @@ -107,7 +108,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { val databases = qe.toRdd.collect().map { row => getNamespace(catalog, parseIdent(row.getString(0))) } - CatalogImpl.makeDataset(databases, sparkSession) + CatalogImpl.makeDataset(databases.toImmutableArraySeq, sparkSession) } /** @@ -165,7 +166,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { makeTable(catalog.name() +: ns :+ tableName) } } - CatalogImpl.makeDataset(tables, sparkSession) + CatalogImpl.makeDataset(tables.toImmutableArraySeq, sparkSession) } private def makeTable(nameParts: Seq[String]): Table = { @@ -281,7 +282,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { functions += makeFunction(parseIdent(row.getString(0))) } - CatalogImpl.makeDataset(functions.result(), sparkSession) + CatalogImpl.makeDataset(functions.result().toImmutableArraySeq, sparkSession) } private def toFunctionIdent(functionName: String): Seq[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala index 8b02ddf3cdff2..8a3d387889369 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.internal.SQLConf.STATE_STORE_PROVIDER_CLASS import org.apache.spark.sql.internal.StaticSQLConf.ENABLED_STREAMING_UI_CUSTOM_METRIC_LIST import org.apache.spark.sql.streaming.ui.UIUtils._ import org.apache.spark.ui.{GraphUIData, JsCollector, UIUtils => SparkUIUtils, WebUIPage} +import org.apache.spark.util.ArrayImplicits._ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) extends WebUIPage("statistics") with Logging { @@ -166,7 +167,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) new GraphUIData( "watermark-gap-timeline", "watermark-gap-histogram", - watermarkData, + watermarkData.toImmutableArraySeq, minBatchTime, maxBatchTime, 0, @@ -222,7 +223,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) new GraphUIData( "aggregated-num-total-state-rows-timeline", "aggregated-num-total-state-rows-histogram", - numRowsTotalData, + numRowsTotalData.toImmutableArraySeq, minBatchTime, maxBatchTime, 0, @@ -234,7 +235,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) new GraphUIData( "aggregated-num-updated-state-rows-timeline", "aggregated-num-updated-state-rows-histogram", - numRowsUpdatedData, + numRowsUpdatedData.toImmutableArraySeq, minBatchTime, maxBatchTime, 0, @@ -246,7 +247,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) new GraphUIData( "aggregated-state-memory-used-bytes-timeline", "aggregated-state-memory-used-bytes-histogram", - memoryUsedBytesData, + memoryUsedBytesData.toImmutableArraySeq, minBatchTime, maxBatchTime, 0, @@ -258,7 +259,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) new GraphUIData( "aggregated-num-rows-dropped-by-watermark-timeline", "aggregated-num-rows-dropped-by-watermark-histogram", - numRowsDroppedByWatermarkData, + numRowsDroppedByWatermarkData.toImmutableArraySeq, minBatchTime, maxBatchTime, 0, @@ -335,7 +336,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) new GraphUIData( s"aggregated-$metricName-timeline", s"aggregated-$metricName-histogram", - data, + data.toImmutableArraySeq, minBatchTime, maxBatchTime, 0, @@ -408,7 +409,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) new GraphUIData( "input-rate-timeline", "input-rate-histogram", - inputRateData, + inputRateData.toImmutableArraySeq, minBatchTime, maxBatchTime, minRecordRate, @@ -420,7 +421,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) new GraphUIData( "process-rate-timeline", "process-rate-histogram", - processRateData, + processRateData.toImmutableArraySeq, minBatchTime, maxBatchTime, minProcessRate, @@ -432,7 +433,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) new GraphUIData( "input-rows-timeline", "input-rows-histogram", - inputRowsData, + inputRowsData.toImmutableArraySeq, minBatchTime, maxBatchTime, minRows, @@ -444,7 +445,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) new GraphUIData( "batch-duration-timeline", "batch-duration-histogram", - batchDurations, + batchDurations.toImmutableArraySeq, minBatchTime, maxBatchTime, minBatchDuration, @@ -531,7 +532,8 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) generateTimeToValues(operationDurationData) ++ generateFormattedTimeTipStrings(batchToTimestamps) ++ - generateTimeMap(batchTimes) ++ generateTimeTipStrings(batchToTimestamps) ++ + generateTimeMap(batchTimes.toImmutableArraySeq) ++ + generateTimeTipStrings(batchToTimestamps) ++ table ++ jsCollector.toHtml } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 140daced32234..686b47741589f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.DAY import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -602,7 +603,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("||") { checkAnswer( booleanData.filter($"a" || true), - booleanData.collect()) + booleanData.collect().toImmutableArraySeq) checkAnswer( booleanData.filter($"a" || false), @@ -2558,7 +2559,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) } - checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) + checkAnswer(resultDf, expectedDf.collect().toImmutableArraySeq, expectedDf.schema) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index e7c1d2c772c08..e2dd029d4b103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession, SQLTestData} import org.apache.spark.sql.test.SQLTestData.NullStrings import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlanHelper { @@ -321,7 +322,7 @@ class DataFrameSetOperationsSuite extends QueryTest case (data, index) => val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) data.map(_ => rng.nextDouble()).map(i => Row(i)) - } + }.toImmutableArraySeq ) val intersect = df1.intersect(df2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b28a23f13f86d..59759e34cab32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -54,6 +54,7 @@ import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, ContainerStrin import org.apache.spark.sql.types._ import org.apache.spark.tags.SlowSQLTest import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -780,7 +781,7 @@ class DataFrameSuite extends QueryTest test("SPARK-36642: withMetadata: replace metadata of a column") { val metadata = new MetadataBuilder().putLong("key", 1L).build() - val df1 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df1 = sparkContext.parallelize(Array(1, 2, 3).toImmutableArraySeq).toDF("x") val df2 = df1.withMetadata("x", metadata) assert(df2.schema(0).metadata === metadata) @@ -794,7 +795,7 @@ class DataFrameSuite extends QueryTest } test("replace column using withColumn") { - val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = sparkContext.parallelize(Array(1, 2, 3).toImmutableArraySeq).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index a12f9790a7c24..860d6fce604bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -717,7 +718,8 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo .asInstanceOf[InMemoryTable] assert(table.name === "testcat.table_name") - assert(table.partitioning === Seq(IdentityTransform(FieldReference(Array("ts", "timezone"))))) + assert(table.partitioning === + Seq(IdentityTransform(FieldReference(Array("ts", "timezone").toImmutableArraySeq)))) checkAnswer(spark.table(table.name), data) assert(table.dataMap.toArray.length == 2) assert(table.dataMap(Seq(UTF8String.fromString("America/Los_Angeles"))).head.rows.size == 2) @@ -745,14 +747,14 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo assert(table.name === "testcat.table_name") assert(table.partitioning === Seq( - YearsTransform(FieldReference(Array("ts", "created"))), - MonthsTransform(FieldReference(Array("ts", "created"))), - DaysTransform(FieldReference(Array("ts", "created"))), - HoursTransform(FieldReference(Array("ts", "created"))), - YearsTransform(FieldReference(Array("ts", "modified"))), - MonthsTransform(FieldReference(Array("ts", "modified"))), - DaysTransform(FieldReference(Array("ts", "modified"))), - HoursTransform(FieldReference(Array("ts", "modified"))))) + YearsTransform(FieldReference(Array("ts", "created").toImmutableArraySeq)), + MonthsTransform(FieldReference(Array("ts", "created").toImmutableArraySeq)), + DaysTransform(FieldReference(Array("ts", "created").toImmutableArraySeq)), + HoursTransform(FieldReference(Array("ts", "created").toImmutableArraySeq)), + YearsTransform(FieldReference(Array("ts", "modified").toImmutableArraySeq)), + MonthsTransform(FieldReference(Array("ts", "modified").toImmutableArraySeq)), + DaysTransform(FieldReference(Array("ts", "modified").toImmutableArraySeq)), + HoursTransform(FieldReference(Array("ts", "modified").toImmutableArraySeq)))) } test("SPARK-30289 Create: partitioned by bucket(4, ts.timezone)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9285c31d70253..fe64e5abc5350 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -49,6 +49,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) case class TestDataPoint2(x: Int, s: String) @@ -1638,7 +1639,7 @@ class DatasetSuite extends QueryTest Route("b", "a", 1), Route("b", "a", 5), Route("b", "c", 6)) - val ds = sparkContext.parallelize(data).toDF().as[Route] + val ds = sparkContext.parallelize(data.toImmutableArraySeq).toDF().as[Route] val grouped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r))) .groupByKey(r => (r.src, r.dest)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala index 49811d8ac61bc..5e5e4d09c5274 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions.{length, struct, sum} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * Comprehensive tests for Dataset.unpivot. @@ -242,7 +243,7 @@ class DatasetUnpivotSuite extends QueryTest Row(row.id, "int1", row.int1.orNull), Row(row.id, "long1", row.long1.orNull) ) - }) + }.toImmutableArraySeq) } test("unpivot with id and value expressions") { @@ -275,7 +276,7 @@ class DatasetUnpivotSuite extends QueryTest // length of str2 if set, or null otherwise Option(row.str2).map(_.length).orNull) ) - }) + }.toImmutableArraySeq) } test("unpivot with variable / value columns") { @@ -458,7 +459,8 @@ class DatasetUnpivotSuite extends QueryTest test("unpivot after pivot") { // see test "pivot courses" in DataFramePivotSuite - val pivoted = courseSales.groupBy("year").pivot("course", Array("dotNET", "Java")) + val pivoted = courseSales.groupBy("year") + .pivot("course", Array("dotNET", "Java").toImmutableArraySeq) .agg(sum($"earnings")) val unpivoted = pivoted.unpivot(Array($"year"), "course", "earnings") val expected = courseSales.groupBy("year", "course").sum("earnings") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala b/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala index 5f95ba4f38547..00b757e4f78fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala @@ -28,6 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions.{col, rpad} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{CharType, StringType, StructField, StructType, VarcharType} +import org.apache.spark.util.ArrayImplicits._ // The classes in this file are basically moved from https://github.com/databricks/spark-sql-perf @@ -149,7 +150,7 @@ class TPCDSTables(spark: SparkSession, dsdgenDir: String, scaleFactor: Int) v } } - Row.fromSeq(values) + Row.fromSeq(values.toImmutableArraySeq) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index eab3b7d81b804..20158bc5cc620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.python.{UserDefinedPythonDataSource, UserDefinedPythonFunction, UserDefinedPythonTableFunction} import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType, StructType} +import org.apache.spark.util.ArrayImplicits._ /** * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF, @@ -393,7 +394,7 @@ object IntegratedUDFTestUtils extends SQLHelper { private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( name = name, func = SimplePythonFunction( - command = pythonFunc, + command = pythonFunc.toImmutableArraySeq, envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], pythonIncludes = List.empty[String].asJava, pythonExec = pythonExec, @@ -428,7 +429,7 @@ object IntegratedUDFTestUtils extends SQLHelper { pythonScript: String): UserDefinedPythonDataSource = { UserDefinedPythonDataSource( dataSourceCls = SimplePythonFunction( - command = createPythonDataSource(name, pythonScript), + command = createPythonDataSource(name, pythonScript).toImmutableArraySeq, envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], pythonIncludes = List.empty[String].asJava, pythonExec = pythonExec, @@ -446,7 +447,7 @@ object IntegratedUDFTestUtils extends SQLHelper { UserDefinedPythonTableFunction( name = name, func = SimplePythonFunction( - command = createPythonUDTF(name, pythonScript), + command = createPythonUDTF(name, pythonScript).toImmutableArraySeq, envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], pythonIncludes = List.empty[String].asJava, pythonExec = pythonExec, @@ -1165,7 +1166,7 @@ object IntegratedUDFTestUtils extends SQLHelper { private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( name = name, func = SimplePythonFunction( - command = pandasFunc, + command = pandasFunc.toImmutableArraySeq, envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], pythonIncludes = List.empty[String].asJava, pythonExec = pythonExec, @@ -1219,7 +1220,7 @@ object IntegratedUDFTestUtils extends SQLHelper { private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( name = name, func = SimplePythonFunction( - command = pandasGroupedAggFunc, + command = pandasGroupedAggFunc.toImmutableArraySeq, envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], pythonIncludes = List.empty[String].asJava, pythonExec = pythonExec, @@ -1254,7 +1255,7 @@ object IntegratedUDFTestUtils extends SQLHelper { private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( name = name, func = SimplePythonFunction( - command = createPandasGroupedMapFuncWithState(pythonScript), + command = createPandasGroupedMapFuncWithState(pythonScript).toImmutableArraySeq, envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], pythonIncludes = List.empty[String].asJava, pythonExec = pythonExec, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8668d61317409..f5ba655e3e85f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ abstract class QueryTest extends PlanTest { @@ -157,7 +158,17 @@ abstract class QueryTest extends PlanTest { } protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { - checkAnswer(df, expectedAnswer.collect()) + checkAnswer(df, expectedAnswer.collect().toImmutableArraySeq) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Array]] of [[Row]]s. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Array[Row]): Unit = { + checkAnswer(df, expectedAnswer.toImmutableArraySeq) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index eb5dbc6ef54c8..e011c2a24b18b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.TimestampTypes import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.tags.ExtendedSQLTest +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils // scalastyle:off line.size.limit @@ -401,7 +402,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper } queries.toSeq } else { - splitWithSemicolon(allCode).toSeq + splitWithSemicolon(allCode.toImmutableArraySeq).toSeq } // List of SQL queries to run @@ -416,7 +417,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper }) if (regenerateGoldenFiles) { - runQueries(queries, testCase, settings) + runQueries(queries, testCase, settings.toImmutableArraySeq) } else { // A config dimension has multiple config sets, and a config set has multiple configs. // - config dim: Seq[Seq[(String, String)]] @@ -438,7 +439,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper configSets.foreach { configSet => try { - runQueries(queries, testCase, settings ++ configSet) + runQueries(queries, testCase, (settings ++ configSet).toImmutableArraySeq) } catch { case e: Throwable => val configs = configSet.map { @@ -700,7 +701,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper // Filter out test files with invalid extensions such as temp files created // by vi (.swp), Mac (.DS_Store) etc. val filteredFiles = files.filter(_.getName.endsWith(validFileExtensions)) - filteredFiles ++ dirs.flatMap(listFilesRecursively) + (filteredFiles ++ dirs.flatMap(listFilesRecursively)).toImmutableArraySeq } /** Load built-in test tables into the SparkSession. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index e8525eee2960a..e61f8cb0bf069 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.ArrayImplicits._ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveSparkPlanHelper { import testImplicits._ @@ -1064,7 +1065,7 @@ class OrderAndPartitionAwareDataSource extends PartitionAwareDataSource { override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder( - Option(options.get("partitionKeys")).map(_.split(",")), + Option(options.get("partitionKeys")).map(_.split(",").toImmutableArraySeq), Option(options.get("orderKeys")).map(_.split(",").toSeq).getOrElse(Seq.empty) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 5df80c111ba99..68f996ba31367 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ abstract class RowLevelOperationSuiteBase extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { @@ -91,7 +92,7 @@ abstract class RowLevelOperationSuiteBase private def toDF(jsonData: String, schemaString: String = null): DataFrame = { val jsonRows = jsonData.split("\\n").filter(str => str.trim.nonEmpty) - val jsonDS = spark.createDataset(jsonRows)(Encoders.STRING) + val jsonDS = spark.createDataset(jsonRows.toImmutableArraySeq)(Encoders.STRING) if (schemaString == null) { spark.read.json(jsonDS) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 377e1e2b084c6..ec62739b9cf2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -386,7 +387,7 @@ class InMemoryTableWithV1Fallback( } else if (dataMap.contains(partition)) { throw new IllegalStateException("Partition was not removed properly") } else { - dataMap.put(partition, elements) + dataMap.put(partition, elements.toImmutableArraySeq) } } } @@ -414,7 +415,7 @@ class InMemoryTableWithV1Fallback( override def schema: StructType = requiredSchema override def buildScan(): RDD[Row] = { val data = InMemoryV1Provider.getTableData(context.sparkSession, name).collect() - context.sparkContext.makeRDD(data) + context.sparkContext.makeRDD(data.toImmutableArraySeq) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index 6a69691bea8c3..ee71bd3af1e02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.PreprocessTableCreation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{LongType, StringType} +import org.apache.spark.util.ArrayImplicits._ class V2CommandsCaseSensitivitySuite extends SharedSparkSession @@ -55,7 +56,7 @@ class V2CommandsCaseSensitivitySuite val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, false) val plan = CreateTableAsSelect( - UnresolvedIdentifier(Array("table_name")), + UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.identity(ref) :: Nil, TestRelation2, tableSpec, @@ -79,7 +80,7 @@ class V2CommandsCaseSensitivitySuite val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, false) val plan = CreateTableAsSelect( - UnresolvedIdentifier(Array("table_name")), + UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, ref) :: Nil, TestRelation2, tableSpec, @@ -104,7 +105,7 @@ class V2CommandsCaseSensitivitySuite val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, false) val plan = ReplaceTableAsSelect( - UnresolvedIdentifier(Array("table_name")), + UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.identity(ref) :: Nil, TestRelation2, tableSpec, @@ -128,7 +129,7 @@ class V2CommandsCaseSensitivitySuite val tableSpec = UnresolvedTableSpec(Map.empty, None, OptionList(Seq.empty), None, None, None, false) val plan = ReplaceTableAsSelect( - UnresolvedIdentifier(Array("table_name")), + UnresolvedIdentifier(Array("table_name").toImmutableArraySeq), Expressions.bucket(4, ref) :: Nil, TestRelation2, tableSpec, @@ -153,7 +154,8 @@ class V2CommandsCaseSensitivitySuite AddColumns( table, Seq(QualifiedColType( - Some(UnresolvedFieldName(field.init)), field.last, LongType, true, None, None, None))), + Some(UnresolvedFieldName(field.init.toImmutableArraySeq)), + field.last, LongType, true, None, None, None))), Seq("Missing field " + field.head) ) } @@ -314,7 +316,7 @@ class V2CommandsCaseSensitivitySuite test("SPARK-36381: Check column name exist case sensitive and insensitive when rename column") { alterTableErrorClass( - RenameColumn(table, UnresolvedFieldName(Array("id")), "DATA"), + RenameColumn(table, UnresolvedFieldName(Array("id").toImmutableArraySeq), "DATA"), "FIELDS_ALREADY_EXISTS", Map( "op" -> "rename", @@ -331,7 +333,7 @@ class V2CommandsCaseSensitivitySuite } else { Seq("Missing field " + ref.quoted) } - val alter = DropColumns(table, Seq(UnresolvedFieldName(ref)), ifExists) + val alter = DropColumns(table, Seq(UnresolvedFieldName(ref.toImmutableArraySeq)), ifExists) if (ifExists) { // using IF EXISTS will silence all errors for missing columns assertAnalysisSuccess(alter, caseSensitive = true) @@ -346,7 +348,7 @@ class V2CommandsCaseSensitivitySuite test("AlterTable: rename column resolution") { Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => alterTableTest( - RenameColumn(table, UnresolvedFieldName(ref), "newName"), + RenameColumn(table, UnresolvedFieldName(ref.toImmutableArraySeq), "newName"), Seq("Missing field " + ref.quoted) ) } @@ -355,7 +357,8 @@ class V2CommandsCaseSensitivitySuite test("AlterTable: drop column nullability resolution") { Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => alterTableTest( - AlterColumn(table, UnresolvedFieldName(ref), None, Some(true), None, None, None), + AlterColumn(table, UnresolvedFieldName(ref.toImmutableArraySeq), + None, Some(true), None, None, None), Seq("Missing field " + ref.quoted) ) } @@ -364,7 +367,8 @@ class V2CommandsCaseSensitivitySuite test("AlterTable: change column type resolution") { Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => alterTableTest( - AlterColumn(table, UnresolvedFieldName(ref), Some(StringType), None, None, None, None), + AlterColumn(table, UnresolvedFieldName(ref.toImmutableArraySeq), + Some(StringType), None, None, None, None), Seq("Missing field " + ref.quoted) ) } @@ -373,7 +377,8 @@ class V2CommandsCaseSensitivitySuite test("AlterTable: change column comment resolution") { Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => alterTableTest( - AlterColumn(table, UnresolvedFieldName(ref), None, None, Some("comment"), None, None), + AlterColumn(table, UnresolvedFieldName(ref.toImmutableArraySeq), + None, None, Some("comment"), None, None), Seq("Missing field " + ref.quoted) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index a07d206a340f5..b2a46afb13b9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.ArrayImplicits._ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with BeforeAndAfterEach { @@ -84,7 +85,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU child = child, ioschema = defaultIOSchema ), - rowsDf.collect()) + rowsDf.collect().toImmutableArraySeq) assert(uncaughtExceptionHandler.exception.isEmpty) } @@ -101,7 +102,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU child = ExceptionInjectingOperator(child), ioschema = defaultIOSchema ), - rowsDf.collect()) + rowsDf.collect().toImmutableArraySeq) } assert(e.getMessage().contains("intentional exception")) // Before SPARK-25158, uncaughtExceptionHandler will catch IllegalArgumentException @@ -137,7 +138,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU $"b".cast("string"), $"c".cast("string"), $"d".cast("string"), - $"e".cast("string")).collect()) + $"e".cast("string")).collect().toImmutableArraySeq) } } @@ -161,7 +162,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ), df.select( $"a".cast("string").as("key"), - $"b".cast("string").as("value")).collect()) + $"b".cast("string").as("value")).collect().toImmutableArraySeq) checkAnswer( df.select($"a", $"b"), @@ -175,7 +176,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ), df.select( $"a".cast("string").as("key"), - $"b".cast("string").as("value")).collect()) + $"b".cast("string").as("value")).collect().toImmutableArraySeq) checkAnswer( df.select($"a"), @@ -189,7 +190,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ), df.select( $"a".cast("string").as("key"), - lit(null)).collect()) + lit(null)).collect().toImmutableArraySeq) } } @@ -243,7 +244,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ioschema = serde ), df.select($"a", $"b", $"c", $"d", $"e", - $"f", $"g", $"h", $"i", $"j").collect()) + $"f", $"g", $"h", $"i", $"j").collect().toImmutableArraySeq) } } } @@ -283,7 +284,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU child = child, ioschema = defaultIOSchema ), - df.select($"a", $"b", $"c", $"d", $"e").collect()) + df.select($"a", $"b", $"c", $"d", $"e").collect().toImmutableArraySeq) } } @@ -305,7 +306,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU |USING 'cat' AS (a timestamp, b date) |FROM v """.stripMargin) - checkAnswer(query, identity, df.select($"a", $"b").collect()) + checkAnswer(query, identity, df.select($"a", $"b").collect().toImmutableArraySeq) } } } @@ -344,7 +345,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU $"b".cast("string"), $"c".cast("string"), $"d".cast("string"), - $"e".cast("string")).collect()) + $"e".cast("string")).collect().toImmutableArraySeq) // input/output with different delimit and show result checkAnswer( @@ -367,7 +368,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU $"b".cast("string"), $"c".cast("string"), $"d".cast("string"), - $"e".cast("string"))).collect()) + $"e".cast("string"))).collect().toImmutableArraySeq) } } @@ -394,7 +395,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU df.select( $"a".cast("string").as("a"), $"b".cast("string").as("b"), - lit(null), lit(null)).collect()) + lit(null), lit(null)).collect().toImmutableArraySeq) } test("SPARK-32106: TRANSFORM with non-existent command/file") { @@ -485,7 +486,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ioschema = defaultIOSchema ), df.select($"a", $"b", $"c", $"d", $"e", - $"f", $"g").collect()) + $"f", $"g").collect().toImmutableArraySeq) } } @@ -517,7 +518,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU $"b".cast("string"), $"c".cast("string"), $"d".cast("string"), - $"e".cast("string")).collect()) + $"e".cast("string")).collect().toImmutableArraySeq) // test '/path/to/script.py' with script not executable val e1 = intercept[TestFailedException] { @@ -537,7 +538,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU $"b".cast("string"), $"c".cast("string"), $"d".cast("string"), - $"e".cast("string")).collect()) + $"e".cast("string")).collect().toImmutableArraySeq) }.getMessage // Check with status exit code since in GA test, it may lose detail failed root cause. // Different root cause's exitcode is not same. @@ -562,7 +563,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU $"b".cast("string"), $"c".cast("string"), $"d".cast("string"), - $"e".cast("string")).collect()) + $"e".cast("string")).collect().toImmutableArraySeq) scriptFilePath.setExecutable(false) sql(s"ADD FILE ${scriptFilePath.getAbsolutePath}") @@ -583,7 +584,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU $"b".cast("string"), $"c".cast("string"), $"d".cast("string"), - $"e".cast("string")).collect()) + $"e".cast("string")).collect().toImmutableArraySeq) // test `python3 script.py` when file added checkAnswer( @@ -601,7 +602,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU $"b".cast("string"), $"c".cast("string"), $"d".cast("string"), - $"e".cast("string")).collect()) + $"e".cast("string")).collect().toImmutableArraySeq) } } @@ -653,7 +654,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU child = child, ioschema = defaultIOSchema ), - df.select($"ym", $"dt").collect()) + df.select($"ym", $"dt").collect().toImmutableArraySeq) } } @@ -668,7 +669,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU child = child, ioschema = defaultIOSchema ), - df.select($"col").collect()) + df.select($"col").collect().toImmutableArraySeq) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala index 24a98dd83f33a..96aa9be6c924a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ArrayImplicits._ class CoalesceShufflePartitionsSuite extends SparkFunSuite { @@ -102,7 +103,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite { // Check the answer first. QueryTest.checkAnswer( agg, - spark.range(0, 20).selectExpr("id", "50 as cnt").collect()) + spark.range(0, 20).selectExpr("id", "50 as cnt").collect().toImmutableArraySeq) // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. @@ -148,7 +149,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite { .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) QueryTest.checkAnswer( join, - expectedAnswer.collect()) + expectedAnswer.collect().toImmutableArraySeq) // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. @@ -199,7 +200,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite { .selectExpr("id", "2 as cnt") QueryTest.checkAnswer( join, - expectedAnswer.collect()) + expectedAnswer.collect().toImmutableArraySeq) // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. @@ -250,7 +251,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite { .selectExpr("id % 500 as key", "2 as cnt", "id as value") QueryTest.checkAnswer( join, - expectedAnswer.collect()) + expectedAnswer.collect().toImmutableArraySeq) // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. @@ -293,7 +294,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite { .union(spark.range(500, 1000).selectExpr("id % 500", "id as value")) QueryTest.checkAnswer( join, - expectedAnswer.collect()) + expectedAnswer.collect().toImmutableArraySeq) // Then, let's make sure we do not reduce number of post shuffle partitions. val finalPlan = join.queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowToColumnConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowToColumnConverterSuite.scala index 1afe742b988ee..621228fabf875 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowToColumnConverterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowToColumnConverterSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic import org.apache.spark.sql.execution.vectorized.{OnHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ class RowToColumnConverterSuite extends SparkFunSuite { def convertRows(rows: Seq[InternalRow], schema: StructType): Seq[WritableColumnVector] = { @@ -32,7 +33,7 @@ class RowToColumnConverterSuite extends SparkFunSuite { for (row <- rows) { converter.convert(row, vectors) } - vectors + vectors.toImmutableArraySeq } test("integer column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala index 057cb527cf03b..b7907b99495a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.util.ArrayImplicits._ /** * A base suite contains a set of view related test cases for different kind of views @@ -475,7 +476,7 @@ abstract class TempViewTestSuite extends SQLViewTestSuite { val query = s"SELECT $funcName(max(a), min(a)) FROM VALUES (1), (2), (3) t(a)" val viewName = createView("tempView", query) withView(viewName) { - checkViewOutput(viewName, sql(query).collect()) + checkViewOutput(viewName, sql(query).collect().toImmutableArraySeq) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala index da05373125d31..69bcaebdb7823 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.{LocalSparkContext, MapOutputStatistics, MapOutputTracke import org.apache.spark.scheduler.MapStatus import org.apache.spark.sql.execution.adaptive.ShufflePartitionsUtil import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ArrayImplicits._ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { @@ -35,7 +36,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { Some(new MapOutputStatistics(index, bytesByPartitionId)) } val estimatedPartitionStartIndices = ShufflePartitionsUtil.coalescePartitions( - mapOutputStatistics, + mapOutputStatistics.toImmutableArraySeq, Seq.fill(mapOutputStatistics.length)(None), targetSize, minNumPartitions, @@ -103,7 +104,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq.fill(2)(None), targetSize, 1, 0) assert(coalesced.isEmpty) } @@ -340,7 +341,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -373,7 +374,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -412,7 +413,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -455,7 +456,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -495,7 +496,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -522,7 +523,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -542,7 +543,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), None), @@ -558,7 +559,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - None), + None).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -577,7 +578,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -595,7 +596,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -614,7 +615,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -635,7 +636,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -653,7 +654,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -672,7 +673,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -691,7 +692,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), @@ -788,7 +789,7 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { val coalesced = ShufflePartitionsUtil.coalescePartitions( Array( Some(new MapOutputStatistics(0, bytesByPartitionId1)), - Some(new MapOutputStatistics(1, bytesByPartitionId2))), + Some(new MapOutputStatistics(1, bytesByPartitionId2))).toImmutableArraySeq, Seq( Some(specs1), Some(specs2)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index d3f3118664daa..c3768afa90f18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshR import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StringType +import org.apache.spark.util.ArrayImplicits._ /** * Parser test cases for rules defined in [[SparkSqlParser]]. @@ -548,17 +549,19 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { assertEqual("ADD FILE abc.txt", AddFilesCommand(Seq("abc.txt"))) assertEqual("ADD FILE 'abc.txt'", AddFilesCommand(Seq("abc.txt"))) assertEqual("ADD FILE \"/path/to/abc.txt\"", AddFilesCommand("/path/to/abc.txt"::Nil)) - assertEqual("LIST FILE abc.txt", ListFilesCommand(Array("abc.txt"))) - assertEqual("LIST FILE '/path//abc.txt'", ListFilesCommand(Array("/path//abc.txt"))) - assertEqual("LIST FILE \"/path2/abc.txt\"", ListFilesCommand(Array("/path2/abc.txt"))) + assertEqual("LIST FILE abc.txt", ListFilesCommand(Array("abc.txt").toImmutableArraySeq)) + assertEqual("LIST FILE '/path//abc.txt'", + ListFilesCommand(Array("/path//abc.txt").toImmutableArraySeq)) + assertEqual("LIST FILE \"/path2/abc.txt\"", + ListFilesCommand(Array("/path2/abc.txt").toImmutableArraySeq)) assertEqual("ADD JAR /path2/_2/abc.jar", AddJarsCommand(Seq("/path2/_2/abc.jar"))) assertEqual("ADD JAR '/test/path_2/jar/abc.jar'", AddJarsCommand(Seq("/test/path_2/jar/abc.jar"))) assertEqual("ADD JAR \"abc.jar\"", AddJarsCommand(Seq("abc.jar"))) assertEqual("LIST JAR /path-with-dash/abc.jar", - ListJarsCommand(Array("/path-with-dash/abc.jar"))) - assertEqual("LIST JAR 'abc.jar'", ListJarsCommand(Array("abc.jar"))) - assertEqual("LIST JAR \"abc.jar\"", ListJarsCommand(Array("abc.jar"))) + ListJarsCommand(Array("/path-with-dash/abc.jar").toImmutableArraySeq)) + assertEqual("LIST JAR 'abc.jar'", ListJarsCommand(Array("abc.jar").toImmutableArraySeq)) + assertEqual("LIST JAR \"abc.jar\"", ListJarsCommand(Array("abc.jar").toImmutableArraySeq)) assertEqual("ADD FILE '/path with space/abc.txt'", AddFilesCommand(Seq("/path with space/abc.txt"))) assertEqual("ADD JAR '/path with space/abc.jar'", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 0ced0f85fa621..a76360439e601 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.test.SQLTestData.TestData import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.tags.SlowSQLTest +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @SlowSQLTest @@ -85,7 +86,7 @@ class AdaptiveQueryExecSuite val result = dfAdaptive.collect() withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { val df = sql(query) - checkAnswer(df, result) + checkAnswer(df, result.toImmutableArraySeq) } val planAfter = dfAdaptive.queryExecution.executedPlan assert(planAfter.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala index e3b1f467bafad..01ec025b4bf8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BuiltInDataSourceWriteBenchmark.scala @@ -22,6 +22,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql.execution.datasources.orc.OrcCompressionCodec import org.apache.spark.sql.execution.datasources.parquet.ParquetCompressionCodec import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ArrayImplicits._ /** * Benchmark to measure built-in data sources write performance. @@ -50,7 +51,7 @@ object BuiltInDataSourceWriteBenchmark extends DataSourceWriteBenchmark { val formats: Seq[String] = if (mainArgs.isEmpty) { Seq("Parquet", "ORC", "JSON", "CSV") } else { - mainArgs + mainArgs.toImmutableArraySeq } spark.conf.set(SQLConf.PARQUET_COMPRESSION.key, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowFunctionsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowFunctionsSuiteBase.scala index 0f23cc699beba..19bd830500834 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowFunctionsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowFunctionsSuiteBase.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** @@ -127,7 +128,7 @@ trait ShowFunctionsSuiteBase extends QueryTest with DDLCommandTestUtils { createFunction(f) QueryTest.checkAnswer( sql(s"SHOW ALL FUNCTIONS IN $ns"), - allFuns :+ Row(qualifiedFunName("ns", "current_datei"))) + (allFuns :+ Row(qualifiedFunName("ns", "current_datei"))).toImmutableArraySeq) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala index 15d56050c2347..6ba60e245f9b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.ResolvePartitionSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier, InMemoryCatalog, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog} import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.ArrayImplicits._ /** * The trait contains settings and utility functions. It can be mixed to the test suites for @@ -56,7 +57,8 @@ trait CommandSuiteBase extends SharedSparkSession { val partTable = catalogPlugin.asTableCatalog .loadTable(Identifier.of(namespaces, tableName)) .asInstanceOf[InMemoryPartitionTable] - val ident = ResolvePartitionSpec.convertToPartIdent(spec, partTable.partitionSchema.fields) + val ident = ResolvePartitionSpec.convertToPartIdent(spec, + partTable.partitionSchema.fields.toImmutableArraySeq) val partMetadata = partTable.loadPartitionMetadata(ident) assert(partMetadata.containsKey("location")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowFunctionsSuite.scala index f5630e5255972..d0685e8815fc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowFunctionsSuite.scala @@ -23,6 +23,7 @@ import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen.JavaStrL import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper import org.apache.spark.sql.execution.command +import org.apache.spark.util.ArrayImplicits._ /** * The class contains tests for the `SHOW FUNCTIONS` command to check V2 table catalogs. @@ -37,7 +38,7 @@ class ShowFunctionsSuite extends command.ShowFunctionsSuiteBase with CommandSuit private def funNameToId(name: String): Identifier = { val parts = name.split('.') assert(parts.head == funCatalog, s"${parts.head} is wrong catalog. Expected: $funCatalog.") - new MultipartIdentifierHelper(parts.tail).asIdentifier + new MultipartIdentifierHelper(parts.tail.toImmutableArraySeq).asIdentifier } override protected def createFunction(name: String): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index d906ae80a80ff..c7e4db2aa33ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.types.StructType.fromDDL import org.apache.spark.sql.types.TestUDT.{MyDenseVector, MyDenseVectorUDT} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class TestFileFilter extends PathFilter { @@ -1377,7 +1378,7 @@ abstract class JsonSuite val d1 = DataSource( spark, userSpecifiedSchema = None, - partitionColumns = Array.empty[String], + partitionColumns = Array.empty[String].toImmutableArraySeq, bucketSpec = None, className = classOf[JsonFileFormat].getCanonicalName, options = Map("path" -> path)).resolveRelation() @@ -1385,7 +1386,7 @@ abstract class JsonSuite val d2 = DataSource( spark, userSpecifiedSchema = None, - partitionColumns = Array.empty[String], + partitionColumns = Array.empty[String].toImmutableArraySeq, bucketSpec = None, className = classOf[JsonFileFormat].getCanonicalName, options = Map("path" -> path)).resolveRelation() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 4090898cec6dd..ab4389eceec62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.tags.ExtendedSQLTest +import org.apache.spark.util.ArrayImplicits._ /** * A test suite that tests Apache ORC filter API based filter pushdown optimization. @@ -64,7 +65,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _, _, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") assert(o.pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters) + val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters.toImmutableArraySeq) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for " + s"${o.pushedFilters.mkString("pushedFilters(", ", ", ")")}") checker(maybeFilter.get) @@ -546,7 +547,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { OrcFilters.createFilter(schema, Array( LessThan("a", 10), StringContains("b", "prefix") - )).get.asInstanceOf[SearchArgumentImpl].toOldString + ).toImmutableArraySeq).get.asInstanceOf[SearchArgumentImpl].toOldString } // The `LessThan` should be converted while the whole inner `And` shouldn't @@ -557,7 +558,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { GreaterThan("a", 1), StringContains("b", "prefix") )) - )).get.asInstanceOf[SearchArgumentImpl].toOldString + ).toImmutableArraySeq).get.asInstanceOf[SearchArgumentImpl].toOldString } // Safely remove unsupported `StringContains` predicate and push down `LessThan` @@ -567,7 +568,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { LessThan("a", 10), StringContains("b", "prefix") ) - )).get.asInstanceOf[SearchArgumentImpl].toOldString + ).toImmutableArraySeq).get.asInstanceOf[SearchArgumentImpl].toOldString } // Safely remove unsupported `StringContains` predicate, push down `LessThan` and `GreaterThan`. @@ -581,7 +582,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { ), GreaterThan("a", 1) ) - )).get.asInstanceOf[SearchArgumentImpl].toOldString + ).toImmutableArraySeq).get.asInstanceOf[SearchArgumentImpl].toOldString } } @@ -604,7 +605,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { LessThan("a", 1) ) ) - )).get.asInstanceOf[SearchArgumentImpl].toOldString + ).toImmutableArraySeq).get.asInstanceOf[SearchArgumentImpl].toOldString } assertResult("leaf-0 = (LESS_THAN_EQUALS a 10), leaf-1 = (LESS_THAN a 1)," + @@ -620,7 +621,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { LessThan("a", 1) ) ) - )).get.asInstanceOf[SearchArgumentImpl].toOldString + ).toImmutableArraySeq).get.asInstanceOf[SearchArgumentImpl].toOldString } assert(OrcFilters.createFilter(schema, Array( @@ -631,7 +632,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { LessThan("a", 1) ) ) - )).isEmpty) + ).toImmutableArraySeq).isEmpty) } test("SPARK-27160: Fix casting of the DecimalType literal") { @@ -641,7 +642,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { OrcFilters.createFilter(schema, Array( LessThan( "a", - new java.math.BigDecimal(3.14, MathContext.DECIMAL64).setScale(2))) + new java.math.BigDecimal(3.14, MathContext.DECIMAL64).setScale(2))).toImmutableArraySeq ).get.asInstanceOf[SearchArgumentImpl].toOldString } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index c8c823b2018a3..8a38075932b3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION +import org.apache.spark.util.ArrayImplicits._ /** * OrcTest @@ -126,7 +127,8 @@ trait OrcTest extends QueryTest with FileBasedDataSourceTest with BeforeAndAfter assert(o.pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") } else { assert(o.pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters) + val maybeFilter = OrcFilters + .createFilter(query.schema, o.pushedFilters.toImmutableArraySeq) assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for " + s"${o.pushedFilters.mkString("pushedFilters(", ", ", ")")}") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala index d3e9819b9a054..5e59418f8f928 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.util.ArrayImplicits._ class ParquetFileMetadataStructRowIndexSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -75,7 +76,7 @@ class ParquetFileMetadataStructRowIndexSuite extends QueryTest with SharedSparkS case s: StructType => collectMetadataCols(s) case _ if allMetadataCols.contains(field.name) => Some(field.name) case _ => None - }} + }}.toImmutableArraySeq } for (useVectorizedReader <- Seq(false, true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 4405fcd2f34e8..4ed5297ff4ead 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -55,6 +55,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.tags.ExtendedSQLTest import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -200,7 +201,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared Seq(3, 20).foreach { threshold => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( - In(tsAttr, Array(ts2.ts, ts3.ts, ts4.ts, "2021-05-01 00:01:02".ts).map(Literal.apply)), + In(tsAttr, Array(ts2.ts, ts3.ts, ts4.ts, "2021-05-01 00:01:02".ts).map(Literal.apply) + .toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(resultFun(ts2)), Row(resultFun(ts3)), Row(resultFun(ts4)))) } @@ -361,7 +363,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared Seq(3, 20).foreach { threshold => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( - In(intAttr, Array(2, 3, 4, 5, 6, 7).map(Literal.apply)), + In(intAttr, Array(2, 3, 4, 5, 6, 7).map(Literal.apply).toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(resultFun(2)), Row(resultFun(3)), Row(resultFun(4)))) } @@ -403,7 +405,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared Seq(3, 20).foreach { threshold => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( - In(intAttr, Array(2, 3, 4, 5, 6, 7).map(Literal.apply)), + In(intAttr, Array(2, 3, 4, 5, 6, 7).map(Literal.apply).toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(2), Row(3), Row(4))) } @@ -446,7 +448,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared Seq(3, 20).foreach { threshold => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( - In(longAttr, Array(2L, 3L, 4L, 5L, 6L, 7L).map(Literal.apply)), + In(longAttr, Array(2L, 3L, 4L, 5L, 6L, 7L).map(Literal.apply).toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(resultFun(2L)), Row(resultFun(3L)), Row(resultFun(4L)))) } @@ -488,7 +490,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared Seq(3, 20).foreach { threshold => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( - In(longAttr, Array(2L, 3L, 4L, 5L, 6L, 7L).map(Literal.apply)), + In(longAttr, Array(2L, 3L, 4L, 5L, 6L, 7L).map(Literal.apply).toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(2L), Row(3L), Row(4L))) } @@ -531,7 +533,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared Seq(3, 20).foreach { threshold => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( - In(floatAttr, Array(2F, 3F, 4F, 5F, 6F, 7F).map(Literal.apply)), + In(floatAttr, Array(2F, 3F, 4F, 5F, 6F, 7F).map(Literal.apply).toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(resultFun(2F)), Row(resultFun(3F)), Row(resultFun(4F)))) } @@ -575,7 +577,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared Seq(3, 20).foreach { threshold => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( - In(doubleAttr, Array(2.0D, 3.0D, 4.0D, 5.0D, 6.0D, 7.0D).map(Literal.apply)), + In(doubleAttr, Array(2.0D, 3.0D, 4.0D, 5.0D, 6.0D, 7.0D).map(Literal.apply) + .toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(resultFun(2D)), Row(resultFun(3D)), Row(resultFun(4F)))) } @@ -619,7 +622,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared Seq(3, 20).foreach { threshold => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( - In(stringAttr, Array("2", "3", "4", "5", "6", "7").map(Literal.apply)), + In(stringAttr, Array("2", "3", "4", "5", "6", "7").map(Literal.apply) + .toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(resultFun("2")), Row(resultFun("3")), Row(resultFun("4")))) } @@ -668,7 +672,8 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared Seq(3, 20).foreach { threshold => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( - In(binaryAttr, Array(2.b, 3.b, 4.b, 5.b, 6.b, 7.b).map(Literal.apply)), + In(binaryAttr, Array(2.b, 3.b, 4.b, 5.b, 6.b, 7.b).map(Literal.apply) + .toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(resultFun(2.b)), Row(resultFun(3.b)), Row(resultFun(4.b)))) } @@ -745,7 +750,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( In(dateAttr, Array("2018-03-19".date, "2018-03-20".date, "2018-03-21".date, - "2018-03-22".date).map(Literal.apply)), + "2018-03-22".date).map(Literal.apply).toList), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(resultFun("2018-03-19")), Row(resultFun("2018-03-20")), Row(resultFun("2018-03-21")))) @@ -854,7 +859,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( In(decimalAttr, Array(2, 3, 4, 5).map(Literal.apply) - .map(_.cast(DecimalType(precision, 2)))), + .map(_.cast(DecimalType(precision, 2))).toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(resultFun(2)), Row(resultFun(3)), Row(resultFun(4)))) } @@ -2040,7 +2045,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared Seq(3, 20).foreach { threshold => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( - In(iAttr, Array(2, 3, 4, 5, 6, 7).map(monthsLit)), + In(iAttr, Array(2, 3, 4, 5, 6, 7).map(monthsLit).toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(resultFun(months(2))), Row(resultFun(months(3))), Row(resultFun(months(4))))) } @@ -2086,7 +2091,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared Seq(3, 20).foreach { threshold => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD.key -> s"$threshold") { checkFilterPredicate( - In(iAttr, Array(2, 3, 4, 5, 6, 7).map(secsLit)), + In(iAttr, Array(2, 3, 4, 5, 6, 7).map(secsLit).toImmutableArraySeq), if (threshold == 3) classOf[FilterIn[_]] else classOf[Operators.Or], Seq(Row(resultFun(secs(2))), Row(resultFun(secs(3))), Row(resultFun(secs(4))))) } @@ -2255,7 +2260,8 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) val parquetFilters = createParquetFilters(schema) // In this test suite, all the simple predicates are convertible here. - assert(parquetFilters.convertibleFilters(sourceFilters) === pushedFilters) + assert( + parquetFilters.convertibleFilters(sourceFilters.toImmutableArraySeq) === pushedFilters) val pushedParquetFilters = pushedFilters.map { pred => val maybeFilter = parquetFilters.createFilter(pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala index 24dd98da82580..aaaf0399d3158 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{LongType, StringType} import org.apache.spark.tags.SlowSQLTest +import org.apache.spark.util.ArrayImplicits._ @SlowSQLTest class ParquetRowIndexSuite extends QueryTest with SharedSparkSession { @@ -49,7 +50,7 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession { assert(dir.isDirectory) dir.listFiles() .filter { f => f.isFile && f.getName.endsWith("parquet") } - .map { f => readRowGroupRowCounts(f.getAbsolutePath) } + .map { f => readRowGroupRowCounts(f.getAbsolutePath) }.toImmutableArraySeq } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala index 91a8a38928224..9dbcae84aa75e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordR import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ /** * A test suite on the vectorized Parquet reader. Unlike `ParquetIOSuite`, this focuses on @@ -509,7 +510,7 @@ class ParquetVectorizedSuite extends QueryTest with ParquetTest with SharedSpark val pageFirstRowIndexes = ArrayBuffer.empty[Long] pageSizes.foreach { size => pageFirstRowIndexes += i - writeDataPage(cd, memPageStore, repetitionLevels.slice(i, i + size), + writeDataPage(cd, memPageStore, repetitionLevels.slice(i, i + size).toImmutableArraySeq, definitionLevels.slice(i, i + size), inputValues.slice(i, i + size), maxDef, dictionaryEnabled) i += size diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 4bfc34a80aa65..079ab994736b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.CompactBuffer class HashedRelationSuite extends SharedSparkSession { @@ -690,7 +691,7 @@ class HashedRelationSuite extends SharedSparkSession { } else { keyIndexToKeyMap(keyIndex) = i.toString } - keyIndexToValueMap(keyIndex) = actualValues + keyIndexToValueMap(keyIndex) = actualValues.toImmutableArraySeq // key index is non-negative assert(keyIndex >= 0) // values are expected diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index d2b751de71cc1..27cdeaeb46238 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path, RawLocalFileSy import org.apache.spark.SparkFunSuite import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.ArrayImplicits._ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSparkSession { @@ -296,7 +297,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSparkSession { private def readFromResource(dir: String): Seq[SinkFileStatus] = { val input = getClass.getResource(s"/structured-streaming/$dir") val log = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, input.toString) - log.allFiles() + log.allFiles().toImmutableArraySeq } private def withCountOpenLocalFileSystemAsLocalFileSystem(body: => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala index b27aa37d531f4..7c967c7de152f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.SerializeFromObjectExec import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming._ +import org.apache.spark.util.ArrayImplicits._ case class KV(key: Int, value: Long) @@ -220,7 +221,8 @@ class ForeachBatchSinkSuite extends StreamTest { def check(in: Int*)(out: T*): Test = Check(in, out) def checkMetrics: Test = CheckMetrics - def record(batchId: Long, ds: Dataset[T]): Unit = recordedOutput.put(batchId, ds.collect()) + def record(batchId: Long, ds: Dataset[T]): Unit = + recordedOutput.put(batchId, ds.collect().toImmutableArraySeq) implicit def conv(x: (Int, Long)): KV = KV(x._1, x._2) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index b12450167d7a5..324717d92c972 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.{count, timestamp_seconds, window} import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.ArrayImplicits._ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeAndAfter { @@ -310,7 +311,7 @@ object ForeachWriterSuite { } def allEvents(): Seq[Seq[Event]] = { - _allEvents.toArray(new Array[Seq[Event]](_allEvents.size())) + _allEvents.toArray(new Array[Seq[Event]](_allEvents.size())).toImmutableArraySeq } def clear(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index ac50e55202919..ddef26224f240 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.tags.SlowSQLTest import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ trait RocksDBStateStoreChangelogCheckpointingTestUtil { val rocksdbChangelogCheckpointingConfKey: String = RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + @@ -48,6 +49,7 @@ trait RocksDBStateStoreChangelogCheckpointingTestUtil { .map(_.getName.stripSuffix(".zip")) .map(_.toLong) .sorted + .toImmutableArraySeq } def changelogVersionsPresent(dir: File): Seq[Long] = { @@ -55,6 +57,7 @@ trait RocksDBStateStoreChangelogCheckpointingTestUtil { .map(_.getName.stripSuffix(".changelog")) .map(_.toLong) .sorted + .toImmutableArraySeq } } @@ -1307,6 +1310,7 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared def listFiles(file: File): Seq[File] = { if (!file.exists()) return Seq.empty file.listFiles.filter(file => !file.getName.endsWith("crc") && !file.isDirectory) + .toImmutableArraySeq } def listFiles(file: String): Seq[File] = listFiles(new File(file)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 3cbf0bb899930..7cce6086c6fd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { private def withVector( @@ -502,7 +503,7 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { test("CachedBatch boolean Apis") { val dataType = BooleanType val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) - val row = new SpecificInternalRow(Array(dataType)) + val row = new SpecificInternalRow(Array(dataType).toImmutableArraySeq) row.setNullAt(0) columnBuilder.appendFrom(row, 0) @@ -526,7 +527,7 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { test("CachedBatch byte Apis") { val dataType = ByteType val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) - val row = new SpecificInternalRow(Array(dataType)) + val row = new SpecificInternalRow(Array(dataType).toImmutableArraySeq) row.setNullAt(0) columnBuilder.appendFrom(row, 0) @@ -550,7 +551,7 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { test("CachedBatch short Apis") { val dataType = ShortType val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) - val row = new SpecificInternalRow(Array(dataType)) + val row = new SpecificInternalRow(Array(dataType).toImmutableArraySeq) row.setNullAt(0) columnBuilder.appendFrom(row, 0) @@ -574,7 +575,7 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { test("CachedBatch int Apis") { val dataType = IntegerType val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) - val row = new SpecificInternalRow(Array(dataType)) + val row = new SpecificInternalRow(Array(dataType).toImmutableArraySeq) row.setNullAt(0) columnBuilder.appendFrom(row, 0) @@ -598,7 +599,7 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { test("CachedBatch long Apis") { Seq(LongType, TimestampType, TimestampNTZType).foreach { dataType => val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) - val row = new SpecificInternalRow(Array(dataType)) + val row = new SpecificInternalRow(Array(dataType).toImmutableArraySeq) row.setNullAt(0) columnBuilder.appendFrom(row, 0) @@ -623,7 +624,7 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { test("CachedBatch float Apis") { val dataType = FloatType val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) - val row = new SpecificInternalRow(Array(dataType)) + val row = new SpecificInternalRow(Array(dataType).toImmutableArraySeq) row.setNullAt(0) columnBuilder.appendFrom(row, 0) @@ -647,7 +648,7 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { test("CachedBatch double Apis") { val dataType = DoubleType val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) - val row = new SpecificInternalRow(Array(dataType)) + val row = new SpecificInternalRow(Array(dataType).toImmutableArraySeq) row.setNullAt(0) columnBuilder.appendFrom(row, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index ec95257b0ee4c..933447354fd95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -46,6 +46,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column import org.apache.spark.tags.ExtendedSQLTest import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.ArrayImplicits._ @ExtendedSQLTest class ColumnarBatchSuite extends SparkFunSuite { @@ -1476,7 +1477,7 @@ class ColumnarBatchSuite extends SparkFunSuite { case _ => assert(a1 === a2, "Seed = " + seed) } case StructType(childFields) => - compareStruct(childFields, r1.getStruct(ordinal, fields.length), + compareStruct(childFields.toImmutableArraySeq, r1.getStruct(ordinal, fields.length), r2.getStruct(ordinal), seed) case _ => throw new UnsupportedOperationException("Not implemented " + field.dataType) @@ -1526,9 +1527,9 @@ class ColumnarBatchSuite extends SparkFunSuite { var i = 0 while (i < NUM_ITERS) { val schema = if (flatSchema) { - RandomDataGenerator.randomSchema(random, numFields, types) + RandomDataGenerator.randomSchema(random, numFields, types.toImmutableArraySeq) } else { - RandomDataGenerator.randomNestedSchema(random, numFields, types) + RandomDataGenerator.randomNestedSchema(random, numFields, types.toImmutableArraySeq) } val rows = mutable.ArrayBuffer.empty[Row] var j = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 42cfe7c81a8da..ccb202085910a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { @@ -92,13 +93,15 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { conn1.close() } - private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) - private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) + private lazy val arr2x2 = + Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)).toImmutableArraySeq + private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)).toImmutableArraySeq private lazy val schema2 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: Nil) - private lazy val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) + private lazy val arr2x3 = + Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)).toImmutableArraySeq private lazy val schema3 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 657ef5ca13bd9..780c29d693101 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits._ class FilteredScanSource extends RelationProvider { override def createRelation( @@ -78,7 +79,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S Seq(c * 5 + c.toUpperCase(Locale.ROOT) * 5) } - FiltersPushed.list = filters + FiltersPushed.list = filters.toImmutableArraySeq ColumnsRequired.set = requiredColumns.toSet // Predicate test on integer column diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 5c8a55b3f63bb..04193d5189ae9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.tags.SlowSQLTest +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils abstract class FileStreamSinkSuite extends StreamTest { @@ -751,7 +752,7 @@ class FileStreamSinkV2Suite extends FileStreamSinkSuite { }.headOption.getOrElse { fail(s"No FileScan in query\n${df.queryExecution}") } - func(fileScan.planInputPartitions().map(_.asInstanceOf[FilePartition])) + func(fileScan.planInputPartitions().map(_.asInstanceOf[FilePartition]).toImmutableArraySeq) } // Read without pruning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 0d7d85acaebd1..bd0285db46d96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -46,6 +46,7 @@ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.tags.SlowSQLTest +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils abstract class FileStreamSourceTest @@ -1699,7 +1700,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { private def readLogFromResource(dir: String): Seq[FileEntry] = { val input = getClass.getResource(s"/structured-streaming/$dir") val log = new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, input.toString) - log.allFiles() + log.allFiles().toImmutableArraySeq } private def readOffsetFromResource(file: String): SerializedOffset = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index 07837f5c06473..5f9ff25e16e73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.util.ArrayImplicits._ trait StateStoreMetricsTest extends StreamTest { @@ -58,18 +59,20 @@ trait StateStoreMetricsTest extends StreamTest { val numTotalRows = progressesSinceLastCheck.last.stateOperators.map(_.numRowsTotal) assert(numTotalRows === total, s"incorrect total rows, $debugString") - val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) + val numUpdatedRows = arraySum( + allNumUpdatedRowsSinceLastCheck.toImmutableArraySeq, numStateOperators) assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") - val numRowsDroppedByWatermark = arraySum(allNumRowsDroppedByWatermarkSinceLastCheck, - numStateOperators) + val numRowsDroppedByWatermark = arraySum( + allNumRowsDroppedByWatermarkSinceLastCheck.toImmutableArraySeq, numStateOperators) assert(numRowsDroppedByWatermark === droppedByWatermark, s"incorrect dropped rows by watermark, $debugString") if (removed.isDefined) { val allNumRowsRemovedSinceLastCheck = progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsRemoved)) - val numRemovedRows = arraySum(allNumRowsRemovedSinceLastCheck, numStateOperators) + val numRemovedRows = arraySum( + allNumRowsRemovedSinceLastCheck.toImmutableArraySeq, numStateOperators) assert(numRemovedRows === removed.get, s"incorrect removed rows, $debugString") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 03780478b331b..be84640f4bf36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualCloc import org.apache.spark.sql.types.{StructType, TimestampType} import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.tags.SlowSQLTest +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils object FailureSingleton { @@ -211,7 +212,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { } def stateOperatorProgresses: Seq[StateOperatorProgress] = { - lastExecutedBatch.stateOperators + lastExecutedBatch.stateOperators.toImmutableArraySeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 4c478486c6b39..8fe4ef39b2552 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite._ import org.apache.spark.sql.streaming.StreamingQuerySuite.clock import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { test("StreamingQueryProgress - prettyJson") { @@ -278,7 +279,7 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { p.sources.size >= 1 && p.stateOperators.size >= 1 && p.sink != null }) - val array = spark.sparkContext.parallelize(progress).collect() + val array = spark.sparkContext.parallelize(progress.toImmutableArraySeq).collect() assert(array.length === progress.length) array.zip(progress).foreach { case (p1, p2) => // Make sure we did serialize and deserialize the object diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index fab5c49da8e41..3264141aa7abe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.UninterruptibleThread import org.apache.spark.util.Utils @@ -203,7 +204,7 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with */ protected def withTempPaths(numPaths: Int)(f: Seq[File] => Unit): Unit = { val files = Array.fill[File](numPaths)(Utils.createTempDir().getCanonicalFile) - try f(files) finally { + try f(files.toImmutableArraySeq) finally { // wait for all tasks to finish before deleting files waitForTasksToFinish() files.foreach(Utils.deleteRecursively) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index bd323dc4b24e1..a65f405d6aefc 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.tags.SlowHiveTest +import org.apache.spark.util.ArrayImplicits._ /** * Runs the test cases that are included in the hive distribution. @@ -45,7 +46,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.conf.getConf(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT) def testCases: Seq[(String, File)] = { - hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f).toImmutableArraySeq } override def beforeAll(): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 4e1ae0ad23466..bfc4fac6280da 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * A trait for subclasses that handle table scans. @@ -497,7 +498,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { val unwrapper = unwrapperFor(oi) (value: Any, row: InternalRow, ordinal: Int) => row(ordinal) = unwrapper(value) } - } + }.toImmutableArraySeq val converter = ObjectInspectorConverters.getConverter(rawDeser.getObjectInspector, soi) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 68b0f2176fbc0..4027cd94d4150 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.internal.NonClosableMutableURLClassLoader import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{MavenUtils, MutableURLClassLoader, Utils, VersionUtils} +import org.apache.spark.util.ArrayImplicits._ /** Factory for `IsolatedClientLoader` with specific versions of hive. */ private[hive] object IsolatedClientLoader extends Logging { @@ -143,7 +144,7 @@ private[hive] object IsolatedClientLoader extends Logging { val tempDir = Utils.createTempDir(namePrefix = s"hive-${version}") allFiles.foreach(f => FileUtils.copyFileToDirectory(f, tempDir)) logInfo(s"Downloaded metastore jars to ${tempDir.getCanonicalPath}") - tempDir.listFiles().map(_.toURI.toURL) + tempDir.listFiles().map(_.toURI.toURL).toImmutableArraySeq } // A map from a given pair of HiveVersion and Hadoop version to jar files. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/V1WritesHiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/V1WritesHiveUtils.scala index 6421dd184ae0d..e6b1019e717ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/V1WritesHiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/V1WritesHiveUtils.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.hive.client.HiveClientImpl +import org.apache.spark.util.ArrayImplicits._ trait V1WritesHiveUtils { @@ -99,7 +100,7 @@ trait V1WritesHiveUtils { // during `loadDynamicPartitions`. Spark needs to write partition directories with lower-cased // column names in order to make `loadDynamicPartitions` work. attr.withName(name.toLowerCase(Locale.ROOT)) - } + }.toImmutableArraySeq } def getOptionsWithHiveBucketWrite(bucketSpec: Option[BucketSpec]): Map[String, String] = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 258aff1f20623..89fe10d5c4bd9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType import org.apache.spark.storage.RDDBlockId import org.apache.spark.storage.StorageLevel.{DISK_ONLY, MEMORY_ONLY} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -212,7 +213,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto assertCached(table("refreshTable")) checkAnswer( table("refreshTable"), - table("src").union(table("src")).collect()) + table("src").union(table("src")).collect().toImmutableArraySeq) // Drop the table and create it again. sql("DROP TABLE refreshTable") @@ -224,7 +225,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("REFRESH TABLE refreshTable") checkAnswer( table("refreshTable"), - table("src").union(table("src")).collect()) + table("src").union(table("src")).collect().toImmutableArraySeq) // It is not cached. assert(!isCached("refreshTable"), "refreshTable should not be cached.") @@ -240,7 +241,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sparkSession.catalog.createTable("refreshTable", tempPath.toString, "parquet") checkAnswer( table("refreshTable"), - table("src").collect()) + table("src").collect().toImmutableArraySeq) // Cache the table. sql("CACHE TABLE refreshTable") assertCached(table("refreshTable")) @@ -252,7 +253,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto assertCached(table("refreshTable")) checkAnswer( table("refreshTable"), - table("src").union(table("src")).collect()) + table("src").union(table("src")).collect().toImmutableArraySeq) // Drop the table and create it again. sql("DROP TABLE refreshTable") @@ -264,7 +265,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql(s"REFRESH ${tempPath.toString}") checkAnswer( table("refreshTable"), - table("src").union(table("src")).collect()) + table("src").union(table("src")).collect().toImmutableArraySeq) // It is not cached. assert(!isCached("refreshTable"), "refreshTable should not be cached.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 40d983e10616e..f63e75c9e4e1a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.tags.{ExtendedHiveTest, SlowHiveTest} import org.apache.spark.util.{Utils, VersionUtils} +import org.apache.spark.util.ArrayImplicits._ /** * Test HiveExternalCatalog backward compatibility. @@ -270,7 +271,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { .filter(_.contains("""""".r.findFirstMatchIn(_).get.group(1)) - .filter(_ < org.apache.spark.SPARK_VERSION) + .filter(_ < org.apache.spark.SPARK_VERSION).toImmutableArraySeq } catch { // Do not throw exception during object initialization. case NonFatal(_) => Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUDFDynamicLoadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUDFDynamicLoadSuite.scala index ee8e6f4f78be5..fa54d4898f1d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUDFDynamicLoadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUDFDynamicLoadSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class HiveUDFDynamicLoadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -51,7 +52,7 @@ class HiveUDFDynamicLoadSuite extends QueryTest with SQLTestUtils with TestHiveS HiveFunctionWrapper("org.apache.hadoop.hive.contrib.udf.example.UDFExampleAdd2"), Array( AttributeReference("a", IntegerType, nullable = false)(), - AttributeReference("b", IntegerType, nullable = false)())) + AttributeReference("b", IntegerType, nullable = false)()).toImmutableArraySeq) }), // GenericUDF @@ -67,7 +68,7 @@ class HiveUDFDynamicLoadSuite extends QueryTest with SQLTestUtils with TestHiveS HiveGenericUDF( "default.generic_udf_trim2", HiveFunctionWrapper("org.apache.hadoop.hive.contrib.udf.example.GenericUDFTrim2"), - Array(AttributeReference("a", StringType, nullable = false)()) + Array(AttributeReference("a", StringType, nullable = false)()).toImmutableArraySeq ) } ), @@ -89,7 +90,7 @@ class HiveUDFDynamicLoadSuite extends QueryTest with SQLTestUtils with TestHiveS HiveUDAFFunction( "default.generic_udaf_sum2", HiveFunctionWrapper("org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum2"), - Array(AttributeReference("a", IntegerType, nullable = false)()) + Array(AttributeReference("a", IntegerType, nullable = false)()).toImmutableArraySeq ) } ), @@ -111,7 +112,7 @@ class HiveUDFDynamicLoadSuite extends QueryTest with SQLTestUtils with TestHiveS HiveUDAFFunction( "default.udaf_max2", HiveFunctionWrapper("org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax2"), - Array(AttributeReference("a", IntegerType, nullable = false)()), + Array(AttributeReference("a", IntegerType, nullable = false)()).toImmutableArraySeq, isUDAFBridgeRequired = true ) } @@ -133,11 +134,11 @@ class HiveUDFDynamicLoadSuite extends QueryTest with SQLTestUtils with TestHiveS HiveGenericUDTF( "default.udtf_count3", HiveFunctionWrapper("org.apache.hadoop.hive.contrib.udtf.example.GenericUDTFCount3"), - Array.empty[Expression] + Array.empty[Expression].toImmutableArraySeq ) } ) - ) + ).toImmutableArraySeq udfTestInfos.foreach { udfInfo => // The test jars are built from below commit: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 4e8a62acddd72..5c6b8fda71cc7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType._ import org.apache.spark.sql.types.YearMonthIntervalType._ import org.apache.spark.tags.SlowHiveTest import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.ArrayImplicits._ @SlowHiveTest class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with TestHiveSingleton { @@ -77,7 +78,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T child = child, ioschema = hiveIOSchema ), - rowsDf.collect()) + rowsDf.collect().toImmutableArraySeq) assert(uncaughtExceptionHandler.exception.isEmpty) } @@ -94,7 +95,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T child = ExceptionInjectingOperator(child), ioschema = hiveIOSchema ), - rowsDf.collect()) + rowsDf.collect().toImmutableArraySeq) } assert(e.getMessage().contains("intentional exception")) // Before SPARK-25158, uncaughtExceptionHandler will catch IllegalArgumentException @@ -135,7 +136,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T child = child, ioschema = hiveIOSchema ), - rowsDf.select("name").collect()) + rowsDf.select("name").collect().toImmutableArraySeq) assert(uncaughtExceptionHandler.exception.isEmpty) } @@ -183,7 +184,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T $"b".cast("string"), $"c".cast("string"), $"d".cast("string"), - $"e".cast("string")).as("value")).collect()) + $"e".cast("string")).as("value")).collect().toImmutableArraySeq) // In hive default serde mode, if we don't define output schema, // when output column size > 2 and just specify serde, @@ -207,7 +208,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T identity, df.select( $"a".cast("string").as("key"), - $"b".cast("string").as("value")).collect()) + $"b".cast("string").as("value")).collect().toImmutableArraySeq) // In hive default serde mode, if we don't define output schema, @@ -239,7 +240,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T $"b".cast("string"), $"c".cast("string"), $"d".cast("string"), - $"e".cast("string")).as("value")).collect()) + $"e".cast("string")).as("value")).collect().toImmutableArraySeq) // In hive default serde mode, if we don't define output schema, // when output column size > 2 and specify serde @@ -265,7 +266,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T identity, df.select( $"a".cast("string").as("key"), - $"b".cast("string").as("value")).collect()) + $"b".cast("string").as("value")).collect().toImmutableArraySeq) // In hive default serde mode, if we don't define output schema, // when output column size = 2 and specify serde, it will these two column as @@ -290,7 +291,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T identity, df.select( $"a".cast("string").as("key"), - $"b".cast("string").as("value")).collect()) + $"b".cast("string").as("value")).collect().toImmutableArraySeq) // In hive default serde mode, if we don't define output schema, // when output column size < 2 and specify serde, it will return null for deficiency @@ -315,7 +316,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T identity, df.select( $"a".cast("string").as("key"), - lit(null)).collect()) + lit(null)).collect().toImmutableArraySeq) } } @@ -346,7 +347,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T child = child, ioschema = hiveIOSchema ), - df.select($"c", $"d", $"e").collect()) + df.select($"c", $"d", $"e").collect().toImmutableArraySeq) } } @@ -367,7 +368,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T |USING 'cat' AS (c array, d map, e struct) |FROM v """.stripMargin) - checkAnswer(query, identity, df.select($"c", $"d", $"e").collect()) + checkAnswer(query, identity, df.select($"c", $"d", $"e").collect().toImmutableArraySeq) } } @@ -571,7 +572,8 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T AttributeReference("j", DayTimeIntervalType(SECOND))()), child = child, ioschema = hiveIOSchema), - df.select($"a", $"b", $"c", $"d", $"e", $"f", $"g", $"h", $"i", $"j").collect()) + df.select($"a", $"b", $"c", $"d", $"e", $"f", $"g", $"h", $"i", $"j").collect() + .toImmutableArraySeq) } } @@ -598,7 +600,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T AttributeReference("c", YearMonthIntervalType(MONTH))()), child = child, ioschema = hiveIOSchema), - df.select($"a", $"b", $"c").collect()) + df.select($"a", $"b", $"c").collect().toImmutableArraySeq) } } @@ -616,7 +618,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T output = Seq(AttributeReference("a", DayTimeIntervalType())()), child = child, ioschema = hiveIOSchema), - df.select($"a").collect()) + df.select($"a").collect().toImmutableArraySeq) }.getMessage assert(e.contains("java.lang.ArithmeticException: long overflow")) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index f183c732250e4..86401bf923927 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.internal.LegacyBehaviorPolicy._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.util.ArrayImplicits._ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -607,7 +608,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes require(dir.listFiles().exists(!_.isDirectory)) require(subdir.exists()) require(subdir.listFiles().exists(!_.isDirectory)) - testWithPath(dir, dataInDir.collect()) + testWithPath(dir, dataInDir.collect().toImmutableArraySeq) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 3fee7df19ab5e..9aa4cb77b226a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SerializableConfiguration class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { @@ -67,7 +68,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { SimpleTextRelation.lastHadoopConf = Option(hadoopConf) - SimpleTextRelation.requiredColumns = requiredSchema.fieldNames + SimpleTextRelation.requiredColumns = requiredSchema.fieldNames.toImmutableArraySeq SimpleTextRelation.pushedFilters = filters.toSet val fieldTypes = dataSchema.map(_.dataType) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 1500f72763f9f..8034b1b21715c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.UI._ import org.apache.spark.io.CompressionCodec import org.apache.spark.streaming.scheduler.JobGenerator +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils private[streaming] @@ -136,7 +137,7 @@ object Checkpoint extends Logging { if (statuses != null) { val paths = statuses.filterNot(_.isDirectory).map(_.getPath) val filtered = paths.filter(p => REGEX.findFirstIn(p.getName).nonEmpty) - filtered.sortWith(sortFunc) + filtered.sortWith(sortFunc).toImmutableArraySeq } else { logWarning(s"Listing $path returned null") Seq.empty diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 850628080eca7..414fdf5d619dd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamInputInfo import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * This class represents an input stream that monitors a Hadoop-compatible filesystem for new @@ -149,7 +150,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( batchTimeToSelectedFiles += ((validTime, newFiles)) } recentlySelectedFiles ++= newFiles - val rdds = Some(filesToRDD(newFiles)) + val rdds = Some(filesToRDD(newFiles.toImmutableArraySeq)) // Copy newFiles to immutable.List to prevent from being modified by the user val metadata = Map( "files" -> newFiles.toList, @@ -343,7 +344,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( f.mkString("[", ", ", "]") ) batchTimeToSelectedFiles.synchronized { batchTimeToSelectedFiles += ((t, f)) } recentlySelectedFiles ++= f - generatedRDDs += ((t, filesToRDD(f))) + generatedRDDs += ((t, filesToRDD(f.toImmutableArraySeq))) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 6494e512713f8..9c461f0d4270e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -28,6 +28,7 @@ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.util._ import org.apache.spark.util._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.io.ChunkedByteBuffer /** @@ -195,6 +196,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( try { HdfsUtils.getFileSegmentLocations( fileSegment.path, fileSegment.offset, fileSegment.length, hadoopConfig) + .toImmutableArraySeq } catch { case NonFatal(e) => logError("Error getting WAL file segment locations", e) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 685ddf67237b3..d5bc658b4b50c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -32,6 +32,7 @@ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.util.{SerializableConfiguration, ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ /** Enumeration to identify current state of a Receiver */ @@ -108,7 +109,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private val receivedBlockTracker = new ReceivedBlockTracker( ssc.sparkContext.conf, ssc.sparkContext.hadoopConfiguration, - receiverInputStreamIds, + receiverInputStreamIds.toImmutableArraySeq, ssc.scheduler.clock, ssc.isCheckpointPresent, Option(ssc.checkpointDir) @@ -443,7 +444,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false runDummySparkJob() logInfo("Starting " + receivers.length + " receivers") - endpoint.send(StartAllReceivers(receivers)) + endpoint.send(StartAllReceivers(receivers.toImmutableArraySeq)) } /** Check if tracker has been marked for starting */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 908d155908fce..a2e29a1cfa005 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.{CompletionIterator, ThreadUtils} +import org.apache.spark.util.ArrayImplicits._ /** * This class manages write ahead log files. @@ -246,7 +247,7 @@ private[streaming] class FileBasedWriteAheadLog( // leads to much clearer code. if (fileSystem.getFileStatus(logDirectoryPath).isDirectory) { val logFileInfo = logFilesTologInfo( - fileSystem.listStatus(logDirectoryPath).map { _.getPath }) + fileSystem.listStatus(logDirectoryPath).map { _.getPath }.toImmutableArraySeq) pastLogs.clear() pastLogs ++= logFileInfo logInfo(s"Recovered ${logFileInfo.size} write ahead log files from $logDirectory") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index afe6c73a7b8ca..75edbb173faad 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils /** Testsuite for testing the network receiver behavior */ @@ -223,6 +224,7 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { try { if (logDirectory.exists()) { logDirectory.listFiles().filter { _.getName.startsWith("log") }.map { _.toString } + .toImmutableArraySeq } else { Seq.empty } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index dc6ebaf2162e6..214b1466c47ad 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.{DStream, ForEachDStream, InputDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.util.ArrayImplicits._ /** * A dummy stream that does absolutely nothing. @@ -97,7 +98,7 @@ class TestOutputStream[T: ClassTag]( new ConcurrentLinkedQueue[Seq[T]]() ) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() - output.add(collected) + output.add(collected.toImmutableArraySeq) }, false) { // This is to clear the output buffer every it is read from a checkpoint @@ -121,7 +122,7 @@ class TestOutputStreamWithPartitions[T: ClassTag]( new ConcurrentLinkedQueue[Seq[Seq[T]]]()) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.glom().collect().map(_.toSeq) - output.add(collected) + output.add(collected.toImmutableArraySeq) }, false) { // This is to clear the output buffer every it is read from a checkpoint diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index ad35841caa921..96a48574f30c8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils class WriteAheadLogBackedBlockRDDSuite extends SparkFunSuite { @@ -178,7 +179,7 @@ class WriteAheadLogBackedBlockRDDSuite extends SparkFunSuite { // Generate write ahead log record handles val recordHandles = generateFakeRecordHandles(numPartitions - numPartitionsInWAL) ++ generateWALRecordHandles(data.takeRight(numPartitionsInWAL), - blockIds.takeRight(numPartitionsInWAL)) + blockIds.takeRight(numPartitionsInWAL).toImmutableArraySeq) // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not require( @@ -255,6 +256,6 @@ class WriteAheadLogBackedBlockRDDSuite extends SparkFunSuite { } private def generateFakeRecordHandles(count: Int): Seq[FileBasedWriteAheadLogSegment] = { - Array.fill(count)(new FileBasedWriteAheadLogSegment("random", 0L, 0)) + Array.fill(count)(FileBasedWriteAheadLogSegment("random", 0L, 0)).toImmutableArraySeq } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 97f37cb1f03da..43b8480568dac 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -43,6 +43,7 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{CompletionIterator, ManualClock, ThreadUtils, Utils} +import org.apache.spark.util.ArrayImplicits._ /** Common tests for WriteAheadLogs that we would like to test with different configurations. */ abstract class CommonWriteAheadLogTests( @@ -718,7 +719,7 @@ object WriteAheadLogSuite { val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching) val data = wal.readAll().asScala.map(byteBufferToString).toArray wal.close() - data + data.toImmutableArraySeq } /** Get the log files in a directory. */ @@ -732,7 +733,7 @@ object WriteAheadLogSuite { _.getName().split("-")(1).toLong }.map { _.toString.stripPrefix("file:") - } + }.toImmutableArraySeq } else { Seq.empty } diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 07157da5c3517..acab365418afa 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -18,7 +18,7 @@ // scalastyle:off classforname package org.apache.spark.tools -import scala.collection.mutable +import scala.collection.{immutable, mutable} import scala.reflect.runtime.{universe => unv} import scala.reflect.runtime.universe.runtimeMirror import scala.util.Try @@ -100,8 +100,10 @@ object GenerateMIMAIgnore { */ def getInnerFunctions(classSymbol: unv.ClassSymbol): Seq[String] = { try { - Class.forName(classSymbol.fullName, false, classLoader).getMethods.map(_.getName) + val ret = Class.forName(classSymbol.fullName, false, classLoader) + .getMethods.map(_.getName) .filter(_.contains("$$")).map(classSymbol.fullName + "." + _) + immutable.ArraySeq.unsafeWrapArray(ret) } catch { case t: Throwable => // scalastyle:off println From b501a223bfcf4ddbcb0b2447aa06c549051630b0 Mon Sep 17 00:00:00 2001 From: "longfei.jiang" Date: Sat, 11 Nov 2023 13:49:18 +0800 Subject: [PATCH 092/121] [MINOR][DOCS] Fix the example value in the docs ### What changes were proposed in this pull request? fix the example value ### Why are the changes needed? for doc ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Just example value in the docs, no need to test. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43750 from jlfsdtc/fix_typo_in_doc. Authored-by: longfei.jiang Signed-off-by: Kent Yao --- docs/sql-ref-datetime-pattern.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-ref-datetime-pattern.md b/docs/sql-ref-datetime-pattern.md index 5e28a18acefa4..e5d5388f262e4 100644 --- a/docs/sql-ref-datetime-pattern.md +++ b/docs/sql-ref-datetime-pattern.md @@ -41,7 +41,7 @@ Spark uses pattern letters in the following table for date and timestamp parsing |**a**|am-pm-of-day|am-pm|PM| |**h**|clock-hour-of-am-pm (1-12)|number(2)|12| |**K**|hour-of-am-pm (0-11)|number(2)|0| -|**k**|clock-hour-of-day (1-24)|number(2)|0| +|**k**|clock-hour-of-day (1-24)|number(2)|1| |**H**|hour-of-day (0-23)|number(2)|0| |**m**|minute-of-hour|number(2)|30| |**s**|second-of-minute|number(2)|55| From f9c8c7a3312e533afb95e527ca0c451148bad6a4 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Sun, 12 Nov 2023 13:45:09 +0800 Subject: [PATCH 093/121] [SPARK-45871][CONNECT] Optimizations collection conversion related to `.toBuffer` in the `connect` modules ### What changes were proposed in this pull request? This PR includes the following optimizations related to `.toBuffer` in the `connect`s module: 1. For the two functions `sql(String, java.util.Map[String, Any]): DataFrame` and `sql(String, Array[_]): DataFrame` in `SparkSession`, the approach of using `.find` directly on the `CloseableIterator` to locate the target and utilizing `.foreach` to consume the remaining elements replaced the previous method of converting to a collection using `toBuffer.toSeq` and then searching for the target using `.find`. This approach avoids the need for an unnecessary collection creation. 2. For function `execute(proto.Relation.Builder => Unit): Unit` in `SparkSession`, as no elements are returned, `.foreach` is used instead of `.toBuffer` to avoid an unnecessary collection creation. 3. For function `execute(command: proto.Command): Seq[ExecutePlanResponse]` in `SparkSession`, in Scala 2.12, `s.c.TraversableOnce#toSeq` returns an `immutable.Stream`, which is a tail-lazy structure that may not consume all elements. Therefore, it is necessary to use `s.c.TraversableOnce#toBuffer` for materialization. However, in Scala 2.13, `s.c.IterableOnceOps#toSeq` constructs an `immutable.Seq`, which is not a lazy data structure and ensures consumption of all elements. Therefore, in Scala 2.13, `.toSeq` can be used directly to replace `toBuffer.toSeq`, saving an extra collection transformation. 4. The optimizations for the two functions `listAbandonedExecutions: Seq[ExecuteInfo]` and `listExecuteHolders` in `SparkConnectExecutionManager` are consistent with item 3 above. Additionally, to prevent helper function used for testing from creating copies, a `private[connect]` scope helper function was added to the companion object of `ExecutePlanResponseReattachableIterator`. ### Why are the changes needed? Avoid unnecessary collection copies ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Action - Added new test cases to demonstrate that both `.toSeq` and `.foreach` can consume all elements in `ResponseReattachableIterator` in Scala 2.13. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43745 from LuciferYang/SPARK-45871. Authored-by: yangjie01 Signed-off-by: yangjie01 --- .../org/apache/spark/sql/SparkSession.scala | 48 +++++++++++-------- .../client/SparkConnectClientSuite.scala | 48 +++++++++++++++++++ ...cutePlanResponseReattachableIterator.scala | 10 ++++ .../SparkConnectExecutionManager.scala | 6 +-- .../sql/connect/SparkConnectServerTest.scala | 12 +---- 5 files changed, 90 insertions(+), 34 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 34756f9a440bb..ca692d2d4f8de 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -250,15 +250,18 @@ class SparkSession private[sql] ( .setSql(sqlText) .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) - // .toBuffer forces that the iterator is consumed and closed - val responseSeq = client.execute(plan.build()).toBuffer.toSeq + val responseIter = client.execute(plan.build()) - val response = responseSeq - .find(_.hasSqlCommandResult) - .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) - - // Update the builder with the values from the result. - builder.mergeFrom(response.getSqlCommandResult.getRelation) + try { + val response = responseIter + .find(_.hasSqlCommandResult) + .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) + // Update the builder with the values from the result. + builder.mergeFrom(response.getSqlCommandResult.getRelation) + } finally { + // consume the rest of the iterator + responseIter.foreach(_ => ()) + } } /** @@ -309,15 +312,18 @@ class SparkSession private[sql] ( .setSql(sqlText) .putAllNamedArguments(args.asScala.view.mapValues(lit(_).expr).toMap.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) - // .toBuffer forces that the iterator is consumed and closed - val responseSeq = client.execute(plan.build()).toBuffer.toSeq - - val response = responseSeq - .find(_.hasSqlCommandResult) - .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) - - // Update the builder with the values from the result. - builder.mergeFrom(response.getSqlCommandResult.getRelation) + val responseIter = client.execute(plan.build()) + + try { + val response = responseIter + .find(_.hasSqlCommandResult) + .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) + // Update the builder with the values from the result. + builder.mergeFrom(response.getSqlCommandResult.getRelation) + } finally { + // consume the rest of the iterator + responseIter.foreach(_ => ()) + } } /** @@ -543,14 +549,14 @@ class SparkSession private[sql] ( f(builder) builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement()) val plan = proto.Plan.newBuilder().setRoot(builder).build() - // .toBuffer forces that the iterator is consumed and closed - client.execute(plan).toBuffer + // .foreach forces that the iterator is consumed and closed + client.execute(plan).foreach(_ => ()) } private[sql] def execute(command: proto.Command): Seq[ExecutePlanResponse] = { val plan = proto.Plan.newBuilder().setCommand(command).build() - // .toBuffer forces that the iterator is consumed and closed - client.execute(plan).toBuffer.toSeq + // .toSeq forces that the iterator is consumed and closed + client.execute(plan).toSeq } private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index d0c85da5f212e..b93713383b209 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -387,6 +387,54 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { } assert(dummyFn.counter == 2) } + + test("SPARK-45871: Client execute iterator.toSeq consumes the reattachable iterator") { + startDummyServer(0) + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .enableReattachableExecute() + .build() + val session = SparkSession.builder().client(client).create() + val cmd = session.newCommand(b => + b.setSqlCommand( + proto.SqlCommand + .newBuilder() + .setSql("select * from range(10000000)"))) + val plan = proto.Plan.newBuilder().setCommand(cmd) + val iter = client.execute(plan.build()) + val reattachableIter = + ExecutePlanResponseReattachableIterator.fromIterator(iter) + iter.toSeq + // In several places in SparkSession, we depend on `.toSeq` to consume and close the iterator. + // If this assertion fails, we need to double check the correctness of that. + // In scala 2.12 `s.c.TraversableOnce#toSeq` builds an `immutable.Stream`, + // which is a tail lazy structure and this would fail. + // In scala 2.13 `s.c.IterableOnceOps#toSeq` builds an `immutable.Seq` which is not + // lazy and will consume and close the iterator. + assert(reattachableIter.resultComplete) + } + + test("SPARK-45871: Client execute iterator.foreach consumes the reattachable iterator") { + startDummyServer(0) + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .enableReattachableExecute() + .build() + val session = SparkSession.builder().client(client).create() + val cmd = session.newCommand(b => + b.setSqlCommand( + proto.SqlCommand + .newBuilder() + .setSql("select * from range(10000000)"))) + val plan = proto.Plan.newBuilder().setCommand(cmd) + val iter = client.execute(plan.build()) + val reattachableIter = + ExecutePlanResponseReattachableIterator.fromIterator(iter) + iter.foreach(_ => ()) + assert(reattachableIter.resultComplete) + } } class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 2b61463c343fb..cfa492ef063ca 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -326,3 +326,13 @@ class ExecutePlanResponseReattachableIterator( private def retry[T](fn: => T): T = GrpcRetryHandler.retry(retryPolicy)(fn) } + +private[connect] object ExecutePlanResponseReattachableIterator { + @scala.annotation.tailrec + private[connect] def fromIterator( + iter: Iterator[proto.ExecutePlanResponse]): ExecutePlanResponseReattachableIterator = + iter match { + case e: ExecutePlanResponseReattachableIterator => e + case w: WrappedCloseableIterator[proto.ExecutePlanResponse] => fromIterator(w.innerIterator) + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index c004358e1cf18..36c6f73329b88 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -153,7 +153,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * cache, and the tombstones will be eventually removed. */ def listAbandonedExecutions: Seq[ExecuteInfo] = { - abandonedTombstones.asMap.asScala.values.toBuffer.toSeq + abandonedTombstones.asMap.asScala.values.toSeq } private[connect] def shutdown(): Unit = executionsLock.synchronized { @@ -236,7 +236,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { executions.values.foreach(_.interruptGrpcResponseSenders()) } - private[connect] def listExecuteHolders = executionsLock.synchronized { - executions.values.toBuffer.toSeq + private[connect] def listExecuteHolders: Seq[ExecuteHolder] = executionsLock.synchronized { + executions.values.toSeq } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index c4a5539ce0b7b..1c0d9a68ab6be 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -27,7 +27,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, GrpcRetryHandler, SparkConnectClient, SparkConnectStubState, WrappedCloseableIterator} +import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, GrpcRetryHandler, SparkConnectClient, SparkConnectStubState} import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.connect.common.config.ConnectCommon import org.apache.spark.sql.connect.config.Connect @@ -147,15 +147,7 @@ trait SparkConnectServerTest extends SharedSparkSession { protected def getReattachableIterator( stubIterator: CloseableIterator[proto.ExecutePlanResponse]) = { - // This depends on the wrapping in CustomSparkConnectBlockingStub.executePlanReattachable: - // GrpcExceptionConverter.convertIterator - stubIterator - .asInstanceOf[WrappedCloseableIterator[proto.ExecutePlanResponse]] - .innerIterator - .asInstanceOf[WrappedCloseableIterator[proto.ExecutePlanResponse]] - // ExecutePlanResponseReattachableIterator - .innerIterator - .asInstanceOf[ExecutePlanResponseReattachableIterator] + ExecutePlanResponseReattachableIterator.fromIterator(stubIterator) } protected def assertNoActiveRpcs(): Unit = { From 2605b87990c9826d05ad0943045e8dfa79af13e9 Mon Sep 17 00:00:00 2001 From: Bhuwan Sahni Date: Sun, 12 Nov 2023 16:50:01 +0900 Subject: [PATCH 094/121] [SPARK-45655][SQL][SS] Allow non-deterministic expressions inside AggregateFunctions in CollectMetrics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR allows non-deterministic expressions wrapped inside an `AggregateFunction` such as `count` inside `CollectMetrics` node. `CollectMetrics` is used to collect arbitrary metrics from the query, in certain scenarios user would like to collect metrics for filtering based on non-deterministic expressions (see query example below). Currently, Analyzer does not allow non-deterministic expressions inside a `AggregateFunction` for `CollectMetrics`. This constraint is relaxed to allow collection of such metrics. Note that the metrics are relevant for a completed batch, and can change if the batch is replayed (because non-deterministic expression can behave differently for different runs). While working on this feature, I found a issue with `checkMetric` logic to validate non-deterministic expressions inside an AggregateExpression. An expression is determined as non-deterministic if any of its children is non-deterministic, hence we need to match the case for `!e.deterministic && !seenAggregate` after we have matched if the current expression is a AggregateExpression. If the current expression is a AggregateExpression, we should validate further down in the tree recursively - otherwise we will fail for any non-deterministic expression. ``` val inputData = MemoryStream[Timestamp] inputData.toDF()       .filter("value < current_date()")       .observe("metrics", count(expr("value >= current_date()")).alias("dropped"))       .writeStream       .queryName("ts_metrics_test")       .format("memory")       .outputMode("append")       .start() ``` ### Why are the changes needed? 1. Added a testcase to calculate dropped rows (by `CurrentBatchTimestamp`) and ensure the query is successful. As an example, the query below fails (without this change) due to observe call on the DataFrame. ``` val inputData = MemoryStream[Timestamp] inputData.toDF()       .filter("value < current_date()")       .observe("metrics", count(expr("value >= current_date()")).alias("dropped"))       .writeStream       .queryName("ts_metrics_test")       .format("memory")       .outputMode("append")       .start() ``` 2. Added testing in AnalysisSuite for non-deterministic expressions inside a AggregateFunction. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test cases added. ``` [warn] 20 warnings found WARNING: Using incubator modules: jdk.incubator.vector, jdk.incubator.foreign [info] StreamingQueryStatusAndProgressSuite: 09:14:39.684 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable [info] Passed: Total 0, Failed 0, Errors 0, Passed 0 [info] No tests to run for hive / Test / testOnly [info] - StreamingQueryProgress - prettyJson (436 milliseconds) [info] - StreamingQueryProgress - json (3 milliseconds) [info] - StreamingQueryProgress - toString (5 milliseconds) [info] - StreamingQueryProgress - jsonString and fromJson (163 milliseconds) [info] - StreamingQueryStatus - prettyJson (1 millisecond) [info] - StreamingQueryStatus - json (1 millisecond) [info] - StreamingQueryStatus - toString (2 milliseconds) 09:14:41.674 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-34d2749f-f4d0 -46d8-bc51-29da6411e1c5. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort. 09:14:41.710 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - progress classes should be Serializable (5 seconds, 552 milliseconds) 09:14:46.345 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-3a41d397-c3c1 -490b-9cc7-d775b0c42208. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort. 09:14:46.345 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - SPARK-19378: Continue reporting stateOp metrics even if there is no active trigger (1 second, 337 milliseconds) 09:14:47.677 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - SPARK-29973: Make `processedRowsPerSecond` calculated more accurately and meaningfully (455 milliseconds) 09:14:48.174 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-360fc3b9-a2c5 -430c-a892-c9869f1f8339. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort. 09:14:48.174 WARN org.apache.spark.sql.execution.streaming.ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled. [info] - SPARK-45655: Use current batch timestamp in observe API (587 milliseconds) 09:14:48.768 WARN org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite: ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #43517 from sahnib/SPARK-45655. Authored-by: Bhuwan Sahni Signed-off-by: Jungtaek Lim --- .../sql/catalyst/analysis/CheckAnalysis.scala | 25 +++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 17 +++++-- ...StreamingQueryStatusAndProgressSuite.scala | 44 +++++++++++++++++++ 3 files changed, 74 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 352b3124a864b..d41345f38c2ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Median, PercentileCont, PercentileDisc} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Median, PercentileCont, PercentileDisc} import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery, InlineCTE} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -476,10 +476,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB e.failAnalysis( "INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED", Map("expr" -> toSQLExpr(s))) - case _ if !e.deterministic && !seenAggregate => - e.failAnalysis( - "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC", - Map("expr" -> toSQLExpr(s))) case a: AggregateExpression if seenAggregate => e.failAnalysis( "INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED", @@ -492,12 +488,18 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB e.failAnalysis( "INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED", Map("expr" -> toSQLExpr(s))) + case _: AggregateExpression | _: AggregateFunction => + e.children.foreach(checkMetric (s, _, seenAggregate = true)) case _: Attribute if !seenAggregate => e.failAnalysis( "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE", Map("expr" -> toSQLExpr(s))) - case _: AggregateExpression => - e.children.foreach(checkMetric (s, _, seenAggregate = true)) + case a: Alias => + checkMetric(s, a.child, seenAggregate) + case a if !e.deterministic && !seenAggregate => + e.failAnalysis( + "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC", + Map("expr" -> toSQLExpr(s))) case _ => e.children.foreach(checkMetric (s, _, seenAggregate)) } @@ -734,8 +736,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB "dataType" -> toSQLType(mapCol.dataType))) case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && - !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] && + !o.isInstanceOf[Project] && + // non-deterministic expressions inside CollectMetrics have been + // already validated inside checkMetric function + !o.isInstanceOf[CollectMetrics] && + !o.isInstanceOf[Filter] && + !o.isInstanceOf[Aggregate] && + !o.isInstanceOf[Window] && !o.isInstanceOf[Expand] && !o.isInstanceOf[Generate] && !o.isInstanceOf[CreateVariable] && diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ca22c55b49e89..8e514e245cb9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -794,9 +794,20 @@ class AnalysisSuite extends AnalysisTest with Matchers { // No columns assert(!CollectMetrics("evt", Nil, testRelation, 0).resolved) - def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit = { - assertAnalysisError(CollectMetrics("event", exprs, testRelation, 0), errors) - } + // non-deterministic expression inside an aggregate function is valid + val tsLiteral = Literal.create(java.sql.Timestamp.valueOf("2023-11-30 21:05:00.000000"), + TimestampType) + + assertAnalysisSuccess( + CollectMetrics( + "invalid", + Count( + GreaterThan(tsLiteral, CurrentBatchTimestamp(1699485296000L, TimestampType)) + ).as("count") :: Nil, + testRelation, + 0 + ) + ) // Unwrapped attribute assertAnalysisErrorClass( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 8fe4ef39b2552..8ff71473f271b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.streaming +import java.sql.Timestamp +import java.time.Instant +import java.time.temporal.ChronoUnit import java.util.UUID import scala.jdk.CollectionConverters._ @@ -355,6 +358,47 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually { ) } + test("SPARK-45655: Use current batch timestamp in observe API") { + import testImplicits._ + + val inputData = MemoryStream[Timestamp] + + // current_date() internally uses current batch timestamp on streaming query + val query = inputData.toDF() + .filter("value < current_date()") + .observe("metrics", count(expr("value >= current_date()")).alias("dropped")) + .writeStream + .queryName("ts_metrics_test") + .format("memory") + .outputMode("append") + .start() + + val timeNow = Instant.now().truncatedTo(ChronoUnit.SECONDS) + + // this value would be accepted by the filter and would not count towards + // dropped metrics. + val validValue = Timestamp.from(timeNow.minus(2, ChronoUnit.DAYS)) + inputData.addData(validValue) + + // would be dropped by the filter and count towards dropped metrics + inputData.addData(Timestamp.from(timeNow.plus(2, ChronoUnit.DAYS))) + + query.processAllAvailable() + query.stop() + + val dropped = query.recentProgress.map { p => + val metricVal = Option(p.observedMetrics.get("metrics")) + metricVal.map(_.getLong(0)).getOrElse(0L) + }.sum + // ensure dropped metrics are correct + assert(dropped == 1) + + val data = spark.read.table("ts_metrics_test").collect() + + // ensure valid value ends up in output + assert(data(0).getAs[Timestamp](0).equals(validValue)) + } + def waitUntilBatchProcessed: AssertOnQuery = Execute { q => eventually(Timeout(streamingTimeout)) { if (q.exception.isEmpty) { From eccf07694e3ea72753e4942ce47872c96890354a Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sun, 12 Nov 2023 14:00:04 -0800 Subject: [PATCH 095/121] [SPARK-45870][INFRA] Upgrade action/checkout to v4 ### What changes were proposed in this pull request? The pr aims to upgrade action/checkout from v3(node16) to v4(node20) (point to v4.1.1 now). ### Why are the changes needed? - After V4.0.0, the action/checkout default runtime is node20. image - Node 16 has reached its [end of life](https://github.com/nodejs/Release/#end-of-life-releases) image https://github.blog/changelog/2023-09-22-github-actions-transitioning-from-node-16-to-node-20/ ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43744 from panbingkun/github_checkout_action. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- .github/workflows/benchmark.yml | 6 ++--- .github/workflows/build_and_test.yml | 24 +++++++++---------- .../workflows/build_infra_images_cache.yml | 2 +- .github/workflows/maven_test.yml | 2 +- .github/workflows/publish_snapshot.yml | 2 +- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index f4b17b4117a44..8e7551fa7738a 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -65,7 +65,7 @@ jobs: SPARK_LOCAL_IP: localhost steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # In order to get diff files with: fetch-depth: 0 @@ -95,7 +95,7 @@ jobs: key: tpcds-${{ hashFiles('.github/workflows/benchmark.yml', 'sql/core/src/test/scala/org/apache/spark/sql/TPCDSSchema.scala') }} - name: Checkout tpcds-kit repository if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true' - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: databricks/tpcds-kit ref: 2a5078a782192ddb6efbcead8de9973d6ab4f069 @@ -134,7 +134,7 @@ jobs: SPARK_TPCDS_DATA: ${{ github.workspace }}/tpcds-sf-1 steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # In order to get diff files with: fetch-depth: 0 diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 1e28f5530513a..897d7a68d8a9b 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -64,7 +64,7 @@ jobs: }} steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 repository: apache/spark @@ -202,7 +202,7 @@ jobs: SKIP_PACKAGING: true steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # In order to fetch changed files with: fetch-depth: 0 @@ -302,7 +302,7 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # In order to fetch changed files with: fetch-depth: 0 @@ -381,7 +381,7 @@ jobs: BRANCH: ${{ inputs.branch }} steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # In order to fetch changed files with: fetch-depth: 0 @@ -496,7 +496,7 @@ jobs: SKIP_PACKAGING: true steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # In order to fetch changed files with: fetch-depth: 0 @@ -564,7 +564,7 @@ jobs: runs-on: ubuntu-22.04 steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 repository: apache/spark @@ -604,7 +604,7 @@ jobs: image: ${{ needs.precondition.outputs.image_url }} steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 repository: apache/spark @@ -786,7 +786,7 @@ jobs: timeout-minutes: 300 steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 repository: apache/spark @@ -841,7 +841,7 @@ jobs: SPARK_LOCAL_IP: localhost steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 repository: apache/spark @@ -883,7 +883,7 @@ jobs: key: tpcds-${{ hashFiles('.github/workflows/build_and_test.yml', 'sql/core/src/test/scala/org/apache/spark/sql/TPCDSSchema.scala') }} - name: Checkout tpcds-kit repository if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true' - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: databricks/tpcds-kit ref: 2a5078a782192ddb6efbcead8de9973d6ab4f069 @@ -946,7 +946,7 @@ jobs: SKIP_PACKAGING: true steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 repository: apache/spark @@ -1006,7 +1006,7 @@ jobs: timeout-minutes: 300 steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 repository: apache/spark diff --git a/.github/workflows/build_infra_images_cache.yml b/.github/workflows/build_infra_images_cache.yml index b8aae945599de..3e0258830848c 100644 --- a/.github/workflows/build_infra_images_cache.yml +++ b/.github/workflows/build_infra_images_cache.yml @@ -38,7 +38,7 @@ jobs: packages: write steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up QEMU uses: docker/setup-qemu-action@v2 - name: Set up Docker Buildx diff --git a/.github/workflows/maven_test.yml b/.github/workflows/maven_test.yml index 169def690ae49..6d2c25f07708b 100644 --- a/.github/workflows/maven_test.yml +++ b/.github/workflows/maven_test.yml @@ -115,7 +115,7 @@ jobs: GITHUB_PREV_SHA: ${{ github.event.before }} steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # In order to fetch changed files with: fetch-depth: 0 diff --git a/.github/workflows/publish_snapshot.yml b/.github/workflows/publish_snapshot.yml index c78b633a980c5..9ea214e5be8d7 100644 --- a/.github/workflows/publish_snapshot.yml +++ b/.github/workflows/publish_snapshot.yml @@ -41,7 +41,7 @@ jobs: branch: ${{ fromJSON( inputs.branch || '["master", "branch-3.5", "branch-3.4", "branch-3.3"]' ) }} steps: - name: Checkout Spark repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ matrix.branch }} - name: Cache Maven local repository From 89ff8298dcc7d8872281830db837252cf63492e8 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Sun, 12 Nov 2023 14:26:35 -0800 Subject: [PATCH 096/121] [SPARK-45875][CORE] Remove `MissingStageTableRowData` from `core` module ### What changes were proposed in this pull request? SPARK-15591(https://github.com/apache/spark/pull/13708) introduced the `MissingStageTableRowData`, but it is no longer used after SPARK-20648(https://github.com/apache/spark/pull/19698), so this PR removes it. ### Why are the changes needed? Clean up unused code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43748 from LuciferYang/SPARK-45875. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/ui/jobs/StageTable.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 9e78f29e92e5d..0f611c7472ebd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -90,13 +90,6 @@ private[ui] class StageTableRowData( val shuffleWrite: Long, val shuffleWriteWithUnit: String) -private[ui] class MissingStageTableRowData( - stageInfo: v1.StageData, - stageId: Int, - attemptId: Int) extends StageTableRowData( - stageInfo, None, stageId, attemptId, "", None, new Date(0), "", -1, "", 0, "", 0, "", 0, "", 0, - "") - /** Page showing list of all ongoing and recently finished stages */ private[ui] class StagePagedTable( store: AppStatusStore, From 375aacd2e7308395b5a5503210cbfc5177be2925 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Sun, 12 Nov 2023 14:28:44 -0800 Subject: [PATCH 097/121] [SPARK-45886][SQL] Output full stack trace in `callSite` of DataFrame context ### What changes were proposed in this pull request? In the PR, I propose to include all available stack traces in DataFrame context to the `callSite` field and apparently to the `summary`. For now, DataFrame context contains only one item of stack trace, but later we'll add a config to control the number of items in stack traces (see https://github.com/apache/spark/pull/43695). ### Why are the changes needed? To improve user experience with Spark SQL while debugging some issue. Users can see all available stack trace, and see from where the issue comes in user code from. ### Does this PR introduce _any_ user-facing change? No, should not. Even if user's code parses the summary. ### How was this patch tested? By running new test suite: ``` $ build/sbt "test:testOnly *QueryContextSuite" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43758 from MaxGekk/output-stack-trace. Authored-by: Max Gekk Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/trees/QueryContexts.scala | 4 +- .../spark/sql/errors/QueryContextSuite.scala | 39 +++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/errors/QueryContextSuite.scala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index 874c834b75585..57271e535afbf 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -153,7 +153,7 @@ case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement]) extends Que }.getOrElse("") } - override val callSite: String = stackTrace.tail.headOption.map(_.toString).getOrElse("") + override val callSite: String = stackTrace.tail.mkString("\n") override lazy val summary: String = { val builder = new StringBuilder @@ -162,7 +162,7 @@ case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement]) extends Que builder ++= fragment builder ++= "\"" - builder ++= " was called from " + builder ++= " was called from\n" builder ++= callSite builder += '\n' builder.result() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryContextSuite.scala new file mode 100644 index 0000000000000..7d57eeb01bfa1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryContextSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.errors + +import org.apache.spark.SparkArithmeticException +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class QueryContextSuite extends QueryTest with SharedSparkSession { + + test("summary of DataFrame context") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val e = intercept[SparkArithmeticException] { + spark.range(1).select(lit(1) / lit(0)).collect() + } + assert(e.getQueryContext.head.summary() == + """== DataFrame == + |"div" was called from + |org.apache.spark.sql.errors.QueryContextSuite.$anonfun$new$3(QueryContextSuite.scala:30) + |""".stripMargin) + } + } +} From 202884f4c5712492d1d893b590c63f835dbe3c0a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sun, 12 Nov 2023 14:30:57 -0800 Subject: [PATCH 098/121] [SPARK-45770][SQL][PYTHON][CONNECT] Introduce plan `DataFrameDropColumns` for `Dataframe.drop` ### What changes were proposed in this pull request? Fix column resolution in DataFrame.drop ### Why are the changes needed? ``` from pyspark.sql.functions import col # create first dataframe left_df = spark.createDataFrame([(1, 'a'), (2, 'b'), (3, 'c')], ['join_key', 'value1']) # create second dataframe right_df = spark.createDataFrame([(1, 'aa'), (2, 'bb'), (4, 'dd')], ['join_key', 'value2']) joined_df = left_df.join(right_df, on=left_df['join_key'] == right_df['join_key'], how='left') display(joined_df) cleaned_df = joined_df.drop(left_df['join_key']) display(cleaned_df) # error here JVM stacktrace: org.apache.spark.sql.AnalysisException: [AMBIGUOUS_REFERENCE] Reference `join_key` is ambiguous, could be: [`join_key`, `join_key`]. SQLSTATE: 42704 at org.apache.spark.sql.errors.QueryCompilationErrors$.ambiguousReferenceError(QueryCompilationErrors.scala:1957) at org.apache.spark.sql.catalyst.expressions.package$AttributeSeq.resolve(package.scala:377) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolve(LogicalPlan.scala:156) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveQuoted(LogicalPlan.scala:167) at org.apache.spark.sql.Dataset.$anonfun$drop$4(Dataset.scala:3071) ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #43683 from zhengruifeng/sql_drop_plan. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/tests/test_dataframe.py | 37 ++++++++++++++ .../sql/catalyst/analysis/Analyzer.scala | 1 + .../ResolveDataFrameDropColumns.scala | 49 +++++++++++++++++++ .../plans/logical/basicLogicalOperators.scala | 14 ++++++ .../sql/catalyst/trees/TreePatterns.scala | 1 + .../scala/org/apache/spark/sql/Dataset.scala | 17 ++----- 6 files changed, 106 insertions(+), 13 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 0a2e3a5394602..527cf702bce9e 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -105,6 +105,43 @@ def test_drop(self): self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"]) self.assertEqual(df.drop(col("name"), col("age"), col("random")).columns, ["active"]) + def test_drop_join(self): + left_df = self.spark.createDataFrame( + [(1, "a"), (2, "b"), (3, "c")], + ["join_key", "value1"], + ) + right_df = self.spark.createDataFrame( + [(1, "aa"), (2, "bb"), (4, "dd")], + ["join_key", "value2"], + ) + joined_df = left_df.join( + right_df, + on=left_df["join_key"] == right_df["join_key"], + how="left", + ) + + dropped_1 = joined_df.drop(left_df["join_key"]) + self.assertEqual(dropped_1.columns, ["value1", "join_key", "value2"]) + self.assertEqual( + dropped_1.sort("value1").collect(), + [ + Row(value1="a", join_key=1, value2="aa"), + Row(value1="b", join_key=2, value2="bb"), + Row(value1="c", join_key=None, value2=None), + ], + ) + + dropped_2 = joined_df.drop(right_df["join_key"]) + self.assertEqual(dropped_2.columns, ["join_key", "value1", "value2"]) + self.assertEqual( + dropped_2.sort("value1").collect(), + [ + Row(join_key=1, value1="a", value2="aa"), + Row(join_key=2, value1="b", value2="bb"), + Row(join_key=3, value1="c", value2=None), + ], + ) + def test_with_columns_renamed(self): df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0b0bf86cdd039..0e07ef64a9d4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -304,6 +304,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: ResolveOutputRelation :: + new ResolveDataFrameDropColumns(catalogManager) :: new ResolveSetVariable(catalogManager) :: ExtractWindowExpressions :: GlobalAggregates :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala new file mode 100644 index 0000000000000..2642b4a1c5daa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.DF_DROP_COLUMNS +import org.apache.spark.sql.connector.catalog.CatalogManager + +/** + * A rule that rewrites DataFrameDropColumns to Project. + * Note that DataFrameDropColumns allows and ignores non-existing columns. + */ +class ResolveDataFrameDropColumns(val catalogManager: CatalogManager) + extends Rule[LogicalPlan] with ColumnResolutionHelper { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(DF_DROP_COLUMNS)) { + case d: DataFrameDropColumns if d.childrenResolved => + // expressions in dropList can be unresolved, e.g. + // df.drop(col("non-existing-column")) + val dropped = d.dropList.map { + case u: UnresolvedAttribute => + resolveExpressionByPlanChildren(u, d.child) + case e => e + } + val remaining = d.child.output.filterNot(attr => dropped.exists(_.semanticEquals(attr))) + if (remaining.size == d.child.output.size) { + d.child + } else { + Project(remaining, d.child) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 4353f535aaa4a..05abaf3090644 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -237,6 +237,20 @@ object Project { } } +case class DataFrameDropColumns(dropList: Seq[Expression], child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = Nil + + override def maxRows: Option[Long] = child.maxRows + override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition + + final override val nodePatterns: Seq[TreePattern] = Seq(DF_DROP_COLUMNS) + + override lazy val resolved: Boolean = false + + override protected def withNewChildInternal(newChild: LogicalPlan): DataFrameDropColumns = + copy(child = newChild) +} + /** * Applies a [[Generator]] to a stream of input rows, combining the * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 9b3337d1a9406..1f0df8f3b8ab2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -106,6 +106,7 @@ object TreePattern extends Enumeration { val AS_OF_JOIN: Value = Value val COMMAND: Value = Value val CTE: Value = Value + val DF_DROP_COLUMNS: Value = Value val DISTINCT_LIKE: Value = Value val EVAL_PYTHON_UDF: Value = Value val EVAL_PYTHON_UDTF: Value = Value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d18b12964c6d3..5a372f9a0f917 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3064,19 +3064,10 @@ class Dataset[T] private[sql]( * @since 3.4.0 */ @scala.annotation.varargs - def drop(col: Column, cols: Column*): DataFrame = { - val allColumns = col +: cols - val expressions = (for (col <- allColumns) yield col match { - case Column(u: UnresolvedAttribute) => - queryExecution.analyzed.resolveQuoted( - u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) - case Column(expr: Expression) => expr - }) - val attrs = this.logicalPlan.output - val colsAfterDrop = attrs.filter { attr => - expressions.forall(expression => !attr.semanticEquals(expression)) - }.map(attr => Column(attr)) - select(colsAfterDrop : _*) + def drop(col: Column, cols: Column*): DataFrame = withOrigin { + withPlan { + DataFrameDropColumns((col +: cols).map(_.expr), logicalPlan) + } } /** From e440f3245243a31e7bdfe945e1ce7194609b78fb Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 12 Nov 2023 14:34:32 -0800 Subject: [PATCH 099/121] [SPARK-45896][SQL] Construct `ValidateExternalType` with the correct expected type ### What changes were proposed in this pull request? When creating a serializer for a `Map` or `Seq` with an element of type `Option`, pass an expected type of `Option` to `ValidateExternalType` rather than the `Option`'s type argument. ### Why are the changes needed? In 3.4.1, 3.5.0, and master, the following code gets an error: ``` scala> val df = Seq(Seq(Some(Seq(0)))).toDF("a") val df = Seq(Seq(Some(Seq(0)))).toDF("a") org.apache.spark.SparkRuntimeException: [EXPRESSION_ENCODING_FAILED] Failed to encode a value of the expressions: mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -2), assertnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -2), IntegerType, IntegerType)), unwrapoption(ObjectType(interface scala.collection.immutable.Seq), validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), ArrayType(IntegerType,false), ObjectType(class scala.Option))), None), input[0, scala.collection.immutable.Seq, true], None) AS value#0 to a row. SQLSTATE: 42846 ... Caused by: java.lang.RuntimeException: scala.Some is not a valid external type for schema of array at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.MapObjects_0$(Unknown Source) ... ``` However, this code works in 3.3.3. Similarly, this code gets an error: ``` scala> val df = Seq(Seq(Some(java.sql.Timestamp.valueOf("2023-01-01 00:00:00")))).toDF("a") val df = Seq(Seq(Some(java.sql.Timestamp.valueOf("2023-01-01 00:00:00")))).toDF("a") org.apache.spark.SparkRuntimeException: [EXPRESSION_ENCODING_FAILED] Failed to encode a value of the expressions: mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), staticinvoke(class org.apache.spark.sql.catalyst.util.DateTimeUtils$, TimestampType, fromJavaTimestamp, unwrapoption(ObjectType(class java.sql.Timestamp), validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), TimestampType, ObjectType(class scala.Option))), true, false, true), input[0, scala.collection.immutable.Seq, true], None) AS value#0 to a row. SQLSTATE: 42846 ... Caused by: java.lang.RuntimeException: scala.Some is not a valid external type for schema of timestamp ... ``` As with the first example, this code works in 3.3.3. `SerializerBuildHelper#validateAndSerializeElement` will construct `ValidateExternalType` with an expected type of the `Option`'s type parameter. Therefore, for element types `Option[Seq/Date/Timestamp/BigDecimal]`, `ValidateExternalType` will try to validate that the element is of the contained type (e.g., `BigDecimal`) rather than of type `Option`. Since the element type is of type `Option`, the validation fails. Validation currently works by accident for element types `Option[Map/ Signed-off-by: Dongjoon Hyun --- .../spark/sql/catalyst/SerializerBuildHelper.scala | 7 ++++++- .../catalyst/encoders/ExpressionEncoderSuite.scala | 12 ++++++++++++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 9 +++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 27090ff6fa5d6..cd087514f4be3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -450,10 +450,15 @@ object SerializerBuildHelper { private def validateAndSerializeElement( enc: AgnosticEncoder[_], nullable: Boolean): Expression => Expression = { input => + val expected = enc match { + case OptionEncoder(_) => lenientExternalDataTypeFor(enc) + case _ => enc.dataType + } + expressionWithNullSafety( createSerializer( enc, - ValidateExternalType(input, enc.dataType, lenientExternalDataTypeFor(enc))), + ValidateExternalType(input, expected, lenientExternalDataTypeFor(enc))), nullable, WalkedTypePath()) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index dc5e22f0571ea..35d8327b93086 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -479,6 +479,18 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest(Option.empty[Int], "empty option of int") encodeDecodeTest(Option("abc"), "option of string") encodeDecodeTest(Option.empty[String], "empty option of string") + encodeDecodeTest(Seq(Some(Seq(0))), "SPARK-45896: seq of option of seq") + encodeDecodeTest(Map(0 -> Some(Seq(0))), "SPARK-45896: map of option of seq") + encodeDecodeTest(Seq(Some(Timestamp.valueOf("2023-01-01 00:00:00"))), + "SPARK-45896: seq of option of timestamp") + encodeDecodeTest(Map(0 -> Some(Timestamp.valueOf("2023-01-01 00:00:00"))), + "SPARK-45896: map of option of timestamp") + encodeDecodeTest(Seq(Some(Date.valueOf("2023-01-01"))), + "SPARK-45896: seq of option of date") + encodeDecodeTest(Map(0 -> Some(Date.valueOf("2023-01-01"))), + "SPARK-45896: map of option of date") + encodeDecodeTest(Seq(Some(BigDecimal(200))), "SPARK-45896: seq of option of bigdecimal") + encodeDecodeTest(Map(0 -> Some(BigDecimal(200))), "SPARK-45896: map of option of bigdecimal") encodeDecodeTest(ScroogeLikeExample(1), "SPARK-40385 class with only a companion object constructor") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index fe64e5abc5350..152fd0d7d8ed7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -274,6 +274,13 @@ class DatasetSuite extends QueryTest (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } + test("SPARK-45896: seq of option of seq") { + val ds = Seq(DataSeqOptSeq(Seq(Some(Seq(0))))).toDS() + checkDataset( + ds, + DataSeqOptSeq(Seq(Some(List(0))))) + } + test("select") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset( @@ -2760,6 +2767,8 @@ case class ClassNullableData(a: String, b: Integer) case class NestedStruct(f: ClassData) case class DeepNestedStruct(f: NestedStruct) +case class DataSeqOptSeq(a: Seq[Option[Seq[Int]]]) + /** * A class used to test serialization using encoders. This class throws exceptions when using * Java serialization -- so the only way it can be "serialized" is through our encoders. From 9ab75f65cee6548e305d172d6fe79e57d3ce94f6 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Sun, 12 Nov 2023 14:39:59 -0800 Subject: [PATCH 100/121] [MINOR][CONNECT][TESTS] Use `Some(())` instead of `Some()` for `Option[Unit]` return type ### What changes were proposed in this pull request? This pr fix a compile warning like ``` [warn] /Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala:95:9: adaptation of an empty argument list by inserting () is deprecated: this is unlikely to be what you want [warn] signature: Some.apply[A](value: A): Some[A] [warn] given arguments: [warn] after adaptation: Some((): Unit) [warn] Applicable -Wconf / nowarn filters for this fatal warning: msg=, cat=deprecation, site=org.apache.spark.sql.connect.plugin.ExampleCommandPlugin.process, origin=scala.Some.apply, version=2.11.0 [warn] Some() [warn] ^ ``` ### Why are the changes needed? Clean up a deprecated usage since Scala 2.11.0 ### Does this PR introduce _any_ user-facing change? No, just for test ### How was this patch tested? Pass GitHub Action ### Was this patch authored or co-authored using generative AI tooling? No Closes #43733 from LuciferYang/minor-option-unit. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- .../sql/connect/plugin/SparkConnectPluginRegistrySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index e1de6b04d211e..ded5ca6415b94 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -92,7 +92,7 @@ class ExampleCommandPlugin extends CommandPlugin { val cmd = command.unpack(classOf[proto.ExamplePluginCommand]) assert(planner.session != null) SparkContext.getActive.get.setLocalProperty("testingProperty", cmd.getCustomField) - Some() + Some(()) } } From 36d57f8a29736c91b241d2dfa7d7062a6ea8027d Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Sun, 12 Nov 2023 14:43:08 -0800 Subject: [PATCH 101/121] [SPARK-45874][SQL] Remove Java version check from `IsolatedClientLoader` ### What changes were proposed in this pull request? This pr remove unnecessary Java version check from `IsolatedClientLoader`. ### Why are the changes needed? Apache Spark 4.0.0 has a minimum requirement of Java 17, so the version check for Java 9 is not necessary. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43747 from LuciferYang/SPARK-45874. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../hive/client/IsolatedClientLoader.scala | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 4027cd94d4150..74b33e6437fb6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -25,7 +25,6 @@ import java.util import scala.util.Try import org.apache.commons.io.{FileUtils, IOUtils} -import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.shims.ShimLoader @@ -233,22 +232,16 @@ private[hive] class IsolatedClientLoader( private[hive] val classLoader: MutableURLClassLoader = { val isolatedClassLoader = if (isolationOn) { - val rootClassLoader: ClassLoader = - if (SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) { - // In Java 9, the boot classloader can see few JDK classes. The intended parent - // classloader for delegation is now the platform classloader. - // See http://java9.wtf/class-loading/ - val platformCL = - classOf[ClassLoader].getMethod("getPlatformClassLoader"). - invoke(null).asInstanceOf[ClassLoader] - // Check to make sure that the root classloader does not know about Hive. - assert(Try(platformCL.loadClass("org.apache.hadoop.hive.conf.HiveConf")).isFailure) - platformCL - } else { - // The boot classloader is represented by null (the instance itself isn't accessible) - // and before Java 9 can see all JDK classes - null - } + val rootClassLoader: ClassLoader = { + // In Java 9, the boot classloader can see few JDK classes. The intended parent + // classloader for delegation is now the platform classloader. + // See http://java9.wtf/class-loading/ + val platformCL = classOf[ClassLoader].getMethod("getPlatformClassLoader") + .invoke(null).asInstanceOf[ClassLoader] + // Check to make sure that the root classloader does not know about Hive. + assert(Try(platformCL.loadClass("org.apache.hadoop.hive.conf.HiveConf")).isFailure) + platformCL + } new URLClassLoader(allJars, rootClassLoader) { override def loadClass(name: String, resolve: Boolean): Class[_] = { val loaded = findLoadedClass(name) From 6e6a3a1f631a53e3cf5332f88e52d6c6686ba529 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Sun, 12 Nov 2023 14:46:31 -0800 Subject: [PATCH 102/121] [SPARK-45830][CORE] Refactor `StorageUtils#bufferCleaner` ### What changes were proposed in this pull request? This pr refactor `StorageUtils#bufferCleaner` as follows: - Change the return value of `bufferCleaner` from `DirectBuffer => Unit` to `ByteBuffer => Unit` - Directly calling `unsafe.invokeCleaner` instead of reflecting calls ### Why are the changes needed? 1. After Scala 2.13.9, it is recommended to use the `-release` instead of the `-target` for compilation. However, due to `sun.nio.ch` module was not exported, this can lead to the issue of class invisibility during Java cross compilation, such as building or testing using Java 21 with `-release:17` After this pr, the following compilation errors will not occur again when build core module using Java 21 with `-release:17`: ``` [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala:71: object security is not a member of package sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:26: object nio is not a member of package sun [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:200: not found: type DirectBuffer [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:206: not found: type DirectBuffer [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:220: not found: type DirectBuffer [ERROR] [Error] /Users/yangjie01/SourceCode/git/spark-mine-13/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala:26: Unused import ``` 2. Direct use of `unsafe.invokeCleaner` provides better performance, compared to reflection calls, it is at least 30% faster ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions - Manual check building core module using Java 21 with `-release:17`, no longer compilation failure logs above Note: There is still an issue with other classes being invisible, which needs to be fixed in follow up ### Was this patch authored or co-authored using generative AI tooling? No Closes #43675 from LuciferYang/bufferCleaner. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/storage/StorageUtils.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index e73a65e09cb47..c409ee37a06a5 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -23,7 +23,6 @@ import scala.collection.Map import scala.collection.mutable import sun.misc.Unsafe -import sun.nio.ch.DirectBuffer import org.apache.spark.SparkConf import org.apache.spark.internal.{config, Logging} @@ -197,13 +196,11 @@ private[spark] class StorageStatus( /** Helper methods for storage-related objects. */ private[spark] object StorageUtils extends Logging { - private val bufferCleaner: DirectBuffer => Unit = { - val cleanerMethod = - Utils.classForName("sun.misc.Unsafe").getMethod("invokeCleaner", classOf[ByteBuffer]) + private val bufferCleaner: ByteBuffer => Unit = { val unsafeField = classOf[Unsafe].getDeclaredField("theUnsafe") unsafeField.setAccessible(true) val unsafe = unsafeField.get(null).asInstanceOf[Unsafe] - buffer: DirectBuffer => cleanerMethod.invoke(unsafe, buffer) + buffer: ByteBuffer => unsafe.invokeCleaner(buffer) } /** @@ -217,7 +214,7 @@ private[spark] object StorageUtils extends Logging { def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { logTrace(s"Disposing of $buffer") - bufferCleaner(buffer.asInstanceOf[DirectBuffer]) + bufferCleaner(buffer) } } From 4397c09bc0d79fe70bcc8253d584684b84d3b768 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Sun, 12 Nov 2023 14:50:15 -0800 Subject: [PATCH 103/121] [SPARK-45848][BUILD] Make `spark-version-info.properties` generated by `spark-build-info.ps1` include `docroot` ### What changes were proposed in this pull request? The `spark-version-info.properties` generated by `spark-build-info` include `docroot=https://spark.apache.org/docs/latest`, this pr make `spark-version-info.properties` generated by `spark-build-info.ps1` also include `docroot` part. ### Why are the changes needed? Keep the items generated by `spark-build-info` and `spark-build-info.ps1` consistent ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manual verification with this pr: - `spark-version-info.properties` generated by `spark-build-info` ``` version=4.0.0-SNAPSHOT user=yangjie01 revision=d92f634bdaf8d040b7c7a5ca675db3cb265486b5 branch=SPARK-45848 date=2023-11-09T06:21:05Z url=gitgithub.com:LuciferYang/spark.git docroot=https://spark.apache.org/docs/latest ``` - `spark-version-info.properties` generated by `spark-build-info.ps1` ``` version=4.0.0-SNAPSHOT user=yangjie01 revision=d92f634bdaf8d040b7c7a5ca675db3cb265486b5 branch=SPARK-45848 date=2023-11-09T06:22:25Z url=gitgithub.com:LuciferYang/spark.git, docroot=https://spark.apache.org/docs/latest ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #43726 from LuciferYang/SPARK-45848. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- build/spark-build-info.ps1 | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/build/spark-build-info.ps1 b/build/spark-build-info.ps1 index 43db8823340c6..def98fec85f55 100644 --- a/build/spark-build-info.ps1 +++ b/build/spark-build-info.ps1 @@ -41,6 +41,7 @@ user=$($Env:USERNAME) revision=$(git rev-parse HEAD) branch=$(git rev-parse --abbrev-ref HEAD) date=$([DateTime]::UtcNow | Get-Date -UFormat +%Y-%m-%dT%H:%M:%SZ) -url=$(git config --get remote.origin.url)" +url=$(git config --get remote.origin.url), +docroot=https://spark.apache.org/docs/latest" Set-Content -Path $SparkBuildInfoPath -Value $SparkBuildInfoContent From 3013d8b4d310997c8a6f6021f860851cd4f3c32a Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Sun, 12 Nov 2023 14:56:13 -0800 Subject: [PATCH 104/121] [SPARK-45857][SQL] Enforce the error classes in sub-classes of `AnalysisException` ### What changes were proposed in this pull request? In the PR, I propose to enforce creation of `AnalysisException` sub-class exceptions with an error class always. In particular, it converts the constructor with a message to private one, so, callers have to create a sub-class of `AnalysisException` with an error class. ### Why are the changes needed? This simplifies migration on error classes. ### Does this PR introduce _any_ user-facing change? No, since user code doesn't throw `AnalysisException` and its sub-classes in regular cases. ### How was this patch tested? By existing test suites, for instance: ``` $ build/sbt "sql/testOnly *QueryParsingErrorsSuite" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43731 from MaxGekk/ban-message-subclasses-AnalysisException. Authored-by: Max Gekk Signed-off-by: Dongjoon Hyun --- .../main/resources/error/error-classes.json | 5 ++ .../client/GrpcExceptionConverter.scala | 32 +++----- .../catalyst/analysis/NonEmptyException.scala | 2 +- .../analysis/alreadyExistException.scala | 45 ++--------- .../analysis/noSuchItemsExceptions.scala | 79 ++++++------------- .../analysis/AlreadyExistException.scala | 14 ---- .../analysis/NoSuchItemException.scala | 14 ---- .../catalog/InvalidUDFClassException.scala | 2 +- .../sql/errors/QueryCompilationErrors.scala | 6 ++ .../org/apache/spark/sql/jdbc/H2Dialect.scala | 6 +- .../sql/hive/HiveSessionStateBuilder.scala | 4 +- 11 files changed, 60 insertions(+), 149 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 3b7a3a6006ef3..e3b9f3161b24d 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -6334,6 +6334,11 @@ "Operation not allowed: only works on table with location provided: " ] }, + "_LEGACY_ERROR_TEMP_2450" : { + "message" : [ + "No handler for UDF/UDAF/UDTF ''" + ] + }, "_LEGACY_ERROR_TEMP_3000" : { "message" : [ "Unexpected Py4J server ." diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 52bd276b0c4b5..73e2db2f4ac7a 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -204,34 +204,24 @@ private[client] object GrpcExceptionConverter { messageParameters = params.messageParameters, context = params.queryContext)), errorConstructor(params => - new NamespaceAlreadyExistsException( - params.message, - params.errorClass, - params.messageParameters)), + new NamespaceAlreadyExistsException(params.errorClass.orNull, params.messageParameters)), errorConstructor(params => new TableAlreadyExistsException( - params.message, - params.cause, - params.errorClass, - params.messageParameters)), + params.errorClass.orNull, + params.messageParameters, + params.cause)), errorConstructor(params => new TempTableAlreadyExistsException( - params.message, - params.cause, - params.errorClass, - params.messageParameters)), + params.errorClass.orNull, + params.messageParameters, + params.cause)), errorConstructor(params => new NoSuchDatabaseException( - params.message, - params.cause, - params.errorClass, - params.messageParameters)), + params.errorClass.orNull, + params.messageParameters, + params.cause)), errorConstructor(params => - new NoSuchTableException( - params.message, - params.cause, - params.errorClass, - params.messageParameters)), + new NoSuchTableException(params.errorClass.orNull, params.messageParameters, params.cause)), errorConstructor[NumberFormatException](params => new SparkNumberFormatException( params.message, diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala index ecd57672b6168..2aea9bac12fed 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.QuotingUtils.quoted * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ -case class NonEmptyNamespaceException( +case class NonEmptyNamespaceException private( override val message: String, override val cause: Option[Throwable] = None) extends AnalysisException(message, cause = cause) { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala index 85eba2b246143..8932a0296428f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala @@ -32,7 +32,7 @@ class DatabaseAlreadyExistsException(db: String) extends NamespaceAlreadyExistsException(Array(db)) // any changes to this class should be backward compatible as it may be used by external connectors -class NamespaceAlreadyExistsException private[sql]( +class NamespaceAlreadyExistsException private( message: String, errorClass: Option[String], messageParameters: Map[String, String]) @@ -52,17 +52,10 @@ class NamespaceAlreadyExistsException private[sql]( this(errorClass = "SCHEMA_ALREADY_EXISTS", Map("schemaName" -> quoteNameParts(namespace.toImmutableArraySeq))) } - - def this(message: String) = { - this( - message, - errorClass = Some("SCHEMA_ALREADY_EXISTS"), - messageParameters = Map.empty[String, String]) - } } // any changes to this class should be backward compatible as it may be used by external connectors -class TableAlreadyExistsException private[sql]( +class TableAlreadyExistsException private( message: String, cause: Option[Throwable], errorClass: Option[String], @@ -106,21 +99,13 @@ class TableAlreadyExistsException private[sql]( messageParameters = Map("relationName" -> quoted(tableIdent)), cause = None) } - - def this(message: String, cause: Option[Throwable] = None) = { - this( - message, - cause, - errorClass = Some("TABLE_OR_VIEW_ALREADY_EXISTS"), - messageParameters = Map.empty[String, String]) - } } -class TempTableAlreadyExistsException private[sql]( - message: String, - cause: Option[Throwable], - errorClass: Option[String], - messageParameters: Map[String, String]) +class TempTableAlreadyExistsException private( + message: String, + cause: Option[Throwable], + errorClass: Option[String], + messageParameters: Map[String, String]) extends AnalysisException( message, cause = cause, @@ -144,14 +129,6 @@ class TempTableAlreadyExistsException private[sql]( messageParameters = Map("relationName" -> quoteNameParts(AttributeNameParser.parseAttributeName(table)))) } - - def this(message: String, cause: Option[Throwable]) = { - this( - message, - cause, - errorClass = Some("TEMP_TABLE_OR_VIEW_ALREADY_EXISTS"), - messageParameters = Map.empty[String, String]) - } } // any changes to this class should be backward compatible as it may be used by external connectors @@ -203,12 +180,4 @@ class IndexAlreadyExistsException private( def this(indexName: String, tableName: String, cause: Option[Throwable]) = { this("INDEX_ALREADY_EXISTS", Map("indexName" -> indexName, "tableName" -> tableName), cause) } - - def this(message: String, cause: Option[Throwable] = None) = { - this( - message, - cause, - errorClass = Some("INDEX_ALREADY_EXISTS"), - messageParameters = Map.empty[String, String]) - } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala index b7c8473c08c04..ac22d26ccfd18 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala @@ -27,21 +27,21 @@ import org.apache.spark.util.ArrayImplicits._ * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ -class NoSuchDatabaseException private[sql]( - message: String, - cause: Option[Throwable], - errorClass: Option[String], - messageParameters: Map[String, String]) +class NoSuchDatabaseException private( + message: String, + cause: Option[Throwable], + errorClass: Option[String], + messageParameters: Map[String, String]) extends AnalysisException( message, cause = cause, errorClass = errorClass, messageParameters = messageParameters) { - def this(errorClass: String, messageParameters: Map[String, String]) = { + def this(errorClass: String, messageParameters: Map[String, String], cause: Option[Throwable]) = { this( SparkThrowableHelper.getMessage(errorClass, messageParameters), - cause = None, + cause = cause, Some(errorClass), messageParameters) } @@ -49,15 +49,8 @@ class NoSuchDatabaseException private[sql]( def this(db: String) = { this( errorClass = "SCHEMA_NOT_FOUND", - messageParameters = Map("schemaName" -> quoteIdentifier(db))) - } - - def this(message: String, cause: Option[Throwable]) = { - this( - message = message, - cause = cause, - errorClass = Some("SCHEMA_NOT_FOUND"), - messageParameters = Map.empty[String, String]) + messageParameters = Map("schemaName" -> quoteIdentifier(db)), + cause = None) } } @@ -90,18 +83,10 @@ class NoSuchNamespaceException private( this(errorClass = "SCHEMA_NOT_FOUND", Map("schemaName" -> quoteNameParts(namespace.toImmutableArraySeq))) } - - def this(message: String, cause: Option[Throwable] = None) = { - this( - message, - cause, - errorClass = Some("SCHEMA_NOT_FOUND"), - messageParameters = Map.empty[String, String]) - } } // any changes to this class should be backward compatible as it may be used by external connectors -class NoSuchTableException private[sql]( +class NoSuchTableException private( message: String, cause: Option[Throwable], errorClass: Option[String], @@ -112,36 +97,34 @@ class NoSuchTableException private[sql]( errorClass = errorClass, messageParameters = messageParameters) { - def this(errorClass: String, messageParameters: Map[String, String]) = { + def this(errorClass: String, messageParameters: Map[String, String], cause: Option[Throwable]) = { this( SparkThrowableHelper.getMessage(errorClass, messageParameters), - cause = None, + cause = cause, Some(errorClass), messageParameters) } def this(db: String, table: String) = { - this(errorClass = "TABLE_OR_VIEW_NOT_FOUND", + this( + errorClass = "TABLE_OR_VIEW_NOT_FOUND", messageParameters = Map("relationName" -> - (quoteIdentifier(db) + "." + quoteIdentifier(table)))) + (quoteIdentifier(db) + "." + quoteIdentifier(table))), + cause = None) } def this(name : Seq[String]) = { - this(errorClass = "TABLE_OR_VIEW_NOT_FOUND", - messageParameters = Map("relationName" -> quoteNameParts(name))) + this( + errorClass = "TABLE_OR_VIEW_NOT_FOUND", + messageParameters = Map("relationName" -> quoteNameParts(name)), + cause = None) } def this(tableIdent: Identifier) = { - this(errorClass = "TABLE_OR_VIEW_NOT_FOUND", - messageParameters = Map("relationName" -> quoted(tableIdent))) - } - - def this(message: String, cause: Option[Throwable] = None) = { this( - message, - cause, - errorClass = Some("TABLE_OR_VIEW_NOT_FOUND"), - messageParameters = Map.empty[String, String]) + errorClass = "TABLE_OR_VIEW_NOT_FOUND", + messageParameters = Map("relationName" -> quoted(tableIdent)), + cause = None) } } @@ -186,14 +169,6 @@ class NoSuchFunctionException private( def this(identifier: Identifier) = { this(errorClass = "ROUTINE_NOT_FOUND", Map("routineName" -> quoted(identifier))) } - - def this(message: String, cause: Option[Throwable] = None) = { - this( - message, - cause, - errorClass = Some("ROUTINE_NOT_FOUND"), - messageParameters = Map.empty[String, String]) - } } class NoSuchTempFunctionException(func: String) @@ -225,12 +200,4 @@ class NoSuchIndexException private( def this(indexName: String, tableName: String, cause: Option[Throwable]) = { this("INDEX_NOT_FOUND", Map("indexName" -> indexName, "tableName" -> tableName), cause) } - - def this(message: String, cause: Option[Throwable] = None) = { - this( - message, - cause, - errorClass = Some("INDEX_NOT_FOUND"), - messageParameters = Map.empty[String, String]) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index bbac5ab7db3ef..4662f1c6ede6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -55,13 +55,6 @@ class PartitionAlreadyExistsException private( .map( kv => quoteIdentifier(s"${kv._2}") + s" = ${kv._1}").mkString(", ") + ")"), "tableName" -> quoteNameParts(UnresolvedAttribute.parseAttributeName(tableName)))) } - - def this(message: String) = { - this( - message, - errorClass = Some("PARTITIONS_ALREADY_EXIST"), - messageParameters = Map.empty[String, String]) - } } // any changes to this class should be backward compatible as it may be used by external connectors @@ -105,11 +98,4 @@ class PartitionsAlreadyExistException private( def this(tableName: String, partitionIdent: InternalRow, partitionSchema: StructType) = this(tableName, Seq(partitionIdent), partitionSchema) - - def this(message: String) = { - this( - message, - errorClass = Some("PARTITIONS_ALREADY_EXIST"), - messageParameters = Map.empty[String, String]) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index 217c293900ec4..5db713066ff96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -56,13 +56,6 @@ class NoSuchPartitionException private( .map( kv => quoteIdentifier(s"${kv._2}") + s" = ${kv._1}").mkString(", ") + ")"), "tableName" -> quoteNameParts(UnresolvedAttribute.parseAttributeName(tableName)))) } - - def this(message: String) = { - this( - message, - errorClass = Some("PARTITIONS_NOT_FOUND"), - messageParameters = Map.empty[String, String]) - } } // any changes to this class should be backward compatible as it may be used by external connectors @@ -98,11 +91,4 @@ class NoSuchPartitionsException private( .mkString(", ")).mkString("), PARTITION (") + ")"), "tableName" -> quoteNameParts(UnresolvedAttribute.parseAttributeName(tableName)))) } - - def this(message: String) = { - this( - message, - errorClass = Some("PARTITIONS_NOT_FOUND"), - messageParameters = Map.empty[String, String]) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InvalidUDFClassException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InvalidUDFClassException.scala index 658ddb21c6d9d..bfd8ba7d5a59c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InvalidUDFClassException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InvalidUDFClassException.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.AnalysisException * Thrown when a query failed for invalid function class, usually because a SQL * function's class does not follow the rules of the UDF/UDAF/UDTF class definition. */ -class InvalidUDFClassException private[sql]( +class InvalidUDFClassException private( message: String, errorClass: Option[String] = None, messageParameters: Map[String, String] = Map.empty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 7399f6c621cc7..c3249a4c02d8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3825,4 +3825,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("rowTag" -> toSQLId(optionName)) ) } + + def invalidUDFClassError(invalidClass: String): Throwable = { + new InvalidUDFClassException( + errorClass = "_LEGACY_ERROR_TEMP_2450", + messageParameters = Map("invalidClass" -> invalidClass)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 43888d0ffedda..9bed6a6f873e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -197,8 +197,10 @@ private[sql] object H2Dialect extends JdbcDialect { // TABLE_OR_VIEW_NOT_FOUND_1 case 42102 => val quotedName = quoteNameParts(UnresolvedAttribute.parseAttributeName(message)) - throw new NoSuchTableException(errorClass = "TABLE_OR_VIEW_NOT_FOUND", - messageParameters = Map("relationName" -> quotedName)) + throw new NoSuchTableException( + errorClass = "TABLE_OR_VIEW_NOT_FOUND", + messageParameters = Map("relationName" -> quotedName), + cause = Some(e)) // SCHEMA_NOT_FOUND_1 case 90079 => val regex = """"((?:[^"\\]|\\[\\"ntbrf])+)"""".r diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 0b5e98d0a3e40..e991665e2887c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, Inval import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.SparkPlanner import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin @@ -207,8 +208,7 @@ object HiveUDFExpressionBuilder extends SparkUDFExpressionBuilder { throw analysisException } udfExpr.getOrElse { - throw new InvalidUDFClassException( - s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'") + throw QueryCompilationErrors.invalidUDFClassError(clazz.getCanonicalName) } } } From ba0e098307573bbbba24e8afb992baee9f9bdbed Mon Sep 17 00:00:00 2001 From: Yihong He Date: Mon, 13 Nov 2023 08:46:14 +0900 Subject: [PATCH 105/121] [SPARK-45899][CONNECT] Set errorClass in errorInfoToThrowable ### What changes were proposed in this pull request? - Set errorClass in errorInfoToThrowable ### Why are the changes needed? - errorClass should be set even when error enrichment is not working ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite"` ### Was this patch authored or co-authored using generative AI tooling? No Closes #43772 from heyihong/SPARK-45899. Authored-by: Yihong He Signed-off-by: Hyukjin Kwon --- .../apache/spark/sql/ClientE2ETestSuite.scala | 1 + .../client/GrpcExceptionConverter.scala | 24 ++++++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 10c928f130416..ee238c5492f91 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -85,6 +85,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM |""".stripMargin) .collect() } + assert(ex.getErrorClass != null) if (enrichErrorEnabled) { assert(ex.getCause.isInstanceOf[DateTimeException]) } else { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 73e2db2f4ac7a..f4254c75a2a8c 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -168,7 +168,7 @@ private[client] object GrpcExceptionConverter { private[client] case class ErrorParams( message: String, cause: Option[Throwable], - // errorClass will only be set if the error is both enriched and SparkThrowable. + // errorClass will only be set if the error is SparkThrowable. errorClass: Option[String], // messageParameters will only be set if the error is both enriched and SparkThrowable. messageParameters: Map[String, String], @@ -357,14 +357,20 @@ private[client] object GrpcExceptionConverter { implicit val formats = DefaultFormats val classes = JsonMethods.parse(info.getMetadataOrDefault("classes", "[]")).extract[Array[String]] - - errorsToThrowable( - 0, - Seq( - FetchErrorDetailsResponse.Error + val errorClass = info.getMetadataOrDefault("errorClass", null) + val builder = FetchErrorDetailsResponse.Error + .newBuilder() + .setMessage(message) + .addAllErrorTypeHierarchy(classes.toImmutableArraySeq.asJava) + + if (errorClass != null) { + builder.setSparkThrowable( + FetchErrorDetailsResponse.SparkThrowable .newBuilder() - .setMessage(message) - .addAllErrorTypeHierarchy(classes.toImmutableArraySeq.asJava) - .build())) + .setErrorClass(errorClass) + .build()) + } + + errorsToThrowable(0, Seq(builder.build())) } } From ef240cdf6eaaa95f85aadc0f1272e991cc50bd35 Mon Sep 17 00:00:00 2001 From: Alice Sayutina Date: Mon, 13 Nov 2023 09:37:48 +0900 Subject: [PATCH 106/121] [SPARK-45733][CONNECT][PYTHON] Support multiple retry policies ### What changes were proposed in this pull request? Support multiple retry policies defined at the same time. Each policy determines which error types it can retry and how exactly those should be spread out. ### Why are the changes needed? Different error types should be treated differently For instance, networking connectivity errors and remote resources being initialized should be treated separately. ### Does this PR introduce _any_ user-facing change? No (as long as user doesn't poke within client internals). ### How was this patch tested? Unit tests, some hand testing. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43591 from cdkrot/SPARK-45733. Authored-by: Alice Sayutina Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/client/core.py | 218 ++----------- python/pyspark/sql/connect/client/reattach.py | 35 +-- python/pyspark/sql/connect/client/retries.py | 293 ++++++++++++++++++ .../sql/tests/connect/client/test_client.py | 53 ++-- .../sql/tests/connect/test_connect_basic.py | 203 ++++++------ 5 files changed, 468 insertions(+), 334 deletions(-) create mode 100644 python/pyspark/sql/connect/client/retries.py diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 7eafcc501f5f9..b98de0f9ceead 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -19,7 +19,6 @@ "SparkConnectClient", ] - from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -27,12 +26,9 @@ import threading import os import platform -import random -import time import urllib.parse import uuid import sys -from types import TracebackType from typing import ( Iterable, Iterator, @@ -45,9 +41,6 @@ Set, NoReturn, cast, - Callable, - Generator, - Type, TYPE_CHECKING, Sequence, ) @@ -66,10 +59,8 @@ from pyspark.resource.information import ResourceInformation from pyspark.sql.connect.client.artifact import ArtifactManager from pyspark.sql.connect.client.logging import logger -from pyspark.sql.connect.client.reattach import ( - ExecutePlanResponseReattachableIterator, - RetryException, -) +from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator +from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib @@ -555,38 +546,6 @@ class SparkConnectClient(object): Conceptually the remote spark session that communicates with the server """ - @classmethod - def retry_exception(cls, e: Exception) -> bool: - """ - Helper function that is used to identify if an exception thrown by the server - can be retried or not. - - Parameters - ---------- - e : Exception - The GRPC error as received from the server. Typed as Exception, because other exception - thrown during client processing can be passed here as well. - - Returns - ------- - True if the exception can be retried, False otherwise. - - """ - if not isinstance(e, grpc.RpcError): - return False - - if e.code() in [grpc.StatusCode.INTERNAL]: - msg = str(e) - - # This error happens if another RPC preempts this RPC. - if "INVALID_CURSOR.DISCONNECTED" in msg: - return True - - if e.code() == grpc.StatusCode.UNAVAILABLE: - return True - - return False - def __init__( self, connection: Union[str, ChannelBuilder], @@ -634,7 +593,9 @@ def __init__( else ChannelBuilder(connection, channel_options) ) self._user_id = None - self._retry_policy = { + self._retry_policies: List[RetryPolicy] = [] + + default_policy_args = { # Please synchronize changes here with Scala side # GrpcRetryHandler.scala # @@ -648,7 +609,10 @@ def __init__( "min_jitter_threshold": 2000, } if retry_policy: - self._retry_policy.update(retry_policy) + default_policy_args.update(retry_policy) + + default_policy = DefaultPolicy(**default_policy_args) + self.set_retry_policies([default_policy]) if self._builder.session_id is None: # Generate a unique session ID for this client. This UUID must be unique to allow @@ -680,9 +644,7 @@ def __init__( self._server_session_id: Optional[str] = None def _retrying(self) -> "Retrying": - return Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy # type: ignore - ) + return Retrying(self._retry_policies) def disable_reattachable_execute(self) -> "SparkConnectClient": self._use_reattachable_execute = False @@ -692,6 +654,20 @@ def enable_reattachable_execute(self) -> "SparkConnectClient": self._use_reattachable_execute = True return self + def set_retry_policies(self, policies: Iterable[RetryPolicy]) -> None: + """ + Sets list of policies to be used for retries. + I.e. set_retry_policies([DefaultPolicy(), CustomPolicy()]). + + """ + self._retry_policies = list(policies) + + def get_retry_policies(self) -> List[RetryPolicy]: + """ + Return list of currently used policies + """ + return list(self._retry_policies) + def register_udf( self, function: Any, @@ -1152,7 +1128,7 @@ def handle_response(b: pb2.ExecutePlanResponse) -> None: if self._use_reattachable_execute: # Don't use retryHandler - own retry handling is inside. generator = ExecutePlanResponseReattachableIterator( - req, self._stub, self._retry_policy, self._builder.metadata() + req, self._stub, self._retrying, self._builder.metadata() ) for b in generator: handle_response(b) @@ -1262,7 +1238,7 @@ def handle_response( if self._use_reattachable_execute: # Don't use retryHandler - own retry handling is inside. generator = ExecutePlanResponseReattachableIterator( - req, self._stub, self._retry_policy, self._builder.metadata() + req, self._stub, self._retrying, self._builder.metadata() ) for b in generator: yield from handle_response(b) @@ -1641,145 +1617,3 @@ def _verify_response_integrity( else: # Update the server side session ID. self._server_session_id = response.server_side_session_id - - -class RetryState: - """ - Simple state helper that captures the state between retries of the exceptions. It - keeps track of the last exception thrown and how many in total. When the task - finishes successfully done() returns True. - """ - - def __init__(self) -> None: - self._exception: Optional[BaseException] = None - self._done = False - self._count = 0 - - def set_exception(self, exc: BaseException) -> None: - self._exception = exc - self._count += 1 - - def throw(self) -> None: - raise self.exception() - - def exception(self) -> BaseException: - if self._exception is None: - raise RuntimeError("No exception is set") - return self._exception - - def set_done(self) -> None: - self._done = True - - def count(self) -> int: - return self._count - - def done(self) -> bool: - return self._done - - -class AttemptManager: - """ - Simple ContextManager that is used to capture the exception thrown inside the context. - """ - - def __init__(self, check: Callable[..., bool], retry_state: RetryState) -> None: - self._retry_state = retry_state - self._can_retry = check - - def __enter__(self) -> None: - pass - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - if isinstance(exc_val, BaseException): - # Swallow the exception. - if self._can_retry(exc_val) or isinstance(exc_val, RetryException): - self._retry_state.set_exception(exc_val) - return True - # Bubble up the exception. - return False - else: - self._retry_state.set_done() - return None - - def is_first_try(self) -> bool: - return self._retry_state._count == 0 - - -class Retrying: - """ - This helper class is used as a generator together with a context manager to - allow retrying exceptions in particular code blocks. The Retrying can be configured - with a lambda function that is can be filtered what kind of exceptions should be - retried. - - In addition, there are several parameters that are used to configure the exponential - backoff behavior. - - An example to use this class looks like this: - - .. code-block:: python - - for attempt in Retrying(can_retry=lambda x: isinstance(x, TransientError)): - with attempt: - # do the work. - - """ - - def __init__( - self, - max_retries: int, - initial_backoff: int, - max_backoff: int, - backoff_multiplier: float, - jitter: int, - min_jitter_threshold: int, - can_retry: Callable[..., bool] = lambda x: True, - sleep: Callable[[float], None] = time.sleep, - ) -> None: - self._can_retry = can_retry - self._max_retries = max_retries - self._initial_backoff = initial_backoff - self._max_backoff = max_backoff - self._backoff_multiplier = backoff_multiplier - self._jitter = jitter - self._min_jitter_threshold = min_jitter_threshold - self._sleep = sleep - - def __iter__(self) -> Generator[AttemptManager, None, None]: - """ - Generator function to wrap the exception producing code block. - - Returns - ------- - A generator that yields the current attempt. - """ - retry_state = RetryState() - next_backoff: float = self._initial_backoff - - if self._max_retries < 0: - raise ValueError("Can't have negative number of retries") - - while not retry_state.done() and retry_state.count() <= self._max_retries: - # Do backoff - if retry_state.count() > 0: - # Randomize backoff for this iteration - backoff = next_backoff - next_backoff = min(self._max_backoff, next_backoff * self._backoff_multiplier) - - if backoff >= self._min_jitter_threshold: - backoff += random.uniform(0, self._jitter) - - logger.debug( - f"Will retry call after {backoff} ms sleep (error: {retry_state.exception()})" - ) - self._sleep(backoff / 1000.0) - yield AttemptManager(self._can_retry, retry_state) - - if not retry_state.done(): - # Exceeded number of retries, throw last exception we had - retry_state.throw() diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 6addb5bd2c652..9fa0f25413375 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from pyspark.sql.connect.client.retries import Retrying, RetryException from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -22,7 +23,7 @@ import warnings import uuid from collections.abc import Generator -from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar +from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar from multiprocessing.pool import ThreadPool import os @@ -83,12 +84,12 @@ def __init__( self, request: pb2.ExecutePlanRequest, stub: grpc_lib.SparkConnectServiceStub, - retry_policy: Dict[str, Any], + retrying: Callable[[], Retrying], metadata: Iterable[Tuple[str, str]], ): ExecutePlanResponseReattachableIterator._initialize_pool_if_necessary() self._request = request - self._retry_policy = retry_policy + self._retrying = retrying if request.operation_id: self._operation_id = request.operation_id else: @@ -143,17 +144,12 @@ def send(self, value: Any) -> pb2.ExecutePlanResponse: return ret def _has_next(self) -> bool: - from pyspark.sql.connect.client.core import SparkConnectClient - from pyspark.sql.connect.client.core import Retrying - if self._result_complete: # After response complete response return False else: try: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): + for attempt in self._retrying(): with attempt: if self._current is None: try: @@ -199,16 +195,11 @@ def _release_until(self, until_response_id: str) -> None: if self._result_complete: return - from pyspark.sql.connect.client.core import SparkConnectClient - from pyspark.sql.connect.client.core import Retrying - request = self._create_release_execute_request(until_response_id) def target() -> None: try: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): + for attempt in self._retrying(): with attempt: self._stub.ReleaseExecute(request, metadata=self._metadata) except Exception as e: @@ -228,16 +219,11 @@ def _release_all(self) -> None: if self._result_complete: return - from pyspark.sql.connect.client.core import SparkConnectClient - from pyspark.sql.connect.client.core import Retrying - request = self._create_release_execute_request(None) def target() -> None: try: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): + for attempt in self._retrying(): with attempt: self._stub.ReleaseExecute(request, metadata=self._metadata) except Exception as e: @@ -331,10 +317,3 @@ def close(self) -> None: def __del__(self) -> None: return self.close() - - -class RetryException(Exception): - """ - An exception that can be thrown upstream when inside retry and which will be retryable - regardless of policy. - """ diff --git a/python/pyspark/sql/connect/client/retries.py b/python/pyspark/sql/connect/client/retries.py new file mode 100644 index 0000000000000..6aa959e09b5b0 --- /dev/null +++ b/python/pyspark/sql/connect/client/retries.py @@ -0,0 +1,293 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import grpc +import random +import time +import typing +from typing import Optional, Callable, Generator, List, Type +from types import TracebackType +from pyspark.sql.connect.client.logging import logger + +""" +This module contains retry system. The system is designed to be +significantly customizable. + +A key aspect of retries is RetryPolicy class, describing a single policy. +There can be more than one policy defined at the same time. Each policy +determines which error types it can retry and how exactly. + +For instance, networking errors should likely be retried differently that +remote resource being unavailable. + +Given a sequence of policies, retry logic applies all of them in sequential +order, keeping track of different policies budgets. +""" + + +class RetryPolicy: + """ + Describes key aspects of RetryPolicy. + + It's advised that different policies are implemented as different subclasses. + """ + + def __init__( + self, + max_retries: Optional[int] = None, + initial_backoff: int = 1000, + max_backoff: Optional[int] = None, + backoff_multiplier: float = 1.0, + jitter: int = 0, + min_jitter_threshold: int = 0, + ): + self.max_retries = max_retries + self.initial_backoff = initial_backoff + self.max_backoff = max_backoff + self.backoff_multiplier = backoff_multiplier + self.jitter = jitter + self.min_jitter_threshold = min_jitter_threshold + self._name = self.__class__.__name__ + + @property + def name(self) -> str: + return self._name + + def can_retry(self, exception: BaseException) -> bool: + return False + + def to_state(self) -> "RetryPolicyState": + return RetryPolicyState(self) + + +class RetryPolicyState: + """ + This class represents stateful part of the specific policy. + """ + + def __init__(self, policy: RetryPolicy): + self._policy = policy + + # Will allow attempts [0, self._policy.max_retries) + self._attempt = 0 + self._next_wait: float = self._policy.initial_backoff + + @property + def policy(self) -> RetryPolicy: + return self._policy + + @property + def name(self) -> str: + return self.policy.name + + def can_retry(self, exception: BaseException) -> bool: + return self.policy.can_retry(exception) + + def next_attempt(self) -> Optional[int]: + """ + Returns + ------- + Randomized time (in milliseconds) to wait until this attempt + or None if this policy doesn't allow more retries. + """ + + if self.policy.max_retries is not None and self._attempt >= self.policy.max_retries: + # No more retries under this policy + return None + + self._attempt += 1 + wait_time = self._next_wait + + # Calculate future backoff + if self.policy.max_backoff is not None: + self._next_wait = min( + float(self.policy.max_backoff), wait_time * self.policy.backoff_multiplier + ) + + # Jitter current backoff, after the future backoff was computed + if wait_time >= self.policy.min_jitter_threshold: + wait_time += random.uniform(0, self.policy.jitter) + + # Round to whole number of milliseconds + return int(wait_time) + + +class AttemptManager: + """ + Simple ContextManager that is used to capture the exception thrown inside the context. + """ + + def __init__(self, retrying: "Retrying") -> None: + self._retrying = retrying + + def __enter__(self) -> None: + pass + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exception: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + if isinstance(exception, BaseException): + # Swallow the exception. + if self._retrying.accept_exception(exception): + return True + # Bubble up the exception. + return False + else: + self._retrying.accept_succeeded() + return None + + +class Retrying: + """ + This class is a point of entry into the retry logic. + The class accepts a list of retry policies and applies them in given order. + The first policy accepting an exception will be used. + + The usage of the class should be as follows: + for attempt in Retrying(...): + with attempt: + Do something that can throw exception + + In case error is considered retriable, it would be retried based on policies, and + RetriesExceeded will be raised if the retries limit would exceed. + + Exceptions not considered retriable will be passed through transparently. + """ + + def __init__( + self, + policies: typing.Union[RetryPolicy, typing.Iterable[RetryPolicy]], + sleep: Callable[[float], None] = time.sleep, + ) -> None: + if isinstance(policies, RetryPolicy): + policies = [policies] + self._policies: List[RetryPolicyState] = [policy.to_state() for policy in policies] + self._sleep = sleep + + self._exception: Optional[BaseException] = None + self._done = False + + def can_retry(self, exception: BaseException) -> bool: + return any(policy.can_retry(exception) for policy in self._policies) + + def accept_exception(self, exception: BaseException) -> bool: + if self.can_retry(exception): + self._exception = exception + return True + return False + + def accept_succeeded(self) -> None: + self._done = True + + def _last_exception(self) -> BaseException: + if self._exception is None: + raise RuntimeError("No active exception") + return self._exception + + def _wait(self) -> None: + exception = self._last_exception() + + # Attempt to find a policy to wait with + + for policy in self._policies: + if not policy.can_retry(exception): + continue + + wait_time = policy.next_attempt() + if wait_time is not None: + logger.debug( + f"Got error: {repr(exception)}. " + + f"Will retry after {wait_time} ms (policy: {policy.name})" + ) + + self._sleep(wait_time / 1000) + return + + # Exceeded retries + logger.debug(f"Given up on retrying. error: {repr(exception)}") + raise RetriesExceeded from exception + + def __iter__(self) -> Generator[AttemptManager, None, None]: + """ + Generator function to wrap the exception producing code block. + + Returns + ------- + A generator that yields the current attempt. + """ + + # First attempt is free, no need to do waiting. + yield AttemptManager(self) + + while not self._done: + self._wait() + yield AttemptManager(self) + + +class RetryException(Exception): + """ + An exception that can be thrown upstream when inside retry and which is always retryable + """ + + +class DefaultPolicy(RetryPolicy): + def __init__(self, **kwargs): # type: ignore[no-untyped-def] + super().__init__(**kwargs) + + def can_retry(self, e: BaseException) -> bool: + """ + Helper function that is used to identify if an exception thrown by the server + can be retried or not. + + Parameters + ---------- + e : Exception + The GRPC error as received from the server. Typed as Exception, because other exception + thrown during client processing can be passed here as well. + + Returns + ------- + True if the exception can be retried, False otherwise. + + """ + if isinstance(e, RetryException): + return True + + if not isinstance(e, grpc.RpcError): + return False + + if e.code() in [grpc.StatusCode.INTERNAL]: + msg = str(e) + + # This error happens if another RPC preempts this RPC. + if "INVALID_CURSOR.DISCONNECTED" in msg: + return True + + if e.code() == grpc.StatusCode.UNAVAILABLE: + return True + + return False + + +class RetriesExceeded(Exception): + """ + Represents an exception which is considered retriable, but retry limits + were exceeded + """ diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index fb137662f42ff..580ebc3965bb5 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -28,11 +28,13 @@ import pandas as pd import pyarrow as pa from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder - from pyspark.sql.connect.client.core import Retrying - from pyspark.sql.connect.client.reattach import ( + from pyspark.sql.connect.client.retries import ( + Retrying, + DefaultPolicy, RetryException, - ExecutePlanResponseReattachableIterator, + RetriesExceeded, ) + from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator import pyspark.sql.connect.proto as proto @@ -107,17 +109,25 @@ def sleep(t): total_sleep += t try: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, sleep=sleep, **client._retry_policy - ): + for attempt in Retrying(client._retry_policies, sleep=sleep): with attempt: raise RetryException() - except RetryException: + except RetriesExceeded: pass # tolerated at least 10 mins of fails self.assertGreaterEqual(total_sleep, 600) + def test_retry_client_unit(self): + client = SparkConnectClient("sc://foo/;token=bar") + + policyA = TestPolicy() + policyB = DefaultPolicy() + + client.set_retry_policies([policyA, policyB]) + + self.assertEqual(client.get_retry_policies(), [policyA, policyB]) + def test_channel_builder_with_session(self): dummy = str(uuid.uuid4()) chan = ChannelBuilder(f"sc://foo/;session_id={dummy}") @@ -125,18 +135,23 @@ def test_channel_builder_with_session(self): self.assertEqual(client._session_id, chan.session_id) +class TestPolicy(DefaultPolicy): + def __init__(self): + super().__init__( + max_retries=3, + backoff_multiplier=4.0, + initial_backoff=10, + max_backoff=10, + jitter=10, + min_jitter_threshold=10, + ) + + @unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectClientReattachTestCase(unittest.TestCase): def setUp(self) -> None: self.request = proto.ExecutePlanRequest() - self.policy = { - "max_retries": 3, - "backoff_multiplier": 4.0, - "initial_backoff": 10, - "max_backoff": 10, - "jitter": 10, - "min_jitter_threshold": 10, - } + self.retrying = lambda: Retrying(TestPolicy()) self.response = proto.ExecutePlanResponse( response_id="1", ) @@ -153,7 +168,7 @@ def _stub_with(self, execute=None, attach=None): def test_basic_flow(self): stub = self._stub_with([self.response, self.finished]) - ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) for b in ite: pass @@ -171,7 +186,7 @@ def fatal(): stub = self._stub_with([self.response, fatal]) with self.assertRaises(TestException): - ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) for b in ite: pass @@ -190,7 +205,7 @@ def non_fatal(): stub = self._stub_with( [self.response, non_fatal], [self.response, self.response, self.finished] ) - ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) for b in ite: pass @@ -216,7 +231,7 @@ def non_fatal(): stub = self._stub_with( [self.response, non_fatal], [self.response, non_fatal, self.response, self.finished] ) - ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.policy, []) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) for b in ite: pass diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 7a224d68219b0..e926eb835a80e 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -34,6 +34,7 @@ ) from pyspark.errors.exceptions.base import SessionNotSameException from pyspark.sql import SparkSession as PySparkSession, Row +from pyspark.sql.connect.client.retries import RetryPolicy, RetriesExceeded from pyspark.sql.types import ( StructType, StructField, @@ -3484,128 +3485,140 @@ def test_config(self): self.assertEqual(self.spark.conf.get("integer"), "1") +class TestError(grpc.RpcError, Exception): + def __init__(self, code: grpc.StatusCode): + self._code = code + + def code(self): + return self._code + + +class TestPolicy(RetryPolicy): + # Put a small value for initial backoff so that tests don't spend + # Time waiting + def __init__(self, initial_backoff=10, **kwargs): + super().__init__(initial_backoff=initial_backoff, **kwargs) + + def can_retry(self, exception: BaseException): + return isinstance(exception, TestError) + + +class TestPolicySpecificError(TestPolicy): + def __init__(self, specific_code: grpc.StatusCode, **kwargs): + super().__init__(**kwargs) + self.specific_code = specific_code + + def can_retry(self, exception: BaseException): + return exception.code() == self.specific_code + + @unittest.skipIf(not should_test_connect, connect_requirement_message) -class ClientTests(unittest.TestCase): - def test_retry_error_handling(self): - # Helper class for wrapping the test. - class TestError(grpc.RpcError, Exception): - def __init__(self, code: grpc.StatusCode): - self._code = code - - def code(self): - return self._code - - def stub(retries, w, code): - w["attempts"] += 1 - if w["attempts"] < retries: - w["raised"] += 1 - raise TestError(code) +class RetryTests(unittest.TestCase): + def setUp(self) -> None: + self.call_wrap = defaultdict(int) + def stub(self, retries, code): + self.call_wrap["attempts"] += 1 + if self.call_wrap["attempts"] < retries: + self.call_wrap["raised"] += 1 + raise TestError(code) + + def test_simple(self): # Check that max_retries 1 is only one retry so two attempts. - call_wrap = defaultdict(int) - for attempt in Retrying( - can_retry=lambda x: True, - max_retries=1, - backoff_multiplier=1, - initial_backoff=1, - max_backoff=10, - jitter=0, - min_jitter_threshold=0, - ): + for attempt in Retrying(TestPolicy(max_retries=1)): with attempt: - stub(2, call_wrap, grpc.StatusCode.INTERNAL) + self.stub(2, grpc.StatusCode.INTERNAL) - self.assertEqual(2, call_wrap["attempts"]) - self.assertEqual(1, call_wrap["raised"]) + self.assertEqual(2, self.call_wrap["attempts"]) + self.assertEqual(1, self.call_wrap["raised"]) + def test_below_limit(self): # Check that if we have less than 4 retries all is ok. - call_wrap = defaultdict(int) - for attempt in Retrying( - can_retry=lambda x: True, - max_retries=4, - backoff_multiplier=1, - initial_backoff=1, - max_backoff=10, - jitter=0, - min_jitter_threshold=0, - ): + for attempt in Retrying(TestPolicy(max_retries=4)): with attempt: - stub(2, call_wrap, grpc.StatusCode.INTERNAL) + self.stub(2, grpc.StatusCode.INTERNAL) - self.assertTrue(call_wrap["attempts"] < 4) - self.assertEqual(call_wrap["raised"], 1) + self.assertLess(self.call_wrap["attempts"], 4) + self.assertEqual(self.call_wrap["raised"], 1) + def test_exceed_retries(self): # Exceed the retries. - call_wrap = defaultdict(int) - with self.assertRaises(TestError): - for attempt in Retrying( - can_retry=lambda x: True, - max_retries=2, - max_backoff=50, - backoff_multiplier=1, - initial_backoff=50, - jitter=0, - min_jitter_threshold=0, - ): + with self.assertRaises(RetriesExceeded): + for attempt in Retrying(TestPolicy(max_retries=2)): with attempt: - stub(5, call_wrap, grpc.StatusCode.INTERNAL) + self.stub(5, grpc.StatusCode.INTERNAL) - self.assertTrue(call_wrap["attempts"] < 5) - self.assertEqual(call_wrap["raised"], 3) + self.assertLess(self.call_wrap["attempts"], 5) + self.assertEqual(self.call_wrap["raised"], 3) + def test_throw_not_retriable_error(self): + with self.assertRaises(ValueError): + for attempt in Retrying(TestPolicy(max_retries=2)): + with attempt: + raise ValueError + + def test_specific_exception(self): # Check that only specific exceptions are retried. # Check that if we have less than 4 retries all is ok. - call_wrap = defaultdict(int) - for attempt in Retrying( - can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE, - max_retries=4, - backoff_multiplier=1, - initial_backoff=1, - max_backoff=10, - jitter=0, - min_jitter_threshold=0, - ): + policy = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.UNAVAILABLE) + + for attempt in Retrying(policy): with attempt: - stub(2, call_wrap, grpc.StatusCode.UNAVAILABLE) + self.stub(2, grpc.StatusCode.UNAVAILABLE) - self.assertTrue(call_wrap["attempts"] < 4) - self.assertEqual(call_wrap["raised"], 1) + self.assertLess(self.call_wrap["attempts"], 4) + self.assertEqual(self.call_wrap["raised"], 1) + def test_specific_exception_exceed_retries(self): # Exceed the retries. - call_wrap = defaultdict(int) - with self.assertRaises(TestError): - for attempt in Retrying( - can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE, - max_retries=2, - max_backoff=50, - backoff_multiplier=1, - initial_backoff=50, - jitter=0, - min_jitter_threshold=0, - ): + policy = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.UNAVAILABLE) + with self.assertRaises(RetriesExceeded): + for attempt in Retrying(policy): with attempt: - stub(5, call_wrap, grpc.StatusCode.UNAVAILABLE) + self.stub(5, grpc.StatusCode.UNAVAILABLE) - self.assertTrue(call_wrap["attempts"] < 4) - self.assertEqual(call_wrap["raised"], 3) + self.assertLess(self.call_wrap["attempts"], 4) + self.assertEqual(self.call_wrap["raised"], 3) + def test_rejected_by_policy(self): # Test that another error is always thrown. - call_wrap = defaultdict(int) + policy = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.UNAVAILABLE) + with self.assertRaises(TestError): - for attempt in Retrying( - can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE, - max_retries=4, - backoff_multiplier=1, - initial_backoff=1, - max_backoff=10, - jitter=0, - min_jitter_threshold=0, - ): + for attempt in Retrying(policy): + with attempt: + self.stub(5, grpc.StatusCode.INTERNAL) + + self.assertEqual(self.call_wrap["attempts"], 1) + self.assertEqual(self.call_wrap["raised"], 1) + + def test_multiple_policies(self): + policy1 = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.UNAVAILABLE) + policy2 = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.INTERNAL) + + # Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors + + error_suply = iter([grpc.StatusCode.UNAVAILABLE] * 2 + [grpc.StatusCode.INTERNAL] * 4) + + for attempt in Retrying([policy1, policy2]): + with attempt: + error = next(error_suply, None) + if error: + raise TestError(error) + + self.assertEqual(next(error_suply, None), None) + + def test_multiple_policies_exceed(self): + policy1 = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.INTERNAL) + policy2 = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.INTERNAL) + + with self.assertRaises(RetriesExceeded): + for attempt in Retrying([policy1, policy2]): with attempt: - stub(5, call_wrap, grpc.StatusCode.INTERNAL) + self.stub(10, grpc.StatusCode.INTERNAL) - self.assertEqual(call_wrap["attempts"], 1) - self.assertEqual(call_wrap["raised"], 1) + self.assertEqual(self.call_wrap["attempts"], 7) + self.assertEqual(self.call_wrap["raised"], 7) @unittest.skipIf(not should_test_connect, connect_requirement_message) From 9a990b5a40de3c3a17801dd4dbd4f48e3f399815 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 12 Nov 2023 16:57:26 -0800 Subject: [PATCH 107/121] [SPARK-45901][DOCS] Collect and update `spark-standalone.md` with new confs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR aims to collect and make `spark-standalone.md` up-to-date with new configurations to improve UX. ### Why are the changes needed? By collecting and syncing with the code, the users can find all informations in a single page. 1. For example, the existing `zookeeper`-related configuration in `Configuration -> Deploy` section although it's only applicable in `Spark Standalone` cluster. In addition, it's partially described. We had better move it to `Spark Standalone` page. - From: https://spark.apache.org/docs/latest/configuration.html#deploy - To: https://spark.apache.org/docs/latest/spark-standalone.html 2. Add missing configurations. - spark.dead.worker.persistence 3. Revise the default value - spark.deploy.defaultCores 3. Add all Spark 4.0.0 configurations. - spark.deploy.maxDrivers - spark.deploy.appNumberModulo - spark.deploy.driverIdPattern - spark.deploy.appIdPattern - spark.worker.idPattern ### Does this PR introduce _any_ user-facing change? Yes, but this is a documentation only. ### How was this patch tested? Manual review. **CHANGED PART 1** Screenshot 2023-11-12 at 4 18 21 PM **CHANGED PART 2** Screenshot 2023-11-12 at 4 18 08 PM **CHANGED PART 3** Screenshot 2023-11-12 at 4 54 38 PM ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43774 from dongjoon-hyun/SPARK-45901. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- docs/configuration.md | 26 ------------- docs/spark-standalone.md | 84 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 81 insertions(+), 29 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 3d54aaf6518be..75f597fdb4c6c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -3495,32 +3495,6 @@ External users can query the static sql config values via `SparkSession.conf` or -### Deploy - - - - - - - - - - - - - - - - - - - - - -
    Property NameDefaultMeaningSince Version
    spark.deploy.recoveryModeNONEThe recovery mode setting to recover submitted Spark jobs with cluster mode when it failed and relaunches. - This is only applicable for cluster mode when running with Standalone.0.8.1
    spark.deploy.zookeeper.urlNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper URL to connect to.0.8.1
    spark.deploy.zookeeper.dirNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper directory to store recovery state.0.8.1
    - - ### Cluster Managers Each cluster manager in Spark has additional configuration options. Configurations diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 93f0818b6ce8c..bc13693e28050 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -218,7 +218,7 @@ SPARK_MASTER_OPTS supports the following system properties: spark.deploy.defaultCores - (infinite) + Int.MaxValue Default number of cores to give to applications in Spark's standalone mode if they don't set spark.cores.max. If not set, applications always get all available @@ -244,6 +244,43 @@ SPARK_MASTER_OPTS supports the following system properties: 1.6.3 + + spark.deploy.maxDrivers + Int.MaxValue + + The maximum number of running drivers. + + 4.0.0 + + + spark.deploy.appNumberModulo + (None) + + The modulo for app number. By default, the next of `app-yyyyMMddHHmmss-9999` is + `app-yyyyMMddHHmmss-10000`. If we have 10000 as modulo, it will be `app-yyyyMMddHHmmss-0000`. + In most cases, the prefix `app-yyyyMMddHHmmss` is increased already during creating 10000 applications. + + 4.0.0 + + + spark.deploy.driverIdPattern + driver-%s-%04d + + The pattern for driver ID generation based on Java `String.format` method. + The default value is `driver-%s-%04d` which represents the existing driver id string, e.g., `driver-20231031224459-0019`. Please be careful to generate unique IDs. + + 4.0.0 + + + spark.deploy.appIdPattern + app-%s-%04d + + The pattern for app ID generation based on Java `String.format` method. + The default value is `app-%s-%04d` which represents the existing app id string, e.g., + `app-20231031224509-0008`. Plesae be careful to generate unique IDs. + + 4.0.0 + spark.worker.timeout 60 @@ -253,6 +290,14 @@ SPARK_MASTER_OPTS supports the following system properties: 0.6.2 + + spark.dead.worker.persistence + 15 + + Number of iterations to keep the deae worker information in UI. By default, the dead worker is visible for (15 + 1) * spark.worker.timeout since its last heartbeat. + + 0.8.0 + spark.worker.resource.{resourceName}.amount (none) @@ -363,6 +408,16 @@ SPARK_WORKER_OPTS supports the following system properties: 2.0.2 + + spark.worker.idPattern + worker-%s-%s-%d + + The pattern for worker ID generation based on Java `String.format` method. + The default value is `worker-%s-%s-%d` which represents the existing worker id string, e.g., + `worker-20231109183042-[fe80::1%lo0]-39729`. Please be careful to generate unique IDs + + 4.0.0 + # Resource Allocation and Configuration Overview @@ -542,17 +597,40 @@ ZooKeeper is the best way to go for production-level high availability, but if y In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env using this configuration: - + - + + + + + + + + + + + + + + + + + + + + +
    System propertyMeaningSince Version
    System propertyDefault ValueMeaningSince Version
    spark.deploy.recoveryModeSet to FILESYSTEM to enable single-node recovery mode (default: NONE).NONEThe recovery mode setting to recover submitted Spark jobs with cluster mode when it failed and relaunches. + Set to FILESYSTEM to enable single-node recovery mode, ZOOKEEPER to use Zookeeper-based recovery mode, and + CUSTOM to provide a customer provider class via additional `spark.deploy.recoveryMode.factory` configuration. + 0.8.1
    spark.deploy.recoveryDirectory"" The directory in which Spark will store recovery state, accessible from the Master's perspective. 0.8.1
    spark.deploy.recoveryMode.factory""A class to implement StandaloneRecoveryModeFactory interface1.2.0
    spark.deploy.zookeeper.urlNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper URL to connect to.0.8.1
    spark.deploy.zookeeper.dirNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper directory to store recovery state.0.8.1
    **Details** From bfbd3df699560b901c71ffef5912ace97106bca3 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 13 Nov 2023 17:40:22 +0900 Subject: [PATCH 108/121] [MINOR][PYTHON] Better error message when Python worker crushes ### What changes were proposed in this pull request? This PR improves the Python UDF error messages to be more actionable. ### Why are the changes needed? Suppose you face a segfault error: ```python from pyspark.sql.functions import udf import ctypes spark.range(1).select(udf(lambda x: ctypes.string_at(0))("id")).collect() ``` The current error message is not actionable: ``` Traceback (most recent call last): File "", line 1, in ... get_return_value raise Py4JJavaError( py4j.protocol.Py4JJavaError: An error occurred while calling o82.collectToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 15 in stage 1.0 failed 1 times, most recent failure: Lost task 15.0 in stage 1.0 (TID 31) (192.168.123.102 executor driver): org.apache.spark.SparkException: Python worker exited unexpectedly (crashed) ``` After this PR, it fixes the error message as below: ``` Traceback (most recent call last): File "", line 1, in ... get_return_value raise Py4JJavaError( py4j.protocol.Py4JJavaError: An error occurred while calling o59.collectToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 15 in stage 0.0 failed 1 times, most recent failure: Lost task 15.0 in stage 0.0 (TID 15) (192.168.123.102 executor driver): org.apache.spark.SparkException: Python worker exited unexpectedly (crashed). Consider setting 'spark.sql.execution.pyspark.udf.faulthandler.enabled' or 'spark.python.worker.faulthandler.enabled' configuration to 'true' forthe better Python traceback. ``` So you can try this out ```python from pyspark.sql.functions import udf import ctypes spark.conf.set("spark.sql.execution.pyspark.udf.faulthandler.enabled", "true") spark.range(1).select(udf(lambda x: ctypes.string_at(0))("id")).collect() ``` that now shows where the segfault happens: ``` Caused by: org.apache.spark.SparkException: Python worker exited unexpectedly (crashed): Fatal Python error: Segmentation fault Current thread 0x00007ff84ae4b700 (most recent call first): File "/.../envs/python3.9/lib/python3.9/ctypes/__init__.py", line 525 in string_at File "", line 1 in File "/.../lib/pyspark.zip/pyspark/util.py", line 88 in wrapper File "/.../lib/pyspark.zip/pyspark/worker.py", line 99 in File "/.../lib/pyspark.zip/pyspark/worker.py", line 1403 in File "/.../lib/pyspark.zip/pyspark/worker.py", line 1403 in mapper ``` ### Does this PR introduce _any_ user-facing change? Yes, it fixes the error message actionable. ### How was this patch tested? Manually tested as above. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43778 from HyukjinKwon/minor-error-improvement. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/api/python/PythonRunner.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 1a01ad1bc219a..d6363182606d9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -31,7 +31,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} +import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES, Python} import org.apache.spark.internal.config.Python._ import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} @@ -549,6 +549,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( JavaFiles.deleteIfExists(path) throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", e) + case eof: EOFException if !faultHandlerEnabled => + throw new SparkException( + s"Python worker exited unexpectedly (crashed). " + + "Consider setting 'spark.sql.execution.pyspark.udf.faulthandler.enabled' or" + + s"'${Python.PYTHON_WORKER_FAULTHANLDER_ENABLED.key}' configuration to 'true' for" + + "the better Python traceback.", eof) + case eof: EOFException => throw new SparkException("Python worker exited unexpectedly (crashed)", eof) } From adcea1c37b802cc1317244f91d3d84867f0ad6f5 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 13 Nov 2023 17:54:56 +0900 Subject: [PATCH 109/121] [SPARK-45859][ML] Make UDF objects in ml.functions lazy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Since JVM runs static codes only once, if loading functions$ fails, it will always report java.lang.NoClassDefFoundError: Could not initialize class ``` 23/11/01 23:06:21 WARN TaskSetManager: Lost task 136.0 in stage 9565.0 (TID 4557384) (10.4.35.209 executor 16): TaskKilled (Stage cancelled: Job aborted due to stage failure: Task 2 in stage 9565.0 failed 4 times, most recent failure: Lost task 2.3 in stage 9565.0 (TID 4558369) (10.4.56.6 executor 71): java.io.IOException: unexpected exception type at java.io.ObjectStreamClass.throwMiscException(ObjectStreamClass.java:1750) at java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1280) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2222) … at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:900) at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23) at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:795) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:750) Caused by: java.lang.reflect.InvocationTargetException at sun.reflect.GeneratedMethodAccessor520.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.lang.invoke.SerializedLambda.readResolve(SerializedLambda.java:230) at sun.reflect.GeneratedMethodAccessor224.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1274) ... 388 more Caused by: java.lang.NoClassDefFoundError: Could not initialize class org.apache.spark.ml.functions$ ... 396 more ``` This PR just changes `functions.*` as lazy avoid hitting this issue because the initialization codes of a lazy val is not in static codes. ### Why are the changes needed? to fix a intermittent bug ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new UT ### Was this patch authored or co-authored using generative AI tooling? no Closes #43739 from zhengruifeng/ml_lazy_udfs. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/ml/functions.scala | 6 +- .../spark/ml/FunctionsLoadingSuite.scala | 60 +++++++++++++++++++ 2 files changed, 63 insertions(+), 3 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/FunctionsLoadingSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala index fa524ae8ac47e..46c437812233a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.functions.udf @Since("3.0.0") object functions { // scalastyle:on - private[spark] val vectorToArrayUdf = udf { vec: Any => + private[spark] lazy val vectorToArrayUdf = udf { vec: Any => vec match { case v: Vector => v.toArray case v: OldVector => v.toArray @@ -38,7 +38,7 @@ object functions { } }.asNonNullable() - private[spark] val vectorToArrayFloatUdf = udf { vec: Any => + private[spark] lazy val vectorToArrayFloatUdf = udf { vec: Any => vec match { case v: SparseVector => val data = new Array[Float](v.size) @@ -76,7 +76,7 @@ object functions { } } - private[spark] val arrayToVectorUdf = udf { array: Seq[Double] => + private[spark] lazy val arrayToVectorUdf = udf { array: Seq[Double] => Vectors.dense(array.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/FunctionsLoadingSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/FunctionsLoadingSuite.scala new file mode 100644 index 0000000000000..89480b86428de --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/FunctionsLoadingSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark._ +import org.apache.spark.ml.functions.{array_to_vector, vector_to_array} +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.functions.col + +class FunctionsLoadingSuite extends SparkFunSuite with LocalSparkContext { + + test("SPARK-45859: 'functions$' should not be affected by a broken class loader") { + quietly { + val conf = new SparkConf() + .setAppName("FunctionsLoadingSuite") + .setMaster("local-cluster[1,1,1024]") + sc = new SparkContext(conf) + // Make `functions$` be loaded by a broken class loader + intercept[SparkException] { + sc.parallelize(1 to 1).foreach { _ => + val originalClassLoader = Thread.currentThread.getContextClassLoader + try { + Thread.currentThread.setContextClassLoader(new BrokenClassLoader) + vector_to_array(col("vector")) + array_to_vector(col("array")) + } finally { + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } + } + + // We should be able to use `functions$` even it was loaded by a broken class loader + sc.parallelize(1 to 1).foreach { _ => + vector_to_array(col("vector")) + array_to_vector(col("array")) + } + } + } +} + +class BrokenClassLoader extends ClassLoader { + override def findClass(name: String): Class[_] = { + throw new Error(s"class $name") + } +} From 1a276bdb3d369efaa0ad806fb0a5d2f6f2920214 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 13 Nov 2023 17:29:37 +0800 Subject: [PATCH 110/121] [SPARK-45723][PYTHON][CONNECT][FOLLOWUP] Replace `toPandas` with `_to_table` in catalog methods ### What changes were proposed in this pull request? followup of https://github.com/apache/spark/pull/43583, replace `toPandas` with `_to_table` in catalog methods ### Why are the changes needed? pandas conversion not needed ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #43780 from zhengruifeng/py_catalog_arrow_followup. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/connect/catalog.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/catalog.py b/python/pyspark/sql/connect/catalog.py index 657aa7b6fb41d..e725e381b8dbe 100644 --- a/python/pyspark/sql/connect/catalog.py +++ b/python/pyspark/sql/connect/catalog.py @@ -223,7 +223,7 @@ def createExternalTable( options=options, ) df = DataFrame.withPlan(catalog, session=self._sparkSession) - df.toPandas() # Eager execution. + df._to_table() # Eager execution. return df createExternalTable.__doc__ = PySparkCatalog.createExternalTable.__doc__ @@ -246,7 +246,7 @@ def createTable( options=options, ) df = DataFrame.withPlan(catalog, session=self._sparkSession) - df.toPandas() # Eager execution. + df._to_table() # Eager execution. return df createTable.__doc__ = PySparkCatalog.createTable.__doc__ From c29b127dcdd99b0038e96b90177b44b828b32c4b Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Mon, 13 Nov 2023 19:31:22 +0800 Subject: [PATCH 111/121] [SPARK-45906][YARN] Fix error message extraction from ResourceNotFoundException ### What changes were proposed in this pull request? This PR aims to fix the error message extraction from `ResourceNotFoundException`, the current wrong implementation also has a potential NPE issue. ### Why are the changes needed? This bug is introduced in SPARK-43202, previously, `e.getCause()` is used to unwrap `InvocationTargetException`, after replacing reflection invocation with direct API calling, we should not apply `getCause()`. ### Does this PR introduce _any_ user-facing change? Yes, bug fix. ### How was this patch tested? Review. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43782 from pan3793/SPARK-45906. Authored-by: Cheng Pan Signed-off-by: Kent Yao --- .../org/apache/spark/deploy/yarn/ResourceRequestHelper.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala index 0dd4e0a6c8ad9..f9aa11c4d48d6 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala @@ -168,7 +168,7 @@ private object ResourceRequestHelper extends Logging { if (numResourceErrors < 2) { logWarning(s"YARN doesn't know about resource $name, your resource discovery " + s"has to handle properly discovering and isolating the resource! Error: " + - s"${e.getCause.getMessage}") + s"${e.getMessage}") numResourceErrors += 1 } } From d2a1a31a3800c2b64f1e45793267e84ef046e888 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Mon, 13 Nov 2023 21:17:27 +0900 Subject: [PATCH 112/121] [SPARK-45783][PYTHON][CONNECT] Improve error messages when Spark Connect mode is enabled but remote URL is not set ### What changes were proposed in this pull request? This PR improves the error messages when `SPARK_CONNECT_MODE_ENABLED` is defined but neither `spark.remote` option nor the `SPARK_REMOTE` env var is set. ### Why are the changes needed? To improve the error message. Currently the error looks like a bug: ``` url = opts.get("spark.remote", os.environ.get("SPARK_REMOTE")) > if url.startswith("local"): E AttributeError: 'NoneType' object has no attribute 'startswith' ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #43653 from allisonwang-db/spark-45783-fix-url-err. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/session.py | 8 ++++++++ python/pyspark/sql/tests/test_session.py | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 85aff09aa3df1..b4fad7ad29da2 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -481,6 +481,14 @@ def getOrCreate(self) -> "SparkSession": ): url = opts.get("spark.remote", os.environ.get("SPARK_REMOTE")) + if url is None: + raise RuntimeError( + "Cannot create a Spark Connect session because the " + "Spark Connect remote URL has not been set. Please define " + "the remote URL by setting either the 'spark.remote' option " + "or the 'SPARK_REMOTE' environment variable." + ) + if url.startswith("local"): os.environ["SPARK_LOCAL_REMOTE"] = "1" RemoteSparkSession._start_connect_server(url, opts) diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index 706b041bb514f..da27bf9257493 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -17,6 +17,7 @@ import os import unittest +import unittest.mock from pyspark import SparkConf, SparkContext from pyspark.sql import SparkSession, SQLContext, Row @@ -187,6 +188,11 @@ def test_active_session_with_None_and_not_None_context(self): if sc is not None: sc.stop() + def test_session_with_spark_connect_mode_enabled(self): + with unittest.mock.patch.dict(os.environ, {"SPARK_CONNECT_MODE_ENABLED": "1"}): + with self.assertRaisesRegex(RuntimeError, "Cannot create a Spark Connect session"): + SparkSession.builder.appName("test").getOrCreate() + class SparkSessionTests4(ReusedSQLTestCase): def test_get_active_session_after_create_dataframe(self): From 1ece299b15fd198a23d64743e152e61fd74750f5 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Mon, 13 Nov 2023 08:26:31 -0800 Subject: [PATCH 113/121] [SPARK-45902][SQL] Remove unused function `resolvePartitionColumns` from `DataSource` ### What changes were proposed in this pull request? `resolvePartitionColumns` was introduced by SPARK-37287 (https://github.com/apache/spark/pull/37099) and become unused after SPARK-41713 (https://github.com/apache/spark/pull/39220), so this pr remove it from `DataSource`. ### Why are the changes needed? Clean up unused code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43779 from LuciferYang/SPARK-45902. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- .../execution/datasources/DataSource.scala | 25 +------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b3784dbf81373..cd295f3b17bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -29,9 +29,8 @@ import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} -import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TypeUtils} import org.apache.spark.sql.connector.catalog.TableProvider @@ -822,26 +821,4 @@ object DataSource extends Logging { throw QueryCompilationErrors.writeEmptySchemasUnsupportedByDataSourceError() } } - - /** - * Resolve partition columns using output columns of the query plan. - */ - def resolvePartitionColumns( - partitionColumns: Seq[Attribute], - outputColumns: Seq[Attribute], - plan: LogicalPlan, - resolver: Resolver): Seq[Attribute] = { - partitionColumns.map { col => - // The partition columns created in `planForWritingFileFormat` should always be - // `UnresolvedAttribute` with a single name part. - assert(col.isInstanceOf[UnresolvedAttribute]) - val unresolved = col.asInstanceOf[UnresolvedAttribute] - assert(unresolved.nameParts.length == 1) - val name = unresolved.nameParts.head - outputColumns.find(a => resolver(a.name, name)).getOrElse { - throw QueryCompilationErrors.cannotResolveAttributeError( - name, plan.output.map(_.name).mkString(", ")) - } - } - } } From 2b046d2eb5aaa70f8bebf47020b09ec41ff58df9 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 13 Nov 2023 09:42:57 -0800 Subject: [PATCH 114/121] [SPARK-45907][CORE] Use Java9+ ProcessHandle APIs to computeProcessTree in ProcfsMetricsGetter ### What changes were proposed in this pull request? This PR uses Java9+ ProcessHandle APIs to computeProcessTree in ProcfsMetricsGetter. ### Why are the changes needed? Simplify the code ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #43783 from yaooqinn/SPARK-45907. Authored-by: Kent Yao Signed-off-by: Dongjoon Hyun --- .../spark/executor/ProcfsMetricsGetter.scala | 98 +++---------------- .../executor/ProcfsMetricsGetterSuite.scala | 46 ++++++++- 2 files changed, 55 insertions(+), 89 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala b/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala index 3120f69822830..b9a462d62e413 100644 --- a/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala +++ b/core/src/main/scala/org/apache/spark/executor/ProcfsMetricsGetter.scala @@ -22,11 +22,10 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} import java.util.Locale -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ import scala.util.Try -import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.SparkEnv import org.apache.spark.internal.{config, Logging} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -47,7 +46,7 @@ private[spark] class ProcfsMetricsGetter(procfsDir: String = "/proc/") extends L private val testing = Utils.isTesting private val pageSize = computePageSize() private var isAvailable: Boolean = isProcfsAvailable - private val pid = computePid() + private val currentProcessHandle = ProcessHandle.current() private lazy val isProcfsAvailable: Boolean = { if (testing) { @@ -65,26 +64,6 @@ private[spark] class ProcfsMetricsGetter(procfsDir: String = "/proc/") extends L } } - private def computePid(): Int = { - if (!isAvailable || testing) { - return -1; - } - try { - // This can be simplified in java9: - // https://docs.oracle.com/javase/9/docs/api/java/lang/ProcessHandle.html - val cmd = Array("bash", "-c", "echo $PPID") - val out = Utils.executeAndGetOutput(cmd.toImmutableArraySeq) - Integer.parseInt(out.split("\n")(0)) - } - catch { - case e: SparkException => - logDebug("Exception when trying to compute process tree." + - " As a result reporting of ProcessTree metrics is stopped", e) - isAvailable = false - -1 - } - } - private def computePageSize(): Long = { if (testing) { return 4096; @@ -102,70 +81,10 @@ private[spark] class ProcfsMetricsGetter(procfsDir: String = "/proc/") extends L } } - // Exposed for testing - private[executor] def computeProcessTree(): Set[Int] = { - if (!isAvailable || testing) { - return Set() - } - var ptree: Set[Int] = Set() - ptree += pid - val queue = mutable.Queue.empty[Int] - queue += pid - while ( !queue.isEmpty ) { - val p = queue.dequeue() - val c = getChildPids(p) - if (!c.isEmpty) { - queue ++= c - ptree ++= c.toSet - } - } - ptree - } - - private def getChildPids(pid: Int): ArrayBuffer[Int] = { - try { - val builder = new ProcessBuilder("pgrep", "-P", pid.toString) - val process = builder.start() - val childPidsInInt = mutable.ArrayBuffer.empty[Int] - def appendChildPid(s: String): Unit = { - if (s != "") { - logTrace("Found a child pid:" + s) - childPidsInInt += Integer.parseInt(s) - } - } - val stdoutThread = Utils.processStreamByLine("read stdout for pgrep", - process.getInputStream, appendChildPid) - val errorStringBuilder = new StringBuilder() - val stdErrThread = Utils.processStreamByLine( - "stderr for pgrep", - process.getErrorStream, - line => errorStringBuilder.append(line)) - val exitCode = process.waitFor() - stdoutThread.join() - stdErrThread.join() - val errorString = errorStringBuilder.toString() - // pgrep will have exit code of 1 if there are more than one child process - // and it will have a exit code of 2 if there is no child process - if (exitCode != 0 && exitCode > 2) { - val cmd = builder.command().toArray.mkString(" ") - logWarning(s"Process $cmd exited with code $exitCode and stderr: $errorString") - throw SparkException.internalError(msg = s"Process $cmd exited with code $exitCode", - category = "EXECUTOR") - } - childPidsInInt - } catch { - case e: Exception => - logDebug("Exception when trying to compute process tree." + - " As a result reporting of ProcessTree metrics is stopped.", e) - isAvailable = false - mutable.ArrayBuffer.empty[Int] - } - } - // Exposed for testing private[executor] def addProcfsMetricsFromOneProcess( allMetrics: ProcfsMetrics, - pid: Int): ProcfsMetrics = { + pid: Long): ProcfsMetrics = { // The computation of RSS and Vmem are based on proc(5): // http://man7.org/linux/man-pages/man5/proc.5.html @@ -207,6 +126,15 @@ private[spark] class ProcfsMetricsGetter(procfsDir: String = "/proc/") extends L } } + private[executor] def computeProcessTree(): Set[Long] = { + if (!isAvailable) { + Set.empty + } else { + val children = currentProcessHandle.descendants().map(_.pid()).toList.asScala.toSet + children + currentProcessHandle.pid() + } + } + private[spark] def computeAllMetrics(): ProcfsMetrics = { if (!isAvailable) { return ProcfsMetrics(0, 0, 0, 0, 0, 0) diff --git a/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala b/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala index a5b4814946f9d..573540180e6cb 100644 --- a/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ProcfsMetricsGetterSuite.scala @@ -17,13 +17,19 @@ package org.apache.spark.executor -import org.mockito.Mockito.{spy, when} +import org.mockito.Mockito.{mock, spy, when} +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.{interval, timeout} +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} +import org.apache.spark.internal.config.EXECUTOR_PROCESS_TREE_METRICS_ENABLED +import org.apache.spark.util.Utils class ProcfsMetricsGetterSuite extends SparkFunSuite { - + private val sparkHome = + sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) test("testGetProcessInfo") { val p = new ProcfsMetricsGetter(getTestResourcePath("ProcfsMetrics")) @@ -45,7 +51,7 @@ class ProcfsMetricsGetterSuite extends SparkFunSuite { val p = new ProcfsMetricsGetter(getTestResourcePath("ProcfsMetrics")) val mockedP = spy[ProcfsMetricsGetter](p) - var ptree: Set[Int] = Set(26109, 22763) + var ptree: Set[Long] = Set(26109, 22763) when(mockedP.computeProcessTree()).thenReturn(ptree) var r = mockedP.computeAllMetrics() assert(r.jvmVmemTotal == 4769947648L) @@ -62,4 +68,36 @@ class ProcfsMetricsGetterSuite extends SparkFunSuite { assert(r.pythonVmemTotal == 0) assert(r.pythonRSSTotal == 0) } + + test("SPARK-45907: Use ProcessHandle APIs to computeProcessTree in ProcfsMetricsGetter") { + val originalSparkEnv = SparkEnv.get + val sparkEnv = mock(classOf[SparkEnv]) + val conf = new SparkConf(false) + .set(EXECUTOR_PROCESS_TREE_METRICS_ENABLED, true) + when(sparkEnv.conf).thenReturn(conf) + try { + SparkEnv.set(sparkEnv) + val p = new ProcfsMetricsGetter() + val currentPid = ProcessHandle.current().pid() + val process = Utils.executeCommand(Seq( + s"$sparkHome/bin/spark-class", + this.getClass.getCanonicalName.stripSuffix("$"), + currentPid.toString)) + val child = process.toHandle.pid() + eventually(timeout(10.seconds), interval(100.milliseconds)) { + val pids = p.computeProcessTree() + assert(pids.size === 3) + assert(pids.contains(currentPid)) + assert(pids.contains(child)) + } + } finally { + SparkEnv.set(originalSparkEnv) + } + } +} + +object ProcfsMetricsGetterSuite { + def main(args: Array[String]): Unit = { + Utils.executeCommand(Seq("jstat", "-gcutil", args(0), "50", "100")) + } } From a6b089fa00c2736bafd7bd374f401c392da9cf80 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 13 Nov 2023 10:50:55 -0800 Subject: [PATCH 115/121] [SPARK-45909][SQL] Remove `NumericType` cast if it can safely up-cast in `IsNotNull` ### What changes were proposed in this pull request? Similar to SPARK-37922. We can remove the cast if it can safely up-cast in `IsNotNull`. ### Why are the changes needed? Improve the query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43785 from wangyum/SPARK-45909. Authored-by: Yuming Wang Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/optimizer/expressions.scala | 2 ++ .../optimizer/SimplifyCastsSuite.scala | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 2c73e25cb5ed1..9e9e6bd905b94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -1063,6 +1063,8 @@ object SimplifyCasts extends Rule[LogicalPlan] { if fromKey == toKey && fromValue == toValue => e case _ => c } + case IsNotNull(Cast(e, dataType: NumericType, _, _)) if isWiderCast(e.dataType, dataType) => + IsNotNull(e) } // Returns whether the from DataType can be safely casted to the to DataType without losing diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index 741b1bb8c082c..2eb4830bca98b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -117,4 +117,25 @@ class SimplifyCastsSuite extends PlanTest { input.select($"d".cast(LongType).cast(StringType).as("casted")).analyze), input.select($"d".cast(LongType).cast(StringType).as("casted")).analyze) } + + test("SPARK-45909: Remove the cast if it can safely up-cast in IsNotNull") { + val input = LocalRelation($"a".int, $"b".decimal(18, 0)) + // Remove cast + comparePlans( + Optimize.execute( + input.select($"a".cast(DecimalType(18, 1)).isNotNull.as("v")).analyze), + input.select($"a".isNotNull.as("v")).analyze) + comparePlans( + Optimize.execute(input.select($"a".cast(LongType).isNotNull.as("v")).analyze), + input.select($"a".isNotNull.as("v")).analyze) + comparePlans( + Optimize.execute(input.select($"b".cast(LongType).isNotNull.as("v")).analyze), + input.select($"b".isNotNull.as("v")).analyze) + + // Can not remove cast + comparePlans( + Optimize.execute( + input.select($"a".cast(DecimalType(2, 1)).as("v")).analyze), + input.select($"a".cast(DecimalType(2, 1)).as("v")).analyze) + } } From 7cea52c96f5be1bc565a033bfd77370ab5527a35 Mon Sep 17 00:00:00 2001 From: Xi Liang Date: Tue, 14 Nov 2023 05:28:07 +0800 Subject: [PATCH 116/121] [SPARK-45892][SQL] Refactor optimizer plan validation to decouple `validateSchemaOutput` and `validateExprIdUniqueness` ### What changes were proposed in this pull request? Currently, the expressionIDUniquenessValidation is [closely coupled with outputSchemaValidation](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala#L403C7-L411C8). This PR refactors the code to improve readability and maintainability. ### Why are the changes needed? Improve code readability and maintainability. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43761 from xil-db/SPARK-45892-validation-refactor. Authored-by: Xi Liang Signed-off-by: Wenchen Fan --- .../catalyst/plans/logical/LogicalPlan.scala | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ae3029b279da4..cce385e8d9d16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -381,6 +381,15 @@ object LogicalPlanIntegrity { }.flatten } + def validateSchemaOutput(previousPlan: LogicalPlan, currentPlan: LogicalPlan): Option[String] = { + if (!DataTypeUtils.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema)) { + Some(s"The plan output schema has changed from ${previousPlan.schema.sql} to " + + currentPlan.schema.sql + s". The previous plan: ${previousPlan.treeString}\nThe new " + + "plan:\n" + currentPlan.treeString) + } else { + None + } + } /** * Validate the structural integrity of an optimized plan. @@ -400,17 +409,11 @@ object LogicalPlanIntegrity { } else if (currentPlan.exists(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty)) { Some("Special expressions are placed in the wrong plan: " + currentPlan.treeString) } else { - LogicalPlanIntegrity.validateExprIdUniqueness(currentPlan).orElse { - if (!DataTypeUtils.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema)) { - Some(s"The plan output schema has changed from ${previousPlan.schema.sql} to " + - currentPlan.schema.sql + s". The previous plan: ${previousPlan.treeString}\nThe new " + - "plan:\n" + currentPlan.treeString) - } else { - None - } - } + None } validation = validation + .orElse(LogicalPlanIntegrity.validateExprIdUniqueness(currentPlan)) + .orElse(LogicalPlanIntegrity.validateSchemaOutput(previousPlan, currentPlan)) .orElse(LogicalPlanIntegrity.validateNoDanglingReferences(currentPlan)) .orElse(LogicalPlanIntegrity.validateGroupByTypes(currentPlan)) .orElse(LogicalPlanIntegrity.validateAggregateExpressions(currentPlan)) From 961dcdfb98455f341c3f6279fa65aa1dd58ca199 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Tue, 14 Nov 2023 05:42:13 +0800 Subject: [PATCH 117/121] [SPARK-45882][SQL] BroadcastHashJoinExec propagate partitioning should respect CoalescedHashPartitioning ### What changes were proposed in this pull request? Add HashPartitioningLike trait and make HashPartitioning and CoalescedHashPartitioning extend it. When we propagate output partiitoning, we should handle HashPartitioningLike instead of HashPartitioning. This pr also changes the BroadcastHashJoinExec to use HashPartitioningLike to avoid regression. ### Why are the changes needed? Avoid unnecessary shuffle exchange. ### Does this PR introduce _any_ user-facing change? yes, avoid regression ### How was this patch tested? add test ### Was this patch authored or co-authored using generative AI tooling? no Closes #43753 from ulysses-you/partitioning. Authored-by: ulysses-you Signed-off-by: Wenchen Fan --- .../plans/physical/partitioning.scala | 46 +++++++++---------- .../joins/BroadcastHashJoinExec.scala | 11 +++-- .../org/apache/spark/sql/JoinSuite.scala | 28 ++++++++++- 3 files changed, 54 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0ae2857161c8c..60e6e42bedf87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -258,18 +258,8 @@ case object SinglePartition extends Partitioning { SinglePartitionShuffleSpec } -/** - * Represents a partitioning where rows are split up across partitions based on the hash - * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be - * in the same partition. - * - * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires - * stateful operators to retain the same physical partitioning during the lifetime of the query - * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged - * across Spark versions. Violation of this requirement may bring silent correctness issue. - */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression with Partitioning with Unevaluable { +trait HashPartitioningLike extends Expression with Partitioning with Unevaluable { + def expressions: Seq[Expression] override def children: Seq[Expression] = expressions override def nullable: Boolean = false @@ -294,6 +284,20 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) } } } +} + +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. + * + * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires + * stateful operators to retain the same physical partitioning during the lifetime of the query + * (including restart), the result of evaluation on `partitionIdExpression` must be unchanged + * across Spark versions. Violation of this requirement may bring silent correctness issue. + */ +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends HashPartitioningLike { override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = HashShuffleSpec(this, distribution) @@ -306,7 +310,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) - } case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int) @@ -316,25 +319,18 @@ case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int) * fewer number of partitions. */ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[CoalescedBoundary]) - extends Expression with Partitioning with Unevaluable { - - override def children: Seq[Expression] = from.expressions - override def nullable: Boolean = from.nullable - override def dataType: DataType = from.dataType + extends HashPartitioningLike { - override def satisfies0(required: Distribution): Boolean = from.satisfies0(required) + override def expressions: Seq[Expression] = from.expressions override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions) - override protected def withNewChildrenInternal( - newChildren: IndexedSeq[Expression]): CoalescedHashPartitioning = - copy(from = from.copy(expressions = newChildren)) - override val numPartitions: Int = partitions.length - override def toString: String = from.toString - override def sql: String = from.sql + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CoalescedHashPartitioning = + copy(from = from.copy(expressions = newChildren)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 68022757ff241..368534d05b1f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioningLike, Partitioning, PartitioningCollection, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -73,7 +73,7 @@ case class BroadcastHashJoinExec( joinType match { case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => streamedPlan.outputPartitioning match { - case h: HashPartitioning => expandOutputPartitioning(h) + case h: HashPartitioningLike => expandOutputPartitioning(h) case c: PartitioningCollection => expandOutputPartitioning(c) case other => other } @@ -99,7 +99,7 @@ case class BroadcastHashJoinExec( private def expandOutputPartitioning( partitioning: PartitioningCollection): PartitioningCollection = { PartitioningCollection(partitioning.partitionings.flatMap { - case h: HashPartitioning => expandOutputPartitioning(h).partitionings + case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) case other => Seq(other) }) @@ -111,11 +111,12 @@ case class BroadcastHashJoinExec( // the expanded partitioning will have the following expressions: // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). // The expanded expressions are returned as PartitioningCollection. - private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { + private def expandOutputPartitioning( + partitioning: HashPartitioningLike): PartitioningCollection = { PartitioningCollection(partitioning.multiTransformDown { case e: Expression if streamedKeyToBuildKeyMapping.contains(e.canonicalized) => e +: streamedKeyToBuildKeyMapping(e.canonicalized) - }.asInstanceOf[LazyList[HashPartitioning]] + }.asInstanceOf[LazyList[HashPartitioningLike]] .take(conf.broadcastHashJoinOutputPartitioningExpandLimit)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index c41b85f75e584..909a05ce26f78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION} import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.BatchEvalPythonExec import org.apache.spark.sql.internal.SQLConf @@ -1729,4 +1729,30 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan checkAnswer(joined, expected) } + + test("SPARK-45882: BroadcastHashJoinExec propagate partitioning should respect " + + "CoalescedHashPartitioning") { + val cached = spark.sql( + """ + |select /*+ broadcast(testData) */ key, value, a + |from testData join ( + | select a from testData2 group by a + |)tmp on key = a + |""".stripMargin).cache() + try { + val df = cached.groupBy("key").count() + val expected = Seq(Row(1, 1), Row(2, 1), Row(3, 1)) + assert(find(df.queryExecution.executedPlan) { + case _: ShuffleExchangeLike => true + case _ => false + }.size == 1, df.queryExecution) + checkAnswer(df, expected) + assert(find(df.queryExecution.executedPlan) { + case _: ShuffleExchangeLike => true + case _ => false + }.isEmpty, df.queryExecution) + } finally { + cached.unpersist() + } + } } From ac293423543f6a8801dc88fd3b09c56a3b7c95ea Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Tue, 14 Nov 2023 06:48:50 +0900 Subject: [PATCH 118/121] [SPARK-45845][SS][UI] Add number of evicted state rows to streaming UI ### What changes were proposed in this pull request? In a stateful query, a watermark has two responsibilities: 1. drop late rows 2. Evict state rows from state store Before we only log "aggregated number of rows dropped by watermark". This is case 1. But people would confuse this with case 2. This PR purpose we also add a chart for case 2. Also made the explanation for case 1 more verbose. Now: image ### Why are the changes needed? UI improvement ### Does this PR introduce _any_ user-facing change? Yes, users will be able to view the new UI ### How was this patch tested? Manually tested ### Was this patch authored or co-authored using generative AI tooling? No Closes #43723 from WweiL/streaming-ui-update. Authored-by: Wei Liu Signed-off-by: Jungtaek Lim --- .../ui/StreamingQueryStatisticsPage.scala | 27 ++++++++++++++++++- .../sql/streaming/ui/UISeleniumSuite.scala | 3 ++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala index 8a3d387889369..d499f8f4a96ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ui/StreamingQueryStatisticsPage.scala @@ -210,6 +210,10 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) p.stateOperators.map(_.numRowsUpdated).sum.toDouble)) val maxNumRowsUpdated = numRowsUpdatedData.maxBy(_._2)._2 + val numRowsRemovedData = query.recentProgress.map(p => (parseProgressTimestamp(p.timestamp), + p.stateOperators.map(_.numRowsRemoved).sum.toDouble)) + val maxNumRowsRemoved = numRowsRemovedData.maxBy(_._2)._2 + val memoryUsedBytesData = query.recentProgress.map(p => (parseProgressTimestamp(p.timestamp), p.stateOperators.map(_.memoryUsedBytes).sum.toDouble)) val maxMemoryUsedBytes = memoryUsedBytesData.maxBy(_._2)._2 @@ -243,6 +247,18 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) "records") graphUIDataForNumberUpdatedRows.generateDataJs(jsCollector) + val graphUIDataForNumberRemovedRows = + new GraphUIData( + "aggregated-num-removed-state-rows-timeline", + "aggregated-num-removed-state-rows-histogram", + numRowsRemovedData, + minBatchTime, + maxBatchTime, + 0, + maxNumRowsRemoved, + "records") + graphUIDataForNumberRemovedRows.generateDataJs(jsCollector) + val graphUIDataForMemoryUsedBytes = new GraphUIData( "aggregated-state-memory-used-bytes-timeline", @@ -287,6 +303,15 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab) {graphUIDataForNumberUpdatedRows.generateTimelineHtml(jsCollector)} {graphUIDataForNumberUpdatedRows.generateHistogramHtml(jsCollector)} + + +
    +
    Aggregated Number Of Removed State Rows{SparkUIUtils.tooltip("Aggregated number of state rows removed from the state. Normally it means the number of rows evicted from the state because watermark has passed, except in flatMapGroupWithState, where users can manually remove the state.", "right")}
    +
    + + {graphUIDataForNumberRemovedRows.generateTimelineHtml(jsCollector)} + {graphUIDataForNumberRemovedRows.generateHistogramHtml(jsCollector)} +
    @@ -299,7 +324,7 @@ private[ui] class StreamingQueryStatisticsPage(parent: StreamingQueryTab)
    -
    Aggregated Number Of Rows Dropped By Watermark {SparkUIUtils.tooltip("Accumulates all input rows being dropped in stateful operators by watermark. 'Inputs' are relative to operators.", "right")}
    +
    Aggregated Number Of Late Rows Dropped By Watermark {SparkUIUtils.tooltip("Accumulates all late input rows being dropped in stateful operators by watermark. This only represents the late rows ever reached to stateful operators, not rows from the source. A row could be filtered out at an earlier stage.", "right")}
    {graphUIDataForNumRowsDroppedByWatermark.generateTimelineHtml(jsCollector)} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala index f27fc883ac9af..6ad940dadc674 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ui/UISeleniumSuite.scala @@ -156,8 +156,9 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers { summaryText should contain ("Global Watermark Gap (?)") summaryText should contain ("Aggregated Number Of Total State Rows (?)") summaryText should contain ("Aggregated Number Of Updated State Rows (?)") + summaryText should contain ("Aggregated Number Of Removed State Rows (?)") summaryText should contain ("Aggregated State Memory Used In Bytes (?)") - summaryText should contain ("Aggregated Number Of Rows Dropped By Watermark (?)") + summaryText should contain ("Aggregated Number Of Late Rows Dropped By Watermark (?)") summaryText should contain ("Aggregated Custom Metric stateOnCurrentVersionSizeBytes" + " (?)") summaryText should not contain ("Aggregated Custom Metric loadedMapCacheHitCount (?)") From cd19d6c299a7bc6d8a785208654dff132ca5fe1b Mon Sep 17 00:00:00 2001 From: Phil Dakin Date: Tue, 14 Nov 2023 11:25:57 +0900 Subject: [PATCH 119/121] [SPARK-44733][PYTHON][DOCS] Add Python to Spark type conversion page to PySpark docs. allisonwang-db ### What changes were proposed in this pull request? Add documentation page showing Python to Spark type mappings for PySpark. ### Why are the changes needed? Surface this information to users navigating the PySpark docs per https://issues.apache.org/jira/browse/SPARK-44733. ### Does this PR introduce _any_ user-facing change? Yes, adds new page to PySpark docs. ### How was this patch tested? Build HTML docs file using Sphinx, inspect visually. ### Was this patch authored or co-authored using generative AI tooling? No. ![full](https://github.com/apache/spark/assets/15946757/fde09420-5dc1-461c-9dc8-5e3c830740bd) Closes #43369 from PhilDakin/20231013.SPARK-44733. Authored-by: Phil Dakin Signed-off-by: Hyukjin Kwon --- docs/sql-ref-datatypes.md | 8 +- python/docs/source/user_guide/sql/index.rst | 1 + .../user_guide/sql/type_conversions.rst | 248 ++++++++++++++++++ 3 files changed, 253 insertions(+), 4 deletions(-) create mode 100644 python/docs/source/user_guide/sql/type_conversions.rst diff --git a/docs/sql-ref-datatypes.md b/docs/sql-ref-datatypes.md index 25dc00f18a4e9..041d22baf6593 100644 --- a/docs/sql-ref-datatypes.md +++ b/docs/sql-ref-datatypes.md @@ -119,10 +119,10 @@ from pyspark.sql.types import * |Data type|Value type in Python|API to access or create a data type| |---------|--------------------|-----------------------------------| -|**ByteType**|int or long
    **Note:** Numbers will be converted to 1-byte signed integer numbers at runtime. Please make sure that numbers are within the range of -128 to 127.|ByteType()| -|**ShortType**|int or long
    **Note:** Numbers will be converted to 2-byte signed integer numbers at runtime. Please make sure that numbers are within the range of -32768 to 32767.|ShortType()| -|**IntegerType**|int or long|IntegerType()| -|**LongType**|long
    **Note:** Numbers will be converted to 8-byte signed integer numbers at runtime. Please make sure that numbers are within the range of -9223372036854775808 to 9223372036854775807. Otherwise, please convert data to decimal.Decimal and use DecimalType.|LongType()| +|**ByteType**|int
    **Note:** Numbers will be converted to 1-byte signed integer numbers at runtime. Please make sure that numbers are within the range of -128 to 127.|ByteType()| +|**ShortType**|int
    **Note:** Numbers will be converted to 2-byte signed integer numbers at runtime. Please make sure that numbers are within the range of -32768 to 32767.|ShortType()| +|**IntegerType**|int|IntegerType()| +|**LongType**|int
    **Note:** Numbers will be converted to 8-byte signed integer numbers at runtime. Please make sure that numbers are within the range of -9223372036854775808 to 9223372036854775807. Otherwise, please convert data to decimal.Decimal and use DecimalType.|LongType()| |**FloatType**|float
    **Note:** Numbers will be converted to 4-byte single-precision floating point numbers at runtime.|FloatType()| |**DoubleType**|float|DoubleType()| |**DecimalType**|decimal.Decimal|DecimalType()| diff --git a/python/docs/source/user_guide/sql/index.rst b/python/docs/source/user_guide/sql/index.rst index c0369de67865b..118cf139d9b38 100644 --- a/python/docs/source/user_guide/sql/index.rst +++ b/python/docs/source/user_guide/sql/index.rst @@ -25,4 +25,5 @@ Spark SQL arrow_pandas python_udtf + type_conversions diff --git a/python/docs/source/user_guide/sql/type_conversions.rst b/python/docs/source/user_guide/sql/type_conversions.rst new file mode 100644 index 0000000000000..b63e7dfa88518 --- /dev/null +++ b/python/docs/source/user_guide/sql/type_conversions.rst @@ -0,0 +1,248 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +================================ +Python to Spark Type Conversions +================================ + +.. TODO: Add additional information on conversions when Arrow is enabled. +.. TODO: Add in-depth explanation and table for type conversions (SPARK-44734). + +.. currentmodule:: pyspark.sql.types + +When working with PySpark, you will often need to consider the conversions between Python-native +objects to their Spark equivalents. For instance, when working with user-defined functions, the +function return type will be cast by Spark to an appropriate Spark SQL type. Or, when creating a +``DataFrame``, you may supply ``numpy`` or ``pandas`` objects as the inputted data. This guide will cover +the various conversions between Python and Spark SQL types. + +Browsing Type Conversions +------------------------- + +Though this document provides a comprehensive list of type conversions, you may find it easier to +interactively check the conversion behavior of Spark. To do so, you can test small examples of +user-defined functions, and use the ``spark.createDataFrame`` interface. + +All data types of Spark SQL are located in the package of ``pyspark.sql.types``. +You can access them by doing: + +.. code-block:: python + + from pyspark.sql.types import * + +Configuration +------------- +There are several configurations that affect the behavior of type conversions. These configurations +are listed below: + +.. list-table:: + :header-rows: 1 + + * - Configuration + - Description + - Default + * - spark.sql.execution.pythonUDF.arrow.enabled + - Enable PyArrow in PySpark. See more `here `_. + - False + * - spark.sql.pyspark.inferNestedDictAsStruct.enabled + - When enabled, nested dictionaries are inferred as StructType. Otherwise, they are inferred as MapType. + - False + * - spark.sql.timestampType + - If set to `TIMESTAMP_NTZ`, the default timestamp type is ``TimestampNTZType``. Otherwise, the default timestamp type is TimestampType. + - "" + +All Conversions +--------------- +.. list-table:: + :header-rows: 1 + + * - Data type + - Value type in Python + - API to access or create a data type + * - **ByteType** + - int + .. note:: Numbers will be converted to 1-byte signed integer numbers at runtime. Please make sure that numbers are within the range of -128 to 127. + - ByteType() + * - **ShortType** + - int + .. note:: Numbers will be converted to 2-byte signed integer numbers at runtime. Please make sure that numbers are within the range of -32768 to 32767. + - ShortType() + * - **IntegerType** + - int + - IntegerType() + * - **LongType** + - int + .. note:: Numbers will be converted to 8-byte signed integer numbers at runtime. Please make sure that numbers are within the range of -9223372036854775808 to 9223372036854775807. Otherwise, please convert data to decimal.Decimal and use DecimalType. + - LongType() + * - **FloatType** + - float + .. note:: Numbers will be converted to 4-byte single-precision floating point numbers at runtime. + - FloatType() + * - **DoubleType** + - float + - DoubleType() + * - **DecimalType** + - decimal.Decimal + - DecimalType()| + * - **StringType** + - string + - StringType() + * - **BinaryType** + - bytearray + - BinaryType() + * - **BooleanType** + - bool + - BooleanType() + * - **TimestampType** + - datetime.datetime + - TimestampType() + * - **TimestampNTZType** + - datetime.datetime + - TimestampNTZType() + * - **DateType** + - datetime.date + - DateType() + * - **DayTimeIntervalType** + - datetime.timedelta + - DayTimeIntervalType() + * - **ArrayType** + - list, tuple, or array + - ArrayType(*elementType*, [*containsNull*]) + .. note:: The default value of *containsNull* is True. + * - **MapType** + - dict + - MapType(*keyType*, *valueType*, [*valueContainsNull]*) + .. note:: The default value of *valueContainsNull* is True. + * - **StructType** + - list or tuple + - StructType(*fields*) + .. note:: *fields* is a Seq of StructFields. Also, two fields with the same name are not allowed. + * - **StructField** + - The value type in Python of the data type of this field. For example, Int for a StructField with the data type IntegerType. + - StructField(*name*, *dataType*, [*nullable*]) + .. note:: The default value of *nullable* is True. + +Conversions in Practice - UDFs +------------------------------ +A common conversion case is returning a Python value from a UDF. In this case, the return type of +the UDF must match the provided return type. + +.. note:: If the actual return type of your function does not match the provided return type, Spark will implicitly cast the value to null. + +.. code-block:: python + + from pyspark.sql.types import ( + StructType, + StructField, + IntegerType, + StringType, + FloatType, + ) + from pyspark.sql.functions import udf, col + + df = spark.createDataFrame( + [[1]], schema=StructType([StructField("int", IntegerType())]) + ) + + @udf(returnType=StringType()) + def to_string(value): + return str(value) + + @udf(returnType=FloatType()) + def to_float(value): + return float(value) + + df.withColumn("cast_int", to_float(col("int"))).withColumn( + "cast_str", to_string(col("int")) + ).printSchema() + # root + # |-- int: integer (nullable = true) + # |-- cast_int: float (nullable = true) + # |-- cast_str: string (nullable = true) + +Conversions in Practice - Creating DataFrames +--------------------------------------------- +Another common conversion case is when creating a DataFrame from values in Python. In this case, +you can supply a schema, or allow Spark to infer the schema from the provided data. + +.. code-block:: python + + data = [ + ["Wei", "Math", 93.0, 1], + ["Jerry", "Physics", 85.0, 4], + ["Katrina", "Geology", 90.0, 2], + ] + cols = ["Name", "Subject", "Score", "Period"] + + spark.createDataFrame(data, cols).printSchema() + # root + # |-- Name: string (nullable = true) + # |-- Subject: string (nullable = true) + # |-- Score: double (nullable = true) + # |-- Period: long (nullable = true) + + import pandas as pd + + df = pd.DataFrame(data, columns=cols) + spark.createDataFrame(df).printSchema() + # root + # |-- Name: string (nullable = true) + # |-- Subject: string (nullable = true) + # |-- Score: double (nullable = true) + # |-- Period: long (nullable = true) + + import numpy as np + + spark.createDataFrame(np.zeros([3, 2], "int8")).printSchema() + # root + # |-- _1: byte (nullable = true) + # |-- _2: byte (nullable = true) + +Conversions in Practice - Nested Data Types +------------------------------------------- +Nested data types will convert to ``StructType``, ``MapType``, and ``ArrayType``, depending on the passed data. + +.. code-block:: python + + data = [ + ["Wei", [[1, 2]], {"RecordType": "Scores", "Math": { "H1": 93.0, "H2": 85.0}}], + ] + cols = ["Name", "ActiveHalfs", "Record"] + + spark.createDataFrame(data, cols).printSchema() + # root + # |-- Name: string (nullable = true) + # |-- ActiveHalfs: array (nullable = true) + # | |-- element: array (containsNull = true) + # | | |-- element: long (containsNull = true) + # |-- Record: map (nullable = true) + # | |-- key: string + # | |-- value: string (valueContainsNull = true) + + spark.conf.set('spark.sql.pyspark.inferNestedDictAsStruct.enabled', True) + + spark.createDataFrame(data, cols).printSchema() + # root + # |-- Name: string (nullable = true) + # |-- ActiveHalfs: array (nullable = true) + # | |-- element: array (containsNull = true) + # | | |-- element: long (containsNull = true) + # |-- Record: struct (nullable = true) + # | |-- RecordType: string (nullable = true) + # | |-- Math: struct (nullable = true) + # | | |-- H1: double (nullable = true) + # | | |-- H2: double (nullable = true) From aa10ac79fecd1a88bc0fcd54551b5df6cffff480 Mon Sep 17 00:00:00 2001 From: Chenhao Li Date: Tue, 14 Nov 2023 13:42:22 +0900 Subject: [PATCH 120/121] [SPARK-45827][SQL] Add Variant data type in Spark ## What changes were proposed in this pull request? This PR adds Variant data type in Spark. It doesn't actually introduce any binary encoding, but just has the `value` and `metadata` binaries. This PR includes: - The in-memory Variant representation in different types of Spark rows. All rows except `UnsafeRow` use the `VariantVal` object to store an Variant value. In the `UnsafeRow`, the two binaries are stored contiguously. - Spark parquet writer and reader support for the Variant type. This is agnostic to the detailed binary encoding but just transparently reads the two binaries. - A dummy Spark `parse_json` implementation so that I can manually test the writer and reader. It currently returns an `VariantVal` with value being the raw bytes of the input string and empty metadata. This is **not** a valid Variant value in the final binary encoding. ## How was this patch tested? Manual testing. Some supported usages: ``` > sql("create table T using parquet as select parse_json('1') as o") > sql("select * from T").show +---+ | o| +---+ | 1| +---+ > sql("insert into T select parse_json('[2]') as o") > sql("select * from T").show +---+ | o| +---+ |[2]| | 1| +---+ ``` Closes #43707 from chenhao-db/variant-type. Authored-by: Chenhao Li Signed-off-by: Hyukjin Kwon --- .../apache/spark/unsafe/types/VariantVal.java | 110 ++++++++++++++++++ .../main/resources/error/error-classes.json | 6 + .../org/apache/spark/sql/avro/AvroUtils.scala | 2 + docs/sql-error-conditions.md | 6 + docs/sql-ref-ansi-compliance.md | 1 + .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 1 + .../sql/catalyst/parser/SqlBaseParser.g4 | 3 + .../catalyst/encoders/AgnosticEncoder.scala | 3 +- .../sql/catalyst/encoders/RowEncoder.scala | 3 +- .../catalyst/parser/DataTypeAstBuilder.scala | 3 +- .../org/apache/spark/sql/types/DataType.scala | 3 +- .../apache/spark/sql/types/VariantType.scala | 43 +++++++ .../catalyst/expressions/ExpressionInfo.java | 2 +- .../expressions/SpecializedGetters.java | 3 + .../expressions/SpecializedGettersReader.java | 3 + .../catalyst/expressions/UnsafeArrayData.java | 7 ++ .../sql/catalyst/expressions/UnsafeRow.java | 7 ++ .../expressions/codegen/UnsafeWriter.java | 22 ++++ .../spark/sql/vectorized/ColumnVector.java | 11 ++ .../spark/sql/vectorized/ColumnarArray.java | 6 + .../sql/vectorized/ColumnarBatchRow.java | 8 ++ .../spark/sql/vectorized/ColumnarRow.java | 8 ++ .../sql/catalyst/ProjectingInternalRow.scala | 6 +- .../catalyst/analysis/FunctionRegistry.scala | 4 + .../sql/catalyst/encoders/EncoderUtils.scala | 3 +- .../InterpretedUnsafeProjection.scala | 2 + .../sql/catalyst/expressions/JoinedRow.scala | 5 +- .../expressions/codegen/CodeGenerator.scala | 4 + .../sql/catalyst/expressions/literals.scala | 3 + .../spark/sql/catalyst/expressions/rows.scala | 3 +- .../variant/variantExpressions.scala | 54 +++++++++ .../sql/catalyst/types/PhysicalDataType.scala | 17 ++- .../sql/catalyst/util/GenericArrayData.scala | 3 +- .../sql/errors/QueryCompilationErrors.scala | 6 + .../ansi-sql-2016-reserved-keywords.txt | 1 + .../GenerateUnsafeProjectionSuite.scala | 4 +- .../codegen/UnsafeRowWriterSuite.scala | 13 ++- .../vectorized/ColumnVectorUtils.java | 10 +- .../vectorized/ConstantColumnVector.java | 13 +++ .../vectorized/MutableColumnarRow.java | 6 + .../vectorized/WritableColumnVector.java | 6 +- .../apache/spark/sql/execution/Columnar.scala | 10 ++ .../spark/sql/execution/HiveResult.scala | 3 +- .../execution/datasources/DataSource.scala | 12 +- .../datasources/csv/CSVFileFormat.scala | 2 + .../datasources/json/JsonFileFormat.scala | 2 + .../datasources/orc/OrcFileFormat.scala | 2 + .../parquet/ParquetSchemaConverter.scala | 11 ++ .../parquet/ParquetWriteSupport.scala | 12 ++ .../datasources/xml/XmlFileFormat.scala | 2 + .../sql-functions/sql-expression-schema.md | 1 + .../sql-tests/results/ansi/keywords.sql.out | 1 + .../sql-tests/results/keywords.sql.out | 1 + .../org/apache/spark/sql/VariantSuite.scala | 77 ++++++++++++ .../ThriftServerWithSparkContextSuite.scala | 2 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 2 + 56 files changed, 545 insertions(+), 19 deletions(-) create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java new file mode 100644 index 0000000000000..40bd1c7abc75f --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import org.apache.spark.unsafe.Platform; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * The physical data representation of {@link org.apache.spark.sql.types.VariantType} that + * represents a semi-structured value. It consists of two binary values: {@link VariantVal#value} + * and {@link VariantVal#metadata}. The value encodes types and values, but not field names. The + * metadata currently contains a version flag and a list of field names. We can extend/modify the + * detailed binary format given the version flag. + *

    + * A {@link VariantVal} can be produced by casting another value into the Variant type or parsing a + * JSON string in the {@link org.apache.spark.sql.catalyst.expressions.variant.ParseJson} + * expression. We can extract a path consisting of field names and array indices from it, cast it + * into a concrete data type, or rebuild a JSON string from it. + *

    + * The storage layout of this class in {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} + * and {@link org.apache.spark.sql.catalyst.expressions.UnsafeArrayData} is: the fixed-size part is + * a long value "offsetAndSize". The upper 32 bits is the offset that points to the start position + * of the actual binary content. The lower 32 bits is the total length of the binary content. The + * binary content contains: 4 bytes representing the length of {@link VariantVal#value}, content of + * {@link VariantVal#value}, content of {@link VariantVal#metadata}. This is an internal and + * transient format and can be modified at any time. + */ +public class VariantVal implements Serializable { + protected final byte[] value; + protected final byte[] metadata; + + public VariantVal(byte[] value, byte[] metadata) { + this.value = value; + this.metadata = metadata; + } + + public byte[] getValue() { + return value; + } + + public byte[] getMetadata() { + return metadata; + } + + /** + * This function reads the binary content described in `writeIntoUnsafeRow` from `baseObject`. The + * offset is computed by adding the offset in {@code offsetAndSize} and {@code baseOffset}. + */ + public static VariantVal readFromUnsafeRow( + long offsetAndSize, + Object baseObject, + long baseOffset) { + // offset and totalSize is the upper/lower 32 bits in offsetAndSize. + int offset = (int) (offsetAndSize >> 32); + int totalSize = (int) offsetAndSize; + int valueSize = Platform.getInt(baseObject, baseOffset + offset); + int metadataSize = totalSize - 4 - valueSize; + byte[] value = new byte[valueSize]; + byte[] metadata = new byte[metadataSize]; + Platform.copyMemory( + baseObject, + baseOffset + offset + 4, + value, + Platform.BYTE_ARRAY_OFFSET, + valueSize + ); + Platform.copyMemory( + baseObject, + baseOffset + offset + 4 + valueSize, + metadata, + Platform.BYTE_ARRAY_OFFSET, + metadataSize + ); + return new VariantVal(value, metadata); + } + + public String debugString() { + return "VariantVal{" + + "value=" + Arrays.toString(value) + + ", metadata=" + Arrays.toString(metadata) + + '}'; + } + + /** + * @return A human-readable representation of the Variant value. It is always a JSON string at + * this moment. + */ + @Override + public String toString() { + // NOTE: the encoding is not yet implemented, this is not the final implementation. + return new String(value); + } +} diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index e3b9f3161b24d..1b4c10acaf7bc 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -317,6 +317,12 @@ ], "sqlState" : "58030" }, + "CANNOT_SAVE_VARIANT" : { + "message" : [ + "Cannot save variant data type into external storage." + ], + "sqlState" : "0A000" + }, "CANNOT_UPDATE_FIELD" : { "message" : [ "Cannot update field type:" diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 0e27e4a604c46..e235c13d413e2 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -72,6 +72,8 @@ private[sql] object AvroUtils extends Logging { } def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: AtomicType => true case st: StructType => st.forall { f => supportsDataType(f.dataType) } diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index a811019e0a57b..ee9c2fd67b307 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -280,6 +280,12 @@ SQLSTATE: 58030 Failed to set permissions on created path `` back to ``. +### CANNOT_SAVE_VARIANT + +[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported) + +Cannot save variant data type into external storage. + ### [CANNOT_UPDATE_FIELD](sql-error-conditions-cannot-update-field-error-class.html) [SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 09c38a0099599..4729db16d63f3 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -671,6 +671,7 @@ Below is a list of all the keywords in Spark SQL. |VARCHAR|non-reserved|non-reserved|reserved| |VAR|non-reserved|non-reserved|non-reserved| |VARIABLE|non-reserved|non-reserved|non-reserved| +|VARIANT|non-reserved|non-reserved|reserved| |VERSION|non-reserved|non-reserved|non-reserved| |VIEW|non-reserved|non-reserved|non-reserved| |VIEWS|non-reserved|non-reserved|non-reserved| diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index e8b5cb012fcae..9b3dcbc6d194f 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -408,6 +408,7 @@ VALUES: 'VALUES'; VARCHAR: 'VARCHAR'; VAR: 'VAR'; VARIABLE: 'VARIABLE'; +VARIANT: 'VARIANT'; VERSION: 'VERSION'; VIEW: 'VIEW'; VIEWS: 'VIEWS'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index bd449a4e194e8..609bd72e21935 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1086,6 +1086,7 @@ type | DECIMAL | DEC | NUMERIC | VOID | INTERVAL + | VARIANT | ARRAY | STRUCT | MAP | unsupportedType=identifier ; @@ -1545,6 +1546,7 @@ ansiNonReserved | VARCHAR | VAR | VARIABLE + | VARIANT | VERSION | VIEW | VIEWS @@ -1893,6 +1895,7 @@ nonReserved | VARCHAR | VAR | VARIABLE + | VARIANT | VERSION | VIEW | VIEWS diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index e5e9ba644b814..9133abce88adc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -25,7 +25,7 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.SparkClassUtils /** @@ -216,6 +216,7 @@ object AgnosticEncoders { case object CalendarIntervalEncoder extends LeafEncoder[CalendarInterval](CalendarIntervalType) case object DayTimeIntervalEncoder extends LeafEncoder[Duration](DayTimeIntervalType()) case object YearMonthIntervalEncoder extends LeafEncoder[Period](YearMonthIntervalType()) + case object VariantEncoder extends LeafEncoder[VariantVal](VariantType) case class DateEncoder(override val lenientSerialization: Boolean) extends LeafEncoder[jsql.Date](DateType) case class LocalDateEncoder(override val lenientSerialization: Boolean) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 69661c343c5b1..a201da9c95c9e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.reflect.classTag import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VariantEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types._ @@ -90,6 +90,7 @@ object RowEncoder { case CalendarIntervalType => CalendarIntervalEncoder case _: DayTimeIntervalType => DayTimeIntervalEncoder case _: YearMonthIntervalType => YearMonthIntervalEncoder + case _: VariantType => VariantEncoder case p: PythonUserDefinedType => // TODO check if this works. encoderForDataType(p.sqlType, lenient) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index b30c6fa29e829..3a2e704ffe9f7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { protected def typedVisit[T](ctx: ParseTree): T = { @@ -82,6 +82,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { DecimalType(precision.getText.toInt, scale.getText.toInt) case (VOID, Nil) => NullType case (INTERVAL, Nil) => CalendarIntervalType + case (VARIANT, Nil) => VariantType case (CHARACTER | CHAR | VARCHAR, Nil) => throw QueryParsingErrors.charTypeMissingLengthError(ctx.`type`.getText, ctx) case (ARRAY | STRUCT | MAP, Nil) => diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 5f563e3b7a8f1..94252de48d1ea 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -173,7 +173,8 @@ object DataType { YearMonthIntervalType(YEAR), YearMonthIntervalType(MONTH), YearMonthIntervalType(YEAR, MONTH), - TimestampNTZType) + TimestampNTZType, + VariantType) .map(t => t.typeName -> t).toMap } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala new file mode 100644 index 0000000000000..103fe7a59fc83 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.annotation.Unstable + +/** + * The data type representing semi-structured values with arbitrary hierarchical data structures. It + * is intended to store parsed JSON values and most other data types in the system (e.g., it cannot + * store a map with a non-string key type). + * + * @since 4.0.0 + */ +@Unstable +class VariantType private () extends AtomicType { + // The default size is used in query planning to drive optimization decisions. 2048 is arbitrarily + // picked and we currently don't have any data to support it. This may need revisiting later. + override def defaultSize: Int = 2048 + + /** This is a no-op because values with VARIANT type are always nullable. */ + private[spark] override def asNullable: VariantType = this +} + +/** + * @since 4.0.0 + */ +@Unstable +case object VariantType extends VariantType diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java index e7af5a68b4663..ffc3c8eaf8f84 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java @@ -45,7 +45,7 @@ public class ExpressionInfo { "collection_funcs", "predicate_funcs", "conditional_funcs", "conversion_funcs", "csv_funcs", "datetime_funcs", "generator_funcs", "hash_funcs", "json_funcs", "lambda_funcs", "map_funcs", "math_funcs", "misc_funcs", "string_funcs", "struct_funcs", - "window_funcs", "xml_funcs", "table_funcs", "url_funcs")); + "window_funcs", "xml_funcs", "table_funcs", "url_funcs", "variant_funcs")); private static final Set validSources = new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf", diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index eea7149d02594..b88a892db4b46 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; public interface SpecializedGetters { @@ -51,6 +52,8 @@ public interface SpecializedGetters { CalendarInterval getInterval(int ordinal); + VariantVal getVariant(int ordinal); + InternalRow getStruct(int ordinal, int numFields); ArrayData getArray(int ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java index 91f04c3d327ac..9e508dbb271cf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java @@ -66,6 +66,9 @@ public static Object read( if (physicalDataType instanceof PhysicalBinaryType) { return obj.getBinary(ordinal); } + if (physicalDataType instanceof PhysicalVariantType) { + return obj.getVariant(ordinal); + } if (physicalDataType instanceof PhysicalStructType) { return obj.getStruct(ordinal, ((PhysicalStructType) physicalDataType).fields().length); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index ea6f1e05422b5..700e42cb843c8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -38,6 +38,7 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; @@ -231,6 +232,12 @@ public CalendarInterval getInterval(int ordinal) { return new CalendarInterval(months, days, microseconds); } + @Override + public VariantVal getVariant(int ordinal) { + if (isNullAt(ordinal)) return null; + return VariantVal.readFromUnsafeRow(getLong(ordinal), baseObject, baseOffset); + } + @Override public UnsafeRow getStruct(int ordinal, int numFields) { if (isNullAt(ordinal)) return null; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 8f9d5919e1d9f..fca45c58beed0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -36,6 +36,7 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; @@ -417,6 +418,12 @@ public CalendarInterval getInterval(int ordinal) { } } + @Override + public VariantVal getVariant(int ordinal) { + if (isNullAt(ordinal)) return null; + return VariantVal.readFromUnsafeRow(getLong(ordinal), baseObject, baseOffset); + } + @Override public UnsafeRow getStruct(int ordinal, int numFields) { if (isNullAt(ordinal)) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 8d4e187d01a12..d651e5ab5b3e5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -25,6 +25,7 @@ import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * Base class for writing Unsafe* structures. @@ -149,6 +150,27 @@ public void write(int ordinal, CalendarInterval input) { increaseCursor(16); } + public void write(int ordinal, VariantVal input) { + // See the class comment of VariantVal for the format of the binary content. + byte[] value = input.getValue(); + byte[] metadata = input.getMetadata(); + int totalSize = 4 + value.length + metadata.length; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSize); + grow(roundedSize); + zeroOutPaddingBytes(totalSize); + Platform.putInt(getBuffer(), cursor(), value.length); + Platform.copyMemory(value, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor() + 4, value.length); + Platform.copyMemory( + metadata, + Platform.BYTE_ARRAY_OFFSET, + getBuffer(), + cursor() + 4 + value.length, + metadata.length + ); + setOffsetAndSize(ordinal, totalSize); + increaseCursor(roundedSize); + } + public final void write(int ordinal, UnsafeRow row) { writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes()); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 73c2cf2cc05f8..cd3c30fa69335 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -22,6 +22,7 @@ import org.apache.spark.sql.types.UserDefinedType; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * An interface representing in-memory columnar data in Spark. This interface defines the main APIs @@ -299,6 +300,16 @@ public CalendarInterval getInterval(int rowId) { return new CalendarInterval(months, days, microseconds); } + /** + * Returns the Variant value for {@code rowId}. Similar to {@link #getInterval(int)}, the + * implementation must implement {@link #getChild(int)} and define 2 child vectors of binary type + * for the Variant value and metadata. + */ + public final VariantVal getVariant(int rowId) { + if (isNullAt(rowId)) return null; + return new VariantVal(getChild(0).getBinary(rowId), getChild(1).getBinary(rowId)); + } + /** * @return child {@link ColumnVector} at the given ordinal. */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index bd7c3d7c0fd49..e0141a575b299 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * Array abstraction in {@link ColumnVector}. @@ -160,6 +161,11 @@ public CalendarInterval getInterval(int ordinal) { return data.getInterval(offset + ordinal); } + @Override + public VariantVal getVariant(int ordinal) { + return data.getVariant(offset + ordinal); + } + @Override public ColumnarRow getStruct(int ordinal, int numFields) { return data.getStruct(offset + ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java index c0d2ae8e7d0e8..ac23f70584e89 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * This class wraps an array of {@link ColumnVector} and provides a row view. @@ -133,6 +134,11 @@ public CalendarInterval getInterval(int ordinal) { return columns[ordinal].getInterval(rowId); } + @Override + public VariantVal getVariant(int ordinal) { + return columns[ordinal].getVariant(rowId); + } + @Override public ColumnarRow getStruct(int ordinal, int numFields) { return columns[ordinal].getStruct(rowId); @@ -182,6 +188,8 @@ public Object get(int ordinal, DataType dataType) { return getStruct(ordinal, ((StructType)dataType).fields().length); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof VariantType) { + return getVariant(ordinal); } else { throw new UnsupportedOperationException("Datatype not supported " + dataType); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 1df4653f55276..18f6779cccb96 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * Row abstraction in {@link ColumnVector}. @@ -140,6 +141,11 @@ public CalendarInterval getInterval(int ordinal) { return data.getChild(ordinal).getInterval(rowId); } + @Override + public VariantVal getVariant(int ordinal) { + return data.getChild(ordinal).getVariant(rowId); + } + @Override public ColumnarRow getStruct(int ordinal, int numFields) { return data.getChild(ordinal).getStruct(rowId); @@ -187,6 +193,8 @@ public Object get(int ordinal, DataType dataType) { return getStruct(ordinal, ((StructType)dataType).fields().length); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof VariantType) { + return getVariant(ordinal); } else { throw new UnsupportedOperationException("Datatype not supported " + dataType); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala index 429ce805bf2c4..034b959c5a383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{DataType, Decimal, StructType} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} /** * An [[InternalRow]] that projects particular columns from another [[InternalRow]] without copying @@ -99,6 +99,10 @@ case class ProjectingInternalRow(schema: StructType, colOrdinals: Seq[Int]) exte row.getInterval(colOrdinals(ordinal)) } + override def getVariant(ordinal: Int): VariantVal = { + row.getVariant(colOrdinals(ordinal)) + } + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { row.getStruct(colOrdinals(ordinal), numFields) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 23d63011db53f..4fb8d88f6eab1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.variant._ import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, Range} import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -808,6 +809,9 @@ object FunctionRegistry { expression[LengthOfJsonArray]("json_array_length"), expression[JsonObjectKeys]("json_object_keys"), + // Variant + expression[ParseJson]("parse_json"), + // cast expression[Cast]("cast"), // Cast aliases (SPARK-16730) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala index 4540ecffe0d21..793dd373d6899 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.collection.Map import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder, VariantEncoder} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType, PhysicalIntegerType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} @@ -68,6 +68,7 @@ object EncoderUtils { case CalendarIntervalEncoder => true case BinaryEncoder => true case _: SparkDecimalEncoder => true + case VariantEncoder => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 6aa5fefc73902..50408b41c1a76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -162,6 +162,8 @@ object InterpretedUnsafeProjection { case PhysicalStringType => (v, i) => writer.write(i, v.getUTF8String(i)) + case PhysicalVariantType => (v, i) => writer.write(i, v.getVariant(i)) + case PhysicalStructType(fields) => val numFields = fields.length val rowWriter = new UnsafeRowWriter(writer, numFields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index 86871223d66ad..345f2b3030b58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} /** * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to @@ -120,6 +120,9 @@ class JoinedRow extends InternalRow { override def getInterval(i: Int): CalendarInterval = if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields) + override def getVariant(i: Int): VariantVal = + if (i < row1.numFields) row1.getVariant(i) else row2.getVariant(i - row1.numFields) + override def getMap(i: Int): MapData = if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3595e43fcb987..4c32f682c275f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1475,6 +1475,7 @@ object CodeGenerator extends Logging { classOf[UTF8String].getName, classOf[Decimal].getName, classOf[CalendarInterval].getName, + classOf[VariantVal].getName, classOf[ArrayData].getName, classOf[UnsafeArrayData].getName, classOf[MapData].getName, @@ -1641,6 +1642,7 @@ object CodeGenerator extends Logging { case PhysicalNullType => "null" case PhysicalStringType => s"$input.getUTF8String($ordinal)" case t: PhysicalStructType => s"$input.getStruct($ordinal, ${t.fields.size})" + case PhysicalVariantType => s"$input.getVariant($ordinal)" case _ => s"($jt)$input.get($ordinal, null)" } } @@ -1928,6 +1930,7 @@ object CodeGenerator extends Logging { case PhysicalShortType => JAVA_SHORT case PhysicalStringType => "UTF8String" case _: PhysicalStructType => "InternalRow" + case _: PhysicalVariantType => "VariantVal" case _ => "Object" } } @@ -1951,6 +1954,7 @@ object CodeGenerator extends Logging { case _: MapType => classOf[MapData] case udt: UserDefinedType[_] => javaClass(udt.sqlType) case ObjectType(cls) => cls + case VariantType => classOf[VariantVal] case _ => classOf[Object] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 217ed562db779..c406ba0707b3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -97,6 +97,7 @@ object Literal { val convert = CatalystTypeConverters.createToCatalystConverter(dataType) Literal(convert(a), dataType) case i: CalendarInterval => Literal(i, CalendarIntervalType) + case v: VariantVal => Literal(v, VariantType) case null => Literal(null, NullType) case v: Literal => v case _ => @@ -143,6 +144,7 @@ object Literal { case _ if clz == classOf[BigInt] => DecimalType.SYSTEM_DEFAULT case _ if clz == classOf[BigDecimal] => DecimalType.SYSTEM_DEFAULT case _ if clz == classOf[CalendarInterval] => CalendarIntervalType + case _ if clz == classOf[VariantVal] => VariantType case _ if clz.isArray => ArrayType(componentTypeToDataType(clz.getComponentType)) @@ -235,6 +237,7 @@ object Literal { case PhysicalNullType => true case PhysicalShortType => v.isInstanceOf[Short] case PhysicalStringType => v.isInstanceOf[UTF8String] + case PhysicalVariantType => v.isInstanceOf[VariantVal] case st: PhysicalStructType => v.isInstanceOf[InternalRow] && { val row = v.asInstanceOf[InternalRow] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 296d093a13de6..8379069c53d9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} import org.apache.spark.util.ArrayImplicits._ /** @@ -47,6 +47,7 @@ trait BaseGenericInternalRow extends InternalRow { override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) override def getArray(ordinal: Int): ArrayData = getAs(ordinal) override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getVariant(ordinal: Int): VariantVal = getAs(ordinal) override def getMap(ordinal: Int): MapData = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala new file mode 100644 index 0000000000000..136ae4a3ef436 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.variant + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types._ + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(jsonStr) - Parse a JSON string as an Variant value. Throw an exception when the string is not valid JSON value.", + examples = """ + Examples: + > SELECT _FUNC_('{"a":1,"b":0.8}'); + {"a":1,"b":0.8} + """, + since = "4.0.0", + group = "variant_funcs" +) +// scalastyle:on line.size.limit +case class ParseJson(child: Expression) extends UnaryExpression + with NullIntolerant with ExpectsInputTypes with CodegenFallback { + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + + override def dataType: DataType = VariantType + + override def prettyName: String = "parse_json" + + protected override def nullSafeEval(input: Any): Any = { + // A dummy implementation: the value is the raw bytes of the input string. This is not the final + // implementation, but only intended for debugging. + // TODO(SPARK-45891): Have an actual parse_json implementation. + new VariantVal(input.asInstanceOf[UTF8String].toString.getBytes, Array()) + } + + override protected def withNewChildInternal(newChild: Expression): ParseJson = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala index 29d7a39ace3c1..290a35eb8e3b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala @@ -23,8 +23,8 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, InterpretedOrdering, SortOrder} import org.apache.spark.sql.catalyst.util.{ArrayData, SQLOrderingUtil} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, YearMonthIntervalType} -import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} +import org.apache.spark.unsafe.types.{ByteArray, UTF8String, VariantVal} import org.apache.spark.util.ArrayImplicits._ sealed abstract class PhysicalDataType { @@ -58,6 +58,7 @@ object PhysicalDataType { case StructType(fields) => PhysicalStructType(fields) case MapType(keyType, valueType, valueContainsNull) => PhysicalMapType(keyType, valueType, valueContainsNull) + case VariantType => PhysicalVariantType case _ => UninitializedPhysicalType } @@ -327,6 +328,18 @@ case class PhysicalStructType(fields: Array[StructField]) extends PhysicalDataTy } } +class PhysicalVariantType extends PhysicalDataType { + private[sql] type InternalType = VariantVal + @transient private[sql] lazy val tag = typeTag[InternalType] + + // TODO(SPARK-45891): Support comparison for the Variant type. + override private[sql] def ordering = + throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( + "PhysicalVariantType") +} + +object PhysicalVariantType extends PhysicalVariantType + object UninitializedPhysicalType extends PhysicalDataType { override private[sql] def ordering = throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index bdf8d36321e64..7ff36bef5a4b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{DataType, Decimal} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} class GenericArrayData(val array: Array[Any]) extends ArrayData { @@ -73,6 +73,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getVariant(ordinal: Int): VariantVal = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) override def getArray(ordinal: Int): ArrayData = getAs(ordinal) override def getMap(ordinal: Int): MapData = getAs(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index c3249a4c02d8c..9cc99e9bfa335 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1564,6 +1564,12 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map.empty) } + def cannotSaveVariantIntoExternalStorageError(): Throwable = { + new AnalysisException( + errorClass = "CANNOT_SAVE_VARIANT", + messageParameters = Map.empty) + } + def cannotResolveAttributeError(name: String, outputStr: String): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1137", diff --git a/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt b/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt index 921491a4a4761..47a3f02ac1656 100644 --- a/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt +++ b/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt @@ -355,6 +355,7 @@ VAR_POP VAR_SAMP VARBINARY VARCHAR +VARIANT VARYING VERSIONING WHEN diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala index 01aa3579aea98..eeb05139a3e5b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} class GenerateUnsafeProjectionSuite extends SparkFunSuite { test("Test unsafe projection string access pattern") { @@ -88,6 +88,7 @@ object AlwaysNull extends InternalRow { override def getUTF8String(ordinal: Int): UTF8String = notSupported override def getBinary(ordinal: Int): Array[Byte] = notSupported override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getVariant(ordinal: Int): VariantVal = notSupported override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported override def getArray(ordinal: Int): ArrayData = notSupported override def getMap(ordinal: Int): MapData = notSupported @@ -117,6 +118,7 @@ object AlwaysNonNull extends InternalRow { override def getUTF8String(ordinal: Int): UTF8String = UTF8String.fromString("test") override def getBinary(ordinal: Int): Array[Byte] = notSupported override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getVariant(ordinal: Int): VariantVal = notSupported override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported override def getArray(ordinal: Int): ArrayData = stringToUTF8Array(Array("1", "2", "3")) val keyArray = stringToUTF8Array(Array("1", "2", "3")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala index eaed279679251..e2a416b773aa9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.Decimal -import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} class UnsafeRowWriterSuite extends SparkFunSuite { @@ -61,4 +61,15 @@ class UnsafeRowWriterSuite extends SparkFunSuite { rowWriter.write(1, interval) assert(rowWriter.getRow.getInterval(1) === interval) } + + test("write and get variant through UnsafeRowWriter") { + val rowWriter = new UnsafeRowWriter(2) + rowWriter.resetRowWriter() + rowWriter.setNullAt(0) + assert(rowWriter.getRow.isNullAt(0)) + assert(rowWriter.getRow.getVariant(0) === null) + val variant = new VariantVal(Array[Byte](1, 2, 3), Array[Byte](-1, -2, -3, -4)) + rowWriter.write(1, variant) + assert(rowWriter.getRow.getVariant(1).debugString() == variant.debugString()) + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 7b841ab9933e2..29c106651acf0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -38,6 +38,7 @@ import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly @@ -89,6 +90,8 @@ public static void populate(ConstantColumnVector col, InternalRow row, int field } else if (pdt instanceof PhysicalCalendarIntervalType) { // The value of `numRows` is irrelevant. col.setCalendarInterval((CalendarInterval) row.get(fieldIdx, t)); + } else if (pdt instanceof PhysicalVariantType) { + col.setVariant((VariantVal)row.get(fieldIdx, t)); } else { throw new RuntimeException(String.format("DataType %s is not supported" + " in column vectorized reader.", t.sql())); @@ -124,7 +127,7 @@ public static Map toJavaIntMap(ColumnarMap map) { private static void appendValue(WritableColumnVector dst, DataType t, Object o) { if (o == null) { - if (t instanceof CalendarIntervalType) { + if (t instanceof CalendarIntervalType || t instanceof VariantType) { dst.appendStruct(true); } else { dst.appendNull(); @@ -167,6 +170,11 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) dst.getChild(0).appendInt(c.months); dst.getChild(1).appendInt(c.days); dst.getChild(2).appendLong(c.microseconds); + } else if (t instanceof VariantType) { + VariantVal v = (VariantVal) o; + dst.appendStruct(false); + dst.getChild(0).appendByteArray(v.getValue(), 0, v.getValue().length); + dst.getChild(1).appendByteArray(v.getMetadata(), 0, v.getMetadata().length); } else if (t instanceof DateType) { dst.appendInt(DateTimeUtils.fromJavaDate((Date) o)); } else if (t instanceof TimestampType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java index 5095e6b0c9c6b..43854c2300fde 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * This class adds the constant support to ColumnVector. @@ -67,6 +68,10 @@ public ConstantColumnVector(int numRows, DataType type) { this.childData[0] = new ConstantColumnVector(1, DataTypes.IntegerType); this.childData[1] = new ConstantColumnVector(1, DataTypes.IntegerType); this.childData[2] = new ConstantColumnVector(1, DataTypes.LongType); + } else if (type instanceof VariantType) { + this.childData = new ConstantColumnVector[2]; + this.childData[0] = new ConstantColumnVector(1, DataTypes.BinaryType); + this.childData[1] = new ConstantColumnVector(1, DataTypes.BinaryType); } else { this.childData = null; } @@ -307,4 +312,12 @@ public void setCalendarInterval(CalendarInterval value) { this.childData[1].setInt(value.days); this.childData[2].setLong(value.microseconds); } + + /** + * Sets the Variant `value` for all rows + */ + public void setVariant(VariantVal value) { + this.childData[0].setBinary(value.getValue()); + this.childData[1].setBinary(value.getMetadata()); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index eda58815f3b3a..0a110a204e04b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -28,6 +28,7 @@ import org.apache.spark.sql.vectorized.ColumnarRow; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash @@ -142,6 +143,11 @@ public CalendarInterval getInterval(int ordinal) { return columns[ordinal].getInterval(rowId); } + @Override + public VariantVal getVariant(int ordinal) { + return columns[ordinal].getVariant(rowId); + } + @Override public ColumnarRow getStruct(int ordinal, int numFields) { return columns[ordinal].getStruct(rowId); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 4c8ceff356595..10907c69c2260 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -694,7 +694,7 @@ public final int appendStruct(boolean isNull) { putNull(elementsAppended); elementsAppended++; for (WritableColumnVector c: childColumns) { - if (c.type instanceof StructType) { + if (c.type instanceof StructType || c.type instanceof VariantType) { c.appendStruct(true); } else { c.appendNull(); @@ -975,6 +975,10 @@ protected WritableColumnVector(int capacity, DataType dataType) { this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[1] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[2] = reserveNewColumn(capacity, DataTypes.LongType); + } else if (type instanceof VariantType) { + this.childColumns = new WritableColumnVector[2]; + this.childColumns[0] = reserveNewColumn(capacity, DataTypes.BinaryType); + this.childColumns[1] = reserveNewColumn(capacity, DataTypes.BinaryType); } else { this.childColumns = null; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala index 3fec13a7f9ba9..7c117e0cace97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala @@ -267,6 +267,7 @@ private object RowToColumnConverter { case DoubleType => DoubleConverter case StringType => StringConverter case CalendarIntervalType => CalendarConverter + case VariantType => VariantConverter case at: ArrayType => ArrayConverter(getConverterForType(at.elementType, at.containsNull)) case st: StructType => new StructConverter(st.fields.map( (f) => getConverterForType(f.dataType, f.nullable))) @@ -346,6 +347,15 @@ private object RowToColumnConverter { } } + private object VariantConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + val v = row.getVariant(column) + cv.appendStruct(false) + cv.getChild(0).appendByteArray(v.getValue, 0, v.getValue.length) + cv.getChild(1).appendByteArray(v.getMetadata, 0, v.getMetadata.length) + } + } + private case class ArrayConverter(childConverter: TypeConverter) extends TypeConverter { override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { val values = row.getArray(column) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 9811a1d3f33e4..f6b5ba15afbd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedComm import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTablesExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.ArrayImplicits._ /** @@ -131,6 +131,7 @@ object HiveResult { HIVE_STYLE, startField, endField) + case (v: VariantVal, VariantType) => v.toString case (other, _: UserDefinedType[_]) => other.toString } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index cd295f3b17bd6..835308f3d0248 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, Tex import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructField, StructType, VariantType} import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.{HadoopFSUtils, ThreadUtils, Utils} import org.apache.spark.util.ArrayImplicits._ @@ -503,6 +503,7 @@ case class DataSource( providingInstance() match { case dataSource: CreatableRelationProvider => disallowWritingIntervals(outputColumns.map(_.dataType), forbidAnsiIntervals = true) + disallowWritingVariant(outputColumns.map(_.dataType)) dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => @@ -524,6 +525,7 @@ case class DataSource( providingInstance() match { case dataSource: CreatableRelationProvider => disallowWritingIntervals(data.schema.map(_.dataType), forbidAnsiIntervals = true) + disallowWritingVariant(data.schema.map(_.dataType)) SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => disallowWritingIntervals(data.schema.map(_.dataType), forbidAnsiIntervals = false) @@ -560,6 +562,14 @@ case class DataSource( throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError() }) } + + private def disallowWritingVariant(dataTypes: Seq[DataType]): Unit = { + dataTypes.foreach { dt => + if (dt.existsRecursively(_.isInstanceOf[VariantType])) { + throw QueryCompilationErrors.cannotSaveVariantIntoExternalStorageError() + } + } + } } object DataSource extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 069ad9562a7d5..32370562003f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -145,6 +145,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] override def supportDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: BinaryType => false case _: AtomicType => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 9c6c77a8b9622..7fb6e98fb0468 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -134,6 +134,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] override def supportDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: AtomicType => true case st: StructType => st.forall { f => supportDataType(f.dataType) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index b7e6f11f67d69..623f97499cd55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -237,6 +237,8 @@ class OrcFileFormat } override def supportDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: AtomicType => true case st: StructType => st.forall { f => supportDataType(f.dataType) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index eedd165278aed..f60f7c11eefa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -185,6 +185,11 @@ class ParquetToSparkSchemaConverter( } field match { case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, targetType) + case groupColumn: GroupColumnIO if targetType.contains(VariantType) => + ParquetColumn(VariantType, groupColumn, Seq( + convertField(groupColumn.getChild(0), Some(BinaryType)), + convertField(groupColumn.getChild(1), Some(BinaryType)) + )) case groupColumn: GroupColumnIO => convertGroupField(groupColumn, targetType) } } @@ -719,6 +724,12 @@ class SparkToParquetSchemaConverter( // Other types // =========== + case VariantType => + Types.buildGroup(repetition) + .addField(convertField(StructField("value", BinaryType, nullable = false))) + .addField(convertField(StructField("metadata", BinaryType, nullable = false))) + .named(field.name) + case StructType(fields) => fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) => builder.addField(convertField(field)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index 9535bbd585bce..e410789504e70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -238,6 +238,18 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { case DecimalType.Fixed(precision, scale) => makeDecimalWriter(precision, scale) + case VariantType => + (row: SpecializedGetters, ordinal: Int) => + val v = row.getVariant(ordinal) + consumeGroup { + consumeField("value", 0) { + recordConsumer.addBinary(Binary.fromReusedByteArray(v.getValue)) + } + consumeField("metadata", 1) { + recordConsumer.addBinary(Binary.fromReusedByteArray(v.getMetadata)) + } + } + case t: StructType => val fieldWriters = t.map(_.dataType).map(makeWriter).toArray[ValueWriter] (row: SpecializedGetters, ordinal: Int) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala index 776192992789a..300c0f5004252 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala @@ -140,6 +140,8 @@ class XmlFileFormat extends TextBasedFileFormat with DataSourceRegister { override def equals(other: Any): Boolean = other.isInstanceOf[XmlFileFormat] override def supportDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: AtomicType => true case st: StructType => st.forall { f => supportDataType(f.dataType) } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 017cc474ea028..8e6bad11c09a9 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -429,6 +429,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.VariancePop | var_pop | SELECT var_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp | SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance | SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | +| org.apache.spark.sql.catalyst.expressions.variant.ParseJson | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct | | org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | SELECT xpath_boolean('1','a/b') | struct1, a/b):boolean> | | org.apache.spark.sql.catalyst.expressions.xml.XPathDouble | xpath_double | SELECT xpath_double('12', 'sum(a/b)') | struct12, sum(a/b)):double> | | org.apache.spark.sql.catalyst.expressions.xml.XPathDouble | xpath_number | SELECT xpath_number('12', 'sum(a/b)') | struct12, sum(a/b)):double> | diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index f88dcbd465852..10fcee1469398 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -320,6 +320,7 @@ VALUES false VAR false VARCHAR false VARIABLE false +VARIANT false VERSION false VIEW false VIEWS false diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index b618299ea61a8..be2303a716da5 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -320,6 +320,7 @@ VALUES false VAR false VARCHAR false VARIABLE false +VARIANT false VERSION false VIEW false VIEWS false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala new file mode 100644 index 0000000000000..dde986c555b10 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File + +import scala.util.Random + +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.VariantVal + +class VariantSuite extends QueryTest with SharedSparkSession { + test("basic tests") { + def verifyResult(df: DataFrame): Unit = { + val result = df.collect() + .map(_.get(0).asInstanceOf[VariantVal].toString) + .sorted + .toSeq + val expected = (1 until 10).map(id => "1" * id) + assert(result == expected) + } + + // At this point, JSON parsing logic is not really implemented. We just construct some number + // inputs that are also valid JSON. This exercises passing VariantVal throughout the system. + val query = spark.sql("select parse_json(repeat('1', id)) as v from range(1, 10)") + verifyResult(query) + + // Write into and read from Parquet. + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + query.write.parquet(tempDir) + verifyResult(spark.read.parquet(tempDir)) + } + } + + test("round trip tests") { + val rand = new Random(42) + val input = Seq.fill(50) { + if (rand.nextInt(10) == 0) { + null + } else { + val value = new Array[Byte](rand.nextInt(50)) + rand.nextBytes(value) + val metadata = new Array[Byte](rand.nextInt(50)) + rand.nextBytes(metadata) + new VariantVal(value, metadata) + } + } + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(input.map(Row(_))), + StructType.fromDDL("v variant") + ) + val result = df.collect().map(_.get(0).asInstanceOf[VariantVal]) + + def prepareAnswer(values: Seq[VariantVal]): Seq[String] = { + values.map(v => if (v == null) "null" else v.debugString()).sorted + } + assert(prepareAnswer(input) == prepareAnswer(result)) + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 72e6fae92cbc2..9bb35bb8719ea 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BETWEEN,BIGINT,BINARY,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPUTE,CONCATENATE,CONSTRAINT,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DELETE,DELIMITED,DESC,DESCRIBE,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EXCEPT,EXCHANGE,EXCLUDE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,IS,ITEMS,JOIN,KEYS,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PERCENTILE_CONT,PERCENTILE_DISC,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BETWEEN,BIGINT,BINARY,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPUTE,CONCATENATE,CONSTRAINT,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DELETE,DELIMITED,DESC,DESCRIBE,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EXCEPT,EXCHANGE,EXCLUDE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,IS,ITEMS,JOIN,KEYS,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PERCENTILE_CONT,PERCENTILE_DISC,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 3cf6fcbc65ace..5ccd40aefa255 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -191,6 +191,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } override def supportDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: AnsiIntervalType => false case _: AtomicType => true From 2cac768fabaa7cad40390b2205dd9c5000011e4c Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Tue, 14 Nov 2023 13:55:25 +0900 Subject: [PATCH 121/121] [SPARK-45844][SQL] Implement case-insensitivity for XML ### What changes were proposed in this pull request? This PR addresses the current lack of support for case-insensitive schema handling in XML file formats. Our approach now follows the `SQLConf` case insensitivity setting in both schema inference and file read operations. We handle duplicate keys in the following behavior: 1. When we encounter duplicates (whether case-sensitive or not) in a row, we will convert them into an array and pick the first one we encounter as the array's name. 2. When we encounter duplicates across rows, we will also respect the first one we encounter Keys of the map-type data are string types and are not treated as field names, thereby not requiring case-sensitivity checks. ### Why are the changes needed? To keep consistent with other file formats and reduce maintenance efforts. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #43722 from shujingyang-db/case-sensitive. Lead-authored-by: Shujing Yang Co-authored-by: Shujing Yang <135740748+shujingyang-db@users.noreply.github.com> Signed-off-by: Hyukjin Kwon --- .../catalyst/expressions/xmlExpressions.scala | 3 +- .../sql/catalyst/xml/StaxXmlParser.scala | 17 ++- .../sql/catalyst/xml/XmlInferSchema.scala | 91 ++++++++---- .../datasources/xml/XmlDataSource.scala | 7 +- .../attributes-case-sensitive.xml | 12 ++ .../execution/datasources/xml/XmlSuite.scala | 139 ++++++++++++++++++ 6 files changed, 235 insertions(+), 34 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/xml-resources/attributes-case-sensitive.xml diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index c581643460f65..27c0a09fa1f06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -189,7 +189,8 @@ case class SchemaOfXml( private lazy val xmlFactory = xmlOptions.buildXmlFactory() @transient - private lazy val xmlInferSchema = new XmlInferSchema(xmlOptions) + private lazy val xmlInferSchema = + new XmlInferSchema(xmlOptions, caseSensitive = SQLConf.get.caseSensitiveAnalysis) @transient private lazy val xml = child.eval().asInstanceOf[UTF8String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 77a0bd1dff179..754b54ce1575c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -35,10 +35,11 @@ import org.apache.spark.SparkUpgradeException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, CaseInsensitiveMap, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -87,6 +88,14 @@ class StaxXmlParser( } } + private def getFieldNameToIndex(schema: StructType): Map[String, Int] = { + if (SQLConf.get.caseSensitiveAnalysis) { + schema.map(_.name).zipWithIndex.toMap + } else { + CaseInsensitiveMap(schema.map(_.name).zipWithIndex.toMap) + } + } + def parseStream( inputStream: InputStream, schema: StructType): Iterator[InternalRow] = { @@ -274,7 +283,7 @@ class StaxXmlParser( val convertedValuesMap = collection.mutable.Map.empty[String, Any] val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options) valuesMap.foreach { case (f, v) => - val nameToIndex = schema.map(_.name).zipWithIndex.toMap + val nameToIndex = getFieldNameToIndex(schema) nameToIndex.get(f).foreach { i => convertedValuesMap(f) = convertTo(v, schema(i).dataType) } @@ -313,7 +322,7 @@ class StaxXmlParser( // Here we merge both to a row. val valuesMap = fieldsMap ++ attributesMap valuesMap.foreach { case (f, v) => - val nameToIndex = schema.map(_.name).zipWithIndex.toMap + val nameToIndex = getFieldNameToIndex(schema) nameToIndex.get(f).foreach { row(_) = v } } @@ -335,7 +344,7 @@ class StaxXmlParser( rootAttributes: Array[Attribute] = Array.empty, isRootAttributesOnly: Boolean = false): InternalRow = { val row = new Array[Any](schema.length) - val nameToIndex = schema.map(_.name).zipWithIndex.toMap + val nameToIndex = getFieldNameToIndex(schema) // If there are attributes, then we process them first. convertAttributes(rootAttributes, schema).toSeq.foreach { case (f, v) => nameToIndex.get(f).foreach { row(_) = v } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index 777dd69fd7fa0..25f33e7f1bbdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -36,7 +36,9 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, PermissiveMode, Timest import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.types._ -private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with Logging { +private[sql] class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) + extends Serializable + with Logging { private val decimalParser = ExprUtils.getDecimalParser(options.locale) @@ -115,8 +117,7 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with } } - def infer(xml: String, - xsdSchema: Option[Schema] = None): Option[DataType] = { + def infer(xml: String, xsdSchema: Option[Schema] = None): Option[DataType] = { try { val xsd = xsdSchema.orElse(Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)) xsd.foreach { schema => @@ -199,14 +200,50 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with private def inferObject( parser: XMLEventReader, rootAttributes: Array[Attribute] = Array.empty): DataType = { - val builder = ArrayBuffer[StructField]() - val nameToDataType = collection.mutable.Map.empty[String, ArrayBuffer[DataType]] + /** + * Retrieves the field name with respect to the case sensitivity setting. + * We pick the first name we encountered. + * + * If case sensitivity is enabled, the original field name is returned. + * If not, the field name is managed in a case-insensitive map. + * + * For instance, if we encounter the following field names: + * foo, Foo, FOO + * + * In case-sensitive mode: we will infer three fields: foo, Foo, FOO + * In case-insensitive mode, we will infer an array named by foo + * (as it's the first one we encounter) + */ + val caseSensitivityOrdering: Ordering[String] = (x: String, y: String) => + if (caseSensitive) { + x.compareTo(y) + } else { + x.compareToIgnoreCase(y) + } + + val nameToDataType = + collection.mutable.TreeMap.empty[String, DataType](caseSensitivityOrdering) + + def addOrUpdateType(fieldName: String, newType: DataType): Unit = { + val oldTypeOpt = nameToDataType.get(fieldName) + oldTypeOpt match { + // If the field name exists in the map, + // merge the type and infer the combined field as an array type if necessary + case Some(oldType) if !oldType.isInstanceOf[ArrayType] => + nameToDataType.update(fieldName, ArrayType(compatibleType(oldType, newType))) + case Some(oldType) => + nameToDataType.update(fieldName, compatibleType(oldType, newType)) + case None => + nameToDataType.put(fieldName, newType) + } + } + // If there are attributes, then we should process them first. val rootValuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options) rootValuesMap.foreach { case (f, v) => - nameToDataType += (f -> ArrayBuffer(inferFrom(v))) + addOrUpdateType(f, inferFrom(v)) } var shouldStop = false while (!shouldStop) { @@ -239,14 +276,12 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with } // Add the field and datatypes so that we can check if this is ArrayType. val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options) - val dataTypes = nameToDataType.getOrElse(field, ArrayBuffer.empty[DataType]) - dataTypes += inferredType - nameToDataType += (field -> dataTypes) + addOrUpdateType(field, inferredType) case c: Characters if !c.isWhiteSpace => // This can be an attribute-only object val valueTagType = inferFrom(c.getData) - nameToDataType += options.valueTag -> ArrayBuffer(valueTagType) + addOrUpdateType(options.valueTag, valueTagType) case _: EndElement => shouldStop = StaxXmlParserUtils.checkEndElement(parser) @@ -258,25 +293,17 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with // if it only consists of attributes and valueTags. // If not, we will remove the valueTag field from the schema val attributesOnly = nameToDataType.forall { - case (fieldName, dataTypes) => - dataTypes.length == 1 && - (fieldName == options.valueTag || fieldName.startsWith(options.attributePrefix)) + case (fieldName, _) => + fieldName == options.valueTag || fieldName.startsWith(options.attributePrefix) } if (!attributesOnly) { nameToDataType -= options.valueTag } - // We need to manually merges the fields having the sames so that - // This can be inferred as ArrayType. - nameToDataType.foreach { - case (field, dataTypes) if dataTypes.length > 1 => - val elementType = dataTypes.reduceLeft(compatibleType) - builder += StructField(field, ArrayType(elementType), nullable = true) - case (field, dataTypes) => - builder += StructField(field, dataTypes.head, nullable = true) - } // Note: other code relies on this sorting for correctness, so don't remove it! - StructType(builder.sortBy(_.name).toArray) + StructType(nameToDataType.map{ + case (name, dataType) => StructField(name, dataType) + }.toList.sortBy(_.name)) } /** @@ -384,7 +411,12 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with /** * Returns the most general data type for two given data types. */ - def compatibleType(t1: DataType, t2: DataType): DataType = { + private[xml] def compatibleType(t1: DataType, t2: DataType): DataType = { + + def normalize(name: String): String = { + if (caseSensitive) name else name.toLowerCase(Locale.ROOT) + } + // TODO: Optimise this logic. findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. @@ -406,10 +438,15 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with } case (StructType(fields1), StructType(fields2)) => - val newFields = (fields1 ++ fields2).groupBy(_.name).map { - case (name, fieldTypes) => + val newFields = (fields1 ++ fields2) + // normalize field name and pair it with original field + .map(field => (normalize(field.name), field)) + .groupBy(_._1) // group by normalized field name + .map { case (_: String, fields: Array[(String, StructField)]) => + val fieldTypes = fields.map(_._2) val dataType = fieldTypes.map(_.dataType).reduce(compatibleType) - StructField(name, dataType, nullable = true) + // we pick up the first field name that we've encountered for the field + StructField(fields.head._2.name, dataType) } StructType(newFields.toArray.sortBy(_.name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala index b09be84130abb..4b3c82bd83bc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala @@ -122,7 +122,8 @@ object TextInputXmlDataSource extends XmlDataSource { xml: Dataset[String], parsedOptions: XmlOptions): StructType = { SQLExecution.withSQLConfPropagated(xml.sparkSession) { - new XmlInferSchema(parsedOptions).infer(xml.rdd) + new XmlInferSchema(parsedOptions, xml.sparkSession.sessionState.conf.caseSensitiveAnalysis) + .infer(xml.rdd) } } @@ -179,7 +180,9 @@ object MultiLineXmlDataSource extends XmlDataSource { parsedOptions) } SQLExecution.withSQLConfPropagated(sparkSession) { - val schema = new XmlInferSchema(parsedOptions).infer(tokenRDD) + val schema = + new XmlInferSchema(parsedOptions, sparkSession.sessionState.conf.caseSensitiveAnalysis) + .infer(tokenRDD) schema } } diff --git a/sql/core/src/test/resources/test-data/xml-resources/attributes-case-sensitive.xml b/sql/core/src/test/resources/test-data/xml-resources/attributes-case-sensitive.xml new file mode 100644 index 0000000000000..40a78fb279ba3 --- /dev/null +++ b/sql/core/src/test/resources/test-data/xml-resources/attributes-case-sensitive.xml @@ -0,0 +1,12 @@ + + + + 1 + 2 + 3 + 4 + + + 5 + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 21122676c46be..5a901dadff94d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.xml.XmlOptions._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.xml.TestUtils._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1976,4 +1977,142 @@ class XmlSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(true, 10.1, -10, 10, "8E9D", "8E9F", Timestamp.valueOf("2015-01-01 00:00:00"))) } + + test("case sensitivity test - attributes-only object") { + val schemaCaseSensitive = new StructType() + .add("array", ArrayType( + new StructType() + .add("_Attr2", LongType) + .add("_VALUE", LongType) + .add("_aTTr2", LongType) + .add("_attr2", LongType))) + .add("struct", new StructType() + .add("_Attr1", LongType) + .add("_VALUE", LongType) + .add("_attr1", LongType)) + + val dfCaseSensitive = Seq( + Row( + Array( + Row(null, 2, null, 2), + Row(3, 3, null, null), + Row(null, 4, 4, null)), + Row(null, 1, 1) + ), + Row( + null, + Row(5, 5, null) + ) + ) + val schemaCaseInSensitive = new StructType() + .add("array", ArrayType(new StructType().add("_VALUE", LongType).add("_attr2", LongType))) + .add("struct", new StructType().add("_VALUE", LongType).add("_attr1", LongType)) + val dfCaseInsensitive = + Seq( + Row( + Array(Row(2, 2), Row(3, 3), Row(4, 4)), + Row(1, 1)), + Row(null, Row(5, 5))) + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val df = spark.read + .option("rowTag", "ROW") + .xml(getTestResourcePath(resDir + "attributes-case-sensitive.xml")) + assert(df.schema == (if (caseSensitive) schemaCaseSensitive else schemaCaseInSensitive)) + checkAnswer( + df, + if (caseSensitive) dfCaseSensitive else dfCaseInsensitive) + } + } + } + + testCaseSensitivity( + "basic", + writeData = Seq(Row(1L, null), Row(null, 2L)), + writeSchema = new StructType() + .add("A1", LongType) + .add("a1", LongType), + expectedSchema = new StructType() + .add("A1", LongType), + readDataCaseInsensitive = Seq(Row(1L), Row(2L))) + + testCaseSensitivity( + "nested struct", + writeData = Seq(Row(Row(1L), null), Row(null, Row(2L))), + writeSchema = new StructType() + .add("A1", new StructType().add("B1", LongType)) + .add("a1", new StructType().add("b1", LongType)), + expectedSchema = new StructType() + .add("A1", new StructType().add("B1", LongType)), + readDataCaseInsensitive = Seq(Row(Row(1L)), Row(Row(2L))) + ) + + testCaseSensitivity( + "convert fields into array", + writeData = Seq(Row(1L, 2L)), + writeSchema = new StructType() + .add("A1", LongType) + .add("a1", LongType), + expectedSchema = new StructType() + .add("A1", ArrayType(LongType)), + readDataCaseInsensitive = Seq(Row(Array(1L, 2L)))) + + testCaseSensitivity( + "basic array", + writeData = Seq(Row(Array(1L, 2L), Array(3L, 4L))), + writeSchema = new StructType() + .add("A1", ArrayType(LongType)) + .add("a1", ArrayType(LongType)), + expectedSchema = new StructType() + .add("A1", ArrayType(LongType)), + readDataCaseInsensitive = Seq(Row(Array(1L, 2L, 3L, 4L)))) + + testCaseSensitivity( + "nested array", + writeData = + Seq(Row(Array(Row(1L, 2L), Row(3L, 4L)), null), Row(null, Array(Row(5L, 6L), Row(7L, 8L)))), + writeSchema = new StructType() + .add("A1", ArrayType(new StructType().add("B1", LongType).add("d", LongType))) + .add("a1", ArrayType(new StructType().add("b1", LongType).add("c", LongType))), + expectedSchema = new StructType() + .add( + "A1", + ArrayType( + new StructType() + .add("B1", LongType) + .add("c", LongType) + .add("d", LongType))), + readDataCaseInsensitive = Seq( + Row(Array(Row(1L, null, 2L), Row(3L, null, 4L))), + Row(Array(Row(5L, 6L, null), Row(7L, 8L, null))))) + + def testCaseSensitivity( + name: String, + writeData: Seq[Row], + writeSchema: StructType, + expectedSchema: StructType, + readDataCaseInsensitive: Seq[Row]): Unit = { + test(s"case sensitivity test - $name") { + withTempDir { dir => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + spark + .createDataFrame(writeData.asJava, writeSchema) + .repartition(1) + .write + .option("rowTag", "ROW") + .format("xml") + .mode("overwrite") + .save(dir.getCanonicalPath) + } + + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val df = spark.read.option("rowTag", "ROW").xml(dir.getCanonicalPath) + assert(df.schema == (if (caseSensitive) writeSchema else expectedSchema)) + checkAnswer(df, if (caseSensitive) writeData else readDataCaseInsensitive) + } + } + } + } + } }