diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 6d4d4834dfc75..4a53e90dabcb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -212,9 +212,9 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } } - private def ensureChildNumPartitionsAgreementIfNecessary(operator: SparkPlan): SparkPlan = { + private def ensureChildPartitioningsAreCompatible(operator: SparkPlan): SparkPlan = { if (operator.requiresChildPartitioningsToBeCompatible) { - if (operator.children.map(_.outputPartitioning.numPartitions).distinct.size > 1) { + if (!Partitioning.allCompatible(operator.children.map(_.outputPartitioning))) { val newChildren = operator.children.zip(operator.requiredChildDistribution).map { case (child, requiredDistribution) => val targetPartitioning = canonicalPartitioning(requiredDistribution) @@ -271,6 +271,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - ensureDistributionAndOrdering(ensureChildNumPartitionsAgreementIfNecessary(operator)) + ensureDistributionAndOrdering(ensureChildPartitioningsAreCompatible(operator)) } }