diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2cf471e47d746..4d31bcd2d1949 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1436,6 +1436,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE = + buildConf("spark.sql.streaming.continuous.epochBacklogQueueSize") + .doc("The max number of entries to be stored in queue to wait for late epochs. " + + "If this parameter is exceeded by the size of the queue, stream will stop with an error.") + .intConf + .createWithDefault(10000) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -2066,6 +2073,9 @@ class SQLConf extends Serializable with Logging { def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) + def continuousStreamingEpochBacklogQueueSize: Int = + getConf(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) def continuousStreamingExecutorPollIntervalMs: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 20101c7fda320..3f88bcf48b86d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.continuous import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import java.util.function.UnaryOperator import scala.collection.JavaConverters._ @@ -58,6 +59,9 @@ class ContinuousExecution( // For use only in test harnesses. private[sql] var currentEpochCoordinatorId: String = _ + // Throwable that caused the execution to fail + private val failure: AtomicReference[Throwable] = new AtomicReference[Throwable](null) + override val logicalPlan: LogicalPlan = { val v2ToRelationMap = MutableMap[StreamingRelationV2, StreamingDataSourceV2Relation]() var nextSourceId = 0 @@ -261,6 +265,11 @@ class ContinuousExecution( lastExecution.toRdd } } + + val f = failure.get() + if (f != null) { + throw f + } } catch { case t: Throwable if StreamExecution.isInterruptionException(t, sparkSession.sparkContext) && state.get() == RECONFIGURING => @@ -373,6 +382,35 @@ class ContinuousExecution( } } + /** + * Stores error and stops the query execution thread to terminate the query in new thread. + */ + def stopInNewThread(error: Throwable): Unit = { + if (failure.compareAndSet(null, error)) { + logError(s"Query $prettyIdString received exception $error") + stopInNewThread() + } + } + + /** + * Stops the query execution thread to terminate the query in new thread. + */ + private def stopInNewThread(): Unit = { + new Thread("stop-continuous-execution") { + setDaemon(true) + + override def run(): Unit = { + try { + ContinuousExecution.this.stop() + } catch { + case e: Throwable => + logError(e.getMessage, e) + throw e + } + } + }.start() + } + /** * Stops the query execution thread to terminate the query. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index a99842220424d..decf524f7167c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -123,6 +123,9 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private val epochBacklogQueueSize = + session.sqlContext.conf.continuousStreamingEpochBacklogQueueSize + private var queryWritesStopped: Boolean = false private var numReaderPartitions: Int = _ @@ -212,6 +215,7 @@ private[continuous] class EpochCoordinator( if (!partitionCommits.isDefinedAt((epoch, partitionId))) { partitionCommits.put((epoch, partitionId), message) resolveCommitsAtEpoch(epoch) + checkProcessingQueueBoundaries() } case ReportPartitionOffset(partitionId, epoch, offset) => @@ -223,6 +227,22 @@ private[continuous] class EpochCoordinator( query.addOffset(epoch, stream, thisEpochOffsets.toSeq) resolveCommitsAtEpoch(epoch) } + checkProcessingQueueBoundaries() + } + + private def checkProcessingQueueBoundaries() = { + if (partitionOffsets.size > epochBacklogQueueSize) { + query.stopInNewThread(new IllegalStateException("Size of the partition offset queue has " + + "exceeded its maximum")) + } + if (partitionCommits.size > epochBacklogQueueSize) { + query.stopInNewThread(new IllegalStateException("Size of the partition commit queue has " + + "exceeded its maximum")) + } + if (epochsWaitingToBeCommitted.size > epochBacklogQueueSize) { + query.stopInNewThread(new IllegalStateException("Size of the epoch queue has " + + "exceeded its maximum")) + } } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 344a8aa55f0c5..d2e489a7d4ad2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -343,3 +344,33 @@ class ContinuousMetaSuite extends ContinuousSuiteBase { } } } + +class ContinuousEpochBacklogSuite extends ContinuousSuiteBase { + import testImplicits._ + + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[1]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // This test forces the backlog to overflow by not standing up enough executors for the query + // to make progress. + test("epoch backlog overflow") { + withSQLConf((CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE.key, "10")) { + val df = spark.readStream + .format("rate") + .option("numPartitions", "2") + .option("rowsPerSecond", "500") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(Trigger.Continuous(1)), + ExpectFailure[IllegalStateException] { e => + e.getMessage.contains("queue has exceeded its maximum") + } + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index f74285f4b0fb3..e3498db4194e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.streaming.continuous +import org.mockito.{ArgumentCaptor, InOrder} import org.mockito.ArgumentMatchers.{any, eq => eqTo} -import org.mockito.InOrder -import org.mockito.Mockito.{inOrder, never, verify} +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfterEach import org.scalatest.mockito.MockitoSugar @@ -27,6 +27,7 @@ import org.apache.spark._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite @@ -43,6 +44,7 @@ class EpochCoordinatorSuite private var writeSupport: StreamingWrite = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ + private val epochBacklogQueueSize = 10 override def beforeEach(): Unit = { val stream = mock[ContinuousStream] @@ -50,7 +52,11 @@ class EpochCoordinatorSuite query = mock[ContinuousExecution] orderVerifier = inOrder(writeSupport, query) - spark = new TestSparkSession() + spark = new TestSparkSession( + new SparkContext( + "local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true") + .set(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE, epochBacklogQueueSize))) epochCoordinator = EpochCoordinatorRef.create(writeSupport, stream, query, "test", 1, spark, SparkEnv.get) @@ -186,6 +192,66 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2, 3, 4, 5)) } + test("several epochs, max epoch backlog reached by partitionOffsets") { + setWriterPartitions(1) + setReaderPartitions(1) + + reportPartitionOffset(0, 1) + // Commit messages not arriving + for (i <- 2 to epochBacklogQueueSize + 1) { + reportPartitionOffset(0, i) + } + + makeSynchronousCall() + + for (i <- 1 to epochBacklogQueueSize + 1) { + verifyNoCommitFor(i) + } + verifyStoppedWithException("Size of the partition offset queue has exceeded its maximum") + } + + test("several epochs, max epoch backlog reached by partitionCommits") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + // Offset messages not arriving + for (i <- 2 to epochBacklogQueueSize + 1) { + commitPartitionEpoch(0, i) + } + + makeSynchronousCall() + + for (i <- 1 to epochBacklogQueueSize + 1) { + verifyNoCommitFor(i) + } + verifyStoppedWithException("Size of the partition commit queue has exceeded its maximum") + } + + test("several epochs, max epoch backlog reached by epochsWaitingToBeCommitted") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + // For partition 2 epoch 1 messages never arriving + // +2 because the first epoch not yet arrived + for (i <- 2 to epochBacklogQueueSize + 2) { + commitPartitionEpoch(0, i) + reportPartitionOffset(0, i) + commitPartitionEpoch(1, i) + reportPartitionOffset(1, i) + } + + makeSynchronousCall() + + for (i <- 1 to epochBacklogQueueSize + 2) { + verifyNoCommitFor(i) + } + verifyStoppedWithException("Size of the epoch queue has exceeded its maximum") + } + private def setWriterPartitions(numPartitions: Int): Unit = { epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions)) } @@ -221,4 +287,13 @@ class EpochCoordinatorSuite private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = { epochs.foreach(verifyCommit) } + + private def verifyStoppedWithException(msg: String): Unit = { + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]); + verify(query, atLeastOnce()).stopInNewThread(exceptionCaptor.capture()) + + import scala.collection.JavaConverters._ + val throwable = exceptionCaptor.getAllValues.asScala.find(_.getMessage === msg) + assert(throwable != null, "Stream stopped with an exception but expected message is missing") + } }