Skip to content

Commit

Permalink
Writing proper unit test for OutputCommitCoordinator and fixing bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
mccheah committed Jan 29, 2015
1 parent d63f63f commit 60a47f4
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 225 deletions.
3 changes: 1 addition & 2 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,7 @@ object SparkEnv extends Logging {
val outputCommitCoordinator = new OutputCommitCoordinator(conf)
val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator",
new OutputCommitCoordinatorActor(outputCommitCoordinator))
outputCommitCoordinator.initialize(outputCommitCoordinatorActor, isDriver)

outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor)
new SparkEnv(
executorId,
actorSystem,
Expand Down
9 changes: 3 additions & 6 deletions core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
val cmtr = getOutputCommitter()
if (cmtr.needsTaskCommit(taCtxt)) {
val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
val canCommit = outputCommitCoordinator.canCommit(jobID,
taID.value.getTaskID().getId(), splitID, attemptID)
val canCommit = outputCommitCoordinator.canCommit(jobID, splitID, attemptID)
if (canCommit) {
try {
cmtr.commitTask(taCtxt)
Expand All @@ -124,13 +123,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
} else {
val msg: String = s"$taID: Not committed because the driver did not authorize commit"
logInfo(msg)
cmtr.abortTask(taCtxt)
throw new CommitDeniedException(msg, jobID, splitID, attemptID)
}
} else {
val msg: String = s"No need to commit output of task" +
" because needsTaskCommit=false: ${taID.value}"
logInfo(msg)
throw new CommitDeniedException(msg, jobID, splitID, attemptID)
logInfo(s"No need to commit output of task because needsTaskCommit=false: ${taID.value}")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ class DAGScheduler(
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size))
outputCommitCoordinator.stageStart(stage.id, partitionsToCompute)
outputCommitCoordinator.stageStart(stage.id)
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))

// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
Expand Down Expand Up @@ -912,11 +912,9 @@ class DAGScheduler(
val task = event.task
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)
val isSuccess = event.reason == Success

outputCommitCoordinator.taskCompleted(stageId, event.taskInfo.taskId,
task.partitionId, event.taskInfo.attempt,
isSuccess)
outputCommitCoordinator.taskCompleted(stageId, task.partitionId,
event.taskInfo.attempt, event.reason)

// The success case is dealt with separately below, since we need to compute accumulator
// updates before posting.
Expand All @@ -930,6 +928,7 @@ class DAGScheduler(
// Skip all the actions if the stage has been cancelled.
return
}

val stage = stageIdToStage(task.stageId)

def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,31 @@

package org.apache.spark.scheduler

import java.util.concurrent.{ExecutorService, TimeUnit, ConcurrentHashMap}

import scala.collection.{Map => ScalaImmutableMap}
import scala.collection.convert.decorateAsScala._
import scala.collection.mutable

import akka.actor.{ActorRef, Actor}

import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.util.{Utils, AkkaUtils, ActorLogReceive}
import org.apache.spark._
import org.apache.spark.util.{AkkaUtils, ActorLogReceive}

private[spark] sealed trait OutputCommitCoordinationMessage
private[spark] sealed trait OutputCommitCoordinationMessage extends Serializable

private[spark] case class StageStarted(stage: Int, partitionIds: Seq[Int])
extends OutputCommitCoordinationMessage
private[spark] case class StageStarted(stage: Int) extends OutputCommitCoordinationMessage
private[spark] case class StageEnded(stage: Int) extends OutputCommitCoordinationMessage
private[spark] case object StopCoordinator extends OutputCommitCoordinationMessage

private[spark] case class AskPermissionToCommitOutput(
stage: Int,
task: Long,
partId: Int,
taskAttempt: Long)
extends OutputCommitCoordinationMessage with Serializable
extends OutputCommitCoordinationMessage

private[spark] case class TaskCompleted(
stage: Int,
task: Long,
partId: Int,
attempt: Long,
successful: Boolean)
extends OutputCommitCoordinationMessage
reason: TaskEndReason)
extends OutputCommitCoordinationMessage

