Skip to content

Commit

Permalink
fix: CometShuffleExchangeExec logical link should be correct (#324)
Browse files Browse the repository at this point in the history
* fix: CometShuffleExchangeExec logical link should be correct

* Implement equals and hashCode for CometShuffleExchangeExec

* Update plan stability

* Restore plan stability

* Dedup test

* Remove unused import

* Fix test

* Use columnar shuffle
  • Loading branch information
viirya authored Apr 30, 2024
1 parent 1865284 commit b326637
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,37 @@ class CometSparkSessionExtensions
}

// Set up logical links
newPlan = newPlan.transform { case op: CometExec =>
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
op
newPlan = newPlan.transform {
case op: CometExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
case op: CometShuffleExchangeExec =>
// Original Spark shuffle exchange operator might have empty logical link.
// But the `setLogicalLink` call above on downstream operator of
// `CometShuffleExchangeExec` will set its logical link to the downstream
// operators which cause AQE behavior to be incorrect. So we need to unset
// the logical link here.
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op

case op: CometBroadcastExchangeExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
}

// Convert native execution block by linking consecutive native operators.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
import org.apache.spark.util.random.XORShiftRandom

import com.google.common.base.Objects

import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde.serializeDataType
Expand All @@ -61,6 +63,7 @@ import org.apache.comet.shims.ShimCometShuffleExchangeExec
case class CometShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
originalPlan: ShuffleExchangeLike,
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS,
shuffleType: ShuffleType = CometNativeShuffle,
advisoryPartitionSize: Option[Long] = None)
Expand Down Expand Up @@ -192,6 +195,24 @@ case class CometShuffleExchangeExec(

override protected def withNewChildInternal(newChild: SparkPlan): CometShuffleExchangeExec =
copy(child = newChild)

override def equals(obj: Any): Boolean = {
obj match {
case other: CometShuffleExchangeExec =>
this.outputPartitioning == other.outputPartitioning &&
this.shuffleOrigin == other.shuffleOrigin && this.child == other.child &&
this.shuffleType == other.shuffleType &&
this.advisoryPartitionSize == other.advisoryPartitionSize
case _ =>
false
}
}

override def hashCode(): Int =
Objects.hashCode(outputPartitioning, shuffleOrigin, shuffleType, advisoryPartitionSize, child)

override def stringArgs: Iterator[Any] =
Iterator(outputPartitioning, shuffleOrigin, shuffleType, child) ++ Iterator(s"[plan_id=$id]")
}

object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ abstract class CometNativeExec extends CometExec {
val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find {
case (_: CometBroadcastExchangeExec, _) => false
case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false
case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false
case _ => true
}

Expand All @@ -264,6 +265,13 @@ abstract class CometNativeExec extends CometExec {
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case BroadcastQueryStageExec(
_,
ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
_) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case _ if idx == firstNonBroadcastPlan.get._2 =>
inputs += firstNonBroadcastPlanRDD
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ trait ShimCometShuffleExchangeExec {
CometShuffleExchangeExec(
s.outputPartitioning,
s.child,
s,
s.shuffleOrigin,
shuffleType,
advisoryPartitionSize)
Expand Down
31 changes: 29 additions & 2 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.Window
Expand All @@ -62,6 +62,29 @@ class CometExecSuite extends CometTestBase {
}
}

test("CometShuffleExchangeExec logical link should be correct") {
withTempView("v") {
spark.sparkContext
.parallelize((1 to 4).map(i => TestData(i, i.toString)), 2)
.toDF("c1", "c2")
.createOrReplaceTempView("v")

Seq(true, false).foreach { columnarShuffle =>
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> columnarShuffle.toString) {
val df = sql("SELECT * FROM v where c1 = 1 order by c1, c2")
val shuffle = find(df.queryExecution.executedPlan) {
case _: CometShuffleExchangeExec if columnarShuffle => true
case _: ShuffleExchangeExec if !columnarShuffle => true
case _ => false
}.get
assert(shuffle.logicalLink.isEmpty)
}
}
}
}

test("Ensure that the correct outputPartitioning of CometSort") {
withTable("test_data") {
val tableDF = spark.sparkContext
Expand Down Expand Up @@ -302,7 +325,8 @@ class CometExecSuite extends CometTestBase {
withSQLConf(
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true",
"spark.sql.autoBroadcastJoinThreshold" -> "0",
"spark.sql.adaptive.autoBroadcastJoinThreshold" -> "-1",
"spark.sql.autoBroadcastJoinThreshold" -> "-1",
"spark.sql.join.preferSortMergeJoin" -> "true") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl1") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl2") {
Expand Down Expand Up @@ -407,6 +431,7 @@ class CometExecSuite extends CometTestBase {
withSQLConf(
SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
withTable(tableName, dim) {
Expand Down Expand Up @@ -1340,3 +1365,5 @@ case class BucketedTableTestSpec(
expectedShuffle: Boolean = true,
expectedSort: Boolean = true,
expectedNumOutputPartitions: Option[Int] = None)

case class TestData(key: Int, value: String)
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class CometTPCHQuerySuite extends QueryTest with CometTPCBase with SQLQueryTestH
conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")
conf.set(CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key, "true")
conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
conf.set(MEMORY_OFFHEAP_SIZE.key, "2g")
}
Expand Down

0 comments on commit b326637

Please sign in to comment.