Skip to content

Commit

Permalink
[CELEBORN-1048] Align fetchWaitTime metrics to spark implementation
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Align fetchWaitTime metrics to spark implementation

### Why are the changes needed?
In our production environment, there are variations in the fetchWaitTime metric for the same stage of the same job.

ON YARN ESS:
![image](https://github.com/apache/incubator-celeborn/assets/68682646/601a8315-1317-48dc-b9a6-7ea651d5122d)
ON CELEBORN
![image](https://github.com/apache/incubator-celeborn/assets/68682646/e00ed60f-3789-4330-a7ed-fdd5754acf1d)
Then, based on the implementation of Spark ShuffleBlockFetcherIterator, I made adjustments to the fetchWaitTime metrics code

Now, looks like more reasonable, 
![image](https://github.com/apache/incubator-celeborn/assets/68682646/ce5e46e4-8ed2-422e-b54b-cd594aad73dd)
### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
yes, tested in our production environment

Closes #2000 from TongWei1105/CELEBORN-1048.

Lead-authored-by: TongWei1105 <vvtwow@gmail.com>
Co-authored-by: Keyong Zhou <zhouky@apache.org>
Co-authored-by: zky.zhoukeyong <zky.zhoukeyong@alibaba-inc.com>
Signed-off-by: Cheng Pan <chengpan@apache.org>
  • Loading branch information
3 people authored and pan3793 committed Nov 2, 2023
1 parent e437228 commit 0583cdb
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class CelebornShuffleReader[K, C](
partitionId,
context.attemptNumber(),
startMapIndex,
endMapIndex)
endMapIndex,
metricsCallback)
streams.put(partitionId, inputStream)
} catch {
case e: IOException =>
Expand Down Expand Up @@ -119,7 +120,6 @@ class CelebornShuffleReader[K, C](
}
metricsCallback.incReadTime(
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait))
inputStream.setCallback(metricsCallback)
// ensure inputStream is closed when task completes
context.addTaskCompletionListener(_ => inputStream.close())
inputStream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class CelebornShuffleReader[K, C](
partitionId,
context.attemptNumber(),
startMapIndex,
endMapIndex)
endMapIndex,
metricsCallback)
streams.put(partitionId, inputStream)
} catch {
case e: IOException =>
Expand Down Expand Up @@ -121,7 +122,6 @@ class CelebornShuffleReader[K, C](
}
metricsCallback.incReadTime(
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait))
inputStream.setCallback(metricsCallback)
// ensure inputStream is closed when task completes
context.addTaskCompletionListener[Unit](_ => inputStream.close())
inputStream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.slf4j.LoggerFactory;

import org.apache.celeborn.client.read.CelebornInputStream;
import org.apache.celeborn.client.read.MetricsCallback;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.protocol.PartitionLocation;
Expand Down Expand Up @@ -191,7 +192,12 @@ public abstract void mapPartitionMapperEnd(
* @throws IOException
*/
public abstract CelebornInputStream readPartition(
int shuffleId, int partitionId, int attemptNumber, int startMapIndex, int endMapIndex)
int shuffleId,
int partitionId,
int attemptNumber,
int startMapIndex,
int endMapIndex,
MetricsCallback metricsCallback)
throws IOException;

public abstract boolean cleanupShuffle(int shuffleId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import org.apache.celeborn.client.compress.Compressor;
import org.apache.celeborn.client.read.CelebornInputStream;
import org.apache.celeborn.client.read.MetricsCallback;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.identity.UserIdentifier;
Expand Down Expand Up @@ -1585,7 +1586,12 @@ protected ReduceFileGroups loadFileGroup(int shuffleId, int partitionId) throws

@Override
public CelebornInputStream readPartition(
int shuffleId, int partitionId, int attemptNumber, int startMapIndex, int endMapIndex)
int shuffleId,
int partitionId,
int attemptNumber,
int startMapIndex,
int endMapIndex,
MetricsCallback metricsCallback)
throws IOException {
ReduceFileGroups fileGroups = loadFileGroup(shuffleId, partitionId);

Expand All @@ -1604,7 +1610,8 @@ public CelebornInputStream readPartition(
attemptNumber,
startMapIndex,
endMapIndex,
fetchExcludedWorkers);
fetchExcludedWorkers,
metricsCallback);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ public static CelebornInputStream create(
int attemptNumber,
int startMapIndex,
int endMapIndex,
ConcurrentHashMap<String, Long> fetchExcludedWorkers)
ConcurrentHashMap<String, Long> fetchExcludedWorkers,
MetricsCallback metricsCallback)
throws IOException {
if (locations == null || locations.length == 0) {
return emptyInputStream;
Expand All @@ -70,16 +71,15 @@ public static CelebornInputStream create(
attemptNumber,
startMapIndex,
endMapIndex,
fetchExcludedWorkers);
fetchExcludedWorkers,
metricsCallback);
}
}

