Skip to content

Commit

Permalink
ut
Browse files Browse the repository at this point in the history
  • Loading branch information
turboFei committed Nov 21, 2024
1 parent 2be9682 commit 323ff65
Showing 1 changed file with 57 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,63 @@ class CelebornFetchFailureSuite extends AnyFunSuite
}
}

test("celeborn spark integration test - prevent stage re-run if another task attempt is running or successful") {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
.config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
.config(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.config("spark.speculation", "true")
.config("spark.speculation.multiplier", "0.0")
.config("spark.speculation.quantile", "0.1")
.getOrCreate()

try {
val sc = sparkSession.sparkContext

val groupedRdd = sc
.parallelize(0 until 10000, 2)
.map(v => (v, v))
.groupByKey()

val appShuffleId = findAppShuffleId(groupedRdd)

val result = groupedRdd.mapPartitions { iter =>
val context = TaskContext.get()
if (context.stageAttemptNumber() == 0 && context.partitionId() == 0 && context.attemptNumber() == 0) {
Thread.sleep(3000)
throw ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER.makeFetchFailureException(
appShuffleId,
-1,
context.partitionId(),
new IOException("forced"))
} else if (context.attemptNumber() > 0) {
Thread.sleep(3000)
}
iter
}.collect()

assert(result.size == 10000)

val shuffleMgr = SparkContextHelper.env
.shuffleManager
.asInstanceOf[TestCelebornShuffleManager]
val lifecycleManager = shuffleMgr.getLifecycleManager

// assert no stage re-run
assert(lifecycleManager.getUnregisterShuffleTime().isEmpty)
shuffleMgr.unregisterShuffle(0)
assert(lifecycleManager.getUnregisterShuffleTime().containsKey(0))
assert(!lifecycleManager.getUnregisterShuffleTime().containsKey(1))
} finally {
sparkSession.stop()
}
}

private def findAppShuffleId(rdd: RDD[_]): Int = {
val deps = rdd.dependencies
if (deps.size != 1 && !deps.head.isInstanceOf[ShuffleDependency[_, _, _]]) {
Expand Down

0 comments on commit 323ff65

Please sign in to comment.