Skip to content

Commit

Permalink
[WIP] Use CollectMetrics for numOutputRows in streaming sources
Browse files Browse the repository at this point in the history
  • Loading branch information
HeartSaVioR committed Sep 2, 2024
1 parent c58148d commit a7b343f
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2024,10 +2024,17 @@ case class CollectMetrics(
dataframeId: Long)
extends UnaryNode {

import CollectMetrics._

override lazy val resolved: Boolean = {
name.nonEmpty && metrics.nonEmpty && metrics.forall(_.resolved) && childrenResolved
}

if (isForStreamSource(name)) {
assert(references.isEmpty,
"The node should not refer any column if it's used for stream source output counter!")
}

override def maxRows: Option[Long] = child.maxRows
override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition
override def output: Seq[Attribute] = child.output
Expand All @@ -2040,6 +2047,14 @@ case class CollectMetrics(
}
}

object CollectMetrics {
val STREAM_SOURCE_PREFIX = "__stream_source_"

def nameForStreamSource(name: String): String = s"$STREAM_SOURCE_PREFIX$name"

def isForStreamSource(name: String): Boolean = name.startsWith(STREAM_SOURCE_PREFIX)
}

/**
* A placeholder for domain join that can be added when decorrelating subqueries.
* It should be rewritten during the optimization phase.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ import scala.collection.mutable
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{NUM_PRUNED, POST_SCAN_FILTERS, PUSHED_FILTERS, TOTAL}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.TreePattern.{PLAN_EXPRESSION, SCALAR_SUBQUERY}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
Expand Down Expand Up @@ -61,6 +61,25 @@ import org.apache.spark.util.collection.BitSet
*/
object FileSourceStrategy extends Strategy with PredicateHelper with Logging {

private type HadoopFsRelationHolderRetType =
(LogicalRelation, HadoopFsRelation, Option[CatalogTable], Option[CollectMetrics])

private object HadoopFsRelationHolder {
def unapply(plan: LogicalPlan): Option[HadoopFsRelationHolderRetType] = {
plan match {
case c @ CollectMetrics(name, _,
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _), _)
if CollectMetrics.isForStreamSource(name) =>
Some(l, fsRelation, table, Some(c))

case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _) =>
Some(l, fsRelation, table, None)

case _ => None
}
}
}

// should prune buckets iff num buckets is greater than 1 and there is only one bucket column
private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = {
bucketSpec match {
Expand Down Expand Up @@ -151,7 +170,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ScanOperation(projects, stayUpFilters, filters,
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) =>
HadoopFsRelationHolder(l, fsRelation, table, collectMetricsOpt)) =>
// Filters on this relation fall into four categories based on where we can use them to avoid
// reading unneeded data:
// - partition keys only - used to prune directories to read
Expand Down Expand Up @@ -342,9 +361,25 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
val metadataAlias =
Alias(KnownNotNull(CreateStruct(structColumns.toImmutableArraySeq)),
FileFormat.METADATA_NAME)(exprId = metadataStruct.exprId)

val nodeExec = if (collectMetricsOpt.isDefined) {
val collectMetricsLogical = collectMetricsOpt.get
execution.CollectMetricsExec(
collectMetricsLogical.name, collectMetricsLogical.metrics, scan)
} else {
scan
}
execution.ProjectExec(
readDataColumns ++ partitionColumns :+ metadataAlias, scan)
}.getOrElse(scan)
readDataColumns ++ partitionColumns :+ metadataAlias, nodeExec)
}.getOrElse {
if (collectMetricsOpt.isDefined) {
val collectMetricsLogical = collectMetricsOpt.get
execution.CollectMetricsExec(
collectMetricsLogical.name, collectMetricsLogical.metrics, scan)
} else {
scan
}
}

// bottom-most filters are put in the left of the list.
val finalFilters = afterScanFilters.toSeq.reduceOption(expressions.And).toSeq ++ stayUpFilters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.sql.execution.streaming

import java.util.UUID

import scala.collection.mutable.{Map => MutableMap}
import scala.collection.mutable

