Skip to content

Commit

Permalink
[SPARK-20629][CORE][K8S] Copy shuffle data when nodes are being shutdown
Browse files Browse the repository at this point in the history
### What is changed?

This pull request adds the ability to migrate shuffle files during Spark's decommissioning. The design document associated with this change is at https://docs.google.com/document/d/1xVO1b6KAwdUhjEJBolVPl9C6sLj7oOveErwDSYdT-pE .

To allow this change the `MapOutputTracker` has been extended to allow the location of shuffle files to be updated with `updateMapOutput`. When a shuffle block is put, a block update message will be sent which triggers the `updateMapOutput`.

Instead of rejecting remote puts of shuffle blocks `BlockManager` delegates the storage of shuffle blocks to it's shufflemanager's resolver (if supported). A new, experimental, trait is added for shuffle resolvers to indicate they handle remote putting of blocks.

The existing block migration code is moved out into a separate file, and a producer/consumer model is introduced for migrating shuffle files from the host as quickly as possible while not overwhelming other executors.

### Why are the changes needed?

Recomputting shuffle blocks can be expensive, we should take advantage of our decommissioning time to migrate these blocks.

### Does this PR introduce any user-facing change?

This PR introduces two new configs parameters, `spark.storage.decommission.shuffleBlocks.enabled` & `spark.storage.decommission.rddBlocks.enabled` that control which blocks should be migrated during storage decommissioning.

### How was this patch tested?

New unit test & expansion of the Spark on K8s decom test to assert that decommisioning with shuffle block migration means that the results are not recomputed even when the original executor is terminated.

This PR is a cleaned-up version of the previous WIP PR I made #28331 (thanks to attilapiros for his very helpful reviewing on it :)).

Closes #28708 from holdenk/SPARK-20629-copy-shuffle-data-when-nodes-are-being-shutdown-cleaned-up.

Lead-authored-by: Holden Karau <hkarau@apple.com>
Co-authored-by: Holden Karau <holden@pigscanfly.ca>
Co-authored-by: “attilapiros” <piros.attila.zsolt@gmail.com>
Co-authored-by: Attila Zsolt Piros <attilazsoltpiros@apiros-mbp16.lan>
Signed-off-by: Holden Karau <hkarau@apple.com>
  • Loading branch information
4 people committed Jul 20, 2020
1 parent ef3cad1 commit a4ca355
Show file tree
Hide file tree
Showing 26 changed files with 1,150 additions and 255 deletions.
38 changes: 35 additions & 3 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.util._
*
* All public methods of this class are thread-safe.
*/
private class ShuffleStatus(numPartitions: Int) {
private class ShuffleStatus(numPartitions: Int) extends Logging {

private val (readLock, writeLock) = {
val lock = new ReentrantReadWriteLock()
Expand Down Expand Up @@ -121,12 +121,28 @@ private class ShuffleStatus(numPartitions: Int) {
mapStatuses(mapIndex) = status
}

/**
* Update the map output location (e.g. during migration).
*/
def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock {
val mapStatusOpt = mapStatuses.find(_.mapId == mapId)
mapStatusOpt match {
case Some(mapStatus) =>
logInfo(s"Updating map output for ${mapId} to ${bmAddress}")
mapStatus.updateLocation(bmAddress)
invalidateSerializedMapOutputStatusCache()
case None =>
logError(s"Asked to update map output ${mapId} for untracked map status.")
}
}

/**
* Remove the map output which was served by the specified block manager.
* This is a no-op if there is no registered map output or if the registered output is from a
* different block manager.
*/
def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) {
_numAvailableOutputs -= 1
mapStatuses(mapIndex) = null
Expand All @@ -139,6 +155,7 @@ private class ShuffleStatus(numPartitions: Int) {
* outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsOnHost(host: String): Unit = withWriteLock {
logDebug(s"Removing outputs for host ${host}")
removeOutputsByFilter(x => x.host == host)
}

Expand All @@ -148,6 +165,7 @@ private class ShuffleStatus(numPartitions: Int) {
* still registered with that execId.
*/
def removeOutputsOnExecutor(execId: String): Unit = withWriteLock {
logDebug(s"Removing outputs for execId ${execId}")
removeOutputsByFilter(x => x.executorId == execId)
}

Expand Down Expand Up @@ -265,7 +283,7 @@ private[spark] class MapOutputTrackerMasterEndpoint(
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = context.senderAddress.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
logInfo(s"Asked to send map output locations for shuffle ${shuffleId} to ${hostPort}")
tracker.post(new GetMapOutputMessage(shuffleId, context))

case StopMapOutputTracker =>
Expand Down Expand Up @@ -465,6 +483,15 @@ private[spark] class MapOutputTrackerMaster(
}
}

def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.updateMapOutput(mapId, bmAddress)
case None =>
logError(s"Asked to update map output for unknown shuffle ${shuffleId}")
}
}

def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
}
Expand Down Expand Up @@ -745,7 +772,12 @@ private[spark] class MapOutputTrackerMaster(
override def stop(): Unit = {
mapOutputRequests.offer(PoisonPill)
threadpool.shutdown()
sendTracker(StopMapOutputTracker)
try {
sendTracker(StopMapOutputTracker)
} catch {
case e: SparkException =>
logError("Could not tell tracker we are stopping.", e)
}
trackerEndpoint = null
shuffleStatuses.clear()
}
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ object SparkEnv extends Logging {
externalShuffleClient
} else {
None
}, blockManagerInfo)),
}, blockManagerInfo,
mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])),
registerOrLookupEndpoint(
BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME,
new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)),
Expand Down
23 changes: 23 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,29 @@ package object config {
.booleanConf
.createWithDefault(false)

private[spark] val STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED =
ConfigBuilder("spark.storage.decommission.shuffleBlocks.enabled")
.doc("Whether to transfer shuffle blocks during block manager decommissioning. Requires " +
"a migratable shuffle resolver (like sort based shuffe)")
.version("3.1.0")
.booleanConf
.createWithDefault(false)

private[spark] val STORAGE_DECOMMISSION_SHUFFLE_MAX_THREADS =
ConfigBuilder("spark.storage.decommission.shuffleBlocks.maxThreads")
.doc("Maximum number of threads to use in migrating shuffle files.")
.version("3.1.0")
.intConf
.checkValue(_ > 0, "The maximum number of threads should be positive")
.createWithDefault(8)

private[spark] val STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED =
ConfigBuilder("spark.storage.decommission.rddBlocks.enabled")
.doc("Whether to transfer RDD blocks during block manager decommissioning.")
.version("3.1.0")
.booleanConf
.createWithDefault(false)

private[spark] val STORAGE_DECOMMISSION_MAX_REPLICATION_FAILURE_PER_BLOCK =
ConfigBuilder("spark.storage.decommission.maxReplicationFailuresPerBlock")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ private[spark] class NettyBlockTransferService(
// Everything else is encoded using our binary protocol.
val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag)))

