From 544ac866edf21230140fe56ee7a428fe0ab86329 Mon Sep 17 00:00:00 2001
From: Andrew Or <andrewor14@gmail.com>
Date: Wed, 26 Mar 2014 15:11:42 -0700
Subject: [PATCH] Clean up broadcast blocks through BlockManager*

---
 .../apache/spark/broadcast/HttpBroadcast.scala   |  2 +-
 .../spark/broadcast/TorrentBroadcast.scala       |  2 +-
 .../org/apache/spark/storage/BlockManager.scala  | 14 +++++++++++++-
 .../spark/storage/BlockManagerMaster.scala       |  7 +++++++
 .../spark/storage/BlockManagerMasterActor.scala  | 16 +++++++++++++---
 .../spark/storage/BlockManagerMessages.scala     | 13 ++++++++++---
 .../spark/storage/BlockManagerSlaveActor.scala   |  3 +++
 .../main/scala/org/apache/spark/util/Utils.scala |  8 ++++----
 .../org/apache/spark/ContextCleanerSuite.scala   |  2 +-
 9 files changed, 53 insertions(+), 14 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 89361efec44a4..4985d4202ed6b 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -186,7 +186,7 @@ private[spark] object HttpBroadcast extends Logging {
    * and delete the associated broadcast file.
    */
   def unpersist(id: Long, removeFromDriver: Boolean) = synchronized {
-    //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver)
+    SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver)
     if (removeFromDriver) {
       val file = new File(broadcastDir, BroadcastBlockId(id).name)
       files.remove(file.toString)
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 07ef54bb120b9..51f1592cef752 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -232,7 +232,7 @@ private[spark] object TorrentBroadcast extends Logging {
    * If removeFromDriver is true, also remove these persisted blocks on the driver.
    */
   def unpersist(id: Long, removeFromDriver: Boolean) = synchronized {
-    //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver)
+    SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver)
   }
 
 }
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index ca23513c4dc64..3c0941e195724 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -820,10 +820,22 @@ private[spark] class BlockManager(
     // from RDD.id to blocks.
     logInfo("Removing RDD " + rddId)
     val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
-    blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false))
+    blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) }
     blocksToRemove.size
   }
 
+  /**
+   * Remove all blocks belonging to the given broadcast.
+   */
+  def removeBroadcast(broadcastId: Long) {
+    logInfo("Removing broadcast " + broadcastId)
+    val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect {
+      case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid
+      case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid
+    }
+    blocksToRemove.foreach { blockId => removeBlock(blockId) }
+  }
+
   /**
    * Remove a block from both memory and disk.
    */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index ff3f22b3b092a..4579c0d959553 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -126,6 +126,13 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
     askDriverWithReply(RemoveShuffle(shuffleId))
   }
 
+  /**
+   * Remove all blocks belonging to the given broadcast.
+   */
+  def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) {
+    askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster))
+  }
+
   /**
    * Return the memory status for each block manager, in the form of a map from
    * the block manager's id to two long values. The first value is the maximum
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 646ccb7fa74f6..4cc4227fd87e2 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -100,6 +100,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
       removeShuffle(shuffleId)
       sender ! true
 
+    case RemoveBroadcast(broadcastId, removeFromDriver) =>
+      removeBroadcast(broadcastId, removeFromDriver)
+      sender ! true
+
     case RemoveBlock(blockId) =>
       removeBlockFromWorkers(blockId)
       sender ! true
@@ -151,9 +155,15 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
   private def removeShuffle(shuffleId: Int) {
     // Nothing to do in the BlockManagerMasterActor data structures
     val removeMsg = RemoveShuffle(shuffleId)
-    blockManagerInfo.values.foreach { bm =>
-      bm.slaveActor ! removeMsg
-    }
+    blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg }
+  }
+
+  private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) {
+    // TODO(aor): Consolidate usages of <driver>
+    val removeMsg = RemoveBroadcast(broadcastId)
+    blockManagerInfo.values
+      .filter { info => removeFromDriver || info.blockManagerId.executorId != "<driver>" }
+      .foreach { bm => bm.slaveActor ! removeMsg }
   }
 
   private def removeBlockManager(blockManagerId: BlockManagerId) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 4c5b31d0abe44..3ea710ebc786e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -22,9 +22,11 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
 import akka.actor.ActorRef
 
 private[storage] object BlockManagerMessages {
+
   //////////////////////////////////////////////////////////////////////////////////
   // Messages from the master to slaves.
   //////////////////////////////////////////////////////////////////////////////////
+
   sealed trait ToBlockManagerSlave
 
   // Remove a block from the slaves that have it. This can only be used to remove
@@ -37,10 +39,15 @@ private[storage] object BlockManagerMessages {
   // Remove all blocks belonging to a specific shuffle.
   case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave
 
+  // Remove all blocks belonging to a specific broadcast.
+  case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true)
+    extends ToBlockManagerSlave
+
 
   //////////////////////////////////////////////////////////////////////////////////
   // Messages from slaves to the master.
   //////////////////////////////////////////////////////////////////////////////////
+
   sealed trait ToBlockManagerMaster
 
   case class RegisterBlockManager(
@@ -57,8 +64,7 @@ private[storage] object BlockManagerMessages {
       var storageLevel: StorageLevel,
       var memSize: Long,
       var diskSize: Long)
-    extends ToBlockManagerMaster
-    with Externalizable {
+    extends ToBlockManagerMaster with Externalizable {
 
     def this() = this(null, null, null, 0, 0)  // For deserialization only
 
@@ -80,7 +86,8 @@ private[storage] object BlockManagerMessages {
   }
 
   object UpdateBlockInfo {
-    def apply(blockManagerId: BlockManagerId,
+    def apply(
+        blockManagerId: BlockManagerId,
         blockId: BlockId,
         storageLevel: StorageLevel,
         memSize: Long,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index 9a12481b7f6d5..8c2ccbe6a7e66 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -46,5 +46,8 @@ class BlockManagerSlaveActor(
       if (mapOutputTracker != null) {
         mapOutputTracker.unregisterShuffle(shuffleId)
       }
+
+    case RemoveBroadcast(broadcastId, _) =>
+      blockManager.removeBroadcast(broadcastId)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index ad87fda140476..e541591ee7582 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -461,10 +461,10 @@ private[spark] object Utils extends Logging {
   private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
 
   def parseHostPort(hostPort: String): (String,  Int) = {
-    {
-      // Check cache first.
-      val cached = hostPortParseResults.get(hostPort)
-      if (cached != null) return cached
+    // Check cache first.
+    val cached = hostPortParseResults.get(hostPort)
+    if (cached != null) {
+      return cached
     }
 
     val indx: Int = hostPort.lastIndexOf(':')
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 11e22145ebb88..77d9825434706 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -28,8 +28,8 @@ import org.scalatest.concurrent.Eventually._
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.SparkContext._
-import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId}
 
 class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {