Skip to content

Commit

Permalink
[CELEBORN-1094] Optimize mechanism of ChunkManager expired shuffle ke…
Browse files Browse the repository at this point in the history
…y cleanup to avoid memory leak

### What changes were proposed in this pull request?

The `cleaner` of `Worker` executes the `StorageManager#cleanupExpiredShuffleKey` to clean expired shuffle keys with daemon cached thread pool. The optimization speeds up cleaning including expired shuffle keys of ChunkManager to avoid memory leak.

### Why are the changes needed?

`ChunkManager#streams` could lead memory leak when the speed of cleanup is slower than expiration for expired shuffle of worker. The behavior that `ChunkStreamManager` cleanup expired shuffle key should be optimized to avoid memory leak, which causes that the VM thread of worker is 100%.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

`WorkerSuite#clean up`.

Closes #2053 from SteNicholas/CELEBORN-1094.

Authored-by: SteNicholas <programgeek@163.com>
Signed-off-by: mingji <fengmingxiao.fmx@alibaba-inc.com>
  • Loading branch information
SteNicholas authored and FMX committed Nov 2, 2023
1 parent 0583cdb commit 4e8e8c2
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def workerReplicateThreads: Int = get(WORKER_REPLICATE_THREADS)
def workerCommitThreads: Int =
if (hasHDFSStorage) Math.max(128, get(WORKER_COMMIT_THREADS)) else get(WORKER_COMMIT_THREADS)
def workerCleanThreads: Int = get(WORKER_CLEAN_THREADS)
def workerShuffleCommitTimeout: Long = get(WORKER_SHUFFLE_COMMIT_TIMEOUT)
def minPartitionSizeToEstimate: Long = get(ESTIMATED_PARTITION_SIZE_MIN_SIZE)
def partitionSorterSortPartitionTimeout: Long = get(PARTITION_SORTER_SORT_TIMEOUT)
Expand Down Expand Up @@ -973,6 +974,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def workerDiskTimeSlidingWindowMinFetchCount: Int =
get(WORKER_DISKTIME_SLIDINGWINDOW_MINFETCHCOUNT)
def workerDiskReserveSize: Long = get(WORKER_DISK_RESERVE_SIZE)
def workerDiskCleanThreads: Int = get(WORKER_DISK_CLEAN_THREADS)
def workerDiskMonitorEnabled: Boolean = get(WORKER_DISK_MONITOR_ENABLED)
def workerDiskMonitorCheckList: Seq[String] = get(WORKER_DISK_MONITOR_CHECKLIST)
def workerDiskMonitorCheckInterval: Long = get(WORKER_DISK_MONITOR_CHECK_INTERVAL)
Expand Down Expand Up @@ -2140,6 +2142,14 @@ object CelebornConf extends Logging {
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("5G")

val WORKER_DISK_CLEAN_THREADS: ConfigEntry[Int] =
buildConf("celeborn.worker.disk.clean.threads")
.categories("worker")
.version("0.3.2")
.doc("Thread number of worker to clean up directories of expired shuffle keys on disk.")
.intConf
.createWithDefault(4)

val WORKER_CHECK_FILE_CLEAN_MAX_RETRIES: ConfigEntry[Int] =
buildConf("celeborn.worker.storage.checkDirsEmpty.maxRetries")
.withAlternative("celeborn.worker.disk.checkFileClean.maxRetries")
Expand Down Expand Up @@ -2292,6 +2302,14 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(32)

val WORKER_CLEAN_THREADS: ConfigEntry[Int] =
buildConf("celeborn.worker.clean.threads")
.categories("worker")
.version("0.3.2")
.doc("Thread number of worker to clean up expired shuffle keys.")
.intConf
.createWithDefault(64)

val WORKER_SHUFFLE_COMMIT_TIMEOUT: ConfigEntry[Long] =
buildConf("celeborn.worker.commitFiles.timeout")
.withAlternative("celeborn.worker.shuffle.commit.timeout")
Expand Down
2 changes: 2 additions & 0 deletions docs/configuration/worker.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ license: |
| celeborn.storage.hdfs.dir | &lt;undefined&gt; | HDFS base directory for Celeborn to store shuffle data. | 0.2.0 |
| celeborn.worker.activeConnection.max | &lt;undefined&gt; | If the number of active connections on a worker exceeds this configuration value, the worker will be marked as high-load in the heartbeat report, and the master will not include that node in the response of RequestSlots. | 0.3.1 |
| celeborn.worker.bufferStream.threadsPerMountpoint | 8 | Threads count for read buffer per mount point. | 0.3.0 |
| celeborn.worker.clean.threads | 64 | Thread number of worker to clean up expired shuffle keys. | 0.3.2 |
| celeborn.worker.closeIdleConnections | false | Whether worker will close idle connections. | 0.2.0 |
| celeborn.worker.commitFiles.threads | 32 | Thread number of worker to commit shuffle data files asynchronously. It's recommended to set at least `128` when `HDFS` is enabled in `celeborn.storage.activeTypes`. | 0.3.0 |
| celeborn.worker.commitFiles.timeout | 120s | Timeout for a Celeborn worker to commit files of a shuffle. It's recommended to set at least `240s` when `HDFS` is enabled in `celeborn.storage.activeTypes`. | 0.3.0 |
Expand All @@ -42,6 +43,7 @@ license: |
| celeborn.worker.directMemoryRatioToPauseReceive | 0.85 | If direct memory usage reaches this limit, the worker will stop to receive data from Celeborn shuffle clients. | 0.2.0 |
| celeborn.worker.directMemoryRatioToPauseReplicate | 0.95 | If direct memory usage reaches this limit, the worker will stop to receive replication data from other workers. This value should be higher than celeborn.worker.directMemoryRatioToPauseReceive. | 0.2.0 |
| celeborn.worker.directMemoryRatioToResume | 0.7 | If direct memory usage is less than this limit, worker will resume. | 0.2.0 |
| celeborn.worker.disk.clean.threads | 4 | Thread number of worker to clean up directories of expired shuffle keys on disk. | 0.3.2 |
| celeborn.worker.fetch.heartbeat.enabled | false | enable the heartbeat from worker to client when fetching data | 0.3.0 |
| celeborn.worker.fetch.io.threads | &lt;undefined&gt; | Netty IO thread number of worker to handle client fetch data. The default threads number is the number of flush thread. | 0.2.0 |
| celeborn.worker.fetch.port | 0 | Server port for Worker to receive fetch data request from ShuffleClient. | 0.2.0 |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -91,9 +89,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len
}

