diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 713bed2abc4..e0888c61704 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -20,8 +20,10 @@ import java.io.IOException; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; @@ -258,6 +260,14 @@ protected static List getTaskAttempts(TaskSetManager taskSetManager, l } } + public static Map> reportedStageShuffleFetchFailureTasks = + JavaUtils.newConcurrentHashMap(); + + /** + * Only check for the shuffle fetch failure task whether another attempt is running or successful. + * If task another attempt has reported fetch failure, return false. If another attempt is running + * or successful, return true. Otherwise, return false. + */ public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { TaskSetManager taskSetManager = getTaskSetManager(taskId); if (taskSetManager != null) { @@ -267,6 +277,23 @@ public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { taskAttempts.stream().filter(ti -> ti.taskId() == taskId).findFirst(); if (taskInfoOpt.isPresent()) { TaskInfo taskInfo = taskInfoOpt.get(); + + List reportedStageTasks = + reportedStageShuffleFetchFailureTasks.computeIfAbsent(stageId, k -> new ArrayList<>()); + for (TaskInfo reportedTi : reportedStageTasks) { + if (taskInfo.index() == reportedTi.index()) { + LOG.info( + "StageId={} index={} taskId={} attempt={} another attempt {} already reported fetch failure.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber(), + reportedTi.attemptNumber()); + return false; + } + } + reportedStageTasks.add(taskInfo); + int taskIndex = taskInfo.index(); for (TaskInfo ti : taskAttempts) { if (ti.taskId() != taskId) { diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 5aa005a49d9..d70f7a95253 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -17,8 +17,10 @@ package org.apache.spark.shuffle.celeborn; +import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; @@ -54,6 +56,7 @@ import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.util.JavaUtils; import org.apache.celeborn.reflect.DynConstructors; import org.apache.celeborn.reflect.DynFields; import org.apache.celeborn.reflect.DynMethods; @@ -374,7 +377,15 @@ protected static List getTaskAttempts(TaskSetManager taskSetManager, l } } - public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { + public static Map> reportedStageShuffleFetchFailureTasks = + JavaUtils.newConcurrentHashMap(); + + /** + * Only check for the shuffle fetch failure task whether another attempt is running or successful. + * If task another attempt has reported fetch failure, return false. If another attempt is running + * or successful, return true. Otherwise, return false. + */ + public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { TaskSetManager taskSetManager = getTaskSetManager(taskId); if (taskSetManager != null) { int stageId = taskSetManager.stageId(); @@ -383,6 +394,23 @@ public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { taskAttempts.stream().filter(ti -> ti.taskId() == taskId).findFirst(); if (taskInfoOpt.isPresent()) { TaskInfo taskInfo = taskInfoOpt.get(); + + List reportedStageTasks = + reportedStageShuffleFetchFailureTasks.computeIfAbsent(stageId, k -> new ArrayList<>()); + for (TaskInfo reportedTi : reportedStageTasks) { + if (taskInfo.index() == reportedTi.index()) { + LOG.info( + "StageId={} index={} taskId={} attempt={} another attempt {} already reported fetch failure.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber(), + reportedTi.attemptNumber()); + return false; + } + } + reportedStageTasks.add(taskInfo); + int taskIndex = taskInfo.index(); for (TaskInfo ti : taskAttempts) { if (ti.taskId() != taskId) { diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index e8374bf450a..bf8aca8bf8d 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -55,6 +55,7 @@ class SparkUtilsSuite extends AnyFunSuite { assert(taskSetManager != null) assert(SparkUtils.getTaskAttempts(taskSetManager, taskId).size() == 1) assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)) + SparkUtils.reportedStageShuffleFetchFailureTasks.clear(); } jobThread.interrupt()