Skip to content

Commit

Permalink
[SPARK-32399][SQL] Full outer shuffled hash join
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Add support for full outer join inside shuffled hash join. Currently if the query is a full outer join, we only use sort merge join as the physical operator. However it can be CPU and IO intensive in case input table is large for sort merge join. Shuffled hash join on the other hand saves the sort CPU and IO compared to sort merge join, especially when table is large.

This PR implements the full outer join as followed:
* Process rows from stream side by looking up hash relation, and mark the matched rows from build side by:
  * for joining with unique key, a `BitSet` is used to record matched rows from build side (`key index` to represent each row)
  * for joining with non-unique key, a `HashSet[Long]` is  used to record matched rows from build side (`key index` + `value index` to represent each row).
`key index` is defined as the index into key addressing array `longArray` in `BytesToBytesMap`.
`value index` is defined as the iterator index of values for same key.

* Process rows from build side by iterating hash relation, and filter out rows from build side being looked up already (done in `ShuffledHashJoinExec.fullOuterJoin`)

For context, this PR was originally implemented as followed (up to commit e332276):
1. Construct hash relation from build side, with extra boolean value at the end of row to track look up information (done in `ShuffledHashJoinExec.buildHashedRelation` and `UnsafeHashedRelation.apply`).
2. Process rows from stream side by looking up hash relation, and mark the matched rows from build side be looked up (done in `ShuffledHashJoinExec.fullOuterJoin`).
3. Process rows from build side by iterating hash relation, and filter out rows from build side being looked up already (done in `ShuffledHashJoinExec.fullOuterJoin`).

