From 25550e8059ad05b4b04c6c929bfbab8b6196a819 Mon Sep 17 00:00:00 2001 From: "Wang, Fei" Date: Fri, 27 Dec 2024 15:12:02 -0800 Subject: [PATCH 1/4] ut --- .../spark/shuffle/celeborn/SparkUtils.java | 7 ++ .../spark/shuffle/celeborn/SparkUtils.java | 7 ++ .../spark/CelebornFetchFailureSuite.scala | 11 --- .../celeborn/tests/spark/SparkTestBase.scala | 11 +++ .../shuffle/celeborn/SparkUtilsSuite.scala | 75 ++++++++++++++++--- 5 files changed, 91 insertions(+), 20 deletions(-) diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 7ac0f658310..c1ae7fdee4d 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; @@ -265,6 +266,9 @@ protected static Tuple2> getTaskAttempts( } } + // For testing only + protected static Optional firstReportedShuffleFetchFailureTaskId = Optional.empty(); + protected static Map> reportedStageShuffleFetchFailureTaskIds = JavaUtils.newConcurrentHashMap(); @@ -295,6 +299,9 @@ public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { reportedStageShuffleFetchFailureTaskIds.computeIfAbsent( stageUniqId, k -> new HashSet<>()); reportedStageTaskIds.add(taskId); + if (!firstReportedShuffleFetchFailureTaskId.isPresent()) { + firstReportedShuffleFetchFailureTaskId = Optional.of(taskId); + } Tuple2> taskAttempts = getTaskAttempts(taskSetManager, taskId); diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 6c2e5120e04..9ef5b1312a2 100644 --- a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -20,6 +20,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; @@ -381,6 +382,9 @@ protected static Tuple2> getTaskAttempts( } } + // For testing only + protected static Optional firstReportedShuffleFetchFailureTaskId = Optional.empty(); + protected static Map> reportedStageShuffleFetchFailureTaskIds = JavaUtils.newConcurrentHashMap(); @@ -411,6 +415,9 @@ public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { reportedStageShuffleFetchFailureTaskIds.computeIfAbsent( stageUniqId, k -> new HashSet<>()); reportedStageTaskIds.add(taskId); + if (!firstReportedShuffleFetchFailureTaskId.isPresent()) { + firstReportedShuffleFetchFailureTaskId = Optional.of(taskId); + } Tuple2> taskAttempts = getTaskAttempts(taskSetManager, taskId); diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala index 1703ad0b8f2..b4cce90f079 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala @@ -32,7 +32,6 @@ import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.protocol.ShuffleMode -import org.apache.celeborn.service.deploy.worker.Worker class CelebornFetchFailureSuite extends AnyFunSuite with SparkTestBase @@ -46,16 +45,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite System.gc() } - var workerDirs: Seq[String] = Seq.empty - - override def createWorker(map: Map[String, String]): Worker = { - val storageDir = createTmpDir() - this.synchronized { - workerDirs = workerDirs :+ storageDir - } - super.createWorker(map, storageDir) - } - class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook { var executed: AtomicBoolean = new AtomicBoolean(false) val lock = new Object diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala index c92ec4c9d3c..d4153648fcc 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala @@ -30,6 +30,7 @@ import org.apache.celeborn.common.CelebornConf._ import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.protocol.ShuffleMode import org.apache.celeborn.service.deploy.MiniClusterFeature +import org.apache.celeborn.service.deploy.worker.Worker trait SparkTestBase extends AnyFunSuite with Logging with MiniClusterFeature with BeforeAndAfterAll with BeforeAndAfterEach { @@ -52,6 +53,16 @@ trait SparkTestBase extends AnyFunSuite shutdownMiniCluster() } + var workerDirs: Seq[String] = Seq.empty + + override def createWorker(map: Map[String, String]): Worker = { + val storageDir = createTmpDir() + this.synchronized { + workerDirs = workerDirs :+ storageDir + } + super.createWorker(map, storageDir) + } + def updateSparkConf(sparkConf: SparkConf, mode: ShuffleMode): SparkConf = { sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") sparkConf.set( diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index 2edfaf898e4..16175dd8232 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -17,8 +17,13 @@ package org.apache.spark.shuffle.celeborn -import org.apache.spark.SparkConf +import java.io.File +import java.util.Optional +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.SparkSession import org.scalatest.BeforeAndAfterEach import org.scalatest.concurrent.Eventually.eventually @@ -27,6 +32,7 @@ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.time.SpanSugar.convertIntToGrainOfTime import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.protocol.ShuffleMode import org.apache.celeborn.tests.spark.SparkTestBase @@ -54,13 +60,19 @@ class SparkUtilsSuite extends AnyFunSuite "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") .getOrCreate() + val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) + val hook = new ShuffleReaderGetHook(celebornConf) + TestCelebornShuffleManager.registerReaderGetHook(hook) + try { val sc = sparkSession.sparkContext val jobThread = new Thread { override def run(): Unit = { try { - sc.parallelize(1 to 100, 2) - .repartition(1) + val value = Range(1, 10000).mkString(",") + sc.parallelize(1 to 10000, 2) + .map { i => (i, value) } + .groupByKey(10) .mapPartitions { iter => Thread.sleep(3000) iter @@ -70,16 +82,20 @@ class SparkUtilsSuite extends AnyFunSuite } } } + + SparkUtils.firstReportedShuffleFetchFailureTaskId = Optional.empty(); + jobThread.start() val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] - eventually(timeout(3.seconds), interval(100.milliseconds)) { - val taskId = 0 - val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId) + eventually(timeout(30.seconds), interval(100.milliseconds)) { + assert(hook.executed.get() == true) + assert(SparkUtils.firstReportedShuffleFetchFailureTaskId.isPresent) + val reportedTaskId = SparkUtils.firstReportedShuffleFetchFailureTaskId.get() + val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, reportedTaskId) assert(taskSetManager != null) - assert(SparkUtils.getTaskAttempts(taskSetManager, taskId)._2.size() == 1) - assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)) - assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 1) + assert(SparkUtils.getTaskAttempts(taskSetManager, reportedTaskId)._2.size() == 1) + assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(reportedTaskId)) } sparkSession.sparkContext.cancelAllJobs() @@ -93,4 +109,45 @@ class SparkUtilsSuite extends AnyFunSuite sparkSession.stop() } } + + class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook { + var executed: AtomicBoolean = new AtomicBoolean(false) + val lock = new Object + + override def exec( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): Unit = { + if (executed.get() == true) return + + lock.synchronized { + handle match { + case h: CelebornShuffleHandle[_, _, _] => { + val appUniqueId = h.appUniqueId + val shuffleClient = ShuffleClient.get( + h.appUniqueId, + h.lifecycleManagerHost, + h.lifecycleManagerPort, + conf, + h.userIdentifier, + h.extension) + val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) + val allFiles = workerDirs.map(dir => { + new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") + }) + val datafile = allFiles.filter(_.exists()) + .flatMap(_.listFiles().iterator).headOption + datafile match { + case Some(file) => file.delete() + case None => throw new RuntimeException("unexpected, there must be some data file" + + s" under ${workerDirs.mkString(",")}") + } + } + case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") + } + executed.set(true) + } + } + } } From bba02e670038884f38338b6f7966e5b4aa159a73 Mon Sep 17 00:00:00 2001 From: "Wang, Fei" Date: Fri, 27 Dec 2024 15:25:22 -0800 Subject: [PATCH 2/4] reuse --- .../spark/CelebornFetchFailureSuite.scala | 55 ++----------------- .../celeborn/tests/spark/SparkTestBase.scala | 51 ++++++++++++++++- .../shuffle/celeborn/SparkUtilsSuite.scala | 49 +---------------- 3 files changed, 57 insertions(+), 98 deletions(-) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala index b4cce90f079..dd0f3840149 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala @@ -17,20 +17,18 @@ package org.apache.celeborn.tests.spark -import java.io.{File, IOException} +import java.io.IOException import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{BarrierTaskContext, ShuffleDependency, SparkConf, SparkContextHelper, SparkException, TaskContext} import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.rdd.RDD -import org.apache.spark.shuffle.ShuffleHandle -import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager} +import org.apache.spark.shuffle.celeborn.{SparkShuffleManager, SparkUtils, TestCelebornShuffleManager} import org.apache.spark.sql.SparkSession import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.client.ShuffleClient -import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.protocol.ShuffleMode class CelebornFetchFailureSuite extends AnyFunSuite @@ -45,47 +43,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite System.gc() } - class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook { - var executed: AtomicBoolean = new AtomicBoolean(false) - val lock = new Object - - override def exec( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext): Unit = { - if (executed.get() == true) return - - lock.synchronized { - handle match { - case h: CelebornShuffleHandle[_, _, _] => { - val appUniqueId = h.appUniqueId - val shuffleClient = ShuffleClient.get( - h.appUniqueId, - h.lifecycleManagerHost, - h.lifecycleManagerPort, - conf, - h.userIdentifier, - h.extension) - val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) - val allFiles = workerDirs.map(dir => { - new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") - }) - val datafile = allFiles.filter(_.exists()) - .flatMap(_.listFiles().iterator).headOption - datafile match { - case Some(file) => file.delete() - case None => throw new RuntimeException("unexpected, there must be some data file" + - s" under ${workerDirs.mkString(",")}") - } - } - case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") - } - executed.set(true) - } - } - } - test("celeborn spark integration test - Fetch Failure") { if (Spark3OrNewer) { val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") @@ -100,7 +57,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) val value = Range(1, 10000).mkString(",") @@ -173,7 +130,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) import sparkSession.implicits._ @@ -204,7 +161,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) val sc = sparkSession.sparkContext @@ -244,7 +201,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) val sc = sparkSession.sparkContext diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala index d4153648fcc..95d496beeec 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala @@ -17,15 +17,21 @@ package org.apache.celeborn.tests.spark +import java.io.File +import java.util.concurrent.atomic.AtomicBoolean + import scala.util.Random -import org.apache.spark.SPARK_VERSION -import org.apache.spark.SparkConf +import org.apache.spark.{SPARK_VERSION, SparkConf, TaskContext} +import org.apache.spark.shuffle.ShuffleHandle +import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkUtils} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.SQLConf import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.funsuite.AnyFunSuite +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.CelebornConf._ import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.protocol.ShuffleMode @@ -109,4 +115,45 @@ trait SparkTestBase extends AnyFunSuite val outMap = result.collect().map(row => row.getString(0) -> row.getLong(1)).toMap outMap } + + class ShuffleReaderFetchFailureGetHook(conf: CelebornConf) extends ShuffleManagerHook { + var executed: AtomicBoolean = new AtomicBoolean(false) + val lock = new Object + + override def exec( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): Unit = { + if (executed.get() == true) return + + lock.synchronized { + handle match { + case h: CelebornShuffleHandle[_, _, _] => { + val appUniqueId = h.appUniqueId + val shuffleClient = ShuffleClient.get( + h.appUniqueId, + h.lifecycleManagerHost, + h.lifecycleManagerPort, + conf, + h.userIdentifier, + h.extension) + val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) + val allFiles = workerDirs.map(dir => { + new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") + }) + val datafile = allFiles.filter(_.exists()) + .flatMap(_.listFiles().iterator).headOption + datafile match { + case Some(file) => file.delete() + case None => throw new RuntimeException("unexpected, there must be some data file" + + s" under ${workerDirs.mkString(",")}") + } + } + case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") + } + executed.set(true) + } + } + } } diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index 16175dd8232..a5df47e1953 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -17,13 +17,10 @@ package org.apache.spark.shuffle.celeborn -import java.io.File import java.util.Optional -import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.SparkConf import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.shuffle.ShuffleHandle import org.apache.spark.sql.SparkSession import org.scalatest.BeforeAndAfterEach import org.scalatest.concurrent.Eventually.eventually @@ -32,7 +29,6 @@ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.time.SpanSugar.convertIntToGrainOfTime import org.apache.celeborn.client.ShuffleClient -import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.protocol.ShuffleMode import org.apache.celeborn.tests.spark.SparkTestBase @@ -61,7 +57,7 @@ class SparkUtilsSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) try { @@ -109,45 +105,4 @@ class SparkUtilsSuite extends AnyFunSuite sparkSession.stop() } } - - class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook { - var executed: AtomicBoolean = new AtomicBoolean(false) - val lock = new Object - - override def exec( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext): Unit = { - if (executed.get() == true) return - - lock.synchronized { - handle match { - case h: CelebornShuffleHandle[_, _, _] => { - val appUniqueId = h.appUniqueId - val shuffleClient = ShuffleClient.get( - h.appUniqueId, - h.lifecycleManagerHost, - h.lifecycleManagerPort, - conf, - h.userIdentifier, - h.extension) - val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) - val allFiles = workerDirs.map(dir => { - new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") - }) - val datafile = allFiles.filter(_.exists()) - .flatMap(_.listFiles().iterator).headOption - datafile match { - case Some(file) => file.delete() - case None => throw new RuntimeException("unexpected, there must be some data file" + - s" under ${workerDirs.mkString(",")}") - } - } - case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") - } - executed.set(true) - } - } - } } From ae338a95a6368abd8e3393015803e5c84f6783d7 Mon Sep 17 00:00:00 2001 From: "Wang, Fei" Date: Fri, 27 Dec 2024 15:35:03 -0800 Subject: [PATCH 3/4] sort --- .../scala/org/apache/celeborn/tests/spark/SparkTestBase.scala | 2 +- .../org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala index 95d496beeec..999abc053d3 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala @@ -143,7 +143,7 @@ trait SparkTestBase extends AnyFunSuite new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") }) val datafile = allFiles.filter(_.exists()) - .flatMap(_.listFiles().iterator).headOption + .flatMap(_.listFiles().iterator).sortBy(_.getName).headOption datafile match { case Some(file) => file.delete() case None => throw new RuntimeException("unexpected, there must be some data file" + diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index a5df47e1953..be154b62572 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -84,7 +84,7 @@ class SparkUtilsSuite extends AnyFunSuite jobThread.start() val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] - eventually(timeout(30.seconds), interval(100.milliseconds)) { + eventually(timeout(30.seconds), interval(0.milliseconds)) { assert(hook.executed.get() == true) assert(SparkUtils.firstReportedShuffleFetchFailureTaskId.isPresent) val reportedTaskId = SparkUtils.firstReportedShuffleFetchFailureTaskId.get() From fa0b22a2bd885080f3d21d313eefe890d0c22f85 Mon Sep 17 00:00:00 2001 From: "Wang, Fei" Date: Fri, 27 Dec 2024 15:40:31 -0800 Subject: [PATCH 4/4] ut --- .../org/apache/spark/shuffle/celeborn/SparkUtils.java | 7 ------- .../org/apache/spark/shuffle/celeborn/SparkUtils.java | 7 ------- .../spark/shuffle/celeborn/SparkUtilsSuite.scala | 10 ++++------ 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index c1ae7fdee4d..7ac0f658310 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -23,7 +23,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; @@ -266,9 +265,6 @@ protected static Tuple2> getTaskAttempts( } } - // For testing only - protected static Optional firstReportedShuffleFetchFailureTaskId = Optional.empty(); - protected static Map> reportedStageShuffleFetchFailureTaskIds = JavaUtils.newConcurrentHashMap(); @@ -299,9 +295,6 @@ public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { reportedStageShuffleFetchFailureTaskIds.computeIfAbsent( stageUniqId, k -> new HashSet<>()); reportedStageTaskIds.add(taskId); - if (!firstReportedShuffleFetchFailureTaskId.isPresent()) { - firstReportedShuffleFetchFailureTaskId = Optional.of(taskId); - } Tuple2> taskAttempts = getTaskAttempts(taskSetManager, taskId); diff --git a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 9ef5b1312a2..6c2e5120e04 100644 --- a/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -20,7 +20,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; @@ -382,9 +381,6 @@ protected static Tuple2> getTaskAttempts( } } - // For testing only - protected static Optional firstReportedShuffleFetchFailureTaskId = Optional.empty(); - protected static Map> reportedStageShuffleFetchFailureTaskIds = JavaUtils.newConcurrentHashMap(); @@ -415,9 +411,6 @@ public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) { reportedStageShuffleFetchFailureTaskIds.computeIfAbsent( stageUniqId, k -> new HashSet<>()); reportedStageTaskIds.add(taskId); - if (!firstReportedShuffleFetchFailureTaskId.isPresent()) { - firstReportedShuffleFetchFailureTaskId = Optional.of(taskId); - } Tuple2> taskAttempts = getTaskAttempts(taskSetManager, taskId); diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index be154b62572..2d753ff7b17 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.celeborn -import java.util.Optional +import scala.collection.JavaConverters._ import org.apache.spark.SparkConf import org.apache.spark.scheduler.TaskSchedulerImpl @@ -78,16 +78,14 @@ class SparkUtilsSuite extends AnyFunSuite } } } - - SparkUtils.firstReportedShuffleFetchFailureTaskId = Optional.empty(); - jobThread.start() val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] eventually(timeout(30.seconds), interval(0.milliseconds)) { assert(hook.executed.get() == true) - assert(SparkUtils.firstReportedShuffleFetchFailureTaskId.isPresent) - val reportedTaskId = SparkUtils.firstReportedShuffleFetchFailureTaskId.get() + val reportedTaskId = + SparkUtils.reportedStageShuffleFetchFailureTaskIds.values().asScala.flatMap( + _.asScala).head val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, reportedTaskId) assert(taskSetManager != null) assert(SparkUtils.getTaskAttempts(taskSetManager, reportedTaskId)._2.size() == 1)