diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index deabf6f5c8c5f..f856a13f84dec 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -21,27 +21,41 @@ import java.lang.ref.{ReferenceQueue, WeakReference} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -/** Listener class used for testing when any item has been cleaned by the Cleaner class */ -private[spark] trait CleanerListener { - def rddCleaned(rddId: Int) - def shuffleCleaned(shuffleId: Int) -} +/** + * Classes that represent cleaning tasks. + */ +private sealed trait CleanupTask +private case class CleanRDD(rddId: Int) extends CleanupTask +private case class CleanShuffle(shuffleId: Int) extends CleanupTask +private case class CleanBroadcast(broadcastId: Long) extends CleanupTask /** - * Cleans RDDs and shuffle data. + * A WeakReference associated with a CleanupTask. + * + * When the referent object becomes only weakly reachable, the corresponding + * CleanupTaskWeakReference is automatically added to the given reference queue. + */ +private class CleanupTaskWeakReference( + val task: CleanupTask, + referent: AnyRef, + referenceQueue: ReferenceQueue[AnyRef]) + extends WeakReference(referent, referenceQueue) + +/** + * An asynchronous cleaner for RDD, shuffle, and broadcast state. + * + * This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest, + * to be processed when the associated object goes out of scope of the application. Actual + * cleanup is performed in a separate daemon thread. */ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - /** Classes to represent cleaning tasks */ - private sealed trait CleanupTask - private case class CleanRDD(rddId: Int) extends CleanupTask - private case class CleanShuffle(shuffleId: Int) extends CleanupTask - // TODO: add CleanBroadcast + private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] + with SynchronizedBuffer[CleanupTaskWeakReference] - private val referenceBuffer = new ArrayBuffer[WeakReferenceWithCleanupTask] - with SynchronizedBuffer[WeakReferenceWithCleanupTask] private val referenceQueue = new ReferenceQueue[AnyRef] private val listeners = new ArrayBuffer[CleanerListener] @@ -49,77 +63,64 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} - private val REF_QUEUE_POLL_TIMEOUT = 100 - @volatile private var stopped = false - private class WeakReferenceWithCleanupTask(referent: AnyRef, val task: CleanupTask) - extends WeakReference(referent, referenceQueue) + /** Attach a listener object to get information of when objects are cleaned. */ + def attachListener(listener: CleanerListener) { + listeners += listener + } - /** Start the cleaner */ + /** Start the cleaner. */ def start() { cleaningThread.setDaemon(true) cleaningThread.setName("ContextCleaner") cleaningThread.start() } - /** Stop the cleaner */ + /** Stop the cleaner. */ def stop() { stopped = true cleaningThread.interrupt() } - /** - * Register a RDD for cleanup when it is garbage collected. - */ + /** Register a RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]) { registerForCleanup(rdd, CleanRDD(rdd.id)) } - /** - * Register a shuffle dependency for cleanup when it is garbage collected. - */ + /** Register a ShuffleDependency for cleanup when it is garbage collected. */ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } - /** Cleanup RDD. */ - def cleanupRDD(rdd: RDD[_]) { - doCleanupRDD(rdd.id) - } - - /** Cleanup shuffle. */ - def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { - doCleanupShuffle(shuffleDependency.shuffleId) - } - - /** Attach a listener object to get information of when objects are cleaned. */ - def attachListener(listener: CleanerListener) { - listeners += listener + /** Register a Broadcast for cleanup when it is garbage collected. */ + def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { + registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) } /** Register an object for cleanup. */ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { - referenceBuffer += new WeakReferenceWithCleanupTask(objectForCleanup, task) + referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) } - /** Keep cleaning RDDs and shuffle data */ + /** Keep cleaning RDD, shuffle, and broadcast state. */ private def keepCleaning() { - while (!isStopped) { + while (!stopped) { try { - val reference = Option(referenceQueue.remove(REF_QUEUE_POLL_TIMEOUT)) - .map(_.asInstanceOf[WeakReferenceWithCleanupTask]) + val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) + .map(_.asInstanceOf[CleanupTaskWeakReference]) reference.map(_.task).foreach { task => logDebug("Got cleaning task " + task) referenceBuffer -= reference.get task match { case CleanRDD(rddId) => doCleanupRDD(rddId) case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId) + case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId) } } } catch { case ie: InterruptedException => - if (!isStopped) logWarning("Cleaning thread interrupted") + if (!stopped) logWarning("Cleaning thread interrupted") case t: Throwable => logError("Error in cleaning thread", t) } } @@ -129,7 +130,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def doCleanupRDD(rddId: Int) { try { logDebug("Cleaning RDD " + rddId) - sc.unpersistRDD(rddId, false) + sc.unpersistRDD(rddId, blocking = false) listeners.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { @@ -150,10 +151,47 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - private def mapOutputTrackerMaster = - sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + /** Perform broadcast cleanup. */ + private def doCleanupBroadcast(broadcastId: Long) { + try { + logDebug("Cleaning broadcast " + broadcastId) + broadcastManager.unbroadcast(broadcastId, removeFromDriver = true) + listeners.foreach(_.broadcastCleaned(broadcastId)) + logInfo("Cleaned broadcast " + broadcastId) + } catch { + case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t) + } + } private def blockManagerMaster = sc.env.blockManager.master + private def broadcastManager = sc.env.broadcastManager + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + + // Used for testing + + private[spark] def cleanupRDD(rdd: RDD[_]) { + doCleanupRDD(rdd.id) + } + + private[spark] def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { + doCleanupShuffle(shuffleDependency.shuffleId) + } - private def isStopped = stopped + private[spark] def cleanupBroadcast[T](broadcast: Broadcast[T]) { + doCleanupBroadcast(broadcast.id) + } + +} + +private object ContextCleaner { + private val REF_QUEUE_POLL_TIMEOUT = 100 +} + +/** + * Listener class used for testing when any item has been cleaned by the Cleaner class. + */ +private[spark] trait CleanerListener { + def rddCleaned(rddId: Int) + def shuffleCleaned(shuffleId: Int) + def broadcastCleaned(broadcastId: Long) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5cd2caed10297..689180fcd719b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -642,7 +642,11 @@ class SparkContext( * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) + def broadcast[T](value: T) = { + val bc = env.broadcastManager.newBroadcast[T](value, isLocal) + cleaner.registerBroadcastForCleanup(bc) + bc + } /** * Add a file to be downloaded with this Spark job on every node. diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index e3e1e4f29b107..d75b9acfb7aa0 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -50,6 +50,12 @@ import java.io.Serializable abstract class Broadcast[T](val id: Long) extends Serializable { def value: T + /** + * Remove all persisted state associated with this broadcast. + * @param removeFromDriver Whether to remove state from the driver. + */ + def unpersist(removeFromDriver: Boolean) + // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 0a0bb6cca336c..850650951e603 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -29,5 +29,6 @@ import org.apache.spark.SparkConf trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def unbroadcast(id: Long, removeFromDriver: Boolean) def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 746e23e81931a..85d62aae03959 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -60,4 +60,8 @@ private[spark] class BroadcastManager( broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } + def unbroadcast(id: Long, removeFromDriver: Boolean) { + broadcastFactory.unbroadcast(id, removeFromDriver) + } + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 374180e472805..89361efec44a4 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -21,10 +21,9 @@ import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} import java.net.{URL, URLConnection, URI} import java.util.concurrent.TimeUnit -import it.unimi.dsi.fastutil.io.FastBufferedInputStream -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream +import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream} -import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv} +import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} @@ -32,18 +31,27 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + override def value = value_ - def blockId = BroadcastBlockId(id) + val blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } if (!isLocal) { HttpBroadcast.write(id, value_) } + /** + * Remove all persisted state associated with this HTTP broadcast. + * @param removeFromDriver Whether to remove state from the driver. + */ + override def unpersist(removeFromDriver: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver) + } + // Called by JVM when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject() @@ -54,7 +62,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -63,7 +72,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } } -private object HttpBroadcast extends Logging { +private[spark] object HttpBroadcast extends Logging { private var initialized = false private var broadcastDir: File = null @@ -74,7 +83,7 @@ private object HttpBroadcast extends Logging { private var securityManager: SecurityManager = null // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist - private val files = new TimeStampedHashSet[String] + val files = new TimeStampedHashSet[String] private var cleaner: MetadataCleaner = null private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt @@ -122,8 +131,10 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } + def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def write(id: Long, value: Any) { - val file = new File(broadcastDir, BroadcastBlockId(id).name) + val file = getFile(id) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -146,7 +157,7 @@ private object HttpBroadcast extends Logging { if (securityManager.isAuthenticationEnabled()) { logDebug("broadcast security enabled") val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) - uc = newuri.toURL().openConnection() + uc = newuri.toURL.openConnection() uc.setAllowUserInteraction(false) } else { logDebug("broadcast not using security") @@ -155,7 +166,7 @@ private object HttpBroadcast extends Logging { val in = { uc.setReadTimeout(httpReadTimeout) - val inputStream = uc.getInputStream(); + val inputStream = uc.getInputStream if (compress) { compressionCodec.compressedInputStream(inputStream) } else { @@ -169,20 +180,50 @@ private object HttpBroadcast extends Logging { obj } - def cleanup(cleanupTime: Long) { + /** + * Remove all persisted blocks associated with this HTTP broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver + * and delete the associated broadcast file. + */ + def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { + //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + if (removeFromDriver) { + val file = new File(broadcastDir, BroadcastBlockId(id).name) + files.remove(file.toString) + deleteBroadcastFile(file) + } + } + + /** + * Periodically clean up old broadcasts by removing the associated map entries and + * deleting the associated files. + */ + private def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - try { - iterator.remove() - new File(file.toString).delete() - logInfo("Deleted broadcast file '" + file + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) - } + iterator.remove() + deleteBroadcastFile(new File(file.toString)) } } } + + /** Delete the given broadcast file. */ + private def deleteBroadcastFile(file: File) { + try { + if (!file.exists()) { + logWarning("Broadcast file to be deleted does not exist: %s".format(file)) + } else if (file.delete()) { + logInfo("Deleted broadcast file: %s".format(file)) + } else { + logWarning("Could not delete broadcast file: %s".format(file)) + } + } catch { + case e: Exception => + logWarning("Exception while deleting broadcast file: %s".format(file), e) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index c4f0f149534a5..4affa922156c9 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -31,4 +31,12 @@ class HttpBroadcastFactory extends BroadcastFactory { new HttpBroadcast[T](value_, isLocal, id) def stop() { HttpBroadcast.stop() } + + /** + * Remove all persisted state associated with the HTTP broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver. + */ + def unbroadcast(id: Long, removeFromDriver: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver) + } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 0828035c5d217..07ef54bb120b9 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -29,12 +29,13 @@ import org.apache.spark.util.Utils private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + override def value = value_ - def broadcastId = BroadcastBlockId(id) + val broadcastId = BroadcastBlockId(id) TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } @transient var arrayOfBlocks: Array[TorrentBlock] = null @@ -47,8 +48,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } def sendBroadcast() { - var tInfo = TorrentBroadcast.blockifyObject(value_) - + val tInfo = TorrentBroadcast.blockifyObject(value_) totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes hasBlocks = tInfo.totalBlocks @@ -58,7 +58,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true) + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) } // Store individual pieces @@ -66,11 +66,19 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) } } } + /** + * Remove all persisted state associated with this HTTP broadcast. + * @param removeFromDriver Whether to remove state from the driver. + */ + override def unpersist(removeFromDriver: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver) + } + // Called by JVM when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject() @@ -86,18 +94,18 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Initialize @transient variables that will receive garbage values from the master. resetWorkerVariables() - if (receiveBroadcast(id)) { + if (receiveBroadcast()) { value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - // Store the merged copy in cache so that the next worker doesn't need to rebuild it. - // This creates a tradeoff between memory usage and latency. - // Storing copy doubles the memory footprint; not storing doubles deserialization cost. + /* Store the merged copy in cache so that the next worker doesn't need to rebuild it. + * This creates a trade-off between memory usage and latency. Storing copy doubles + * the memory footprint; not storing doubles deserialization cost. */ SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() - } else { + } else { logError("Reading broadcast variable " + id + " failed") } @@ -114,7 +122,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo hasBlocks = 0 } - def receiveBroadcast(variableID: Long): Boolean = { + def receiveBroadcast(): Boolean = { // Receive meta-info val metaId = BroadcastHelperBlockId(broadcastId, "meta") var attemptId = 10 @@ -148,7 +156,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] hasBlocks += 1 SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) case None => throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) @@ -156,15 +164,17 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } } - (hasBlocks == totalBlocks) + hasBlocks == totalBlocks } } -private object TorrentBroadcast extends Logging { - +private[spark] object TorrentBroadcast extends Logging { private var initialized = false private var conf: SparkConf = null + + lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + def initialize(_isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.conf = conf //TODO: we might have to fix it in tests synchronized { @@ -178,39 +188,37 @@ private object TorrentBroadcast extends Logging { initialized = false } - lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 - def blockifyObject[T](obj: T): TorrentInfo = { val byteArray = Utils.serialize[T](obj) val bais = new ByteArrayInputStream(byteArray) - var blockNum = (byteArray.length / BLOCK_SIZE) + var blockNum = byteArray.length / BLOCK_SIZE if (byteArray.length % BLOCK_SIZE != 0) { blockNum += 1 } - var retVal = new Array[TorrentBlock](blockNum) - var blockID = 0 + val blocks = new Array[TorrentBlock](blockNum) + var blockId = 0 for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + val tempByteArray = new Array[Byte](thisBlockSize) + bais.read(tempByteArray, 0, thisBlockSize) - retVal(blockID) = new TorrentBlock(blockID, tempByteArray) - blockID += 1 + blocks(blockId) = new TorrentBlock(blockId, tempByteArray) + blockId += 1 } bais.close() - val tInfo = TorrentInfo(retVal, blockNum, byteArray.length) - tInfo.hasBlocks = blockNum - - tInfo + val info = TorrentInfo(blocks, blockNum, byteArray.length) + info.hasBlocks = blockNum + info } - def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], - totalBytes: Int, - totalBlocks: Int): T = { + def unBlockifyObject[T]( + arrayOfBlocks: Array[TorrentBlock], + totalBytes: Int, + totalBlocks: Int): T = { val retByteArray = new Array[Byte](totalBytes) for (i <- 0 until totalBlocks) { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, @@ -219,6 +227,14 @@ private object TorrentBroadcast extends Logging { Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) } + /** + * Remove all persisted blocks associated with this torrent broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver. + */ + def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { + //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + } + } private[spark] case class TorrentBlock( @@ -227,7 +243,7 @@ private[spark] case class TorrentBlock( extends Serializable private[spark] case class TorrentInfo( - @transient arrayOfBlocks : Array[TorrentBlock], + @transient arrayOfBlocks: Array[TorrentBlock], totalBlocks: Int, totalBytes: Int) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index a51c438c57717..eabe792b550bb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -33,4 +33,11 @@ class TorrentBroadcastFactory extends BroadcastFactory { def stop() { TorrentBroadcast.stop() } + /** + * Remove all persisted state associated with the torrent broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver. + */ + def unbroadcast(id: Long, removeFromDriver: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 50ea4e31ce509..4c5b31d0abe44 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -35,7 +35,7 @@ private[storage] object BlockManagerMessages { case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave // Remove all blocks belonging to a specific shuffle. - case class RemoveShuffle(shuffleId: Int) + case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave ////////////////////////////////////////////////////////////////////////////////// diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index a6ff147c1d3e6..9a12481b7f6d5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -29,8 +29,9 @@ import org.apache.spark.storage.BlockManagerMessages._ private[storage] class BlockManagerSlaveActor( blockManager: BlockManager, - mapOutputTracker: MapOutputTracker - ) extends Actor { + mapOutputTracker: MapOutputTracker) + extends Actor { + override def receive = { case RemoveBlock(blockId) => diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index b07f8817b7974..11e22145ebb88 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark +import java.lang.ref.WeakReference + import scala.collection.mutable.{ArrayBuffer, HashSet, SynchronizedSet} +import scala.util.Random import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually @@ -26,9 +29,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} -import org.apache.spark.rdd.{ShuffleCoGroupSplitDep, RDD} -import scala.util.Random -import java.lang.ref.WeakReference +import org.apache.spark.rdd.RDD class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { @@ -67,7 +68,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo test("automatically cleanup RDD") { var rdd = newRDD.persist() rdd.count() - + // test that GC does not cause RDD cleanup due to a strong reference val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) runGC() @@ -171,11 +172,16 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo /** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */ -class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[Int] = Nil) +class CleanerTester( + sc: SparkContext, + rddIds: Seq[Int] = Seq.empty, + shuffleIds: Seq[Int] = Seq.empty, + broadcastIds: Seq[Long] = Seq.empty) extends Logging { val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds + val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds val cleanerListener = new CleanerListener { def rddCleaned(rddId: Int): Unit = { @@ -187,6 +193,11 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In toBeCleanedShuffleIds -= shuffleId logInfo("Shuffle " + shuffleId + " cleaned") } + + def broadcastCleaned(broadcastId: Long): Unit = { + toBeCleanedBroadcstIds -= broadcastId + logInfo("Broadcast" + broadcastId + " cleaned") + } } val MAX_VALIDATION_ATTEMPTS = 10