Skip to content

Commit

Permalink
Merge pull request #87 from aarondav/shuffle-base
Browse files Browse the repository at this point in the history
Basic shuffle file consolidation

The Spark shuffle phase can produce a large number of files, as one file is created
per mapper per reducer. For large or repeated jobs, this often produces millions of
shuffle files, which sees extremely degredaded performance from the OS file system.
This patch seeks to reduce that burden by combining multipe shuffle files into one.

This PR draws upon the work of @jason-dai in mesos/spark#669.
However, it simplifies the design in order to get the majority of the gain with less
overall intellectual and code burden. The vast majority of code in this pull request
is a refactor to allow the insertion of a clean layer of indirection between logical
block ids and physical files. This, I feel, provides some design clarity in addition
to enabling shuffle file consolidation.

The main goal is to produce one shuffle file per reducer per active mapper thread.
This allows us to isolate the mappers (simplifying the failure modes), while still
allowing us to reduce the number of mappers tremendously for large tasks. In order
to accomplish this, we simply create a new set of shuffle files for every parallel
task, and return the files to a pool which will be given out to the next run task.

I have run some ad hoc query testing on 5 m1.xlarge EC2 nodes with 2g of executor memory and the following microbenchmark:

    scala> val nums = sc.parallelize(1 to 1000, 1000).flatMap(x => (1 to 1e6.toInt))
    scala> def time(x: => Unit) = { val now = System.currentTimeMillis; x; System.currentTimeMillis - now }
    scala> (1 to 8).map(_ => time(nums.map(x => (x % 100000, 2000, x)).reduceByKey(_ + _).count) / 1000.0)

For this particular workload, with 1000 mappers and 2000 reducers, I saw the old method running at around 15 minutes, with the consolidated shuffle files running at around 4 minutes. There was a very sharp increase in running time for the non-consolidated version after around 1 million total shuffle files. Below this threshold, however, there wasn't a significant difference between the two.

Better performance measurement of this patch is warranted, and I plan on doing so in the near future as part of a general investigation of our shuffle file bottlenecks and performance.
  • Loading branch information
rxin committed Oct 22, 2013
2 parents a51359c + 053ef94 commit 48952d6
Show file tree
Hide file tree
Showing 13 changed files with 460 additions and 319 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.netty.channel.DefaultFileRegion;

import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.FileSegment;

