Skip to content

Commit

Permalink
Added more documentation on Broadcast implementations, specially whic…
Browse files Browse the repository at this point in the history
…h blocks are told about to the driver. Also, fixed Broadcast API to hide destroy functionality.
  • Loading branch information
tdas committed Apr 7, 2014
1 parent 41c9ece commit 2b95b5e
Showing 7 changed files with 122 additions and 38 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
@@ -76,7 +76,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
/** Start the cleaner. */
def start() {
cleaningThread.setDaemon(true)
cleaningThread.setName("ContextCleaner")
cleaningThread.setName("Spark Context Cleaner")
cleaningThread.start()
}

58 changes: 44 additions & 14 deletions core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
@@ -28,7 +28,8 @@ import org.apache.spark.SparkException
* attempts to distribute broadcast variables using efficient broadcast algorithms to reduce
* communication cost.
*
* Broadcast variables are created from a variable `v` by calling [[SparkContext#broadcast]].
* Broadcast variables are created from a variable `v` by calling
* [[org.apache.spark.SparkContext#broadcast]].
* The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the
* `value` method. The interpreter session below shows this:
*
@@ -51,15 +52,17 @@ import org.apache.spark.SparkException
*/
abstract class Broadcast[T](val id: Long) extends Serializable {

protected var _isValid: Boolean = true

/**
* Whether this Broadcast is actually usable. This should be false once persisted state is
* removed from the driver.
* Flag signifying whether the broadcast variable is valid
* (that is, not already destroyed) or not.
*/
def isValid: Boolean = _isValid
@volatile private var _isValid = true

def value: T
/** Get the broadcasted value. */
def value: T = {
assertValid()
getValue()
}

/**
* Asynchronously delete cached copies of this broadcast on the executors.
@@ -74,23 +77,50 @@ abstract class Broadcast[T](val id: Long) extends Serializable {
* this is called, it will need to be re-sent to each executor.
* @param blocking Whether to block until unpersisting has completed
*/
def unpersist(blocking: Boolean)
def unpersist(blocking: Boolean) {
assertValid()
doUnpersist(blocking)
}

/**
* Remove all persisted state associated with this broadcast on both the executors and
* the driver.
* Destroy all data and metadata related to this broadcast variable. Use this with caution;
* once a broadcast variable has been destroyed, it cannot be used again.
*/
private[spark] def destroy(blocking: Boolean) {
assertValid()
_isValid = false
onDestroy(blocking)
doDestroy(blocking)
}

protected def onDestroy(blocking: Boolean)
/**
* Whether this Broadcast is actually usable. This should be false once persisted state is
* removed from the driver.
*/
private[spark] def isValid: Boolean = {
_isValid
}

/**
* Actually get the broadcasted value. Concrete implementations of Broadcast class must
* define their own way to get the value.
*/
private[spark] def getValue(): T

/**
* If this broadcast is no longer valid, throw an exception.
* Actually unpersist the broadcasted value on the executors. Concrete implementations of
* Broadcast class must define their own logic to unpersist their own data.
*/
protected def assertValid() {
private[spark] def doUnpersist(blocking: Boolean)

/**
* Actually destroy all data and metadata related to this broadcast variable.
* Implementation of Broadcast class must define their own logic to destroy their own
* state.
*/
private[spark] def doDestroy(blocking: Boolean)

/** Check if this broadcast is valid. If not valid, exception is thrown. */
private[spark] def assertValid() {
if (!_isValid) {
throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
}
36 changes: 28 additions & 8 deletions core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
Original file line number Diff line number Diff line change
@@ -28,16 +28,26 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}

/**
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server
* as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a
* task) is deserialized in the executor, the broadcasted data is fetched from the driver
* (through a HTTP server running at the driver) and stored in the BlockManager of the
* executor to speed up future accesses.
*/
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {

def value: T = {
assertValid()
value_
}
def getValue = value_

val blockId = BroadcastBlockId(id)

/*
* Broadcasted data is also stored in the BlockManager of the driver.
* The BlockManagerMaster
* does not need to be told about this block as not only
* need to know about this data block.
*/
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
@@ -50,21 +60,24 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
/**
* Remove all persisted state associated with this HTTP broadcast on the executors.
*/
def unpersist(blocking: Boolean) {
def doUnpersist(blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver = false, blocking)
}

protected def onDestroy(blocking: Boolean) {
/**
* Remove all persisted state associated with this HTTP broadcast on the executors and driver.
*/
def doDestroy(blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver = true, blocking)
}

// Used by the JVM when serializing this object
/** Used by the JVM when serializing this object. */
private def writeObject(out: ObjectOutputStream) {
assertValid()
out.defaultWriteObject()
}

// Used by the JVM when deserializing this object
/** Used by the JVM when deserializing this object. */
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
HttpBroadcast.synchronized {
@@ -74,6 +87,13 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
/*
* Storing the broadcast data in BlockManager so that all
* so that all subsequent tasks using the broadcast variable
* does not need to fetch it again. The BlockManagerMaster
* does not need to be told about this block as no one
* needs to know about this data block.
*/
SparkEnv.get.blockManager.putSingle(
blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
val time = (System.nanoTime - start) / 1e9
Original file line number Diff line number Diff line change
@@ -20,7 +20,9 @@ package org.apache.spark.broadcast
import org.apache.spark.{SecurityManager, SparkConf}

/**
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
* A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a
* HTTP server as the broadcast mechanism. Refer to
* [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism.
*/
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
Original file line number Diff line number Diff line change
@@ -26,13 +26,28 @@ import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils

/**
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
* protocol to do a distributed transfer of the broadcasted data to the executors.
* The mechanism is as follows. The driver divides the serializes the broadcasted data,
* divides it into smaller chunks, and stores them in the BlockManager of the driver.
* These chunks are reported to the BlockManagerMaster so that all the executors can
* learn the location of those chunks. The first time the broadcast variable (sent as
* part of task) is deserialized at a executor, all the chunks are fetched using
* the BlockManager. When all the chunks are fetched (initially from the driver's
* BlockManager), they are combined and deserialized to recreate the broadcasted data.
* However, the chunks are also stored in the BlockManager and reported to the
* BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns
* multiple locations for each chunk. Hence, subsequent fetches of each chunk will be
* made to other executors who already have those chunks, resulting in a distributed
* fetching. This prevents the driver from being the bottleneck in sending out multiple
* copies of the broadcast data (one per executor) as done by the
* [[org.apache.spark.broadcast.HttpBroadcast]].
*/
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {

def value = {
assertValid()
value_
}
def getValue = value_

val broadcastId = BroadcastBlockId(id)

@@ -53,15 +68,19 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
/**
* Remove all persisted state associated with this Torrent broadcast on the executors.
*/
def unpersist(blocking: Boolean) {
def doUnpersist(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
}

protected def onDestroy(blocking: Boolean) {
/**
* Remove all persisted state associated with this Torrent broadcast on the executors
* and driver.
*/
def doDestroy(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
}

private def sendBroadcast() {
def sendBroadcast() {
val tInfo = TorrentBroadcast.blockifyObject(value_)
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
@@ -85,13 +104,13 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
}
}

// Used by the JVM when serializing this object
/** Used by the JVM when serializing this object. */
private def writeObject(out: ObjectOutputStream) {
assertValid()
out.defaultWriteObject()
}

// Used by the JVM when deserializing this object
/** Used by the JVM when deserializing this object. */
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
@@ -111,7 +130,11 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo

/* Store the merged copy in cache so that the next worker doesn't need to rebuild it.
* This creates a trade-off between memory usage and latency. Storing copy doubles
* the memory footprint; not storing doubles deserialization cost. */
* the memory footprint; not storing doubles deserialization cost. Also,
* this does not need to be reported to BlockManagerMaster since other executors
* does not need to access this block (they only need to fetch the chunks,
* which are reported).
*/
SparkEnv.get.blockManager.putSingle(
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)

@@ -135,7 +158,8 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
}

def receiveBroadcast(): Boolean = {
// Receive meta-info
// Receive meta-info about the size of broadcast data,
// the number of chunks it is divided into, etc.
val metaId = BroadcastBlockId(id, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
@@ -158,7 +182,11 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
return false
}

// Receive actual blocks
/*
* Fetch actual chunks of data. Note that all these chunks are stored in
* the BlockManager and reported to the master, so that other executors
* can find out and pull the chunks from this executor.
*/
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
for (pid <- recvOrder) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
Original file line number Diff line number Diff line change
@@ -20,7 +20,9 @@ package org.apache.spark.broadcast
import org.apache.spark.{SecurityManager, SparkConf}

/**
* A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast.
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
* protocol to do a distributed transfer of the broadcasted data to the executors. Refer to
* [[org.apache.spark.broadcast.TorrentBroadcast]] for more details.
*/
class TorrentBroadcastFactory extends BroadcastFactory {

4 changes: 3 additions & 1 deletion core/src/test/scala/org/apache/spark/BroadcastSuite.scala
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ package org.apache.spark
import org.scalatest.FunSuite

import org.apache.spark.storage._
import org.apache.spark.broadcast.HttpBroadcast
import org.apache.spark.broadcast.{Broadcast, HttpBroadcast}
import org.apache.spark.storage.BroadcastBlockId

class BroadcastSuite extends FunSuite with LocalSparkContext {
@@ -298,6 +298,8 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
// Using this variable on the executors crashes them, which hangs the test.
// Instead, crash the driver by directly accessing the broadcast value.
intercept[SparkException] { broadcast.value }
intercept[SparkException] { broadcast.unpersist() }
intercept[SparkException] { broadcast.destroy(blocking = true) }
} else {
val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)

0 comments on commit 2b95b5e

Please sign in to comment.