diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index b8b1a7a2f2..589a1d4fd1 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -320,6 +320,20 @@ class CometSparkSessionExtensions c } + case s: TakeOrderedAndProjectExec + if isCometNative(s.child) && isCometOperatorEnabled(conf, "takeOrderedAndProjectExec") + && isCometShuffleEnabled(conf) && + CometTakeOrderedAndProjectExec.isSupported(s.projectList, s.sortOrder, s.child) => + // TODO: support offset for Spark 3.4 + QueryPlanSerde.operator2Proto(s) match { + case Some(nativeOp) => + val cometOp = + CometTakeOrderedAndProjectExec(s, s.limit, s.sortOrder, s.projectList, s.child) + CometSinkPlaceHolder(nativeOp, s, cometOp) + case None => + s + } + case u: UnionExec if isCometOperatorEnabled(conf, "union") && u.children.forall(isCometNative) => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala new file mode 100644 index 0000000000..f46c6e1ec3 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.JavaConverters.asJavaIterableConverter + +import org.apache.spark.rdd.{ParallelCollectionRDD, RDD} +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode, UnsafeRowSerializer} +import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.serde.OperatorOuterClass +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} + +/** + * Comet physical plan node for Spark `TakeOrderedAndProjectExec`. + * + * It is used to execute a `TakeOrderedAndProjectExec` physical operator by using Comet native + * engine. It is not like other physical plan nodes which are wrapped by `CometExec`, because it + * contains two native executions separated by a Comet shuffle exchange. + */ +case class CometTakeOrderedAndProjectExec( + override val originalPlan: SparkPlan, + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Seq[NamedExpression], + child: SparkPlan) + extends CometExec + with UnaryExecNode { + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "shuffleReadElapsedCompute" -> + SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"), + "numPartitions" -> SQLMetrics.createMetric( + sparkContext, + "number of partitions")) ++ readMetrics ++ writeMetrics + + private lazy val serializer: Serializer = + new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + + // Exposed for testing. + lazy val orderingSatisfies: Boolean = + SortOrder.orderingSatisfies(child.outputOrdering, sortOrder) + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val childRDD = child.executeColumnar() + if (childRDD.getNumPartitions == 0) { + new ParallelCollectionRDD(sparkContext, Seq.empty[ColumnarBatch], 1, Map.empty) + } else { + val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { + childRDD + } else { + val localTopK = if (orderingSatisfies) { + childRDD.mapPartitionsInternal { iter => + val limitOp = + CometTakeOrderedAndProjectExec.getLimitNativePlan(output, limit).get + CometExec.getCometIterator(Seq(iter), limitOp) + } + } else { + childRDD.mapPartitionsInternal { iter => + val topK = + CometTakeOrderedAndProjectExec + .getTopKNativePlan(output, sortOrder, child, limit) + .get + CometExec.getCometIterator(Seq(iter), topK) + } + } + + // Shuffle to Single Partition using Comet native shuffle + val dep = CometShuffleExchangeExec.prepareShuffleDependency( + localTopK, + child.output, + outputPartitioning, + serializer, + metrics) + metrics("numPartitions").set(dep.partitioner.numPartitions) + + new CometShuffledBatchRDD(dep, readMetrics) + } + + singlePartitionRDD.mapPartitionsInternal { iter => + val topKAndProjection = CometTakeOrderedAndProjectExec + .getProjectionNativePlan(projectList, output, sortOrder, child, limit) + .get + CometExec.getCometIterator(Seq(iter), topKAndProjection) + } + } + } + + override def simpleString(maxFields: Int): String = { + val orderByString = truncatedString(sortOrder, "[", ",", "]", maxFields) + val outputString = truncatedString(output, "[", ",", "]", maxFields) + + s"CometTakeOrderedAndProjectExec(limit=$limit, orderBy=$orderByString, output=$outputString)" + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + this.copy(child = newChild) +} + +object CometTakeOrderedAndProjectExec { + def isSupported( + projectList: Seq[NamedExpression], + sortOrder: Seq[SortOrder], + child: SparkPlan): Boolean = { + val exprs = projectList.map(exprToProto(_, child.output)) + val sortOrders = sortOrder.map(exprToProto(_, child.output)) + exprs.forall(_.isDefined) && sortOrders.forall(_.isDefined) + } + + /** + * Prepare Projection + TopK native plan for CometTakeOrderedAndProjectExec. + */ + def getProjectionNativePlan( + projectList: Seq[NamedExpression], + outputAttributes: Seq[Attribute], + sortOrder: Seq[SortOrder], + child: SparkPlan, + limit: Int): Option[Operator] = { + getTopKNativePlan(outputAttributes, sortOrder, child, limit).flatMap { topK => + val exprs = projectList.map(exprToProto(_, child.output)) + + if (exprs.forall(_.isDefined)) { + val projectBuilder = OperatorOuterClass.Projection.newBuilder() + projectBuilder.addAllProjectList(exprs.map(_.get).asJava) + val opBuilder = OperatorOuterClass.Operator + .newBuilder() + .addChildren(topK) + Some(opBuilder.setProjection(projectBuilder).build()) + } else { + None + } + } + } + + def getLimitNativePlan(outputAttributes: Seq[Attribute], limit: Int): Option[Operator] = { + val scanBuilder = OperatorOuterClass.Scan.newBuilder() + val scanOpBuilder = OperatorOuterClass.Operator.newBuilder() + + val scanTypes = outputAttributes.flatten { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == outputAttributes.length) { + scanBuilder.addAllFields(scanTypes.asJava) + + val limitBuilder = OperatorOuterClass.Limit.newBuilder() + limitBuilder.setLimit(limit) + + val limitOpBuilder = OperatorOuterClass.Operator + .newBuilder() + .addChildren(scanOpBuilder.setScan(scanBuilder)) + Some(limitOpBuilder.setLimit(limitBuilder).build()) + } else { + None + } + } + + /** + * Prepare TopK native plan for CometTakeOrderedAndProjectExec. + */ + def getTopKNativePlan( + outputAttributes: Seq[Attribute], + sortOrder: Seq[SortOrder], + child: SparkPlan, + limit: Int): Option[Operator] = { + val scanBuilder = OperatorOuterClass.Scan.newBuilder() + val scanOpBuilder = OperatorOuterClass.Operator.newBuilder() + + val scanTypes = outputAttributes.flatten { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == outputAttributes.length) { + scanBuilder.addAllFields(scanTypes.asJava) + + val sortOrders = sortOrder.map(exprToProto(_, child.output)) + + if (sortOrders.forall(_.isDefined)) { + val sortBuilder = OperatorOuterClass.Sort.newBuilder() + sortBuilder.addAllSortOrders(sortOrders.map(_.get).asJava) + sortBuilder.setFetch(limit) + + val sortOpBuilder = OperatorOuterClass.Operator + .newBuilder() + .addChildren(scanOpBuilder.setScan(scanBuilder)) + Some(sortOpBuilder.setSort(sortBuilder).build()) + } else { + None + } + } else { + None + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index d3a1bd2c95..e675026a0a 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -31,11 +31,12 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{CollectLimitExec, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.functions.{date_add, expr} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE @@ -964,6 +965,74 @@ class CometExecSuite extends CometTestBase { } } } + + test("TakeOrderedAndProjectExec") { + Seq("true", "false").foreach(aqeEnabled => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) { + withTable("t1") { + val numRows = 10 + spark + .range(numRows) + .selectExpr("if (id % 2 = 0, null, id) AS a", s"$numRows - id AS b") + .repartition(3) // Force repartition to test data will come to single partition + .write + .saveAsTable("t1") + + val df1 = spark.sql(""" + |SELECT a, b, ROW_NUMBER() OVER(ORDER BY a, b) AS rn + |FROM t1 LIMIT 3 + |""".stripMargin) + + assert(df1.rdd.getNumPartitions == 1) + checkSparkAnswerAndOperator(df1, classOf[WindowExec]) + + val df2 = spark.sql(""" + |SELECT b, RANK() OVER(ORDER BY a, b) AS rk, DENSE_RANK(b) OVER(ORDER BY a, b) AS s + |FROM t1 LIMIT 2 + |""".stripMargin) + assert(df2.rdd.getNumPartitions == 1) + checkSparkAnswerAndOperator(df2, classOf[WindowExec], classOf[ProjectExec]) + + // Other Comet native operator can take input from `CometTakeOrderedAndProjectExec`. + val df3 = sql("SELECT * FROM t1 ORDER BY a, b LIMIT 3").groupBy($"a").sum("b") + checkSparkAnswerAndOperator(df3) + } + }) + } + + test("TakeOrderedAndProjectExec without sorting") { + Seq("true", "false").foreach(aqeEnabled => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled, + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.EliminateSorts") { + withTable("t1") { + val numRows = 10 + spark + .range(numRows) + .selectExpr("if (id % 2 = 0, null, id) AS a", s"$numRows - id AS b") + .repartition(3) // Force repartition to test data will come to single partition + .write + .saveAsTable("t1") + + val df = spark + .table("t1") + .select("a", "b") + .sortWithinPartitions("b", "a") + .orderBy("b") + .select($"b" + 1, $"a") + .limit(3) + + val takeOrdered = stripAQEPlan(df.queryExecution.executedPlan).collect { + case b: CometTakeOrderedAndProjectExec => b + } + assert(takeOrdered.length == 1) + assert(takeOrdered.head.orderingSatisfies) + + checkSparkAnswerAndOperator(df) + } + }) + } } case class BucketedTableTestSpec(