public static CelebornInputStream empty() {
return emptyInputStream;
}

public abstract void setCallback(MetricsCallback callback);

private static final CelebornInputStream emptyInputStream =
new CelebornInputStream() {
@Override
Expand All @@ -92,9 +92,6 @@ public int read(byte[] b, int off, int len) throws IOException {
return -1;
}

@Override
public void setCallback(MetricsCallback callback) {}

@Override
public int totalPartitionsToRead() {
return 0;
Expand Down Expand Up @@ -164,7 +161,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
int attemptNumber,
int startMapIndex,
int endMapIndex,
ConcurrentHashMap<String, Long> fetchExcludedWorkers)
ConcurrentHashMap<String, Long> fetchExcludedWorkers,
MetricsCallback metricsCallback)
throws IOException {
this.conf = conf;
this.clientFactory = clientFactory;
Expand Down Expand Up @@ -202,6 +200,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
TransportConf transportConf =
Utils.fromCelebornConf(conf, TransportModuleConstants.DATA_MODULE, 0);
retryWaitMs = transportConf.ioRetryWaitTimeMs();
this.callback = metricsCallback;
moveToNextReader();
}

Expand Down Expand Up @@ -418,7 +417,7 @@ private PartitionReader createReader(
logger.debug("Read local shuffle file {}", localHostAddress);
containLocalRead = true;
return new LocalPartitionReader(
conf, shuffleKey, location, clientFactory, startMapIndex, endMapIndex);
conf, shuffleKey, location, clientFactory, startMapIndex, endMapIndex, callback);
} else {
return new WorkerPartitionReader(
conf,
Expand All @@ -428,22 +427,18 @@ private PartitionReader createReader(
startMapIndex,
endMapIndex,
fetchChunkRetryCnt,
fetchChunkMaxRetry);
fetchChunkMaxRetry,
callback);
}
case HDFS:
return new DfsPartitionReader(
conf, shuffleKey, location, clientFactory, startMapIndex, endMapIndex);
conf, shuffleKey, location, clientFactory, startMapIndex, endMapIndex, callback);
default:
throw new CelebornIOException(
String.format("Unknown storage info %s to read location %s", storageInfo, location));
}
}

public void setCallback(MetricsCallback callback) {
// callback must set before read()
this.callback = callback;
}

