diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala index 1703ad0b8f2..dd0f3840149 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala @@ -17,22 +17,19 @@ package org.apache.celeborn.tests.spark -import java.io.{File, IOException} +import java.io.IOException import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{BarrierTaskContext, ShuffleDependency, SparkConf, SparkContextHelper, SparkException, TaskContext} import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.rdd.RDD -import org.apache.spark.shuffle.ShuffleHandle -import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager} +import org.apache.spark.shuffle.celeborn.{SparkShuffleManager, SparkUtils, TestCelebornShuffleManager} import org.apache.spark.sql.SparkSession import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.client.ShuffleClient -import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.protocol.ShuffleMode -import org.apache.celeborn.service.deploy.worker.Worker class CelebornFetchFailureSuite extends AnyFunSuite with SparkTestBase @@ -46,57 +43,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite System.gc() } - var workerDirs: Seq[String] = Seq.empty - - override def createWorker(map: Map[String, String]): Worker = { - val storageDir = createTmpDir() - this.synchronized { - workerDirs = workerDirs :+ storageDir - } - super.createWorker(map, storageDir) - } - - class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook { - var executed: AtomicBoolean = new AtomicBoolean(false) - val lock = new Object - - override def exec( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext): Unit = { - if (executed.get() == true) return - - lock.synchronized { - handle match { - case h: CelebornShuffleHandle[_, _, _] => { - val appUniqueId = h.appUniqueId - val shuffleClient = ShuffleClient.get( - h.appUniqueId, - h.lifecycleManagerHost, - h.lifecycleManagerPort, - conf, - h.userIdentifier, - h.extension) - val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) - val allFiles = workerDirs.map(dir => { - new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") - }) - val datafile = allFiles.filter(_.exists()) - .flatMap(_.listFiles().iterator).headOption - datafile match { - case Some(file) => file.delete() - case None => throw new RuntimeException("unexpected, there must be some data file" + - s" under ${workerDirs.mkString(",")}") - } - } - case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") - } - executed.set(true) - } - } - } - test("celeborn spark integration test - Fetch Failure") { if (Spark3OrNewer) { val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") @@ -111,7 +57,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) val value = Range(1, 10000).mkString(",") @@ -184,7 +130,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) import sparkSession.implicits._ @@ -215,7 +161,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) val sc = sparkSession.sparkContext @@ -255,7 +201,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite .getOrCreate() val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) - val hook = new ShuffleReaderGetHook(celebornConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) TestCelebornShuffleManager.registerReaderGetHook(hook) val sc = sparkSession.sparkContext diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala index c92ec4c9d3c..999abc053d3 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala @@ -17,19 +17,26 @@ package org.apache.celeborn.tests.spark +import java.io.File +import java.util.concurrent.atomic.AtomicBoolean + import scala.util.Random -import org.apache.spark.SPARK_VERSION -import org.apache.spark.SparkConf +import org.apache.spark.{SPARK_VERSION, SparkConf, TaskContext} +import org.apache.spark.shuffle.ShuffleHandle +import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkUtils} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.SQLConf import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.funsuite.AnyFunSuite +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.CelebornConf._ import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.protocol.ShuffleMode import org.apache.celeborn.service.deploy.MiniClusterFeature +import org.apache.celeborn.service.deploy.worker.Worker trait SparkTestBase extends AnyFunSuite with Logging with MiniClusterFeature with BeforeAndAfterAll with BeforeAndAfterEach { @@ -52,6 +59,16 @@ trait SparkTestBase extends AnyFunSuite shutdownMiniCluster() } + var workerDirs: Seq[String] = Seq.empty + + override def createWorker(map: Map[String, String]): Worker = { + val storageDir = createTmpDir() + this.synchronized { + workerDirs = workerDirs :+ storageDir + } + super.createWorker(map, storageDir) + } + def updateSparkConf(sparkConf: SparkConf, mode: ShuffleMode): SparkConf = { sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") sparkConf.set( @@ -98,4 +115,45 @@ trait SparkTestBase extends AnyFunSuite val outMap = result.collect().map(row => row.getString(0) -> row.getLong(1)).toMap outMap } + + class ShuffleReaderFetchFailureGetHook(conf: CelebornConf) extends ShuffleManagerHook { + var executed: AtomicBoolean = new AtomicBoolean(false) + val lock = new Object + + override def exec( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): Unit = { + if (executed.get() == true) return + + lock.synchronized { + handle match { + case h: CelebornShuffleHandle[_, _, _] => { + val appUniqueId = h.appUniqueId + val shuffleClient = ShuffleClient.get( + h.appUniqueId, + h.lifecycleManagerHost, + h.lifecycleManagerPort, + conf, + h.userIdentifier, + h.extension) + val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) + val allFiles = workerDirs.map(dir => { + new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") + }) + val datafile = allFiles.filter(_.exists()) + .flatMap(_.listFiles().iterator).sortBy(_.getName).headOption + datafile match { + case Some(file) => file.delete() + case None => throw new RuntimeException("unexpected, there must be some data file" + + s" under ${workerDirs.mkString(",")}") + } + } + case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here") + } + executed.set(true) + } + } + } } diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index 2edfaf898e4..2d753ff7b17 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn +import scala.collection.JavaConverters._ + import org.apache.spark.SparkConf import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.sql.SparkSession @@ -54,13 +56,19 @@ class SparkUtilsSuite extends AnyFunSuite "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") .getOrCreate() + val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) + val hook = new ShuffleReaderFetchFailureGetHook(celebornConf) + TestCelebornShuffleManager.registerReaderGetHook(hook) + try { val sc = sparkSession.sparkContext val jobThread = new Thread { override def run(): Unit = { try { - sc.parallelize(1 to 100, 2) - .repartition(1) + val value = Range(1, 10000).mkString(",") + sc.parallelize(1 to 10000, 2) + .map { i => (i, value) } + .groupByKey(10) .mapPartitions { iter => Thread.sleep(3000) iter @@ -73,13 +81,15 @@ class SparkUtilsSuite extends AnyFunSuite jobThread.start() val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] - eventually(timeout(3.seconds), interval(100.milliseconds)) { - val taskId = 0 - val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId) + eventually(timeout(30.seconds), interval(0.milliseconds)) { + assert(hook.executed.get() == true) + val reportedTaskId = + SparkUtils.reportedStageShuffleFetchFailureTaskIds.values().asScala.flatMap( + _.asScala).head + val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, reportedTaskId) assert(taskSetManager != null) - assert(SparkUtils.getTaskAttempts(taskSetManager, taskId)._2.size() == 1) - assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)) - assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 1) + assert(SparkUtils.getTaskAttempts(taskSetManager, reportedTaskId)._2.size() == 1) + assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(reportedTaskId)) } sparkSession.sparkContext.cancelAllJobs()