From 8681d73bb53d7747d3122ae0f17ba5acdb4ea9a4 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 13 Apr 2015 22:45:48 -0700 Subject: [PATCH] refactor Exchange and fix copy for sorting --- .../org/apache/spark/sql/execution/Exchange.scala | 10 +++++----- .../apache/spark/sql/execution/SparkStrategies.scala | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 7bdfbb8ec4e7a..ba866357f8bf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -34,8 +34,8 @@ import org.apache.spark.util.MutablePair @DeveloperApi case class Exchange( newPartitioning: Partitioning, - child: SparkPlan, - sort: Boolean = false) + sort: Boolean, + child: SparkPlan) extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning @@ -59,7 +59,7 @@ case class Exchange( // we can avoid the defensive copies to improve performance. In the long run, we probably // want to include information in shuffle dependencies to indicate whether elements in the // source RDD should be copied. - val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { + val rdd = if ((sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort) { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() iter.map(r => (hashExpressions(r).copy(), r.copy())) @@ -178,7 +178,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl val needSort = child.outputOrdering != rowOrdering if (child.outputPartitioning != partitioning || needSort) { // TODO: if only needSort, we need only sort each partition instead of an Exchange - Exchange(partitioning, child, sort = needSort) + Exchange(partitioning, sort = needSort, child) } else { child } @@ -197,7 +197,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl addExchangeIfNecessary(SinglePartition, child) case (ClusteredDistribution(clustering), (child, rowOrdering)) => addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering) - case (OrderedDistribution(ordering), (child, _)) => + case (OrderedDistribution(ordering), (child, None)) => addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) case (UnspecifiedDistribution, (child, _)) => child case (dist, _) => sys.error(s"Don't know how to ensure $dist") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 519ef5d93154c..c6ff8c30c24e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -307,7 +307,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.Exchange( + HashPartitioning(expressions, numPartitions), sort = false, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil