Skip to content

Commit

Permalink
[SPARK-40950][CORE] Fix isRemoteAddressMaxedOut performance overhead …
Browse files Browse the repository at this point in the history
…on scala 2.13

### What changes were proposed in this pull request?
In FetchBlockRequest we currently store a `Seq[FetchBlockInfo]` as part of the function `isRemoteAddressMaxedOut` (probably other places as well, but this is the function that showd up in my profileling) we use the length of this Seq. In scala 2.12 `Seq` is an alias for `scala.collection.Seq` but in 2.13 it an alias for `scala.collection.immutable.Seq`. This means that in when for example we call `toSeq` on a `ArrayBuffer` in 2.12 we do nothing and the `blocks` in the `FetchRequest` will be backed by something with a cheap `length` but in 2.13 we end up copying the data to a `List` with O(n) length function.

This PR solves this changing the `Seq` to and `IndexedSeq` and therefore making the expectation of a cheap length function explicit. This means that we some places will do an extra copy in scala 2.13 compared to 2.12 (was also the case before this PR). If we wanted to avoid this copy we should instead change it to use `scala.collection.IndexedSeq` so we would have the same types in both 2.13 and 2.12.

### Why are the changes needed?
The performance for ShuffleBlockFetcherIterator is much worse on Scala 2.13 than 2.12. Have seen cases were the overhead of repeatedly calculating the length is as much as 20% of cpu time (and could probably be even worse for larger shuffles).

### Does this PR introduce _any_ user-facing change?
No. I think the interface changes are only on private classes.

### How was this patch tested?
Existing specs.

Closes #38427 from eejbyfeldt/SPARK-40950.

Authored-by: Emil Ejbyfeldt <eejbyfeldt@liveintent.com>
Signed-off-by: Mridul <mridul<at>gmail.com>
  • Loading branch information
eejbyfeldt authored and Mridul committed Nov 4, 2022
1 parent 5196ff5 commit 66c6aab
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 37 deletions.
26 changes: 14 additions & 12 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.nio.ByteBuffer
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.locks.ReentrantReadWriteLock

import scala.collection
import scala.collection.JavaConverters._
import scala.collection.mutable.{HashMap, ListBuffer, Map}
import scala.concurrent.{ExecutionContext, Future}
Expand Down Expand Up @@ -457,7 +458,7 @@ private[spark] case class GetMapAndMergeOutputMessage(shuffleId: Int,
private[spark] case class GetShufflePushMergersMessage(shuffleId: Int,
context: RpcCallContext) extends MapOutputTrackerMasterMessage
private[spark] case class MapSizesByExecutorId(
iter: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], enableBatchFetch: Boolean)
iter: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])], enableBatchFetch: Boolean)

/** RpcEndpoint class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterEndpoint(
Expand Down Expand Up @@ -535,7 +536,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging

// For testing
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = {
getMapSizesByExecutorId(shuffleId, 0, Int.MaxValue, reduceId, reduceId + 1)
}

Expand Down Expand Up @@ -563,7 +564,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
endPartition: Int): Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]

/**
* Called from executors to get the server URIs and output sizes for each shuffle block that
Expand Down Expand Up @@ -600,7 +601,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
*/
def getMapSizesForMergeResult(
shuffleId: Int,
partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
partitionId: Int): Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]

/**
* Called from executors upon fetch failure on a merged shuffle reduce partition chunk. This is
Expand All @@ -619,7 +620,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
def getMapSizesForMergeResult(
shuffleId: Int,
partitionId: Int,
chunkBitmap: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
chunkBitmap: RoaringBitmap): Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]

/**
* Called from executors whenever a task with push based shuffle is enabled doesn't have shuffle
Expand Down Expand Up @@ -1147,7 +1148,7 @@ private[spark] class MapOutputTrackerMaster(
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
endPartition: Int): Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = {
val mapSizesByExecutorId = getPushBasedShuffleMapSizesByExecutorId(
shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
assert(mapSizesByExecutorId.enableBatchFetch == true)
Expand Down Expand Up @@ -1251,7 +1252,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
endPartition: Int): Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = {
val mapSizesByExecutorId = getMapSizesByExecutorIdImpl(
shuffleId, startMapIndex, endMapIndex, startPartition, endPartition, useMergeResult = false)
assert(mapSizesByExecutorId.enableBatchFetch == true)
Expand Down Expand Up @@ -1303,7 +1304,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr

override def getMapSizesForMergeResult(
shuffleId: Int,
partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
partitionId: Int): Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
// Fetch the map statuses and merge statuses again since they might have already been
// cleared by another task running in the same executor.
Expand All @@ -1328,7 +1329,8 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
override def getMapSizesForMergeResult(
shuffleId: Int,
partitionId: Int,
chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
chunkTracker: RoaringBitmap
): Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
// Fetch the map statuses and merge statuses again since they might have already been
// cleared by another task running in the same executor.
Expand Down Expand Up @@ -1660,7 +1662,7 @@ private[spark] object MapOutputTracker extends Logging {
}
}

MapSizesByExecutorId(splitsByAddress.mapValues(_.toSeq).iterator, enableBatchFetch)
MapSizesByExecutorId(splitsByAddress.iterator, enableBatchFetch)
}

/**
Expand All @@ -1683,7 +1685,7 @@ private[spark] object MapOutputTracker extends Logging {
shuffleId: Int,
partitionId: Int,
mapStatuses: Array[MapStatus],
tracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
tracker: RoaringBitmap): Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = {
assert (mapStatuses != null && tracker != null)
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
for ((status, mapIndex) <- mapStatuses.zipWithIndex) {
Expand All @@ -1695,7 +1697,7 @@ private[spark] object MapOutputTracker extends Logging {
status.getSizeForBlock(partitionId), mapIndex))
}
}
splitsByAddress.mapValues(_.toSeq).iterator
splitsByAddress.iterator
}

def validateStatus(status: ShuffleOutputStatus, shuffleId: Int, partition: Int) : Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.shuffle

import scala.collection

import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.io.CompressionCodec
Expand All @@ -30,7 +32,7 @@ import org.apache.spark.util.collection.ExternalSorter
*/
private[spark] class BlockStoreShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])],
context: TaskContext,
readMetrics: ShuffleReadMetricsReporter,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.storage

