Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9703] [SQL] Refactor EnsureRequirements to avoid certain unnecessary shuffles #7988

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
2dfc648
Add failing test illustrating bad exchange planning.
JoshRosen Aug 6, 2015
cc5669c
Adding outputPartitioning to Repartition does not fix the test.
JoshRosen Aug 6, 2015
0675956
Preserving ordering and partitioning in row format converters also do…
JoshRosen Aug 6, 2015
adcc742
Move test to PlannerSuite.
JoshRosen Aug 6, 2015
c9fb231
Rewrite exchange to fix better handle this case.
JoshRosen Aug 6, 2015
c628daf
Revert accidental ExchangeSuite change.
JoshRosen Aug 6, 2015
752b8de
style fix
JoshRosen Aug 6, 2015
2e0f33a
Write a more generic test for EnsureRequirements.
JoshRosen Aug 6, 2015
5172ac5
Add test for requiresChildrenToProduceSameNumberOfPartitions.
JoshRosen Aug 6, 2015
0725a34
Small assertion cleanup.
JoshRosen Aug 6, 2015
a1c12b9
Add failing test to demonstrate allCompatible bug
JoshRosen Aug 6, 2015
4f08278
Fix the test by adding the compatibility check to EnsureRequirements
JoshRosen Aug 6, 2015
8dbc845
Add even more tests.
JoshRosen Aug 6, 2015
06aba0c
Add more comments
JoshRosen Aug 6, 2015
642b0bb
Further expand comment / reasoning
JoshRosen Aug 6, 2015
fee65c4
Further refinement to comments / reasoning
JoshRosen Aug 6, 2015
2c7e126
Merge remote-tracking branch 'origin/master' into exchange-fixes
JoshRosen Aug 8, 2015
18cddeb
Rename DummyPlan to DummySparkPlan.
JoshRosen Aug 8, 2015
1307c50
Update conditions for requiring child compatibility.
JoshRosen Aug 8, 2015
8784bd9
Giant comment explaining compatibleWith vs. guarantees
JoshRosen Aug 8, 2015
0983f75
More guarantees vs. compatibleWith cleanup; delete BroadcastPartition…
JoshRosen Aug 9, 2015
38006e7
Rewrite EnsureRequirements _yet again_ to make things even simpler
JoshRosen Aug 9, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ sealed trait Partitioning {
def guarantees(other: Partitioning): Boolean
}

object Partitioning {
def allCompatible(partitionings: Seq[Partitioning]): Boolean = {
// Note: this assumes transitivity
partitionings.sliding(2).map {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we can use forall instead of map here and remove the forall at the end.

case Seq(a) => true
case Seq(a, b) =>
if (a.numPartitions != b.numPartitions) {
assert(!a.guarantees(b) && !b.guarantees(a))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using guarantees to check if sibling operators are compatible or not. How about we just get back compatible with?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds good to me. Let's revisit this naming later once we actually break this symmetry.

false
} else {
a.guarantees(b) && b.guarantees(a)
}
}.forall(_ == true)
}
}

