Skip to content

Commit

Permalink
SPARK-1582 Invoke Thread.interrupt() when cancelling jobs
Browse files Browse the repository at this point in the history
Sometimes executor threads are blocked waiting for IO or monitors,
and the current implementation of job killing may never recover
these threads. By simply invoking Thread.interrupt() during
cancellation, we can often safely unblock the threads and use them
for subsequent work.

Note that this feature must remain optional for now because of a
bug in HDFS where Thread.interrupt() may cause nodes to be marked
as permanently dead (as the InterruptedException is reinterpreted
as an IOException during communication with some node).
  • Loading branch information
aarondav committed Apr 23, 2014
1 parent 26d35f3 commit 4cb9fd6
Show file tree
Hide file tree
Showing 14 changed files with 51 additions and 29 deletions.
11 changes: 10 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -381,16 +381,23 @@ class SparkContext(config: SparkConf) extends Logging {
* // In a separate thread:
* sc.cancelJobGroup("some_job_to_cancel")
* }}}
*
* If interruptOnCancel is set to true for the job group, then job cancellation will result
* in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure
* that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208,
* where HDFS may respond to Thread.interrupt() by marking nodes as dead.
*/
def setJobGroup(groupId: String, description: String) {
def setJobGroup(groupId: String, description: String, interruptOnCancel: Boolean = false) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, interruptOnCancel.toString)
}

/** Clear the current thread's job group ID and its description. */
def clearJobGroup() {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null)
}

// Post init
Expand Down Expand Up @@ -1244,6 +1251,8 @@ object SparkContext extends Logging {

private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"

private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel"

private[spark] val SPARK_UNKNOWN_USER = "<unknown>"

implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ private[spark] class PythonRDD[T: ClassTag](
val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator += Collections.singletonList(update)

}
Array.empty[Byte]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
}

case KillTask(taskId, _) =>
case KillTask(taskId, _, interruptThread) =>
if (executor == null) {
logError("Received KillTask command but executor was null")
System.exit(1)
} else {
executor.killTask(taskId)
executor.killTask(taskId, interruptThread)
}

case x: DisassociatedEvent =>
Expand Down
10 changes: 5 additions & 5 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ private[spark] class Executor(
threadPool.execute(tr)
}

def killTask(taskId: Long) {
def killTask(taskId: Long, interruptThread: Boolean) {
val tr = runningTasks.get(taskId)
if (tr != null) {
tr.kill()
tr.kill(interruptThread)
}
}

Expand All @@ -166,11 +166,11 @@ private[spark] class Executor(
@volatile private var killed = false
@volatile private var task: Task[Any] = _

def kill() {
def kill(interruptThread: Boolean) {
logInfo("Executor is trying to kill task " + taskId)
killed = true
if (task != null) {
task.kill()
task.kill(interruptThread)
}
}

Expand Down Expand Up @@ -257,7 +257,7 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}

case TaskKilledException => {
case TaskKilledException | _: InterruptedException if task.killed => {
logInfo("Executor killed task " + taskId)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ private[spark] class MesosExecutorBackend
if (executor == null) {
logError("Received KillTask but executor was null")
} else {
executor.killTask(t.getValue.toLong)
// TODO: Determine the 'interruptOnCancel' property set for the given job.
executor.killTask(t.getValue.toLong, interruptThread = false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,9 @@ class DAGScheduler(
val error = new SparkException(failureReason)
job.listener.jobFailed(error)

val shouldInterruptThread =
job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false").toBoolean

// Cancel all independent, running stages.
val stages = jobIdToStageIds(job.jobId)
if (stages.isEmpty) {
Expand All @@ -1073,7 +1076,7 @@ class DAGScheduler(
// This is the only job that uses this stage, so fail the stage if it is running.
val stage = stageIdToStage(stageId)
if (runningStages.contains(stage)) {
taskScheduler.cancelTasks(stageId)
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
val stageInfo = stageToInfos(stage)
stageInfo.stageFailed(failureReason)
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int

def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException
def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit =
throw new UnsupportedOperationException
}
12 changes: 10 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex

final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
taskThread = Thread.currentThread()
if (_killed) {
kill()
kill(interruptThread = false)
}
runTask(context)
}
Expand All @@ -62,6 +63,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
// Task context, to be initialized in run().
@transient protected var context: TaskContext = _

// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _

// A flag to indicate whether the task is killed. This is used in case context is not yet
// initialized when kill() is invoked.
@volatile @transient private var _killed = false
Expand All @@ -75,12 +79,16 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
* Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
* code and user code to properly handle the flag. This function should be idempotent so it can
* be called multiple times.
* If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread.
*/
def kill() {
def kill(interruptThread: Boolean) {
_killed = true
if (context != null) {
context.interrupted = true
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private[spark] trait TaskScheduler {
def submitTasks(taskSet: TaskSet): Unit

// Cancel a stage.
def cancelTasks(stageId: Int)
def cancelTasks(stageId: Int, interruptThread: Boolean)

// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ private[spark] class TaskSchedulerImpl(
backend.reviveOffers()
}

override def cancelTasks(stageId: Int): Unit = synchronized {
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
// There are two possible cases here:
Expand All @@ -181,7 +181,7 @@ private[spark] class TaskSchedulerImpl(
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId)
backend.killTask(tid, execId, interruptThread)
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ private[spark] class TaskSet(
val properties: Properties) {
val id: String = stageId + "." + attempt

def kill() {
tasks.foreach(_.kill())
def kill(interruptThread: Boolean) {
tasks.foreach(_.kill(interruptThread))
}

override def toString: String = "TaskSet " + id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ private[spark] object CoarseGrainedClusterMessages {
// Driver to executors
case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage

case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage
case class KillTask(taskId: Long, executor: String, interruptThread: Boolean)
extends CoarseGrainedClusterMessage

case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
extends CoarseGrainedClusterMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
case ReviveOffers =>
makeOffers()

case KillTask(taskId, executorId) =>
executorActor(executorId) ! KillTask(taskId, executorId)
case KillTask(taskId, executorId, interruptThread) =>
executorActor(executorId) ! KillTask(taskId, executorId, interruptThread)

case StopDriver =>
sender ! true
Expand Down Expand Up @@ -207,8 +207,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
driverActor ! ReviveOffers
}

override def killTask(taskId: Long, executorId: String) {
driverActor ! KillTask(taskId, executorId)
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
driverActor ! KillTask(taskId, executorId, interruptThread)
}

override def defaultParallelism(): Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ private case class ReviveOffers()

private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)

private case class KillTask(taskId: Long)
private case class KillTask(taskId: Long, interruptThread: Boolean)

/**
* Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on
Expand Down Expand Up @@ -61,8 +61,8 @@ private[spark] class LocalActor(
reviveOffers()
}

case KillTask(taskId) =>
executor.killTask(taskId)
case KillTask(taskId, interruptThread) =>
executor.killTask(taskId, interruptThread)
}

def reviveOffers() {
Expand Down Expand Up @@ -99,8 +99,8 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores:

override def defaultParallelism() = totalCores

override def killTask(taskId: Long, executorId: String) {
localActor ! KillTask(taskId)
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
localActor ! KillTask(taskId, interruptThread)
}

override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
Expand Down

0 comments on commit 4cb9fd6

Please sign in to comment.