From a7260d346882bcdfe6e5014c52960017fb602300 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 17 Mar 2014 15:49:50 -0700 Subject: [PATCH] Added try-catch in context cleaner and null value cleaning in TimeStampedWeakValueHashMap. --- .../org/apache/spark/ContextCleaner.scala | 50 +++++++++++-------- .../org/apache/spark/MapOutputTracker.scala | 1 - .../util/TimeStampedWeakValueHashMap.scala | 47 ++++++++++++----- 3 files changed, 64 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 8f76b91753157..7636c6cf64972 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -50,6 +50,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Start the cleaner */ def start() { cleaningThread.setDaemon(true) + cleaningThread.setName("ContextCleaner") cleaningThread.start() } @@ -60,7 +61,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** - * Clean (unpersist) RDD data. Do not perform any time or resource intensive + * Clean RDD data. Do not perform any time or resource intensive * computation in this function as this is called from a finalize() function. */ def cleanRDD(rddId: Int) { @@ -92,39 +93,48 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Keep cleaning RDDs and shuffle data */ private def keepCleaning() { - try { - while (!isStopped) { + while (!isStopped) { + try { val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS)) - taskOpt.foreach(task => { + taskOpt.foreach { task => logDebug("Got cleaning task " + taskOpt.get) task match { - case CleanRDD(rddId) => doCleanRDD(sc, rddId) + case CleanRDD(rddId) => doCleanRDD(rddId) case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId) } - }) + } + } catch { + case ie: InterruptedException => + if (!isStopped) logWarning("Cleaning thread interrupted") + case t: Throwable => logError("Error in cleaning thread", t) } - } catch { - case ie: InterruptedException => - if (!isStopped) logWarning("Cleaning thread interrupted") } } /** Perform RDD cleaning */ - private def doCleanRDD(sc: SparkContext, rddId: Int) { - logDebug("Cleaning rdd " + rddId) - blockManagerMaster.removeRdd(rddId, false) - sc.persistentRdds.remove(rddId) - listeners.foreach(_.rddCleaned(rddId)) - logInfo("Cleaned rdd " + rddId) + private def doCleanRDD(rddId: Int) { + try { + logDebug("Cleaning RDD " + rddId) + blockManagerMaster.removeRdd(rddId, false) + sc.persistentRdds.remove(rddId) + listeners.foreach(_.rddCleaned(rddId)) + logInfo("Cleaned RDD " + rddId) + } catch { + case t: Throwable => logError("Error cleaning RDD " + rddId, t) + } } /** Perform shuffle cleaning */ private def doCleanShuffle(shuffleId: Int) { - logDebug("Cleaning shuffle " + shuffleId) - mapOutputTrackerMaster.unregisterShuffle(shuffleId) - blockManagerMaster.removeShuffle(shuffleId) - listeners.foreach(_.shuffleCleaned(shuffleId)) - logInfo("Cleaned shuffle " + shuffleId) + try { + logDebug("Cleaning shuffle " + shuffleId) + mapOutputTrackerMaster.unregisterShuffle(shuffleId) + blockManagerMaster.removeShuffle(shuffleId) + listeners.foreach(_.shuffleCleaned(shuffleId)) + logInfo("Cleaned shuffle " + shuffleId) + } catch { + case t: Throwable => logError("Error cleaning shuffle " + shuffleId, t) + } } private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 27f94ce0e42d0..f37a9d41b2237 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.Some import scala.collection.mutable.{HashSet, Map} import scala.concurrent.Await diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index ea0fde87c56d0..bd86d78b8010f 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -24,6 +24,7 @@ import java.lang.ref.WeakReference import java.util.concurrent.ConcurrentHashMap import org.apache.spark.Logging +import java.util.concurrent.atomic.AtomicInteger private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) { def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value)) @@ -44,6 +45,12 @@ private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: Wea private[spark] class TimeStampedWeakValueHashMap[A, B]() extends WrappedJavaHashMap[A, B, A, TimeStampedWeakValue[B]] with Logging { + /** Number of inserts after which keys whose weak ref values are null will be cleaned */ + private val CLEANUP_INTERVAL = 1000 + + /** Counter for counting the number of inserts */ + private val insertCounts = new AtomicInteger(0) + protected[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = { new ConcurrentHashMap[A, TimeStampedWeakValue[B]]() } @@ -52,11 +59,21 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]() new TimeStampedWeakValueHashMap[K1, V1]() } + override def +=(kv: (A, B)): this.type = { + // Cleanup null value at certain intervals + if (insertCounts.incrementAndGet() % CLEANUP_INTERVAL == 0) { + cleanNullValues() + } + super.+=(kv) + } + override def get(key: A): Option[B] = { Option(internalJavaMap.get(key)) match { case Some(weakValue) => val value = weakValue.weakValue.get - if (value == null) cleanupKey(key) + if (value == null) { + internalJavaMap.remove(key) + } Option(value) case None => None @@ -72,16 +89,10 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]() } override def iterator: Iterator[(A, B)] = { - val jIterator = internalJavaMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).flatMap(kv => { - val key = kv.getKey - val value = kv.getValue.weakValue.get - if (value == null) { - cleanupKey(key) - Seq.empty - } else { - Seq((key, value)) - } + val iterator = internalJavaMap.entrySet().iterator() + JavaConversions.asScalaIterator(iterator).flatMap(kv => { + val (key, value) = (kv.getKey, kv.getValue.weakValue.get) + if (value != null) Seq((key, value)) else Seq.empty }) } @@ -104,8 +115,18 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]() } } - private def cleanupKey(key: A) { - // TODO: Consider cleaning up keys to empty weak ref values automatically in future. + /** + * Removes keys whose weak referenced values have become null. + */ + private def cleanNullValues() { + val iterator = internalJavaMap.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue.weakValue.get == null) { + logDebug("Removing key " + entry.getKey) + iterator.remove() + } + } } private def currentTime = System.currentTimeMillis()