Skip to content

Commit

Permalink
[SPARK-2213] [SQL] sort merge join for spark sql
Browse files Browse the repository at this point in the history
Thanks for the initial work from Ishiihara in apache#3173

This PR introduce a new join method of sort merge join, which firstly ensure that keys of same value are in the same partition, and inside each partition the Rows are sorted by key. Then we can run down both sides together, find matched rows using [sort merge join](http://en.wikipedia.org/wiki/Sort-merge_join). In this way, we don't have to store the whole hash table of one side as hash join, thus we have less memory usage. Also, this PR would benefit from apache#3438 , making the sorting phrase much more efficient.

We introduced a new configuration of "spark.sql.planner.sortMergeJoin" to switch between this(`true`) and ShuffledHashJoin(`false`), probably we want the default value of it be `false` at first.

Author: Daoyuan Wang <daoyuan.wang@intel.com>
Author: Michael Armbrust <michael@databricks.com>

This patch had conflicts when merged, resolved by
Committer: Michael Armbrust <michael@databricks.com>

Closes apache#5208 from adrian-wang/smj and squashes the following commits:

2493b9f [Daoyuan Wang] fix style
5049d88 [Daoyuan Wang] propagate rowOrdering for RangePartitioning
f91a2ae [Daoyuan Wang] yin's comment: use external sort if option is enabled, add comments
f515cd2 [Daoyuan Wang] yin's comment: outputOrdering, join suite refine
ec8061b [Daoyuan Wang] minor change
413fd24 [Daoyuan Wang] Merge pull request #3 from marmbrus/pr/5208
952168a [Michael Armbrust] add type
5492884 [Michael Armbrust] copy when ordering
7ddd656 [Michael Armbrust] Cleanup addition of ordering requirements
b198278 [Daoyuan Wang] inherit ordering in project
c8e82a3 [Daoyuan Wang] fix style
6e897dd [Daoyuan Wang] hide boundReference from manually construct RowOrdering for key compare in smj
8681d73 [Daoyuan Wang] refactor Exchange and fix copy for sorting
2875ef2 [Daoyuan Wang] fix changed configuration
61d7f49 [Daoyuan Wang] add omitted comment
00a4430 [Daoyuan Wang] fix bug
078d69b [Daoyuan Wang] address comments: add comments, do sort in shuffle, and others
3af6ba5 [Daoyuan Wang] use buffer for only one side
171001f [Daoyuan Wang] change default outputordering
47455c9 [Daoyuan Wang] add apache license ...
a28277f [Daoyuan Wang] fix style
645c70b [Daoyuan Wang] address comments using sort
068c35d [Daoyuan Wang] fix new style and add some tests
925203b [Daoyuan Wang] address comments
07ce92f [Daoyuan Wang] fix ArrayIndexOutOfBound
42fca0e [Daoyuan Wang] code clean
e3ec096 [Daoyuan Wang] fix comment style..
2edd235 [Daoyuan Wang] fix outputpartitioning
57baa40 [Daoyuan Wang] fix sort eval bug
303b6da [Daoyuan Wang] fix several errors
95db7ad [Daoyuan Wang] fix brackets for if-statement
4464f16 [Daoyuan Wang] fix error
880d8e9 [Daoyuan Wang] sort merge join for spark sql
  • Loading branch information
adrian-wang authored and marmbrus committed Apr 15, 2015
1 parent 4754e16 commit 585638e
Show file tree
Hide file tree
Showing 11 changed files with 534 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.types.{UTF8String, StructType, NativeType}

import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType}

/**
* An extended interface to [[Row]] that allows the values for each column to be updated. Setting
Expand Down Expand Up @@ -239,3 +238,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
return 0
}
}

object RowOrdering {
def forSchema(dataTypes: Seq[DataType]): RowOrdering =
new RowOrdering(dataTypes.zipWithIndex.map {
case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ sealed trait Partitioning {
* only compatible if the `numPartitions` of them is the same.
*/
def compatibleWith(other: Partitioning): Boolean

/** Returns the expressions that are used to key the partitioning. */
def keyExpressions: Seq[Expression]
}

case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
Expand All @@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case UnknownPartitioning(_) => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

