Skip to content

Commit

Permalink
Address comments from mridul (#44)
Browse files Browse the repository at this point in the history
* revert logger => LOG

* taskScheduler instance lock and stage uniq id

* docs

* listener

* spark 2

* comments

* test
  • Loading branch information
turboFei authored Dec 21, 2024
1 parent 156c6f8 commit 8b4ba08
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ private void initializeLifecycleManager(String appId) {

lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(
taskId -> !SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId));
SparkUtils.addSparkListener(new ShuffleFetchFailureReportTaskCleanListener());

lifecycleManager.registerShuffleTrackerCallback(
shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
Expand All @@ -34,6 +32,7 @@
import scala.Some;
import scala.Tuple2;

import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.BarrierTaskContext;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
Expand All @@ -43,6 +42,7 @@
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.scheduler.ShuffleMapStage;
import org.apache.spark.scheduler.SparkListener;
import org.apache.spark.scheduler.TaskInfo;
import org.apache.spark.scheduler.TaskSchedulerImpl;
import org.apache.spark.scheduler.TaskSetManager;
Expand All @@ -59,7 +59,7 @@
import org.apache.celeborn.reflect.DynFields;

public class SparkUtils {
private static final Logger LOG = LoggerFactory.getLogger(SparkUtils.class);
private static final Logger logger = LoggerFactory.getLogger(SparkUtils.class);

public static final String FETCH_FAILURE_ERROR_MSG = "Celeborn FetchFailure with shuffle id ";

Expand Down Expand Up @@ -105,7 +105,7 @@ public static SQLMetric getUnsafeRowSerializerDataSizeMetric(UnsafeRowSerializer
field.setAccessible(true);
return (SQLMetric) field.get(serializer);
} catch (NoSuchFieldException | IllegalAccessException e) {
LOG.warn("Failed to get dataSize metric, aqe won`t work properly.");
logger.warn("Failed to get dataSize metric, aqe won`t work properly.");
}
return null;
}
Expand Down Expand Up @@ -212,7 +212,7 @@ public static void cancelShuffle(int shuffleId, String reason) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
LOG.error("Can not get active SparkContext, skip cancelShuffle.");
logger.error("Can not get active SparkContext, skip cancelShuffle.");
}
}

