Skip to content

Commit

Permalink
Fix more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 30, 2024
1 parent ae5ca18 commit edfce1f
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 3 deletions.
267 changes: 265 additions & 2 deletions dev/diffs/3.4.2.diff
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,34 @@ index 00000000000..4b31bea33de
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
index 1792b4c32eb..1616e6f39bd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.comet.{CometHashJoinExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
@@ -362,6 +363,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP
val executedPlan = df.queryExecution.executedPlan
val shuffleHashJoins = collect(executedPlan) {
case s: ShuffledHashJoinExec => s
+ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[ShuffledHashJoinExec]
}
assert(shuffleHashJoins.size == 1)
assert(shuffleHashJoins.head.buildSide == buildSide)
@@ -371,6 +373,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP
val executedPlan = df.queryExecution.executedPlan
val shuffleMergeJoins = collect(executedPlan) {
case s: SortMergeJoinExec => s
+ case c: CometSortMergeJoinExec => c
}
assert(shuffleMergeJoins.size == 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 5125708be32..123f58c522a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Expand Down Expand Up @@ -794,11 +822,168 @@ index 9e9d717db3b..91a4f9a38d5 100644
assert(actual == expected)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
index 30ce940b032..0d3f6c6c934 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning}
+import org.apache.spark.sql.comet.CometSortExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.joins.ShuffledJoin
import org.apache.spark.sql.internal.SQLConf
@@ -33,7 +34,7 @@ abstract class RemoveRedundantSortsSuiteBase

private def checkNumSorts(df: DataFrame, count: Int): Unit = {
val plan = df.queryExecution.executedPlan
- assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count)
+ assert(collectWithSubqueries(plan) { case _: SortExec | _: CometSortExec => 1 }.length == count)
}

private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
index 47679ed7865..9ffbaecb98e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.internal.SQLConf
@@ -31,7 +32,7 @@ abstract class ReplaceHashWithSortAggSuiteBase
private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = {
val plan = df.queryExecution.executedPlan
assert(collectWithSubqueries(plan) {
- case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s
+ case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec | _: CometHashAggregateExec ) => s
}.length == hashAggCount)
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index ac710c32296..37746bd470d 100644
index ac710c32296..88a5329e74e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -616,7 +616,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution

import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator}
+import org.apache.spark.sql.comet.{CometProjectExec, CometSortExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
@@ -192,7 +193,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val joinUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2", "full_outer")
assert(joinUniqueDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
- case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(_: CometSortMergeJoinExec)))
+ if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4),
Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9)))
@@ -202,7 +204,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val joinNonUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2" % 3, "full_outer")
assert(joinNonUniqueDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
- case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(_: CometSortMergeJoinExec)))
+ if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null)))
@@ -224,6 +227,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 2)
checkAnswer(twoJoinsDF,
Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null),
@@ -240,7 +244,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
// test one left outer sort merge join
val oneLeftOuterJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_outer")
assert(oneLeftOuterJoinDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(_: CometSortMergeJoinExec))) => true
}.size === 1)
checkAnswer(oneLeftOuterJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, null),
Row(5, null), Row(6, null), Row(7, null), Row(8, null), Row(9, null)))
@@ -248,7 +252,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
// test one right outer sort merge join
val oneRightOuterJoinDF = df2.join(df3.hint("SHUFFLE_MERGE"), $"k2" === $"k3", "right_outer")
assert(oneRightOuterJoinDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(_: CometSortMergeJoinExec))) => true
}.size === 1)
checkAnswer(oneRightOuterJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(null, 4),
Row(null, 5)))
@@ -258,6 +262,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "right_outer")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF,
Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, null, 4), Row(5, null, 5),
@@ -272,7 +277,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
// test one left semi sort merge join
val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_semi")
assert(oneJoinDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) => true
+ case WholeStageCodegenExec(
+ ColumnarToRowExec(InputAdapter(
+ CometProjectExec(_, _, _, _, _: CometSortMergeJoinExec, _)))) => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(0), Row(1), Row(2), Row(3)))