See discussion of pros and cons between these two approaches [here](#29342 (comment)), [here](#29342 (comment)) and [here](#29342 (comment)).

TODO: codegen for full outer shuffled hash join can be implemented in another followup PR.

### Why are the changes needed?

As implementation in this PR, full outer shuffled hash join will have overhead to iterate build side twice (once for building hash map, and another for outputting non-matching rows), and iterate stream side once. However, full outer sort merge join needs to iterate both sides twice, and sort the large table can be more CPU and IO intensive. So full outer shuffled hash join can be more efficient than sort merge join when stream side is much more larger than build side.

For example query below, full outer SHJ saved 30% wall clock time compared to full outer SMJ.

```
def shuffleHashJoin(): Unit = {
    val N: Long = 4 << 22
    withSQLConf(
      SQLConf.SHUFFLE_PARTITIONS.key -> "2",
      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000") {
      codegenBenchmark("shuffle hash join", N) {
        val df1 = spark.range(N).selectExpr(s"cast(id as string) as k1")
        val df2 = spark.range(N / 10).selectExpr(s"cast(id * 10 as string) as k2")
        val df = df1.join(df2, col("k1") === col("k2"), "full_outer")
        df.noop()
    }
  }
}
```

```
Running benchmark: shuffle hash join
  Running case: shuffle hash join off
  Stopped after 2 iterations, 16602 ms
  Running case: shuffle hash join on
  Stopped after 5 iterations, 31911 ms

Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.4
Intel(R) Core(TM) i9-9980HK CPU  2.40GHz
shuffle hash join:                        Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------
shuffle hash join off                              7900           8301         567          2.1         470.9       1.0X
shuffle hash join on                               6250           6382          95          2.7         372.5       1.3X
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added unit test in `JoinSuite.scala`, `AbstractBytesToBytesMapSuite.java` and `HashedRelationSuite.scala`.

Closes #29342 from c21/full-outer-shj.

Authored-by: Cheng Su <chengsu@fb.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
  • Loading branch information
c21 authored and maropu committed Aug 16, 2020
1 parent 9a79bbc commit 8f0fef1
Show file tree
Hide file tree
Showing 12 changed files with 693 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,68 @@ public MapIterator destructiveIterator() {
return new MapIterator(numValues, new Location(), true);
}

/**
* Iterator for the entries of this map. This is to first iterate over key indices in
* `longArray` then accessing values in `dataPages`. NOTE: this is different from `MapIterator`
* in the sense that key index is preserved here
* (See `UnsafeHashedRelation` for example of usage).
*/
public final class MapIteratorWithKeyIndex implements Iterator<Location> {

/**
* The index in `longArray` where the key is stored.
*/
private int keyIndex = 0;

private int numRecords;
private final Location loc;

private MapIteratorWithKeyIndex() {
this.numRecords = numValues;
this.loc = new Location();
}

@Override
public boolean hasNext() {
return numRecords > 0;
}

@Override
public Location next() {
if (!loc.isDefined() || !loc.nextValue()) {
while (longArray.get(keyIndex * 2) == 0) {
keyIndex++;
}
loc.with(keyIndex, 0, true);
keyIndex++;
}
numRecords--;
return loc;
}
}

/**
* Returns an iterator for iterating over the entries of this map,
* by first iterating over the key index inside hash map's `longArray`.
*
* For efficiency, all calls to `next()` will return the same {@link Location} object.
*
* The returned iterator is NOT thread-safe. If the map is modified while iterating over it,
* the behavior of the returned iterator is undefined.
*/
public MapIteratorWithKeyIndex iteratorWithKeyIndex() {
return new MapIteratorWithKeyIndex();
}

/**
* The maximum number of allowed keys index.
*
* The value of allowed keys index is in the range of [0, maxNumKeysIndex - 1].
*/
public int maxNumKeysIndex() {
return (int) (longArray.size() / 2);
}

/**
* Looks up a key, and return a {@link Location} handle that can be used to test existence
* and read/write values.
Expand Down Expand Up @@ -601,6 +663,14 @@ public boolean isDefined() {
return isDefined;
}

/**
* Returns index for key.
*/
public int getKeyIndex() {
assert (isDefined);
return pos;
}

/**
* Returns the base object for key.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ public void emptyMap() {
final byte[] key = getRandomByteArray(keyLengthInWords);
Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined());
Assert.assertFalse(map.iterator().hasNext());
Assert.assertFalse(map.iteratorWithKeyIndex().hasNext());
} finally {
map.free();
}
Expand Down Expand Up @@ -233,9 +234,10 @@ public void setAndRetrieveAKey() {
}
}

private void iteratorTestBase(boolean destructive) throws Exception {
private void iteratorTestBase(boolean destructive, boolean isWithKeyIndex) throws Exception {
final int size = 4096;
BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size / 2, PAGE_SIZE_BYTES);
Assert.assertEquals(size / 2, map.maxNumKeysIndex());
try {
for (long i = 0; i < size; i++) {
final long[] value = new long[] { i };
Expand Down Expand Up @@ -267,6 +269,8 @@ private void iteratorTestBase(boolean destructive) throws Exception {
final Iterator<BytesToBytesMap.Location> iter;
if (destructive) {
iter = map.destructiveIterator();
} else if (isWithKeyIndex) {
iter = map.iteratorWithKeyIndex();
} else {
iter = map.iterator();
}
Expand All @@ -291,6 +295,12 @@ private void iteratorTestBase(boolean destructive) throws Exception {
countFreedPages++;
}
}
if (keyLength != 0 && isWithKeyIndex) {
final BytesToBytesMap.Location expectedLoc = map.lookup(
loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength());
Assert.assertTrue(expectedLoc.isDefined() &&
expectedLoc.getKeyIndex() == loc.getKeyIndex());
}
}
if (destructive) {
// Latest page is not freed by iterator but by map itself
Expand All @@ -304,12 +314,17 @@ private void iteratorTestBase(boolean destructive) throws Exception {

@Test
public void iteratorTest() throws Exception {
iteratorTestBase(false);
iteratorTestBase(false, false);
}

@Test
public void destructiveIteratorTest() throws Exception {
iteratorTestBase(true);
iteratorTestBase(true, false);
}

@Test
public void iteratorWithKeyIndexTest() throws Exception {
iteratorTestBase(false, true);
}

@Test
Expand Down Expand Up @@ -603,6 +618,12 @@ public void multipleValuesForSameKey() {
final BytesToBytesMap.Location loc = iter.next();
assert loc.isDefined();
}
BytesToBytesMap.MapIteratorWithKeyIndex iterWithKeyIndex = map.iteratorWithKeyIndex();
for (i = 0; i < 2048; i++) {
assert iterWithKeyIndex.hasNext();
final BytesToBytesMap.Location loc = iterWithKeyIndex.next();
assert loc.isDefined() && loc.getKeyIndex() >= 0;
}
} finally {
map.free();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ trait JoinSelectionHelper {
canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint)
}
getBuildSide(
canBuildLeft(joinType) && buildLeft,
canBuildRight(joinType) && buildRight,
canBuildBroadcastLeft(joinType) && buildLeft,
canBuildBroadcastRight(joinType) && buildRight,
left,
right
)
Expand All @@ -260,8 +260,8 @@ trait JoinSelectionHelper {
canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left)
}
getBuildSide(
canBuildLeft(joinType) && buildLeft,
canBuildRight(joinType) && buildRight,
canBuildShuffledHashJoinLeft(joinType) && buildLeft,
canBuildShuffledHashJoinRight(joinType) && buildRight,
left,
right
)
Expand All @@ -278,20 +278,35 @@ trait JoinSelectionHelper {
plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold
}

def canBuildLeft(joinType: JoinType): Boolean = {
def canBuildBroadcastLeft(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | RightOuter => true
case _ => false
}
}

def canBuildRight(joinType: JoinType): Boolean = {
def canBuildBroadcastRight(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true
case _ => false
}
}

def canBuildShuffledHashJoinLeft(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | RightOuter | FullOuter => true
case _ => false
}
}

def canBuildShuffledHashJoinRight(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | LeftOuter | FullOuter |
LeftSemi | LeftAnti | _: ExistenceJoin => true
case _ => false
}
}

def hintToBroadcastLeft(hint: JoinHint): Boolean = {
hint.leftHint.exists(_.strategy.contains(BROADCAST))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,10 @@ object SQLConf {

val PREFER_SORTMERGEJOIN = buildConf("spark.sql.join.preferSortMergeJoin")
.internal()
.doc("When true, prefer sort merge join over shuffle hash join.")
.doc("When true, prefer sort merge join over shuffled hash join. " +
"Sort merge join consumes less memory than shuffled hash join and it works efficiently " +
"when both join tables are large. On the other hand, shuffled hash join can improve " +
"performance (e.g., of full outer joins) when one of join tables is much smaller.")
.version("2.0.0")
.booleanConf
.createWithDefault(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*
* - Shuffle hash join:
* Only supported for equi-joins, while the join keys do not need to be sortable.
* Supported for all join types except full outer joins.
* Supported for all join types.
* Building hash map from table is a memory-intensive operation and it could cause OOM
* when the build side is big.
*
* - Shuffle sort merge join (SMJ):
* Only supported for equi-joins and the join keys have to be sortable.
Expand Down Expand Up @@ -260,7 +262,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// it's a right join, and broadcast right side if it's a left join.
// TODO: revisit it. If left side is much smaller than the right side, it may be better
// to broadcast the left side even if it's a left join.
if (canBuildLeft(joinType)) BuildLeft else BuildRight
if (canBuildBroadcastLeft(joinType)) BuildLeft else BuildRight
}

def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
}
}

@transient private lazy val (buildOutput, streamedOutput) = {
@transient protected lazy val (buildOutput, streamedOutput) = {
buildSide match {
case BuildLeft => (left.output, right.output)
case BuildRight => (right.output, left.output)
Expand All @@ -133,7 +133,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
protected def streamSideKeyGenerator(): UnsafeProjection =
UnsafeProjection.create(streamedBoundKeys)

@transient private[this] lazy val boundCondition = if (condition.isDefined) {
@transient protected[this] lazy val boundCondition = if (condition.isDefined) {
Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _
} else {
(r: InternalRow) => true
Expand Down
Loading

0 comments on commit 8f0fef1

Please sign in to comment.