From 315e3a5caa90af0ae5dc153c195f65dc1fd67232 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 1 Aug 2014 19:50:54 -0700 Subject: [PATCH] Some review comments - Always notifyAll if a new thread was added in tryToAcquire - Log when a thread blocks --- .../spark/shuffle/ShuffleMemoryManager.scala | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 82c335a59972c..d9e53af5e147e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -17,9 +17,10 @@ package org.apache.spark.shuffle -import org.apache.spark.{SparkException, SparkConf} import scala.collection.mutable +import org.apache.spark.{Logging, SparkException, SparkConf} + /** * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory @@ -34,7 +35,7 @@ import scala.collection.mutable * this set changes. This is all done by synchronizing access on "this" to mutate state and using * wait() and notifyAll() to signal changes. */ -private[spark] class ShuffleMemoryManager(maxMemory: Long) { +private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) @@ -50,7 +51,10 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) { // Add this thread to the threadMemory map just so we can keep an accurate count of the number // of active threads, to let other threads ramp down their memory in calls to tryToAcquire - threadMemory.getOrElseUpdate(threadId, 0L) + if (!threadMemory.contains(threadId)) { + threadMemory(threadId) = 0L + notifyAll() // Will later cause waiting threads to wake up and check numActiveThreads again + } // Keep looping until we're either sure that we don't want to grant this request (because this // thread would have more than 1 / numActiveThreads of the memory) or we have enough free @@ -66,16 +70,11 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) { if (bytesFree >= numBytes) { // Grant the request threadMemory(threadId) = curMem + numBytes - // Notify other waiting threads because the # active of threads may have increased, so - // they may cancel their current waits - notifyAll() return true } else if (curMem + numBytes <= maxMemory / (2 * numActiveThreads)) { // This thread has so little memory that we want it to block and acquire a bigger // amount instead of cancelling the request. Wait on "this" for a thread to call notify. - // Before doing the wait, however, also notify other current waiters in case our thread - // becoming active just pushed them over the limit to give up their own waits. - notifyAll() + logInfo(s"Thread ${threadId} blocking for shuffle memory pool to free up") wait() } else { // Thread would have between 1 / (2 * numActiveThreads) and 1 / numActiveThreads memory @@ -106,6 +105,11 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) { } private object ShuffleMemoryManager { + /** + * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction + * of the memory pool and a safety factor since collections can sometimes grow bigger than + * the size we target before we estimate their size again. + */ def getMaxMemory(conf: SparkConf): Long = { val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)