Skip to content

Commit

Permalink
fix: Remove original plan parameter from CometNativeExec
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jul 7, 2024
1 parent 335146e commit b272551
Show file tree
Hide file tree
Showing 145 changed files with 2,312 additions and 2,285 deletions.
178 changes: 118 additions & 60 deletions spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,13 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometProjectExec(
val newPlan = CometProjectExec(
nativeOp,
op,
op.projectList,
op.output,
op.projectList,
op.child,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -343,7 +343,9 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometFilterExec(nativeOp, op, op.condition, op.child, SerializedPlan(None))
val newPlan =
CometFilterExec(nativeOp, op.output, op.condition, op.child, SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -352,7 +354,15 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometSortExec(nativeOp, op, op.sortOrder, op.child, SerializedPlan(None))
val newPlan =
CometSortExec(
nativeOp,
op.output,
op.outputOrdering,
op.sortOrder,
op.child,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -361,7 +371,9 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometLocalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None))
val newPlan =
CometLocalLimitExec(nativeOp, op.limit, op.child, SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -370,7 +382,9 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometGlobalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None))
val newPlan =
CometGlobalLimitExec(nativeOp, op.limit, op.child, SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -382,8 +396,9 @@ class CometSparkSessionExtensions
QueryPlanSerde.operator2Proto(op) match {
case Some(nativeOp) =>
val offset = getOffset(op)
val cometOp =
CometCollectLimitExec(op, op.limit, offset, op.child)
val newPlan =
CometCollectLimitExec(op.limit, offset, op.child)
val cometOp = setLogicalLink(newPlan, op)
CometSinkPlaceHolder(nativeOp, op, cometOp)
case None =>
op
Expand All @@ -393,12 +408,28 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometExpandExec(nativeOp, op, op.projections, op.child, SerializedPlan(None))
val newPlan =
CometExpandExec(
nativeOp,
op.output,
op.projections,
op.child,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}

case op @ HashAggregateExec(_, _, _, groupingExprs, aggExprs, _, _, _, child) =>
case op @ HashAggregateExec(
_,
_,
_,
groupingExprs,
aggExprs,
_,
_,
resultExpressions,
child) =>
val modes = aggExprs.map(_.mode).distinct

if (!modes.isEmpty && modes.size != 1) {
Expand All @@ -422,15 +453,17 @@ class CometSparkSessionExtensions
// modes is empty too. If aggExprs is not empty, we need to verify all the
// aggregates have the same mode.
assert(modes.length == 1 || modes.length == 0)
CometHashAggregateExec(
val newPlan = CometHashAggregateExec(
nativeOp,
op,
op.output,
groupingExprs,
aggExprs,
resultExpressions,
child.output,
if (modes.nonEmpty) Some(modes.head) else None,
child,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -443,9 +476,10 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometHashJoinExec(
val newPlan = CometHashJoinExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.leftKeys,
op.rightKeys,
op.joinType,
Expand All @@ -454,6 +488,7 @@ class CometSparkSessionExtensions
op.left,
op.right,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -475,9 +510,10 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometBroadcastHashJoinExec(
val newPlan = CometBroadcastHashJoinExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.leftKeys,
op.rightKeys,
op.joinType,
Expand All @@ -486,6 +522,7 @@ class CometSparkSessionExtensions
op.left,
op.right,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand All @@ -496,16 +533,18 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometSortMergeJoinExec(
val newPlan = CometSortMergeJoinExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.leftKeys,
op.rightKeys,
op.joinType,
op.condition,
op.left,
op.right,
SerializedPlan(None))
setLogicalLink(newPlan, op)
case None =>
op
}
Expand Down Expand Up @@ -535,7 +574,8 @@ class CometSparkSessionExtensions
&& isCometNative(child) =>
QueryPlanSerde.operator2Proto(c) match {
case Some(nativeOp) =>
val cometOp = CometCoalesceExec(c, numPartitions, child)
val newPlan = CometCoalesceExec(c.output, numPartitions, child)
val cometOp = setLogicalLink(newPlan, c)
CometSinkPlaceHolder(nativeOp, c, cometOp)
case None =>
c
Expand All @@ -558,8 +598,14 @@ class CometSparkSessionExtensions
CometTakeOrderedAndProjectExec.isSupported(s) =>
QueryPlanSerde.operator2Proto(s) match {
case Some(nativeOp) =>
val cometOp =
CometTakeOrderedAndProjectExec(s, s.limit, s.sortOrder, s.projectList, s.child)
val newPlan =
CometTakeOrderedAndProjectExec(
s.output,
s.limit,
s.sortOrder,
s.projectList,
s.child)
val cometOp = setLogicalLink(newPlan, s)
CometSinkPlaceHolder(nativeOp, s, cometOp)
case None =>
s
Expand All @@ -579,8 +625,14 @@ class CometSparkSessionExtensions
val newOp = transform1(w)
newOp match {
case Some(nativeOp) =>
val cometOp =
CometWindowExec(w, w.windowExpression, w.partitionSpec, w.orderSpec, w.child)
val newPlan =
CometWindowExec(
w.output,
w.windowExpression,
w.partitionSpec,
w.orderSpec,
w.child)
val cometOp = setLogicalLink(newPlan, w)
CometSinkPlaceHolder(nativeOp, w, cometOp)
case None =>
w
Expand All @@ -591,7 +643,8 @@ class CometSparkSessionExtensions
u.children.forall(isCometNative) =>
QueryPlanSerde.operator2Proto(u) match {
case Some(nativeOp) =>
val cometOp = CometUnionExec(u, u.children)
val newPlan = CometUnionExec(u.output, u.children)
val cometOp = setLogicalLink(newPlan, u)
CometSinkPlaceHolder(nativeOp, u, cometOp)
case None =>
u
Expand Down Expand Up @@ -631,7 +684,8 @@ class CometSparkSessionExtensions
isSpark34Plus => // Spark 3.4+ only
QueryPlanSerde.operator2Proto(b) match {
case Some(nativeOp) =>
val cometOp = CometBroadcastExchangeExec(b, b.child)
val newPlan = CometBroadcastExchangeExec(b.output, b.child)
val cometOp = setLogicalLink(newPlan, b)
CometSinkPlaceHolder(nativeOp, b, cometOp)
case None => b
}
Expand Down Expand Up @@ -822,40 +876,6 @@ class CometSparkSessionExtensions
case CometScanWrapper(_, s) => s
}

// Set up logical links
newPlan = newPlan.transform {
case op: CometExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
case op: CometShuffleExchangeExec =>
// Original Spark shuffle exchange operator might have empty logical link.
// But the `setLogicalLink` call above on downstream operator of
// `CometShuffleExchangeExec` will set its logical link to the downstream
// operators which cause AQE behavior to be incorrect. So we need to unset
// the logical link here.
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op

case op: CometBroadcastExchangeExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
}

// Convert native execution block by linking consecutive native operators.
var firstNativeOp = true
newPlan.transformDown {
Expand Down Expand Up @@ -887,6 +907,44 @@ class CometSparkSessionExtensions
}.flatten
}

/**
* Set up logical links for transformed Comet operators.
*/
def setLogicalLink(newPlan: SparkPlan, originalPlan: SparkPlan): SparkPlan = {
newPlan match {
case op: CometExec =>
if (originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
case op: CometShuffleExchangeExec =>
// Original Spark shuffle exchange operator might have empty logical link.
// But the `setLogicalLink` call above on downstream operator of
// `CometShuffleExchangeExec` will set its logical link to the downstream
// operators which cause AQE behavior to be incorrect. So we need to unset
// the logical link here.
if (originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op

case op: CometBroadcastExchangeExec =>
if (originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
}
}

/**
* Returns true if a given spark plan is Comet shuffle operator.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
Expand Down Expand Up @@ -60,7 +61,9 @@ import org.apache.comet.CometRuntimeException
* Note that this only supports Spark 3.4 and later, because the serialization class
* `ChunkedByteBuffer` is only serializable in Spark 3.4 and later.
*/
case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
case class CometBroadcastExchangeExec(
override val output: Seq[Attribute],
override val child: SparkPlan)
extends BroadcastExchangeLike
with ShimCometBroadcastExchangeExec {
import CometBroadcastExchangeExec._
Expand All @@ -74,10 +77,6 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build"),
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast"))

override def doCanonicalize(): SparkPlan = {
CometBroadcastExchangeExec(originalPlan.canonicalized, child.canonicalized)
}

override def runtimeStatistics: Statistics = {
val dataSize = metrics("dataSize").value
val rowCount = metrics("numOutputRows").value
Expand Down Expand Up @@ -237,7 +236,7 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
override def equals(obj: Any): Boolean = {
obj match {
case other: CometBroadcastExchangeExec =>
this.originalPlan == other.originalPlan &&
this.output == other.output &&
this.child == other.child
case _ =>
false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.spark.sql.comet

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -31,7 +32,7 @@ import com.google.common.base.Objects
* more efficient when including it in a Comet query plan.
*/
case class CometCoalesceExec(
override val originalPlan: SparkPlan,
override val output: Seq[Attribute],
numPartitions: Int,
child: SparkPlan)
extends CometExec
Expand Down
Loading

0 comments on commit b272551

Please sign in to comment.