Skip to content

Commit

Permalink
record the reported shuffle fetch failure tasks (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
turboFei authored Dec 20, 2024
1 parent 6d1b4ad commit 1c22a21
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
Expand Down Expand Up @@ -51,6 +53,7 @@

import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynFields;

Expand Down Expand Up @@ -258,6 +261,14 @@ protected static List<TaskInfo> getTaskAttempts(TaskSetManager taskSetManager, l
}
}

public static Map<Integer, List<TaskInfo>> reportedStageShuffleFetchFailureTasks =
JavaUtils.newConcurrentHashMap();

/**
* Only check for the shuffle fetch failure task whether another attempt is running or successful.
* If task another attempt has reported fetch failure, return false. If another attempt is running
* or successful, return true. Otherwise, return false.
*/
public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
TaskSetManager taskSetManager = getTaskSetManager(taskId);
if (taskSetManager != null) {
Expand All @@ -267,6 +278,23 @@ public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
taskAttempts.stream().filter(ti -> ti.taskId() == taskId).findFirst();
if (taskInfoOpt.isPresent()) {
TaskInfo taskInfo = taskInfoOpt.get();

List<TaskInfo> reportedStageTasks =
reportedStageShuffleFetchFailureTasks.computeIfAbsent(stageId, k -> new ArrayList<>());
for (TaskInfo reportedTi : reportedStageTasks) {
if (taskInfo.index() == reportedTi.index()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} already reported fetch failure.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
reportedTi.attemptNumber());
return false;
}
}
reportedStageTasks.add(taskInfo);

int taskIndex = taskInfo.index();
for (TaskInfo ti : taskAttempts) {
if (ti.taskId() != taskId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.shuffle.celeborn;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
Expand Down Expand Up @@ -54,6 +56,7 @@

import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.reflect.DynConstructors;
import org.apache.celeborn.reflect.DynFields;
import org.apache.celeborn.reflect.DynMethods;
Expand Down Expand Up @@ -374,7 +377,15 @@ protected static List<TaskInfo> getTaskAttempts(TaskSetManager taskSetManager, l
}
}

public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
public static Map<Integer, List<TaskInfo>> reportedStageShuffleFetchFailureTasks =
JavaUtils.newConcurrentHashMap();

/**
* Only check for the shuffle fetch failure task whether another attempt is running or successful.
* If task another attempt has reported fetch failure, return false. If another attempt is running
* or successful, return true. Otherwise, return false.
*/
public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
TaskSetManager taskSetManager = getTaskSetManager(taskId);
if (taskSetManager != null) {
int stageId = taskSetManager.stageId();
Expand All @@ -383,6 +394,23 @@ public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
taskAttempts.stream().filter(ti -> ti.taskId() == taskId).findFirst();
if (taskInfoOpt.isPresent()) {
TaskInfo taskInfo = taskInfoOpt.get();

List<TaskInfo> reportedStageTasks =
reportedStageShuffleFetchFailureTasks.computeIfAbsent(stageId, k -> new ArrayList<>());
for (TaskInfo reportedTi : reportedStageTasks) {
if (taskInfo.index() == reportedTi.index()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} already reported fetch failure.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
reportedTi.attemptNumber());
return false;
}
}
reportedStageTasks.add(taskInfo);

int taskIndex = taskInfo.index();
for (TaskInfo ti : taskAttempts) {
if (ti.taskId() != taskId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class SparkUtilsSuite extends AnyFunSuite {
assert(taskSetManager != null)
assert(SparkUtils.getTaskAttempts(taskSetManager, taskId).size() == 1)
assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId))
SparkUtils.reportedStageShuffleFetchFailureTasks.clear();
}

jobThread.interrupt()
Expand Down

0 comments on commit 1c22a21

Please sign in to comment.