Skip to content

Commit

Permalink
[SPARK-50046][SS] Use stable order of EventTimeWatermark node to calc…
Browse files Browse the repository at this point in the history
…ulate watermark

### What changes were proposed in this pull request?

This PR proposes to use stable order of EventTimeWatermark node (instead of traversal order) to calculate watermark.

### Why are the changes needed?

WatermarkTracker only looks at the physical plan during calculation of the new watermark value. It determines the watermark node by index, hence we have various issues when the watermark node is lost on the optimization phase.

1) watermark advancement is made even there is one node to be dropped (should be considered as no data from that node, hence should not advance the watermark)
2) watermark tracker incorrectly update the memory map of the previous value of watermark node (index is not a stable key, but used to update the map)

New UT describes what is the expectation of the behavior and how it was broken before this PR.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New UT.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48570 from HeartSaVioR/SPARK-50046.

Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
HeartSaVioR committed Oct 26, 2024
1 parent c086163 commit b8c2a32
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3929,7 +3929,7 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper {
object EliminateEventTimeWatermark extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
_.containsPattern(EVENT_TIME_WATERMARK)) {
case EventTimeWatermark(_, _, child) if child.resolved && !child.isStreaming => child
case EventTimeWatermark(_, _, _, child) if child.resolved && !child.isStreaming => child
case UpdateEventTimeWatermarkColumn(_, _, child) if child.resolved && !child.isStreaming =>
child
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ object ResolveUpdateEventTimeWatermarkColumn extends Rule[LogicalPlan] {
_.containsPattern(UPDATE_EVENT_TIME_WATERMARK_COLUMN), ruleId) {
case u: UpdateEventTimeWatermarkColumn if u.delay.isEmpty && u.childrenResolved =>
val existingWatermarkDelay = u.child.collect {
case EventTimeWatermark(_, delay, _) => delay
case EventTimeWatermark(_, _, delay, _) => delay
}

if (existingWatermarkDelay.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical

import java.util.UUID
import java.util.concurrent.TimeUnit

import org.apache.spark.sql.catalyst.expressions.Attribute
Expand Down Expand Up @@ -69,6 +70,7 @@ object EventTimeWatermark {
* Used to mark a user specified column as holding the event time for a row.
*/
case class EventTimeWatermark(
nodeId: UUID,
eventTime: Attribute,
delay: CalendarInterval,
child: LogicalPlan) extends UnaryNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import java.util.TimeZone
import java.util.{TimeZone, UUID}

import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
Expand Down Expand Up @@ -1763,7 +1763,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

test("SPARK-46064 Basic functionality of elimination for watermark node in batch query") {
val dfWithEventTimeWatermark = EventTimeWatermark($"ts",
val dfWithEventTimeWatermark = EventTimeWatermark(UUID.randomUUID(), $"ts",
IntervalUtils.fromIntervalString("10 seconds"), batchRelationWithTs)

val analyzed = getAnalyzer.executeAndCheck(dfWithEventTimeWatermark, new QueryPlanningTracker)
Expand All @@ -1776,7 +1776,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
"EventTimeWatermark changes the isStreaming flag during resolution") {
// UnresolvedRelation which is batch initially and will be resolved as streaming
val dfWithTempView = UnresolvedRelation(TableIdentifier("streamingTable"))
val dfWithEventTimeWatermark = EventTimeWatermark($"ts",
val dfWithEventTimeWatermark = EventTimeWatermark(UUID.randomUUID(), $"ts",
IntervalUtils.fromIntervalString("10 seconds"), dfWithTempView)

val analyzed = getAnalyzer.executeAndCheck(dfWithEventTimeWatermark, new QueryPlanningTracker)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.optimizer

import java.util.UUID

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand Down Expand Up @@ -1229,9 +1231,10 @@ class FilterPushdownSuite extends PlanTest {

// Verify that all conditions except the watermark touching condition are pushed down
// by the optimizer and others are not.
val originalQuery = EventTimeWatermark($"b", interval, relation)
val nodeId = UUID.randomUUID()
val originalQuery = EventTimeWatermark(nodeId, $"b", interval, relation)
.where($"a" === 5 && $"b" === new java.sql.Timestamp(0) && $"c" === 5)
val correctAnswer = EventTimeWatermark(
val correctAnswer = EventTimeWatermark(nodeId,
$"b", interval, relation.where($"a" === 5 && $"c" === 5))
.where($"b" === new java.sql.Timestamp(0))

Expand All @@ -1244,9 +1247,10 @@ class FilterPushdownSuite extends PlanTest {

// Verify that all conditions except the watermark touching condition are pushed down
// by the optimizer and others are not.
val originalQuery = EventTimeWatermark($"c", interval, relation)
val nodeId = UUID.randomUUID()
val originalQuery = EventTimeWatermark(nodeId, $"c", interval, relation)
.where($"a" === 5 && $"b" === Rand(10) && $"c" === new java.sql.Timestamp(0))
val correctAnswer = EventTimeWatermark(
val correctAnswer = EventTimeWatermark(nodeId,
$"c", interval, relation.where($"a" === 5))
.where($"b" === Rand(10) && $"c" === new java.sql.Timestamp(0))

Expand All @@ -1260,9 +1264,10 @@ class FilterPushdownSuite extends PlanTest {

// Verify that all conditions except the watermark touching condition are pushed down
// by the optimizer and others are not.
val originalQuery = EventTimeWatermark($"c", interval, relation)
val nodeId = UUID.randomUUID()
val originalQuery = EventTimeWatermark(nodeId, $"c", interval, relation)
.where($"a" === 5 && $"b" === 10)
val correctAnswer = EventTimeWatermark(
val correctAnswer = EventTimeWatermark(nodeId,
$"c", interval, relation.where($"a" === 5 && $"b" === 10))

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze,
Expand All @@ -1273,9 +1278,10 @@ class FilterPushdownSuite extends PlanTest {
val interval = new CalendarInterval(2, 2, 2000L)
val relation = LocalRelation(Seq($"a".timestamp, attrB, attrC), Nil, isStreaming = true)

val originalQuery = EventTimeWatermark($"a", interval, relation)
val nodeId = UUID.randomUUID()
val originalQuery = EventTimeWatermark(nodeId, $"a", interval, relation)
.where($"a" === new java.sql.Timestamp(0) && $"b" === 10)
val correctAnswer = EventTimeWatermark(
val correctAnswer = EventTimeWatermark(nodeId,
$"a", interval, relation.where($"b" === 10)).where($"a" === new java.sql.Timestamp(0))

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze,
Expand Down
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,8 @@ class Dataset[T] private[sql](
require(!IntervalUtils.isNegative(parsedDelay),
s"delay threshold ($delayThreshold) should not be negative.")
EliminateEventTimeWatermark(
EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan))
EventTimeWatermark(util.UUID.randomUUID(), UnresolvedAttribute(eventTime),
parsedDelay, logicalPlan))
}

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case _ if !plan.isStreaming => Nil

case EventTimeWatermark(columnName, delay, child) =>
EventTimeWatermarkExec(columnName, delay, planLater(child)) :: Nil
case EventTimeWatermark(nodeId, columnName, delay, child) =>
EventTimeWatermarkExec(nodeId, columnName, delay, planLater(child)) :: Nil

case UpdateEventTimeWatermarkColumn(columnName, delay, child) =>
// we expect watermarkDelay to be resolved before physical planning.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.streaming

import java.util.UUID

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Predicate, SortOrder, UnsafeProjection}
Expand Down Expand Up @@ -90,6 +92,7 @@ class EventTimeStatsAccum(protected var currentStats: EventTimeStats = EventTime
* period. Note that event time is measured in milliseconds.
*/
case class EventTimeWatermarkExec(
nodeId: UUID,
eventTime: Attribute,
delay: CalendarInterval,
child: SparkPlan) extends UnaryExecNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ class MicroBatchExecution(
OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.sessionState.conf)
execCtx.offsetSeqMetadata = OffsetSeqMetadata(
metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf)
watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf)
watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf, logicalPlan)
watermarkTracker.setWatermark(metadata.batchWatermarkMs)
}

Expand Down Expand Up @@ -539,7 +539,7 @@ class MicroBatchExecution(
case None => // We are starting this stream for the first time.
logInfo(s"Starting new streaming query.")
execCtx.batchId = 0
watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf)
watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf, logicalPlan)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

package org.apache.spark.sql.execution.streaming

import java.util.Locale
import java.util.{Locale, UUID}

import scala.collection.mutable

import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf

Expand Down Expand Up @@ -79,8 +80,21 @@ case object MaxWatermark extends MultipleWatermarkPolicy {
}

/** Tracks the watermark value of a streaming query based on a given `policy` */
case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging {
private val operatorToWatermarkMap = mutable.HashMap[Int, Long]()
class WatermarkTracker(
policy: MultipleWatermarkPolicy,
initialPlan: LogicalPlan) extends Logging {

private val operatorToWatermarkMap: mutable.Map[UUID, Option[Long]] = {
val map = mutable.HashMap[UUID, Option[Long]]()
val watermarkOperators = initialPlan.collect {
case e: EventTimeWatermark => e
}
watermarkOperators.foreach { op =>
map.put(op.nodeId, None)
}
map
}

private var globalWatermarkMs: Long = 0

def setWatermark(newWatermarkMs: Long): Unit = synchronized {
Expand All @@ -93,26 +107,33 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging {
}
if (watermarkOperators.isEmpty) return

watermarkOperators.zipWithIndex.foreach {
case (e, index) if e.eventTimeStats.value.count > 0 =>
logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}")
watermarkOperators.foreach {
case e if e.eventTimeStats.value.count > 0 =>
logDebug(s"Observed event time stats ${e.nodeId}: ${e.eventTimeStats.value}")

if (!operatorToWatermarkMap.isDefinedAt(e.nodeId)) {
throw new IllegalStateException(s"Unknown watermark node ID: ${e.nodeId}, known IDs: " +
s"${operatorToWatermarkMap.keys.mkString("[", ",", "]")}")
}

val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs
val prevWatermarkMs = operatorToWatermarkMap.get(index)
val prevWatermarkMs = operatorToWatermarkMap(e.nodeId)
if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) {
operatorToWatermarkMap.put(index, newWatermarkMs)
operatorToWatermarkMap.put(e.nodeId, Some(newWatermarkMs))
}

// Populate 0 if we haven't seen any data yet for this watermark node.
case (_, index) =>
if (!operatorToWatermarkMap.isDefinedAt(index)) {
operatorToWatermarkMap.put(index, 0)
case e =>
if (!operatorToWatermarkMap.isDefinedAt(e.nodeId)) {
throw new IllegalStateException(s"Unknown watermark node ID: ${e.nodeId}, known IDs: " +
s"${operatorToWatermarkMap.keys.mkString("[", ",", "]")}")
}
}

// Update the global watermark accordingly to the chosen policy. To find all available policies
// and their semantics, please check the comments of
// `org.apache.spark.sql.execution.streaming.MultipleWatermarkPolicy` implementations.
val chosenGlobalWatermark = policy.chooseGlobalWatermark(operatorToWatermarkMap.values.toSeq)
val chosenGlobalWatermark = policy.chooseGlobalWatermark(
operatorToWatermarkMap.values.map(_.getOrElse(0L)).toSeq)
if (chosenGlobalWatermark > globalWatermarkMs) {
logInfo(log"Updating event-time watermark from " +
log"${MDC(GLOBAL_WATERMARK, globalWatermarkMs)} " +
Expand All @@ -124,10 +145,14 @@ case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging {
}

def currentWatermark: Long = synchronized { globalWatermarkMs }

private[sql] def watermarkMap: Map[UUID, Option[Long]] = synchronized {
operatorToWatermarkMap.toMap
}
}

object WatermarkTracker {
def apply(conf: RuntimeConfig): WatermarkTracker = {
def apply(conf: RuntimeConfig, initialPlan: LogicalPlan): WatermarkTracker = {
// If the session has been explicitly configured to use non-default policy then use it,
// otherwise use the default `min` policy as thats the safe thing to do.
// When recovering from a checkpoint location, it is expected that the `conf` will already
Expand All @@ -137,6 +162,6 @@ object WatermarkTracker {
val policyName = conf.get(
SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key,
MultipleWatermarkPolicy.DEFAULT_POLICY_NAME)
new WatermarkTracker(MultipleWatermarkPolicy(policyName))
new WatermarkTracker(MultipleWatermarkPolicy(policyName), initialPlan)
}
}
Loading

0 comments on commit b8c2a32

Please sign in to comment.