/**
* Authority that decides whether tasks can commit output to HDFS.
Expand All @@ -57,157 +51,100 @@ private[spark] case class TaskCompleted(
*/
private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {

private type StageId = Int
private type PartitionId = Int
private type TaskId = Long
private type TaskAttemptId = Long

// Wrapper for an int option that allows it to be locked via a synchronized block
// while still setting option itself to Some(...) or None.
private class LockableAttemptId(var value: Option[TaskAttemptId])

private type CommittersByStageHashMap =
ConcurrentHashMap[StageId, ScalaImmutableMap[PartitionId, LockableAttemptId]]

// Initialized by SparkEnv
private var coordinatorActor: Option[ActorRef] = None
var coordinatorActor: Option[ActorRef] = None
private val timeout = AkkaUtils.askTimeout(conf)
private val maxAttempts = AkkaUtils.numRetries(conf)
private val retryInterval = AkkaUtils.retryWaitMs(conf)
private val authorizedCommittersByStage = new CommittersByStageHashMap().asScala

private var executorRequestHandlingThreadPool: Option[ExecutorService] = None
private type StageId = Int
private type TaskId = Long
private type TaskAttemptId = Long
private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[TaskId, TaskAttemptId]]

def stageStart(stage: StageId, partitionIds: Seq[Int]): Unit = {
sendToActor(StageStarted(stage, partitionIds))
}
private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map()

