Skip to content

Commit

Permalink
Some review comments
Browse files Browse the repository at this point in the history
- Always notifyAll if a new thread was added in tryToAcquire
- Log when a thread blocks
  • Loading branch information
mateiz committed Aug 4, 2014
1 parent b810120 commit 315e3a5
Showing 1 changed file with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.shuffle

import org.apache.spark.{SparkException, SparkConf}
import scala.collection.mutable

import org.apache.spark.{Logging, SparkException, SparkConf}

/**
* Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
* collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
Expand All @@ -34,7 +35,7 @@ import scala.collection.mutable
* this set changes. This is all done by synchronizing access on "this" to mutate state and using
* wait() and notifyAll() to signal changes.
*/
private[spark] class ShuffleMemoryManager(maxMemory: Long) {
private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes

def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))
Expand All @@ -50,7 +51,10 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) {

// Add this thread to the threadMemory map just so we can keep an accurate count of the number
// of active threads, to let other threads ramp down their memory in calls to tryToAcquire
threadMemory.getOrElseUpdate(threadId, 0L)
if (!threadMemory.contains(threadId)) {
threadMemory(threadId) = 0L
notifyAll() // Will later cause waiting threads to wake up and check numActiveThreads again
}

// Keep looping until we're either sure that we don't want to grant this request (because this
// thread would have more than 1 / numActiveThreads of the memory) or we have enough free
Expand All @@ -66,16 +70,11 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) {
if (bytesFree >= numBytes) {
// Grant the request
threadMemory(threadId) = curMem + numBytes
// Notify other waiting threads because the # active of threads may have increased, so
// they may cancel their current waits
notifyAll()
return true
} else if (curMem + numBytes <= maxMemory / (2 * numActiveThreads)) {
// This thread has so little memory that we want it to block and acquire a bigger
// amount instead of cancelling the request. Wait on "this" for a thread to call notify.
// Before doing the wait, however, also notify other current waiters in case our thread
// becoming active just pushed them over the limit to give up their own waits.
notifyAll()
logInfo(s"Thread ${threadId} blocking for shuffle memory pool to free up")
wait()
} else {
// Thread would have between 1 / (2 * numActiveThreads) and 1 / numActiveThreads memory
Expand Down Expand Up @@ -106,6 +105,11 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) {
}

private object ShuffleMemoryManager {
/**
* Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction
* of the memory pool and a safety factor since collections can sometimes grow bigger than
* the size we target before we estimate their size again.
*/
def getMaxMemory(conf: SparkConf): Long = {
val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
Expand Down

0 comments on commit 315e3a5

Please sign in to comment.