import java.util.concurrent.TimeUnit

import scala.collection
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Success}
Expand Down Expand Up @@ -292,7 +293,7 @@ private class PushBasedFetchHelper(
logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId")
// Increase the blocks processed since we will process another block in the next iteration of
// the while loop in ShuffleBlockFetcherIterator.next().
val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
val fallbackBlocksByAddr: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] =
blockId match {
case shuffleBlockId: ShuffleMergedBlockId =>
iterator.decreaseNumBlocksToFetch(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import java.util.zip.CheckedInputStream
import javax.annotation.concurrent.GuardedBy

import scala.collection
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
import scala.util.{Failure, Success}
Expand Down Expand Up @@ -87,7 +88,7 @@ final class ShuffleBlockFetcherIterator(
shuffleClient: BlockStoreClient,
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker,
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
Expand Down Expand Up @@ -276,7 +277,7 @@ final class ShuffleBlockFetcherIterator(
val (size, mapIndex) = infoMap(blockId)
FetchBlockInfo(BlockId(blockId), size, mapIndex)
}
results.put(DeferFetchRequestResult(FetchRequest(address, blocks.toSeq)))
results.put(DeferFetchRequestResult(FetchRequest(address, blocks)))
deferredBlocks.clear()
}
}
Expand Down Expand Up @@ -369,9 +370,10 @@ final class ShuffleBlockFetcherIterator(
* [[PushBasedFetchHelper]].
*/
private[this] def partitionBlocksByFetchMode(
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
blocksByAddress: Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])],
localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
hostLocalBlocksByExecutor:
mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]],
pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+ s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
Expand Down Expand Up @@ -448,7 +450,7 @@ final class ShuffleBlockFetcherIterator(
}

private def createFetchRequest(
blocks: Seq[FetchBlockInfo],
blocks: collection.Seq[FetchBlockInfo],
address: BlockManagerId,
forMergedMetas: Boolean): FetchRequest = {
logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address "
Expand All @@ -457,7 +459,7 @@ final class ShuffleBlockFetcherIterator(
}

private def createFetchRequests(
curBlocks: Seq[FetchBlockInfo],
curBlocks: collection.Seq[FetchBlockInfo],
address: BlockManagerId,
isLast: Boolean,
collectedRemoteRequests: ArrayBuffer[FetchRequest],
Expand Down Expand Up @@ -485,7 +487,7 @@ final class ShuffleBlockFetcherIterator(

private def collectFetchRequests(
address: BlockManagerId,
blockInfos: Seq[(BlockId, Long, Int)],
blockInfos: collection.Seq[(BlockId, Long, Int)],
collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = {
val iterator = blockInfos.iterator
var curRequestSize = 0L
Expand All @@ -502,20 +504,20 @@ final class ShuffleBlockFetcherIterator(
case ShuffleBlockChunkId(_, _, _, _) =>
if (curRequestSize >= targetRemoteRequestSize ||
curBlocks.size >= maxBlocksInFlightPerAddress) {
curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false,
curBlocks = createFetchRequests(curBlocks, address, isLast = false,
collectedRemoteRequests, enableBatchFetch = false)
curRequestSize = curBlocks.map(_.size).sum
}
case ShuffleMergedBlockId(_, _, _) =>
if (curBlocks.size >= maxBlocksInFlightPerAddress) {
curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false,
curBlocks = createFetchRequests(curBlocks, address, isLast = false,
collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true)
}
case _ =>
// For batch fetch, the actual block in flight should count for merged block.
val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false,
curBlocks = createFetchRequests(curBlocks, address, isLast = false,
collectedRemoteRequests, doBatchFetch)
curRequestSize = curBlocks.map(_.size).sum
}
Expand All @@ -530,7 +532,7 @@ final class ShuffleBlockFetcherIterator(
case _ => (doBatchFetch, false)
}
}
createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests,
createFetchRequests(curBlocks, address, isLast = true, collectedRemoteRequests,
enableBatchFetch = enableBatchFetch, forMergedMetas = forMergedMetas)
}
}
Expand All @@ -543,7 +545,7 @@ final class ShuffleBlockFetcherIterator(
}
}

private def checkBlockSizes(blockInfos: Seq[(BlockId, Long, Int)]): Unit = {
private def checkBlockSizes(blockInfos: collection.Seq[(BlockId, Long, Int)]): Unit = {
blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) }
}

