Skip to content

Commit

Permalink
Write a more generic test for EnsureRequirements.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Aug 6, 2015
1 parent 752b8de commit 2e0f33a
Showing 1 changed file with 37 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions .{Ascending, Literal, Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin, SortMergeJoin}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext._
Expand Down Expand Up @@ -203,13 +207,38 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
}
}

test("EnsureRequirements shouldn't add exchange to SMJ inputs if both are SinglePartition") {
val df = (1 to 10).map(Tuple1.apply).toDF("a").repartition(1)
val keys = Seq(df.col("a").expr)
val smj = SortMergeJoin(keys, keys, df.queryExecution.sparkPlan, df.queryExecution.sparkPlan)
val afterEnsureRequirements = EnsureRequirements(df.sqlContext).apply(smj)
if (afterEnsureRequirements.collect { case Exchange(_, _) => true }.nonEmpty) {
fail(s"No Exchanges should have been added:\n$afterEnsureRequirements")
// --- Unit tests of EnsureRequirements ---------------------------------------------------------

test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") {
val outputOrdering = Seq(SortOrder(Literal(1), Ascending))
val distribution = ClusteredDistribution(Literal(1) :: Nil)
val inputPlan = DummyPlan(
children = Seq(
DummyPlan(outputPartitioning = SinglePartition),
DummyPlan(outputPartitioning = SinglePartition)
),
requiresChildrenToProduceSameNumberOfPartitions = true,
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(outputOrdering, outputOrdering)
)
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) {
fail(s"No Exchanges should have been added:\n$outputPlan")
}
}

// ---------------------------------------------------------------------------------------------
}

// Used for unit-testing EnsureRequirements
private case class DummyPlan(
override val children: Seq[SparkPlan] = Nil,
override val outputOrdering: Seq[SortOrder] = Nil,
override val outputPartitioning: Partitioning = UnknownPartitioning(0),
override val requiresChildrenToProduceSameNumberOfPartitions: Boolean = false,
override val requiredChildDistribution: Seq[Distribution] = Nil,
override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil
) extends SparkPlan {
override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError
override def output: Seq[Attribute] = Seq.empty
}

0 comments on commit 2e0f33a

Please sign in to comment.