Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fetch failure integration #46

Merged
merged 4 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,19 @@

package org.apache.celeborn.tests.spark

import java.io.{File, IOException}
import java.io.IOException
import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.{BarrierTaskContext, ShuffleDependency, SparkConf, SparkContextHelper, SparkException, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager}
import org.apache.spark.shuffle.celeborn.{SparkShuffleManager, SparkUtils, TestCelebornShuffleManager}
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.protocol.ShuffleMode
import org.apache.celeborn.service.deploy.worker.Worker

class CelebornFetchFailureSuite extends AnyFunSuite
with SparkTestBase
Expand All @@ -46,57 +43,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite
System.gc()
}

var workerDirs: Seq[String] = Seq.empty

override def createWorker(map: Map[String, String]): Worker = {
val storageDir = createTmpDir()
this.synchronized {
workerDirs = workerDirs :+ storageDir
}
super.createWorker(map, storageDir)
}

class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook {
var executed: AtomicBoolean = new AtomicBoolean(false)
val lock = new Object

override def exec(
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): Unit = {
if (executed.get() == true) return

lock.synchronized {
handle match {
case h: CelebornShuffleHandle[_, _, _] => {
val appUniqueId = h.appUniqueId
val shuffleClient = ShuffleClient.get(
h.appUniqueId,
h.lifecycleManagerHost,
h.lifecycleManagerPort,
conf,
h.userIdentifier,
h.extension)
val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
val allFiles = workerDirs.map(dir => {
new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
})
val datafile = allFiles.filter(_.exists())
.flatMap(_.listFiles().iterator).headOption
datafile match {
case Some(file) => file.delete()
case None => throw new RuntimeException("unexpected, there must be some data file" +
s" under ${workerDirs.mkString(",")}")
}
}
case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here")
}
executed.set(true)
}
}
}

test("celeborn spark integration test - Fetch Failure") {
if (Spark3OrNewer) {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
Expand All @@ -111,7 +57,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()

val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)

val value = Range(1, 10000).mkString(",")
Expand Down Expand Up @@ -184,7 +130,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()

val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)

import sparkSession.implicits._
Expand Down Expand Up @@ -215,7 +161,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()

val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)

val sc = sparkSession.sparkContext
Expand Down Expand Up @@ -255,7 +201,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()

val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)

val sc = sparkSession.sparkContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,26 @@

package org.apache.celeborn.tests.spark

import java.io.File
import java.util.concurrent.atomic.AtomicBoolean

import scala.util.Random

import org.apache.spark.SPARK_VERSION
import org.apache.spark.SparkConf
import org.apache.spark.{SPARK_VERSION, SparkConf, TaskContext}
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkUtils}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.scalatest.funsuite.AnyFunSuite

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.CelebornConf._
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.protocol.ShuffleMode
import org.apache.celeborn.service.deploy.MiniClusterFeature
import org.apache.celeborn.service.deploy.worker.Worker

trait SparkTestBase extends AnyFunSuite
with Logging with MiniClusterFeature with BeforeAndAfterAll with BeforeAndAfterEach {
Expand All @@ -52,6 +59,16 @@ trait SparkTestBase extends AnyFunSuite
shutdownMiniCluster()
}

var workerDirs: Seq[String] = Seq.empty

override def createWorker(map: Map[String, String]): Worker = {
val storageDir = createTmpDir()
this.synchronized {
workerDirs = workerDirs :+ storageDir
}
super.createWorker(map, storageDir)
}

def updateSparkConf(sparkConf: SparkConf, mode: ShuffleMode): SparkConf = {
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.set(
Expand Down Expand Up @@ -98,4 +115,45 @@ trait SparkTestBase extends AnyFunSuite
val outMap = result.collect().map(row => row.getString(0) -> row.getLong(1)).toMap
outMap
}

class ShuffleReaderFetchFailureGetHook(conf: CelebornConf) extends ShuffleManagerHook {
var executed: AtomicBoolean = new AtomicBoolean(false)
val lock = new Object

override def exec(
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): Unit = {
if (executed.get() == true) return

lock.synchronized {
handle match {
case h: CelebornShuffleHandle[_, _, _] => {
val appUniqueId = h.appUniqueId
val shuffleClient = ShuffleClient.get(
h.appUniqueId,
h.lifecycleManagerHost,
h.lifecycleManagerPort,
conf,
h.userIdentifier,
h.extension)
val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
val allFiles = workerDirs.map(dir => {
new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
})
val datafile = allFiles.filter(_.exists())
.flatMap(_.listFiles().iterator).sortBy(_.getName).headOption
datafile match {
case Some(file) => file.delete()
case None => throw new RuntimeException("unexpected, there must be some data file" +
s" under ${workerDirs.mkString(",")}")
}
}
case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here")
}
executed.set(true)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.shuffle.celeborn

import scala.collection.JavaConverters._

import org.apache.spark.SparkConf
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -54,13 +56,19 @@ class SparkUtilsSuite extends AnyFunSuite
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.getOrCreate()

val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)

try {
val sc = sparkSession.sparkContext
val jobThread = new Thread {
override def run(): Unit = {
try {
sc.parallelize(1 to 100, 2)
.repartition(1)
val value = Range(1, 10000).mkString(",")
sc.parallelize(1 to 10000, 2)
.map { i => (i, value) }
.groupByKey(10)
.mapPartitions { iter =>
Thread.sleep(3000)
iter
Expand All @@ -73,13 +81,15 @@ class SparkUtilsSuite extends AnyFunSuite
jobThread.start()

val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
eventually(timeout(3.seconds), interval(100.milliseconds)) {
val taskId = 0
val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId)
eventually(timeout(30.seconds), interval(0.milliseconds)) {
assert(hook.executed.get() == true)
val reportedTaskId =
SparkUtils.reportedStageShuffleFetchFailureTaskIds.values().asScala.flatMap(
_.asScala).head
val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, reportedTaskId)
assert(taskSetManager != null)
assert(SparkUtils.getTaskAttempts(taskSetManager, taskId)._2.size() == 1)
assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId))
assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 1)
assert(SparkUtils.getTaskAttempts(taskSetManager, reportedTaskId)._2.size() == 1)
assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(reportedTaskId))
}

sparkSession.sparkContext.cancelAllJobs()
Expand Down