Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup addition of ordering requirements #3

Merged
merged 3 commits into from
Apr 15, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
}

object RowOrdering {
def getOrderingFromDataTypes(dataTypes: Seq[DataType]): RowOrdering =
def forSchema(dataTypes: Seq[DataType]): RowOrdering =
new RowOrdering(dataTypes.zipWithIndex.map {
case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ sealed trait Partitioning {
* only compatible if the `numPartitions` of them is the same.
*/
def compatibleWith(other: Partitioning): Boolean

/** Returns the expressions that are used to key the partitioning. */
def keyExpressions: Seq[Expression]
}

case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
Expand All @@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case UnknownPartitioning(_) => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

case object SinglePartition extends Partitioning {
Expand All @@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning {
case SinglePartition => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

case object BroadcastPartitioning extends Partitioning {
Expand All @@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning {
case SinglePartition => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

/**
Expand Down Expand Up @@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}

override def keyExpressions: Seq[Expression] = expressions

override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
Expand Down Expand Up @@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}

override def keyExpressions: Seq[Expression] = ordering.map(_.child)

override def eval(input: Row): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
Batch("Add exchange", Once, AddExchange(self)) :: Nil
Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil
}

protected[sql] def openSession(): SQLSession = {
Expand Down
149 changes: 102 additions & 47 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,30 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair

object Exchange {
/** Returns true when the ordering expressions are a subset of the key. */
def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = {
desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
}
}

/**
* Shuffle data according to a new partition rule, and sort inside each partition if necessary.
* @param newPartitioning The new partitioning way that required by parent
* @param sort Whether we will sort inside each partition
* @param child Child operator
* :: DeveloperApi ::
* Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
* resulting partition based on expressions from the partition key. It is invalid to construct an
* exchange operator with a `newOrdering` that cannot be calculated using the partitioning key.
*/
@DeveloperApi
case class Exchange(
newPartitioning: Partitioning,
sort: Boolean,
newOrdering: Seq[SortOrder],
child: SparkPlan)
extends UnaryNode {

override def outputPartitioning: Partitioning = newPartitioning

override def outputOrdering: Seq[SortOrder] = newOrdering

override def output: Seq[Attribute] = child.output

/** We must copy rows when sort based shuffle is on */
Expand All @@ -51,6 +60,20 @@ case class Exchange(
private val bypassMergeThreshold =
child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)

private val keyOrdering = {
if (newOrdering.nonEmpty) {
val key = newPartitioning.keyExpressions
val boundOrdering = newOrdering.map { o =>
val ordinal = key.indexOf(o.child)
if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning")
o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable))
}
new RowOrdering(boundOrdering)
} else {
null // Ordering will not be used
}
}

override def execute(): RDD[Row] = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
Expand All @@ -62,7 +85,9 @@ case class Exchange(
// we can avoid the defensive copies to improve performance. In the long run, we probably
// want to include information in shuffle dependencies to indicate whether elements in the
// source RDD should be copied.
val rdd = if ((sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort) {
val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold

val rdd = if (willMergeSort || newOrdering.nonEmpty) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
Expand All @@ -75,21 +100,17 @@ case class Exchange(
}
}
val part = new HashPartitioner(numPartitions)
val shuffled = sort match {
case false => new ShuffledRDD[Row, Row, Row](rdd, part)
case true =>
val sortingExpressions = expressions.zipWithIndex.map {
case (exp, index) =>
new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending)
}
val ordering = new RowOrdering(sortingExpressions, child.output)
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
}
val shuffled =
if (newOrdering.nonEmpty) {
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering)
} else {
new ShuffledRDD[Row, Row, Row](rdd, part)
}
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)

case RangePartitioning(sortingExpressions, numPartitions) =>
val rdd = if (sortBasedShuffleOn) {
val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) {
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
} else {
child.execute().mapPartitions { iter =>
Expand All @@ -102,7 +123,12 @@ case class Exchange(
implicit val ordering = new RowOrdering(sortingExpressions, child.output)

val part = new RangePartitioner(numPartitions, rdd, ascending = true)
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
val shuffled =
if (newOrdering.nonEmpty) {
new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering)
} else {
new ShuffledRDD[Row, Null, Null](rdd, part)
}
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))

shuffled.map(_._1)
Expand Down Expand Up @@ -135,27 +161,35 @@ case class Exchange(
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
* of input data meets the
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
* each operator by inserting [[Exchange]] Operators where required.
* each operator by inserting [[Exchange]] Operators where required. Also ensure that the
* required input partition ordering requirements are met.
*/
private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
def numPartitions: Int = sqlContext.conf.numShufflePartitions

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
// Check if every child's outputPartitioning satisfies the corresponding
// True iff every child's outputPartitioning satisfies the corresponding
// required data distribution.
def meetsRequirements: Boolean =
!operator.requiredChildDistribution.zip(operator.children).map {
operator.requiredChildDistribution.zip(operator.children).forall {
case (required, child) =>
val valid = child.outputPartitioning.satisfies(required)
logDebug(
s"${if (valid) "Valid" else "Invalid"} distribution," +
s"required: $required current: ${child.outputPartitioning}")
valid
}.exists(!_)
}

// True iff any of the children are incorrectly sorted.
def needsAnySort: Boolean =
operator.requiredChildOrdering.zip(operator.children).exists {
case (required, child) => required.nonEmpty && required != child
}


// Check if outputPartitionings of children are compatible with each other.
// True iff outputPartitionings of children are compatible with each other.
// It is possible that every child satisfies its required data distribution
// but two children have incompatible outputPartitionings. For example,
// A dataset is range partitioned by "a.asc" (RangePartitioning) and another
Expand All @@ -172,40 +206,61 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
case Seq(a,b) => a compatibleWith b
}.exists(!_)

// Check if the partitioning we want to ensure is the same as the child's output
// partitioning. If so, we do not need to add the Exchange operator.
def addExchangeIfNecessary(
// Adds Exchange or Sort operators as required
def addOperatorsIfNecessary(
partitioning: Partitioning,
child: SparkPlan,
rowOrdering: Option[Ordering[Row]] = None): SparkPlan = {
val needSort = child.outputOrdering != rowOrdering
if (child.outputPartitioning != partitioning || needSort) {
// TODO: if only needSort, we need only sort each partition instead of an Exchange
Exchange(partitioning, sort = needSort, child)
rowOrdering: Seq[SortOrder],
child: SparkPlan): SparkPlan = {
val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
val needsShuffle = child.outputPartitioning != partitioning
val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering)

if (needSort && needsShuffle && canSortWithShuffle) {
Exchange(partitioning, rowOrdering, child)
} else {
child
val withShuffle = if (needsShuffle) {
Exchange(partitioning, Nil, child)
} else {
child
}

val withSort = if (needSort) {
Sort(rowOrdering, global = false, withShuffle)
} else {
withShuffle
}

withSort
}
}

if (meetsRequirements && compatible) {
if (meetsRequirements && compatible && !needsAnySort) {
operator
} else {
// At least one child does not satisfies its required data distribution or
// at least one child's outputPartitioning is not compatible with another child's
// outputPartitioning. In this case, we need to add Exchange operators.
val repartitionedChildren = operator.requiredChildDistribution.zip(
operator.children.zip(operator.requiredChildOrdering)
).map {
case (AllTuples, (child, _)) =>
addExchangeIfNecessary(SinglePartition, child)
case (ClusteredDistribution(clustering), (child, rowOrdering)) =>
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering)
case (OrderedDistribution(ordering), (child, None)) =>
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
case (UnspecifiedDistribution, (child, _)) => child
case (dist, _) => sys.error(s"Don't know how to ensure $dist")
val requirements =
(operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)

val fixedChildren = requirements.zipped.map {
case (AllTuples, rowOrdering, child) =>
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
case (ClusteredDistribution(clustering), rowOrdering, child) =>
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
case (OrderedDistribution(ordering), rowOrdering, child) =>
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), Nil, child)

case (UnspecifiedDistribution, Seq(), child) =>
child
case (UnspecifiedDistribution, rowOrdering, child) =>
Sort(rowOrdering, global = false, child)

case (dist, ordering, _) =>
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
}
operator.withNewChildren(repartitionedChildren)

operator.withNewChildren(fixedChildren)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
Seq.fill(children.size)(UnspecifiedDistribution)

/** Specifies how data is ordered in each partition. */
def outputOrdering: Option[Ordering[Row]] = None
def outputOrdering: Seq[SortOrder] = Nil

/** Specifies sort order for each partition requirements on the input data for this operator. */
def requiredChildOrdering: Seq[Option[Ordering[Row]]] = Seq.fill(children.size)(None)
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)

/**
* Runs this query returning the result as an RDD.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>
execution.Exchange(
HashPartitioning(expressions, numPartitions), sort = false, planLater(child)) :: Nil
HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil
case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
iter.map(resuableProjection)
}

/**
* outputOrdering of Project is not always same with child's outputOrdering if the certain
* key is pruned, however, if the key is pruned then we must not require child using this
* ordering from upper layer, so it is fine to keep it to avoid some unnecessary sorting.
*/
override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
}

/**
Expand All @@ -63,7 +58,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
iter.filter(conditionEvaluator)
}

override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
}

/**
Expand Down Expand Up @@ -111,7 +106,7 @@ case class Limit(limit: Int, child: SparkPlan)
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition

override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def executeCollect(): Array[Row] = child.executeTake(limit)

Expand Down Expand Up @@ -158,7 +153,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
// TODO: Pick num splits based on |limit|.
override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1)

override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
override def outputOrdering: Seq[SortOrder] = sortOrder
}

/**
Expand All @@ -185,7 +180,7 @@ case class Sort(

override def output: Seq[Attribute] = child.output

override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
override def outputOrdering: Seq[SortOrder] = sortOrder
}

/**
Expand Down Expand Up @@ -217,7 +212,7 @@ case class ExternalSort(

override def output: Seq[Attribute] = child.output

override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
override def outputOrdering: Seq[SortOrder] = sortOrder
}

/**
Expand Down
Loading