diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java index 9b486629c2..c7a17b9b44 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java @@ -671,10 +671,10 @@ public static Stream testBlockIdLayouts() { @ParameterizedTest @MethodSource("testBlockIdLayouts") public void multipleShuffleResultTest(BlockIdLayout layout) throws Exception { + String appId = "multipleShuffleResultTest_" + layout.sequenceNoBits; Set expectedBlockIds = Sets.newConcurrentHashSet(); RssRegisterShuffleRequest rrsr = - new RssRegisterShuffleRequest( - "multipleShuffleResultTest", 100, Lists.newArrayList(new PartitionRange(0, 1)), ""); + new RssRegisterShuffleRequest(appId, 100, Lists.newArrayList(new PartitionRange(0, 1)), ""); grpcShuffleServerClient.registerShuffle(rrsr); Runnable r1 = @@ -687,7 +687,7 @@ public void multipleShuffleResultTest(BlockIdLayout layout) throws Exception { blockIds.add(blockId); ptbs.put(1, blockIds); RssReportShuffleResultRequest req1 = - new RssReportShuffleResultRequest("multipleShuffleResultTest", 1, 0, ptbs, 1); + new RssReportShuffleResultRequest(appId, 1, 0, ptbs, 1); grpcShuffleServerClient.reportShuffleResult(req1); } }; @@ -701,7 +701,7 @@ public void multipleShuffleResultTest(BlockIdLayout layout) throws Exception { blockIds.add(blockId); ptbs.put(1, blockIds); RssReportShuffleResultRequest req1 = - new RssReportShuffleResultRequest("multipleShuffleResultTest", 1, 1, ptbs, 1); + new RssReportShuffleResultRequest(appId, 1, 1, ptbs, 1); grpcShuffleServerClient.reportShuffleResult(req1); } }; @@ -715,7 +715,7 @@ public void multipleShuffleResultTest(BlockIdLayout layout) throws Exception { blockIds.add(blockId); ptbs.put(1, blockIds); RssReportShuffleResultRequest req1 = - new RssReportShuffleResultRequest("multipleShuffleResultTest", 1, 2, ptbs, 1); + new RssReportShuffleResultRequest(appId, 1, 2, ptbs, 1); grpcShuffleServerClient.reportShuffleResult(req1); } }; @@ -734,8 +734,7 @@ public void multipleShuffleResultTest(BlockIdLayout layout) throws Exception { blockIdBitmap.addLong(blockId); } - RssGetShuffleResultRequest req = - new RssGetShuffleResultRequest("multipleShuffleResultTest", 1, 1, layout); + RssGetShuffleResultRequest req = new RssGetShuffleResultRequest(appId, 1, 1, layout); RssGetShuffleResultResponse result = grpcShuffleServerClient.getShuffleResult(req); Roaring64NavigableMap actualBlockIdBitmap = result.getBlockIdBitmap(); assertEquals(blockIdBitmap, actualBlockIdBitmap); diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index 7b14d981d7..7e1eb88cfc 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -506,14 +506,23 @@ public void reportShuffleResult( "appId[" + appId + "], shuffleId[" + shuffleId + "], taskAttemptId[" + taskAttemptId + "]"; try { + int expectedBlockCount = partitionToBlockIds.values().stream().mapToInt(x -> x.length).sum(); LOG.info( - "Report " - + partitionToBlockIds.size() - + " blocks as shuffle result for the task of " - + requestInfo); - shuffleServer - .getShuffleTaskManager() - .addFinishedBlockIds(appId, shuffleId, partitionToBlockIds, bitmapNum); + "Accepted blockIds report for {} blocks across {} partitions as shuffle result for task {}", + expectedBlockCount, + partitionToBlockIds.size(), + request); + int updatedBlockCount = + shuffleServer + .getShuffleTaskManager() + .addFinishedBlockIds(appId, shuffleId, partitionToBlockIds, bitmapNum); + if (expectedBlockCount != updatedBlockCount) { + LOG.warn( + "Existing {} duplicated blockIds on blockId report for appId: {}, shuffleId: {}", + expectedBlockCount - updatedBlockCount, + appId, + shuffleId); + } } catch (Exception e) { status = StatusCode.INTERNAL_ERROR; msg = "error happened when report shuffle result, check shuffle server for detail"; diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java index f45b1be944..b6806f6343 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java @@ -64,6 +64,8 @@ public class ShuffleTaskInfo { private final AtomicReference specification; + private final Map> partitionBlockCounters; + public ShuffleTaskInfo(String appId) { this.appId = appId; this.currentTimes = System.currentTimeMillis(); @@ -75,6 +77,7 @@ public ShuffleTaskInfo(String appId) { this.hugePartitionTags = JavaUtils.newConcurrentMap(); this.existHugePartition = new AtomicBoolean(false); this.specification = new AtomicReference<>(); + this.partitionBlockCounters = JavaUtils.newConcurrentMap(); } public Long getCurrentTimes() { @@ -198,6 +201,25 @@ public Set getShuffleIds() { return partitionDataSizes.keySet(); } + public void incBlockNumber(int shuffleId, int partitionId, int delta) { + this.partitionBlockCounters + .computeIfAbsent(shuffleId, x -> JavaUtils.newConcurrentMap()) + .computeIfAbsent(partitionId, x -> new AtomicLong()) + .addAndGet(delta); + } + + public long getBlockNumber(int shuffleId, int partitionId) { + Map partitionBlockCounters = this.partitionBlockCounters.get(shuffleId); + if (partitionBlockCounters == null) { + return 0L; + } + AtomicLong counter = partitionBlockCounters.get(partitionId); + if (counter == null) { + return 0L; + } + return counter.get(); + } + @Override public String toString() { return "ShuffleTaskInfo{" diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index 0a07f70f21..a98eac26c2 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -387,7 +387,16 @@ public StatusCode commitShuffle(String appId, int shuffleId) throws Exception { return StatusCode.SUCCESS; } - public void addFinishedBlockIds( + /** + * Add finished blockIds from client + * + * @param appId + * @param shuffleId + * @param partitionToBlockIds + * @param bitmapNum + * @return the number of added blockIds + */ + public int addFinishedBlockIds( String appId, Integer shuffleId, Map partitionToBlockIds, int bitmapNum) { refreshAppId(appId); Map shuffleIdToPartitions = partitionsToBlockIds.get(appId); @@ -413,15 +422,28 @@ public void addFinishedBlockIds( + " bitmaps!"); } + ShuffleTaskInfo taskInfo = getShuffleTaskInfo(appId); + if (taskInfo == null) { + throw new InvalidRequestException( + "ShuffleTaskInfo is not found that should not happen for appId: " + appId); + } + int totalUpdatedBlockCount = 0; for (Map.Entry entry : partitionToBlockIds.entrySet()) { Integer partitionId = entry.getKey(); Roaring64NavigableMap bitmap = blockIds[partitionId % bitmapNum]; + int updatedBlockCount = 0; synchronized (bitmap) { for (long blockId : entry.getValue()) { - bitmap.addLong(blockId); + if (!bitmap.contains(blockId)) { + bitmap.addLong(blockId); + updatedBlockCount++; + totalUpdatedBlockCount++; + } } } + taskInfo.incBlockNumber(shuffleId, partitionId, updatedBlockCount); } + return totalUpdatedBlockCount; } public int updateAndGetCommitCount(String appId, int shuffleId) { @@ -553,13 +575,18 @@ public byte[] getFinishedBlockIds( } Map shuffleIdToPartitions = partitionsToBlockIds.get(appId); if (shuffleIdToPartitions == null) { + LOG.warn("Empty blockIds for app: {}. This should not happen", appId); return null; } Roaring64NavigableMap[] blockIds = shuffleIdToPartitions.get(shuffleId); if (blockIds == null) { + LOG.warn("Empty blockIds for app: {}, shuffleId: {}", appId, shuffleId); return new byte[] {}; } + + ShuffleTaskInfo taskInfo = getShuffleTaskInfo(appId); + long expectedBlockNumber = 0; Map> bitmapIndexToPartitions = Maps.newHashMap(); for (int partitionId : partitions) { int bitmapIndex = partitionId % blockIds.length; @@ -569,6 +596,7 @@ public byte[] getFinishedBlockIds( HashSet newHashSet = Sets.newHashSet(partitionId); bitmapIndexToPartitions.put(bitmapIndex, newHashSet); } + expectedBlockNumber += taskInfo.getBlockNumber(shuffleId, partitionId); } Roaring64NavigableMap res = Roaring64NavigableMap.bitmapOf(); @@ -577,6 +605,17 @@ public byte[] getFinishedBlockIds( Roaring64NavigableMap bitmap = blockIds[entry.getKey()]; getBlockIdsByPartitionId(requestPartitions, bitmap, res, blockIdLayout); } + + if (res.getLongCardinality() != expectedBlockNumber) { + throw new RssException( + "Inconsistent block number for partitions: " + + partitions + + ". Excepted: " + + expectedBlockNumber + + ", actual: " + + res.getLongCardinality()); + } + return RssUtils.serializeBitMap(res); }