Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41471][SQL] Reduce Spark shuffle when only one side of a join is KeyGroupedPartitioning #42194

Closed
wants to merge 18 commits into from
Closed
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@ private[spark] class PartitionIdPassthrough(override val numPartitions: Int) ext
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}

/**
* A [[org.apache.spark.Partitioner]] that partitions all records use partition value map
*/
private[spark] class KeyGroupedPartitioner(
valueMap: mutable.Map[Seq[Any], Int],
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = {
val keys = key.asInstanceOf[Seq[Any]]
valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions))
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
}
}

/**
* A [[org.apache.spark.Partitioner]] that partitions all records into a single partition.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,13 @@ case class KeyGroupedShuffleSpec(
case _ => false
}

override def canCreatePartitioning: Boolean = false
override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleOneSideEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
}
}

case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_SHUFFLE_ONE_SIDE_ENABLED =
buildConf("spark.sql.sources.v2.bucketing.shuffleOneSide.enabled")
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
.doc("During a storage-partitioned join, whether to allow to shuffle only one side." +
"When only one side is KeyGroupedPartitioning, if the conditions are met, spark will " +
"only shuffle the other side. This optimization will reduce the amount of data that " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we make the algorithm smarter? If the other side is large, doing a KeyGroupedPartitioning may lead to skew and it's still better to shuffle both sides with hash partitioning.

Let's think of an extreme case: one side reports KeyGroupedPartitioning with only one partition, with this optimization, we end up with doing the join using a single thread.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the ShuffleSpec "framework" in EnsureRequirements already takes this into consideration. This PR mainly makes KeyGroupedShuffleSpec behaves similar to HashShuffleSpec and be able to shuffle the other side (via making canCreatePartitioning return true).

s"needs to be shuffle. This config requires ${V2_BUCKETING_ENABLED.key} to be enabled")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -4877,6 +4887,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingPartiallyClusteredDistributionEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED)

def v2BucketingShuffleOneSideEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_SHUFFLE_ONE_SIDE_ENABLED)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.exchange

import java.util.function.Supplier

import scala.collection.mutable
import scala.concurrent.Future

import org.apache.spark._
Expand All @@ -29,6 +30,7 @@ import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProces
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -299,6 +301,12 @@ object ShuffleExchangeExec {
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
case KeyGroupedPartitioning(expressions, n, partitionValues) =>
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
val partitionValueMap = mutable.Map[Seq[Any], Int]()
partitionValues.zipWithIndex.foreach(partAndIndex => {
Hisoka-X marked this conversation as resolved.
Show resolved Hide resolved
partitionValueMap(partAndIndex._1.toSeq(expressions.map(_.dataType))) = partAndIndex._2
})
new KeyGroupedPartitioner(partitionValueMap, n)
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
Expand All @@ -325,6 +333,8 @@ object ShuffleExchangeExec {
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case KeyGroupedPartitioning(expressions, _, _) =>
row => bindReferences(expressions, outputAttributes).map(_.eval(row))
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,108 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
}

test("SPARK-41471: shuffle one side: only one side reports partitioning") {
val items_partitions = Array(identity("id"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

Seq(true, false).foreach { shuffleOneSide =>
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ONE_SIDE_ENABLED.key -> shuffleOneSide.toString) {
val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
"ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")

val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffleOneSide) {
assert(shuffles.size == 1, "only shuffle one side not report partitioning")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
" is not enabled")
}

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5)))
}
}
}

test("SPARK-41471: shuffle one side: only one side reports partitioning with two identity") {
val items_partitions = Array(identity("id"), identity("arrive_time"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

Seq(true, false).foreach { shuffleOneSide =>
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ONE_SIDE_ENABLED.key -> shuffleOneSide.toString) {
val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
"ON i.id = p.item_id and i.arrive_time = p.time ORDER BY id, purchase_price, sale_price")

val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffleOneSide) {
assert(shuffles.size == 1, "only shuffle one side not report partitioning")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
" is not enabled")
}

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
}
}
}

test("SPARK-41471: shuffle one side: partitioning with transform") {
val items_partitions = Array(years("arrive_time"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2021-02-01' as timestamp))")

Seq(true, false).foreach { shuffleOneSide =>
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ONE_SIDE_ENABLED.key -> shuffleOneSide.toString) {
val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
"ON i.arrive_time = p.time ORDER BY id, purchase_price, sale_price")

val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffleOneSide) {
assert(shuffles.size == 2, "partitioning with transform not work now")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
" is not enabled")
}

checkAnswer(df, Seq(
Row(1, "aa", 40.0, 42.0),
Row(3, "bb", 10.0, 42.0),
Row(4, "cc", 15.5, 19.5)))
}
}
}

test("SPARK-44641: duplicated records when SPJ is not triggered") {
val items_partitions = Array(bucket(8, "id"))
createTable(items, items_schema, items_partitions)
Expand Down