Skip to content

Commit

Permalink
refactor Exchange and fix copy for sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Apr 14, 2015
1 parent 2875ef2 commit 8681d73
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ import org.apache.spark.util.MutablePair
@DeveloperApi
case class Exchange(
newPartitioning: Partitioning,
child: SparkPlan,
sort: Boolean = false)
sort: Boolean,
child: SparkPlan)
extends UnaryNode {

override def outputPartitioning: Partitioning = newPartitioning
Expand All @@ -59,7 +59,7 @@ case class Exchange(
// we can avoid the defensive copies to improve performance. In the long run, we probably
// want to include information in shuffle dependencies to indicate whether elements in the
// source RDD should be copied.
val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
val rdd = if ((sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
Expand Down Expand Up @@ -178,7 +178,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
val needSort = child.outputOrdering != rowOrdering
if (child.outputPartitioning != partitioning || needSort) {
// TODO: if only needSort, we need only sort each partition instead of an Exchange
Exchange(partitioning, child, sort = needSort)
Exchange(partitioning, sort = needSort, child)
} else {
child
}
Expand All @@ -197,7 +197,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
addExchangeIfNecessary(SinglePartition, child)
case (ClusteredDistribution(clustering), (child, rowOrdering)) =>
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering)
case (OrderedDistribution(ordering), (child, _)) =>
case (OrderedDistribution(ordering), (child, None)) =>
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
case (UnspecifiedDistribution, (child, _)) => child
case (dist, _) => sys.error(s"Don't know how to ensure $dist")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.OneRowRelation =>
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
execution.Exchange(
HashPartitioning(expressions, numPartitions), sort = false, planLater(child)) :: Nil
case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
Expand Down

0 comments on commit 8681d73

Please sign in to comment.