@@ -280,8 +287,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val twoJoinsDF = df3.join(df2.hint("SHUFFLE_MERGE"), $"k3" === $"k2", "left_semi")
.join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "left_semi")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
- WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF, Seq(Row(0), Row(1), Row(2), Row(3)))
}
@@ -294,7 +300,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
// test one left anti sort merge join
val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti")
assert(oneJoinDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) => true
+ case WholeStageCodegenExec(
+ ColumnarToRowExec(InputAdapter(
+ CometProjectExec(_, _, _, _, _: CometSortMergeJoinExec, _)))) => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(4), Row(5), Row(6), Row(7), Row(8), Row(9)))

@@ -302,8 +310,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val twoJoinsDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti")
.join(df3.hint("SHUFFLE_MERGE"), $"k1" === $"k3", "left_anti")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
- WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF, Seq(Row(6), Row(7), Row(8), Row(9)))
}
@@ -436,7 +443,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]))
+ p.asInstanceOf[WholeStageCodegenExec].collect {
+ case _: CometSortExec => true
+ }.nonEmpty))
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}

@@ -616,7 +625,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.write.mode(SaveMode.Overwrite).parquet(path)

withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
Expand Down Expand Up @@ -1606,6 +1791,84 @@ index ef5b8a769fe..84fe1bfabc9 100644
s"Local limit was not LocalLimitExec:\n$execPlan")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
index b4c4ec7acbf..20579284856 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.commons.io.FileUtils
import org.scalatest.Assertions

import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
+import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.streaming.{MemoryStream, StateStoreRestoreExec, StateStoreSaveExec}
import org.apache.spark.sql.functions.count
@@ -67,6 +68,7 @@ class StreamingAggregationDistributionSuite extends StreamTest
// verify aggregations in between, except partial aggregation
val allAggregateExecs = query.lastExecution.executedPlan.collect {
case a: BaseAggregateExec => a
+ case c: CometHashAggregateExec => c.originalPlan
}

val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
@@ -201,6 +203,7 @@ class StreamingAggregationDistributionSuite extends StreamTest
// verify aggregations in between, except partial aggregation
val allAggregateExecs = executedPlan.collect {
case a: BaseAggregateExec => a
+ case c: CometHashAggregateExec => c.originalPlan
}

val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index 4d92e270539..33f1c2eb75e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper}
import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId}
import org.apache.spark.sql.functions._
@@ -619,14 +619,28 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite {

val numPartitions = spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS)

- assert(query.lastExecution.executedPlan.collect {
- case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _,
- ShuffleExchangeExec(opA: HashPartitioning, _, _),
- ShuffleExchangeExec(opB: HashPartitioning, _, _))
- if partitionExpressionsColumns(opA.expressions) === Seq("a", "b")
- && partitionExpressionsColumns(opB.expressions) === Seq("a", "b")
- && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j
- }.size == 1)
+ val join = query.lastExecution.executedPlan.collect {
+ case j: StreamingSymmetricHashJoinExec => j
+ }.head
+ val opA = join.left.collect {
+ case s: ShuffleExchangeLike
+ if s.outputPartitioning.isInstanceOf[HashPartitioning] &&
+ partitionExpressionsColumns(
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") =>
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning]
+ }.head
+ val opB = join.right.collect {
+ case s: ShuffleExchangeLike
+ if s.outputPartitioning.isInstanceOf[HashPartitioning] &&
+ partitionExpressionsColumns(
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") =>
+ s.outputPartitioning
+ .asInstanceOf[HashPartitioning]
+ }.head
+ assert(opA.numPartitions == numPartitions && opB.numPartitions == numPartitions)
})
}

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
index abe606ad9c1..2d930b64cca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.trees.CurrentOrigin
trait AliasAwareOutputExpression extends SQLConfHelper {
// `SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT` is Spark 3.4+ only.
// Use a default value for now.
protected val aliasCandidateLimit = 100
protected val aliasCandidateLimit =
conf.getConfString("spark.sql.optimizer.expressionProjectionCandidateLimit", "100").toInt
protected def outputExpressions: Seq[NamedExpression]

/**
Expand Down

0 comments on commit edfce1f

Please sign in to comment.