case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
Expand Down
132 changes: 87 additions & 45 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,66 +197,108 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
* of input data meets the
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
* each operator by inserting [[Exchange]] Operators where required. Also ensure that the
* required input partition ordering requirements are met.
* input partition ordering requirements are met.
*/
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
def numPartitions: Int = sqlContext.conf.numShufflePartitions
private def numPartitions: Int = sqlContext.conf.numShufflePartitions

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
// Adds Exchange or Sort operators as required
def addOperatorsIfNecessary(
partitioning: Partitioning,
rowOrdering: Seq[SortOrder],
child: SparkPlan): SparkPlan = {

def addShuffleIfNecessary(child: SparkPlan): SparkPlan = {
if (!child.outputPartitioning.guarantees(partitioning)) {
Exchange(partitioning, child)
} else {
child
}
}
/**
* Given a required distribution, returns a partitioning that satisfies that distribution.
*/
private def canonicalPartitioning(requiredDistribution: Distribution): Partitioning = {
requiredDistribution match {
case AllTuples => SinglePartition
case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
case dist => sys.error(s"Do not know how to satisfy distribution $dist")
}
}

def addSortIfNecessary(child: SparkPlan): SparkPlan = {
/**
* Return true if all of the operator's children satisfy their output distribution requirements.
*/
private def childPartitioningsSatisfyDistributionRequirements(operator: SparkPlan): Boolean = {
operator.children.zip(operator.requiredChildDistribution).forall {
case (child, distribution) => child.outputPartitioning.satisfies(distribution)
}
}

if (rowOrdering.nonEmpty) {
// If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort.
val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
} else {
/**
* Given an operator, check whether the operator requires its children to have compatible
* output partitionings and add Exchanges to fix any detected incompatibilities.
*/
private def ensureChildPartitioningsAreCompatible(operator: SparkPlan): SparkPlan = {
if (operator.requiresChildPartitioningsToBeCompatible) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems requiresChildPartitioningsToBeCompatible is equivalent with operator.children.length > 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. What about Union?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default return value of requiredChildDistribution is Seq.fill(children.size)(UnspecifiedDistribution). Union and those broadcast joins should be fine.

if (!Partitioning.allCompatible(operator.children.map(_.outputPartitioning))) {
val newChildren = operator.children.zip(operator.requiredChildDistribution).map {
case (child, requiredDistribution) =>
val targetPartitioning = canonicalPartitioning(requiredDistribution)
if (child.outputPartitioning.guarantees(targetPartitioning)) {
child
} else {
Exchange(targetPartitioning, child)
}
} else {
child
}
}

addSortIfNecessary(addShuffleIfNecessary(child))
val newOperator = operator.withNewChildren(newChildren)
assert(childPartitioningsSatisfyDistributionRequirements(newOperator))
newOperator
} else {
operator
}
} else {
operator
}
}

val requirements =
(operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {

val fixedChildren = requirements.zipped.map {
case (AllTuples, rowOrdering, child) =>
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
case (ClusteredDistribution(clustering), rowOrdering, child) =>
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
case (OrderedDistribution(ordering), rowOrdering, child) =>
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
def addShuffleIfNecessary(child: SparkPlan, requiredDistribution: Distribution): SparkPlan = {
// A pre-condition of ensureDistributionAndOrdering is that joins' children have compatible
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This piece of reasoning is the trickiest part of this entire patch. Is this a valid argument given the current semantics of guarantees() and satisfies()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be cleaned up tomorrow when I consolidate the shuffle steps and make the assertions / invariants clearer.

// partitionings. Thus, we only need to check whether the output partitionings satisfy
// the required distribution. In the case where the children are all compatible, then they
// will either all satisfy the required distribution or will all fail to satisfy it, since
// A.guarantees(B) implies that A and B satisfy the same set of distributions.
// Therefore, if all children are compatible then either all or none of them will shuffled to
// ensure that the distribution requirements are met.
//
// Note that this reasoning implicitly assumes that operators which require compatible
// child partitionings have equivalent required distributions for those children.
if (child.outputPartitioning.satisfies(requiredDistribution)) {
child
} else {
Exchange(canonicalPartitioning(requiredDistribution), child)
}
}

case (UnspecifiedDistribution, Seq(), child) =>
def addSortIfNecessary(child: SparkPlan, requiredOrdering: Seq[SortOrder]): SparkPlan = {
if (requiredOrdering.nonEmpty) {
// If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
val minSize = Seq(requiredOrdering.size, child.outputOrdering.size).min
if (minSize == 0 || requiredOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
sqlContext.planner.BasicOperators.getSortOperator(requiredOrdering, global = false, child)
} else {
child
case (UnspecifiedDistribution, rowOrdering, child) =>
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)

case (dist, ordering, _) =>
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
}
} else {
child
}
}

val children = operator.children
val requiredChildDistribution = operator.requiredChildDistribution
val requiredChildOrdering = operator.requiredChildOrdering
assert(children.length == requiredChildDistribution.length)
assert(children.length == requiredChildOrdering.length)
val newChildren = (children, requiredChildDistribution, requiredChildOrdering).zipped.map {
case (child, requiredDistribution, requiredOrdering) =>
addSortIfNecessary(addShuffleIfNecessary(child, requiredDistribution), requiredOrdering)
}
operator.withNewChildren(newChildren)
}

operator.withNewChildren(fixedChildren)
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
ensureDistributionAndOrdering(ensureChildPartitioningsAreCompatible(operator))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** Specifies sort order for each partition requirements on the input data for this operator. */
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)

/**
* Specifies whether this operator requires all of its children to have [[outputPartitioning]]s
* that are compatible with each other.
*/
def requiresChildPartitioningsToBeCompatible: Boolean = false

/** Specifies whether this operator outputs UnsafeRows */
def outputsUnsafeRows: Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan)
extends UnaryNode {
override def output: Seq[Attribute] = child.output

override def outputPartitioning: Partitioning = {
if (numPartitions == 1) SinglePartition
else UnknownPartitioning(numPartitions)
}

protected override def doExecute(): RDD[InternalRow] = {
child.execute().map(_.copy()).coalesce(numPartitions, shuffle)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ case class LeftSemiJoinHash(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def requiresChildPartitioningsToBeCompatible: Boolean = true

protected override def doExecute(): RDD[InternalRow] = {
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
if (condition.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ case class ShuffledHashJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def requiresChildPartitioningsToBeCompatible: Boolean = true

protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ case class ShuffledHashOuterJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def requiresChildPartitioningsToBeCompatible: Boolean = true

override def outputPartitioning: Partitioning = joinType match {
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ case class SortMergeJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def requiresChildPartitioningsToBeCompatible: Boolean = true

override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)

override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule

/**
Expand All @@ -33,6 +34,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe")

override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def outputsUnsafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = false
override def canProcessSafeRows: Boolean = true
Expand All @@ -51,6 +54,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
@DeveloperApi
case class ConvertToSafe(child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def outputsUnsafeRows: Boolean = false
override def canProcessUnsafeRows: Boolean = true
override def canProcessSafeRows: Boolean = false
Expand Down
Loading