diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 566472e597958..9f017574aefdb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -290,16 +290,27 @@ class SparkContext( * // In a separate thread: * sc.cancelJobGroup("some_job_to_cancel") * }}} + * + * If interruptOnCancel is set to true for the job group, then job cancellation will result + * in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure + * that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208, + * where HDFS may respond to Thread.interrupt() by marking nodes as dead. */ - def setJobGroup(groupId: String, description: String) { + def setJobGroup(groupId: String, description: String, interruptOnCancel: Boolean = false) { setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description) setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) + // Note: Specifying interruptOnCancel in setJobGroup (rather than cancelJobGroup) avoids + // changing several public APIs and allows Spark cancellations outside of the cancelJobGroup + // APIs to also take advantage of this property (e.g., internal job failures or canceling from + // JobProgressTab UI) on a per-job basis. + setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, interruptOnCancel.toString) } /** Clear the current thread's job group ID and its description. */ def clearJobGroup() { setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null) setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null) + setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null) } // Post init @@ -1029,6 +1040,8 @@ object SparkContext { private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" + private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" + private[spark] val SPARK_UNKNOWN_USER = "" implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8ef527b2acebe..da7850f766fab 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -69,12 +69,12 @@ private[spark] class CoarseGrainedExecutorBackend( executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) } - case KillTask(taskId, _) => + case KillTask(taskId, _, interruptThread) => if (executor == null) { logError("Received KillTask command but executor was null") System.exit(1) } else { - executor.killTask(taskId) + executor.killTask(taskId, interruptThread) } case x: DisassociatedEvent => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 74bd2cfef2ce0..07411bfa1c172 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -133,10 +133,10 @@ private[spark] class Executor( threadPool.execute(tr) } - def killTask(taskId: Long) { + def killTask(taskId: Long, interruptThread: Boolean) { val tr = runningTasks.get(taskId) if (tr != null) { - tr.kill() + tr.kill(interruptThread) } } @@ -161,11 +161,11 @@ private[spark] class Executor( @volatile private var killed = false @volatile private var task: Task[Any] = _ - def kill() { + def kill(interruptThread: Boolean) { logInfo("Executor is trying to kill task " + taskId) killed = true if (task != null) { - task.kill() + task.kill(interruptThread) } } diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index d5ef00cb7531d..9e8c7b6179024 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -78,7 +78,8 @@ private[spark] class MesosExecutorBackend if (executor == null) { logError("Received KillTask but executor was null") } else { - executor.killTask(t.getValue.toLong) + // TODO: Determine the 'interruptOnCancel' property set for the given job. + executor.killTask(t.getValue.toLong, interruptThread = false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 536d84f07e5ec..ce6eb756ece46 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -993,7 +993,10 @@ class DAGScheduler( logDebug("Trying to cancel unregistered job " + jobId) } else { val independentStages = removeJobAndIndependentStages(jobId) - independentStages.foreach { taskSched.cancelTasks } + val shouldInterruptThread = + if (job.properties == null) false + else job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false").toBoolean + independentStages.foreach { taskSched.cancelTasks(_, shouldInterruptThread) } val error = new SparkException("Job %d cancelled".format(jobId)) val job = idToActiveJob(jobId) job.listener.jobFailed(error) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 69b42e86eae3e..b7b4cb6b27720 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -47,8 +47,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(attemptId: Long): T = { context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + taskThread = Thread.currentThread() if (_killed) { - kill() + kill(interruptThread = false) } runTask(context) } @@ -65,6 +66,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex // Task context, to be initialized in run(). @transient protected var context: TaskContext = _ + // The actual Thread on which the task is running, if any. Initialized in run(). + @volatile @transient private var taskThread: Thread = _ + // A flag to indicate whether the task is killed. This is used in case context is not yet // initialized when kill() is invoked. @volatile @transient private var _killed = false @@ -78,12 +82,16 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark * code and user code to properly handle the flag. This function should be idempotent so it can * be called multiple times. + * If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread. */ - def kill() { + def kill(interruptThread: Boolean) { _killed = true if (context != null) { context.interrupted = true } + if (interruptThread && taskThread != null) { + taskThread.interrupt() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 17b6d97e90e0a..5f215dad70d8f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -46,7 +46,7 @@ private[spark] trait TaskScheduler { def submitTasks(taskSet: TaskSet): Unit // Cancel a stage. - def cancelTasks(stageId: Int) + def cancelTasks(stageId: Int, interruptThread: Boolean) // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index 03bf76083761f..613fa7850bb25 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -31,8 +31,8 @@ private[spark] class TaskSet( val properties: Properties) { val id: String = stageId + "." + attempt - def kill() { - tasks.foreach(_.kill()) + def kill(interruptThread: Boolean) { + tasks.foreach(_.kill(interruptThread)) } override def toString: String = "TaskSet " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 53316dae2a6c8..b8aaa097e9b99 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -31,7 +31,8 @@ private[spark] object CoarseGrainedClusterMessages { // Driver to executors case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage - case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage + case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) + extends CoarseGrainedClusterMessage case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 0208388e86680..2223236f4c3a4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -101,8 +101,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A case ReviveOffers => makeOffers() - case KillTask(taskId, executorId) => - executorActor(executorId) ! KillTask(taskId, executorId) + case KillTask(taskId, executorId, interruptThread) => + executorActor(executorId) ! KillTask(taskId, executorId, interruptThread) case StopDriver => sender ! true @@ -204,8 +204,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A driverActor ! ReviveOffers } - override def killTask(taskId: Long, executorId: String) { - driverActor ! KillTask(taskId, executorId) + override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { + driverActor ! KillTask(taskId, executorId, interruptThread) } override def defaultParallelism(): Int = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index f0236ef1e975b..f9399139e82b0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -52,7 +52,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSets += taskSet } - override def cancelTasks(stageId: Int) {} + override def cancelTasks(stageId: Int, interruptThread: Boolean) { + cancelledStages += stageId + } override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 }