Skip to content

Commit

Permalink
feat: Add support of TakeOrderedAndProjectExec in Comet
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Feb 22, 2024
1 parent 637dba9 commit f084203
Show file tree
Hide file tree
Showing 3 changed files with 309 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,20 @@ class CometSparkSessionExtensions
c
}

case s: TakeOrderedAndProjectExec
if isCometNative(s.child) && isCometOperatorEnabled(conf, "takeOrderedAndProjectExec")
&& isCometShuffleEnabled(conf) &&
CometTakeOrderedAndProjectExec.isSupported(s.projectList, s.sortOrder, s.child) =>
// TODO: support offset for Spark 3.4
QueryPlanSerde.operator2Proto(s) match {
case Some(nativeOp) =>
val cometOp =
CometTakeOrderedAndProjectExec(s, s.limit, s.sortOrder, s.projectList, s.child)
CometSinkPlaceHolder(nativeOp, s, cometOp)
case None =>
s
}

case u: UnionExec
if isCometOperatorEnabled(conf, "union") &&
u.children.forall(isCometNative) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet

import scala.collection.JavaConverters.asJavaIterableConverter

import org.apache.spark.rdd.{ParallelCollectionRDD, RDD}
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode, UnsafeRowSerializer}
import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.serde.OperatorOuterClass
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}

/**
* Comet physical plan node for Spark `TakeOrderedAndProjectExec`.
*
* It is used to execute a `TakeOrderedAndProjectExec` physical operator by using Comet native
* engine. It is not like other physical plan nodes which are wrapped by `CometExec`, because it
* contains two native executions separated by a Comet shuffle exchange.
*/
case class CometTakeOrderedAndProjectExec(
override val originalPlan: SparkPlan,
limit: Int,
sortOrder: Seq[SortOrder],
projectList: Seq[NamedExpression],
child: SparkPlan)
extends CometExec
with UnaryExecNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)

private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private lazy val readMetrics =
SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics: Map[String, SQLMetric] = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"shuffleReadElapsedCompute" ->
SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"),
"numPartitions" -> SQLMetrics.createMetric(
sparkContext,
"number of partitions")) ++ readMetrics ++ writeMetrics

private lazy val serializer: Serializer =
new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))

// Exposed for testing.
lazy val orderingSatisfies: Boolean =
SortOrder.orderingSatisfies(child.outputOrdering, sortOrder)

protected override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val childRDD = child.executeColumnar()
if (childRDD.getNumPartitions == 0) {
new ParallelCollectionRDD(sparkContext, Seq.empty[ColumnarBatch], 1, Map.empty)
} else {
val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
childRDD
} else {
val localTopK = if (orderingSatisfies) {
childRDD.mapPartitionsInternal { iter =>
val limitOp =
CometTakeOrderedAndProjectExec.getLimitNativePlan(output, limit).get
CometExec.getCometIterator(Seq(iter), limitOp)
}
} else {
childRDD.mapPartitionsInternal { iter =>
val topK =
CometTakeOrderedAndProjectExec
.getTopKNativePlan(output, sortOrder, child, limit)
.get
CometExec.getCometIterator(Seq(iter), topK)
}
}

// Shuffle to Single Partition using Comet native shuffle
val dep = CometShuffleExchangeExec.prepareShuffleDependency(
localTopK,
child.output,
outputPartitioning,
serializer,
metrics)
metrics("numPartitions").set(dep.partitioner.numPartitions)

new CometShuffledBatchRDD(dep, readMetrics)
}

singlePartitionRDD.mapPartitionsInternal { iter =>
val topKAndProjection = CometTakeOrderedAndProjectExec
.getProjectionNativePlan(projectList, output, sortOrder, child, limit)
.get
CometExec.getCometIterator(Seq(iter), topKAndProjection)
}
}
}

override def simpleString(maxFields: Int): String = {
val orderByString = truncatedString(sortOrder, "[", ",", "]", maxFields)
val outputString = truncatedString(output, "[", ",", "]", maxFields)

s"CometTakeOrderedAndProjectExec(limit=$limit, orderBy=$orderByString, output=$outputString)"
}

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
}

