Skip to content

Commit

Permalink
[SPARK-32056][SQL][FOLLOW-UP] Coalesce partitions for repartiotion hi…
Browse files Browse the repository at this point in the history
…nt and sql when AQE is enabled

### What changes were proposed in this pull request?

As the followup of #28900, this patch extends coalescing partitions to repartitioning using hints and SQL syntax without specifying number of partitions, when AQE is enabled.

### Why are the changes needed?

When repartitionning using hints and SQL syntax, we should follow the shuffling behavior of repartition by expression/range to coalesce partitions when AQE is enabled.

### Does this PR introduce _any_ user-facing change?

Yes. After this change, if users don't specify the number of partitions when repartitioning using `REPARTITION`/`REPARTITION_BY_RANGE` hint or `DISTRIBUTE BY`/`CLUSTER BY`, AQE will coalesce partitions.

### How was this patch tested?

Unit tests.

Closes #28952 from viirya/SPARK-32056-sql.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
viirya authored and dongjoon-hyun committed Jul 1, 2020
1 parent bcf2330 commit 2a52a1b
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ object ResolveHints {
val hintName = hint.name.toUpperCase(Locale.ROOT)

def createRepartitionByExpression(
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder])
if (sortOrders.nonEmpty) throw new IllegalArgumentException(
s"""Invalid partitionExprs specified: $sortOrders
Expand All @@ -208,11 +208,11 @@ object ResolveHints {
throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter")

case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) if shuffle =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) if shuffle =>
createRepartitionByExpression(conf.numShufflePartitions, param)
createRepartitionByExpression(None, param)
}
}

Expand All @@ -224,7 +224,7 @@ object ResolveHints {
val hintName = hint.name.toUpperCase(Locale.ROOT)

def createRepartitionByExpression(
numPartitions: Int, partitionExprs: Seq[Any]): RepartitionByExpression = {
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
if (invalidParams.nonEmpty) {
throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
Expand All @@ -239,11 +239,11 @@ object ResolveHints {

hint.parameters match {
case param @ Seq(IntegerLiteral(numPartitions), _*) =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) =>
createRepartitionByExpression(numPartitions, param.tail)
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) =>
createRepartitionByExpression(conf.numShufflePartitions, param)
createRepartitionByExpression(None, param)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class ResolveHintsSuite extends AnalysisTest {
checkAnalysis(
UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
Seq(AttributeReference("a", IntegerType)()), testRelation, conf.numShufflePartitions))
Seq(AttributeReference("a", IntegerType)()), testRelation, None))

val e = intercept[IllegalArgumentException] {
checkAnalysis(
Expand All @@ -187,7 +187,7 @@ class ResolveHintsSuite extends AnalysisTest {
"REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("a")), table("TaBlE")),
RepartitionByExpression(
Seq(SortOrder(AttributeReference("a", IntegerType)(), Ascending)),
testRelation, conf.numShufflePartitions))
testRelation, None))

val errMsg2 = "REPARTITION Hint parameter should include columns, but"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
ctx: QueryOrganizationContext,
expressions: Seq[Expression],
query: LogicalPlan): LogicalPlan = {
RepartitionByExpression(expressions, query, conf.numShufflePartitions)
RepartitionByExpression(expressions, query, None)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,20 +199,20 @@ class SparkSqlParserSuite extends AnalysisTest {
assertEqual(s"$baseSql distribute by a, b",
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions))
None))
assertEqual(s"$baseSql distribute by a sort by b",
Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions)))
None)))
assertEqual(s"$baseSql cluster by a, b",
Sort(SortOrder(UnresolvedAttribute("a"), Ascending) ::
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions)))
None)))
}

test("pipeline concatenation") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.net.URI
import org.apache.log4j.Level

import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan}
Expand Down Expand Up @@ -130,6 +130,17 @@ class AdaptiveQueryExecSuite
assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader))
}

private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = {
// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
}

test("Change merge join to broadcast join") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
Expand Down Expand Up @@ -1040,14 +1051,8 @@ class AdaptiveQueryExecSuite
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)

// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
val plan = df1.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == 10)
checkInitialPartitionNum(df1, 10)
checkInitialPartitionNum(df2, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
Expand Down Expand Up @@ -1081,14 +1086,8 @@ class AdaptiveQueryExecSuite
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)

// repartition obeys initialPartitionNum when adaptiveExecutionEnabled
val plan = df1.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == 10)
checkInitialPartitionNum(df1, 10)
checkInitialPartitionNum(df2, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
Expand All @@ -1100,4 +1099,52 @@ class AdaptiveQueryExecSuite
}
}
}

test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") {
Seq(true, false).foreach { enableAQE =>
withTempView("test") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {

spark.range(10).toDF.createTempView("test")

val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test")
val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test")
val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id")
val df4 = spark.sql("SELECT * from test CLUSTER BY id")

val partitionsNum1 = df1.rdd.collectPartitions().length
val partitionsNum2 = df2.rdd.collectPartitions().length
val partitionsNum3 = df3.rdd.collectPartitions().length
val partitionsNum4 = df4.rdd.collectPartitions().length

if (enableAQE) {
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)
assert(partitionsNum3 < 10)
assert(partitionsNum4 < 10)

checkInitialPartitionNum(df1, 10)
checkInitialPartitionNum(df2, 10)
checkInitialPartitionNum(df3, 10)
checkInitialPartitionNum(df4, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
assert(partitionsNum3 === 10)
assert(partitionsNum4 === 10)
}

// Don't coalesce partitions if the number of partitions is specified.
val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test")
val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test")
assert(df5.rdd.collectPartitions().length == 10)
assert(df6.rdd.collectPartitions().length == 10)
}
}
}
}
}

0 comments on commit 2a52a1b

Please sign in to comment.