Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41471][SQL] Reduce Spark shuffle when only one side of a join is KeyGroupedPartitioning #42194

Closed
wants to merge 18 commits into from
Closed
15 changes: 15 additions & 0 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@ private[spark] class PartitionIdPassthrough(override val numPartitions: Int) ext
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}

/**
* A [[org.apache.spark.Partitioner]] that partitions all records use partition value map.
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
* The valueMap is a map that contains tuples of (partition value, partition id). It generated
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
* by [[org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning]], used to partition
* another side to make sure records with same partition value are in same partition.
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
*/
private[spark] class KeyGroupedPartitioner(
valueMap: mutable.Map[Seq[Any], Int],
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = {
val keys = key.asInstanceOf[Seq[Any]]
valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions))
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
}
}

/**
* A [[org.apache.spark.Partitioner]] that partitions all records into a single partition.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,13 @@ case class KeyGroupedShuffleSpec(
case _ => false
}

override def canCreatePartitioning: Boolean = false
override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
}
}

case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_SHUFFLE_ENABLED =
buildConf("spark.sql.sources.v2.bucketing.shuffle.enabled")
.doc("During a storage-partitioned join, whether to allow to shuffle only one side." +
"When only one side is KeyGroupedPartitioning, if the conditions are met, spark will " +
"only shuffle the other side. This optimization will reduce the amount of data that " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we make the algorithm smarter? If the other side is large, doing a KeyGroupedPartitioning may lead to skew and it's still better to shuffle both sides with hash partitioning.

Let's think of an extreme case: one side reports KeyGroupedPartitioning with only one partition, with this optimization, we end up with doing the join using a single thread.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the ShuffleSpec "framework" in EnsureRequirements already takes this into consideration. This PR mainly makes KeyGroupedShuffleSpec behaves similar to HashShuffleSpec and be able to shuffle the other side (via making canCreatePartitioning return true).

s"needs to be shuffle. This config requires ${V2_BUCKETING_ENABLED.key} to be enabled")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -4879,6 +4889,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingPartiallyClusteredDistributionEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED)

def v2BucketingShuffleEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,16 @@ case class BatchScanExec(

// Now fill missing partition keys with empty partitions
val partitionMapping = nestGroupedPartitions.toMap
finalPartitions = spjParams.commonPartitionValues.get.flatMap {

// SPARK-41471: We keep to order of partition keys in `commonPartitionValues` to
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
// make sure the order of partitions is deterministic in different case.
val partitionDataTypes = p.expressions.map(_.dataType)
val partitionOrdering: Ordering[(InternalRow, Int)] = {
RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1)
}
val sortedCommonPartitionValues = spjParams.commonPartitionValues.get
.sorted(partitionOrdering)
finalPartitions = sortedCommonPartitionValues.flatMap {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.exchange

import java.util.function.Supplier

import scala.collection.mutable
import scala.concurrent.Future

import org.apache.spark._
Expand All @@ -29,6 +30,7 @@ import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProces
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -299,6 +301,11 @@ object ShuffleExchangeExec {
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
case k @ KeyGroupedPartitioning(expressions, n, _) =>
val valueMap = k.uniquePartitionValues.zipWithIndex.map {
case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index)
}.toMap
new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), n)
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
Expand All @@ -325,6 +332,8 @@ object ShuffleExchangeExec {
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case KeyGroupedPartitioning(expressions, _, _) =>
row => bindReferences(expressions, outputAttributes).map(_.eval(row))
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,187 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
}

test("SPARK-41471: shuffle one side: only one side reports partitioning") {
val items_partitions = Array(identity("id"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) {
val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
"ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")

val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffle) {
assert(shuffles.size == 1, "only shuffle one side not report partitioning")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
" is not enabled")
}

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5)))
}
}
}

test("SPARK-41471: shuffle one side: shuffle side has more partition value") {
val items_partitions = Array(identity("id"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp)), " +
"(5, 26.0, cast('2023-01-01' as timestamp)), " +
"(6, 50.0, cast('2023-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) {
Seq("JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN", "FULL OUTER JOIN").foreach { joinType =>
val df = sql(s"SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i $joinType testcat.ns.$purchases p " +
"ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")

val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffle) {
assert(shuffles.size == 1, "only shuffle one side not report partitioning")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one " +
"side is not enabled")
}
joinType match {
case "JOIN" =>
checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5)))
case "LEFT OUTER JOIN" =>
checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5),
Row(4, "cc", 15.5, null)))
case "RIGHT OUTER JOIN" =>
checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null, null, 50.0),
Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5)))
case "FULL OUTER JOIN" =>
checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null, null, 50.0),
Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5),
Row(4, "cc", 15.5, null)))
}
}
}
}
}

test("SPARK-41471: shuffle one side: only one side reports partitioning with two identity") {
val items_partitions = Array(identity("id"), identity("arrive_time"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) {
val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
"ON i.id = p.item_id and i.arrive_time = p.time ORDER BY id, purchase_price, sale_price")

val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffle) {
assert(shuffles.size == 1, "only shuffle one side not report partitioning")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
" is not enabled")
}

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
}
}
}

test("SPARK-41471: shuffle one side: partitioning with transform") {
val items_partitions = Array(years("arrive_time"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2021-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) {
val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
"ON i.arrive_time = p.time ORDER BY id, purchase_price, sale_price")

val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffle) {
assert(shuffles.size == 2, "partitioning with transform not work now")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
" is not enabled")
}

checkAnswer(df, Seq(
Row(1, "aa", 40.0, 42.0),
Row(3, "bb", 10.0, 42.0),
Row(4, "cc", 15.5, 19.5)))
}
}
}

test("SPARK-41471: shuffle one side: work with group partition split") {
val items_partitions = Array(identity("id"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp)), " +
"(5, 26.0, cast('2023-01-01' as timestamp)), " +
"(6, 50.0, cast('2023-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString,
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") {
val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
"ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5)))
}
}
}

test("SPARK-44641: duplicated records when SPJ is not triggered") {
val items_partitions = Array(bucket(8, "id"))
createTable(items, items_schema, items_partitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
package org.apache.spark.sql.execution.exchange

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.optimizer.BuildRight
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.plans.physical.{SinglePartition, _}
import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1109,6 +1111,32 @@ class EnsureRequirementsSuite extends SharedSparkSession {
}
}

test(s"SPARK-41471: shuffle right side when" +
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
s" spark.sql.sources.v2.bucketing.shuffle.enabled is true") {
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {

val a1 = AttributeReference("a1", IntegerType)()

val partitionValue = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v)))
val plan1 = DummySparkPlan(outputPartitioning = KeyGroupedPartitioning(
identity(a1) :: Nil, 4, partitionValue))
val plan2 = DummySparkPlan(outputPartitioning = SinglePartition)

val smjExec = ShuffledHashJoinExec(
a1 :: Nil, a1 :: Nil, Inner, BuildRight, None, plan1, plan2)
EnsureRequirements.apply(smjExec) match {
case ShuffledHashJoinExec(_, _, _, _, _,
DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _),
ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv),
DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) =>
assert(left.expressions == a1 :: Nil)
assert(attrs == a1 :: Nil)
assert(partitionValue == pv)
case other => fail(other.toString)
}
}
}

test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing key order") {
val lKey = AttributeReference("key", IntegerType)()
val lKey2 = AttributeReference("key2", IntegerType)()
Expand Down