Skip to content

Commit

Permalink
[CELEBORN-1720] Prevent stage re-run if task another attempt is running
Browse files Browse the repository at this point in the history
  • Loading branch information
turboFei committed Nov 15, 2024
1 parent b755765 commit 8c43cc2
Show file tree
Hide file tree
Showing 14 changed files with 162 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ private void initializeLifecycleManager(String appId) {
if (celebornConf.clientFetchThrowsFetchFailure()) {
MapOutputTrackerMaster mapOutputTracker =
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(
taskId -> !SparkUtils.taskAnotherAttemptRunning(taskId));
lifecycleManager.registerShuffleTrackerCallback(
shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;

import scala.Option;
Expand All @@ -35,6 +37,9 @@
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.scheduler.ShuffleMapStage;
import org.apache.spark.scheduler.TaskInfo;
import org.apache.spark.scheduler.TaskSchedulerImpl;
import org.apache.spark.scheduler.TaskSetManager;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.storage.BlockManagerId;
Expand Down Expand Up @@ -203,4 +208,47 @@ public static void cancelShuffle(int shuffleId, String reason) {
logger.error("Can not get active SparkContext, skip cancelShuffle.");
}
}

private static final DynFields.UnboundField<ConcurrentHashMap<Long, TaskSetManager>>
TASK_ID_TO_TASK_SET_MANAGER_FIELD =
DynFields.builder()
.hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
.defaultAlwaysNull()
.build();
private static final DynFields.UnboundField<HashMap<Long, TaskInfo>> TASK_INFOS_FIELD =
DynFields.builder().hiddenImpl(TaskSetManager.class, "taskInfos").defaultAlwaysNull().build();

public static boolean taskAnotherAttemptRunning(long taskId) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
TaskSchedulerImpl taskScheduler =
(TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler();
ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
TaskSetManager taskSetManager = taskIdToTaskSetManager.get(taskId);
if (taskSetManager != null) {
HashMap<Long, TaskInfo> taskInfos = TASK_INFOS_FIELD.bind(taskSetManager).get();
TaskInfo taskInfo = taskInfos.get(taskId);
if (taskInfo != null) {
return taskSetManager.taskAttempts()[taskInfo.index()].exists(
ti -> {
if (ti.running() && ti.attemptNumber() != taskInfo.attemptNumber()) {
LOG.info("Another attempt of task {} is running: {}.", taskInfo, ti);
return true;
} else {
return false;
}
});
} else {
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
return false;
}
} else {
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
return false;
}
} else {
LOG.error("Can not get active SparkContext, skip cancelShuffle.");
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class CelebornShuffleReader[K, C](
shuffleId,
partitionId,
encodedAttemptId,
context.taskAttemptId(),
startMapIndex,
endMapIndex,
metricsCallback)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ private void initializeLifecycleManager(String appId) {
if (celebornConf.clientFetchThrowsFetchFailure()) {
MapOutputTrackerMaster mapOutputTracker =
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();

lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(
taskId -> !SparkUtils.taskAnotherAttemptRunning(taskId));
lifecycleManager.registerShuffleTrackerCallback(
shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.shuffle.celeborn;

import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;

import scala.Option;
Expand All @@ -33,6 +35,9 @@
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.scheduler.ShuffleMapStage;
import org.apache.spark.scheduler.TaskInfo;
import org.apache.spark.scheduler.TaskSchedulerImpl;
import org.apache.spark.scheduler.TaskSetManager;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
import org.apache.spark.shuffle.ShuffleReader;
Expand Down Expand Up @@ -319,4 +324,47 @@ public static void cancelShuffle(int shuffleId, String reason) {
LOG.error("Can not get active SparkContext, skip cancelShuffle.");
}
}

private static final DynFields.UnboundField<ConcurrentHashMap<Long, TaskSetManager>>
TASK_ID_TO_TASK_SET_MANAGER_FIELD =
DynFields.builder()
.hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
.defaultAlwaysNull()
.build();
private static final DynFields.UnboundField<HashMap<Long, TaskInfo>> TASK_INFOS_FIELD =
DynFields.builder().hiddenImpl(TaskSetManager.class, "taskInfos").defaultAlwaysNull().build();

public static boolean taskAnotherAttemptRunning(long taskId) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
TaskSchedulerImpl taskScheduler =
(TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler();
ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
TaskSetManager taskSetManager = taskIdToTaskSetManager.get(taskId);
if (taskSetManager != null) {
HashMap<Long, TaskInfo> taskInfos = TASK_INFOS_FIELD.bind(taskSetManager).get();
TaskInfo taskInfo = taskInfos.get(taskId);
if (taskInfo != null) {
return taskSetManager.taskAttempts()[taskInfo.index()].exists(
ti -> {
if (ti.running() && ti.attemptNumber() != taskInfo.attemptNumber()) {
LOG.info("Another attempt of task {} is running: {}.", taskInfo, ti);
return true;
} else {
return false;
}
});
} else {
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
return false;
}
} else {
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
return false;
}
} else {
LOG.error("Can not get active SparkContext, skip cancelShuffle.");
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ class CelebornShuffleReader[K, C](
handle.shuffleId,
partitionId,
encodedAttemptId,
context.taskAttemptId(),
startMapIndex,
endMapIndex,
if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER
Expand Down Expand Up @@ -371,7 +372,10 @@ class CelebornShuffleReader[K, C](

private def handleFetchExceptions(shuffleId: Int, partitionId: Int, ce: Throwable) = {
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
shuffleClient.reportShuffleFetchFailure(
handle.shuffleId,
shuffleId,
context.taskAttemptId())) {
logWarning(s"Handle fetch exceptions for ${shuffleId}-${partitionId}", ce)
throw new FetchFailedException(
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ public CelebornInputStream readPartition(
int shuffleId,
int partitionId,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
MetricsCallback metricsCallback)
Expand All @@ -233,6 +234,7 @@ public CelebornInputStream readPartition(
shuffleId,
partitionId,
attemptNumber,
taskId,
startMapIndex,
endMapIndex,
null,
Expand All @@ -247,6 +249,7 @@ public abstract CelebornInputStream readPartition(
int appShuffleId,
int partitionId,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
ExceptionMaker exceptionMaker,
Expand Down Expand Up @@ -276,7 +279,7 @@ public abstract int getShuffleId(
* cleanup for spark app. It must be a sync call and make sure the cleanup is done, otherwise,
* incorrect shuffle data can be fetched in re-run tasks
*/
public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId);
public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId);

/**
* Report barrier task failure. When any barrier task fails, all (shuffle) output for that stage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,11 +622,12 @@ public int getShuffleId(
}

@Override
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) {
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) {
PbReportShuffleFetchFailure pbReportShuffleFetchFailure =
PbReportShuffleFetchFailure.newBuilder()
.setAppShuffleId(appShuffleId)
.setShuffleId(shuffleId)
.setTaskId(taskId)
.build();
PbReportShuffleFetchFailureResponse pbReportShuffleFetchFailureResponse =
lifecycleManagerRef.askSync(
Expand Down Expand Up @@ -1752,6 +1753,7 @@ public CelebornInputStream readPartition(
int appShuffleId,
int partitionId,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
ExceptionMaker exceptionMaker,
Expand Down Expand Up @@ -1790,6 +1792,7 @@ public CelebornInputStream readPartition(
streamHandlers,
mapAttempts,
attemptNumber,
taskId,
startMapIndex,
endMapIndex,
fetchExcludedWorkers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public static CelebornInputStream create(
ArrayList<PbStreamHandler> streamHandlers,
int[] attempts,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
ConcurrentHashMap<String, Long> fetchExcludedWorkers,
Expand All @@ -77,6 +78,7 @@ public static CelebornInputStream create(
streamHandlers,
attempts,
attemptNumber,
taskId,
startMapIndex,
endMapIndex,
fetchExcludedWorkers,
Expand Down Expand Up @@ -130,6 +132,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private ArrayList<PbStreamHandler> streamHandlers;
private int[] attempts;
private final int attemptNumber;
private final long taskId;
private final int startMapIndex;
private final int endMapIndex;

Expand Down Expand Up @@ -179,6 +182,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
ArrayList<PbStreamHandler> streamHandlers,
int[] attempts,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
ConcurrentHashMap<String, Long> fetchExcludedWorkers,
Expand All @@ -198,6 +202,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
}
this.attempts = attempts;
this.attemptNumber = attemptNumber;
this.taskId = taskId;
this.startMapIndex = startMapIndex;
this.endMapIndex = endMapIndex;
this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled();
Expand Down Expand Up @@ -673,7 +678,7 @@ private boolean fillBuffer() throws IOException {
ioe = new IOException(e);
}
if (exceptionMaker != null) {
if (shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId)) {
if (shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId, taskId)) {
/*
* [[ExceptionMaker.makeException]], for spark applications with celeborn.client.spark.fetch.throwsFetchFailure enabled will result in creating
* a FetchFailedException; and that will make the TaskContext as failed with shuffle fetch issues - see SPARK-19276 for more.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
case pb: PbReportShuffleFetchFailure =>
val appShuffleId = pb.getAppShuffleId
val shuffleId = pb.getShuffleId
logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId $appShuffleId shuffleId $shuffleId")
handleReportShuffleFetchFailure(context, appShuffleId, shuffleId)
val taskId = pb.getTaskId
logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId $appShuffleId shuffleId $shuffleId taskId $taskId")
handleReportShuffleFetchFailure(context, appShuffleId, shuffleId, taskId)

case pb: PbReportBarrierStageAttemptFailure =>
val appShuffleId = pb.getAppShuffleId
Expand Down Expand Up @@ -931,7 +932,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
private def handleReportShuffleFetchFailure(
context: RpcCallContext,
appShuffleId: Int,
shuffleId: Int): Unit = {
shuffleId: Int,
taskId: Long): Unit = {

val shuffleIds = shuffleIdMapping.get(appShuffleId)
if (shuffleIds == null) {
Expand All @@ -941,9 +943,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
shuffleIds.synchronized {
shuffleIds.find(e => e._2._1 == shuffleId) match {
case Some((appShuffleIdentifier, (shuffleId, true))) =>
logInfo(s"handle fetch failure for appShuffleId $appShuffleId shuffleId $shuffleId")
ret = invokeAppShuffleTrackerCallback(appShuffleId)
shuffleIds.put(appShuffleIdentifier, (shuffleId, false))
if (invokeReportTaskShuffleFetchFailurePreCheck(taskId)) {
logInfo(s"handle fetch failure for appShuffleId $appShuffleId shuffleId $shuffleId")
ret = invokeAppShuffleTrackerCallback(appShuffleId)
shuffleIds.put(appShuffleIdentifier, (shuffleId, false))
} else {
logInfo(
s"Ignoring fetch failure from appShuffleIdentifier $appShuffleIdentifier shuffleId $shuffleId taskId $taskId")
}
case Some((appShuffleIdentifier, (shuffleId, false))) =>
logInfo(
s"Ignoring fetch failure from appShuffleIdentifier $appShuffleIdentifier shuffleId $shuffleId, " +
Expand Down Expand Up @@ -1006,6 +1013,22 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
}
}

private def invokeReportTaskShuffleFetchFailurePreCheck(taskId: Long): Boolean = {
reportTaskShuffleFetchFailurePreCheck match {
case Some(precheck) =>
try {
precheck.apply(taskId)
} catch {
case t: Throwable =>
logError(t.toString)
false
}
case None =>
throw new UnsupportedOperationException(
"unexpected! reportTaskShuffleFetchFailurePreCheck is not registered")
}
}

private def handleStageEnd(shuffleId: Int): Unit = {
// check whether shuffle has registered
if (!registeredShuffle.contains(shuffleId)) {
Expand Down Expand Up @@ -1766,6 +1789,13 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
workerStatusTracker.registerWorkerStatusListener(workerStatusListener)
}

@volatile private var reportTaskShuffleFetchFailurePreCheck
: Option[Function[java.lang.Long, Boolean]] = None
def registerReportTaskShuffleFetchFailurePreCheck(preCheck: Function[java.lang.Long, Boolean])
: Unit = {
reportTaskShuffleFetchFailurePreCheck = Some(preCheck)
}

@volatile private var appShuffleTrackerCallback: Option[Consumer[Integer]] = None
def registerShuffleTrackerCallback(callback: Consumer[Integer]): Unit = {
appShuffleTrackerCallback = Some(callback)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ public CelebornInputStream readPartition(
int appShuffleId,
int partitionId,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
ExceptionMaker exceptionMaker,
Expand Down Expand Up @@ -179,7 +180,7 @@ public int getShuffleId(
}

@Override
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) {
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) {
return true;
}

Expand Down
Loading

0 comments on commit 8c43cc2

Please sign in to comment.