diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 0600bb2a2bbc7..ad9f7697f680c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -30,14 +30,4 @@ class ExchangeSuite extends SparkPlanTest { input.map(Row.fromTuple) ) } - - test("EnsureRequirements shouldn't add exchange to SMJ inputs if both are SinglePartition") { - val df = (1 to 10).map(Tuple1.apply).toDF("a").repartition(1) - val keys = Seq(df.col("a").expr) - val smj = SortMergeJoin(keys, keys, df.queryExecution.sparkPlan, df.queryExecution.sparkPlan) - val afterEnsureRequirements = EnsureRequirements(df.sqlContext).apply(smj) - if (afterEnsureRequirements.collect { case Exchange(_, _) => true }.nonEmpty) { - fail(s"No Exchanges should have been added:\n$afterEnsureRequirements") - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 18b0e54dc7c53..6cb751c7bf737 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext._ @@ -202,4 +202,14 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } } } + + test("EnsureRequirements shouldn't add exchange to SMJ inputs if both are SinglePartition") { + val df = (1 to 10).map(Tuple1.apply).toDF("a").repartition(1) + val keys = Seq(df.col("a").expr) + val smj = SortMergeJoin(keys, keys, df.queryExecution.sparkPlan, df.queryExecution.sparkPlan) + val afterEnsureRequirements = EnsureRequirements(df.sqlContext).apply(smj) + if (afterEnsureRequirements.collect { case Exchange(_, _) => true }.nonEmpty) { + fail(s"No Exchanges should have been added:\n$afterEnsureRequirements") + } + } }