From 7b369ca66075c40dafef5ee29607b2a956b25787 Mon Sep 17 00:00:00 2001 From: "Joseph E. Gonzalez" Date: Sat, 1 Mar 2014 17:03:40 -0800 Subject: [PATCH 1/2] Switching to in-memory-shuffle based on changes described in https://github.com/amplab/graphx/commits/sigmod-memory-shuffle --- .../spark/scheduler/ShuffleMapTask.scala | 7 +++++- .../apache/spark/storage/BlockManager.scala | 8 +++---- .../spark/storage/BlockObjectWriter.scala | 24 ++++++++----------- .../apache/spark/storage/MemoryStore.scala | 9 +++++++ .../spark/storage/ShuffleBlockManager.scala | 13 +++++++++- .../org/apache/spark/graphx/Pregel.scala | 8 +++++++ 6 files changed, 49 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 77789031f..f69f6a604 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -28,6 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDDCheckpointData import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import java.nio.ByteBuffer private[spark] object ShuffleMapTask { @@ -167,7 +168,11 @@ private[spark] class ShuffleMapTask( var totalBytes = 0L var totalTime = 0L val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => - writer.commit() + // writer.commit() + val bytes = writer.commit() + if (bytes != null) { + blockManager.putBytes(writer.blockId, ByteBuffer.wrap(bytes), StorageLevel.MEMORY_ONLY_SER, tellMaster = false) + } writer.close() val size = writer.fileSegment().length totalBytes += size diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a734ddc1e..f7d147845 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -50,7 +50,7 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] - private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private[storage] val memoryStore = new MemoryStore(this, maxMemory) private[storage] val diskStore = new DiskStore(this, diskBlockManager) // If we use Netty for shuffle, start a new Netty-based shuffle sender service. @@ -261,7 +261,7 @@ private[spark] class BlockManager( * never deletes (recent) items. */ def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - diskStore.getValues(blockId, serializer).orElse( + memoryStore.getValues(blockId, serializer).orElse( sys.error("Block " + blockId + " not found on disk, though it should be")) } @@ -281,7 +281,7 @@ private[spark] class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { - diskStore.getBytes(blockId) match { + memoryStore.getBytes(blockId) match { case Some(bytes) => Some(bytes) case None => @@ -742,7 +742,7 @@ private[spark] class BlockManager( if (info != null) info.synchronized { // Removals are idempotent in disk store and memory store. At worst, we get a warning. val removedFromMemory = memoryStore.remove(blockId) - val removedFromDisk = diskStore.remove(blockId) + val removedFromDisk = false //diskStore.remove(blockId) if (!removedFromMemory && !removedFromDisk) { logWarning("Block " + blockId + " could not be removed as it was not found in either " + "the disk or memory store") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 696b930a2..4d7e1852f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{FileOutputStream, File, OutputStream} +import java.io.{ByteArrayOutputStream, FileOutputStream, File, OutputStream} import java.nio.channels.FileChannel import it.unimi.dsi.fastutil.io.FastBufferedOutputStream @@ -44,7 +44,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { * Flush the partial writes and commit them as a single atomic block. Return the * number of bytes written for this commit. */ - def commit(): Long + def commit(): Array[Byte] /** * Reverts writes that haven't been flushed yet. Callers should invoke this function @@ -106,7 +106,7 @@ private[spark] class DiskBlockObjectWriter( /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null private var bs: OutputStream = null - private var fos: FileOutputStream = null + private var fos: ByteArrayOutputStream = null private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null private val initialPosition = file.length() @@ -115,9 +115,8 @@ private[spark] class DiskBlockObjectWriter( private var _timeWriting = 0L override def open(): BlockObjectWriter = { - fos = new FileOutputStream(file, true) + fos = new ByteArrayOutputStream() ts = new TimeTrackingOutputStream(fos) - channel = fos.getChannel() lastValidPosition = initialPosition bs = compressStream(new FastBufferedOutputStream(ts, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) @@ -130,9 +129,6 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - val start = System.nanoTime() - fos.getFD.sync() - _timeWriting += System.nanoTime() - start } objOut.close() @@ -149,18 +145,18 @@ private[spark] class DiskBlockObjectWriter( override def isOpen: Boolean = objOut != null - override def commit(): Long = { + override def commit(): Array[Byte] = { if (initialized) { // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the // serializer stream and the lower level stream. objOut.flush() bs.flush() val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos + lastValidPosition = fos.size() + fos.toByteArray } else { // lastValidPosition is zero if stream is uninitialized - lastValidPosition + null } } @@ -170,7 +166,7 @@ private[spark] class DiskBlockObjectWriter( // truncate the file to the last valid position. objOut.flush() bs.flush() - channel.truncate(lastValidPosition) + throw new UnsupportedOperationException("Revert temporarily broken due to in memory shuffle code changes.") } } @@ -182,7 +178,7 @@ private[spark] class DiskBlockObjectWriter( } override def fileSegment(): FileSegment = { - new FileSegment(file, initialPosition, bytesWritten) + new FileSegment(null, initialPosition, bytesWritten) } // Only valid if called after close() diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 181417565..8b1a2e06f 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -23,6 +23,7 @@ import java.util.LinkedHashMap import scala.collection.mutable.ArrayBuffer import org.apache.spark.util.{SizeEstimator, Utils} +import org.apache.spark.serializer.Serializer /** * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as @@ -109,6 +110,14 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + /** + * A version of getValues that allows a custom serializer. This is used as part of the + * shuffle short-circuit code. + */ + def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { + getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) + } + override def remove(blockId: BlockId): Boolean = { entries.synchronized { val entry = entries.remove(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index bb07c8cb1..a80b19743 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -187,6 +187,17 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { } }) } + + def removeAllShuffleStuff() { + for (state <- shuffleStates.values; + group <- state.allFileGroups; + (mapId, _) <- group.mapIdToIndex.iterator; + reducerId <- 0 until group.files.length) { + val blockId = new ShuffleBlockId(group.shuffleId, mapId, reducerId) + blockManager.removeBlock(blockId, tellMaster = false) + } + shuffleStates.clear() + } } private[spark] @@ -200,7 +211,7 @@ object ShuffleBlockManager { * Stores the absolute index of each mapId in the files of this group. For instance, * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. */ - private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() + val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() /** * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index ac07a594a..e5f109e04 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -18,6 +18,7 @@ package org.apache.spark.graphx import scala.reflect.ClassTag +import org.apache.spark.SparkEnv /** @@ -142,6 +143,13 @@ object Pregel { // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). activeMessages = messages.count() + + // Very ugly code to clear the in-memory shuffle data + messages.mapPartitions { iter => + SparkEnv.get.blockManager.shuffleBlockManager.removeAllShuffleStuff() + iter + } + // Unpersist the RDDs hidden by newly-materialized RDDs oldMessages.unpersist(blocking=false) newVerts.unpersist(blocking=false) From d67c73b1629832ff85f4548c2582ffb795509741 Mon Sep 17 00:00:00 2001 From: "Joseph E. Gonzalez" Date: Sat, 1 Mar 2014 17:42:45 -0800 Subject: [PATCH 2/2] using a foreach instead of mapPartitions --- graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index e5f109e04..b0c380d18 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -145,9 +145,8 @@ object Pregel { activeMessages = messages.count() // Very ugly code to clear the in-memory shuffle data - messages.mapPartitions { iter => + messages.foreachPartition { iter => SparkEnv.get.blockManager.shuffleBlockManager.removeAllShuffleStuff() - iter } // Unpersist the RDDs hidden by newly-materialized RDDs