Skip to content

Commit

Permalink
[SPARK-41471][SQL] Reduce Spark shuffle when only one side of a join …
Browse files Browse the repository at this point in the history
…is KeyGroupedPartitioning

### What changes were proposed in this pull request?
When only one side of a SPJ (Storage-Partitioned Join) is KeyGroupedPartitioning, Spark currently needs to shuffle both sides using HashPartitioning. However, we may just need to shuffle the other side according to the partition transforms defined in KeyGroupedPartitioning. This is especially useful when the other side is relatively small.
1. Add new config `spark.sql.sources.v2.bucketing.shuffle.enabled` to control this feature enable or not.
2. Add `KeyGroupedPartitioner` use to partition when we know the tranform value of another side (KeyGroupedPartitioning at now). Spark already know the partition value with partition id of KeyGroupedPartitioning side in `EnsureRequirements`. Then save it in `KeyGroupedPartitioner` use to shuffle another partition, to make sure the same key data will shuffle into same partition.
3. only `identity` transform will work now. Because have another problem for now, same transform between DS V2 connector implement and catalog function will report different value, before solve this problem, we should only support `identity`. eg: in test package, `YearFunction` https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala#L47 and https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala#L143

### Why are the changes needed?
Reduce data shuffle in specific SPJ scenarios

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
add new test

Closes #42194 from Hisoka-X/SPARK-41471_one_side_keygroup.

Authored-by: Jia Fan <fanjiaeminem@qq.com>
Signed-off-by: Chao Sun <sunchao@apple.com>
  • Loading branch information
Hisoka-X authored and sunchao committed Aug 24, 2023
1 parent 517fcd3 commit ce12f6d
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 5 deletions.
16 changes: 16 additions & 0 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,22 @@ private[spark] class PartitionIdPassthrough(override val numPartitions: Int) ext
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}

/**
* A [[org.apache.spark.Partitioner]] that partitions all records using partition value map.
* The `valueMap` is a map that contains tuples of (partition value, partition id). It is generated
* by [[org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning]], used to partition
* the other side of a join to make sure records with same partition value are in the same
* partition.
*/
private[spark] class KeyGroupedPartitioner(
valueMap: mutable.Map[Seq[Any], Int],
override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = {
val keys = key.asInstanceOf[Seq[Any]]
valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions))
}
}

