diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 188dded7c02f7..b3ca150195a5f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -31,12 +31,16 @@ object FakeTask { * locations for each task (given as varargs) if this sequence is not empty. */ def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createTaskSet(numTasks, 0, prefLocs: _*) + } + + def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { if (prefLocs.size != 0 && prefLocs.size != numTasks) { throw new IllegalArgumentException("Wrong number of task locations") } val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, 0, 0, 0, null) + new TaskSet(tasks, 0, stageAttemptId, 0, null) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 55be409afcf31..199d51275c51d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -33,7 +33,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. - val dagScheduler = new DAGScheduler(sc, taskScheduler) { + new DAGScheduler(sc, taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} override def executorAdded(execId: String, host: String) {} } @@ -67,7 +67,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. - val dagScheduler = new DAGScheduler(sc, taskScheduler) { + new DAGScheduler(sc, taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} override def executorAdded(execId: String, host: String) {} } @@ -138,18 +138,103 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L override def executorAdded(execId: String, host: String) {} } taskScheduler.setDAGScheduler(dagScheduler) - val attempt1 = new TaskSet(Array(new FakeTask(0)), 0, 0, 0, null) - val attempt2 = new TaskSet(Array(new FakeTask(0)), 0, 1, 0, null) + val attempt1 = FakeTask.createTaskSet(1, 0) + val attempt2 = FakeTask.createTaskSet(1, 1) taskScheduler.submitTasks(attempt1) intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) } // OK to submit multiple if previous attempts are all zombie - taskScheduler.activeTaskSets(attempt1.id).isZombie = true + taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) + .get.isZombie = true taskScheduler.submitTasks(attempt2) - val attempt3 = new TaskSet(Array(new FakeTask(0)), 0, 2, 0, null) + val attempt3 = FakeTask.createTaskSet(1, 2) intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) } - taskScheduler.activeTaskSets(attempt2.id).isZombie = true + taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId) + .get.isZombie = true taskScheduler.submitTasks(attempt3) } + test("don't schedule more tasks after a taskset is zombie") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val numFreeCores = 1 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val attempt1 = FakeTask.createTaskSet(10) + + // submit attempt 1, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions.length) + + // now mark attempt 1 as a zombie + taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) + .get.isZombie = true + + // don't schedule anything on another resource offer + val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions2.length) + + // if we schedule another attempt for the same stage, it should get scheduled + val attempt2 = FakeTask.createTaskSet(10, 1) + + // submit attempt 2, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt2) + val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions3.length) + val mgr = taskScheduler.taskSetManagerForTask(taskDescriptions3(0).taskId).get + assert(mgr.taskSet.stageAttemptId === 1) + } + + test("if a zombie attempt finishes, continue scheduling tasks for non-zombie attempts") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val numFreeCores = 10 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val attempt1 = FakeTask.createTaskSet(10) + + // submit attempt 1, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(10 === taskDescriptions.length) + + // now mark attempt 1 as a zombie + val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get + mgr1.isZombie = true + + // don't schedule anything on another resource offer + val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions2.length) + + //submit attempt 2 + val attempt2 = FakeTask.createTaskSet(10, 1) + taskScheduler.submitTasks(attempt2) + + // attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were + // already submitted, and then they finish) + taskScheduler.taskSetFinished(mgr1) + + // now with another resource offer, we should still schedule all the tasks in attempt2 + val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten + assert(10 === taskDescriptions3.length) + + taskDescriptions3.foreach{ task => + val mgr = taskScheduler.taskSetManagerForTask(task.taskId).get + assert(mgr.taskSet.stageAttemptId === 1) + } + } + }