diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9d7ea6148757d..6b64f493f4052 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3929,7 +3929,7 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { object EliminateEventTimeWatermark extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( _.containsPattern(EVENT_TIME_WATERMARK)) { - case EventTimeWatermark(_, _, child) if child.resolved && !child.isStreaming => child + case EventTimeWatermark(_, _, _, child) if child.resolved && !child.isStreaming => child case UpdateEventTimeWatermarkColumn(_, _, child) if child.resolved && !child.isStreaming => child } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUpdateEventTimeWatermarkColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUpdateEventTimeWatermarkColumn.scala index 31c4f068a83eb..cddc519d0887e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUpdateEventTimeWatermarkColumn.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUpdateEventTimeWatermarkColumn.scala @@ -36,7 +36,7 @@ object ResolveUpdateEventTimeWatermarkColumn extends Rule[LogicalPlan] { _.containsPattern(UPDATE_EVENT_TIME_WATERMARK_COLUMN), ruleId) { case u: UpdateEventTimeWatermarkColumn if u.delay.isEmpty && u.childrenResolved => val existingWatermarkDelay = u.child.collect { - case EventTimeWatermark(_, delay, _) => delay + case EventTimeWatermark(_, _, delay, _) => delay } if (existingWatermarkDelay.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index 8cfc939755ef7..0d7f2b1d0f3f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import java.util.UUID import java.util.concurrent.TimeUnit import org.apache.spark.sql.catalyst.expressions.Attribute @@ -69,6 +70,7 @@ object EventTimeWatermark { * Used to mark a user specified column as holding the event time for a row. */ case class EventTimeWatermark( + nodeId: UUID, eventTime: Attribute, delay: CalendarInterval, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 8409f454bfb88..939801e3f07af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import java.util.TimeZone +import java.util.{TimeZone, UUID} import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -1763,7 +1763,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-46064 Basic functionality of elimination for watermark node in batch query") { - val dfWithEventTimeWatermark = EventTimeWatermark($"ts", + val dfWithEventTimeWatermark = EventTimeWatermark(UUID.randomUUID(), $"ts", IntervalUtils.fromIntervalString("10 seconds"), batchRelationWithTs) val analyzed = getAnalyzer.executeAndCheck(dfWithEventTimeWatermark, new QueryPlanningTracker) @@ -1776,7 +1776,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { "EventTimeWatermark changes the isStreaming flag during resolution") { // UnresolvedRelation which is batch initially and will be resolved as streaming val dfWithTempView = UnresolvedRelation(TableIdentifier("streamingTable")) - val dfWithEventTimeWatermark = EventTimeWatermark($"ts", + val dfWithEventTimeWatermark = EventTimeWatermark(UUID.randomUUID(), $"ts", IntervalUtils.fromIntervalString("10 seconds"), dfWithTempView) val analyzed = getAnalyzer.executeAndCheck(dfWithEventTimeWatermark, new QueryPlanningTracker) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 5027222be6b80..9424ecda0ed8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer +import java.util.UUID + import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -1229,9 +1231,10 @@ class FilterPushdownSuite extends PlanTest { // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. - val originalQuery = EventTimeWatermark($"b", interval, relation) + val nodeId = UUID.randomUUID() + val originalQuery = EventTimeWatermark(nodeId, $"b", interval, relation) .where($"a" === 5 && $"b" === new java.sql.Timestamp(0) && $"c" === 5) - val correctAnswer = EventTimeWatermark( + val correctAnswer = EventTimeWatermark(nodeId, $"b", interval, relation.where($"a" === 5 && $"c" === 5)) .where($"b" === new java.sql.Timestamp(0)) @@ -1244,9 +1247,10 @@ class FilterPushdownSuite extends PlanTest { // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. - val originalQuery = EventTimeWatermark($"c", interval, relation) + val nodeId = UUID.randomUUID() + val originalQuery = EventTimeWatermark(nodeId, $"c", interval, relation) .where($"a" === 5 && $"b" === Rand(10) && $"c" === new java.sql.Timestamp(0)) - val correctAnswer = EventTimeWatermark( + val correctAnswer = EventTimeWatermark(nodeId, $"c", interval, relation.where($"a" === 5)) .where($"b" === Rand(10) && $"c" === new java.sql.Timestamp(0)) @@ -1260,9 +1264,10 @@ class FilterPushdownSuite extends PlanTest { // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. - val originalQuery = EventTimeWatermark($"c", interval, relation) + val nodeId = UUID.randomUUID() + val originalQuery = EventTimeWatermark(nodeId, $"c", interval, relation) .where($"a" === 5 && $"b" === 10) - val correctAnswer = EventTimeWatermark( + val correctAnswer = EventTimeWatermark(nodeId, $"c", interval, relation.where($"a" === 5 && $"b" === 10)) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, @@ -1273,9 +1278,10 @@ class FilterPushdownSuite extends PlanTest { val interval = new CalendarInterval(2, 2, 2000L) val relation = LocalRelation(Seq($"a".timestamp, attrB, attrC), Nil, isStreaming = true) - val originalQuery = EventTimeWatermark($"a", interval, relation) + val nodeId = UUID.randomUUID() + val originalQuery = EventTimeWatermark(nodeId, $"a", interval, relation) .where($"a" === new java.sql.Timestamp(0) && $"b" === 10) - val correctAnswer = EventTimeWatermark( + val correctAnswer = EventTimeWatermark(nodeId, $"a", interval, relation.where($"b" === 10)).where($"a" === new java.sql.Timestamp(0)) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b489f33cd63b9..3953a5c3704f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -573,7 +573,8 @@ class Dataset[T] private[sql]( require(!IntervalUtils.isNegative(parsedDelay), s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + EventTimeWatermark(util.UUID.randomUUID(), UnresolvedAttribute(eventTime), + parsedDelay, logicalPlan)) } /** @inheritdoc */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 53c335c1eced6..30b395d0c1369 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -425,8 +425,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case _ if !plan.isStreaming => Nil - case EventTimeWatermark(columnName, delay, child) => - EventTimeWatermarkExec(columnName, delay, planLater(child)) :: Nil + case EventTimeWatermark(nodeId, columnName, delay, child) => + EventTimeWatermarkExec(nodeId, columnName, delay, planLater(child)) :: Nil case UpdateEventTimeWatermarkColumn(columnName, delay, child) => // we expect watermarkDelay to be resolved before physical planning. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index 54041abdc9ab4..d25c4be0fb84a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Predicate, SortOrder, UnsafeProjection} @@ -90,6 +92,7 @@ class EventTimeStatsAccum(protected var currentStats: EventTimeStats = EventTime * period. Note that event time is measured in milliseconds. */ case class EventTimeWatermarkExec( + nodeId: UUID, eventTime: Attribute, delay: CalendarInterval, child: SparkPlan) extends UnaryExecNode { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index dc141b21780e7..5ce9e13eb8fac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -485,7 +485,7 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.sessionState.conf) execCtx.offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) - watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf, logicalPlan) watermarkTracker.setWatermark(metadata.batchWatermarkMs) } @@ -539,7 +539,7 @@ class MicroBatchExecution( case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") execCtx.batchId = 0 - watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf, logicalPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index 3e6f122f463d3..7228767c4d18a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.streaming -import java.util.Locale +import java.util.{Locale, UUID} import scala.collection.mutable import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.RuntimeConfig +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.internal.SQLConf @@ -79,8 +80,21 @@ case object MaxWatermark extends MultipleWatermarkPolicy { } /** Tracks the watermark value of a streaming query based on a given `policy` */ -case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { - private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() +class WatermarkTracker( + policy: MultipleWatermarkPolicy, + initialPlan: LogicalPlan) extends Logging { + + private val operatorToWatermarkMap: mutable.Map[UUID, Option[Long]] = { + val map = mutable.HashMap[UUID, Option[Long]]() + val watermarkOperators = initialPlan.collect { + case e: EventTimeWatermark => e + } + watermarkOperators.foreach { op => + map.put(op.nodeId, None) + } + map + } + private var globalWatermarkMs: Long = 0 def setWatermark(newWatermarkMs: Long): Unit = synchronized { @@ -93,26 +107,33 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { } if (watermarkOperators.isEmpty) return - watermarkOperators.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + watermarkOperators.foreach { + case e if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats ${e.nodeId}: ${e.eventTimeStats.value}") + + if (!operatorToWatermarkMap.isDefinedAt(e.nodeId)) { + throw new IllegalStateException(s"Unknown watermark node ID: ${e.nodeId}, known IDs: " + + s"${operatorToWatermarkMap.keys.mkString("[", ",", "]")}") + } + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = operatorToWatermarkMap.get(index) + val prevWatermarkMs = operatorToWatermarkMap(e.nodeId) if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - operatorToWatermarkMap.put(index, newWatermarkMs) + operatorToWatermarkMap.put(e.nodeId, Some(newWatermarkMs)) } - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!operatorToWatermarkMap.isDefinedAt(index)) { - operatorToWatermarkMap.put(index, 0) + case e => + if (!operatorToWatermarkMap.isDefinedAt(e.nodeId)) { + throw new IllegalStateException(s"Unknown watermark node ID: ${e.nodeId}, known IDs: " + + s"${operatorToWatermarkMap.keys.mkString("[", ",", "]")}") } } // Update the global watermark accordingly to the chosen policy. To find all available policies // and their semantics, please check the comments of // `org.apache.spark.sql.execution.streaming.MultipleWatermarkPolicy` implementations. - val chosenGlobalWatermark = policy.chooseGlobalWatermark(operatorToWatermarkMap.values.toSeq) + val chosenGlobalWatermark = policy.chooseGlobalWatermark( + operatorToWatermarkMap.values.map(_.getOrElse(0L)).toSeq) if (chosenGlobalWatermark > globalWatermarkMs) { logInfo(log"Updating event-time watermark from " + log"${MDC(GLOBAL_WATERMARK, globalWatermarkMs)} " + @@ -124,10 +145,14 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { } def currentWatermark: Long = synchronized { globalWatermarkMs } + + private[sql] def watermarkMap: Map[UUID, Option[Long]] = synchronized { + operatorToWatermarkMap.toMap + } } object WatermarkTracker { - def apply(conf: RuntimeConfig): WatermarkTracker = { + def apply(conf: RuntimeConfig, initialPlan: LogicalPlan): WatermarkTracker = { // If the session has been explicitly configured to use non-default policy then use it, // otherwise use the default `min` policy as thats the safe thing to do. // When recovering from a checkpoint location, it is expected that the `conf` will already @@ -137,6 +162,6 @@ object WatermarkTracker { val policyName = conf.get( SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key, MultipleWatermarkPolicy.DEFAULT_POLICY_NAME) - new WatermarkTracker(MultipleWatermarkPolicy(policyName)) + new WatermarkTracker(MultipleWatermarkPolicy(policyName), initialPlan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/WatermarkTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/WatermarkTrackerSuite.scala new file mode 100644 index 0000000000000..6018d286fc21e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/WatermarkTrackerSuite.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import scala.collection.mutable + +import org.apache.spark.sql.execution.{SparkPlan, UnionExec} +import org.apache.spark.sql.functions.timestamp_seconds +import org.apache.spark.sql.streaming.StreamTest + +class WatermarkTrackerSuite extends StreamTest { + + import testImplicits._ + + test("SPARK-50046 proper watermark advancement with dropped watermark nodes") { + val inputStream1 = MemoryStream[Int] + val inputStream2 = MemoryStream[Int] + val inputStream3 = MemoryStream[Int] + + val df1 = inputStream1.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "10 seconds") + + val df2 = inputStream2.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "20 seconds") + + val df3 = inputStream3.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "30 seconds") + + val union = df1.union(df2).union(df3) + + testStream(union)( + // just to ensure that executedPlan has watermark nodes for every stream. + MultiAddData( + (inputStream1, Seq(0)), + (inputStream2, Seq(0)), + (inputStream3, Seq(0)) + ), + ProcessAllAvailable(), + Execute { q => + val initialPlan = q.logicalPlan + val executedPlan = q.lastExecution.executedPlan + + val tracker = WatermarkTracker(spark.conf, initialPlan) + tracker.setWatermark(5) + + val delayMsToNodeId = executedPlan.collect { + case e: EventTimeWatermarkExec => e.delayMs -> e.nodeId + }.toMap + + def setupScenario( + data: Map[Long, Seq[Long]])(fnToPruneSubtree: UnionExec => UnionExec): SparkPlan = { + val eventTimeStatsMap = new mutable.HashMap[Long, EventTimeStatsAccum]() + executedPlan.foreach { + case e: EventTimeWatermarkExec => + eventTimeStatsMap.put(e.delayMs, e.eventTimeStats) + + case _ => + } + + data.foreach { case (delayMs, values) => + val stats = eventTimeStatsMap(delayMs) + values.foreach { value => + stats.add(value) + } + } + + executedPlan.transform { + case e: UnionExec => fnToPruneSubtree(e) + } + } + + def verifyWatermarkMap(expectation: Map[UUID, Option[Long]]): Unit = { + expectation.foreach { case (nodeId, watermarkValue) => + assert(tracker.watermarkMap(nodeId) === watermarkValue, + s"Watermark value for nodeId $nodeId is ${tracker.watermarkMap(nodeId)}, where " + + s"we expect $watermarkValue") + } + } + + // Before SPARK-50046, WatermarkTracker simply assumes that the watermark node won't + // be ever dropped, and the order of watermark nodes won't be changed. We don't find + // a case which breaks this, but it had been happening for other operators (e.g. + // PruneFilters), hence we would be better to guard against this in prior. + + // Scenario: We have three streams with watermark defined per stream. The query has + // executed the first batch in the query run, and (due to some reason) Spark drops one + // of subtrees. This should be considered like stream being a part of dropped subtree + // had no data (because we do not know), hence watermark should not be advanced. But + // before SPARK-50046, WatermarkTracker does not indicate there were watermark node being + // dropped, hence watermark is advanced based on the calculation with remaining two + // streams. + + val executedPlanFor1stBatch = setupScenario( + Map( + // watermark value for this node: 22 - 10 = 12 + 10000L -> Seq(20000L, 21000L, 22000L), + // watermark value for this node: 42 - 20 = 22 + 20000L -> Seq(40000L, 41000L, 42000L), + // watermark value for this node: 62 - 30 = 32 + 30000L -> Seq(60000L, 61000L, 62000L) + ) + ) { unionExec => + // drop the subtree which has watermark node having delay 10 seconds + unionExec.copy(unionExec.children.drop(1)) + } + + tracker.updateWatermark(executedPlanFor1stBatch) + + // watermark hasn't advanced, hence taking default value. + assert(tracker.currentWatermark === 5) + + verifyWatermarkMap( + Map( + delayMsToNodeId(10000L) -> None, + delayMsToNodeId(20000L) -> Some(22000L), + delayMsToNodeId(30000L) -> Some(32000L)) + ) + + // NOTE: Before SPARK-50046, the above verification failed and the below verification works. + // WatermarkTracker can't track the dropped node, hence it advances the watermark from the + // remaining nodes, hence min(22, 32) = 22 + // + // assert(tracker.currentWatermark === 22000) + // + // WatermarkTracker updates the map with shifted index. It should only update index 1 and + // 2, but it updates 0 and 1. + // verifyWatermarkMap(Map(0 -> Some(22000L), 1 -> Some(32000L))) + + // Scenario: after the first batch, the query has executed the second batch. In the second + // batch, and (due to some reason) Spark only retains the middle of the subtrees. Before + // SPARK-50046, WatermarkTracker only tracks the watermark nodes from physical plan with + // index, hence the watermark node for the index 1 in logical plan is shifted to index 0, + // updating the map incorrectly and also advancing the watermark. The correct behavior is, + // the watermark node for the first stream has been dropped for both batches, hence + // watermark must not be advanced. + + val executedPlanFor2ndBatch = setupScenario( + Map( + // watermark value for this node: 52 - 10 = 42 + 10000L -> Seq(50000L, 51000L, 52000L), + // watermark value for this node: 72 - 20 = 52 + 20000L -> Seq(70000L, 71000L, 72000L), + // watermark value for this node: 92 - 30 = 62 + 30000L -> Seq(90000L, 91000L, 92000L) + ) + ) { unionExec => + // only take the middle of the subtree, dropping remaining + unionExec.copy(Seq(unionExec.children(1))) + } + + tracker.updateWatermark(executedPlanFor2ndBatch) + + // watermark hasn't advanced, hence taking default value. + assert(tracker.currentWatermark === 5) + + // WatermarkTracker properly updates the map for the middle of watermark node. + verifyWatermarkMap( + Map( + delayMsToNodeId(10000L) -> None, + delayMsToNodeId(20000L) -> Some(52000L), + delayMsToNodeId(30000L) -> Some(32000L)) + ) + } + ) + } +}