@Override
public int read() throws IOException {
if (position < limit) {
Expand Down Expand Up @@ -539,8 +534,6 @@ private boolean fillBuffer() throws IOException {
return false;
}

long startTime = System.nanoTime();

boolean hasData = false;
while (currentChunk.isReadable() || moveToNextChunk()) {
currentChunk.readBytes(sizeBuf);
Expand Down Expand Up @@ -572,9 +565,7 @@ private boolean fillBuffer() throws IOException {
Set<Integer> batchSet = batchesRead.get(mapId);
if (!batchSet.contains(batchId)) {
batchSet.add(batchId);
if (callback != null) {
callback.incBytesRead(BATCH_HEADER_SIZE + size);
}
callback.incBytesRead(BATCH_HEADER_SIZE + size);
if (shuffleCompressionEnabled) {
// decompress data
int originalLength = decompressor.getOriginalLen(compressedBuf);
Expand All @@ -598,9 +589,6 @@ private boolean fillBuffer() throws IOException {
}
}

if (callback != null) {
callback.incReadTime(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime));
}
return hasData;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,22 @@ public class DfsPartitionReader implements PartitionReader {
private int currentChunkIndex = 0;
private TransportClient client;
private PbStreamHandler streamHandler;
private MetricsCallback metricsCallback;

public DfsPartitionReader(
CelebornConf conf,
String shuffleKey,
PartitionLocation location,
TransportClientFactory clientFactory,
int startMapIndex,
int endMapIndex)
int endMapIndex,
MetricsCallback metricsCallback)
throws IOException {
shuffleChunkSize = conf.dfsReadChunkSize();
fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight();
results = new LinkedBlockingQueue<>();

this.metricsCallback = metricsCallback;
this.location = location;

final List<Long> chunkOffsets = new ArrayList<>();
Expand Down Expand Up @@ -224,7 +227,10 @@ public ByteBuf next() throws IOException, InterruptedException {
try {
while (chunk == null) {
checkException();
Long startFetchWait = System.nanoTime();
chunk = results.poll(500, TimeUnit.MILLISECONDS);
metricsCallback.incReadTime(
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait));
logger.debug("poll result with result size: {}", results.size());
}
} catch (InterruptedException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,16 @@ public class LocalPartitionReader implements PartitionReader {
private AtomicBoolean pendingFetchTask = new AtomicBoolean(false);
private PbStreamHandler streamHandler;
private TransportClient client;
private MetricsCallback metricsCallback;

public LocalPartitionReader(
CelebornConf conf,
String shuffleKey,
PartitionLocation location,
TransportClientFactory clientFactory,
int startMapIndex,
int endMapIndex)
int endMapIndex,
MetricsCallback metricsCallback)
throws IOException {
if (readLocalShufflePool == null) {
synchronized (LocalPartitionReader.class) {
Expand All @@ -88,6 +90,7 @@ public LocalPartitionReader(
fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight();
results = new LinkedBlockingQueue<>();
this.location = location;
this.metricsCallback = metricsCallback;
long fetchTimeoutMs = conf.clientFetchTimeoutMs();
try {
client = clientFactory.createClient(location.getHost(), location.getFetchPort(), 0);
Expand Down Expand Up @@ -199,7 +202,10 @@ public ByteBuf next() throws IOException, InterruptedException {
try {
while (chunk == null) {
checkException();
Long startFetchWait = System.nanoTime();
chunk = results.poll(100, TimeUnit.MILLISECONDS);
metricsCallback.incReadTime(
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait));
logger.debug("Poll result with result size: {}", results.size());
}
} catch (InterruptedException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class WorkerPartitionReader implements PartitionReader {
private final TransportClientFactory clientFactory;
private PbStreamHandler streamHandler;
private TransportClient client;
private MetricsCallback metricsCallback;

private int returnedChunks;
private int chunkIndex;
Expand All @@ -76,11 +77,13 @@ public class WorkerPartitionReader implements PartitionReader {
int startMapIndex,
int endMapIndex,
int fetchChunkRetryCnt,
int fetchChunkMaxRetry)
int fetchChunkMaxRetry,
MetricsCallback metricsCallback)
throws IOException, InterruptedException {
fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight();
results = new LinkedBlockingQueue<>();
fetchTimeoutMs = conf.clientFetchTimeoutMs();
this.metricsCallback = metricsCallback;
// only add the buffer to results queue if this reader is not closed.
callback =
new ChunkReceivedCallback() {
Expand Down Expand Up @@ -144,7 +147,10 @@ public ByteBuf next() throws IOException, InterruptedException {
try {
while (chunk == null) {
checkException();
Long startFetchWait = System.nanoTime();
chunk = results.poll(500, TimeUnit.MILLISECONDS);
metricsCallback.incReadTime(
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait));
}
} catch (InterruptedException e) {
logger.error("PartitionReader thread interrupted while polling data.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.slf4j.LoggerFactory;

import org.apache.celeborn.client.read.CelebornInputStream;
import org.apache.celeborn.client.read.MetricsCallback;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
Expand Down Expand Up @@ -112,7 +113,12 @@ public void cleanup(int shuffleId, int mapId, int attemptId) {}

@Override
public CelebornInputStream readPartition(
int shuffleId, int partitionId, int attemptNumber, int startMapIndex, int endMapIndex) {
int shuffleId,
int partitionId,
int attemptNumber,
int startMapIndex,
int endMapIndex,
MetricsCallback metricsCallback) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._
import org.junit.Assert

import org.apache.celeborn.CelebornFunSuite
import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.identity.UserIdentifier
import org.apache.celeborn.common.util.JavaUtils.timeOutOrMeetCondition
Expand Down Expand Up @@ -140,12 +141,17 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
// reduce file group size (for empty partitions)
Assert.assertEquals(shuffleClient.getReduceFileGroupsMap.size(), 0)

val metricsCallback = new MetricsCallback {
override def incBytesRead(bytesWritten: Long): Unit = {}
override def incReadTime(time: Long): Unit = {}
}

// reduce normal empty CelebornInputStream
var stream = shuffleClient.readPartition(shuffleId, 1, 1, 0, Integer.MAX_VALUE)
var stream = shuffleClient.readPartition(shuffleId, 1, 1, 0, Integer.MAX_VALUE, metricsCallback)
Assert.assertEquals(stream.read(), -1)

// reduce normal null partition for CelebornInputStream
stream = shuffleClient.readPartition(shuffleId, 3, 1, 0, Integer.MAX_VALUE)
stream = shuffleClient.readPartition(shuffleId, 3, 1, 0, Integer.MAX_VALUE, metricsCallback)
Assert.assertEquals(stream.read(), -1)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite

import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.identity.UserIdentifier
import org.apache.celeborn.common.internal.Logging
Expand Down Expand Up @@ -102,7 +103,11 @@ trait ReadWriteTestBase extends AnyFunSuite

shuffleClient.mapperEnd(1, 0, 0, 1)

val inputStream = shuffleClient.readPartition(1, 0, 0, 0, Integer.MAX_VALUE)
val metricsCallback = new MetricsCallback {
override def incBytesRead(bytesWritten: Long): Unit = {}
override def incReadTime(time: Long): Unit = {}
}
val inputStream = shuffleClient.readPartition(1, 0, 0, 0, Integer.MAX_VALUE, metricsCallback)
val outputStream = new ByteArrayOutputStream()

var b = inputStream.read()
Expand Down

0 comments on commit 0583cdb

Please sign in to comment.