diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 37e3381dc68..142d451a11c 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -103,7 +103,7 @@ private void initializeLifecycleManager(String appId) { (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker(); lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck( - taskId -> !SparkUtils.taskAnotherAttemptRunning(taskId)); + taskId -> !SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)); lifecycleManager.registerShuffleTrackerCallback( shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId)); diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 1d5ed4b05ef..4272f9830ef 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -221,7 +221,7 @@ public static void cancelShuffle(int shuffleId, String reason) { .defaultAlwaysNull() .build(); - public static boolean taskAnotherAttemptRunning(long taskId) { + public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { if (SparkContext$.MODULE$.getActive().nonEmpty()) { TaskSchedulerImpl taskScheduler = (TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler(); @@ -238,7 +238,8 @@ public static boolean taskAnotherAttemptRunning(long taskId) { .asJavaCollection().stream() .anyMatch( ti -> { - if (ti.running() && ti.attemptNumber() != taskInfo.attemptNumber()) { + if ((ti.running() || ti.successful()) + && ti.attemptNumber() != taskInfo.attemptNumber()) { LOG.info("Another attempt of task {} is running: {}.", taskInfo, ti); return true; } else { diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 307e877cb2f..b4849e3f036 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -145,7 +145,7 @@ private void initializeLifecycleManager(String appId) { (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker(); lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck( - taskId -> !SparkUtils.taskAnotherAttemptRunning(taskId)); + taskId -> !SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)); lifecycleManager.registerShuffleTrackerCallback( shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId)); diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 46869a76d0a..05f6093df0f 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -337,7 +337,7 @@ public static void cancelShuffle(int shuffleId, String reason) { .defaultAlwaysNull() .build(); - public static boolean taskAnotherAttemptRunning(long taskId) { + public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { if (SparkContext$.MODULE$.getActive().nonEmpty()) { TaskSchedulerImpl taskScheduler = (TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler(); @@ -354,7 +354,8 @@ public static boolean taskAnotherAttemptRunning(long taskId) { .asJavaCollection().stream() .anyMatch( ti -> { - if (ti.running() && ti.attemptNumber() != taskInfo.attemptNumber()) { + if ((ti.running() || ti.successful()) + && ti.attemptNumber() != taskInfo.attemptNumber()) { LOG.info("Another attempt of task {} is running: {}.", taskInfo, ti); return true; } else { diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala index f3cd382118c..f5e8b04c6f9 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{BarrierTaskContext, ShuffleDependency, SparkConf, SparkContextHelper, SparkException, TaskContext} import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.SparkSchedulerHelper import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager} import org.apache.spark.sql.SparkSession @@ -56,7 +57,8 @@ class CelebornFetchFailureSuite extends AnyFunSuite super.createWorker(map, storageDir) } - class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook { + class ShuffleReaderGetHook(conf: CelebornConf, speculation: Boolean = false) + extends ShuffleManagerHook { var executed: AtomicBoolean = new AtomicBoolean(false) val lock = new Object @@ -65,6 +67,10 @@ class CelebornFetchFailureSuite extends AnyFunSuite startPartition: Int, endPartition: Int, context: TaskContext): Unit = { + val taskIndex = SparkSchedulerHelper.getTaskIndex(context.taskAttemptId()) + if (speculation && taskIndex == 0) { + Thread.sleep(3000) // sleep for speculation + } if (executed.get() == true) return lock.synchronized { @@ -82,17 +88,19 @@ class CelebornFetchFailureSuite extends AnyFunSuite val allFiles = workerDirs.map(dir => { new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") }) - val datafile = allFiles.filter(_.exists()) - .flatMap(_.listFiles().iterator).headOption - datafile match { - case Some(file) => file.delete() - case None => throw new RuntimeException("unexpected, there must be some data file" + - s" under ${workerDirs.mkString(",")}") + val datafiles = allFiles.filter(_.exists()) + if (datafiles.nonEmpty) { + if (taskIndex == 0) { // only cleanup the data file in the task with index 0 + datafiles.foreach(_.delete()) + executed.set(true) + } + } else { + throw new RuntimeException("unexpected, there must be some data file" + + s" under ${workerDirs.mkString(",")}") } } case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") } - executed.set(true) } } } @@ -439,6 +447,61 @@ class CelebornFetchFailureSuite extends AnyFunSuite } } + test(s"celeborn spark integration test - do not rerun stage if task another attempt is running") { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .config("spark.speculation", "true") + .config("spark.speculation.multiplier", "2") + .config("spark.speculation.quantile", "0") + .getOrCreate() + + val shuffleMgr = SparkContextHelper.env + .shuffleManager + .asInstanceOf[TestCelebornShuffleManager] + var preventUnnecessaryStageRerun = false + val lifecycleManager = shuffleMgr.getLifecycleManager + lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(new java.util.function.Function[ + java.lang.Long, + Boolean] { + override def apply(taskId: java.lang.Long): Boolean = { + val anotherRunningOrSuccessful = SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId) + if (anotherRunningOrSuccessful) { + preventUnnecessaryStageRerun = true + } + !anotherRunningOrSuccessful + } + }) + + val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) + val hook = new ShuffleReaderGetHook(celebornConf, speculation = true) + TestCelebornShuffleManager.registerReaderGetHook(hook) + + val value = Range(1, 10000).mkString(",") + val tuples = sparkSession.sparkContext.parallelize(1 to 10000, 2) + .map { i => (i, value) }.groupByKey(2).collect() + + // verify result + assert(hook.executed.get() == true) + assert(preventUnnecessaryStageRerun) + assert(tuples.length == 10000) + for (elem <- tuples) { + assert(elem._2.mkString(",").equals(value)) + } + + shuffleMgr.unregisterShuffle(0) + assert(lifecycleManager.getUnregisterShuffleTime().containsKey(0)) + assert(lifecycleManager.getUnregisterShuffleTime().containsKey(1)) + + sparkSession.stop() + } + private def findAppShuffleId(rdd: RDD[_]): Int = { val deps = rdd.dependencies if (deps.size != 1 && !deps.head.isInstanceOf[ShuffleDependency[_, _, _]]) { diff --git a/tests/spark-it/src/test/scala/org/apache/spark/scheduler/SparkSchedulerHelper.scala b/tests/spark-it/src/test/scala/org/apache/spark/scheduler/SparkSchedulerHelper.scala new file mode 100644 index 00000000000..18c20ccf222 --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/spark/scheduler/SparkSchedulerHelper.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.SparkContext + +object SparkSchedulerHelper { + def getTaskIndex(taskId: Long): Int = { + val scheduler = SparkContext.getActive.get.taskScheduler.asInstanceOf[TaskSchedulerImpl] + val taskSetManager = scheduler.taskIdToTaskSetManager.get(taskId) + taskSetManager.taskInfos.get(taskId).get.index + } +}