class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {

Expand All @@ -37,40 +38,34 @@ public FileServerHandler(PathResolver pResolver){
@Override
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
BlockId blockId = BlockId.apply(blockIdString);
String path = pResolver.getAbsolutePath(blockId.name());
// if getFilePath returns null, close the channel
if (path == null) {
FileSegment fileSegment = pResolver.getBlockLocation(blockId);
// if getBlockLocation returns null, close the channel
if (fileSegment == null) {
//ctx.close();
return;
}
File file = new File(path);
File file = fileSegment.file();
if (file.exists()) {
if (!file.isFile()) {
//logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
long length = file.length();
long length = fileSegment.length();
if (length > Integer.MAX_VALUE || length <= 0) {
//logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
//logger.info("Sending block "+blockId+" filelen = "+len);
//logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
.getChannel(), 0, file.length()));
.getChannel(), fileSegment.offset(), fileSegment.length()));
} catch (Exception e) {
//logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
//logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@

package org.apache.spark.network.netty;

import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.FileSegment;

public interface PathResolver {
/**
* Get the absolute path of the file
*
* @param fileId
* @return the absolute path of file
*/
public String getAbsolutePath(String fileId);
/** Get the file segment in which the given block resides. */
public FileSegment getBlockLocation(BlockId blockId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.File

import org.apache.spark.Logging
import org.apache.spark.util.Utils
import org.apache.spark.storage.BlockId
import org.apache.spark.storage.{BlockId, FileSegment}


private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
Expand Down Expand Up @@ -54,8 +54,7 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))

val pResovler = new PathResolver {
override def getAbsolutePath(blockIdString: String): String = {
val blockId = BlockId(blockIdString)
override def getBlockLocation(blockId: BlockId): FileSegment = {
if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
Expand All @@ -65,7 +64,7 @@ private[spark] object ShuffleSender {
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId.name)
return file.getAbsolutePath
return new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ private[spark] class ShuffleMapTask(
var totalTime = 0L
val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
writer.commit()
writer.close()
val size = writer.size()
val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
Expand All @@ -191,6 +190,7 @@ private[spark] class ShuffleMapTask(
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && buckets != null) {
buckets.writers.foreach(_.close())
shuffle.releaseWriters(buckets)
}
// Execute the callbacks on task completion.
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.util.ByteBufferInputStream
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {

def run(attemptId: Long): T = {
final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
if (_killed) {
kill()
Expand Down
34 changes: 25 additions & 9 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import akka.dispatch.{Await, Future}
import akka.util.Duration
import akka.util.duration._

import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}

import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
Expand Down Expand Up @@ -102,18 +102,19 @@ private[spark] class BlockManager(
}

val shuffleBlockManager = new ShuffleBlockManager(this)
val diskBlockManager = new DiskBlockManager(
System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))

private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]

private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore: DiskStore =
new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
private[storage] val diskStore = new DiskStore(this, diskBlockManager)

// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}

val connectionManager = new ConnectionManager(0)
Expand Down Expand Up @@ -512,16 +513,20 @@ private[spark] class BlockManager(

/**
* A short circuited method to get a block writer that can write data directly to disk.
* The Block will be appended to the File specified by filename.
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true)
val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
writer.registerCloseEventHandler(() => {
diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
blockInfo.put(blockId, myInfo)
myInfo.markReady(writer.size())
myInfo.markReady(writer.fileSegment().length)
})
writer
}
Expand Down Expand Up @@ -862,13 +867,24 @@ private[spark] class BlockManager(
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}

/** Serializes into a stream. */
def dataSerializeStream(
blockId: BlockId,
outputStream: OutputStream,
values: Iterator[Any],
serializer: Serializer = defaultSerializer) {
val byteStream = new FastBufferedOutputStream(outputStream)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
}

/** Serializes into a byte buffer. */
def dataSerialize(
blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
dataSerializeStream(blockId, byteStream, values, serializer)
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}
Expand Down
128 changes: 126 additions & 2 deletions core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@

package org.apache.spark.storage

import java.io.{FileOutputStream, File, OutputStream}
import java.nio.channels.FileChannel

import it.unimi.dsi.fastutil.io.FastBufferedOutputStream

import org.apache.spark.Logging
import org.apache.spark.serializer.{SerializationStream, Serializer}

/**
* An interface for writing JVM objects to some underlying storage. This interface allows
Expand Down Expand Up @@ -59,12 +66,129 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
def write(value: Any)

/**
* Size of the valid writes, in bytes.
* Returns the file segment of committed data that this Writer has written.
*/
def size(): Long
def fileSegment(): FileSegment

/**
* Cumulative time spent performing blocking writes, in ns.
*/
def timeWriting(): Long
}

/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
class DiskBlockObjectWriter(
blockId: BlockId,
file: File,
serializer: Serializer,
bufferSize: Int,
compressStream: OutputStream => OutputStream)
extends BlockObjectWriter(blockId)
with Logging
{

/** Intercepts write calls and tracks total time spent writing. Not thread safe. */
private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
def timeWriting = _timeWriting
private var _timeWriting = 0L

private def callWithTiming(f: => Unit) = {
val start = System.nanoTime()
f
_timeWriting += (System.nanoTime() - start)
}

def write(i: Int): Unit = callWithTiming(out.write(i))
override def write(b: Array[Byte]) = callWithTiming(out.write(b))
override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
}

private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean

/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
private var bs: OutputStream = null
private var fos: FileOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
private var initialPosition = 0L
private var lastValidPosition = 0L
private var initialized = false
private var _timeWriting = 0L

override def open(): BlockObjectWriter = {
fos = new FileOutputStream(file, true)
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
initialPosition = channel.position
lastValidPosition = initialPosition
bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
initialized = true
this
}

override def close() {
if (initialized) {
if (syncWrites) {
// Force outstanding writes to disk and track how long it takes
objOut.flush()
val start = System.nanoTime()
fos.getFD.sync()
_timeWriting += System.nanoTime() - start
}
objOut.close()

_timeWriting += ts.timeWriting

channel = null
bs = null
fos = null
ts = null
objOut = null
}
// Invoke the close callback handler.
super.close()
}

override def isOpen: Boolean = objOut != null

override def commit(): Long = {
if (initialized) {
// NOTE: Flush the serializer first and then the compressed/buffered output stream
objOut.flush()
bs.flush()
val prevPos = lastValidPosition
lastValidPosition = channel.position()
lastValidPosition - prevPos
} else {
// lastValidPosition is zero if stream is uninitialized
lastValidPosition
}
}

override def revertPartialWrites() {
if (initialized) {
// Discard current writes. We do this by flushing the outstanding writes and
// truncate the file to the last valid position.
objOut.flush()
bs.flush()
channel.truncate(lastValidPosition)
}
}

override def write(value: Any) {
if (!initialized) {
open()
}
objOut.writeObject(value)
}

override def fileSegment(): FileSegment = {
val bytesWritten = lastValidPosition - initialPosition
new FileSegment(file, initialPosition, bytesWritten)
}

// Only valid if called after close()
override def timeWriting() = _timeWriting
}
Loading

0 comments on commit 48952d6

Please sign in to comment.