Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-1582 Invoke Thread.interrupt() when cancelling jobs #498

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -381,16 +381,27 @@ class SparkContext(config: SparkConf) extends Logging {
* // In a separate thread:
* sc.cancelJobGroup("some_job_to_cancel")
* }}}
*
* If interruptOnCancel is set to true for the job group, then job cancellation will result
* in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure
* that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208,
* where HDFS may respond to Thread.interrupt() by marking nodes as dead.
*/
def setJobGroup(groupId: String, description: String) {
def setJobGroup(groupId: String, description: String, interruptOnCancel: Boolean = false) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is not the ideal way to set this property, as this API is mainly just for initializing the job group name. However, it avoids changing a number of internal and external APIs (there are 4 calls in this function itself that call into the cancellation API through different routes to the DAGScheduler#failJobAndIndependentStages). Additionally, it provides the unique benefit that if the job is cancelled by another source (e.g., Spark fails the job, or the user uses the recently added cancel job feature in the JobProgressTab), then we can still set the flag based on this property.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Aaron. Can we put the above reasoning as inline comments? Helps when we revisit this in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
// Note: Specifying interruptOnCancel in setJobGroup (rather than cancelJobGroup) avoids
// changing several public APIs and allows Spark cancellations outside of the cancelJobGroup
// APIs to also take advantage of this property (e.g., internal job failures or canceling from
// JobProgressTab UI) on a per-job basis.
setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, interruptOnCancel.toString)
}

/** Clear the current thread's job group ID and its description. */
def clearJobGroup() {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null)
}

// Post init
Expand Down Expand Up @@ -1244,6 +1255,8 @@ object SparkContext extends Logging {

private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"

private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel"

private[spark] val SPARK_UNKNOWN_USER = "<unknown>"

implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ private[spark] class PythonRDD[T: ClassTag](
val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator += Collections.singletonList(update)

}
Array.empty[Byte]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
}

case KillTask(taskId, _) =>
case KillTask(taskId, _, interruptThread) =>
if (executor == null) {
logError("Received KillTask command but executor was null")
System.exit(1)
} else {
executor.killTask(taskId)
executor.killTask(taskId, interruptThread)
}

case x: DisassociatedEvent =>
Expand Down
10 changes: 5 additions & 5 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ private[spark] class Executor(
threadPool.execute(tr)
}

def killTask(taskId: Long) {
def killTask(taskId: Long, interruptThread: Boolean) {
val tr = runningTasks.get(taskId)
if (tr != null) {
tr.kill()
tr.kill(interruptThread)
}
}

Expand All @@ -166,11 +166,11 @@ private[spark] class Executor(
@volatile private var killed = false
@volatile private var task: Task[Any] = _

def kill() {
def kill(interruptThread: Boolean) {
logInfo("Executor is trying to kill task " + taskId)
killed = true
if (task != null) {
task.kill()
task.kill(interruptThread)
}
}

Expand Down Expand Up @@ -257,7 +257,7 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}

case TaskKilledException => {
case TaskKilledException | _: InterruptedException if task.killed => {
logInfo("Executor killed task " + taskId)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ private[spark] class MesosExecutorBackend
if (executor == null) {
logError("Received KillTask but executor was null")
} else {
executor.killTask(t.getValue.toLong)
// TODO: Determine the 'interruptOnCancel' property set for the given job.
executor.killTask(t.getValue.toLong, interruptThread = false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,10 @@ class DAGScheduler(
val error = new SparkException(failureReason)
job.listener.jobFailed(error)

val shouldInterruptThread =
if (job.properties == null) false
else job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false").toBoolean

// Cancel all independent, running stages.
val stages = jobIdToStageIds(job.jobId)
if (stages.isEmpty) {
Expand All @@ -1073,7 +1077,7 @@ class DAGScheduler(
// This is the only job that uses this stage, so fail the stage if it is running.
val stage = stageIdToStage(stageId)
if (runningStages.contains(stage)) {
taskScheduler.cancelTasks(stageId)
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
val stageInfo = stageToInfos(stage)
stageInfo.stageFailed(failureReason)
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int

def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException
def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit =
throw new UnsupportedOperationException
}
12 changes: 10 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex

final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
taskThread = Thread.currentThread()
if (_killed) {
kill()
kill(interruptThread = false)
}
runTask(context)
}
Expand All @@ -62,6 +63,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
// Task context, to be initialized in run().
@transient protected var context: TaskContext = _

// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _

// A flag to indicate whether the task is killed. This is used in case context is not yet
// initialized when kill() is invoked.
@volatile @transient private var _killed = false
Expand All @@ -75,12 +79,16 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
* Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
* code and user code to properly handle the flag. This function should be idempotent so it can
* be called multiple times.
* If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread.
*/
def kill() {
def kill(interruptThread: Boolean) {
_killed = true
if (context != null) {
context.interrupted = true
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private[spark] trait TaskScheduler {
def submitTasks(taskSet: TaskSet): Unit

// Cancel a stage.
def cancelTasks(stageId: Int)
def cancelTasks(stageId: Int, interruptThread: Boolean)

// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ private[spark] class TaskSchedulerImpl(
backend.reviveOffers()
}

override def cancelTasks(stageId: Int): Unit = synchronized {
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
// There are two possible cases here:
Expand All @@ -181,7 +181,7 @@ private[spark] class TaskSchedulerImpl(
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId)
backend.killTask(tid, execId, interruptThread)
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ private[spark] class TaskSet(
val properties: Properties) {
val id: String = stageId + "." + attempt

def kill() {
tasks.foreach(_.kill())
def kill(interruptThread: Boolean) {
tasks.foreach(_.kill(interruptThread))
}

override def toString: String = "TaskSet " + id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ private[spark] object CoarseGrainedClusterMessages {
// Driver to executors
case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage

case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage
case class KillTask(taskId: Long, executor: String, interruptThread: Boolean)
extends CoarseGrainedClusterMessage

case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
extends CoarseGrainedClusterMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
case ReviveOffers =>
makeOffers()

case KillTask(taskId, executorId) =>
executorActor(executorId) ! KillTask(taskId, executorId)
case KillTask(taskId, executorId, interruptThread) =>
executorActor(executorId) ! KillTask(taskId, executorId, interruptThread)

case StopDriver =>
sender ! true
Expand Down Expand Up @@ -207,8 +207,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
driverActor ! ReviveOffers
}

override def killTask(taskId: Long, executorId: String) {
driverActor ! KillTask(taskId, executorId)
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
driverActor ! KillTask(taskId, executorId, interruptThread)
}

override def defaultParallelism(): Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ private case class ReviveOffers()

private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)

private case class KillTask(taskId: Long)
private case class KillTask(taskId: Long, interruptThread: Boolean)

/**
* Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on
Expand Down Expand Up @@ -61,8 +61,8 @@ private[spark] class LocalActor(
reviveOffers()
}

case KillTask(taskId) =>
executor.killTask(taskId)
case KillTask(taskId, interruptThread) =>
executor.killTask(taskId, interruptThread)
}

def reviveOffers() {
Expand Down Expand Up @@ -99,8 +99,8 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores:

override def defaultParallelism() = totalCores

override def killTask(taskId: Long, executorId: String) {
localActor ! KillTask(taskId)
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
localActor ! KillTask(taskId, interruptThread)
}

override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
taskSets += taskSet
}
override def cancelTasks(stageId: Int) {
override def cancelTasks(stageId: Int, interruptThread: Boolean) {
cancelledStages += stageId
}
override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
Expand Down