Expand All @@ -229,72 +229,90 @@ public static void cancelShuffle(int shuffleId, String reason) {
.defaultAlwaysNull()
.build();

protected static TaskSetManager getTaskSetManager(long taskId) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
TaskSchedulerImpl taskScheduler =
(TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler();
/**
* TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful.
*/
@VisibleForTesting
protected static TaskSetManager getTaskSetManager(TaskSchedulerImpl taskScheduler, long taskId) {
synchronized (taskScheduler) {
ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
return taskIdToTaskSetManager.get(taskId);
} else {
LOG.error("Can not get active SparkContext.");
return null;
}
}

protected static List<TaskInfo> getTaskAttempts(TaskSetManager taskSetManager, long taskId) {
@VisibleForTesting
protected static Tuple2<TaskInfo, List<TaskInfo>> getTaskAttempts(
TaskSetManager taskSetManager, long taskId) {
if (taskSetManager != null) {
scala.Option<TaskInfo> taskInfoOption =
TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
if (taskInfoOption.isDefined()) {
int taskIndex = taskInfoOption.get().index();
return scala.collection.JavaConverters.asJavaCollectionConverter(
taskSetManager.taskAttempts()[taskIndex])
.asJavaCollection().stream()
.collect(Collectors.toList());
TaskInfo taskInfo = taskInfoOption.get();
List<TaskInfo> taskAttempts =
scala.collection.JavaConverters.asJavaCollectionConverter(
taskSetManager.taskAttempts()[taskInfo.index()])
.asJavaCollection().stream()
.collect(Collectors.toList());
return Tuple2.apply(taskInfo, taskAttempts);
} else {
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
return Collections.emptyList();
logger.error("Can not get TaskInfo for taskId: {}", taskId);
return null;
}
} else {
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
return Collections.emptyList();
logger.error("Can not get TaskSetManager for taskId: {}", taskId);
return null;
}
}

public static Map<Integer, Set<Long>> reportedStageShuffleFetchFailureTaskIds =
protected static Map<String, Set<Long>> reportedStageShuffleFetchFailureTaskIds =
JavaUtils.newConcurrentHashMap();

protected static void removeStageReportedShuffleFetchFailureTaskIds(
int stageId, int stageAttemptId) {
reportedStageShuffleFetchFailureTaskIds.remove(stageId + "-" + stageAttemptId);
}

/**
* Only check for the shuffle fetch failure task whether another attempt is running or successful.
* If another attempt(excluding the reported shuffle fetch failure tasks in current stage) is
* running or successful, return true. Otherwise, return false.
* Only used to check for the shuffle fetch failure task whether another attempt is running or
* successful. If another attempt(excluding the reported shuffle fetch failure tasks in current
* stage) is running or successful, return true. Otherwise, return false.
*/
public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
TaskSetManager taskSetManager = getTaskSetManager(taskId);
if (taskSetManager != null) {
int stageId = taskSetManager.stageId();
Set<Long> reportedStageTaskIds =
reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(stageId, k -> new HashSet<>());
reportedStageTaskIds.add(taskId);

List<TaskInfo> taskAttempts = getTaskAttempts(taskSetManager, taskId);
Optional<TaskInfo> taskInfoOpt =
taskAttempts.stream().filter(ti -> ti.taskId() == taskId).findFirst();
if (taskInfoOpt.isPresent()) {
TaskInfo taskInfo = taskInfoOpt.get();
for (TaskInfo ti : taskAttempts) {
public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
if (sparkContext == null) {
logger.error("Can not get active SparkContext.");
return false;
}
TaskSchedulerImpl taskScheduler = (TaskSchedulerImpl) sparkContext.taskScheduler();
synchronized (taskScheduler) {
TaskSetManager taskSetManager = getTaskSetManager(taskScheduler, taskId);
if (taskSetManager != null) {
int stageId = taskSetManager.stageId();
int stageAttemptId = taskSetManager.taskSet().stageAttemptId();
String stageUniqId = stageId + "-" + stageAttemptId;
Set<Long> reportedStageTaskIds =
reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(
stageUniqId, k -> new HashSet<>());
reportedStageTaskIds.add(taskId);

Tuple2<TaskInfo, List<TaskInfo>> taskAttempts = getTaskAttempts(taskSetManager, taskId);

if (taskAttempts == null) return false;

TaskInfo taskInfo = taskAttempts._1();
for (TaskInfo ti : taskAttempts._2()) {
if (ti.taskId() != taskId) {
if (reportedStageTaskIds.contains(ti.taskId())) {
LOG.info(
logger.info(
"StageId={} index={} taskId={} attempt={} another attempt {} has reported shuffle fetch failure, ignore it.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
} else if (ti.successful()) {
LOG.info(
logger.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is successful.",
stageId,
taskInfo.index(),
Expand All @@ -303,7 +321,7 @@ public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long ta
ti.attemptNumber());
return true;
} else if (ti.running()) {
LOG.info(
logger.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is running.",
stageId,
taskInfo.index(),
Expand All @@ -316,12 +334,16 @@ public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long ta
}
return false;
} else {
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
logger.error("Can not get TaskSetManager for taskId: {}", taskId);
return false;
}
} else {
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
return false;
}
}

public static void addSparkListener(SparkListener listener) {
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
if (sparkContext != null) {
sparkContext.addSparkListener(listener);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn

import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}

class ShuffleFetchFailureReportTaskCleanListener extends SparkListener {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
SparkUtils.removeStageReportedShuffleFetchFailureTaskIds(
stageCompleted.stageInfo.stageId,
stageCompleted.stageInfo.attemptNumber())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ private void initializeLifecycleManager(String appId) {

lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(
taskId -> !SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId));
SparkUtils.addSparkListener(new ShuffleFetchFailureReportTaskCleanListener());

lifecycleManager.registerShuffleTrackerCallback(
shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId));
Expand Down
Loading

0 comments on commit 8b4ba08

Please sign in to comment.