import org.apache.spark.internal.{LogKeys, MDC}
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.{Column, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, LeafNode, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, WriteToStream}
import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
import org.apache.spark.sql.catalyst.util.truncatedString
Expand All @@ -35,6 +38,8 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming.sources.{WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1}
import org.apache.spark.sql.functions.count
import org.apache.spark.sql.internal
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.util.{Clock, Utils}
Expand Down Expand Up @@ -731,10 +736,15 @@ class MicroBatchExecution(
}

// Replace sources in the logical plan with data that has arrived since the last batch.
import sparkSessionToRunBatch.RichColumn

val uuidToStream = new mutable.HashMap[String, SparkDataStream]()
val streamToCollectMetrics = new mutable.HashMap[SparkDataStream, CollectMetrics]()

val newBatchesPlan = logicalPlan transform {
// For v1 sources.
case StreamingExecutionRelation(source, output, catalogTable) =>
mutableNewData.get(source).map { dataPlan =>
val node = mutableNewData.get(source).map { dataPlan =>
val hasFileMetadata = output.exists {
case FileSourceMetadataAttribute(_) => true
case _ => false
Expand Down Expand Up @@ -782,16 +792,54 @@ class MicroBatchExecution(
LocalRelation(output, isStreaming = true)
}

val collectMetricsName = CollectMetrics.nameForStreamSource(
UUID.randomUUID().toString)
uuidToStream.put(collectMetricsName, source)
val cachedCollectMetrics = streamToCollectMetrics.getOrElseUpdate(source,
CollectMetrics(
collectMetricsName,
Seq(
count(
new Column(internal.Literal(1))).as("row_count")
).map(_.named),
UnresolvedRelation(Seq("dummy")),
-1
)
)

val colMetrics = cachedCollectMetrics.copy(child = node)
sparkSessionToRunBatch.sessionState.analyzer.execute(colMetrics)

// For v2 sources.
case r: StreamingDataSourceV2ScanRelation =>
mutableNewData.get(r.stream).map {
case r: StreamingDataSourceV2ScanRelation
if r.startOffset.isEmpty && r.endOffset.isEmpty =>
val node = mutableNewData.get(r.stream).map {
case OffsetHolder(start, end) =>
r.copy(startOffset = Some(start), endOffset = Some(end))
}.getOrElse {
LocalRelation(r.output, isStreaming = true)
}

val collectMetricsName = CollectMetrics.nameForStreamSource(
UUID.randomUUID().toString)
uuidToStream.put(collectMetricsName, r.stream)
val cachedCollectMetrics = streamToCollectMetrics.getOrElseUpdate(r.stream,
CollectMetrics(
collectMetricsName,
Seq(
count(
new Column(internal.Literal(1))).as("row_count")
).map(_.named),
UnresolvedRelation(Seq("dummy")),
-1
)
)

val colMetrics = cachedCollectMetrics.copy(child = node)
sparkSessionToRunBatch.sessionState.analyzer.execute(colMetrics)
}
execCtx.newData = mutableNewData.toMap
execCtx.uuidToStream = uuidToStream.toMap
// Rewire the plan to use the new attributes that were returned by the source.
val newAttributePlan = newBatchesPlan.transformAllExpressionsWithPruning(
_.containsPattern(CURRENT_LIKE)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.optimizer.InlineCTE
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, WithCTE}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream}
import org.apache.spark.sql.connector.read.streaming.{ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress}
import org.apache.spark.sql.execution.datasources.v2.{StreamWriterCommitProgress}
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent}
import org.apache.spark.util.{Clock, Utils}
Expand Down Expand Up @@ -144,6 +143,8 @@ abstract class ProgressContext(
// the most recent input data for each source.
protected def newData: Map[SparkDataStream, LogicalPlan]

protected def uuidToStream: Map[String, SparkDataStream]

/** Flag that signals whether any error with input metrics have already been logged */
protected var metricWarningLogged: Boolean = false

Expand Down Expand Up @@ -409,103 +410,21 @@ abstract class ProgressContext(
tuples.groupBy(_._1).transform((_, v) => v.map(_._2).sum) // sum up rows for each source
}

def unrollCTE(plan: LogicalPlan): LogicalPlan = {
val containsCTE = plan.exists {
case _: WithCTE => true
case _ => false
}

if (containsCTE) {
InlineCTE(alwaysInline = true).apply(plan)
} else {
plan
}
}

val onlyDataSourceV2Sources = {
// Check whether the streaming query's logical plan has only V2 micro-batch data sources
val allStreamingLeaves = progressReporter.logicalPlan().collect {
case s: StreamingDataSourceV2ScanRelation => s.stream.isInstanceOf[MicroBatchStream]
case _: StreamingExecutionRelation => false
}
allStreamingLeaves.forall(_ == true)
}
import org.apache.spark.sql.execution.CollectMetricsExec

if (onlyDataSourceV2Sources) {
// It's possible that multiple DataSourceV2ScanExec instances may refer to the same source
// (can happen with self-unions or self-joins). This means the source is scanned multiple
// times in the query, we should count the numRows for each scan.
if (uuidToStream != null) {
val sourceToInputRowsTuples = lastExecution.executedPlan.collect {
case s: MicroBatchScanExec =>
val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L)
val source = s.stream
source -> numRows
case c: CollectMetricsExec if uuidToStream.contains(c.name) =>
val stream = uuidToStream(c.name)
val numRows = c.collectedMetrics.getAs[Long]("row_count")
stream -> numRows
}

logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t"))
sumRows(sourceToInputRowsTuples)
sumRows(sourceToInputRowsTuples.toSeq)
} else {

// Since V1 source do not generate execution plan leaves that directly link with source that
// generated it, we can only do a best-effort association between execution plan leaves to the
// sources. This is known to fail in a few cases, see SPARK-24050.
//
// We want to associate execution plan leaves to sources that generate them, so that we match
// the their metrics (e.g. numOutputRows) to the sources. To do this we do the following.
// Consider the translation from the streaming logical plan to the final executed plan.
//
// streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan
//
// 1. We keep track of streaming sources associated with each leaf in trigger's logical plan
// - Each logical plan leaf will be associated with a single streaming source.
// - There can be multiple logical plan leaves associated with a streaming source.
// - There can be leaves not associated with any streaming source, because they were
// generated from a batch source (e.g. stream-batch joins)
//
// 2. Assuming that the executed plan has same number of leaves in the same order as that of
// the trigger logical plan, we associate executed plan leaves with corresponding
// streaming sources.
//
// 3. For each source, we sum the metrics of the associated execution plan leaves.
//
val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) =>
logicalPlan.collectLeaves().map { leaf => leaf -> source }
}

// SPARK-41198: CTE is inlined in optimization phase, which ends up with having different
// number of leaf nodes between (analyzed) logical plan and executed plan. Here we apply
// inlining CTE against logical plan manually if there is a CTE node.
val finalLogicalPlan = unrollCTE(lastExecution.logical)

val allLogicalPlanLeaves = finalLogicalPlan.collectLeaves() // includes non-streaming
val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves()
if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) {
val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap {
case (_, ep: MicroBatchScanExec) =>
// SPARK-41199: `logicalPlanLeafToSource` contains OffsetHolder instance for DSv2
// streaming source, hence we cannot lookup the actual source from the map.
// The physical node for DSv2 streaming source contains the information of the source
// by itself, so leverage it.
Some(ep -> ep.stream)
case (lp, ep) =>
logicalPlanLeafToSource.get(lp).map { source => ep -> source }
}
val sourceToInputRowsTuples = execLeafToSource.map { case (execLeaf, source) =>
val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L)
source -> numRows
}
sumRows(sourceToInputRowsTuples)
} else {
if (!metricWarningLogged) {
def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}"

logWarning(log"Could not report metrics as number leaves in trigger logical plan did " +
log"not match that of the execution plan:\nlogical plan leaves: " +
log"${MDC(LogKeys.LOGICAL_PLAN_LEAVES, toString(allLogicalPlanLeaves))}\nexecution " +
log"plan leaves: ${MDC(LogKeys.EXECUTION_PLAN_LEAVES, toString(allExecPlanLeaves))}\n")
metricWarningLogged = true
}
Map.empty
}
logWarning("Association for streaming source output has been lost.")
Map.empty
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ abstract class StreamExecutionContext(
/** Holds the most recent input data for each source. */
var newData: Map[SparkDataStream, LogicalPlan] = _

var uuidToStream: Map[String, SparkDataStream] = _

/**
* Stores the start offset for this batch.
* Only the scheduler thread should modify this field, and only in atomic steps.
Expand Down
Loading

0 comments on commit a7b343f

Please sign in to comment.