Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Jun 26, 2024
1 parent dc1b96c commit 8907e29
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ class CometCostEvaluator extends CostEvaluator with Logging {
0
}


val totalCost = operatorCost + transitionCost + childPlanCost

logWarning(s"total cost is $totalCost ($operatorCost + $transitionCost + $childPlanCost) " +
s"for ${plan.nodeName}")
logWarning(
s"total cost is $totalCost ($operatorCost + $transitionCost + $childPlanCost) " +
s"for ${plan.nodeName}")

totalCost
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -195,27 +195,18 @@ class CometSparkSessionExtensions

case class CometQueryStagePrepRule(session: SparkSession) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {


val newPlan = CometExecRule(session).apply(plan)


if (CometConf.COMET_CBO_ENABLED.get()) {
val costEvaluator = new CometCostEvaluator()
println(plan)
println(newPlan)
val sparkCost = costEvaluator.evaluateCost(plan)
val cometCost = costEvaluator.evaluateCost(newPlan)
println(s"sparkCost = $sparkCost, cometCost = $cometCost")
if (cometCost > sparkCost) {
val msg = s"Comet plan is more expensive than Spark plan ($cometCost > $sparkCost)" +
s"\nSPARK: $plan\n" +
s"\nCOMET: $newPlan\n"
logWarning(msg)
println(msg)
println(s"CometQueryStagePrepRule:\nIN: ${plan.getClass}\nOUT: ${plan.getClass}")

def fallbackRecursively(plan: SparkPlan) : Unit = {
def fallbackRecursively(plan: SparkPlan): Unit = {
plan.setTagValue(CANNOT_RUN_NATIVE, true)
plan match {
case a: AdaptiveSparkPlanExec => fallbackRecursively(a.inputPlan)
Expand All @@ -228,19 +219,13 @@ class CometSparkSessionExtensions
return plan
}
}


println(s"CometQueryStagePrepRule:\nIN: ${plan.getClass}\nOUT: ${newPlan.getClass}")

newPlan
}
}

case class CometPreColumnarRule(session: SparkSession) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
val newPlan = CometExecRule(session).apply(plan)
println(s"CometPreColumnarRule:\nIN: ${plan.getClass}\nOUT: ${newPlan.getClass}")
newPlan
CometExecRule(session).apply(plan)
}
}

Expand Down Expand Up @@ -780,7 +765,7 @@ class CometSparkSessionExtensions
if (!isCometEnabled(conf)) return plan

if (plan.getTagValue(CANNOT_RUN_NATIVE).getOrElse(false)) {
println("Cannot run native - too slow")
logWarning("Will not run plan natively because it may be slower")
return plan
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2301,7 +2301,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.addAllSortOrders(sortOrders.map(_.get).asJava)
Some(result.setSort(sortBuilder).build())
} else {
withInfo(op, "sort not allowed", sortOrder: _*)
withInfo(op, "sort order not supported", sortOrder: _*)
None
}

Expand Down
20 changes: 10 additions & 10 deletions spark/src/test/scala/org/apache/comet/CostBasedOptimizerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class CostBasedOptimizerSuite extends CometTestBase with AdaptiveSparkPlanHelper
private val dataGen = DataGenerator.DEFAULT

test("tbd") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false",
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
Expand All @@ -37,23 +38,22 @@ class CostBasedOptimizerSuite extends CometTestBase with AdaptiveSparkPlanHelper
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") {
val table = "t1"
withTable(table, "t2") {
sql(s"create table t1(col string, a int, b float) using parquet")
sql(s"create table t2(col string, a int, b float) using parquet")
sql("create table t1(col string, a int, b float) using parquet")
sql("create table t2(col string, a int, b float) using parquet")
val tableSchema = spark.table(table).schema
val rows = dataGen.generateRows(
1000,
tableSchema,
Some(() => dataGen.generateString("tbd:", 6)))
val rows =
dataGen.generateRows(1000, tableSchema, Some(() => dataGen.generateString("tbd:", 6)))
val data = spark.createDataFrame(spark.sparkContext.parallelize(rows), tableSchema)
data.write
.mode("append")
.insertInto(table)
data.write
.mode("append")
.insertInto("t2")
val x = checkSparkAnswer/*AndOperator*/("select t1.col as x " +
"from t1 join t2 on cast(t1.col as timestamp) = cast(t2.col as timestamp) " +
"order by x")
val x = checkSparkAnswer /*AndOperator*/ (
"select t1.col as x " +
"from t1 join t2 on cast(t1.col as timestamp) = cast(t2.col as timestamp) " +
"order by x")

// TODO assert that we fell back for whole plan
println(x._1)
Expand Down

0 comments on commit 8907e29

Please sign in to comment.