diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 519ecde50a163..129a426f283d2 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -46,7 +46,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { if (loading.contains(key)) { logInfo("Another thread is loading %s, waiting for it to finish...".format(key)) while (loading.contains(key)) { - try {loading.wait()} catch {case _ : Throwable =>} + try { + loading.wait() + } catch { + case e: Exception => + logWarning("Got an exception while waiting for another thread to load " + key, e) + } } logInfo("Finished waiting for %s".format(key)) // See whether someone else has successfully loaded it. The main way this would fail @@ -74,7 +79,10 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val elements = new ArrayBuffer[Any] elements ++= computedValues blockManager.put(key, elements, storageLevel, tellMaster = true) - return elements.iterator.asInstanceOf[Iterator[T]] + val returnValue: Iterator[T] = elements.iterator.asInstanceOf[Iterator[T]] + + new InterruptibleIterator(context, returnValue) + } finally { loading.synchronized { loading.remove(key) diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index 56e0b8d2c0b9b..c92d2faae51f6 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -24,7 +24,17 @@ package org.apache.spark class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T]) extends Iterator[T] { - def hasNext: Boolean = !context.interrupted && delegate.hasNext + def hasNext: Boolean = { + // TODO(aarondav/rxin): Check Thread.interrupted instead of context.interrupted if interrupt + // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read + // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which + // introduces an expensive read fence. + if (context.interrupted) { + throw new TaskKilledException + } else { + delegate.hasNext + } + } def next(): T = delegate.next() } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index eb5bb17539fb6..e86e7923f8622 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -355,16 +355,27 @@ class SparkContext( * // In a separate thread: * sc.cancelJobGroup("some_job_to_cancel") * }}} + * + * If interruptOnCancel is set to true for the job group, then job cancellation will result + * in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure + * that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208, + * where HDFS may respond to Thread.interrupt() by marking nodes as dead. */ - def setJobGroup(groupId: String, description: String) { + def setJobGroup(groupId: String, description: String, interruptOnCancel: Boolean = false) { setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description) setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) + // Note: Specifying interruptOnCancel in setJobGroup (rather than cancelJobGroup) avoids + // changing several public APIs and allows Spark cancellations outside of the cancelJobGroup + // APIs to also take advantage of this property (e.g., internal job failures or canceling from + // JobProgressTab UI) on a per-job basis. + setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, interruptOnCancel.toString) } /** Clear the job group id and its description. */ def clearJobGroup() { setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null) setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null) + setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null) } // Post init @@ -1022,6 +1033,8 @@ object SparkContext { private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" + private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" + private[spark] val SPARK_UNKNOWN_USER = "" implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { diff --git a/core/src/main/scala/org/apache/spark/TaskKilledException.scala b/core/src/main/scala/org/apache/spark/TaskKilledException.scala new file mode 100644 index 0000000000000..cbd6b2866e4f9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskKilledException.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Exception for a task getting killed. + */ +private[spark] class TaskKilledException extends RuntimeException diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8332631838d0a..3033151b19cba 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -69,12 +69,12 @@ private[spark] class CoarseGrainedExecutorBackend( executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) } - case KillTask(taskId, _) => + case KillTask(taskId, _, interruptThread) => if (executor == null) { logError("Received KillTask command but executor was null") System.exit(1) } else { - executor.killTask(taskId) + executor.killTask(taskId, interruptThread) } case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e3a8d4a224839..893e05837040c 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -138,10 +138,10 @@ private[spark] class Executor( threadPool.execute(tr) } - def killTask(taskId: Long) { + def killTask(taskId: Long, interruptThread: Boolean) { val tr = runningTasks.get(taskId) if (tr != null) { - tr.kill() + tr.kill(interruptThread) } } @@ -163,16 +163,14 @@ private[spark] class Executor( class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) extends Runnable { - object TaskKilledException extends Exception - @volatile private var killed = false @volatile private var task: Task[Any] = _ - def kill() { + def kill(interruptThread: Boolean) { logInfo("Executor is trying to kill task " + taskId) killed = true if (task != null) { - task.kill() + task.kill(interruptThread) } } @@ -202,7 +200,7 @@ private[spark] class Executor( // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl // exception will be caught by the catch block, leading to an incorrect ExceptionFailure // for the task. - throw TaskKilledException + throw new TaskKilledException } attemptedTask = Some(task) @@ -216,7 +214,7 @@ private[spark] class Executor( // If the task has been killed, let's fail it. if (task.killed) { - throw TaskKilledException + throw new TaskKilledException } for (m <- task.metrics) { @@ -254,7 +252,7 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) } - case TaskKilledException => { + case _: TaskKilledException | _: InterruptedException if task.killed => { logInfo("Executor killed task " + taskId) execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) } diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index b56d8c99124df..384494d1afb1b 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -78,7 +78,8 @@ private[spark] class MesosExecutorBackend if (executor == null) { logError("Received KillTask but executor was null") } else { - executor.killTask(t.getValue.toLong) + // TODO: Determine the 'interruptOnCancel' property set for the given job. + executor.killTask(t.getValue.toLong, interruptThread = false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f34e98f86b86b..159ee7e7feb37 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -962,10 +962,13 @@ class DAGScheduler( if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { + val job = idToActiveJob(jobId) val independentStages = removeJobAndIndependentStages(jobId) - independentStages.foreach { taskSched.cancelTasks } + val shouldInterruptThread = + if (job.properties == null) false + else job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false").toBoolean + independentStages.foreach { taskSched.cancelTasks(_, shouldInterruptThread) } val error = new SparkException("Job %d cancelled".format(jobId)) - val job = idToActiveJob(jobId) job.listener.jobFailed(error) jobIdToStageIds -= jobId activeJobs -= job diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 69b42e86eae3e..b7b4cb6b27720 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -47,8 +47,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(attemptId: Long): T = { context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + taskThread = Thread.currentThread() if (_killed) { - kill() + kill(interruptThread = false) } runTask(context) } @@ -65,6 +66,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex // Task context, to be initialized in run(). @transient protected var context: TaskContext = _ + // The actual Thread on which the task is running, if any. Initialized in run(). + @volatile @transient private var taskThread: Thread = _ + // A flag to indicate whether the task is killed. This is used in case context is not yet // initialized when kill() is invoked. @volatile @transient private var _killed = false @@ -78,12 +82,16 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark * code and user code to properly handle the flag. This function should be idempotent so it can * be called multiple times. + * If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread. */ - def kill() { + def kill(interruptThread: Boolean) { _killed = true if (context != null) { context.interrupted = true } + if (interruptThread && taskThread != null) { + taskThread.interrupt() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 10e047810827c..edcfe05576c05 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -45,7 +45,7 @@ private[spark] trait TaskScheduler { def submitTasks(taskSet: TaskSet): Unit // Cancel a stage. - def cancelTasks(stageId: Int) + def cancelTasks(stageId: Int, interruptThread: Boolean) // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index 03bf76083761f..613fa7850bb25 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -31,8 +31,8 @@ private[spark] class TaskSet( val properties: Properties) { val id: String = stageId + "." + attempt - def kill() { - tasks.foreach(_.kill()) + def kill(interruptThread: Boolean) { + tasks.foreach(_.kill(interruptThread)) } override def toString: String = "TaskSet " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index e42146d021eae..059ab51f40b5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -165,7 +165,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.reviveOffers() } - override def cancelTasks(stageId: Int): Unit = synchronized { + override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => // There are two possible cases here: @@ -178,7 +178,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (taskIds.size > 0) { taskIds.foreach { tid => val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId) + backend.killTask(tid, execId, interruptThread) } } logInfo("Stage %d was cancelled".format(stageId)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 53316dae2a6c8..b8aaa097e9b99 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -31,7 +31,8 @@ private[spark] object CoarseGrainedClusterMessages { // Driver to executors case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage - case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage + case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) + extends CoarseGrainedClusterMessage case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index d0ba5bf55dcfd..0811fff8ff64a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -100,8 +100,8 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac case ReviveOffers => makeOffers() - case KillTask(taskId, executorId) => - executorActor(executorId) ! KillTask(taskId, executorId) + case KillTask(taskId, executorId, interruptThread) => + executorActor(executorId) ! KillTask(taskId, executorId, interruptThread) case StopDriver => sender ! true @@ -215,8 +215,8 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac driverActor ! ReviveOffers } - override def killTask(taskId: Long, executorId: String) { - driverActor ! KillTask(taskId, executorId) + override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { + driverActor ! KillTask(taskId, executorId, interruptThread) } override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala index 5367218faa685..23a6d938c8e9f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala @@ -30,7 +30,8 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int - def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException + def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = + throw new UnsupportedOperationException // Memory used by each executor (in megabytes) protected val executorMemory: Int = SparkContext.executorMemoryRequested diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 1c227fefe48d3..2dfb57e96f470 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -44,7 +44,7 @@ private[local] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) private[local] -case class KillTask(taskId: Long) +case class KillTask(taskId: Long, interruptThread: Boolean) private[spark] class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) @@ -62,8 +62,8 @@ class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) launchTask(localScheduler.resourceOffer(freeCores)) } - case KillTask(taskId) => - executor.killTask(taskId) + case KillTask(taskId, interruptThread) => + executor.killTask(taskId, interruptThread) } private def launchTask(tasks: Seq[TaskDescription]) { @@ -128,7 +128,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } - override def cancelTasks(stageId: Int): Unit = synchronized { + override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId)) activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => @@ -141,7 +141,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val taskIds = taskSetTaskIds(tsm.taskSet.id) if (taskIds.size > 0) { taskIds.foreach { tid => - localActor ! KillTask(tid) + localActor ! KillTask(tid, interruptThread) } } logInfo("Stage %d was cancelled".format(stageId)) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 1121e06e2e6cc..f498d0e54723a 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.util.concurrent.Semaphore import scala.concurrent.Await -import scala.concurrent.duration.Duration +import scala.concurrent.duration._ import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global @@ -85,6 +85,35 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf assert(sc.parallelize(1 to 10, 2).count === 10) } + test("do not put partially executed partitions into cache") { + // In this test case, we create a scenario in which a partition is only partially executed, + // and make sure CacheManager does not put that partially executed partition into the + // BlockManager. + import JobCancellationSuite._ + sc = new SparkContext("local", "test") + + // Run from 1 to 10, and then block and wait for the task to be killed. + val rdd = sc.parallelize(1 to 1000, 2).map { x => + if (x > 10) { + taskStartedSemaphore.release() + taskCancelledSemaphore.acquire() + } + x + }.cache() + + val rdd1 = rdd.map(x => x) + + future { + taskStartedSemaphore.acquire() + sc.cancelAllJobs() + taskCancelledSemaphore.release(100000) + } + + intercept[SparkException] { rdd1.count() } + // If the partial block is put into cache, rdd.count() would return a number less than 1000. + assert(rdd.count() === 1000) + } + test("job group") { sc = new SparkContext("local[2]", "test") @@ -102,27 +131,57 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() } + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + sc.clearJobGroup() val jobB = sc.parallelize(1 to 100, 2).countAsync() + sc.cancelJobGroup("jobA") + val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) } + assert(e.getMessage contains "cancel") + + // Once A is cancelled, job B should finish fairly quickly. + assert(jobB.get() === 100) + } + + test("job group with interruption") { + sc = new SparkContext("local[2]", "test") + + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + // jobA is the one to be cancelled. + val jobA = future { + sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true) + sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100000); i }.count() + } // Block until both tasks of job A have started and cancel job A. sem.acquire(2) + + sc.clearJobGroup() + val jobB = sc.parallelize(1 to 100, 2).countAsync() sc.cancelJobGroup("jobA") - val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) } + val e = intercept[SparkException] { Await.result(jobA, 5.seconds) } assert(e.getMessage contains "cancel") // Once A is cancelled, job B should finish fairly quickly. assert(jobB.get() === 100) } -/* - test("two jobs sharing the same stage") { + + ignore("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched // sem2: make sure the first stage is not finished until cancel is issued val sem1 = new Semaphore(0) val sem2 = new Semaphore(0) sc = new SparkContext("local[2]", "test") - sc.dagScheduler.addSparkListener(new SparkListener { + sc.addSparkListener(new SparkListener { override def onTaskStart(taskStart: SparkListenerTaskStart) { sem1.release() } @@ -148,7 +207,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf intercept[SparkException] { f1.get() } intercept[SparkException] { f2.get() } } - */ + def testCount() { // Cancel before launching any tasks { @@ -207,3 +266,9 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf } } } + + +object JobCancellationSuite { + val taskStartedSemaphore = new Semaphore(0) + val taskCancelledSemaphore = new Semaphore(0) +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 03fc8c020e005..6f1925ef092e0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import scala.collection.mutable.{Map, HashMap} +import scala.collection.mutable.{Map, HashMap, HashSet} import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter @@ -49,6 +49,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() + + /** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */ + val cancelledStages = new HashSet[Int]() + val taskScheduler = new TaskScheduler() { override def rootPool: Pool = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE @@ -59,7 +63,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSets += taskSet } - override def cancelTasks(stageId: Int) {} + override def cancelTasks(stageId: Int, interruptThread: Boolean) { + cancelledStages += stageId + } override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 } @@ -97,6 +103,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont before { sc = new SparkContext("local", "DAGSchedulerSuite") taskSets.clear() + cancelledStages.clear() cacheLocations.clear() results.clear() mapOutputTracker = new MapOutputTracker()