object CometTakeOrderedAndProjectExec {
def isSupported(
projectList: Seq[NamedExpression],
sortOrder: Seq[SortOrder],
child: SparkPlan): Boolean = {
val exprs = projectList.map(exprToProto(_, child.output))
val sortOrders = sortOrder.map(exprToProto(_, child.output))
exprs.forall(_.isDefined) && sortOrders.forall(_.isDefined)
}

/**
* Prepare Projection + TopK native plan for CometTakeOrderedAndProjectExec.
*/
def getProjectionNativePlan(
projectList: Seq[NamedExpression],
outputAttributes: Seq[Attribute],
sortOrder: Seq[SortOrder],
child: SparkPlan,
limit: Int): Option[Operator] = {
getTopKNativePlan(outputAttributes, sortOrder, child, limit).flatMap { topK =>
val exprs = projectList.map(exprToProto(_, child.output))

if (exprs.forall(_.isDefined)) {
val projectBuilder = OperatorOuterClass.Projection.newBuilder()
projectBuilder.addAllProjectList(exprs.map(_.get).asJava)
val opBuilder = OperatorOuterClass.Operator
.newBuilder()
.addChildren(topK)
Some(opBuilder.setProjection(projectBuilder).build())
} else {
None
}
}
}

def getLimitNativePlan(outputAttributes: Seq[Attribute], limit: Int): Option[Operator] = {
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
val scanOpBuilder = OperatorOuterClass.Operator.newBuilder()

val scanTypes = outputAttributes.flatten { attr =>
serializeDataType(attr.dataType)
}

if (scanTypes.length == outputAttributes.length) {
scanBuilder.addAllFields(scanTypes.asJava)

val limitBuilder = OperatorOuterClass.Limit.newBuilder()
limitBuilder.setLimit(limit)

val limitOpBuilder = OperatorOuterClass.Operator
.newBuilder()
.addChildren(scanOpBuilder.setScan(scanBuilder))
Some(limitOpBuilder.setLimit(limitBuilder).build())
} else {
None
}
}

/**
* Prepare TopK native plan for CometTakeOrderedAndProjectExec.
*/
def getTopKNativePlan(
outputAttributes: Seq[Attribute],
sortOrder: Seq[SortOrder],
child: SparkPlan,
limit: Int): Option[Operator] = {
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
val scanOpBuilder = OperatorOuterClass.Operator.newBuilder()

val scanTypes = outputAttributes.flatten { attr =>
serializeDataType(attr.dataType)
}

if (scanTypes.length == outputAttributes.length) {
scanBuilder.addAllFields(scanTypes.asJava)

val sortOrders = sortOrder.map(exprToProto(_, child.output))

if (sortOrders.forall(_.isDefined)) {
val sortBuilder = OperatorOuterClass.Sort.newBuilder()
sortBuilder.addAllSortOrders(sortOrders.map(_.get).asJava)
sortBuilder.setFetch(limit)

val sortOpBuilder = OperatorOuterClass.Operator
.newBuilder()
.addChildren(scanOpBuilder.setScan(scanBuilder))
Some(sortOpBuilder.setSort(sortBuilder).build())
} else {
None
}
} else {
None
}
}
}
73 changes: 71 additions & 2 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, UnionExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.functions.{date_add, expr}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
Expand Down Expand Up @@ -964,6 +965,74 @@ class CometExecSuite extends CometTestBase {
}
}
}

test("TakeOrderedAndProjectExec") {
Seq("true", "false").foreach(aqeEnabled =>
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) {
withTable("t1") {
val numRows = 10
spark
.range(numRows)
.selectExpr("if (id % 2 = 0, null, id) AS a", s"$numRows - id AS b")
.repartition(3) // Force repartition to test data will come to single partition
.write
.saveAsTable("t1")

val df1 = spark.sql("""
|SELECT a, b, ROW_NUMBER() OVER(ORDER BY a, b) AS rn
|FROM t1 LIMIT 3
|""".stripMargin)

assert(df1.rdd.getNumPartitions == 1)
checkSparkAnswerAndOperator(df1, classOf[WindowExec])

val df2 = spark.sql("""
|SELECT b, RANK() OVER(ORDER BY a, b) AS rk, DENSE_RANK(b) OVER(ORDER BY a, b) AS s
|FROM t1 LIMIT 2
|""".stripMargin)
assert(df2.rdd.getNumPartitions == 1)
checkSparkAnswerAndOperator(df2, classOf[WindowExec], classOf[ProjectExec])

// Other Comet native operator can take input from `CometTakeOrderedAndProjectExec`.
val df3 = sql("SELECT * FROM t1 ORDER BY a, b LIMIT 3").groupBy($"a").sum("b")
checkSparkAnswerAndOperator(df3)
}
})
}

test("TakeOrderedAndProjectExec without sorting") {
Seq("true", "false").foreach(aqeEnabled =>
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled,
SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
"org.apache.spark.sql.catalyst.optimizer.EliminateSorts") {
withTable("t1") {
val numRows = 10
spark
.range(numRows)
.selectExpr("if (id % 2 = 0, null, id) AS a", s"$numRows - id AS b")
.repartition(3) // Force repartition to test data will come to single partition
.write
.saveAsTable("t1")

val df = spark
.table("t1")
.select("a", "b")
.sortWithinPartitions("b", "a")
.orderBy("b")
.select($"b" + 1, $"a")
.limit(3)

val takeOrdered = stripAQEPlan(df.queryExecution.executedPlan).collect {
case b: CometTakeOrderedAndProjectExec => b
}
assert(takeOrdered.length == 1)
assert(takeOrdered.head.orderingSatisfies)

checkSparkAnswerAndOperator(df)
}
})
}
}

case class BucketedTableTestSpec(
Expand Down

0 comments on commit f084203

Please sign in to comment.