Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: CometShuffleExchangeExec logical link should be correct #324

Merged
merged 8 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,37 @@ class CometSparkSessionExtensions
}

// Set up logical links
newPlan = newPlan.transform { case op: CometExec =>
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
op
newPlan = newPlan.transform {
case op: CometExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
case op: CometShuffleExchangeExec =>
// Original Spark shuffle exchange operator might have empty logical link.
// But the `setLogicalLink` call above on downstream operator of
// `CometShuffleExchangeExec` will set its logical link to the downstream
// operators which cause AQE behavior to be incorrect. So we need to unset
// the logical link here.
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op

case op: CometBroadcastExchangeExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
}

// Convert native execution block by linking consecutive native operators.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
import org.apache.spark.util.random.XORShiftRandom

import com.google.common.base.Objects

import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde.serializeDataType
Expand All @@ -61,6 +63,7 @@ import org.apache.comet.shims.ShimCometShuffleExchangeExec
case class CometShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
originalPlan: ShuffleExchangeLike,
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS,
shuffleType: ShuffleType = CometNativeShuffle,
advisoryPartitionSize: Option[Long] = None)
Expand Down Expand Up @@ -192,6 +195,24 @@ case class CometShuffleExchangeExec(

override protected def withNewChildInternal(newChild: SparkPlan): CometShuffleExchangeExec =
copy(child = newChild)

override def equals(obj: Any): Boolean = {
obj match {
case other: CometShuffleExchangeExec =>
this.outputPartitioning == other.outputPartitioning &&
this.shuffleOrigin == other.shuffleOrigin && this.child == other.child &&
this.shuffleType == other.shuffleType &&
this.advisoryPartitionSize == other.advisoryPartitionSize
case _ =>
false
}
}
Comment on lines +199 to +209
Copy link
Member Author

@viirya viirya Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because CometShuffleExchangeExec adds originalPlan parameter which is not covered by canonicalization in Spark, we need to exclude it when compare two CometShuffleExchangeExec to make sure Spark reuse shuffle rule work.


override def hashCode(): Int =
Objects.hashCode(outputPartitioning, shuffleOrigin, shuffleType, advisoryPartitionSize, child)

override def stringArgs: Iterator[Any] =
Iterator(outputPartitioning, shuffleOrigin, shuffleType, child) ++ Iterator(s"[plan_id=$id]")
Comment on lines +214 to +215
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow Spark Exchange.stringArgs.

}

object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ abstract class CometNativeExec extends CometExec {
val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find {
case (_: CometBroadcastExchangeExec, _) => false
case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false
case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false
case _ => true
}

Expand All @@ -263,6 +264,13 @@ abstract class CometNativeExec extends CometExec {
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case BroadcastQueryStageExec(
_,
ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
_) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
Comment on lines +267 to +273
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are necessary but missed to add previously. This fix exposes that.

case _ if idx == firstNonBroadcastPlan.get._2 =>
inputs += firstNonBroadcastPlanRDD
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ trait ShimCometShuffleExchangeExec {
CometShuffleExchangeExec(
s.outputPartitioning,
s.child,
s,
s.shuffleOrigin,
shuffleType,
advisoryPartitionSize)
Expand Down
31 changes: 29 additions & 2 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.Window
Expand All @@ -62,6 +62,29 @@ class CometExecSuite extends CometTestBase {
}
}

test("CometShuffleExchangeExec logical link should be correct") {
withTempView("v") {
spark.sparkContext
.parallelize((1 to 4).map(i => TestData(i, i.toString)), 2)
.toDF("c1", "c2")
.createOrReplaceTempView("v")

Seq(true, false).foreach { columnarShuffle =>
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> columnarShuffle.toString) {
val df = sql("SELECT * FROM v where c1 = 1 order by c1, c2")
val shuffle = find(df.queryExecution.executedPlan) {
case _: CometShuffleExchangeExec if columnarShuffle => true
case _: ShuffleExchangeExec if !columnarShuffle => true
case _ => false
}.get
assert(shuffle.logicalLink.isEmpty)
}
}
}
}

test("Ensure that the correct outputPartitioning of CometSort") {
withTable("test_data") {
val tableDF = spark.sparkContext
Expand Down Expand Up @@ -302,7 +325,8 @@ class CometExecSuite extends CometTestBase {
withSQLConf(
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true",
"spark.sql.autoBroadcastJoinThreshold" -> "0",
"spark.sql.adaptive.autoBroadcastJoinThreshold" -> "-1",
"spark.sql.autoBroadcastJoinThreshold" -> "-1",
Comment on lines +328 to +329
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Query plan is changed due to the logical link fix. In order to have CometSortMergeJoin, we need to disable broadcast join and AQE broadcast join.

"spark.sql.join.preferSortMergeJoin" -> "true") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl1") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl2") {
Expand Down Expand Up @@ -373,6 +397,7 @@ class CometExecSuite extends CometTestBase {
withSQLConf(
SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
Copy link
Member Author

@viirya viirya Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar reason here. The query plan is changed and AQE interferes with a broadcast join.

CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
withTable(tableName, dim) {
Expand Down Expand Up @@ -1306,3 +1331,5 @@ case class BucketedTableTestSpec(
expectedShuffle: Boolean = true,
expectedSort: Boolean = true,
expectedNumOutputPartitions: Option[Int] = None)

case class TestData(key: Int, value: String)
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class CometTPCHQuerySuite extends QueryTest with CometTPCBase with SQLQueryTestH
conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")
conf.set(CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key, "true")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #336

conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
conf.set(MEMORY_OFFHEAP_SIZE.key, "2g")
}
Expand Down
Loading