Expand Down Expand Up @@ -609,7 +611,8 @@ final class ShuffleBlockFetcherIterator(
*/
private[this] def fetchHostLocalBlocks(
hostLocalDirManager: HostLocalDirManager,
hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]):
hostLocalBlocksByExecutor:
mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]):
Unit = {
val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs
val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = {
Expand Down Expand Up @@ -662,7 +665,7 @@ final class ShuffleBlockFetcherIterator(
}

private def fetchMultipleHostLocalBlocks(
bmIdToBlocks: Map[BlockManagerId, Seq[(BlockId, Long, Int)]],
bmIdToBlocks: Map[BlockManagerId, collection.Seq[(BlockId, Long, Int)]],
localDirsByExecId: Map[String, Array[String]],
cached: Boolean): Unit = {
// We use `forall` because once there's a failed block fetch, `fetchHostLocalBlock` will put
Expand All @@ -686,7 +689,7 @@ final class ShuffleBlockFetcherIterator(
// Local blocks to fetch, excluding zero-sized blocks.
val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
val hostLocalBlocksByExecutor =
mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]()
val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
// Partition blocks by the different fetch modes: local, host-local, push-merged-local and
// remote blocks.
Expand Down Expand Up @@ -715,7 +718,8 @@ final class ShuffleBlockFetcherIterator(
}

private def fetchAllHostLocalBlocks(
hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]):
hostLocalBlocksByExecutor:
mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]):
Unit = {
if (hostLocalBlocksByExecutor.nonEmpty) {
blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor))
Expand Down Expand Up @@ -1191,10 +1195,11 @@ final class ShuffleBlockFetcherIterator(
* fallback.
*/
private[storage] def fallbackFetch(
originalBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = {
originalBlocksByAddr:
Iterator[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])]): Unit = {
val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
val originalHostLocalBlocksByExecutor =
mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
mutable.LinkedHashMap[BlockManagerId, collection.Seq[(BlockId, Long, Int)]]()
val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
val originalRemoteReqs = partitionBlocksByFetchMode(originalBlocksByAddr,
originalLocalBlocks, originalHostLocalBlocksByExecutor, originalMergedLocalBlocks)
Expand Down Expand Up @@ -1374,8 +1379,8 @@ object ShuffleBlockFetcherIterator {
* @return the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true.
*/
def mergeContinuousShuffleBlockIdsIfNeeded(
blocks: Seq[FetchBlockInfo],
doBatchFetch: Boolean): Seq[FetchBlockInfo] = {
blocks: collection.Seq[FetchBlockInfo],
doBatchFetch: Boolean): collection.Seq[FetchBlockInfo] = {
val result = if (doBatchFetch) {
val curBlocks = new ArrayBuffer[FetchBlockInfo]
val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo]
Expand Down Expand Up @@ -1438,7 +1443,7 @@ object ShuffleBlockFetcherIterator {
} else {
blocks
}
result.toSeq
result
}

/**
Expand All @@ -1462,7 +1467,7 @@ object ShuffleBlockFetcherIterator {
*/
case class FetchRequest(
address: BlockManagerId,
blocks: Seq[FetchBlockInfo],
blocks: collection.Seq[FetchBlockInfo],
forMergedMetas: Boolean = false) {
val size = blocks.map(_.size).sum
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.File
import java.nio.file.Files
import java.nio.file.attribute.PosixFilePermission

import scala.collection
import scala.concurrent.Promise
import scala.concurrent.duration.Duration

Expand Down Expand Up @@ -201,7 +202,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi
.getOrElse(fail("No host local dir manager"))

val promises = mapOutputs.map { case (bmid, blocks) =>
val promise = Promise[Seq[File]]()
val promise = Promise[collection.Seq[File]]()
dirManager.getHostLocalDirs(bmid.host, bmid.port, Seq(bmid.executorId).toArray) {
case scala.util.Success(res) => res.foreach { case (eid, dirs) =>
val files = blocks.flatMap { case (blockId, _, _) =>
Expand Down

0 comments on commit 66c6aab

Please sign in to comment.