diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index bdb9a3129f036..82712822a6d06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -145,16 +145,19 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { assert(requiredChildOrderings.length == children.length) // Ensure that the operator's children satisfy their output distribution requirements. - children = children.zip(requiredChildDistributions).map { + // The second boolean parameter in the result is true when a ShuffleExchange + // was introduced to satisfy the output distribution. + val newChildren = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => - child + (child, false) case (child, BroadcastDistribution(mode)) => - BroadcastExchangeExec(mode, child) + (BroadcastExchangeExec(mode, child), false) case (child, distribution) => val numPartitions = distribution.requiredNumPartitions .getOrElse(defaultNumPreShufflePartitions) - ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child) + (ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child), true) } + children = newChildren.map(_._1) // Get the indexes of children which have specified distribution requirements and need to have // same number of partitions. @@ -178,7 +181,12 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { numPartitionsSet.headOption } - val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max) + val maxChildrenNumPartitions = math.abs(newChildren.map { + case (child, false) => child.outputPartitioning.numPartitions + case (child, true) => -child.outputPartitioning.numPartitions + }.max) + + val targetNumPartitions = requiredNumPartitions.getOrElse(maxChildrenNumPartitions) children = children.zip(requiredChildDistributions).zipWithIndex.map { case ((child, distribution), index) if childrenIndexes.contains(index) =>