diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index c435651045..1d463c999f 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -386,9 +386,17 @@ public List buildBlockEvents(List shuffleBlockI + " bytes"); // Use final temporary variables for closures final long _memoryUsed = memoryUsed; + final List finalShuffleBlockInfosPerEvent = shuffleBlockInfoList; events.add( new AddBlockEvent( - taskId, shuffleBlockInfosPerEvent, () -> freeAllocatedMemory(_memoryUsed))); + taskId, + shuffleBlockInfosPerEvent, + () -> { + freeAllocatedMemory(_memoryUsed); + for (ShuffleBlockInfo shuffleBlockInfo : finalShuffleBlockInfosPerEvent) { + shuffleBlockInfo.getData().release(); + } + })); } return events; } diff --git a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java index fe0df6d158..8de75d90d4 100644 --- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java +++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java @@ -22,6 +22,8 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import org.apache.uniffle.common.util.ByteBufUtils; + public class ShuffleBlockInfo { private int partitionId; @@ -150,4 +152,8 @@ public String toString() { return sb.toString(); } + + public synchronized void copyDataTo(ByteBuf to) { + ByteBufUtils.copyByteBuf(data, to); + } } diff --git a/common/src/main/java/org/apache/uniffle/common/config/RssConf.java b/common/src/main/java/org/apache/uniffle/common/config/RssConf.java index 73e787b97d..805dc30a88 100644 --- a/common/src/main/java/org/apache/uniffle/common/config/RssConf.java +++ b/common/src/main/java/org/apache/uniffle/common/config/RssConf.java @@ -26,6 +26,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Sets; import org.apache.uniffle.common.util.UnitConverter; @@ -665,4 +666,9 @@ public String toString() { public String getEnv(String key) { return System.getenv(key); } + + @VisibleForTesting + public void remove(String key) { + this.settings.remove(key); + } } diff --git a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java index c417cedc26..5221cc4287 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java @@ -55,6 +55,7 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) } catch (Exception e) { LOG.error("Unexpected exception during process encode!", e); byteBuf.release(); + throw e; } ctx.writeAndFlush(byteBuf); // do transferTo send data after encode buffer send. diff --git a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java index 674eb4dbb1..acd79a60c4 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java @@ -69,7 +69,7 @@ public ChannelFuture sendRpc(Message message, RpcResponseCallback callback) { if (logger.isTraceEnabled()) { logger.trace("Pushing data to {}", NettyUtils.getRemoteAddress(channel)); } - long requestId = requestId(); + long requestId = message.getRequestId(); handler.addResponseCallback(requestId, callback); RpcChannelListener listener = new RpcChannelListener(requestId, callback); return channel.writeAndFlush(message).addListener(listener); diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encoders.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encoders.java index 7024ef8534..b74a517b90 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encoders.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encoders.java @@ -43,8 +43,7 @@ public static void encodeShuffleBlockInfo(ShuffleBlockInfo shuffleBlockInfo, Byt byteBuf.writeLong(shuffleBlockInfo.getCrc()); byteBuf.writeLong(shuffleBlockInfo.getTaskAttemptId()); // todo: avoid copy - ByteBufUtils.copyByteBuf(shuffleBlockInfo.getData(), byteBuf); - shuffleBlockInfo.getData().release(); + shuffleBlockInfo.copyDataTo(byteBuf); List shuffleServerInfoList = shuffleBlockInfo.getShuffleServerInfos(); byteBuf.writeInt(shuffleServerInfoList.size()); for (ShuffleServerInfo shuffleServerInfo : shuffleServerInfoList) { @@ -64,7 +63,8 @@ public static int encodeLengthOfShuffleBlockInfo(ShuffleBlockInfo shuffleBlockIn int encodeLength = 4 * Long.BYTES + 4 * Integer.BYTES - + ByteBufUtils.encodedLength(shuffleBlockInfo.getData()) + + Integer.BYTES + + shuffleBlockInfo.getLength() + Integer.BYTES; for (ShuffleServerInfo shuffleServerInfo : shuffleBlockInfo.getShuffleServerInfos()) { encodeLength += encodeLengthOfShuffleServerInfo(shuffleServerInfo); diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java index b0a3da1f0b..c019099fdb 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java @@ -131,4 +131,6 @@ public static Message decode(Type msgType, ByteBuf in) { throw new IllegalArgumentException("Unexpected message type: " + msgType); } } + + public abstract long getRequestId(); } diff --git a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java index 09473dd2f0..48ff31ffa3 100644 --- a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java +++ b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java @@ -50,7 +50,7 @@ public void testSendShuffleDataRequest() { 1, 1, 1, - 10, + data.length, 123, Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList, @@ -61,7 +61,7 @@ public void testSendShuffleDataRequest() { 1, 1, 1, - 10, + data.length, 123, Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList, @@ -74,7 +74,7 @@ public void testSendShuffleDataRequest() { 1, 2, 1, - 10, + data.length, 123, Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList, @@ -85,7 +85,7 @@ public void testSendShuffleDataRequest() { 1, 1, 2, - 10, + data.length, 123, Unpooled.wrappedBuffer(data).retain(), shuffleServerInfoList, diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java index f252abd0d5..88a04190f4 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java @@ -73,6 +73,7 @@ public static void setupServers() throws Exception { coordinatorConf.setLong("rss.coordinator.server.heartbeat.timeout", 3000); createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(); + shuffleServerConf.remove(ShuffleServerConf.NETTY_SERVER_PORT.key()); createShuffleServer(shuffleServerConf); shuffleServerConf.setInteger("rss.rpc.server.port", SHUFFLE_SERVER_PORT + 1); shuffleServerConf.setInteger("rss.jetty.http.port", 18081); @@ -155,6 +156,7 @@ public void getShuffleAssignmentsTest() throws Exception { withEnvironmentVariables("RSS_ENV_KEY", storageTypeJsonSource) .execute( () -> { + shuffleServerConf.remove(ShuffleServerConf.NETTY_SERVER_PORT.key()); ShuffleServer ss = new ShuffleServer((ShuffleServerConf) shuffleServerConf); ss.start(); shuffleServers.set(0, ss); diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/IntegrationTestBase.java b/integration-test/common/src/test/java/org/apache/uniffle/test/IntegrationTestBase.java index aa1f311bcc..e73ccfa7f1 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/IntegrationTestBase.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/IntegrationTestBase.java @@ -23,6 +23,7 @@ import java.nio.file.Files; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import com.google.common.collect.Lists; import org.junit.jupiter.api.AfterAll; @@ -60,6 +61,9 @@ public abstract class IntegrationTestBase extends HadoopTestBase { protected static List shuffleServers = Lists.newArrayList(); protected static List coordinators = Lists.newArrayList(); + protected static final int NETTY_PORT = 21000; + protected static AtomicInteger nettyPortCounter = new AtomicInteger(); + public static void startServers() throws Exception { for (CoordinatorServer coordinator : coordinators) { coordinator.start(); @@ -123,6 +127,9 @@ protected static ShuffleServerConf getShuffleServerConf() throws Exception { serverConf.setBoolean("rss.server.health.check.enable", false); serverConf.setBoolean(ShuffleServerConf.RSS_TEST_MODE_ENABLE, true); serverConf.set(ShuffleServerConf.SERVER_TRIGGER_FLUSH_CHECK_INTERVAL, 500L); + serverConf.setInteger( + ShuffleServerConf.NETTY_SERVER_PORT, NETTY_PORT + nettyPortCounter.getAndIncrement()); + serverConf.setString("rss.server.tags", "GRPC,GRPC_NETTY"); return serverConf; } diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleUnregisterWithHadoopTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleUnregisterWithHadoopTest.java index a824a645d1..12299ea3eb 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleUnregisterWithHadoopTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleUnregisterWithHadoopTest.java @@ -99,9 +99,8 @@ public Map runTest(SparkSession spark, String fileName) throws Exception { map = javaPairRDD.collectAsMap(); shufflePath = appPath + "/1"; assertTrue(fs.exists(new Path(shufflePath))); - } else { - runCounter++; } + runCounter++; return map; } } diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleUnregisterWithLocalfileTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleUnregisterWithLocalfileTest.java index 814e0f4c01..e68cc74a7a 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleUnregisterWithLocalfileTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleUnregisterWithLocalfileTest.java @@ -105,9 +105,8 @@ public Map runTest(SparkSession spark, String fileName) throws Exception { map = javaPairRDD.collectAsMap(); shufflePath = appPath + "/1"; assertTrue(new File(shufflePath).exists()); - } else { - runCounter++; } + runCounter++; return map; } } diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java index de87ac8fa8..ac37f03ceb 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java @@ -58,15 +58,23 @@ public void run() throws Exception { updateSparkConfCustomer(sparkConf); start = System.currentTimeMillis(); Map resultWithRss = runSparkApp(sparkConf, fileName); - long durationWithRss = System.currentTimeMillis() - start; + final long durationWithRss = System.currentTimeMillis() - start; + updateSparkConfWithRssNetty(sparkConf); + start = System.currentTimeMillis(); + Map resultWithRssNetty = runSparkApp(sparkConf, fileName); + final long durationWithRssNetty = System.currentTimeMillis() - start; verifyTestResult(resultWithoutRss, resultWithRss); + verifyTestResult(resultWithoutRss, resultWithRssNetty); LOG.info( "Test: durationWithoutRss[" + durationWithoutRss + "], durationWithRss[" + durationWithRss + + "]" + + "], durationWithRssNetty[" + + durationWithRssNetty + "]"); } @@ -110,6 +118,10 @@ public void updateSparkConfWithRss(SparkConf sparkConf) { sparkConf.set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true"); } + public void updateSparkConfWithRssNetty(SparkConf sparkConf) { + sparkConf.set(RssSparkConfig.RSS_CLIENT_TYPE, "GRPC_NETTY"); + } + protected void verifyTestResult(Map expected, Map actual) { assertEquals(expected.size(), actual.size()); for (Object expectedKey : expected.keySet()) { diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/ContinuousSelectPartitionStrategyTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/ContinuousSelectPartitionStrategyTest.java index 410e95e18c..f15d26bf6c 100644 --- a/integration-test/spark3/src/test/java/org/apache/uniffle/test/ContinuousSelectPartitionStrategyTest.java +++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/ContinuousSelectPartitionStrategyTest.java @@ -79,6 +79,7 @@ private static void createShuffleServers() throws Exception { ShuffleServerConf serverConf = new ShuffleServerConf(); dataFolder.deleteOnExit(); serverConf.setInteger("rss.rpc.server.port", SHUFFLE_SERVER_PORT + i); + serverConf.setInteger("rss.server.netty.port", NETTY_PORT + i); serverConf.setString("rss.storage.type", StorageType.MEMORY_LOCALFILE_HDFS.name()); serverConf.setString("rss.storage.basePath", dataFolder.getAbsolutePath()); serverConf.setString("rss.server.buffer.capacity", String.valueOf(671088640 - i)); @@ -94,6 +95,7 @@ private static void createShuffleServers() throws Exception { serverConf.setString("rss.server.hadoop.dfs.replication", "2"); serverConf.setLong("rss.server.disk.capacity", 10L * 1024L * 1024L * 1024L); serverConf.setBoolean("rss.server.health.check.enable", false); + serverConf.setString("rss.server.tags", "GRPC,GRPC_NETTY"); createMockedShuffleServer(serverConf); } enableRecordGetShuffleResult(); diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java index 9d4c98b877..f7944ceb3f 100644 --- a/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java +++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java @@ -83,6 +83,7 @@ private static void createShuffleServers() throws Exception { ShuffleServerConf serverConf = new ShuffleServerConf(); dataFolder.deleteOnExit(); serverConf.setInteger("rss.rpc.server.port", SHUFFLE_SERVER_PORT + i); + serverConf.setInteger("rss.server.netty.port", NETTY_PORT + i); serverConf.setString("rss.storage.type", StorageType.MEMORY_LOCALFILE_HDFS.name()); serverConf.setString("rss.storage.basePath", dataFolder.getAbsolutePath()); serverConf.setString("rss.server.buffer.capacity", "671088640"); @@ -98,6 +99,7 @@ private static void createShuffleServers() throws Exception { serverConf.setString("rss.server.hadoop.dfs.replication", "2"); serverConf.setLong("rss.server.disk.capacity", 10L * 1024L * 1024L * 1024L); serverConf.setBoolean("rss.server.health.check.enable", false); + serverConf.setString("rss.server.tags", "GRPC,GRPC_NETTY"); createMockedShuffleServer(serverConf); } enableRecordGetShuffleResult(); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/factory/CoordinatorClientFactory.java b/internal-client/src/main/java/org/apache/uniffle/client/factory/CoordinatorClientFactory.java index 2ba77d8d15..b1744dce63 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/factory/CoordinatorClientFactory.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/factory/CoordinatorClientFactory.java @@ -40,7 +40,7 @@ public CoordinatorClientFactory(ClientType clientType) { } public CoordinatorClient createCoordinatorClient(String host, int port) { - if (clientType.equals(ClientType.GRPC)) { + if (clientType.equals(ClientType.GRPC) || clientType.equals(ClientType.GRPC_NETTY)) { return new CoordinatorGrpcClient(host, port); } else { throw new UnsupportedOperationException("Unsupported client type " + clientType); diff --git a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java index 5c3d2f1a41..6abbd16ec3 100644 --- a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java +++ b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java @@ -133,7 +133,7 @@ public void handleSendShuffleDataRequest(TransportClient client, SendShuffleData responseMessage = errorMsg; rpcResponse = new RpcResponse(req.getRequestId(), StatusCode.INTERNAL_ERROR, responseMessage); - client.sendRpcSync(rpcResponse, RPC_TIMEOUT); + client.getChannel().writeAndFlush(rpcResponse); return; } final long start = System.currentTimeMillis(); @@ -209,7 +209,7 @@ public void handleSendShuffleDataRequest(TransportClient client, SendShuffleData new RpcResponse(req.getRequestId(), StatusCode.INTERNAL_ERROR, "No data in request"); } - client.sendRpcSync(rpcResponse, RPC_TIMEOUT); + client.getChannel().writeAndFlush(rpcResponse); } public void handleGetMemoryShuffleDataRequest( @@ -292,7 +292,7 @@ public void handleGetMemoryShuffleDataRequest( new GetMemoryShuffleDataResponse( req.getRequestId(), status, msg, Lists.newArrayList(), Unpooled.EMPTY_BUFFER); } - client.sendRpcSync(response, RPC_TIMEOUT); + client.getChannel().writeAndFlush(response); } public void handleGetLocalShuffleIndexRequest( @@ -374,7 +374,7 @@ public void handleGetLocalShuffleIndexRequest( new GetLocalShuffleIndexResponse( req.getRequestId(), status, msg, Unpooled.EMPTY_BUFFER, 0L); } - client.sendRpcSync(response, RPC_TIMEOUT); + client.getChannel().writeAndFlush(response); } public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDataRequest req) { @@ -471,7 +471,7 @@ public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDat new GetLocalShuffleDataResponse( req.getRequestId(), status, msg, new NettyManagedBuffer(Unpooled.EMPTY_BUFFER)); } - client.sendRpcSync(response, RPC_TIMEOUT); + client.getChannel().writeAndFlush(response); } private List toPartitionedData(SendShuffleDataRequest req) {