From cbeb4317ef512ec55c4f3b603a3fe6ebfc9bdb02 Mon Sep 17 00:00:00 2001 From: Sumedh Wale Date: Thu, 15 Dec 2016 10:56:41 +0530 Subject: [PATCH] [SNAP-1251] Avoid exchange when number of shuffle partitions > child partitions (#37) - reason is that shuffle is added first with default shuffle partitions, then the child with maximum partitions is selected; now marking children where implicit shuffle was introduced then taking max of rest (except if there are no others in which case the negative value gets chosen and its abs returns default shuffle partitions) - use child.outputPartitioning.numPartitions for shuffle partition case instead of depending on it being defaultNumPreShufflePartitions --- .../exchange/EnsureRequirements.scala | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index bdb9a3129f036..82712822a6d06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -145,16 +145,19 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { assert(requiredChildOrderings.length == children.length) // Ensure that the operator's children satisfy their output distribution requirements. - children = children.zip(requiredChildDistributions).map { + // The second boolean parameter in the result is true when a ShuffleExchange + // was introduced to satisfy the output distribution. + val newChildren = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => - child + (child, false) case (child, BroadcastDistribution(mode)) => - BroadcastExchangeExec(mode, child) + (BroadcastExchangeExec(mode, child), false) case (child, distribution) => val numPartitions = distribution.requiredNumPartitions .getOrElse(defaultNumPreShufflePartitions) - ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child) + (ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child), true) } + children = newChildren.map(_._1) // Get the indexes of children which have specified distribution requirements and need to have // same number of partitions. @@ -178,7 +181,12 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { numPartitionsSet.headOption } - val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max) + val maxChildrenNumPartitions = math.abs(newChildren.map { + case (child, false) => child.outputPartitioning.numPartitions + case (child, true) => -child.outputPartitioning.numPartitions + }.max) + + val targetNumPartitions = requiredNumPartitions.getOrElse(maxChildrenNumPartitions) children = children.zip(requiredChildDistributions).zipWithIndex.map { case ((child, distribution), index) if childrenIndexes.contains(index) =>