Skip to content

Commit

Permalink
[SPARK-40407][SQL] Fix the potential data skew caused by df.repartition
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

``` scala
val df = spark.range(0, 100, 1, 50).repartition(4)
val v = df.rdd.mapPartitions { iter => {
        Iterator.single(iter.length)
}.collect()
println(v.mkString(","))
```

The above simple code outputs `50,0,0,50`, which means there is no data in partition 1 and partition 2.

The RoundRobin seems to ensure to distribute the records evenly *in the same partition*, and not guarantee it between partitions.

Below is the code to generate the key

``` scala
      case RoundRobinPartitioning(numPartitions) =>
        // Distributes elements evenly across output partitions, starting from a random partition.
        var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions)
        (row: InternalRow) =>
{         // The HashPartitioner will handle the `mod` by the number of partitions
         position += 1
         position
 }
```

In this case, There are 50 partitions, each partition will only compute 2 elements. The issue for RoundRobin here is it always starts with position=2 to do the Roundrobin.

See the output of Random
``` scala
scala> (1 to 200).foreach(partitionId => print(new Random(partitionId).nextInt(4) + " "))  // the position is always 2.
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
```

Similarly, the below Random code also outputs the same value,

``` scala
(1 to 200).foreach(partitionId => print(new Random(partitionId).nextInt(2) + " "))
(1 to 200).foreach(partitionId => print(new Random(partitionId).nextInt(4) + " "))
(1 to 200).foreach(partitionId => print(new Random(partitionId).nextInt(8) + " "))
(1 to 200).foreach(partitionId => print(new Random(partitionId).nextInt(16) + " "))
(1 to 200).foreach(partitionId => print(new Random(partitionId).nextInt(32) + " "))
```

Consider partition 0, the total elements are [0, 1], so when shuffle writes, for element 0, the key will be (position + 1) = 2 + 1 = 3%4=3, the element 1, the key will be (position + 1)=(3+1)=4%4 = 0
consider partition 1, the total elements are [2, 3], so when shuffle writes, for element 2, the key will be (position + 1) = 2 + 1 = 3%4=3, the element 3, the key will be (position + 1)=(3+1)=4%4 = 0

The calculation is also applied for other left partitions since the starting position is always 2 for this case.

So, as you can see, each partition will write its elements to Partition [0, 3], which results in Partition [1, 2] without any data.

This PR changes the starting position of RoundRobin. The default position calculated by `new Random(partitionId).nextInt(numPartitions)` may always be the same for different partitions, which means each partition will output the data into the same keys when shuffle writes, and some keys may not have any data in some special cases.

### Why are the changes needed?

The PR can fix the data skew issue for the special cases.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Will add some tests and watch CI pass

Closes #37855 from wbo4958/roundrobin-data-skew.

Authored-by: Bobby Wang <wbo4958@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit f6c4e58)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
wbo4958 authored and cloud-fan committed Sep 22, 2022
1 parent b608ba3 commit 2bae604
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.Random
import java.util.function.Supplier

import scala.concurrent.Future
import scala.util.hashing

import org.apache.spark._
import org.apache.spark.internal.config
Expand Down Expand Up @@ -306,7 +307,14 @@ object ShuffleExchangeExec {
def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match {
case RoundRobinPartitioning(numPartitions) =>
// Distributes elements evenly across output partitions, starting from a random partition.
var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions)
// nextInt(numPartitions) implementation has a special case when bound is a power of 2,
// which is basically taking several highest bits from the initial seed, with only a
// minimal scrambling. Due to deterministic seed, using the generator only once,
// and lack of scrambling, the position values for power-of-two numPartitions always
// end up being almost the same regardless of the index. substantially scrambling the
// seed by hashing will help. Refer to SPARK-21782 for more details.
val partitionId = TaskContext.get().partitionId()
var position = new Random(hashing.byteswap32(partitionId)).nextInt(numPartitions)
(row: InternalRow) => {
// The HashPartitioner will handle the `mod` by the number of partitions
position += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2132,6 +2132,12 @@ class DatasetSuite extends QueryTest
(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12),
(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13))
}

test("SPARK-40407: repartition should not result in severe data skew") {
val df = spark.range(0, 100, 1, 50).repartition(4)
val result = df.mapPartitions(iter => Iterator.single(iter.length)).collect()
assert(result.sorted.toSeq === Seq(19, 25, 25, 31))
}
}

case class Bar(a: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2090,8 +2090,8 @@ class AdaptiveQueryExecSuite
withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "150") {
// partition size [0,258,72,72,72]
checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 2, 4)
// partition size [72,216,216,144,72]
checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 4, 7)
// partition size [144,72,144,216,144]
checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 2, 6)
}

// no skewed partition should be optimized
Expand Down

0 comments on commit 2bae604

Please sign in to comment.