diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 95b4f746c4d..00b28d3a930 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -20,8 +20,12 @@ import java.io.IOException; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.util.Collections; +import java.util.List; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; +import java.util.stream.Collectors; import scala.Option; import scala.Some; @@ -221,57 +225,79 @@ public static void cancelShuffle(int shuffleId, String reason) { .defaultAlwaysNull() .build(); - public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { + protected static TaskSetManager getTaskSetManager(long taskId) { if (SparkContext$.MODULE$.getActive().nonEmpty()) { TaskSchedulerImpl taskScheduler = (TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler(); ConcurrentHashMap taskIdToTaskSetManager = TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get(); - TaskSetManager taskSetManager = taskIdToTaskSetManager.get(taskId); - if (taskSetManager != null) { - int stageId = taskSetManager.stageId(); - scala.Option taskInfoOption = - TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId); - if (taskInfoOption.isDefined()) { - TaskInfo taskInfo = taskInfoOption.get(); - int taskIndex = taskInfo.index(); - if (taskSetManager.successful()[taskIndex]) { - LOG.info( - "StageId={} index={} taskId={} attempt={} another attempt has been successful.", - stageId, - taskIndex, - taskId, - taskInfo.attemptNumber()); - return true; + return taskIdToTaskSetManager.get(taskId); + } else { + LOG.error("Can not get active SparkContext."); + return null; + } + } + + protected static List getTaskAttempts(TaskSetManager taskSetManager, long taskId) { + if (taskSetManager != null) { + scala.Option taskInfoOption = + TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId); + if (taskInfoOption.isDefined()) { + int taskIndex = taskInfoOption.get().index(); + return scala.collection.JavaConverters.asJavaCollectionConverter( + taskSetManager.taskAttempts()[taskIndex]) + .asJavaCollection().stream() + .collect(Collectors.toList()); + } else { + LOG.error("Can not get TaskInfo for taskId: {}", taskId); + return Collections.emptyList(); + } + } else { + LOG.error("Can not get TaskSetManager for taskId: {}", taskId); + return Collections.emptyList(); + } + } + + public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { + TaskSetManager taskSetManager = getTaskSetManager(taskId); + if (taskSetManager != null) { + int stageId = taskSetManager.stageId(); + List taskAttempts = getTaskAttempts(taskSetManager, taskId); + Optional taskInfoOpt = + taskAttempts.stream().filter(ti -> ti.taskId() == taskId).findFirst(); + if (taskInfoOpt.isPresent()) { + TaskInfo taskInfo = taskInfoOpt.get(); + int taskIndex = taskInfo.index(); + for (TaskInfo ti : taskAttempts) { + if (ti.taskId() != taskId) { + if (ti.successful()) { + LOG.info( + "StageId={} index={} taskId={} attempt={} another attempt {} is finished.", + stageId, + taskIndex, + taskId, + taskInfo.attemptNumber(), + ti.attemptNumber()); + return true; + } else if (ti.running()) { + LOG.info( + "StageId={} index={} taskId={} attempt={} another attempt {} is running.", + stageId, + taskIndex, + taskId, + taskInfo.attemptNumber(), + ti.attemptNumber()); + return true; + } } - return scala.collection.JavaConverters.asJavaCollectionConverter( - taskSetManager.taskAttempts()[taskIndex]) - .asJavaCollection().stream() - .anyMatch( - ti -> { - if (!ti.finished() && ti.attemptNumber() != taskInfo.attemptNumber()) { - LOG.info( - "StageId={} index={} taskId={} attempt={} another attempt {} is running.", - stageId, - taskIndex, - taskId, - taskInfo.attemptNumber(), - ti.attemptNumber()); - return true; - } else { - return false; - } - }); - } else { - LOG.error("Can not get TaskInfo for taskId: {}", taskId); - return false; } + return false; } else { - LOG.error("Can not get TaskSetManager for taskId: {}", taskId); + LOG.error("Can not get TaskInfo for taskId: {}", taskId); return false; } } else { - LOG.error("Can not get active SparkContext, skip checking."); + LOG.error("Can not get TaskSetManager for taskId: {}", taskId); return false; } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 9f4bd69b921..378acb300c3 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -17,8 +17,12 @@ package org.apache.spark.shuffle.celeborn; +import java.util.Collections; +import java.util.List; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; +import java.util.stream.Collectors; import scala.Option; import scala.Some; @@ -337,57 +341,79 @@ public static void cancelShuffle(int shuffleId, String reason) { .defaultAlwaysNull() .build(); - public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { + protected static TaskSetManager getTaskSetManager(long taskId) { if (SparkContext$.MODULE$.getActive().nonEmpty()) { TaskSchedulerImpl taskScheduler = (TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler(); ConcurrentHashMap taskIdToTaskSetManager = TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get(); - TaskSetManager taskSetManager = taskIdToTaskSetManager.get(taskId); - if (taskSetManager != null) { - int stageId = taskSetManager.stageId(); - scala.Option taskInfoOption = - TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId); - if (taskInfoOption.isDefined()) { - TaskInfo taskInfo = taskInfoOption.get(); - int taskIndex = taskInfo.index(); - if (taskSetManager.successful()[taskIndex]) { - LOG.info( - "StageId={} index={} taskId={} attempt={} another attempt has been successful.", - stageId, - taskIndex, - taskId, - taskInfo.attemptNumber()); - return true; + return taskIdToTaskSetManager.get(taskId); + } else { + LOG.error("Can not get active SparkContext."); + return null; + } + } + + protected static List getTaskAttempts(TaskSetManager taskSetManager, long taskId) { + if (taskSetManager != null) { + scala.Option taskInfoOption = + TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId); + if (taskInfoOption.isDefined()) { + int taskIndex = taskInfoOption.get().index(); + return scala.collection.JavaConverters.asJavaCollectionConverter( + taskSetManager.taskAttempts()[taskIndex]) + .asJavaCollection().stream() + .collect(Collectors.toList()); + } else { + LOG.error("Can not get TaskInfo for taskId: {}", taskId); + return Collections.emptyList(); + } + } else { + LOG.error("Can not get TaskSetManager for taskId: {}", taskId); + return Collections.emptyList(); + } + } + + public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { + TaskSetManager taskSetManager = getTaskSetManager(taskId); + if (taskSetManager != null) { + int stageId = taskSetManager.stageId(); + List taskAttempts = getTaskAttempts(taskSetManager, taskId); + Optional taskInfoOpt = + taskAttempts.stream().filter(ti -> ti.taskId() == taskId).findFirst(); + if (taskInfoOpt.isPresent()) { + TaskInfo taskInfo = taskInfoOpt.get(); + int taskIndex = taskInfo.index(); + for (TaskInfo ti : taskAttempts) { + if (ti.taskId() != taskId) { + if (ti.successful()) { + LOG.info( + "StageId={} index={} taskId={} attempt={} another attempt {} is finished.", + stageId, + taskIndex, + taskId, + taskInfo.attemptNumber(), + ti.attemptNumber()); + return true; + } else if (ti.running()) { + LOG.info( + "StageId={} index={} taskId={} attempt={} another attempt {} is running.", + stageId, + taskIndex, + taskId, + taskInfo.attemptNumber(), + ti.attemptNumber()); + return true; + } } - return scala.collection.JavaConverters.asJavaCollectionConverter( - taskSetManager.taskAttempts()[taskIndex]) - .asJavaCollection().stream() - .anyMatch( - ti -> { - if (!ti.finished() && ti.attemptNumber() != taskInfo.attemptNumber()) { - LOG.info( - "StageId={} index={} taskId={} attempt={} another attempt {} is running.", - stageId, - taskIndex, - taskId, - taskInfo.attemptNumber(), - ti.attemptNumber()); - return true; - } else { - return false; - } - }); - } else { - LOG.error("Can not get TaskInfo for taskId: {}", taskId); - return false; } + return false; } else { - LOG.error("Can not get TaskSetManager for taskId: {}", taskId); + LOG.error("Can not get TaskInfo for taskId: {}", taskId); return false; } } else { - LOG.error("Can not get active SparkContext, skip checking."); + LOG.error("Can not get TaskSetManager for taskId: {}", taskId); return false; } } diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala new file mode 100644 index 00000000000..e8374bf450a --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.{interval, timeout} +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime + +class SparkUtilsSuite extends AnyFunSuite { + test("another task running or successful") { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2]") + val sparkSession = SparkSession.builder() + .config(sparkConf) + .config("spark.sql.shuffle.partitions", 2) + .getOrCreate() + + try { + val sc = sparkSession.sparkContext + val jobThread = new Thread { + override def run(): Unit = { + try { + val rdd = sc.parallelize(1 to 100, 2) + rdd.mapPartitions { iter => + Thread.sleep(5000) + iter + }.collect() + } catch { + case _: InterruptedException => + } + } + } + jobThread.start() + + eventually(timeout(5.seconds), interval(100.milliseconds)) { + val taskId = 0 + val taskSetManager = SparkUtils.getTaskSetManager(taskId) + assert(taskSetManager != null) + assert(SparkUtils.getTaskAttempts(taskSetManager, taskId).size() == 1) + assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)) + } + + jobThread.interrupt() + } finally { + sparkSession.stop() + } + } +}