def stageEnd(stage: StageId): Unit = {
def stageStart(stage: StageId) {
sendToActor(StageStarted(stage))
}
def stageEnd(stage: StageId) {
sendToActor(StageEnded(stage))
}

def canCommit(
stage: StageId,
task: TaskId,
partId: PartitionId,
attempt: TaskAttemptId): Boolean = {
askActor(AskPermissionToCommitOutput(stage, task, partId, attempt))
askActor(AskPermissionToCommitOutput(stage, task, attempt))
}

def taskCompleted(
stage: StageId,
task: TaskId,
partId: PartitionId,
attempt: TaskAttemptId,
successful: Boolean): Unit = {
sendToActor(TaskCompleted(stage, task, partId, attempt, successful))
reason: TaskEndReason) {
sendToActor(TaskCompleted(stage, task, attempt, reason))
}

def stop(): Unit = {
executorRequestHandlingThreadPool.foreach { pool =>
pool.shutdownNow()
pool.awaitTermination(10, TimeUnit.SECONDS)
}
def stop() {
sendToActor(StopCoordinator)
coordinatorActor = None
executorRequestHandlingThreadPool = None
authorizedCommittersByStage.foreach(_._2.clear)
authorizedCommittersByStage.clear
}

def initialize(actor: ActorRef, isDriver: Boolean): Unit = {
coordinatorActor = Some(actor)
executorRequestHandlingThreadPool = {
if (isDriver) {
Some(Utils.newDaemonFixedThreadPool(8, "OutputCommitCoordinator"))
} else {
None
}
}
}

// Methods that mutate the internal state of the coordinator shouldn't be
// called directly, and are thus made private instead of public. The
// private methods should be called from the Actor, and callers use the
// public methods to send messages to the actor.
private def handleStageStart(stage: StageId, partitionIds: Seq[Int]): Unit = {
val initialLockStates = partitionIds.map(partId => {
partId -> new LockableAttemptId(None)
}).toMap
authorizedCommittersByStage.put(stage, initialLockStates)
private def handleStageStart(stage: StageId): Unit = {
authorizedCommittersByStage(stage) = mutable.HashMap[TaskId, TaskAttemptId]()
}

private def handleStageEnd(stage: StageId): Unit = {
authorizedCommittersByStage.remove(stage)
}

private def determineIfCommitAllowed(
private def handleAskPermissionToCommit(
stage: StageId,
task: TaskId,
partId: PartitionId,
attempt: TaskAttemptId): Boolean = {
attempt: TaskAttemptId):
Boolean = {
authorizedCommittersByStage.get(stage) match {
case Some(authorizedCommitters) =>
val authorizedCommitMetadataForPart = authorizedCommitters(partId)
authorizedCommitMetadataForPart.synchronized {
// Don't use match - we'll be setting the value of the option in the else block
if (authorizedCommitMetadataForPart.value.isDefined) {
val existingCommitter = authorizedCommitMetadataForPart.value.get
authorizedCommitters.get(stage) match {
case Some(existingCommitter) =>
logDebug(s"Denying $attempt to commit for stage=$stage, task=$task; " +
s"existingCommitter = $existingCommitter")
false
} else {
case None =>
logDebug(s"Authorizing $attempt to commit for stage=$stage, task=$task")
authorizedCommitMetadataForPart.value = Some(attempt)
authorizedCommitters(task) = attempt
true
}
}
case None =>
logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit")
false
return false
}
}

private def handleAskPermissionToCommitOutput(
requester: ActorRef,
stage: StageId,
task: TaskId,
partId: PartitionId,
attempt: TaskAttemptId): Unit = {
executorRequestHandlingThreadPool match {
case Some(threadPool) =>
threadPool.submit(new AskCommitRunnable(requester, this, stage, task, partId, attempt))
case None =>
logWarning("Got a request to commit output, but the OutputCommitCoordinator was already" +
" shut down. Request is being denied.")
requester ! false
}

}

private def handleTaskCompletion(
stage: StageId,
task: TaskId,
partId: PartitionId,
attempt: TaskAttemptId,
successful: Boolean): Unit = {
reason: TaskEndReason): Unit = {
authorizedCommittersByStage.get(stage) match {
case Some(authorizedCommitters) =>
val authorizedCommitMetadataForPart = authorizedCommitters(partId)
authorizedCommitMetadataForPart.synchronized {
if (authorizedCommitMetadataForPart.value == Some(attempt) && !successful) {
logDebug(s"Authorized committer $attempt (stage=$stage," +
s" task=$task) failed; clearing lock")
// The authorized committer failed; clear the lock so future attempts can
// commit their output
authorizedCommitMetadataForPart.value = None
}
reason match {
case Success => return
case TaskCommitDenied(jobID, splitID, attemptID) =>
logInfo(s"Task was denied committing, stage: $stage, taskId: $task, attempt: $attempt")
case otherReason =>
logDebug(s"Authorized committer $attempt (stage=$stage, task=$task) failed; clearing lock")
authorizedCommitters.remove(task)
}
case None =>
logDebug(s"Ignoring task completion for completed stage")
}
}

private def sendToActor(msg: OutputCommitCoordinationMessage): Unit = {
private def sendToActor(msg: OutputCommitCoordinationMessage) {
coordinatorActor.foreach(_ ! msg)
}

Expand All @@ -216,39 +153,22 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
.map(AkkaUtils.askWithReply[Boolean](msg, _, maxAttempts, retryInterval, timeout))
.getOrElse(false)
}

class AskCommitRunnable(
private val requester: ActorRef,
private val outputCommitCoordinator: OutputCommitCoordinator,
private val stage: StageId,
private val task: TaskId,
private val partId: PartitionId,
private val taskAttempt: TaskAttemptId)
extends Runnable {
override def run(): Unit = {
requester ! outputCommitCoordinator.determineIfCommitAllowed(stage, task, partId, taskAttempt)
}
}
}

private[spark] object OutputCommitCoordinator {

// Actor is defined inside the OutputCommitCoordinator object so that receiveWithLogging()
// can call the private methods, where it is safe to do so because it is in the actor event
// loop.
class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator)
extends Actor with ActorLogReceive with Logging {

override def receiveWithLogging() = {
case StageStarted(stage, partitionIds) =>
outputCommitCoordinator.handleStageStart(stage, partitionIds)
override def receiveWithLogging = {
case StageStarted(stage) =>
outputCommitCoordinator.handleStageStart(stage)
case StageEnded(stage) =>
outputCommitCoordinator.handleStageEnd(stage)
case AskPermissionToCommitOutput(stage, task, partId, taskAttempt) =>
outputCommitCoordinator.handleAskPermissionToCommitOutput(
sender, stage, task, partId, taskAttempt)
case TaskCompleted(stage, task, partId, attempt, successful) =>
outputCommitCoordinator.handleTaskCompletion(stage, task, partId, attempt, successful)
case AskPermissionToCommitOutput(stage, task, taskAttempt) =>
sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, task, taskAttempt)
case TaskCompleted(stage, task, attempt, reason) =>
outputCommitCoordinator.handleTaskCompletion(stage, task, attempt, reason)
case StopCoordinator =>
logInfo("OutputCommitCoordinator stopped!")
context.stop(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ private[spark] class TaskSchedulerImpl(
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
val manager = createTaskSetManager(taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

Expand All @@ -180,6 +180,13 @@ private[spark] class TaskSchedulerImpl(
backend.reviveOffers()
}

// Label as private[scheduler] to allow tests to swap in different task set managers if necessary
private[scheduler] def createTaskSetManager(
taskSet: TaskSet,
maxTaskFailures: Int): TaskSetManager = {
new TaskSetManager(this, taskSet, maxTaskFailures)
}

override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
Expand Down
Loading

0 comments on commit 60a47f4

Please sign in to comment.