val asStream = blockData.size() > conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)
// We always transfer shuffle blocks as a stream for simplicity with the receiving code since
// they are always written to disk. Otherwise we check the block size.
val asStream = (blockData.size() > conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) ||
blockId.isShuffle)
val callback = new RpcResponseCallback {
override def onSuccess(response: ByteBuffer): Unit = {
logTrace(s"Successfully uploaded block $blockId${if (asStream) " as stream" else ""}")
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ import org.apache.spark.util.Utils

/**
* Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
* task has shuffle files stored on as well as the sizes of outputs for each reducer, for passing
* on to the reduce tasks.
*/
private[spark] sealed trait MapStatus {
/** Location where this task was run. */
/** Location where this task output is. */
def location: BlockManagerId

def updateLocation(newLoc: BlockManagerId): Unit

/**
* Estimated size for the reduce block, in bytes.
*
Expand Down Expand Up @@ -126,6 +129,10 @@ private[spark] class CompressedMapStatus(

override def location: BlockManagerId = loc

override def updateLocation(newLoc: BlockManagerId): Unit = {
loc = newLoc
}

override def getSizeForBlock(reduceId: Int): Long = {
MapStatus.decompressSize(compressedSizes(reduceId))
}
Expand Down Expand Up @@ -178,6 +185,10 @@ private[spark] class HighlyCompressedMapStatus private (

override def location: BlockManagerId = loc

override def updateLocation(newLoc: BlockManagerId): Unit = {
loc = newLoc
}

override def getSizeForBlock(reduceId: Int): Long = {
assert(hugeBlockSizes != null)
if (emptyBlocks.contains(reduceId)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private[spark] class StandaloneSchedulerBackend(
with StandaloneAppClientListener
with Logging {

private var client: StandaloneAppClient = null
private[spark] var client: StandaloneAppClient = null
private val stopping = new AtomicBoolean(false)
private val launcherBackend = new LauncherBackend() {
override protected def conf: SparkConf = sc.conf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
package org.apache.spark.shuffle

import java.io._
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.file.Files

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.io.NioBufferedFileInputStream
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.client.StreamCallbackWithID
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExecutorDiskUtils
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
import org.apache.spark.storage._
import org.apache.spark.util.Utils
Expand All @@ -44,9 +47,10 @@ import org.apache.spark.util.Utils
// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData().
private[spark] class IndexShuffleBlockResolver(
conf: SparkConf,
_blockManager: BlockManager = null)
// var for testing
var _blockManager: BlockManager = null)
extends ShuffleBlockResolver
with Logging {
with Logging with MigratableResolver {

private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)

Expand All @@ -55,6 +59,19 @@ private[spark] class IndexShuffleBlockResolver(

def getDataFile(shuffleId: Int, mapId: Long): File = getDataFile(shuffleId, mapId, None)

/**
* Get the shuffle files that are stored locally. Used for block migrations.
*/
override def getStoredShuffles(): Seq[ShuffleBlockInfo] = {
val allBlocks = blockManager.diskBlockManager.getAllBlocks()
allBlocks.flatMap {
case ShuffleIndexBlockId(shuffleId, mapId, _) =>
Some(ShuffleBlockInfo(shuffleId, mapId))
case _ =>
None
}
}

/**
* Get the shuffle data file.
*
Expand Down Expand Up @@ -148,6 +165,82 @@ private[spark] class IndexShuffleBlockResolver(
}
}

/**
* Write a provided shuffle block as a stream. Used for block migrations.
* ShuffleBlockBatchIds must contain the full range represented in the ShuffleIndexBlock.
* Requires the caller to delete any shuffle index blocks where the shuffle block fails to
* put.
*/
override def putShuffleBlockAsStream(blockId: BlockId, serializerManager: SerializerManager):
StreamCallbackWithID = {
val file = blockId match {
case ShuffleIndexBlockId(shuffleId, mapId, _) =>
getIndexFile(shuffleId, mapId)
case ShuffleDataBlockId(shuffleId, mapId, _) =>
getDataFile(shuffleId, mapId)
case _ =>
throw new IllegalStateException(s"Unexpected shuffle block transfer ${blockId} as " +
s"${blockId.getClass().getSimpleName()}")
}
val fileTmp = Utils.tempFileWith(file)
val channel = Channels.newChannel(
serializerManager.wrapStream(blockId,
new FileOutputStream(fileTmp)))

new StreamCallbackWithID {

override def getID: String = blockId.name

override def onData(streamId: String, buf: ByteBuffer): Unit = {
while (buf.hasRemaining) {
channel.write(buf)
}
}

override def onComplete(streamId: String): Unit = {
logTrace(s"Done receiving shuffle block $blockId, now storing on local disk.")
channel.close()
val diskSize = fileTmp.length()
this.synchronized {
if (file.exists()) {
file.delete()
}
if (!fileTmp.renameTo(file)) {
throw new IOException(s"fail to rename file ${fileTmp} to ${file}")
}
}
blockManager.reportBlockStatus(blockId, BlockStatus(StorageLevel.DISK_ONLY, 0, diskSize))
}

override def onFailure(streamId: String, cause: Throwable): Unit = {
// the framework handles the connection itself, we just need to do local cleanup
logWarning(s"Error while uploading $blockId", cause)
channel.close()
fileTmp.delete()
}
}
}

/**
* Get the index & data block for migration.
*/
def getMigrationBlocks(shuffleBlockInfo: ShuffleBlockInfo): List[(BlockId, ManagedBuffer)] = {
val shuffleId = shuffleBlockInfo.shuffleId
val mapId = shuffleBlockInfo.mapId
// Load the index block
val indexFile = getIndexFile(shuffleId, mapId)
val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
val indexFileSize = indexFile.length()
val indexBlockData = new FileSegmentManagedBuffer(transportConf, indexFile, 0, indexFileSize)

// Load the data block
val dataFile = getDataFile(shuffleId, mapId)
val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
val dataBlockData = new FileSegmentManagedBuffer(transportConf, dataFile, 0, dataFile.length())
List((indexBlockId, indexBlockData), (dataBlockId, dataBlockData))
}


/**
* Write an index file with the offsets of each block, plus a final offset at the end for the
* end of the output file. This will be used by getBlockData to figure out where each block
Expand All @@ -169,7 +262,7 @@ private[spark] class IndexShuffleBlockResolver(
val dataFile = getDataFile(shuffleId, mapId)
// There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
// the following check and rename are atomic.
synchronized {
this.synchronized {
val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
if (existingLengths != null) {
// Another attempt for the same task has already written our map outputs successfully,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.StreamCallbackWithID
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.BlockId

/**
* :: Experimental ::
* An experimental trait to allow Spark to migrate shuffle blocks.
*/
@Experimental
@Since("3.1.0")
trait MigratableResolver {
/**
* Get the shuffle ids that are stored locally. Used for block migrations.
*/
def getStoredShuffles(): Seq[ShuffleBlockInfo]

/**
* Write a provided shuffle block as a stream. Used for block migrations.
*/
def putShuffleBlockAsStream(blockId: BlockId, serializerManager: SerializerManager):
StreamCallbackWithID

/**
* Get the blocks for migration for a particular shuffle and map.
*/
def getMigrationBlocks(shuffleBlockInfo: ShuffleBlockInfo): List[(BlockId, ManagedBuffer)]
}
Loading

0 comments on commit a4ca355

Please sign in to comment.