Skip to content

Commit

Permalink
[SPARK-48481][SQL][SS] Do not apply OptimizeOnRowPlan against streami…
Browse files Browse the repository at this point in the history
…ng Dataset
  • Loading branch information
HeartSaVioR committed May 31, 2024
1 parent 80addbb commit d8fad62
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.internal.SQLConf

/**
* The rule is applied both normal and AQE Optimizer. It optimizes plan using max rows:
Expand All @@ -31,19 +32,37 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
* it's grouping only(include the rewritten distinct plan), convert aggregate to project
* - if the max rows of the child of aggregate is less than or equal to 1,
* set distinct to false in all aggregate expression
*
* Note: the rule should not be applied to streaming source, since the number of rows it sees is
* just for current microbatch. It does not mean the streaming source will ever produce max 1
* rows during lifetime of the query. Suppose the case: the streaming query has a case where
* batch 0 runs with empty data in streaming source A which triggers the rule with Aggregate,
* and batch 1 runs with several data in streaming source A which no longer trigger the rule.
* In the above scenario, this could fail the query as stateful operator is expected to be planned
* for every batches whereas here it is planned "selectively".
*/
object OptimizeOneRowPlan extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
val enableForStreaming = conf.getConf(SQLConf.STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED)

plan.transformUpWithPruning(_.containsAnyPattern(SORT, AGGREGATE), ruleId) {
case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) => child
case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) => child
case agg @ Aggregate(_, _, child) if agg.groupOnly && child.maxRows.exists(_ <= 1L) =>
case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) => child
case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) => child
case agg @ Aggregate(_, _, child) if agg.groupOnly && child.maxRows.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) =>
Project(agg.aggregateExpressions, child)
case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) =>
case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) &&
isChildEligible(agg.child, enableForStreaming) =>
agg.transformExpressions {
case aggExpr: AggregateExpression if aggExpr.isDistinct =>
aggExpr.copy(isDistinct = false)
}
}
}

private def isChildEligible(child: LogicalPlan, enableForStreaming: Boolean): Boolean = {
enableForStreaming || !child.isStreaming
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2334,6 +2334,17 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED =
buildConf("spark.sql.streaming.optimizeOneRowPlan.enabled")
.internal()
.doc("When true, enable OptimizeOneRowPlan rule for the case where the child is a " +
"streaming Dataset. This is a fallback flag to revert the 'incorrect' behavior, hence " +
"this configuration must not be used without understanding in depth. Use this only to " +
"quickly recover failure in existing query!")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val VARIABLE_SUBSTITUTE_ENABLED =
buildConf("spark.sql.variable.substitute")
.doc("This enables substitution using syntax like `${var}`, `${system:var}`, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.Timestamp
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions.{expr, lit, window}
import org.apache.spark.sql.internal.SQLConf

/**
* This test ensures that any optimizations done by Spark SQL optimizer are
Expand Down Expand Up @@ -451,4 +452,76 @@ class StreamingQueryOptimizationCorrectnessSuite extends StreamTest {
)
}
}

test("SPARK-48481: DISTINCT with empty stream source should retain AGGREGATE") {
def doTest(numExpectedStatefulOperatorsForOneEmptySource: Int): Unit = {
withTempView("tv1", "tv2") {
val inputStream1 = MemoryStream[Int]
val ds1 = inputStream1.toDS()
ds1.registerTempTable("tv1")

val inputStream2 = MemoryStream[Int]
val ds2 = inputStream2.toDS()
ds2.registerTempTable("tv2")

// DISTINCT is rewritten to AGGREGATE, hence an AGGREGATEs for each source
val unioned = spark.sql(
"""
| WITH u AS (
| SELECT DISTINCT value AS value FROM tv1
| ), v AS (
| SELECT DISTINCT value AS value FROM tv2
| )
| SELECT value FROM u UNION ALL SELECT value FROM v
|""".stripMargin
)

testStream(unioned, OutputMode.Update())(
MultiAddData(inputStream1, 1, 1, 2)(inputStream2, 1, 1, 2),
CheckNewAnswer(1, 2, 1, 2),
Execute { qe =>
val stateOperators = qe.lastProgress.stateOperators
// Aggregate should be "stateful" one
assert(stateOperators.length === 2)
stateOperators.zipWithIndex.foreach { case (op, id) =>
assert(op.numRowsUpdated === 2, s"stateful OP ID: $id")
}
},
AddData(inputStream2, 2, 2, 3),
// NOTE: this is probably far from expectation to have 2 as output given user intends
// deduplicate, but the behavior is still correct with rewritten node and output mode:
// Aggregate & Update mode.
// TODO: Probably we should disallow DISTINCT or rewrite to
// dropDuplicates(WithinWatermark) for streaming source?
CheckNewAnswer(2, 3),
Execute { qe =>
val stateOperators = qe.lastProgress.stateOperators
// Aggregate should be "stateful" one
assert(stateOperators.length === numExpectedStatefulOperatorsForOneEmptySource)
val opWithUpdatedRows = stateOperators.zipWithIndex.filterNot(_._1.numRowsUpdated == 0)
assert(opWithUpdatedRows.length === 1)
// If this were dropDuplicates, numRowsUpdated should have been 1.
assert(opWithUpdatedRows.head._1.numRowsUpdated === 2,
s"stateful OP ID: ${opWithUpdatedRows.head._2}")
},
AddData(inputStream1, 4, 4, 5),
CheckNewAnswer(4, 5),
Execute { qe =>
val stateOperators = qe.lastProgress.stateOperators
assert(stateOperators.length === numExpectedStatefulOperatorsForOneEmptySource)
val opWithUpdatedRows = stateOperators.zipWithIndex.filterNot(_._1.numRowsUpdated == 0)
assert(opWithUpdatedRows.length === 1)
assert(opWithUpdatedRows.head._1.numRowsUpdated === 2,
s"stateful OP ID: ${opWithUpdatedRows.head._2}")
}
)
}
}

doTest(numExpectedStatefulOperatorsForOneEmptySource = 2)

withSQLConf(SQLConf.STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED.key -> "true") {
doTest(numExpectedStatefulOperatorsForOneEmptySource = 1)
}
}
}

0 comments on commit d8fad62

Please sign in to comment.