diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 8df37c247d0d4..23b06612fd7ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -25,6 +25,7 @@ import scala.concurrent.duration._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet +import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState @@ -207,9 +208,11 @@ private[spark] class TaskSchedulerImpl( } } - // Build a list of tasks to assign to each worker - val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val availableCpus = offers.map(o => o.cores).toArray + // Randomly shuffle offers to avoid always placing tasks on the same set of workers. + val shuffledOffers = Random.shuffle(offers) + // Build a list of tasks to assign to each worker. + val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val availableCpus = shuffledOffers.map(o => o.cores).toArray val sortedTaskSets = rootPool.getSortedTaskSetQueue() for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( @@ -222,9 +225,9 @@ private[spark] class TaskSchedulerImpl( for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { do { launchedTask = false - for (i <- 0 until offers.size) { - val execId = offers(i).executorId - val host = offers(i).host + for (i <- 0 until shuffledOffers.size) { + val execId = shuffledOffers(i).executorId + val host = shuffledOffers(i).host for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { tasks(i) += task val tid = task.taskId diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 0b90c4e74c8a4..0a7cb69416a08 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -24,3 +24,19 @@ class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int override def preferredLocations: Seq[TaskLocation] = prefLocs } + +object FakeTask { + /** + * Utility method to create a TaskSet, potentially setting a particular sequence of preferred + * locations for each task (given as varargs) if this sequence is not empty. + */ + def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) + } + new TaskSet(tasks, 0, 0, 0, null) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index f4e62c64daf12..6b0800af9c6d0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -25,6 +25,13 @@ import org.scalatest.FunSuite import org.apache.spark._ +class FakeSchedulerBackend extends SchedulerBackend { + def start() {} + def stop() {} + def reviveOffers() {} + def defaultParallelism() = 1 +} + class FakeTaskSetManager( initPriority: Int, initStageId: Int, @@ -107,7 +114,8 @@ class FakeTaskSetManager( class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging { - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, taskSet: TaskSet): FakeTaskSetManager = { + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, + taskSet: TaskSet): FakeTaskSetManager = { new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) } @@ -135,10 +143,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin test("FIFO Scheduler Test") { sc = new SparkContext("local", "TaskSchedulerImplSuite") val taskScheduler = new TaskSchedulerImpl(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + val taskSet = FakeTask.createTaskSet(1) val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) @@ -162,10 +167,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin test("Fair Scheduler Test") { sc = new SparkContext("local", "TaskSchedulerImplSuite") val taskScheduler = new TaskSchedulerImpl(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + val taskSet = FakeTask.createTaskSet(1) val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() System.setProperty("spark.scheduler.allocation.file", xmlPath) @@ -219,10 +221,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin test("Nested Pool Test") { sc = new SparkContext("local", "TaskSchedulerImplSuite") val taskScheduler = new TaskSchedulerImpl(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + val taskSet = FakeTask.createTaskSet(1) val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) @@ -265,4 +264,35 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin checkTaskSetId(rootPool, 6) checkTaskSetId(rootPool, 2) } + + test("Scheduler does not always schedule tasks on the same workers") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + var dagScheduler = new DAGScheduler(taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorGained(execId: String, host: String) {} + } + + val numFreeCores = 1 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores), + new WorkerOffer("executor1", "host1", numFreeCores)) + // Repeatedly try to schedule a 1-task job, and make sure that it doesn't always + // get scheduled on the same executor. While there is a chance this test will fail + // because the task randomly gets placed on the first executor all 1000 times, the + // probability of that happening is 2^-1000 (so sufficiently small to be considered + // negligible). + val numTrials = 1000 + val selectedExecutorIds = 1.to(numTrials).map { _ => + val taskSet = FakeTask.createTaskSet(1) + taskScheduler.submitTasks(taskSet) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions.length) + taskDescriptions(0).executorId + } + var count = selectedExecutorIds.count(_ == workerOffers(0).executorId) + assert(count > 0) + assert(count < numTrials) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 20f6e503872ac..33cc7588b919c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -88,7 +88,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("TaskSet with no preferences") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(1) + val taskSet = FakeTask.createTaskSet(1) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // Offer a host with no CPUs @@ -114,7 +114,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("multiple offers with no preferences") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(3) + val taskSet = FakeTask.createTaskSet(3) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // First three offers should all find tasks @@ -145,7 +145,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("basic delay scheduling") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) - val taskSet = createTaskSet(4, + val taskSet = FakeTask.createTaskSet(4, Seq(TaskLocation("host1", "exec1")), Seq(TaskLocation("host2", "exec2")), Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), @@ -190,7 +190,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) - val taskSet = createTaskSet(5, + val taskSet = FakeTask.createTaskSet(5, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), Seq(TaskLocation("host2")), @@ -229,7 +229,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("delay scheduling with failed hosts") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) - val taskSet = createTaskSet(3, + val taskSet = FakeTask.createTaskSet(3, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), Seq(TaskLocation("host3")) @@ -261,7 +261,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("task result lost") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(1) + val taskSet = FakeTask.createTaskSet(1) val clock = new FakeClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) @@ -278,7 +278,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("repeated failures lead to task set abortion") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(1) + val taskSet = FakeTask.createTaskSet(1) val clock = new FakeClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) @@ -298,21 +298,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { } } - - /** - * Utility method to create a TaskSet, potentially setting a particular sequence of preferred - * locations for each task (given as varargs) if this sequence is not empty. - */ - def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { - if (prefLocs.size != 0 && prefLocs.size != numTasks) { - throw new IllegalArgumentException("Wrong number of task locations") - } - val tasks = Array.tabulate[Task[_]](numTasks) { i => - new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) - } - new TaskSet(tasks, 0, 0, 0, null) - } - def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)