Skip to content

Commit

Permalink
[SPARK-32056][SQL][FOLLOW-UP] Coalesce partitions for repartiotion hi…
Browse files Browse the repository at this point in the history
…nt and sql when AQE is enabled

As the followup of apache#28900, this patch extends coalescing partitions to repartitioning using hints and SQL syntax without specifying number of partitions, when AQE is enabled.

When repartitionning using hints and SQL syntax, we should follow the shuffling behavior of repartition by expression/range to coalesce partitions when AQE is enabled.

Yes. After this change, if users don't specify the number of partitions when repartitioning using `REPARTITION`/`REPARTITION_BY_RANGE` hint or `DISTRIBUTE BY`/`CLUSTER BY`, AQE will coalesce partitions.

Unit tests.

Closes apache#28952 from viirya/SPARK-32056-sql.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
viirya authored and Matt Hawes committed May 16, 2021
1 parent f4b9159 commit c098917
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ object ResolveHints {
val hintName = hint.name.toUpperCase(Locale.ROOT)

def createRepartitionByExpression(
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder])
if (sortOrders.nonEmpty) throw new IllegalArgumentException(
s"""Invalid partitionExprs specified: $sortOrders
Expand All @@ -208,11 +208,11 @@ object ResolveHints {
throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter")

case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) if shuffle =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) if shuffle =>
createRepartitionByExpression(conf.numShufflePartitions, param)
createRepartitionByExpression(None, param)
}
}

Expand All @@ -224,7 +224,7 @@ object ResolveHints {
val hintName = hint.name.toUpperCase(Locale.ROOT)

def createRepartitionByExpression(
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
if (invalidParams.nonEmpty) {
throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
Expand All @@ -239,11 +239,11 @@ object ResolveHints {

hint.parameters match {
case param @ Seq(IntegerLiteral(numPartitions), _*) =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) =>
createRepartitionByExpression(conf.numShufflePartitions, param)
createRepartitionByExpression(None, param)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class ResolveHintsSuite extends AnalysisTest {
checkAnalysis(
UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
Seq(AttributeReference("a", IntegerType)()), testRelation, conf.numShufflePartitions))
Seq(AttributeReference("a", IntegerType)()), testRelation, None))

val e = intercept[IllegalArgumentException] {
checkAnalysis(
Expand All @@ -187,7 +187,7 @@ class ResolveHintsSuite extends AnalysisTest {
"REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)),
testRelation, conf.numShufflePartitions))
testRelation, None))

val errMsg2 = "REPARTITION Hint parameter should include columns, but"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
ctx: QueryOrganizationContext,
expressions: Seq[Expression],
query: LogicalPlan): LogicalPlan = {
RepartitionByExpression(expressions, query, conf.numShufflePartitions)
RepartitionByExpression(expressions, query, None)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,20 @@ class SparkSqlParserSuite extends AnalysisTest {
assertEqual(s"$baseSql distribute by a, b",
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions))
None))
assertEqual(s"$baseSql distribute by a sort by b",
Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions)))
None)))
assertEqual(s"$baseSql cluster by a, b",
Sort(SortOrder(UnresolvedAttribute("a"), Ascending) ::
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions)))
None)))
}

test("pipeline concatenation") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.net.URI
import org.apache.log4j.Level

import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan}
import org.apache.spark.sql.execution.adaptive.OptimizeLocalShuffleReader.LOCAL_SHUFFLE_READER_DESCRIPTION
Expand Down Expand Up @@ -130,6 +130,17 @@ class AdaptiveQueryExecSuite
assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader))
}

private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = {
// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
}

test("Change merge join to broadcast join") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
Expand Down Expand Up @@ -892,14 +903,8 @@ class AdaptiveQueryExecSuite
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)

// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
val plan = df1.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == 10)
checkInitialPartitionNum(df1, 10)
checkInitialPartitionNum(df2, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
Expand Down Expand Up @@ -933,14 +938,8 @@ class AdaptiveQueryExecSuite
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)

// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
val plan = df1.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == 10)
checkInitialPartitionNum(df1, 10)
checkInitialPartitionNum(df2, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
Expand All @@ -966,4 +965,52 @@ class AdaptiveQueryExecSuite
}
}
}

test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") {
Seq(true, false).foreach { enableAQE =>
withTempView("test") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {

spark.range(10).toDF.createTempView("test")

val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test")
val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test")
val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id")
val df4 = spark.sql("SELECT * from test CLUSTER BY id")

val partitionsNum1 = df1.rdd.collectPartitions().length
val partitionsNum2 = df2.rdd.collectPartitions().length
val partitionsNum3 = df3.rdd.collectPartitions().length
val partitionsNum4 = df4.rdd.collectPartitions().length

if (enableAQE) {
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)
assert(partitionsNum3 < 10)
assert(partitionsNum4 < 10)

checkInitialPartitionNum(df1, 10)
checkInitialPartitionNum(df2, 10)
checkInitialPartitionNum(df3, 10)
checkInitialPartitionNum(df4, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
assert(partitionsNum3 === 10)
assert(partitionsNum4 === 10)
}

// Don't coalesce partitions if the number of partitions is specified.
val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test")
val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test")
assert(df5.rdd.collectPartitions().length == 10)
assert(df6.rdd.collectPartitions().length == 10)
}
}
}
}
}

0 comments on commit c098917

Please sign in to comment.