case object SinglePartition extends Partitioning {
Expand All @@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning {
case SinglePartition => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

case object BroadcastPartitioning extends Partitioning {
Expand All @@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning {
case SinglePartition => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

/**
Expand Down Expand Up @@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}

override def keyExpressions: Seq[Expression] = expressions

override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
Expand Down Expand Up @@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}

override def keyExpressions: Seq[Expression] = ordering.map(_.child)

override def eval(input: Row): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
8 changes: 8 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ private[spark] object SQLConf {
// Options that control which operators can be chosen by the query planner. These should be
// considered hints and may be ignored by future versions of Spark SQL.
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin"

// This is only used for the thriftserver
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
Expand Down Expand Up @@ -128,6 +129,13 @@ private[sql] class SQLConf extends Serializable {
/** When true the planner will use the external sort, which may spill to disk. */
private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean

/**
* Sort merge join would sort the two side of join first, and then iterate both sides together
* only once to get all matches. Using sort merge join can save a lot of memory usage compared
* to HashJoin.
*/
private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean

/**
* When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
* that evaluates expressions found in queries. In general this custom code runs much faster
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
Batch("Add exchange", Once, AddExchange(self)) :: Nil
Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil
}

protected[sql] def openSession(): SQLSession = {
Expand Down
148 changes: 120 additions & 28 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,42 @@ package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair

object Exchange {
/**
* Returns true when the ordering expressions are a subset of the key.
* if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]].
*/
def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = {
desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
}
}

/**
* :: DeveloperApi ::
* Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
* resulting partition based on expressions from the partition key. It is invalid to construct an
* exchange operator with a `newOrdering` that cannot be calculated using the partitioning key.
*/
@DeveloperApi
case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
case class Exchange(
newPartitioning: Partitioning,
newOrdering: Seq[SortOrder],
child: SparkPlan)
extends UnaryNode {

override def outputPartitioning: Partitioning = newPartitioning

override def outputOrdering: Seq[SortOrder] = newOrdering

override def output: Seq[Attribute] = child.output

/** We must copy rows when sort based shuffle is on */
Expand All @@ -45,6 +63,20 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
private val bypassMergeThreshold =
child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)

private val keyOrdering = {
if (newOrdering.nonEmpty) {
val key = newPartitioning.keyExpressions
val boundOrdering = newOrdering.map { o =>
val ordinal = key.indexOf(o.child)
if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning")
o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable))
}
new RowOrdering(boundOrdering)
} else {
null // Ordering will not be used
}
}

override def execute(): RDD[Row] = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
Expand All @@ -56,7 +88,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
// we can avoid the defensive copies to improve performance. In the long run, we probably
// want to include information in shuffle dependencies to indicate whether elements in the
// source RDD should be copied.
val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold

val rdd = if (willMergeSort || newOrdering.nonEmpty) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
Expand All @@ -69,12 +103,17 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
}
}
val part = new HashPartitioner(numPartitions)
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
val shuffled =
if (newOrdering.nonEmpty) {
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering)
} else {
new ShuffledRDD[Row, Row, Row](rdd, part)
}
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)

case RangePartitioning(sortingExpressions, numPartitions) =>
val rdd = if (sortBasedShuffleOn) {
val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) {
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
} else {
child.execute().mapPartitions { iter =>
Expand All @@ -87,7 +126,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
implicit val ordering = new RowOrdering(sortingExpressions, child.output)

val part = new RangePartitioner(numPartitions, rdd, ascending = true)
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
val shuffled =
if (newOrdering.nonEmpty) {
new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering)
} else {
new ShuffledRDD[Row, Null, Null](rdd, part)
}
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))