/**
* A [[org.apache.spark.Partitioner]] that partitions all records into a single partition.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,13 @@ case class KeyGroupedShuffleSpec(
case _ => false
}

override def canCreatePartitioning: Boolean = false
override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
}
}

case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ object InternalRowComparableWrapper {
rightPartitioning.partitionValues
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
.foreach(partition => partitionsSet.add(partition))
partitionsSet.map(_.row).toSeq
// SPARK-41471: We keep to order of partitions to make sure the order of
// partitions is deterministic in different case.
val partitionOrdering: Ordering[InternalRow] = {
RowOrdering.createNaturalAscendingOrdering(partitionDataTypes)
}
partitionsSet.map(_.row).toSeq.sorted(partitionOrdering)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_SHUFFLE_ENABLED =
buildConf("spark.sql.sources.v2.bucketing.shuffle.enabled")
.doc("During a storage-partitioned join, whether to allow to shuffle only one side." +
"When only one side is KeyGroupedPartitioning, if the conditions are met, spark will " +
"only shuffle the other side. This optimization will reduce the amount of data that " +
s"needs to be shuffle. This config requires ${V2_BUCKETING_ENABLED.key} to be enabled")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -4899,6 +4909,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingPartiallyClusteredDistributionEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED)

def v2BucketingShuffleEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ case class BatchScanExec(
if (spjParams.commonPartitionValues.isDefined &&
spjParams.applyPartialClustering) {
// A mapping from the common partition values to how many splits the partition
// should contain. Note this no longer maintain the partition key ordering.
// should contain.
val commonPartValuesMap = spjParams.commonPartitionValues
.get
.map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.exchange

import java.util.function.Supplier

import scala.collection.mutable
import scala.concurrent.Future

import org.apache.spark._
Expand All @@ -29,6 +30,7 @@ import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProces
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -299,6 +301,11 @@ object ShuffleExchangeExec {
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
case k @ KeyGroupedPartitioning(expressions, n, _) =>
val valueMap = k.uniquePartitionValues.zipWithIndex.map {
case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index)
}.toMap
new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), n)
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
Expand All @@ -325,6 +332,8 @@ object ShuffleExchangeExec {
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case KeyGroupedPartitioning(expressions, _, _) =>
row => bindReferences(expressions, outputAttributes).map(_.eval(row))
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,187 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
}

test("SPARK-41471: shuffle one side: only one side reports partitioning") {
val items_partitions = Array(identity("id"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) {
val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
"ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")

val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffle) {
assert(shuffles.size == 1, "only shuffle one side not report partitioning")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
" is not enabled")
}

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5)))
}
}
}

test("SPARK-41471: shuffle one side: shuffle side has more partition value") {
val items_partitions = Array(identity("id"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp)), " +
"(5, 26.0, cast('2023-01-01' as timestamp)), " +
"(6, 50.0, cast('2023-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) {
Seq("JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN", "FULL OUTER JOIN").foreach { joinType =>
val df = sql(s"SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i $joinType testcat.ns.$purchases p " +
"ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")

val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffle) {
assert(shuffles.size == 1, "only shuffle one side not report partitioning")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one " +
"side is not enabled")
}
joinType match {
case "JOIN" =>
checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5)))
case "LEFT OUTER JOIN" =>
checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5),
Row(4, "cc", 15.5, null)))
case "RIGHT OUTER JOIN" =>
checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null, null, 50.0),
Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5)))
case "FULL OUTER JOIN" =>
checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null, null, 50.0),
Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5),
Row(4, "cc", 15.5, null)))
}
}
}
}
}

test("SPARK-41471: shuffle one side: only one side reports partitioning with two identity") {
val items_partitions = Array(identity("id"), identity("arrive_time"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) {
val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
"ON i.id = p.item_id and i.arrive_time = p.time ORDER BY id, purchase_price, sale_price")

val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffle) {
assert(shuffles.size == 1, "only shuffle one side not report partitioning")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
" is not enabled")
}

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
}
}
}

test("SPARK-41471: shuffle one side: partitioning with transform") {
val items_partitions = Array(years("arrive_time"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2021-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) {
val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
"ON i.arrive_time = p.time ORDER BY id, purchase_price, sale_price")

val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffle) {
assert(shuffles.size == 2, "partitioning with transform not work now")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
" is not enabled")
}

checkAnswer(df, Seq(
Row(1, "aa", 40.0, 42.0),
Row(3, "bb", 10.0, 42.0),
Row(4, "cc", 15.5, 19.5)))
}
}
}

test("SPARK-41471: shuffle one side: work with group partition split") {
val items_partitions = Array(identity("id"))
createTable(items, items_schema, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchases_schema, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp)), " +
"(5, 26.0, cast('2023-01-01' as timestamp)), " +
"(6, 50.0, cast('2023-02-01' as timestamp))")

Seq(true, false).foreach { shuffle =>
withSQLConf(
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString,
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") {
val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " +
s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
"ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5)))
}
}
}

test("SPARK-44641: duplicated records when SPJ is not triggered") {
val items_partitions = Array(bucket(8, "id"))
createTable(items, items_schema, items_partitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
package org.apache.spark.sql.execution.exchange

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.optimizer.BuildRight
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.plans.physical.{SinglePartition, _}
import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1109,6 +1111,32 @@ class EnsureRequirementsSuite extends SharedSparkSession {
}
}

test("SPARK-41471: shuffle right side when" +
" spark.sql.sources.v2.bucketing.shuffle.enabled is true") {
withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {

val a1 = AttributeReference("a1", IntegerType)()

val partitionValue = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v)))
val plan1 = DummySparkPlan(outputPartitioning = KeyGroupedPartitioning(
identity(a1) :: Nil, 4, partitionValue))
val plan2 = DummySparkPlan(outputPartitioning = SinglePartition)

val smjExec = ShuffledHashJoinExec(
a1 :: Nil, a1 :: Nil, Inner, BuildRight, None, plan1, plan2)
EnsureRequirements.apply(smjExec) match {
case ShuffledHashJoinExec(_, _, _, _, _,
DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _),
ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv),
DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) =>
assert(left.expressions == a1 :: Nil)
assert(attrs == a1 :: Nil)
assert(partitionValue == pv)
case other => fail(other.toString)
}
}
}

test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing key order") {
val lKey = AttributeReference("key", IntegerType)()
val lKey2 = AttributeReference("key2", IntegerType)()
Expand Down

0 comments on commit ce12f6d

Please sign in to comment.