Skip to content

Commit

Permalink
Merge pull request #1655 from apache/master
Browse files Browse the repository at this point in the history
Create a new pull request by comparing changes across two branches
  • Loading branch information
GulajavaMinistudio authored Jun 6, 2024
2 parents 7eabe1d + f4434c3 commit 2823ce4
Show file tree
Hide file tree
Showing 49 changed files with 1,155 additions and 199 deletions.
1 change: 0 additions & 1 deletion LICENSE-binary
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ com.google.crypto.tink:tink
com.google.flatbuffers:flatbuffers-java
com.google.guava:guava
com.jamesmurty.utils:java-xmlbuilder
com.jolbox:bonecp
com.ning:compress-lzf
com.squareup.okhttp3:logging-interceptor
com.squareup.okhttp3:okhttp
Expand Down
9 changes: 5 additions & 4 deletions NOTICE-binary
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ services.
// Version 2.0, in this case for
// ------------------------------------------------------------------

Hive Beeline
Copyright 2016 The Apache Software Foundation
=== NOTICE FOR com.clearspring.analytics:streams ===
stream-api
Copyright 2016 AddThis

This product includes software developed at
The Apache Software Foundation (http://www.apache.org/).
This product includes software developed by AddThis.
=== END OF NOTICE FOR com.clearspring.analytics:streams ===

Apache Avro
Copyright 2009-2014 The Apache Software Foundation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,14 +462,13 @@ private[spark] object MavenUtils extends Logging {
val sysOut = System.out
// Default configuration name for ivy
val ivyConfName = "default"

// A Module descriptor must be specified. Entries are dummy strings
val md = getModuleDescriptor

md.setDefaultConf(ivyConfName)
var md: DefaultModuleDescriptor = null
try {
// To prevent ivy from logging to system out
System.setOut(printStream)
// A Module descriptor must be specified. Entries are dummy strings
md = getModuleDescriptor
md.setDefaultConf(ivyConfName)
val artifacts = extractMavenCoordinates(coordinates)
// Directories for caching downloads through ivy and storing the jars when maven coordinates
// are supplied to spark-submit
Expand Down Expand Up @@ -548,7 +547,9 @@ private[spark] object MavenUtils extends Logging {
}
} finally {
System.setOut(sysOut)
clearIvyResolutionFiles(md.getModuleRevisionId, ivySettings.getDefaultCache, ivyConfName)
if (md != null) {
clearIvyResolutionFiles(md.getModuleRevisionId, ivySettings.getDefaultCache, ivyConfName)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,35 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession {
assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
}

test("interrupt all - streaming queries") {
val q1 = spark.readStream
.format("rate")
.option("rowsPerSecond", 1)
.load()
.writeStream
.format("console")
.start()

val q2 = spark.readStream
.format("rate")
.option("rowsPerSecond", 1)
.load()
.writeStream
.format("console")
.start()

assert(q1.isActive)
assert(q2.isActive)

val interrupted = spark.interruptAll()

q1.awaitTermination(timeoutMs = 20 * 1000)
q2.awaitTermination(timeoutMs = 20 * 1000)
assert(!q1.isActive)
assert(!q2.isActive)
assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
}

// TODO(SPARK-48139): Re-enable `SparkSessionE2ESuite.interrupt tag`
ignore("interrupt tag") {
val session = spark
Expand Down Expand Up @@ -230,6 +259,53 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession {
assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
}

test("interrupt tag - streaming query") {
spark.addTag("foo")
val q1 = spark.readStream
.format("rate")
.option("rowsPerSecond", 1)
.load()
.writeStream
.format("console")
.start()
assert(spark.getTags() == Set("foo"))

spark.addTag("bar")
val q2 = spark.readStream
.format("rate")
.option("rowsPerSecond", 1)
.load()
.writeStream
.format("console")
.start()
assert(spark.getTags() == Set("foo", "bar"))

spark.clearTags()

spark.addTag("zoo")
val q3 = spark.readStream
.format("rate")
.option("rowsPerSecond", 1)
.load()
.writeStream
.format("console")
.start()
assert(spark.getTags() == Set("zoo"))

assert(q1.isActive)
assert(q2.isActive)
assert(q3.isActive)

val interrupted = spark.interruptTag("foo")

q1.awaitTermination(timeoutMs = 20 * 1000)
q2.awaitTermination(timeoutMs = 20 * 1000)
assert(!q1.isActive)
assert(!q2.isActive)
assert(q3.isActive)
assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
}

test("progress is available for the spark result") {
val result = spark
.range(10000)
Expand Down
21 changes: 14 additions & 7 deletions connector/connect/docs/client-connection-string.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ cannot contain arbitrary characters. Configuration parameter are passed in the
style of the HTTP URL Path Parameter Syntax. This is similar to the JDBC connection
strings. The path component must be empty. All parameters are interpreted **case sensitive**.

```shell
sc://hostname:port/;param1=value;param2=value
```text
sc://host:port/;param1=value;param2=value
```

<table>
Expand All @@ -34,7 +34,7 @@ sc://hostname:port/;param1=value;param2=value
<td>Examples</td>
</tr>
<tr>
<td>hostname</td>
<td>host</td>
<td>String</td>
<td>
The hostname of the endpoint for Spark Connect. Since the endpoint
Expand All @@ -49,8 +49,8 @@ sc://hostname:port/;param1=value;param2=value
</tr>
<tr>
<td>port</td>
<td>Numeric</td>
<td>The portname to be used when connecting to the GRPC endpoint. The
<td>Numeric</td>
<td>The port to be used when connecting to the GRPC endpoint. The
default values is: <b>15002</b>. Any valid port number can be used.</td>
<td><pre>15002</pre><pre>443</pre></td>
</tr>
Expand All @@ -75,7 +75,7 @@ sc://hostname:port/;param1=value;param2=value
<td>user_id</td>
<td>String</td>
<td>User ID to automatically set in the Spark Connect UserContext message.
This is necssary for the appropriate Spark Session management. This is an
This is necessary for the appropriate Spark Session management. This is an
*optional* parameter and depending on the deployment this parameter might
be automatically injected using other means.</td>
<td>
Expand All @@ -99,9 +99,16 @@ sc://hostname:port/;param1=value;param2=value
allows to provide this session ID to allow sharing Spark Sessions for the same users
for example across multiple languages. The value must be provided in a valid UUID
string format.<br/>
<i>Default: A UUID generated randomly.</td>
<i>Default: </i><pre>A UUID generated randomly</pre></td>
<td><pre>session_id=550e8400-e29b-41d4-a716-446655440000</pre></td>
</tr>
<tr>
<td>grpc_max_message_size</td>
<td>Numeric</td>
<td>Maximum message size allowed for gRPC messages in bytes.<br/>
<i>Default: </i><pre> 128 * 1024 * 1024</pre></td>
<td><pre>grpc_max_message_size=134217728</pre></td>
</tr>
</table>

## Examples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3163,7 +3163,11 @@ class SparkConnectPlanner(
}

// Register the new query so that its reference is cached and is stopped on session timeout.
SparkConnectService.streamingSessionManager.registerNewStreamingQuery(sessionHolder, query)
SparkConnectService.streamingSessionManager.registerNewStreamingQuery(
sessionHolder,
query,
executeHolder.sparkSessionTags,
executeHolder.operationId)
// Register the runner with the query if Python foreachBatch is enabled.
foreachBatchRunnerCleaner.foreach { cleaner =>
sessionHolder.streamingForeachBatchRunnerCleanerCache.registerCleanerForQuery(
Expand Down Expand Up @@ -3228,7 +3232,9 @@ class SparkConnectPlanner(

// Find the query in connect service level cache, otherwise check session's active streams.
val query = SparkConnectService.streamingSessionManager
.getCachedQuery(id, runId, session) // Common case: query is cached in the cache.
// Common case: query is cached in the cache.
.getCachedQuery(id, runId, executeHolder.sparkSessionTags, session)
.map(_.query)
.orElse { // Else try to find it in active streams. Mostly will not be found here either.
Option(session.streams.get(id))
} match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit}
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
import scala.jdk.CollectionConverters._
import scala.util.Try

Expand Down Expand Up @@ -179,12 +180,14 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
*/
private[service] def interruptAll(): Seq[String] = {
val interruptedIds = new mutable.ArrayBuffer[String]()
val operationsIds =
SparkConnectService.streamingSessionManager.cleanupRunningQueries(this, blocking = false)
executions.asScala.values.foreach { execute =>
if (execute.interrupt()) {
interruptedIds += execute.operationId
}
}
interruptedIds.toSeq
interruptedIds.toSeq ++ operationsIds
}

/**
Expand All @@ -194,14 +197,16 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
*/
private[service] def interruptTag(tag: String): Seq[String] = {
val interruptedIds = new mutable.ArrayBuffer[String]()
val queries = SparkConnectService.streamingSessionManager.getTaggedQuery(tag, session)
queries.foreach(q => Future(q.query.stop())(ExecutionContext.global))
executions.asScala.values.foreach { execute =>
if (execute.sparkSessionTags.contains(tag)) {
if (execute.interrupt()) {
interruptedIds += execute.operationId
}
}
}
interruptedIds.toSeq
interruptedIds.toSeq ++ queries.map(_.operationId)
}

/**
Expand Down Expand Up @@ -298,7 +303,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio

// Clean up running streaming queries.
// Note: there can be concurrent streaming queries being started.
SparkConnectService.streamingSessionManager.cleanupRunningQueries(this)
SparkConnectService.streamingSessionManager.cleanupRunningQueries(this, blocking = true)
streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any streaming workers.
removeAllListeners() // removes all listener and stop python listener processes if necessary.

Expand Down
Loading

0 comments on commit 2823ce4

Please sign in to comment.