Skip to content

Commit

Permalink
Merge pull request alteryx#50 from kayousterhout/SPARK-908
Browse files Browse the repository at this point in the history
Fix race condition in SparkListenerSuite (fixes SPARK-908).
  • Loading branch information
rxin committed Oct 9, 2013
2 parents 7827efc + 36966f6 commit 215238c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 17 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class SparkContext(
}
taskScheduler.start()

@volatile private var dagScheduler = new DAGScheduler(taskScheduler)
@volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()

ui.start()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class DAGScheduler(

private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]

private val listenerBus = new SparkListenerBus()
private[spark] val listenerBus = new SparkListenerBus()

// Contains the locations that each RDD's partitions are cached on
private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,23 @@ private[spark] class SparkListenerBus() extends Logging {
queueFullErrorMessageLogged = true
}
}

/**
* Waits until there are no more events in the queue, or until the specified time has elapsed.
* Used for testing only. Returns true if the queue has emptied and false is the specified time
* elapsed before the queue emptied.
*/
def waitUntilEmpty(timeoutMillis: Int): Boolean = {
val finishTime = System.currentTimeMillis + timeoutMillis
while (!eventQueue.isEmpty()) {
if (System.currentTimeMillis > finishTime) {
return false
}
/* Sleep rather than using wait/notify, because this is used only for testing and wait/notify
* add overhead in the general case. */
Thread.sleep(10)
}
return true
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,9 @@ import scala.collection.mutable
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.SparkContext._

/**
*
*/

class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {

// TODO: This test has a race condition since the DAGScheduler now reports results
// asynchronously. It needs to be updated for that patch.
ignore("local metrics") {
test("local metrics") {
sc = new SparkContext("local[4]", "test")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
Expand All @@ -45,7 +39,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc

val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
d.count
Thread.sleep(1000)
val WAIT_TIMEOUT_MILLIS = 10000
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (1)

val d2 = d.map{i => w(i) -> i * 2}.setName("shuffle input 1")
Expand All @@ -57,18 +52,25 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc

d4.collectAsMap

Thread.sleep(1000)
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (4)
listener.stageInfos.foreach {stageInfo =>
//small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms
listener.stageInfos.foreach { stageInfo =>
/* small test, so some tasks might take less than 1 millisecond, but average should be greater
* than 0 ms. */
checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration")
checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime")
checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime")
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.executorRunTime.toLong},
stageInfo + " executorRunTime")
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong},
stageInfo + " executorDeserializeTime")
if (stageInfo.stage.rdd.name == d4.name) {
checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime")
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime},
stageInfo + " fetchWaitTime")
}

stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) =>
stageInfo.taskInfos.foreach { case (taskInfo, taskMetrics) =>
taskMetrics.resultSize should be > (0l)
if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
taskMetrics.shuffleWriteMetrics should be ('defined)
Expand Down

0 comments on commit 215238c

Please sign in to comment.