diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index de49fdfb0b..edb76a18bb 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -139,15 +139,6 @@ object CometConf { .booleanConf .createWithDefault(false) - val COMET_EXEC_BROADCAST_ENABLED: ConfigEntry[Boolean] = - conf(s"$COMET_EXEC_CONFIG_PREFIX.broadcast.enabled") - .doc( - "Whether to enable broadcasting for Comet native operators. By default, " + - "this config is false. Note that this feature is not fully supported yet " + - "and only enabled for test purpose.") - .booleanConf - .createWithDefault(false) - val COMET_EXEC_SHUFFLE_CODEC: ConfigEntry[String] = conf( s"$COMET_EXEC_CONFIG_PREFIX.shuffle.codec") .doc( diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 87c2265fcb..b3d70f816f 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} @@ -42,7 +41,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported} +import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde @@ -368,15 +367,24 @@ class CometSparkSessionExtensions u } - case b: BroadcastExchangeExec - if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") && - isCometBroadCastEnabled(conf) => - QueryPlanSerde.operator2Proto(b) match { - case Some(nativeOp) => - val cometOp = CometBroadcastExchangeExec(b, b.child) - CometSinkPlaceHolder(nativeOp, b, cometOp) - case None => b + // `CometBroadcastExchangeExec`'s broadcast output is not compatible with Spark's broadcast + // exchange. It is only used for Comet native execution. + case plan + if isCometNative(plan) && + plan.children.exists(_.isInstanceOf[BroadcastExchangeExec]) => + val newChildren = plan.children.map { + case b: BroadcastExchangeExec + if isCometNative(b.child) && + isCometOperatorEnabled(conf, "broadcastExchangeExec") => + QueryPlanSerde.operator2Proto(b) match { + case Some(nativeOp) => + val cometOp = CometBroadcastExchangeExec(b, b.child) + CometSinkPlaceHolder(nativeOp, b, cometOp) + case None => b + } + case other => other } + plan.withNewChildren(newChildren) // Native shuffle for Comet operators case s: ShuffleExchangeExec @@ -547,11 +555,9 @@ object CometSparkSessionExtensions extends Logging { private[comet] def isCometOperatorEnabled(conf: SQLConf, operator: String): Boolean = { val operatorFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.enabled" - conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) - } - - private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = { - COMET_EXEC_BROADCAST_ENABLED.get(conf) + val operatorDisabledFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.disabled" + conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) && + !conf.getConfString(operatorDisabledFlag, "false").toBoolean } private[comet] def isCometShuffleEnabled(conf: SQLConf): Boolean =