Skip to content

Commit

Permalink
[SPARK-49905][SQL][SS] Use different ShuffleOrigin for the shuffle re…
Browse files Browse the repository at this point in the history
…quired from stateful operators
  • Loading branch information
HeartSaVioR committed Oct 8, 2024
1 parent 78135dc commit 4151e40
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,16 @@ case class EnsureRequirements(
case (child, distribution) =>
val numPartitions = distribution.requiredNumPartitions
.getOrElse(conf.numShufflePartitions)
ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child, shuffleOrigin)
distribution match {
case _: StatefulOpClusteredDistribution =>
ShuffleExchangeExec(
distribution.createPartitioning(numPartitions), child,
REQUIRED_BY_STATEFUL_OPERATOR)

case _ =>
ShuffleExchangeExec(
distribution.createPartitioning(numPartitions), child, shuffleOrigin)
}
}

// Get the indexes of children which have specified distribution requirements and need to be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ case object REBALANCE_PARTITIONS_BY_NONE extends ShuffleOrigin
// the output needs to be partitioned by the given columns.
case object REBALANCE_PARTITIONS_BY_COL extends ShuffleOrigin

// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule, but
// was required by a stateful operator. The physical partitioning is static and Spark shouldn't
// change it.
case object REQUIRED_BY_STATEFUL_OPERATOR extends ShuffleOrigin

/**
* Performs a shuffle that will result in the desired partitioning.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.exchange.{REQUIRED_BY_STATEFUL_OPERATOR, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{MemorySink, TestForeachWriter}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -1448,6 +1448,28 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
}
}

test("SPARK-49905 shuffle added by stateful operator should use the shuffle origin " +
"`REQUIRED_BY_STATEFUL_OPERATOR`") {
val inputData = MemoryStream[Int]

// Use the streaming aggregation as an example - all stateful operators are using the same
// distribution, named `StatefulOpClusteredDistribution`.
val df = inputData.toDF().groupBy("value").count()

testStream(df, OutputMode.Update())(
AddData(inputData, 1, 2, 3, 1, 2, 3),
CheckAnswer((1, 2), (2, 2), (3, 2)),
Execute { qe =>
val shuffleOpt = qe.lastExecution.executedPlan.collect {
case s: ShuffleExchangeExec => s
}

assert(shuffleOpt.nonEmpty, "No shuffle exchange found in the query plan")
assert(shuffleOpt.head.shuffleOrigin === REQUIRED_BY_STATEFUL_OPERATOR)
}
)
}

private def checkAppendOutputModeException(df: DataFrame): Unit = {
withTempDir { outputDir =>
withTempDir { checkpointDir =>
Expand Down

0 comments on commit 4151e40

Please sign in to comment.