diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 461af1cd11965..8f76b91753157 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -21,8 +21,6 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} -import org.apache.spark.rdd.RDD - /** Listener class used for testing when any item has been cleaned by the Cleaner class */ private[spark] trait CleanerListener { def rddCleaned(rddId: Int) @@ -32,12 +30,12 @@ private[spark] trait CleanerListener { /** * Cleans RDDs and shuffle data. */ -private[spark] class ContextCleaner(env: SparkEnv) extends Logging { +private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Classes to represent cleaning tasks */ private sealed trait CleaningTask - private case class CleanRDD(sc: SparkContext, id: Int) extends CleaningTask - private case class CleanShuffle(id: Int) extends CleaningTask + private case class CleanRDD(rddId: Int) extends CleaningTask + private case class CleanShuffle(shuffleId: Int) extends CleaningTask // TODO: add CleanBroadcast private val queue = new LinkedBlockingQueue[CleaningTask] @@ -47,7 +45,7 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} - private var stopped = false + @volatile private var stopped = false /** Start the cleaner */ def start() { @@ -57,26 +55,37 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { /** Stop the cleaner */ def stop() { - synchronized { stopped = true } + stopped = true cleaningThread.interrupt() } - /** Clean (unpersist) RDD data. */ - def cleanRDD(rdd: RDD[_]) { - enqueue(CleanRDD(rdd.sparkContext, rdd.id)) - logDebug("Enqueued RDD " + rdd + " for cleaning up") + /** + * Clean (unpersist) RDD data. Do not perform any time or resource intensive + * computation in this function as this is called from a finalize() function. + */ + def cleanRDD(rddId: Int) { + enqueue(CleanRDD(rddId)) + logDebug("Enqueued RDD " + rddId + " for cleaning up") } - /** Clean shuffle data. */ + /** + * Clean shuffle data. Do not perform any time or resource intensive + * computation in this function as this is called from a finalize() function. + */ def cleanShuffle(shuffleId: Int) { enqueue(CleanShuffle(shuffleId)) logDebug("Enqueued shuffle " + shuffleId + " for cleaning up") } + /** Attach a listener object to get information of when objects are cleaned. */ def attachListener(listener: CleanerListener) { listeners += listener } - /** Enqueue a cleaning task */ + + /** + * Enqueue a cleaning task. Do not perform any time or resource intensive + * computation in this function as this is called from a finalize() function. + */ private def enqueue(task: CleaningTask) { queue.put(task) } @@ -86,16 +95,16 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { try { while (!isStopped) { val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS)) - if (taskOpt.isDefined) { + taskOpt.foreach(task => { logDebug("Got cleaning task " + taskOpt.get) - taskOpt.get match { - case CleanRDD(sc, rddId) => doCleanRDD(sc, rddId) + task match { + case CleanRDD(rddId) => doCleanRDD(sc, rddId) case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId) } - } + }) } } catch { - case ie: java.lang.InterruptedException => + case ie: InterruptedException => if (!isStopped) logWarning("Cleaning thread interrupted") } } @@ -103,7 +112,7 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { /** Perform RDD cleaning */ private def doCleanRDD(sc: SparkContext, rddId: Int) { logDebug("Cleaning rdd " + rddId) - sc.env.blockManager.master.removeRdd(rddId, false) + blockManagerMaster.removeRdd(rddId, false) sc.persistentRdds.remove(rddId) listeners.foreach(_.rddCleaned(rddId)) logInfo("Cleaned rdd " + rddId) @@ -113,14 +122,14 @@ private[spark] class ContextCleaner(env: SparkEnv) extends Logging { private def doCleanShuffle(shuffleId: Int) { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) - blockManager.master.removeShuffle(shuffleId) + blockManagerMaster.removeShuffle(shuffleId) listeners.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } - private def mapOutputTrackerMaster = env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - private def blockManager = env.blockManager + private def blockManagerMaster = sc.env.blockManager.master - private def isStopped = synchronized { stopped } + private def isStopped = stopped } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4d0f3dd6cdb71..27f94ce0e42d0 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -17,22 +17,18 @@ package org.apache.spark -import scala.Some -import scala.collection.mutable.{HashSet, Map} -import scala.concurrent.Await - import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.HashSet +import scala.Some +import scala.collection.mutable.{HashSet, Map} import scala.concurrent.Await import akka.actor._ import akka.pattern.ask - import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AkkaUtils, TimeStampedHashMap, BoundedHashMap} +import org.apache.spark.util._ private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) @@ -55,7 +51,7 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster } /** - * Class that keeps track of the location of the location of the mapt output of + * Class that keeps track of the location of the location of the map output of * a stage. This is abstract because different versions of MapOutputTracker * (driver and worker) use different HashMap to store its metadata. */ @@ -155,10 +151,6 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - protected def cleanup(cleanupTime: Long) { - mapStatuses.asInstanceOf[TimeStampedHashMap[_, _]].clearOldValues(cleanupTime) - } - def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() @@ -195,10 +187,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** * Bounded HashMap for storing serialized statuses in the worker. This allows * the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be - * automatically repopulated by fetching them again from the driver. + * automatically repopulated by fetching them again from the driver. Its okay to + * keep the cache size small as it unlikely that there will be a very large number of + * stages active simultaneously in the worker. */ - protected val MAX_MAP_STATUSES = 100 - protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](MAX_MAP_STATUSES, true) + protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]]( + conf.getInt("spark.mapOutputTracker.cacheSize", 100), true + ) } /** @@ -212,20 +207,18 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) private var cacheEpoch = epoch /** - * Timestamp based HashMap for storing mapStatuses in the master, so that statuses are dropped - * only by explicit deregistering or by ttl-based cleaning (if set). Other than these two + * Timestamp based HashMap for storing mapStatuses and cached serialized statuses + * in the master, so that statuses are dropped only by explicit deregistering or + * by TTL-based cleaning (if set). Other than these two * scenarios, nothing should be dropped from this HashMap. */ + protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() + private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]() - /** - * Bounded HashMap for storing serialized statuses in the master. This allows - * the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be - * automatically repopulated by serializing the lost statuses again . - */ - protected val MAX_SERIALIZED_STATUSES = 100 - private val cachedSerializedStatuses = - new BoundedHashMap[Int, Array[Byte]](MAX_SERIALIZED_STATUSES, true) + // For cleaning up TimeStampedHashMaps + private val metadataCleaner = + new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) def registerShuffle(shuffleId: Int, numMaps: Int) { if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { @@ -264,6 +257,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) def unregisterShuffle(shuffleId: Int) { mapStatuses.remove(shuffleId) + cachedSerializedStatuses.remove(shuffleId) } def incrementEpoch() { @@ -303,11 +297,12 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } def contains(shuffleId: Int): Boolean = { - mapStatuses.contains(shuffleId) + cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) } override def stop() { super.stop() + metadataCleaner.cancel() cachedSerializedStatuses.clear() } @@ -315,8 +310,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) // This might be called on the MapOutputTrackerMaster if we're running in local mode. } - def has(shuffleId: Int): Boolean = { - cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId) + protected def cleanup(cleanupTime: Long) { + mapStatuses.clearOldValues(cleanupTime) + cachedSerializedStatuses.clearOldValues(cleanupTime) } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 74d10196980cf..b80c58489cb52 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -206,7 +206,7 @@ class SparkContext( @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) dagScheduler.start() - private[spark] val cleaner = new ContextCleaner(env) + private[spark] val cleaner = new ContextCleaner(this) cleaner.start() ui.start() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e1367131cf569..f2e20a108630a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1027,7 +1027,7 @@ abstract class RDD[T: ClassTag]( def cleanup() { logInfo("Cleanup called on RDD " + id) - sc.cleaner.cleanRDD(this) + sc.cleaner.cleanRDD(id) dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]]) .map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId) .foreach(sc.cleaner.cleanShuffle) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 38628e949a4a6..1a5cd82571a08 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -266,7 +266,7 @@ class DAGScheduler( : Stage = { val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) - if (mapOutputTracker.has(shuffleDep.shuffleId)) { + if (mapOutputTracker.contains(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) for (i <- 0 until locs.size) { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index ed03f189fb4ac..cf83a60ffb9e8 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -169,8 +169,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { throw new IllegalStateException("Failed to find shuffle block: " + id) } - /** Remove all the blocks / files related to a particular shuffle */ + /** Remove all the blocks / files and metadata related to a particular shuffle */ def removeShuffle(shuffleId: ShuffleId) { + removeShuffleBlocks(shuffleId) + shuffleStates.remove(shuffleId) + } + + /** Remove all the blocks / files related to a particular shuffle */ + private def removeShuffleBlocks(shuffleId: ShuffleId) { shuffleStates.get(shuffleId) match { case Some(state) => if (consolidateShuffleFiles) { @@ -194,7 +200,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { } private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffle(shuffleId)) + shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 2553db4ad589e..2ef853710a554 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -62,8 +62,8 @@ private[spark] class MetadataCleaner( private[spark] object MetadataCleanerType extends Enumeration { - val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, BLOCK_MANAGER, - SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS, CLEANER = Value + val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER, + SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value type MetadataCleanerType = Value diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 2ec314aa632f3..cb827b9e955a9 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -25,7 +25,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val rdd = newRDD.persist() rdd.count() val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) - cleaner.cleanRDD(rdd) + cleaner.cleanRDD(rdd.id) tester.assertCleanup } diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala index 7ad65c9681812..f0a84064ab9fb 100644 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala @@ -206,4 +206,4 @@ class TestMap[A, B] extends WrappedJavaHashMap[A, B, A, B] { protected[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { new TestMap[K1, V1] } -} \ No newline at end of file +}