diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 8ef8cb83e..170854443 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -741,9 +741,37 @@ class CometSparkSessionExtensions } // Set up logical links - newPlan = newPlan.transform { case op: CometExec => - op.originalPlan.logicalLink.foreach(op.setLogicalLink) - op + newPlan = newPlan.transform { + case op: CometExec => + if (op.originalPlan.logicalLink.isEmpty) { + op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG) + op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG) + } else { + op.originalPlan.logicalLink.foreach(op.setLogicalLink) + } + op + case op: CometShuffleExchangeExec => + // Original Spark shuffle exchange operator might have empty logical link. + // But the `setLogicalLink` call above on downstream operator of + // `CometShuffleExchangeExec` will set its logical link to the downstream + // operators which cause AQE behavior to be incorrect. So we need to unset + // the logical link here. + if (op.originalPlan.logicalLink.isEmpty) { + op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG) + op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG) + } else { + op.originalPlan.logicalLink.foreach(op.setLogicalLink) + } + op + + case op: CometBroadcastExchangeExec => + if (op.originalPlan.logicalLink.isEmpty) { + op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG) + op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG) + } else { + op.originalPlan.logicalLink.foreach(op.setLogicalLink) + } + op } // Convert native execution block by linking consecutive native operators. diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 232b6bf17..fb2f2a209 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -50,6 +50,8 @@ import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} import org.apache.spark.util.random.XORShiftRandom +import com.google.common.base.Objects + import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde.serializeDataType @@ -61,6 +63,7 @@ import org.apache.comet.shims.ShimCometShuffleExchangeExec case class CometShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, + originalPlan: ShuffleExchangeLike, shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS, shuffleType: ShuffleType = CometNativeShuffle, advisoryPartitionSize: Option[Long] = None) @@ -192,6 +195,24 @@ case class CometShuffleExchangeExec( override protected def withNewChildInternal(newChild: SparkPlan): CometShuffleExchangeExec = copy(child = newChild) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometShuffleExchangeExec => + this.outputPartitioning == other.outputPartitioning && + this.shuffleOrigin == other.shuffleOrigin && this.child == other.child && + this.shuffleType == other.shuffleType && + this.advisoryPartitionSize == other.advisoryPartitionSize + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(outputPartitioning, shuffleOrigin, shuffleType, advisoryPartitionSize, child) + + override def stringArgs: Iterator[Any] = + Iterator(outputPartitioning, shuffleOrigin, shuffleType, child) ++ Iterator(s"[plan_id=$id]") } object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a8579757d..adbe412de 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -239,6 +239,7 @@ abstract class CometNativeExec extends CometExec { val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find { case (_: CometBroadcastExchangeExec, _) => false case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false + case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false case _ => true } @@ -263,6 +264,13 @@ abstract class CometNativeExec extends CometExec { inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() + case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => + inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() + case BroadcastQueryStageExec( + _, + ReusedExchangeExec(_, c: CometBroadcastExchangeExec), + _) => + inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() case _ if idx == firstNonBroadcastPlan.get._2 => inputs += firstNonBroadcastPlanRDD case _ => diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala index f89dbb8db..6b4fad974 100644 --- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala +++ b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala @@ -32,6 +32,7 @@ trait ShimCometShuffleExchangeExec { CometShuffleExchangeExec( s.outputPartitioning, s.child, + s, s.shuffleOrigin, shuffleType, advisoryPartitionSize) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index e5b3523dc..25a5fe72d 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} -import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.expressions.Window @@ -62,6 +62,29 @@ class CometExecSuite extends CometTestBase { } } + test("CometShuffleExchangeExec logical link should be correct") { + withTempView("v") { + spark.sparkContext + .parallelize((1 to 4).map(i => TestData(i, i.toString)), 2) + .toDF("c1", "c2") + .createOrReplaceTempView("v") + + Seq(true, false).foreach { columnarShuffle => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> columnarShuffle.toString) { + val df = sql("SELECT * FROM v where c1 = 1 order by c1, c2") + val shuffle = find(df.queryExecution.executedPlan) { + case _: CometShuffleExchangeExec if columnarShuffle => true + case _: ShuffleExchangeExec if !columnarShuffle => true + case _ => false + }.get + assert(shuffle.logicalLink.isEmpty) + } + } + } + } + test("Ensure that the correct outputPartitioning of CometSort") { withTable("test_data") { val tableDF = spark.sparkContext @@ -302,7 +325,8 @@ class CometExecSuite extends CometTestBase { withSQLConf( CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", - "spark.sql.autoBroadcastJoinThreshold" -> "0", + "spark.sql.adaptive.autoBroadcastJoinThreshold" -> "-1", + "spark.sql.autoBroadcastJoinThreshold" -> "-1", "spark.sql.join.preferSortMergeJoin" -> "true") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl1") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl2") { @@ -373,6 +397,7 @@ class CometExecSuite extends CometTestBase { withSQLConf( SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { withTable(tableName, dim) { @@ -1306,3 +1331,5 @@ case class BucketedTableTestSpec( expectedShuffle: Boolean = true, expectedSort: Boolean = true, expectedNumOutputPartitions: Option[Int] = None) + +case class TestData(key: Int, value: String) diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala index 6ec25dd19..8d7111e8f 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala @@ -91,6 +91,7 @@ class CometTPCHQuerySuite extends QueryTest with CometTPCBase with SQLQueryTestH conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") + conf.set(CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key, "true") conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") conf.set(MEMORY_OFFHEAP_SIZE.key, "2g") }