diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java index 4403120d52..06973d0797 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java @@ -80,6 +80,7 @@ import org.apache.uniffle.common.util.Constants; import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.hadoop.shim.HadoopShimImpl; +import org.apache.uniffle.proto.RssProtos.MergeContext; import org.apache.uniffle.storage.util.StorageType; import static org.apache.hadoop.mapreduce.RssMRConfig.RSS_REMOTE_MERGE_CLASS_LOADER; @@ -285,17 +286,20 @@ public Thread newThread(Runnable r) { RssMRConfig.toRssConf(conf) .get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE), 0, - remoteMergeEnable ? conf.getMapOutputKeyClass().getName() : null, remoteMergeEnable - ? conf.getMapOutputValueClass().getName() - : null, - remoteMergeEnable - ? conf.getOutputKeyComparator().getClass().getName() - : null, - conf.getInt( - RssMRConfig.RSS_MERGED_BLOCK_SZIE, - RssMRConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT), - conf.get(RSS_REMOTE_MERGE_CLASS_LOADER))); + ? MergeContext.newBuilder() + .setKeyClass(conf.getMapOutputKeyClass().getName()) + .setValueClass(conf.getMapOutputValueClass().getName()) + .setComparatorClass( + conf.getOutputKeyComparator().getClass().getName()) + .setMergedBlockSize( + conf.getInt( + RssMRConfig.RSS_MERGED_BLOCK_SZIE, + RssMRConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT)) + .setMergeClassLoader( + conf.get(RSS_REMOTE_MERGE_CLASS_LOADER, "")) + .build() + : null)); LOG.info( "Finish register shuffle with " + (System.currentTimeMillis() - start) diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java index 6bd5cd992b..159da8a84d 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java @@ -62,6 +62,7 @@ import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.util.JavaUtils; +import org.apache.uniffle.proto.RssProtos; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -721,11 +722,7 @@ public void registerShuffle( ShuffleDataDistributionType distributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) {} + RssProtos.MergeContext mergeContext) {} @Override public boolean sendCommit( diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java index a66f16f56d..d2aaebe045 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java @@ -77,6 +77,7 @@ import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.hadoop.shim.HadoopShimImpl; +import org.apache.uniffle.proto.RssProtos; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; @@ -507,11 +508,7 @@ public void registerShuffle( ShuffleDataDistributionType distributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) {} + RssProtos.MergeContext mergeContext) {} @Override public boolean sendCommit( diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index d82d3a509f..47f9e271de 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -1028,10 +1028,6 @@ protected void registerShuffleServers( ShuffleDataDistributionType.NORMAL, maxConcurrencyPerPartitionToWrite, stageAttemptNumber, - null, - null, - null, - -1, null); }); LOG.info( diff --git a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java index 85a138c13e..f44ad0c5e2 100644 --- a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java +++ b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java @@ -66,6 +66,7 @@ import org.apache.uniffle.common.util.Constants; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.common.util.RetryUtils; +import org.apache.uniffle.proto.RssProtos; import static org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE; @@ -305,13 +306,23 @@ public ShuffleAssignmentsInfo run() throws Exception { RssTezConfig.toRssConf(conf) .get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE), 0, - keyClassName, - valueClassName, - comparatorClassName, - conf.getInt( - RssTezConfig.RSS_MERGED_BLOCK_SZIE, - RssTezConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT), - conf.get(RssTezConfig.RSS_REMOTE_MERGE_CLASS_LOADER))); + StringUtils.isBlank(keyClassName) + ? null + : RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClassName) + .setValueClass(valueClassName) + .setComparatorClass(comparatorClassName) + .setMergedBlockSize( + conf.getInt( + RssTezConfig.RSS_MERGED_BLOCK_SZIE, + RssTezConfig + .RSS_MERGED_BLOCK_SZIE_DEFAULT)) + .setMergeClassLoader( + conf.get( + RssTezConfig + .RSS_REMOTE_MERGE_CLASS_LOADER, + "")) + .build())); LOG.info( "Finish register shuffle with " + (System.currentTimeMillis() - start) diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java index 5edf74ef4e..3765dc83b4 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java @@ -77,6 +77,7 @@ import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.util.JavaUtils; +import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.storage.util.StorageType; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -715,11 +716,7 @@ public void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) {} + RssProtos.MergeContext mergeContext) {} @Override public boolean sendCommit( diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java index 121271e361..d21c7e67b7 100644 --- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java +++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java @@ -33,6 +33,7 @@ import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.proto.RssProtos.MergeContext; public interface ShuffleWriteClient { @@ -72,10 +73,6 @@ default void registerShuffle( dataDistributionType, maxConcurrencyPerPartitionToWrite, 0, - null, - null, - null, - -1, null); } @@ -88,11 +85,7 @@ void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader); + MergeContext mergeContext); boolean sendCommit( Set shuffleServerInfoSet, String appId, int shuffleId, int numMaps); diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index cba7ccc065..c81d3c7255 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -95,6 +95,7 @@ import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.common.util.ThreadUtils; +import org.apache.uniffle.proto.RssProtos.MergeContext; public class ShuffleWriteClientImpl implements ShuffleWriteClient { @@ -564,11 +565,7 @@ public void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) { + MergeContext mergeContext) { String user = null; try { user = UserGroupInformation.getCurrentUser().getShortUserName(); @@ -586,11 +583,7 @@ public void registerShuffle( dataDistributionType, maxConcurrencyPerPartitionToWrite, stageAttemptNumber, - keyClassName, - valueClassName, - comparatorClassName, - mergedBlockSize, - mergeClassLoader); + mergeContext); RssRegisterShuffleResponse response = getShuffleServerClient(shuffleServerInfo).registerShuffle(request); diff --git a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java index 7d4cbf980b..6798a792cb 100644 --- a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java +++ b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java @@ -34,6 +34,7 @@ import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.proto.RssProtos; public class MockedShuffleWriteClient implements ShuffleWriteClient { @@ -63,11 +64,7 @@ public void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) {} + RssProtos.MergeContext mergeContext) {} @Override public boolean sendCommit( diff --git a/common/src/main/java/org/apache/uniffle/common/config/RssBaseConf.java b/common/src/main/java/org/apache/uniffle/common/config/RssBaseConf.java index b2cd4e9d0a..4e197204ac 100644 --- a/common/src/main/java/org/apache/uniffle/common/config/RssBaseConf.java +++ b/common/src/main/java/org/apache/uniffle/common/config/RssBaseConf.java @@ -294,6 +294,18 @@ public class RssBaseConf extends RssConf { + " first combining the username and the password with a colon (uniffle:uniffle123)" + ", and then by encoding the resulting string in base64 (dW5pZmZsZTp1bmlmZmxlMTIz)."); + public static final ConfigOption RSS_STORAGE_WRITE_DATA_BUFFER_SIZE = + ConfigOptions.key("rss.storage.write.dataBufferSize") + .stringType() + .defaultValue("8k") + .withDescription("The buffer size to cache the write data content."); + + public static final ConfigOption RSS_STORAGE_WRITE_INDEX_BUFFER_SIZE = + ConfigOptions.key("rss.storage.write.indexBufferSize") + .stringType() + .defaultValue("8k") + .withDescription("The buffer size to cache the write index content."); + public boolean loadConfFromFile(String fileName, List> configOptions) { Map properties = RssUtils.getPropertiesFromFile(fileName); if (properties == null) { diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java index 9fefb98f6b..cde9a3ee26 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java @@ -103,7 +103,21 @@ private static Map> decodePartitionData(ByteBuf int lengthOfShuffleBlocks = byteBuf.readInt(); List shuffleBlockInfoList = Lists.newArrayList(); for (int j = 0; j < lengthOfShuffleBlocks; j++) { - shuffleBlockInfoList.add(Decoders.decodeShuffleBlockInfo(byteBuf)); + try { + shuffleBlockInfoList.add(Decoders.decodeShuffleBlockInfo(byteBuf)); + } catch (Throwable t) { + // An OutOfDirectMemoryError will be thrown, when the direct memory reaches the limit. + // OutOfDirectMemoryError will not cause the JVM to exit, but may lead to direct memory + // leaks. + // Note: You can refer to docs/server_guide.md to set MAX_DIRECT_MEMORY_SIZE to a + // reasonable value. + shuffleBlockInfoList.forEach(sbi -> sbi.getData().release()); + partitionToBlocks.forEach( + (integer, shuffleBlockInfos) -> { + shuffleBlockInfos.forEach(sbi -> sbi.getData().release()); + }); + throw t; + } } partitionToBlocks.put(partitionId, shuffleBlockInfoList); } diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java index 33e08c4742..e494300013 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java @@ -537,7 +537,8 @@ private ServerNode toServerNode(ShuffleServerHeartBeatRequest request) { request.getServerId().getJettyPort(), request.getStartTimeMs(), request.getVersion(), - request.getGitCommitId()); + request.getGitCommitId(), + request.getApplicationInfoList()); } /** diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/ServerNode.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/ServerNode.java index ad992f0bc3..356c4bfe9d 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/ServerNode.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/ServerNode.java @@ -17,14 +17,18 @@ package org.apache.uniffle.coordinator; +import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import org.apache.uniffle.common.ServerStatus; import org.apache.uniffle.common.storage.StorageInfo; +import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.proto.RssProtos.ShuffleServerId; public class ServerNode implements Comparable { @@ -46,6 +50,7 @@ public class ServerNode implements Comparable { private long startTime = -1; private String version; private String gitCommitId; + Map appIdToInfos; public ServerNode(String id) { this(id, "", 0, 0, 0, 0, 0, Sets.newHashSet(), ServerStatus.EXCLUDED); @@ -181,7 +186,8 @@ public ServerNode( jettyPort, startTime, "", - ""); + "", + Collections.EMPTY_LIST); } public ServerNode( @@ -199,7 +205,8 @@ public ServerNode( int jettyPort, long startTime, String version, - String gitCommitId) { + String gitCommitId, + List appInfos) { this.id = id; this.ip = ip; this.grpcPort = grpcPort; @@ -221,6 +228,8 @@ public ServerNode( this.startTime = startTime; this.version = version; this.gitCommitId = gitCommitId; + this.appIdToInfos = new ConcurrentHashMap<>(); + appInfos.forEach(appInfo -> appIdToInfos.put(appInfo.getAppId(), appInfo)); } public ShuffleServerId convertToGrpcProto() { diff --git a/dashboard/src/main/webapp/src/pages/ApplicationPage.vue b/dashboard/src/main/webapp/src/pages/ApplicationPage.vue index 3de57355a4..49278ce25f 100644 --- a/dashboard/src/main/webapp/src/pages/ApplicationPage.vue +++ b/dashboard/src/main/webapp/src/pages/ApplicationPage.vue @@ -79,16 +79,13 @@ :formatter="dateFormatter" sortable /> - - + + + diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java index 3b19e8e7ab..90591a3b87 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java @@ -64,6 +64,7 @@ import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.ChecksumUtils; import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.server.ShuffleServer; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.buffer.ShuffleBufferType; @@ -184,11 +185,13 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { ShuffleDataDistributionType.NORMAL, 0, -1, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // task 0 attempt 0 generate three blocks @@ -352,11 +355,13 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception ShuffleDataDistributionType.NORMAL, 0, -1, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // task 0 attempt 0 generate three blocks @@ -528,11 +533,13 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except ShuffleDataDistributionType.NORMAL, 0, -1, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // this shuffle have three partition, which is hash by key index mode 3 @@ -739,11 +746,13 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th ShuffleDataDistributionType.NORMAL, 0, -1, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // this shuffle have three partition, which is hash by key index mode 3 diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java index ed6a48b17a..85499f3b0a 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java @@ -64,6 +64,7 @@ import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.ChecksumUtils; import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.server.ShuffleServer; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.buffer.ShuffleBufferType; @@ -189,11 +190,13 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { ShuffleDataDistributionType.NORMAL, -1, 0, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // task 0 attempt 0 generate three blocks @@ -357,11 +360,13 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception ShuffleDataDistributionType.NORMAL, -1, 0, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); // 3 report shuffle result @@ -534,11 +539,13 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except ShuffleDataDistributionType.NORMAL, -1, 0, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // this shuffle have three partition, which is hash by key index mode 3 @@ -746,11 +753,13 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th ShuffleDataDistributionType.NORMAL, -1, 0, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // this shuffle have three partition, which is hash by key index mode 3 diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java index 8583e952e1..fbfe578247 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java @@ -127,7 +127,8 @@ public ShuffleServerHeartBeatResponse doSendHeartBeat( Map storageInfo, int nettyPort, int jettyPort, - long startTimeMs) { + long startTimeMs, + List appInfos) { ShuffleServerId serverId = ShuffleServerId.newBuilder() .setId(id) @@ -149,6 +150,7 @@ public ShuffleServerHeartBeatResponse doSendHeartBeat( .setStartTimeMs(startTimeMs) .setVersion(Constants.VERSION) .setGitCommitId(Constants.REVISION_SHORT) + .addAllApplicationInfo(appInfos) .build(); RssProtos.StatusCode status; @@ -225,7 +227,8 @@ public RssSendHeartBeatResponse sendHeartBeat(RssSendHeartBeatRequest request) { request.getStorageInfo(), request.getNettyPort(), request.getJettyPort(), - request.getStartTimeMs()); + request.getStartTimeMs(), + request.getAppInfos()); RssSendHeartBeatResponse response; RssProtos.StatusCode statusCode = rpcResponse.getStatus(); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index 5477fc4ead..dccd9f9383 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -31,7 +31,6 @@ import com.google.protobuf.ByteString; import com.google.protobuf.UnsafeByteOperations; import io.netty.buffer.Unpooled; -import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -95,6 +94,7 @@ import org.apache.uniffle.proto.RssProtos.GetShuffleResultForMultiPartResponse; import org.apache.uniffle.proto.RssProtos.GetShuffleResultRequest; import org.apache.uniffle.proto.RssProtos.GetShuffleResultResponse; +import org.apache.uniffle.proto.RssProtos.MergeContext; import org.apache.uniffle.proto.RssProtos.PartitionToBlockIds; import org.apache.uniffle.proto.RssProtos.RemoteStorage; import org.apache.uniffle.proto.RssProtos.RemoteStorageConfItem; @@ -198,11 +198,7 @@ private ShuffleRegisterResponse doRegisterShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) { + MergeContext mergeContext) { ShuffleRegisterRequest.Builder reqBuilder = ShuffleRegisterRequest.newBuilder(); reqBuilder .setAppId(appId) @@ -212,16 +208,8 @@ private ShuffleRegisterResponse doRegisterShuffle( .setMaxConcurrencyPerPartitionToWrite(maxConcurrencyPerPartitionToWrite) .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges)) .setStageAttemptNumber(stageAttemptNumber); - if (StringUtils.isNotBlank(keyClassName)) { - reqBuilder.setKeyClass(keyClassName); - reqBuilder.setValueClass(valueClassName); - if (StringUtils.isNotBlank(comparatorClassName)) { - reqBuilder.setComparatorClass(comparatorClassName); - } - reqBuilder.setMergedBlockSize(mergedBlockSize); - if (StringUtils.isNotBlank(mergeClassLoader)) { - reqBuilder.setMergeClassLoader(mergeClassLoader); - } + if (mergeContext != null) { + reqBuilder.setMergeContext(mergeContext); } RemoteStorage.Builder rsBuilder = RemoteStorage.newBuilder(); rsBuilder.setPath(remoteStorageInfo.getPath()); @@ -496,11 +484,7 @@ public RssRegisterShuffleResponse registerShuffle(RssRegisterShuffleRequest requ request.getDataDistributionType(), request.getMaxConcurrencyPerPartitionToWrite(), request.getStageAttemptNumber(), - request.getKeyClassName(), - request.getValueClassName(), - request.getComparatorClassName(), - request.getMergedBlockSize(), - request.getMergeClassLoader()); + request.getMergeContext()); RssRegisterShuffleResponse response; RssProtos.StatusCode statusCode = rpcResponse.getStatus(); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java index 1db40a0d1f..92ed1e15e9 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java @@ -25,6 +25,7 @@ import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.config.RssClientConf; +import org.apache.uniffle.proto.RssProtos.MergeContext; public class RssRegisterShuffleRequest { @@ -36,11 +37,8 @@ public class RssRegisterShuffleRequest { private ShuffleDataDistributionType dataDistributionType; private int maxConcurrencyPerPartitionToWrite; private int stageAttemptNumber; - private String keyClassName; - private String valueClassName; - private String comparatorClassName; - private int mergedBlockSize; - private String mergeClassLoader; + + private final MergeContext mergeContext; public RssRegisterShuffleRequest( String appId, @@ -59,10 +57,6 @@ public RssRegisterShuffleRequest( dataDistributionType, maxConcurrencyPerPartitionToWrite, 0, - null, - null, - null, - -1, null); } @@ -75,11 +69,7 @@ public RssRegisterShuffleRequest( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) { + MergeContext mergeContext) { this.appId = appId; this.shuffleId = shuffleId; this.partitionRanges = partitionRanges; @@ -88,11 +78,7 @@ public RssRegisterShuffleRequest( this.dataDistributionType = dataDistributionType; this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite; this.stageAttemptNumber = stageAttemptNumber; - this.keyClassName = keyClassName; - this.valueClassName = valueClassName; - this.comparatorClassName = comparatorClassName; - this.mergedBlockSize = mergedBlockSize; - this.mergeClassLoader = mergeClassLoader; + this.mergeContext = mergeContext; } public RssRegisterShuffleRequest( @@ -111,10 +97,6 @@ public RssRegisterShuffleRequest( dataDistributionType, RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), 0, - null, - null, - null, - -1, null); } @@ -129,10 +111,6 @@ public RssRegisterShuffleRequest( ShuffleDataDistributionType.NORMAL, RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), 0, - null, - null, - null, - -1, null); } @@ -168,23 +146,7 @@ public int getStageAttemptNumber() { return stageAttemptNumber; } - public String getKeyClassName() { - return keyClassName; - } - - public String getValueClassName() { - return valueClassName; - } - - public String getComparatorClassName() { - return comparatorClassName; - } - - public int getMergedBlockSize() { - return mergedBlockSize; - } - - public String getMergeClassLoader() { - return mergeClassLoader; + public MergeContext getMergeContext() { + return mergeContext; } } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendHeartBeatRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendHeartBeatRequest.java index a31164195d..a4f23ba738 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendHeartBeatRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendHeartBeatRequest.java @@ -17,11 +17,13 @@ package org.apache.uniffle.client.request; +import java.util.List; import java.util.Map; import java.util.Set; import org.apache.uniffle.common.ServerStatus; import org.apache.uniffle.common.storage.StorageInfo; +import org.apache.uniffle.proto.RssProtos; public class RssSendHeartBeatRequest { @@ -39,6 +41,7 @@ public class RssSendHeartBeatRequest { private final int nettyPort; private final int jettyPort; private final long startTimeMs; + private final List appInfos; public RssSendHeartBeatRequest( String shuffleServerId, @@ -54,7 +57,8 @@ public RssSendHeartBeatRequest( Map storageInfo, int nettyPort, int jettyPort, - long startTimeMs) { + long startTimeMs, + List appInfos) { this.shuffleServerId = shuffleServerId; this.shuffleServerIp = shuffleServerIp; this.shuffleServerPort = shuffleServerPort; @@ -69,6 +73,7 @@ public RssSendHeartBeatRequest( this.nettyPort = nettyPort; this.jettyPort = jettyPort; this.startTimeMs = startTimeMs; + this.appInfos = appInfos; } public String getShuffleServerId() { @@ -126,4 +131,8 @@ public int getJettyPort() { public long getStartTimeMs() { return startTimeMs; } + + public List getAppInfos() { + return appInfos; + } } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index 06d781e134..d92ec40c7a 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -179,6 +179,14 @@ message ShufflePartitionRange { int32 end = 2; } +message MergeContext { + string keyClass = 1; + string valueClass = 2; + string comparatorClass = 3; + int32 mergedBlockSize = 4; + string mergeClassLoader = 5; +} + message ShuffleRegisterRequest { string appId = 1; int32 shuffleId = 2; @@ -188,11 +196,7 @@ message ShuffleRegisterRequest { DataDistribution shuffleDataDistribution = 6; int32 maxConcurrencyPerPartitionToWrite = 7; int32 stageAttemptNumber = 8; - string keyClass = 9; - string valueClass = 10; - string comparatorClass = 11; - int32 mergedBlockSize = 12; - string mergeClassLoader = 13; + MergeContext mergeContext = 9; } enum DataDistribution { @@ -273,6 +277,17 @@ enum ServerStatus { // todo: more status, such as UPGRADING } +message ApplicationInfo { + string appId = 1; + int64 partitionNum = 2; + int64 memorySize = 3; + int64 localFileNum = 4; + int64 localTotalSize = 5; + int64 hadoopFileNum = 6; + int64 hadoopTotalSize = 7; + int64 totalSize = 8; +} + message ShuffleServerHeartBeatRequest { ShuffleServerId serverId = 1; int64 usedMemory = 2; @@ -286,6 +301,7 @@ message ShuffleServerHeartBeatRequest { optional string version = 22; optional string gitCommitId = 23; optional int64 startTimeMs = 24; + repeated ApplicationInfo applicationInfo = 25; } message ShuffleServerHeartBeatResponse { diff --git a/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java b/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java index 4b2d4607ad..5a44728b59 100644 --- a/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java +++ b/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java @@ -17,6 +17,7 @@ package org.apache.uniffle.server; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ScheduledExecutorService; @@ -33,6 +34,7 @@ import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.storage.StorageInfo; import org.apache.uniffle.common.util.ThreadUtils; +import org.apache.uniffle.proto.RssProtos; public class RegisterHeartBeat { @@ -84,7 +86,8 @@ public void startHeartBeat() { shuffleServer.getStorageManager().getStorageInfo(), shuffleServer.getNettyPort(), shuffleServer.getJettyPort(), - shuffleServer.getStartTimeMs()); + shuffleServer.getStartTimeMs(), + shuffleServer.getAppInfos()); } catch (Exception e) { LOG.warn("Error happened when send heart beat to coordinator"); } @@ -107,7 +110,8 @@ public boolean sendHeartBeat( Map localStorageInfo, int nettyPort, int jettyPort, - long startTimeMs) { + long startTimeMs, + List appInfos) { // use `rss.server.heartbeat.interval` as the timeout option RssSendHeartBeatRequest request = new RssSendHeartBeatRequest( @@ -124,7 +128,8 @@ public boolean sendHeartBeat( localStorageInfo, nettyPort, jettyPort, - startTimeMs); + startTimeMs, + appInfos); if (coordinatorClient.sendHeartBeat(request).getStatusCode() == StatusCode.SUCCESS) { return true; diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java index 0d3d5ca0d7..574b9ef0af 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleFlushManager.java @@ -158,6 +158,7 @@ public void processFlushEvent(ShuffleDataFlushEvent event) throws Exception { int maxConcurrencyPerPartitionToWrite = getMaxConcurrencyPerPartitionWrite(event); CreateShuffleWriteHandlerRequest request = new CreateShuffleWriteHandlerRequest( + this.shuffleServerConf, storageType, event.getAppId(), event.getShuffleId(), diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java index 92fa6b36bc..59f53c97a2 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java @@ -17,8 +17,10 @@ package org.apache.uniffle.server; +import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; @@ -57,6 +59,7 @@ import org.apache.uniffle.common.util.ThreadUtils; import org.apache.uniffle.common.web.CoalescedCollectorRegistry; import org.apache.uniffle.common.web.JettyServer; +import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.server.buffer.ShuffleBufferManager; import org.apache.uniffle.server.buffer.ShuffleBufferType; import org.apache.uniffle.server.merge.ShuffleMergeManager; @@ -584,6 +587,26 @@ public long getStartTimeMs() { return startTimeMs; } + public List getAppInfos() { + List appInfos = new ArrayList<>(); + Map taskInfos = getShuffleTaskManager().getShuffleTaskInfos(); + taskInfos.forEach( + (appId, taskInfo) -> { + RssProtos.ApplicationInfo applicationInfo = + RssProtos.ApplicationInfo.newBuilder() + .setAppId(appId) + .setPartitionNum(taskInfo.getPartitionNum()) + .setMemorySize(taskInfo.getInMemoryDataSize()) + .setLocalTotalSize(taskInfo.getOnLocalFileDataSize()) + .setHadoopTotalSize(taskInfo.getOnHadoopDataSize()) + .setTotalSize(taskInfo.getTotalDataSize()) + .build(); + + appInfos.add(applicationInfo); + }); + return appInfos; + } + @VisibleForTesting public void sendHeartbeat() { ShuffleServer shuffleServer = this; @@ -600,7 +623,8 @@ public void sendHeartbeat() { shuffleServer.getStorageManager().getStorageInfo(), shuffleServer.getNettyPort(), shuffleServer.getJettyPort(), - shuffleServer.getStartTimeMs()); + shuffleServer.getStartTimeMs(), + shuffleServer.getAppInfos()); } public ShuffleMergeManager getShuffleMergeManager() { diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index b18ccda342..460b64d693 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -325,7 +325,7 @@ public void registerShuffle( maxConcurrencyPerPartitionToWrite); if (StatusCode.SUCCESS == result && shuffleServer.isRemoteMergeEnable() - && StringUtils.isNotBlank(req.getKeyClass())) { + && req.hasMergeContext()) { // The merged block is in a different domain from the original block, // so you need to register a new app for holding the merged block. result = @@ -343,14 +343,7 @@ public void registerShuffle( result = shuffleServer .getShuffleMergeManager() - .registerShuffle( - appId, - shuffleId, - req.getKeyClass(), - req.getValueClass(), - req.getComparatorClass(), - req.getMergedBlockSize(), - req.getMergeClassLoader()); + .registerShuffle(appId, shuffleId, req.getMergeContext()); } } auditContext.withStatusCode(result); diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java index d4e6eeb326..94987d6614 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java @@ -283,6 +283,10 @@ public ShuffleDetailInfo getShuffleDetailInfo(int shuffleId) { return shuffleDetailInfos.get(shuffleId); } + public long getPartitionNum() { + return partitionDataSizes.values().stream().mapToLong(Map::size).sum(); + } + @Override public String toString() { return "ShuffleTaskInfo{" diff --git a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java index ffe7f025f3..027b63d9df 100644 --- a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java +++ b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java @@ -44,6 +44,7 @@ import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.util.JavaUtils; +import org.apache.uniffle.proto.RssProtos.MergeContext; import org.apache.uniffle.server.ShuffleServer; import org.apache.uniffle.server.ShuffleServerConf; @@ -147,22 +148,16 @@ public ClassLoader getClassLoader(String label) { return cachedClassLoader.getOrDefault(label, cachedClassLoader.get("")); } - public StatusCode registerShuffle( - String appId, - int shuffleId, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String classLoaderLabel) { + public StatusCode registerShuffle(String appId, int shuffleId, MergeContext mergeContext) { try { - ClassLoader classLoader = getClassLoader(classLoaderLabel); - Class kClass = ClassUtils.getClass(classLoader, keyClassName); - Class vClass = ClassUtils.getClass(classLoader, valueClassName); + ClassLoader classLoader = getClassLoader(mergeContext.getMergeClassLoader()); + Class kClass = ClassUtils.getClass(classLoader, mergeContext.getKeyClass()); + Class vClass = ClassUtils.getClass(classLoader, mergeContext.getValueClass()); Comparator comparator; - if (StringUtils.isNotBlank(comparatorClassName)) { + if (StringUtils.isNotBlank(mergeContext.getComparatorClass())) { Constructor constructor = - ClassUtils.getClass(classLoader, comparatorClassName).getDeclaredConstructor(); + ClassUtils.getClass(classLoader, mergeContext.getComparatorClass()) + .getDeclaredConstructor(); constructor.setAccessible(true); comparator = (Comparator) constructor.newInstance(); } else { @@ -182,7 +177,7 @@ public StatusCode registerShuffle( kClass, vClass, comparator, - mergedBlockSize, + mergeContext.getMergedBlockSize(), classLoader)); } catch (ClassNotFoundException | InstantiationException diff --git a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java index 6dfc8e5706..112f7f5e43 100644 --- a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java +++ b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java @@ -48,6 +48,7 @@ import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.serializer.writable.WritableSerializer; import org.apache.uniffle.common.util.BlockIdLayout; +import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.server.ShuffleServer; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.ShuffleServerMetrics; @@ -142,7 +143,15 @@ public void testMergerManager(String classes) throws Exception { new RemoteStorageInfo(""), USER); mergeManager.registerShuffle( - APP_ID, SHUFFLE_ID, keyClassName, valueClassName, comparatorClassName, -1, ""); + APP_ID, + SHUFFLE_ID, + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClassName) + .setValueClass(valueClassName) + .setComparatorClass(comparatorClassName) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 4 report blocks // 4.1 send shuffle data diff --git a/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java b/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java index f2c79ac1de..e748608ded 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java +++ b/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java @@ -132,6 +132,7 @@ public void updateReadMetrics(StorageReadMetrics metrics) { @Override ShuffleWriteHandler newWriteHandler(CreateShuffleWriteHandlerRequest request) { return new LocalFileWriteHandler( + request.getRssBaseConf(), request.getAppId(), request.getShuffleId(), request.getStartPartition(), diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java index 9f9fd0c5da..2b398fd929 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java @@ -27,6 +27,7 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.common.ShufflePartitionedBlock; +import org.apache.uniffle.common.config.RssBaseConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.ByteBufUtils; import org.apache.uniffle.storage.common.FileBasedShuffleSegment; @@ -37,24 +38,57 @@ public class LocalFileWriteHandler implements ShuffleWriteHandler { private static final Logger LOG = LoggerFactory.getLogger(LocalFileWriteHandler.class); + private final RssBaseConf rssBaseConf; private String fileNamePrefix; private String basePath; + private final int dataBufferSize; + private final int indexBufferSize; public LocalFileWriteHandler( + RssBaseConf rssBaseConf, String appId, int shuffleId, int startPartition, int endPartition, String storageBasePath, String fileNamePrefix) { + this.rssBaseConf = rssBaseConf; this.fileNamePrefix = fileNamePrefix; this.basePath = ShuffleStorageUtils.getFullShuffleDataFolder( storageBasePath, ShuffleStorageUtils.getShuffleDataPath(appId, shuffleId, startPartition, endPartition)); + this.dataBufferSize = + (int) + this.rssBaseConf.getSizeAsBytes( + RssBaseConf.RSS_STORAGE_WRITE_DATA_BUFFER_SIZE.key(), + RssBaseConf.RSS_STORAGE_WRITE_DATA_BUFFER_SIZE.defaultValue()); + this.indexBufferSize = + (int) + this.rssBaseConf.getSizeAsBytes( + RssBaseConf.RSS_STORAGE_WRITE_INDEX_BUFFER_SIZE.key(), + RssBaseConf.RSS_STORAGE_WRITE_INDEX_BUFFER_SIZE.defaultValue()); createBasePath(); } + @VisibleForTesting + public LocalFileWriteHandler( + String appId, + int shuffleId, + int startPartition, + int endPartition, + String storageBasePath, + String fileNamePrefix) { + this( + new RssBaseConf(), + appId, + shuffleId, + startPartition, + endPartition, + storageBasePath, + fileNamePrefix); + } + private void createBasePath() { File baseFolder = new File(basePath); if (baseFolder.isDirectory()) { @@ -96,8 +130,8 @@ public synchronized void write(Collection shuffleBlocks String dataFileName = ShuffleStorageUtils.generateDataFileName(fileNamePrefix); String indexFileName = ShuffleStorageUtils.generateIndexFileName(fileNamePrefix); - try (LocalFileWriter dataWriter = createWriter(dataFileName); - LocalFileWriter indexWriter = createWriter(indexFileName)) { + try (LocalFileWriter dataWriter = createWriter(dataFileName, dataBufferSize); + LocalFileWriter indexWriter = createWriter(indexFileName, indexBufferSize); ) { long startTime = System.currentTimeMillis(); for (ShufflePartitionedBlock block : shuffleBlocks) { @@ -131,9 +165,10 @@ public synchronized void write(Collection shuffleBlocks } } - private LocalFileWriter createWriter(String fileName) throws IOException, IllegalStateException { + private LocalFileWriter createWriter(String fileName, int bufferSize) + throws IOException, IllegalStateException { File file = new File(basePath, fileName); - return new LocalFileWriter(file); + return new LocalFileWriter(file, bufferSize); } @VisibleForTesting diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriter.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriter.java index 01c188f3ff..5d5aae7b9f 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriter.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriter.java @@ -24,6 +24,8 @@ import java.io.FileOutputStream; import java.io.IOException; +import com.google.common.annotations.VisibleForTesting; + import org.apache.uniffle.storage.api.FileWriter; import org.apache.uniffle.storage.common.FileBasedShuffleSegment; @@ -33,10 +35,15 @@ public class LocalFileWriter implements FileWriter, Closeable { private FileOutputStream fileOutputStream; private long nextOffset; + @VisibleForTesting public LocalFileWriter(File file) throws IOException { + this(file, 8 * 1024); + } + + public LocalFileWriter(File file, int bufferSize) throws IOException { fileOutputStream = new FileOutputStream(file, true); // init fsDataOutputStream - dataOutputStream = new DataOutputStream(new BufferedOutputStream(fileOutputStream)); + dataOutputStream = new DataOutputStream(new BufferedOutputStream(fileOutputStream, bufferSize)); nextOffset = file.length(); } diff --git a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleWriteHandlerRequest.java b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleWriteHandlerRequest.java index 0d9c21f490..7ed7e6f502 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleWriteHandlerRequest.java +++ b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleWriteHandlerRequest.java @@ -17,10 +17,14 @@ package org.apache.uniffle.storage.request; +import com.google.common.annotations.VisibleForTesting; import org.apache.hadoop.conf.Configuration; +import org.apache.uniffle.common.config.RssBaseConf; + public class CreateShuffleWriteHandlerRequest { + private RssBaseConf rssBaseConf; private String storageType; private String appId; private int shuffleId; @@ -33,6 +37,7 @@ public class CreateShuffleWriteHandlerRequest { private String user; private int maxFileNumber; + @VisibleForTesting public CreateShuffleWriteHandlerRequest( String storageType, String appId, @@ -45,6 +50,7 @@ public CreateShuffleWriteHandlerRequest( int storageDataReplica, String user) { this( + new RssBaseConf(), storageType, appId, shuffleId, @@ -59,6 +65,7 @@ public CreateShuffleWriteHandlerRequest( } public CreateShuffleWriteHandlerRequest( + RssBaseConf rssBaseConf, String storageType, String appId, int shuffleId, @@ -70,6 +77,7 @@ public CreateShuffleWriteHandlerRequest( int storageDataReplica, String user, int maxFileNumber) { + this.rssBaseConf = rssBaseConf; this.storageType = storageType; this.appId = appId; this.shuffleId = shuffleId; @@ -83,6 +91,10 @@ public CreateShuffleWriteHandlerRequest( this.maxFileNumber = maxFileNumber; } + public RssBaseConf getRssBaseConf() { + return rssBaseConf; + } + public String getStorageType() { return storageType; }