shuffled.map(_._1)
Expand Down Expand Up @@ -120,27 +164,34 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
* of input data meets the
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
* each operator by inserting [[Exchange]] Operators where required.
* each operator by inserting [[Exchange]] Operators where required. Also ensure that the
* required input partition ordering requirements are met.
*/
private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
def numPartitions: Int = sqlContext.conf.numShufflePartitions

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
// Check if every child's outputPartitioning satisfies the corresponding
// True iff every child's outputPartitioning satisfies the corresponding
// required data distribution.
def meetsRequirements: Boolean =
!operator.requiredChildDistribution.zip(operator.children).map {
operator.requiredChildDistribution.zip(operator.children).forall {
case (required, child) =>
val valid = child.outputPartitioning.satisfies(required)
logDebug(
s"${if (valid) "Valid" else "Invalid"} distribution," +
s"required: $required current: ${child.outputPartitioning}")
valid
}.exists(!_)
}

// Check if outputPartitionings of children are compatible with each other.
// True iff any of the children are incorrectly sorted.
def needsAnySort: Boolean =
operator.requiredChildOrdering.zip(operator.children).exists {
case (required, child) => required.nonEmpty && required != child.outputOrdering
}

// True iff outputPartitionings of children are compatible with each other.
// It is possible that every child satisfies its required data distribution
// but two children have incompatible outputPartitionings. For example,
// A dataset is range partitioned by "a.asc" (RangePartitioning) and another
Expand All @@ -157,28 +208,69 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
case Seq(a,b) => a compatibleWith b
}.exists(!_)

// Check if the partitioning we want to ensure is the same as the child's output
// partitioning. If so, we do not need to add the Exchange operator.
def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan =
if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child
// Adds Exchange or Sort operators as required
def addOperatorsIfNecessary(
partitioning: Partitioning,
rowOrdering: Seq[SortOrder],
child: SparkPlan): SparkPlan = {
val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
val needsShuffle = child.outputPartitioning != partitioning
val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering)

if (needSort && needsShuffle && canSortWithShuffle) {
Exchange(partitioning, rowOrdering, child)
} else {
val withShuffle = if (needsShuffle) {
Exchange(partitioning, Nil, child)
} else {
child
}

if (meetsRequirements && compatible) {
val withSort = if (needSort) {
if (sqlContext.conf.externalSortEnabled) {
ExternalSort(rowOrdering, global = false, withShuffle)
} else {
Sort(rowOrdering, global = false, withShuffle)
}
} else {
withShuffle
}

withSort
}
}

if (meetsRequirements && compatible && !needsAnySort) {
operator
} else {
// At least one child does not satisfies its required data distribution or
// at least one child's outputPartitioning is not compatible with another child's
// outputPartitioning. In this case, we need to add Exchange operators.
val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map {
case (AllTuples, child) =>
addExchangeIfNecessary(SinglePartition, child)
case (ClusteredDistribution(clustering), child) =>
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
case (OrderedDistribution(ordering), child) =>
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
case (UnspecifiedDistribution, child) => child
case (dist, _) => sys.error(s"Don't know how to ensure $dist")
val requirements =
(operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)

val fixedChildren = requirements.zipped.map {
case (AllTuples, rowOrdering, child) =>
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
case (ClusteredDistribution(clustering), rowOrdering, child) =>
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
case (OrderedDistribution(ordering), rowOrdering, child) =>
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)

case (UnspecifiedDistribution, Seq(), child) =>
child
case (UnspecifiedDistribution, rowOrdering, child) =>
if (sqlContext.conf.externalSortEnabled) {
ExternalSort(rowOrdering, global = false, child)
} else {
Sort(rowOrdering, global = false, child)
}

case (dist, ordering, _) =>
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
}
operator.withNewChildren(repartitionedChildren)

operator.withNewChildren(fixedChildren)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
def requiredChildDistribution: Seq[Distribution] =
Seq.fill(children.size)(UnspecifiedDistribution)

/** Specifies how data is ordered in each partition. */
def outputOrdering: Seq[SortOrder] = Nil

/** Specifies sort order for each partition requirements on the input data for this operator. */
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)

/**
* Runs this query returning the result as an RDD.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)

// If the sort merge join option is set, we want to use sort merge join prior to hashjoin
// for now let's support inner join first, then add outer join
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.sortMergeJoinEnabled =>
val mergeJoin =
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
Expand Down Expand Up @@ -309,7 +317,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.OneRowRelation =>
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
execution.Exchange(
HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil
case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
Expand Down
Loading

0 comments on commit 585638e

Please sign in to comment.