diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java index 4403120d52..06973d0797 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java @@ -80,6 +80,7 @@ import org.apache.uniffle.common.util.Constants; import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.hadoop.shim.HadoopShimImpl; +import org.apache.uniffle.proto.RssProtos.MergeContext; import org.apache.uniffle.storage.util.StorageType; import static org.apache.hadoop.mapreduce.RssMRConfig.RSS_REMOTE_MERGE_CLASS_LOADER; @@ -285,17 +286,20 @@ public Thread newThread(Runnable r) { RssMRConfig.toRssConf(conf) .get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE), 0, - remoteMergeEnable ? conf.getMapOutputKeyClass().getName() : null, remoteMergeEnable - ? conf.getMapOutputValueClass().getName() - : null, - remoteMergeEnable - ? conf.getOutputKeyComparator().getClass().getName() - : null, - conf.getInt( - RssMRConfig.RSS_MERGED_BLOCK_SZIE, - RssMRConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT), - conf.get(RSS_REMOTE_MERGE_CLASS_LOADER))); + ? MergeContext.newBuilder() + .setKeyClass(conf.getMapOutputKeyClass().getName()) + .setValueClass(conf.getMapOutputValueClass().getName()) + .setComparatorClass( + conf.getOutputKeyComparator().getClass().getName()) + .setMergedBlockSize( + conf.getInt( + RssMRConfig.RSS_MERGED_BLOCK_SZIE, + RssMRConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT)) + .setMergeClassLoader( + conf.get(RSS_REMOTE_MERGE_CLASS_LOADER, "")) + .build() + : null)); LOG.info( "Finish register shuffle with " + (System.currentTimeMillis() - start) diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java index 813ea1218f..0385cb58e0 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java @@ -62,6 +62,7 @@ import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.util.JavaUtils; +import org.apache.uniffle.proto.RssProtos; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -722,11 +723,7 @@ public void registerShuffle( ShuffleDataDistributionType distributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) {} + RssProtos.MergeContext mergeContext) {} @Override public boolean sendCommit( diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java index a66f16f56d..d2aaebe045 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java @@ -77,6 +77,7 @@ import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.hadoop.shim.HadoopShimImpl; +import org.apache.uniffle.proto.RssProtos; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; @@ -507,11 +508,7 @@ public void registerShuffle( ShuffleDataDistributionType distributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) {} + RssProtos.MergeContext mergeContext) {} @Override public boolean sendCommit( diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index d82d3a509f..47f9e271de 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -1028,10 +1028,6 @@ protected void registerShuffleServers( ShuffleDataDistributionType.NORMAL, maxConcurrencyPerPartitionToWrite, stageAttemptNumber, - null, - null, - null, - -1, null); }); LOG.info( diff --git a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java index 85a138c13e..f44ad0c5e2 100644 --- a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java +++ b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java @@ -66,6 +66,7 @@ import org.apache.uniffle.common.util.Constants; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.common.util.RetryUtils; +import org.apache.uniffle.proto.RssProtos; import static org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE; @@ -305,13 +306,23 @@ public ShuffleAssignmentsInfo run() throws Exception { RssTezConfig.toRssConf(conf) .get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE), 0, - keyClassName, - valueClassName, - comparatorClassName, - conf.getInt( - RssTezConfig.RSS_MERGED_BLOCK_SZIE, - RssTezConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT), - conf.get(RssTezConfig.RSS_REMOTE_MERGE_CLASS_LOADER))); + StringUtils.isBlank(keyClassName) + ? null + : RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClassName) + .setValueClass(valueClassName) + .setComparatorClass(comparatorClassName) + .setMergedBlockSize( + conf.getInt( + RssTezConfig.RSS_MERGED_BLOCK_SZIE, + RssTezConfig + .RSS_MERGED_BLOCK_SZIE_DEFAULT)) + .setMergeClassLoader( + conf.get( + RssTezConfig + .RSS_REMOTE_MERGE_CLASS_LOADER, + "")) + .build())); LOG.info( "Finish register shuffle with " + (System.currentTimeMillis() - start) diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java index 724cec6c08..ce0458219f 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java @@ -78,6 +78,7 @@ import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.util.JavaUtils; +import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.storage.util.StorageType; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -719,11 +720,7 @@ public void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) {} + RssProtos.MergeContext mergeContext) {} @Override public boolean sendCommit( diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java index 121271e361..d21c7e67b7 100644 --- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java +++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java @@ -33,6 +33,7 @@ import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.proto.RssProtos.MergeContext; public interface ShuffleWriteClient { @@ -72,10 +73,6 @@ default void registerShuffle( dataDistributionType, maxConcurrencyPerPartitionToWrite, 0, - null, - null, - null, - -1, null); } @@ -88,11 +85,7 @@ void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader); + MergeContext mergeContext); boolean sendCommit( Set shuffleServerInfoSet, String appId, int shuffleId, int numMaps); diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index cba7ccc065..c81d3c7255 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -95,6 +95,7 @@ import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.common.util.ThreadUtils; +import org.apache.uniffle.proto.RssProtos.MergeContext; public class ShuffleWriteClientImpl implements ShuffleWriteClient { @@ -564,11 +565,7 @@ public void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) { + MergeContext mergeContext) { String user = null; try { user = UserGroupInformation.getCurrentUser().getShortUserName(); @@ -586,11 +583,7 @@ public void registerShuffle( dataDistributionType, maxConcurrencyPerPartitionToWrite, stageAttemptNumber, - keyClassName, - valueClassName, - comparatorClassName, - mergedBlockSize, - mergeClassLoader); + mergeContext); RssRegisterShuffleResponse response = getShuffleServerClient(shuffleServerInfo).registerShuffle(request); diff --git a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java index 7d4cbf980b..6798a792cb 100644 --- a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java +++ b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java @@ -34,6 +34,7 @@ import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.proto.RssProtos; public class MockedShuffleWriteClient implements ShuffleWriteClient { @@ -63,11 +64,7 @@ public void registerShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) {} + RssProtos.MergeContext mergeContext) {} @Override public boolean sendCommit( diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java index 17ec0c2126..f341175864 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java @@ -64,6 +64,7 @@ import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.ChecksumUtils; import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.buffer.ShuffleBufferType; import org.apache.uniffle.storage.util.StorageType; @@ -174,11 +175,13 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { ShuffleDataDistributionType.NORMAL, 0, -1, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // task 0 attempt 0 generate three blocks @@ -337,11 +340,13 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception ShuffleDataDistributionType.NORMAL, 0, -1, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // task 0 attempt 0 generate three blocks @@ -508,11 +513,13 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except ShuffleDataDistributionType.NORMAL, 0, -1, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // this shuffle have three partition, which is hash by key index mode 3 @@ -714,11 +721,13 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th ShuffleDataDistributionType.NORMAL, 0, -1, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // this shuffle have three partition, which is hash by key index mode 3 diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java index c0af5e1bff..d12b286c2a 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java @@ -64,6 +64,7 @@ import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.ChecksumUtils; import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.buffer.ShuffleBufferType; import org.apache.uniffle.storage.util.StorageType; @@ -179,11 +180,13 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { ShuffleDataDistributionType.NORMAL, -1, 0, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // task 0 attempt 0 generate three blocks @@ -342,11 +345,13 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception ShuffleDataDistributionType.NORMAL, -1, 0, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); // 3 report shuffle result @@ -514,11 +519,13 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except ShuffleDataDistributionType.NORMAL, -1, 0, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // this shuffle have three partition, which is hash by key index mode 3 @@ -721,11 +728,13 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th ShuffleDataDistributionType.NORMAL, -1, 0, - keyClass.getName(), - valueClass.getName(), - comparator.getClass().getName(), - -1, - null); + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClass.getName()) + .setValueClass(valueClass.getName()) + .setComparatorClass(comparator.getClass().getName()) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 3 report shuffle result // this shuffle have three partition, which is hash by key index mode 3 diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index 63081041d6..20b6bf98b1 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -31,7 +31,6 @@ import com.google.protobuf.ByteString; import com.google.protobuf.UnsafeByteOperations; import io.netty.buffer.Unpooled; -import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -95,6 +94,7 @@ import org.apache.uniffle.proto.RssProtos.GetShuffleResultForMultiPartResponse; import org.apache.uniffle.proto.RssProtos.GetShuffleResultRequest; import org.apache.uniffle.proto.RssProtos.GetShuffleResultResponse; +import org.apache.uniffle.proto.RssProtos.MergeContext; import org.apache.uniffle.proto.RssProtos.PartitionToBlockIds; import org.apache.uniffle.proto.RssProtos.RemoteStorage; import org.apache.uniffle.proto.RssProtos.RemoteStorageConfItem; @@ -198,11 +198,7 @@ private ShuffleRegisterResponse doRegisterShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) { + MergeContext mergeContext) { ShuffleRegisterRequest.Builder reqBuilder = ShuffleRegisterRequest.newBuilder(); reqBuilder .setAppId(appId) @@ -212,16 +208,8 @@ private ShuffleRegisterResponse doRegisterShuffle( .setMaxConcurrencyPerPartitionToWrite(maxConcurrencyPerPartitionToWrite) .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges)) .setStageAttemptNumber(stageAttemptNumber); - if (StringUtils.isNotBlank(keyClassName)) { - reqBuilder.setKeyClass(keyClassName); - reqBuilder.setValueClass(valueClassName); - if (StringUtils.isNotBlank(comparatorClassName)) { - reqBuilder.setComparatorClass(comparatorClassName); - } - reqBuilder.setMergedBlockSize(mergedBlockSize); - if (StringUtils.isNotBlank(mergeClassLoader)) { - reqBuilder.setMergeClassLoader(mergeClassLoader); - } + if (mergeContext != null) { + reqBuilder.setMergeContext(mergeContext); } RemoteStorage.Builder rsBuilder = RemoteStorage.newBuilder(); rsBuilder.setPath(remoteStorageInfo.getPath()); @@ -496,11 +484,7 @@ public RssRegisterShuffleResponse registerShuffle(RssRegisterShuffleRequest requ request.getDataDistributionType(), request.getMaxConcurrencyPerPartitionToWrite(), request.getStageAttemptNumber(), - request.getKeyClassName(), - request.getValueClassName(), - request.getComparatorClassName(), - request.getMergedBlockSize(), - request.getMergeClassLoader()); + request.getMergeContext()); RssRegisterShuffleResponse response; RssProtos.StatusCode statusCode = rpcResponse.getStatus(); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java index 1db40a0d1f..92ed1e15e9 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java @@ -25,6 +25,7 @@ import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.config.RssClientConf; +import org.apache.uniffle.proto.RssProtos.MergeContext; public class RssRegisterShuffleRequest { @@ -36,11 +37,8 @@ public class RssRegisterShuffleRequest { private ShuffleDataDistributionType dataDistributionType; private int maxConcurrencyPerPartitionToWrite; private int stageAttemptNumber; - private String keyClassName; - private String valueClassName; - private String comparatorClassName; - private int mergedBlockSize; - private String mergeClassLoader; + + private final MergeContext mergeContext; public RssRegisterShuffleRequest( String appId, @@ -59,10 +57,6 @@ public RssRegisterShuffleRequest( dataDistributionType, maxConcurrencyPerPartitionToWrite, 0, - null, - null, - null, - -1, null); } @@ -75,11 +69,7 @@ public RssRegisterShuffleRequest( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String mergeClassLoader) { + MergeContext mergeContext) { this.appId = appId; this.shuffleId = shuffleId; this.partitionRanges = partitionRanges; @@ -88,11 +78,7 @@ public RssRegisterShuffleRequest( this.dataDistributionType = dataDistributionType; this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite; this.stageAttemptNumber = stageAttemptNumber; - this.keyClassName = keyClassName; - this.valueClassName = valueClassName; - this.comparatorClassName = comparatorClassName; - this.mergedBlockSize = mergedBlockSize; - this.mergeClassLoader = mergeClassLoader; + this.mergeContext = mergeContext; } public RssRegisterShuffleRequest( @@ -111,10 +97,6 @@ public RssRegisterShuffleRequest( dataDistributionType, RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), 0, - null, - null, - null, - -1, null); } @@ -129,10 +111,6 @@ public RssRegisterShuffleRequest( ShuffleDataDistributionType.NORMAL, RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), 0, - null, - null, - null, - -1, null); } @@ -168,23 +146,7 @@ public int getStageAttemptNumber() { return stageAttemptNumber; } - public String getKeyClassName() { - return keyClassName; - } - - public String getValueClassName() { - return valueClassName; - } - - public String getComparatorClassName() { - return comparatorClassName; - } - - public int getMergedBlockSize() { - return mergedBlockSize; - } - - public String getMergeClassLoader() { - return mergeClassLoader; + public MergeContext getMergeContext() { + return mergeContext; } } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index 7e4b19696b..d92ec40c7a 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -179,6 +179,14 @@ message ShufflePartitionRange { int32 end = 2; } +message MergeContext { + string keyClass = 1; + string valueClass = 2; + string comparatorClass = 3; + int32 mergedBlockSize = 4; + string mergeClassLoader = 5; +} + message ShuffleRegisterRequest { string appId = 1; int32 shuffleId = 2; @@ -188,11 +196,7 @@ message ShuffleRegisterRequest { DataDistribution shuffleDataDistribution = 6; int32 maxConcurrencyPerPartitionToWrite = 7; int32 stageAttemptNumber = 8; - string keyClass = 9; - string valueClass = 10; - string comparatorClass = 11; - int32 mergedBlockSize = 12; - string mergeClassLoader = 13; + MergeContext mergeContext = 9; } enum DataDistribution { diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index ee780b872c..994a25c890 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -325,7 +325,7 @@ public void registerShuffle( maxConcurrencyPerPartitionToWrite); if (StatusCode.SUCCESS == result && shuffleServer.isRemoteMergeEnable() - && StringUtils.isNotBlank(req.getKeyClass())) { + && req.hasMergeContext()) { // The merged block is in a different domain from the original block, // so you need to register a new app for holding the merged block. result = @@ -343,14 +343,7 @@ public void registerShuffle( result = shuffleServer .getShuffleMergeManager() - .registerShuffle( - appId, - shuffleId, - req.getKeyClass(), - req.getValueClass(), - req.getComparatorClass(), - req.getMergedBlockSize(), - req.getMergeClassLoader()); + .registerShuffle(appId, shuffleId, req.getMergeContext()); } } auditContext.withStatusCode(result); diff --git a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java index f50bfd2715..f8a6c1bfd2 100644 --- a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java +++ b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java @@ -43,6 +43,7 @@ import org.apache.uniffle.common.merger.Segment; import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.util.JavaUtils; +import org.apache.uniffle.proto.RssProtos.MergeContext; import org.apache.uniffle.server.ShuffleServer; import org.apache.uniffle.server.ShuffleServerConf; @@ -145,22 +146,16 @@ public ClassLoader getClassLoader(String label) { return cachedClassLoader.getOrDefault(label, cachedClassLoader.get("")); } - public StatusCode registerShuffle( - String appId, - int shuffleId, - String keyClassName, - String valueClassName, - String comparatorClassName, - int mergedBlockSize, - String classLoaderLabel) { + public StatusCode registerShuffle(String appId, int shuffleId, MergeContext mergeContext) { try { - ClassLoader classLoader = getClassLoader(classLoaderLabel); - Class kClass = ClassUtils.getClass(classLoader, keyClassName); - Class vClass = ClassUtils.getClass(classLoader, valueClassName); + ClassLoader classLoader = getClassLoader(mergeContext.getMergeClassLoader()); + Class kClass = ClassUtils.getClass(classLoader, mergeContext.getKeyClass()); + Class vClass = ClassUtils.getClass(classLoader, mergeContext.getValueClass()); Comparator comparator; - if (StringUtils.isNotBlank(comparatorClassName)) { + if (StringUtils.isNotBlank(mergeContext.getComparatorClass())) { Constructor constructor = - ClassUtils.getClass(classLoader, comparatorClassName).getDeclaredConstructor(); + ClassUtils.getClass(classLoader, mergeContext.getComparatorClass()) + .getDeclaredConstructor(); constructor.setAccessible(true); comparator = (Comparator) constructor.newInstance(); } else { @@ -180,7 +175,7 @@ public StatusCode registerShuffle( kClass, vClass, comparator, - mergedBlockSize, + mergeContext.getMergedBlockSize(), classLoader)); } catch (ClassNotFoundException | InstantiationException diff --git a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java index 449881dc33..4ea82750c7 100644 --- a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java +++ b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java @@ -45,6 +45,7 @@ import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.serializer.writable.WritableSerializer; import org.apache.uniffle.common.util.BlockIdLayout; +import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.server.ShuffleServer; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.ShuffleServerMetrics; @@ -131,7 +132,15 @@ public void testMergerManager(String classes, @TempDir File tmpDir) throws Excep new RemoteStorageInfo(""), USER); mergeManager.registerShuffle( - APP_ID, SHUFFLE_ID, keyClassName, valueClassName, comparatorClassName, -1, ""); + APP_ID, + SHUFFLE_ID, + RssProtos.MergeContext.newBuilder() + .setKeyClass(keyClassName) + .setValueClass(valueClassName) + .setComparatorClass(comparatorClassName) + .setMergedBlockSize(-1) + .setMergeClassLoader("") + .build()); // 4 report blocks // 4.1 send shuffle data