FileManagedBuffers buffers = state.buffers;
ManagedBuffer nextChunk = buffers.chunk(chunkIndex, offset, len);

return nextChunk;
return buffers.chunk(chunkIndex, offset, len);
}

public TimeWindow getFetchTimeMetric(long streamId) {
Expand All @@ -105,20 +101,6 @@ public TimeWindow getFetchTimeMetric(long streamId) {
}
}

public static String genStreamChunkId(long streamId, int chunkId) {
return String.format("%d_%d", streamId, chunkId);
}

// Parse streamChunkId to be stream id and chunk id. This is used when fetch remote chunk as a
// stream.
public static Pair<Long, Integer> parseStreamChunkId(String streamChunkId) {
String[] array = streamChunkId.split("_");
assert array.length == 2 : "Stream id and chunk index should be specified.";
long streamId = Long.parseLong(array[0]);
int chunkIndex = Integer.parseInt(array[1]);
return ImmutablePair.of(streamId, chunkIndex);
}

public void chunkBeingSent(long streamId) {
StreamState streamState = streams.get(streamId);
if (streamState != null) {
Expand Down Expand Up @@ -184,14 +166,21 @@ public long nextStreamId() {
}

public void cleanupExpiredShuffleKey(Set<String> expiredShuffleKeys) {
logger.info(
"Clean up expired shuffle keys {}",
String.join(",", expiredShuffleKeys.toArray(new String[0])));
for (String expiredShuffleKey : expiredShuffleKeys) {
Set<Long> expiredStreamIds = shuffleStreamIds.remove(expiredShuffleKey);

// normally expiredStreamIds set will be empty as streamId will be removed when be fully read
if (expiredStreamIds != null && !expiredStreamIds.isEmpty()) {
streams.keySet().removeAll(expiredStreamIds);
expiredStreamIds.forEach(streams::remove);
}
}
logger.info(
"Cleaned up expired shuffle keys. The count of shuffle keys and streams: {}, {}",
shuffleStreamIds.size(),
streams.size());
}

public Tuple2<String, String> getShuffleKeyAndFileName(long streamId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,11 @@ private[celeborn] class Worker(
val replicateThreadPool: ThreadPoolExecutor =
ThreadUtils.newDaemonCachedThreadPool("worker-replicate-data", conf.workerReplicateThreads)
val commitThreadPool: ThreadPoolExecutor =
ThreadUtils.newDaemonCachedThreadPool("Worker-CommitFiles", conf.workerCommitThreads)
ThreadUtils.newDaemonCachedThreadPool("worker-commit-files", conf.workerCommitThreads)
val cleanThreadPool: ThreadPoolExecutor =
ThreadUtils.newDaemonCachedThreadPool(
"worker-clean-expired-shuffle-keys",
conf.workerCleanThreads)
val asyncReplyPool: ScheduledExecutorService =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("async-reply")
val timer = new HashedWheelTimer()
Expand Down Expand Up @@ -405,7 +409,7 @@ private[celeborn] class Worker(
while (true) {
val expiredShuffleKeys = cleanTaskQueue.take()
try {
cleanup(expiredShuffleKeys)
cleanup(expiredShuffleKeys, cleanThreadPool)
} catch {
case e: Throwable =>
logError("Cleanup failed", e)
Expand Down Expand Up @@ -562,20 +566,23 @@ private[celeborn] class Worker(
}

@VisibleForTesting
def cleanup(expiredShuffleKeys: JHashSet[String]): Unit = synchronized {
expiredShuffleKeys.asScala.foreach { shuffleKey =>
partitionLocationInfo.removeShuffle(shuffleKey)
shufflePartitionType.remove(shuffleKey)
shufflePushDataTimeout.remove(shuffleKey)
shuffleMapperAttempts.remove(shuffleKey)
shuffleCommitInfos.remove(shuffleKey)
workerInfo.releaseSlots(shuffleKey)
logInfo(s"Cleaned up expired shuffle $shuffleKey")
def cleanup(expiredShuffleKeys: JHashSet[String], threadPool: ThreadPoolExecutor): Unit =
synchronized {
expiredShuffleKeys.asScala.foreach { shuffleKey =>
partitionLocationInfo.removeShuffle(shuffleKey)
shufflePartitionType.remove(shuffleKey)
shufflePushDataTimeout.remove(shuffleKey)
shuffleMapperAttempts.remove(shuffleKey)
shuffleCommitInfos.remove(shuffleKey)
workerInfo.releaseSlots(shuffleKey)
logInfo(s"Cleaned up expired shuffle $shuffleKey")
}
partitionsSorter.cleanup(expiredShuffleKeys)
fetchHandler.cleanupExpiredShuffleKey(expiredShuffleKeys)
threadPool.execute(new Runnable {
override def run(): Unit = storageManager.cleanupExpiredShuffleKey(expiredShuffleKeys)
})
}
partitionsSorter.cleanup(expiredShuffleKeys)
storageManager.cleanupExpiredShuffleKey(expiredShuffleKeys)
fetchHandler.cleanupExpiredShuffleKey(expiredShuffleKeys)
}

override def getWorkerInfo: String = {
val sb = new StringBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import scala.collection.JavaConverters._
import scala.concurrent.duration._

import io.netty.buffer.PooledByteBufAllocator
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.fs.permission.FsPermission

Expand Down Expand Up @@ -88,7 +89,9 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
diskInfo =>
cleaners.put(
diskInfo.mountPoint,
ThreadUtils.newDaemonCachedThreadPool(s"Disk-cleaner-${diskInfo.mountPoint}", 1))
ThreadUtils.newDaemonCachedThreadPool(
s"disk-cleaner-${diskInfo.mountPoint}",
conf.workerDiskCleanThreads))
}
cleaners
}
Expand Down Expand Up @@ -156,7 +159,7 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs

override def notifyError(mountPoint: String, diskStatus: DiskStatus): Unit = this.synchronized {
if (diskStatus == DiskStatus.CRITICAL_ERROR) {
logInfo(s"Disk ${mountPoint} faces critical error, will remove its disk operator.")
logInfo(s"Disk $mountPoint faces critical error, will remove its disk operator.")
val operator = diskOperators.remove(mountPoint)
if (operator != null) {
operator.shutdown()
Expand All @@ -168,15 +171,17 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
if (!diskOperators.containsKey(mountPoint)) {
diskOperators.put(
mountPoint,
ThreadUtils.newDaemonCachedThreadPool(s"Disk-cleaner-${mountPoint}", 1))
ThreadUtils.newDaemonCachedThreadPool(
s"disk-cleaner-$mountPoint",
conf.workerDiskCleanThreads))
}
}

private val counter = new AtomicInteger()
private val counterOperator = new IntUnaryOperator() {
override def applyAsInt(operand: Int): Int = {
val dirs = healthyWorkingDirs()
if (dirs.length > 0) {
if (dirs.nonEmpty) {
(operand + 1) % dirs.length
} else 0
}
Expand Down Expand Up @@ -254,12 +259,12 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
val shuffleKey = parseDbShuffleKey(key)
try {
val files = PbSerDeUtils.fromPbFileInfoMap(entry.getValue, cache)
logDebug(s"Reload DB: ${shuffleKey} -> ${files}")
logDebug(s"Reload DB: $shuffleKey -> $files")
fileInfos.put(shuffleKey, files)
db.delete(entry.getKey)
} catch {
case exception: Exception =>
logError(s"Reload DB: ${shuffleKey} failed.", exception)
logError(s"Reload DB: $shuffleKey failed.", exception)
}
} else {
return
Expand Down Expand Up @@ -523,7 +528,7 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
val hdfsFileWriter = hdfsWriters.get(fileInfo.getFilePath)
if (hdfsFileWriter != null) {
hdfsFileWriter.destroy(new IOException(
s"Destroy FileWriter ${hdfsFileWriter} caused by shuffle ${shuffleKey} expired."))
s"Destroy FileWriter $hdfsFileWriter caused by shuffle $shuffleKey expired."))
hdfsWriters.remove(fileInfo.getFilePath)
}
} else {
Expand All @@ -534,7 +539,7 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
val fileWriter = writers.get(fileInfo.getFilePath)
if (fileWriter != null) {
fileWriter.destroy(new IOException(
s"Destroy FileWriter ${fileWriter} caused by shuffle ${shuffleKey} expired."))
s"Destroy FileWriter $fileWriter caused by shuffle $shuffleKey expired."))
writers.remove(fileInfo.getFilePath)
}
}
Expand Down Expand Up @@ -611,9 +616,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
.filter(diskInfo =>
diskInfo.status == DiskStatus.HEALTHY
|| diskInfo.status == DiskStatus.HIGH_DISK_USAGE)
.map { case diskInfo =>
(diskInfo, diskInfo.dirs.filter(_.exists).flatMap(_.listFiles()))
}
.map(diskInfo =>
(diskInfo, diskInfo.dirs.filter(_.exists).flatMap(_.listFiles())))
val appIds = shuffleKeySet().asScala.map(key => Utils.splitShuffleKey(key)._1)

diskInfoAndAppDirs.foreach { case (diskInfo, appDirs) =>
Expand All @@ -629,34 +633,25 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
}

private def deleteDirectory(dir: File, threadPool: ThreadPoolExecutor): Unit = {
val allContents = dir.listFiles
if (allContents != null) {
for (file <- allContents) {
deleteDirectory(file, threadPool)
}
if (dir.exists()) {
threadPool.submit(new Runnable {
override def run(): Unit = {
deleteDirectoryWithRetry(dir)
}
})
}
threadPool.submit(new Runnable {
override def run(): Unit = {
deleteFileWithRetry(dir)
}
})
}

private def deleteFileWithRetry(file: File): Unit = {
if (file.exists()) {
var retryCount = 0
var deleteSuccess = false
while (!deleteSuccess && retryCount <= 3) {
deleteSuccess = file.delete()
retryCount = retryCount + 1
if (!deleteSuccess) {
Thread.sleep(200 * retryCount)
}
}
if (deleteSuccess) {
logDebug(s"Deleted expired shuffle file $file.")
} else {
logWarning(s"Failed to delete expired shuffle file $file.")
private def deleteDirectoryWithRetry(dir: File): Unit = {
var retryCount = 0
var deleteSuccess = false
while (!deleteSuccess && retryCount <= 3) {
try {
FileUtils.deleteDirectory(dir)
deleteSuccess = true
} catch {
case _: IOException =>
retryCount = retryCount + 1
}
}
}
Expand Down Expand Up @@ -696,7 +691,7 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs
retryTimes += 1
if (retryTimes < conf.workerCheckFileCleanMaxRetries) {
logInfo(s"Working directory's files have not been cleaned up completely, " +
s"will start ${retryTimes + 1}th attempt after ${workerCheckFileCleanTimeout} milliseconds.")
s"will start ${retryTimes + 1}th attempt after $workerCheckFileCleanTimeout milliseconds.")
}
Thread.sleep(workerCheckFileCleanTimeout)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.scalatest.funsuite.AnyFunSuite
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.identity.UserIdentifier
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType}
import org.apache.celeborn.common.util.{CelebornExitKind, JavaUtils}
import org.apache.celeborn.common.util.{CelebornExitKind, JavaUtils, ThreadUtils}
import org.apache.celeborn.service.deploy.worker.{Worker, WorkerArguments}

class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach {
Expand Down Expand Up @@ -83,7 +83,12 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach {
val shuffleKey2 = "2-2"
expiredShuffleKeys.add(shuffleKey1)
expiredShuffleKeys.add(shuffleKey2)
worker.cleanup(expiredShuffleKeys)
worker.cleanup(
expiredShuffleKeys,
ThreadUtils.newDaemonCachedThreadPool(
"worker-clean-expired-shuffle-keys",
conf.workerCleanThreads))
Thread.sleep(3000)
worker.storageManager.workingDirWriters.values().asScala.map(t => assert(t.size() == 0))
}

Expand Down

0 comments on commit 4e8e8c2

Please sign in to comment.