Skip to content

Commit

Permalink
ut
Browse files Browse the repository at this point in the history
  • Loading branch information
turboFei committed Nov 21, 2024
1 parent 2be9682 commit c2ed014
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1788,7 +1788,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
workerStatusTracker.registerWorkerStatusListener(workerStatusListener)
}

@volatile private var reportTaskShuffleFetchFailurePreCheck
@volatile private[celeborn] var reportTaskShuffleFetchFailurePreCheck
: Option[java.util.function.Function[java.lang.Long, Boolean]] = None
def registerReportTaskShuffleFetchFailurePreCheck(preCheck: java.util.function.Function[
java.lang.Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,63 @@ class CelebornFetchFailureSuite extends AnyFunSuite
}
}

test("celeborn spark integration test - prevent stage re-run if another task attempt is running or successful") {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
.config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
.config(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.config("spark.speculation", "true")
.config("spark.speculation.multiplier", "0.0")
.config("spark.speculation.quantile", "0.1")
.getOrCreate()

try {
val sc = sparkSession.sparkContext

val groupedRdd = sc
.parallelize(0 until 10000, 2)
.map(v => (v, v))
.groupByKey()

val appShuffleId = findAppShuffleId(groupedRdd)

val result = groupedRdd.mapPartitions { iter =>
val context = TaskContext.get()
if (context.stageAttemptNumber() == 0 && context.partitionId() == 0 && context.attemptNumber() == 0) {
Thread.sleep(5000)
throw ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER.makeFetchFailureException(
appShuffleId,
-1,
context.partitionId(),
new IOException("forced"))
} else if (context.attemptNumber() > 0) {
Thread.sleep(5000)
}
iter
}.collect()

assert(result.size == 10000)

val shuffleMgr = SparkContextHelper.env
.shuffleManager
.asInstanceOf[TestCelebornShuffleManager]
val lifecycleManager = shuffleMgr.getLifecycleManager

// assert no stage re-run
assert(lifecycleManager.getUnregisterShuffleTime().isEmpty)
shuffleMgr.unregisterShuffle(0)
assert(lifecycleManager.getUnregisterShuffleTime().containsKey(0))
assert(!lifecycleManager.getUnregisterShuffleTime().containsKey(1))
} finally {
sparkSession.stop()
}
}

private def findAppShuffleId(rdd: RDD[_]): Int = {
val deps = rdd.dependencies
if (deps.size != 1 && !deps.head.isInstanceOf[ShuffleDependency[_, _, _]]) {
Expand Down

0 comments on commit c2ed014

Please sign in to comment.