diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index b03da328c76af..129a426f283d2 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -50,7 +50,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { loading.wait() } catch { case e: Exception => - logWarning(s"Got an exception while waiting for another thread to load $key", e) + logWarning("Got an exception while waiting for another thread to load " + key, e) } } logInfo("Finished waiting for %s".format(key)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c6b56a4ea5d3f..159ee7e7feb37 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -962,13 +962,13 @@ class DAGScheduler( if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { + val job = idToActiveJob(jobId) val independentStages = removeJobAndIndependentStages(jobId) val shouldInterruptThread = if (job.properties == null) false else job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false").toBoolean independentStages.foreach { taskSched.cancelTasks(_, shouldInterruptThread) } val error = new SparkException("Job %d cancelled".format(jobId)) - val job = idToActiveJob(jobId) job.listener.jobFailed(error) jobIdToStageIds -= jobId activeJobs -= job diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index e42146d021eae..059ab51f40b5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -165,7 +165,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.reviveOffers() } - override def cancelTasks(stageId: Int): Unit = synchronized { + override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => // There are two possible cases here: @@ -178,7 +178,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (taskIds.size > 0) { taskIds.foreach { tid => val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId) + backend.killTask(tid, execId, interruptThread) } } logInfo("Stage %d was cancelled".format(stageId)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala index 5367218faa685..23a6d938c8e9f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala @@ -30,7 +30,8 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int - def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException + def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = + throw new UnsupportedOperationException // Memory used by each executor (in megabytes) protected val executorMemory: Int = SparkContext.executorMemoryRequested diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 1c227fefe48d3..2dfb57e96f470 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -44,7 +44,7 @@ private[local] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) private[local] -case class KillTask(taskId: Long) +case class KillTask(taskId: Long, interruptThread: Boolean) private[spark] class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) @@ -62,8 +62,8 @@ class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) launchTask(localScheduler.resourceOffer(freeCores)) } - case KillTask(taskId) => - executor.killTask(taskId) + case KillTask(taskId, interruptThread) => + executor.killTask(taskId, interruptThread) } private def launchTask(tasks: Seq[TaskDescription]) { @@ -128,7 +128,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } - override def cancelTasks(stageId: Int): Unit = synchronized { + override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId)) activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => @@ -141,7 +141,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val taskIds = taskSetTaskIds(tsm.taskSet.id) if (taskIds.size > 0) { taskIds.foreach { tid => - localActor ! KillTask(tid) + localActor ! KillTask(tid, interruptThread) } } logInfo("Stage %d was cancelled".format(stageId)) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 70d1acaba8f49..f498d0e54723a 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.util.concurrent.Semaphore import scala.concurrent.Await -import scala.concurrent.duration.Duration +import scala.concurrent.duration._ import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global @@ -131,6 +131,9 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() } + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + sc.clearJobGroup() val jobB = sc.parallelize(1 to 100, 2).countAsync() sc.cancelJobGroup("jobA") @@ -160,8 +163,11 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf // Block until both tasks of job A have started and cancel job A. sem.acquire(2) + + sc.clearJobGroup() + val jobB = sc.parallelize(1 to 100, 2).countAsync() sc.cancelJobGroup("jobA") - val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) } + val e = intercept[SparkException] { Await.result(jobA, 5.seconds) } assert(e.getMessage contains "cancel") // Once A is cancelled, job B should finish fairly quickly. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index bab6e35c14478..6f1925ef092e0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import scala.collection.mutable.{Map, HashMap} +import scala.collection.mutable.{Map, HashMap, HashSet} import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter @@ -49,6 +49,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() + + /** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */ + val cancelledStages = new HashSet[Int]() + val taskScheduler = new TaskScheduler() { override def rootPool: Pool = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE @@ -99,6 +103,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont before { sc = new SparkContext("local", "DAGSchedulerSuite") taskSets.clear() + cancelledStages.clear() cacheLocations.clear() results.clear() mapOutputTracker = new MapOutputTracker()