Skip to content

Commit

Permalink
Minimize changes to UnionRDD
Browse files Browse the repository at this point in the history
  • Loading branch information
ankurdave committed Oct 20, 2023
1 parent 8123d37 commit 922b878
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.rdd
import java.io.{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.collection.parallel.ForkJoinTaskSupport
import scala.collection.parallel.immutable.ParVector
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
Expand Down Expand Up @@ -57,6 +59,11 @@ private[spark] class UnionPartition[T: ClassTag](
}
}

object UnionRDD {
private[spark] lazy val partitionEvalTaskSupport =
new ForkJoinTaskSupport(ThreadUtils.newForkJoinPool("partition-eval-task-support", 8))
}

@DeveloperApi
class UnionRDD[T: ClassTag](
sc: SparkContext,
Expand All @@ -67,16 +74,17 @@ class UnionRDD[T: ClassTag](
private[spark] val isPartitionListingParallel: Boolean =
rdds.length > conf.get(RDD_PARALLEL_LISTING_THRESHOLD)

private def countParentPartitions(): Int = {
if (isPartitionListingParallel) {
ThreadUtils.parmap(rdds, "UnionRDD-parallel-eval", maxThreads = 8)(_.partitions.length).sum
override def getPartitions: Array[Partition] = {
val parRDDs = if (isPartitionListingParallel) {
// scalastyle:off parvector
val parArray = new ParVector(rdds.toVector)
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
// scalastyle:on parvector
parArray
} else {
parRDDs.map(_.partitions.length).sum
rdds
}
}

override def getPartitions: Array[Partition] = {
val array = new Array[Partition](countParentPartitions())
val array = new Array[Partition](parRDDs.map(_.partitions.length).sum)
var pos = 0
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
Expand Down

0 comments on commit 922b878

Please sign in to comment.