Skip to content

Commit

Permalink
Test with a single partition in all operator join unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Aug 10, 2015
1 parent 2a9165e commit c700df8
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,96 +20,102 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.{execution, Row, DataFrame}
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution._

class InnerJoinSuite extends SparkPlanTest {
class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {

private def testInnerJoin(
testName: String,
leftRows: DataFrame,
rightRows: DataFrame,
condition: Expression,
expectedAnswer: Seq[Product]): Unit = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>

def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val broadcastHashJoin =
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
}

def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val shuffledHashJoin =
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
val filteredJoin =
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(filteredJoin.sqlContext).apply(filteredJoin)
}

def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
val sortMergeJoin =
execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
EnsureRequirements(filteredJoin.sqlContext).apply(filteredJoin)
}

test(s"$testName using BroadcastHashJoin (build=left)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeBroadcastHashJoin(left, right, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

test(s"$testName using BroadcastHashJoin (build=right)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeBroadcastHashJoin(left, right, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

test(s"$testName using ShuffledHashJoin (build=left)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeShuffledHashJoin(left, right, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

test(s"$testName using ShuffledHashJoin (build=right)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeShuffledHashJoin(left, right, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

test(s"$testName using SortMergeJoin") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeSortMergeJoin(left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>

def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val broadcastHashJoin =
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
}

def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val shuffledHashJoin =
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
val filteredJoin =
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}

def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
val sortMergeJoin =
execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}

test(s"$testName using BroadcastHashJoin (build=left)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeBroadcastHashJoin(left, right, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

test(s"$testName using BroadcastHashJoin (build=right)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeBroadcastHashJoin(left, right, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

test(s"$testName using ShuffledHashJoin (build=left)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeShuffledHashJoin(left, right, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

test(s"$testName using ShuffledHashJoin (build=right)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeShuffledHashJoin(left, right, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

test(s"$testName using SortMergeJoin") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeSortMergeJoin(left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
}

{
val upperCaseData = Seq(
(1, "A"),
(2, "B"),
(3, "C"),
(4, "D"),
(5, "E"),
(6, "F")
).toDF("N", "L")

val lowerCaseData = Seq(
(1, "a"),
(2, "b"),
(3, "c"),
(4, "d")
).toDF("n", "l")
val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
Row(1, "A"),
Row(2, "B"),
Row(3, "C"),
Row(4, "D"),
Row(5, "E"),
Row(6, "F"),
Row(null, "G")
)), new StructType().add("N", IntegerType).add("L", StringType))

val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
Row(1, "a"),
Row(2, "b"),
Row(3, "c"),
Row(4, "d"),
Row(null, "e")
)), new StructType().add("n", IntegerType).add("l", StringType))

testInnerJoin(
"inner join, one match per row",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest}

class OuterJoinSuite extends SparkPlanTest {
class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {

private def testOuterJoin(
testName: String,
Expand All @@ -34,39 +35,41 @@ class OuterJoinSuite extends SparkPlanTest {
joinType: JoinType,
condition: Expression,
expectedAnswer: Seq[Product]): Unit = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
test(s"$testName using ShuffledHashOuterJoin") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext).apply(
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = false)
}

if (joinType != FullOuter) {
test(s"$testName using BroadcastHashOuterJoin") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
test(s"$testName using ShuffledHashOuterJoin") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
EnsureRequirements(sqlContext).apply(
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = false)
sortAnswers = true)
}
}
}

test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
if (joinType != FullOuter) {
test(s"$testName using BroadcastHashOuterJoin") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}

test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,60 +20,69 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}

class SemiJoinSuite extends SparkPlanTest {
class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {

private def testLeftSemiJoin(
testName: String,
leftRows: DataFrame,
rightRows: DataFrame,
condition: Expression,
expectedAnswer: Seq[Product]): Unit = {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
test(s"$testName using LeftSemiJoinHash") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext).apply(
LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
test(s"$testName using LeftSemiJoinHash") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext).apply(
LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
test(s"$testName using BroadcastLeftSemiJoinHash") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}

test(s"$testName using BroadcastLeftSemiJoinHash") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}

test(s"$testName using LeftSemiJoinBNL") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
LeftSemiJoinBNL(left, right, Some(condition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
test(s"$testName using LeftSemiJoinBNL") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
LeftSemiJoinBNL(left, right, Some(condition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}

val left = Seq(
(1, 2.0),
(1, 2.0),
(2, 1.0),
(2, 1.0),
(3, 3.0)
).toDF("a", "b")
val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
Row(1, 2.0),
Row(1, 2.0),
Row(2, 1.0),
Row(2, 1.0),
Row(3, 3.0),
Row(null, null),
Row(null, 5.0),
Row(6, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))

val right = Seq(
(2, 3.0),
(2, 3.0),
(3, 2.0),
(4, 1.0)
).toDF("c", "d")
val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
Row(2, 3.0),
Row(2, 3.0),
Row(3, 2.0),
Row(4, 1.0),
Row(null, null),
Row(null, 5.0),
Row(6, null)
)), new StructType().add("c", IntegerType).add("d", DoubleType))

val condition = {
And(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils

trait SQLTestUtils { this: SparkFunSuite =>
def sqlContext: SQLContext
protected def sqlContext: SQLContext

protected def configuration = sqlContext.sparkContext.hadoopConfiguration

Expand Down

0 comments on commit c700df8

Please sign in to comment.