Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In memory shuffle #135

Open
wants to merge 2 commits into
base: vldb
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
import org.apache.spark.storage._
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
import java.nio.ByteBuffer

private[spark] object ShuffleMapTask {

Expand Down Expand Up @@ -167,7 +168,11 @@ private[spark] class ShuffleMapTask(
var totalBytes = 0L
var totalTime = 0L
val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
// writer.commit()
val bytes = writer.commit()
if (bytes != null) {
blockManager.putBytes(writer.blockId, ByteBuffer.wrap(bytes), StorageLevel.MEMORY_ONLY_SER, tellMaster = false)
}
writer.close()
val size = writer.fileSegment().length
totalBytes += size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ private[spark] class BlockManager(

private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]

private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private[storage] val memoryStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore = new DiskStore(this, diskBlockManager)

// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
Expand Down Expand Up @@ -261,7 +261,7 @@ private[spark] class BlockManager(
* never deletes (recent) items.
*/
def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
diskStore.getValues(blockId, serializer).orElse(
memoryStore.getValues(blockId, serializer).orElse(
sys.error("Block " + blockId + " not found on disk, though it should be"))
}

Expand All @@ -281,7 +281,7 @@ private[spark] class BlockManager(
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
if (blockId.isShuffle) {
diskStore.getBytes(blockId) match {
memoryStore.getBytes(blockId) match {
case Some(bytes) =>
Some(bytes)
case None =>
Expand Down Expand Up @@ -742,7 +742,7 @@ private[spark] class BlockManager(
if (info != null) info.synchronized {
// Removals are idempotent in disk store and memory store. At worst, we get a warning.
val removedFromMemory = memoryStore.remove(blockId)
val removedFromDisk = diskStore.remove(blockId)
val removedFromDisk = false //diskStore.remove(blockId)
if (!removedFromMemory && !removedFromDisk) {
logWarning("Block " + blockId + " could not be removed as it was not found in either " +
"the disk or memory store")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.storage

import java.io.{FileOutputStream, File, OutputStream}
import java.io.{ByteArrayOutputStream, FileOutputStream, File, OutputStream}
import java.nio.channels.FileChannel

import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
Expand All @@ -44,7 +44,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
* Flush the partial writes and commit them as a single atomic block. Return the
* number of bytes written for this commit.
*/
def commit(): Long
def commit(): Array[Byte]

/**
* Reverts writes that haven't been flushed yet. Callers should invoke this function
Expand Down Expand Up @@ -106,7 +106,7 @@ private[spark] class DiskBlockObjectWriter(
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
private var bs: OutputStream = null
private var fos: FileOutputStream = null
private var fos: ByteArrayOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
private val initialPosition = file.length()
Expand All @@ -115,9 +115,8 @@ private[spark] class DiskBlockObjectWriter(
private var _timeWriting = 0L

override def open(): BlockObjectWriter = {
fos = new FileOutputStream(file, true)
fos = new ByteArrayOutputStream()
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
lastValidPosition = initialPosition
bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
Expand All @@ -130,9 +129,6 @@ private[spark] class DiskBlockObjectWriter(
if (syncWrites) {
// Force outstanding writes to disk and track how long it takes
objOut.flush()
val start = System.nanoTime()
fos.getFD.sync()
_timeWriting += System.nanoTime() - start
}
objOut.close()

Expand All @@ -149,18 +145,18 @@ private[spark] class DiskBlockObjectWriter(

override def isOpen: Boolean = objOut != null

override def commit(): Long = {
override def commit(): Array[Byte] = {
if (initialized) {
// NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
// serializer stream and the lower level stream.
objOut.flush()
bs.flush()
val prevPos = lastValidPosition
lastValidPosition = channel.position()
lastValidPosition - prevPos
lastValidPosition = fos.size()
fos.toByteArray
} else {
// lastValidPosition is zero if stream is uninitialized
lastValidPosition
null
}
}

Expand All @@ -170,7 +166,7 @@ private[spark] class DiskBlockObjectWriter(
// truncate the file to the last valid position.
objOut.flush()
bs.flush()
channel.truncate(lastValidPosition)
throw new UnsupportedOperationException("Revert temporarily broken due to in memory shuffle code changes.")
}
}

Expand All @@ -182,7 +178,7 @@ private[spark] class DiskBlockObjectWriter(
}

override def fileSegment(): FileSegment = {
new FileSegment(file, initialPosition, bytesWritten)
new FileSegment(null, initialPosition, bytesWritten)
}

// Only valid if called after close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.LinkedHashMap
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.util.{SizeEstimator, Utils}
import org.apache.spark.serializer.Serializer

/**
* Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as
Expand Down Expand Up @@ -109,6 +110,14 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}

/**
* A version of getValues that allows a custom serializer. This is used as part of the
* shuffle short-circuit code.
*/
def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
}

override def remove(blockId: BlockId): Boolean = {
entries.synchronized {
val entry = entries.remove(blockId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,17 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
}
})
}

def removeAllShuffleStuff() {
for (state <- shuffleStates.values;
group <- state.allFileGroups;
(mapId, _) <- group.mapIdToIndex.iterator;
reducerId <- 0 until group.files.length) {
val blockId = new ShuffleBlockId(group.shuffleId, mapId, reducerId)
blockManager.removeBlock(blockId, tellMaster = false)
}
shuffleStates.clear()
}
}

private[spark]
Expand All @@ -200,7 +211,7 @@ object ShuffleBlockManager {
* Stores the absolute index of each mapId in the files of this group. For instance,
* if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
*/
private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()

/**
* Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
Expand Down
7 changes: 7 additions & 0 deletions graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.graphx

import scala.reflect.ClassTag
import org.apache.spark.SparkEnv


/**
Expand Down Expand Up @@ -142,6 +143,12 @@ object Pregel {
// hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the
// vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
activeMessages = messages.count()

// Very ugly code to clear the in-memory shuffle data
messages.foreachPartition { iter =>
SparkEnv.get.blockManager.shuffleBlockManager.removeAllShuffleStuff()
}

// Unpersist the RDDs hidden by newly-materialized RDDs
oldMessages.unpersist(blocking=false)
newVerts.unpersist(blocking=false)
Expand Down