Skip to content

Commit

Permalink
record the reported shuffle fetch failure tasks (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
turboFei committed Dec 20, 2024
1 parent 4ae7cee commit 52c3066
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -51,6 +54,7 @@

import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynFields;

Expand Down Expand Up @@ -258,21 +262,42 @@ protected static List<TaskInfo> getTaskAttempts(TaskSetManager taskSetManager, l
}
}

public static Map<Integer, Set<Long>> reportedStageShuffleFetchFailureTaskIds =
JavaUtils.newConcurrentHashMap();

/**
* Only check for the shuffle fetch failure task whether another attempt is running or successful.
* If another attempt(excluding the reported shuffle fetch failure tasks in current stage) is
* running or successful, return true. Otherwise, return false.
*/
public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
TaskSetManager taskSetManager = getTaskSetManager(taskId);
if (taskSetManager != null) {
int stageId = taskSetManager.stageId();
Set<Long> reportedStageTaskIds =
reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(stageId, k -> new HashSet<>());
reportedStageTaskIds.add(taskId);

List<TaskInfo> taskAttempts = getTaskAttempts(taskSetManager, taskId);
Optional<TaskInfo> taskInfoOpt =
taskAttempts.stream().filter(ti -> ti.taskId() == taskId).findFirst();
if (taskInfoOpt.isPresent()) {
TaskInfo taskInfo = taskInfoOpt.get();

int taskIndex = taskInfo.index();
for (TaskInfo ti : taskAttempts) {
if (ti.taskId() != taskId) {
if (ti.successful()) {
if (reportedStageTaskIds.contains(ti.taskId())) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} has reported shuffle fetch failure, ignore it.",
stageId,
taskIndex,
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
} else if (ti.successful()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is finished.",
"StageId={} index={} taskId={} attempt={} another attempt {} is successful.",
stageId,
taskIndex,
taskId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
package org.apache.spark.shuffle.celeborn;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -54,6 +57,7 @@

import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.reflect.DynConstructors;
import org.apache.celeborn.reflect.DynFields;
import org.apache.celeborn.reflect.DynMethods;
Expand Down Expand Up @@ -374,21 +378,42 @@ protected static List<TaskInfo> getTaskAttempts(TaskSetManager taskSetManager, l
}
}

public static Map<Integer, Set<Long>> reportedStageShuffleFetchFailureTaskIds =
JavaUtils.newConcurrentHashMap();

/**
* Only check for the shuffle fetch failure task whether another attempt is running or successful.
* If another attempt(excluding the reported shuffle fetch failure tasks in current stage) is
* running or successful, return true. Otherwise, return false.
*/
public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
TaskSetManager taskSetManager = getTaskSetManager(taskId);
if (taskSetManager != null) {
int stageId = taskSetManager.stageId();
Set<Long> reportedStageTaskIds =
reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(stageId, k -> new HashSet<>());
reportedStageTaskIds.add(taskId);

List<TaskInfo> taskAttempts = getTaskAttempts(taskSetManager, taskId);
Optional<TaskInfo> taskInfoOpt =
taskAttempts.stream().filter(ti -> ti.taskId() == taskId).findFirst();
if (taskInfoOpt.isPresent()) {
TaskInfo taskInfo = taskInfoOpt.get();

int taskIndex = taskInfo.index();
for (TaskInfo ti : taskAttempts) {
if (ti.taskId() != taskId) {
if (ti.successful()) {
if (reportedStageTaskIds.contains(ti.taskId())) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} has reported shuffle fetch failure, ignore it.",
stageId,
taskIndex,
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
} else if (ti.successful()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is finished.",
"StageId={} index={} taskId={} attempt={} another attempt {} is successful.",
stageId,
taskIndex,
taskId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class SparkUtilsSuite extends AnyFunSuite {
assert(taskSetManager != null)
assert(SparkUtils.getTaskAttempts(taskSetManager, taskId).size() == 1)
assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId))
SparkUtils.reportedStageShuffleFetchFailureTaskIds.clear();
}

jobThread.interrupt()
Expand Down

0 comments on commit 52c3066

Please sign in to comment.