diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java index 98c3e4a0e8..13b3f215cd 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java @@ -121,6 +121,8 @@ public void init(Context context) { } Map> serverInfoMap = new HashMap<>(); serverInfoMap.put(partitionId, new ArrayList<>(serverInfoSet)); + String clientType = + rssJobConf.get(RssMRConfig.RSS_CLIENT_TYPE, RssMRConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE); this.reader = new RMRecordsReader( appId, @@ -134,7 +136,8 @@ public void init(Context context) { true, combiner, combiner != null, - new MRMetricsReporter(context.getReporter())); + new MRMetricsReporter(context.getReporter()), + clientType); } @Override diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java index 0385cb58e0..159da8a84d 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java @@ -57,7 +57,7 @@ import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.records.RecordsReader; import org.apache.uniffle.common.rpc.StatusCode; -import org.apache.uniffle.common.serializer.PartialInputStreamImpl; +import org.apache.uniffle.common.serializer.SerInputStream; import org.apache.uniffle.common.serializer.SerializerFactory; import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.SerializerUtils; @@ -523,11 +523,8 @@ public void testWriteNormalWithRemoteMerge() throws Exception { ByteBuf byteBuf = blockInfos.get(0).getData(); RecordsReader reader = new RecordsReader<>( - rssConf, - PartialInputStreamImpl.newInputStream(byteBuf.nioBuffer()), - Text.class, - Text.class, - false); + rssConf, SerInputStream.newInputStream(byteBuf), Text.class, Text.class, false, false); + reader.init(); int index = 0; while (reader.next()) { assertEquals(SerializerUtils.genData(Text.class, index), reader.getCurrentKey()); @@ -610,10 +607,12 @@ public void testWriteNormalWithRemoteMergeAndCombine() throws Exception { RecordsReader reader = new RecordsReader<>( rssConf, - PartialInputStreamImpl.newInputStream(byteBuf.nioBuffer()), + SerInputStream.newInputStream(byteBuf), Text.class, IntWritable.class, + false, false); + reader.init(); int index = 0; while (reader.next()) { int aimValue = index; diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffleTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffleTest.java index 575be5a54b..8eec9d2ff0 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffleTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffleTest.java @@ -17,8 +17,6 @@ package org.apache.hadoop.mapreduce.task.reduce; -import java.io.ByteArrayOutputStream; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -29,6 +27,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import io.netty.buffer.ByteBuf; import org.apache.hadoop.io.DataInputBuffer; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.RawComparator; @@ -58,13 +57,15 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.merger.Merger; import org.apache.uniffle.common.merger.Segment; +import org.apache.uniffle.common.serializer.DynBufferSerOutputStream; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.serializer.Serializer; import org.apache.uniffle.common.serializer.SerializerFactory; import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.SerializerUtils; import static org.apache.uniffle.common.serializer.SerializerUtils.genData; -import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes; +import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; @@ -136,12 +137,10 @@ public void testReadShuffleWithoutCombine() throws Exception { combiner, false, null); - ByteBuffer byteBuffer = - ByteBuffer.wrap( - genSortedRecordBytes(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, 1)); + ByteBuf byteBuf = genSortedRecordBuffer(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, 1); ShuffleServerClient serverClient = new MockedShuffleServerClient( - new int[] {PARTITION_ID}, new ByteBuffer[][] {{byteBuffer}}, blockIds); + new int[] {PARTITION_ID}, new ByteBuf[][] {{byteBuf}}, blockIds); RMRecordsReader readerSpy = spy(reader); doReturn(serverClient).when(readerSpy).createShuffleServerClient(any()); @@ -170,6 +169,7 @@ public void testReadShuffleWithoutCombine() throws Exception { index++; } assertEquals(RECORDS_NUM, index); + byteBuf.release(); } } @@ -219,20 +219,21 @@ public void testReadShuffleWithCombine() throws Exception { List segments = new ArrayList<>(); segments.add( SerializerUtils.genMemorySegment( - rssConf, keyClass, valueClass, 0L, 0, 2, RECORDS_NUM, true)); + rssConf, keyClass, valueClass, 0L, 0, 2, RECORDS_NUM, true, false)); segments.add( SerializerUtils.genMemorySegment( - rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM, true)); + rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM, true, false)); segments.add( SerializerUtils.genMemorySegment( - rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM, true)); - ByteArrayOutputStream output = new ByteArrayOutputStream(); + rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM, true, false)); + segments.forEach(segment -> segment.init()); + SerOutputStream output = new DynBufferSerOutputStream(); Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator, true); output.close(); - ByteBuffer byteBuffer = ByteBuffer.wrap(output.toByteArray()); + ByteBuf byteBuf = output.toByteBuf(); ShuffleServerClient serverClient = new MockedShuffleServerClient( - new int[] {PARTITION_ID}, new ByteBuffer[][] {{byteBuffer}}, blockIds); + new int[] {PARTITION_ID}, new ByteBuf[][] {{byteBuf}}, blockIds); RMRecordsReader reader = new RMRecordsReader( APP_ID, @@ -280,6 +281,7 @@ public void testReadShuffleWithCombine() throws Exception { index++; } assertEquals(RECORDS_NUM * 2, index); + byteBuf.release(); } } } diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffle.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffle.java index 542d251e7e..2b0ae1e7c4 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffle.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffle.java @@ -73,6 +73,7 @@ public class RMRssShuffle implements ExceptionReporter { private ShuffleInputEventHandlerOrderedGrouped eventHandler; private final TezTaskAttemptID tezTaskAttemptID; private final String srcNameTrimmed; + private final String clientType; private Map> partitionToServers; private AtomicBoolean isShutDown = new AtomicBoolean(false); @@ -101,6 +102,8 @@ public RMRssShuffle( this.numInputs = numInputs; this.shuffleId = shuffleId; this.applicationAttemptId = applicationAttemptId; + this.clientType = + conf.get(RssTezConfig.RSS_CLIENT_TYPE, RssTezConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE); this.appId = this.applicationAttemptId.toString(); this.srcNameTrimmed = TezUtilsInternal.cleanVertexName(inputContext.getSourceVertexName()); LOG.info(srcNameTrimmed + ": Shuffle assigned with " + numInputs + " inputs."); @@ -254,7 +257,8 @@ public RMRecordsReader createRMRecordsReader(Set partitionIds) { false, (inc) -> { inputRecordCounter.increment(inc); - }); + }, + this.clientType); } @Override diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleTest.java index b42b57bda5..df5ced4a53 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RMRssShuffleTest.java @@ -20,6 +20,7 @@ import java.lang.reflect.Method; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.Comparator; import java.util.Iterator; import java.util.List; @@ -28,6 +29,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import io.netty.buffer.ByteBuf; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.IntWritable; @@ -74,7 +76,7 @@ import static org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS; import static org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_KEY_COMPARATOR_CLASS; import static org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS; -import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes; +import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; @@ -158,12 +160,9 @@ public void testReadShuffleData() throws Exception { false, null); RMRecordsReader recordsReaderSpy = spy(recordsReader); - ByteBuffer[][] buffers = - new ByteBuffer[][] { - { - ByteBuffer.wrap( - genSortedRecordBytes(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, duplicated)) - } + ByteBuf[][] buffers = + new ByteBuf[][] { + {genSortedRecordBuffer(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, duplicated)} }; ShuffleServerClient serverClient = new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, blockIds); @@ -228,6 +227,7 @@ public void testReadShuffleData() throws Exception { index++; } assertEquals(RECORDS_NUM, index); + Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release())); } @Test @@ -309,15 +309,13 @@ public void testReadMultiPartitionShuffleData() throws Exception { false, null); RMRecordsReader recordsReaderSpy = spy(recordsReader); - ByteBuffer[][] buffers = new ByteBuffer[3][2]; + ByteBuf[][] buffers = new ByteBuf[3][2]; for (int i = 0; i < 3; i++) { buffers[i][0] = - ByteBuffer.wrap( - genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, duplicated)); + genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, duplicated); buffers[i][1] = - ByteBuffer.wrap( - genSortedRecordBytes( - rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, duplicated)); + genSortedRecordBuffer( + rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, duplicated); } ShuffleServerClient serverClient = new MockedShuffleServerClient( @@ -396,6 +394,7 @@ public void testReadMultiPartitionShuffleData() throws Exception { index++; } assertEquals(RECORDS_NUM * 6, index); + Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release())); } public static DataMovementEvent createDataMovementEvent(int partition, String path) diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java index ce0458219f..3765dc83b4 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java @@ -19,7 +19,6 @@ import java.io.File; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -73,7 +72,7 @@ import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.records.RecordsReader; import org.apache.uniffle.common.rpc.StatusCode; -import org.apache.uniffle.common.serializer.PartialInputStreamImpl; +import org.apache.uniffle.common.serializer.SerInputStream; import org.apache.uniffle.common.serializer.SerializerFactory; import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.SerializerUtils; @@ -617,11 +616,8 @@ public void testWriteWithRemoteMerge() throws Exception { buf.readBytes(bytes); RecordsReader reader = new RecordsReader<>( - rssConf, - PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes)), - Text.class, - Text.class, - false); + rssConf, SerInputStream.newInputStream(buf), Text.class, Text.class, false, false); + reader.init(); int index = 0; while (reader.next()) { assertEquals(SerializerUtils.genData(Text.class, index), reader.getCurrentKey()); diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java index 73d399062a..6102c8a176 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferTest.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Random; +import io.netty.buffer.Unpooled; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.IntWritable; @@ -41,8 +42,7 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.serializer.DeserializationStream; -import org.apache.uniffle.common.serializer.PartialInputStream; -import org.apache.uniffle.common.serializer.PartialInputStreamImpl; +import org.apache.uniffle.common.serializer.SerInputStream; import org.apache.uniffle.common.serializer.SerializerFactory; import org.apache.uniffle.common.serializer.SerializerInstance; @@ -203,9 +203,11 @@ public void testReadWriteWithRemoteMergeAndNoSort() throws IOException { buffer.addRecord(genData(Text.class, i), genData(IntWritable.class, i)); } byte[] bytes = buffer.getData(); - PartialInputStream inputStream = PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes)); + SerInputStream inputStream = + SerInputStream.newInputStream(Unpooled.wrappedBuffer(ByteBuffer.wrap(bytes))); DeserializationStream dStream = - instance.deserializeStream(inputStream, Text.class, IntWritable.class, false); + instance.deserializeStream(inputStream, Text.class, IntWritable.class, false, false); + dStream.init(); for (int i = 0; i < RECORDS_NUM; i++) { assertTrue(dStream.nextRecord()); assertEquals(genData(Text.class, i), dStream.getCurrentKey()); @@ -240,9 +242,11 @@ public void testReadWriteWithRemoteMergeAndSort() throws IOException { buffer.addRecord(genData(Text.class, i), genData(IntWritable.class, i)); } byte[] bytes = buffer.getData(); - PartialInputStream inputStream = PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes)); + SerInputStream inputStream = + SerInputStream.newInputStream(Unpooled.wrappedBuffer(ByteBuffer.wrap(bytes))); DeserializationStream dStream = - instance.deserializeStream(inputStream, Text.class, IntWritable.class, false); + instance.deserializeStream(inputStream, Text.class, IntWritable.class, false, false); + dStream.init(); for (int i = 0; i < RECORDS_NUM; i++) { assertTrue(dStream.nextRecord()); assertEquals(genData(Text.class, i), dStream.getCurrentKey()); diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInputTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInputTest.java index 24157c8832..57e0ee72c7 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInputTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/input/RMRssOrderedGroupedKVInputTest.java @@ -17,7 +17,6 @@ package org.apache.tez.runtime.library.input; -import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Comparator; @@ -29,6 +28,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import io.netty.buffer.ByteBuf; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.IntWritable; @@ -72,12 +72,14 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.merger.Merger; import org.apache.uniffle.common.merger.Segment; +import org.apache.uniffle.common.serializer.DynBufferSerOutputStream; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.util.BlockIdLayout; import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID; import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID; -import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes; +import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; @@ -166,10 +168,11 @@ public Void answer(InvocationOnMock invocation) throws Throwable { SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM)); segments.add( SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM)); - ByteArrayOutputStream output = new ByteArrayOutputStream(); + segments.forEach(segment -> segment.init()); + SerOutputStream output = new DynBufferSerOutputStream(); Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator, false); output.close(); - ByteBuffer[][] buffers = new ByteBuffer[][] {{ByteBuffer.wrap(output.toByteArray())}}; + ByteBuf[][] buffers = new ByteBuf[][] {{output.toByteBuf()}}; ShuffleServerClient serverClient = new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, blockIds); RMRecordsReader recordsReader = @@ -362,15 +365,12 @@ public Void answer(InvocationOnMock invocation) throws Throwable { false, null); RMRecordsReader recordsReaderSpy = spy(recordsReader); - ByteBuffer[][] buffers = new ByteBuffer[3][2]; + ByteBuf[][] buffers = new ByteBuf[3][2]; for (int i = 0; i < 3; i++) { - buffers[i][0] = - ByteBuffer.wrap( - genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 1)); + buffers[i][0] = genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 1); buffers[i][1] = - ByteBuffer.wrap( - genSortedRecordBytes( - rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 1)); + genSortedRecordBuffer( + rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 1); } ShuffleServerClient serverClient = new MockedShuffleServerClient( diff --git a/client/src/main/java/org/apache/uniffle/client/record/reader/BufferedSegment.java b/client/src/main/java/org/apache/uniffle/client/record/reader/BufferedSegment.java index d7f2afe8a1..4ae8a9bf5c 100644 --- a/client/src/main/java/org/apache/uniffle/client/record/reader/BufferedSegment.java +++ b/client/src/main/java/org/apache/uniffle/client/record/reader/BufferedSegment.java @@ -58,4 +58,10 @@ public void close() throws IOException { this.recordBuffer = null; } } + + @Override + public long getSize() { + // Should never use + return -1; + } } diff --git a/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java b/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java index 30290a08f7..83856bb2aa 100644 --- a/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java +++ b/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java @@ -18,7 +18,6 @@ package org.apache.uniffle.client.record.reader; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Comparator; import java.util.Iterator; @@ -29,6 +28,7 @@ import java.util.concurrent.TimeUnit; import com.google.common.annotations.VisibleForTesting; +import io.netty.buffer.ByteBuf; import org.apache.hadoop.io.DataInputBuffer; import org.apache.hadoop.io.RawComparator; import org.slf4j.Logger; @@ -43,21 +43,23 @@ import org.apache.uniffle.client.record.writer.Combiner; import org.apache.uniffle.client.request.RssGetSortedShuffleDataRequest; import org.apache.uniffle.client.response.RssGetSortedShuffleDataResponse; +import org.apache.uniffle.client.util.RssClientConfig; +import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.merger.MergeState; import org.apache.uniffle.common.merger.Merger; +import org.apache.uniffle.common.netty.buffer.ManagedBuffer; import org.apache.uniffle.common.records.RecordsReader; import org.apache.uniffle.common.rpc.StatusCode; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.SerInputStream; import org.apache.uniffle.common.serializer.Serializer; import org.apache.uniffle.common.serializer.SerializerFactory; import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.writable.ComparativeOutputBuffer; import org.apache.uniffle.common.util.JavaUtils; -import static org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE; import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_MERGE_FETCH_INIT_SLEEP_MS; import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_MERGE_FETCH_MAX_SLEEP_MS; import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_MERGE_READER_MAX_BUFFER; @@ -78,8 +80,11 @@ public class RMRecordsReader { private final Combiner combiner; private boolean isMapCombine; private final MetricsReporter metrics; + private final String clientType; private SerializerInstance serializerInstance; + private final int retryMax; + private final long retryIntervalMax; private final long initFetchSleepTime; private final long maxFetchSleepTime; private final int maxBufferPerPartition; @@ -87,7 +92,7 @@ public class RMRecordsReader { private Map> shuffleServerInfoMap; private volatile boolean stop = false; - private volatile String errorMessage = null; + private volatile Throwable error = null; private Map> combineBuffers = JavaUtils.newConcurrentMap(); private Map> mergeBuffers = JavaUtils.newConcurrentMap(); @@ -106,6 +111,36 @@ public RMRecordsReader( Combiner combiner, boolean isMapCombine, MetricsReporter metrics) { + this( + appId, + shuffleId, + partitionIds, + shuffleServerInfoMap, + rssConf, + keyClass, + valueClass, + comparator, + raw, + combiner, + isMapCombine, + metrics, + ClientType.GRPC.name()); + } + + public RMRecordsReader( + String appId, + int shuffleId, + Set partitionIds, + Map> shuffleServerInfoMap, + RssConf rssConf, + Class keyClass, + Class valueClass, + Comparator comparator, + boolean raw, + Combiner combiner, + boolean isMapCombine, + MetricsReporter metrics, + String clientType) { this.appId = appId; this.shuffleId = shuffleId; this.partitionIds = partitionIds; @@ -131,6 +166,7 @@ public int compare(K o1, K o2) { this.combiner = combiner; this.isMapCombine = isMapCombine; this.metrics = metrics; + this.clientType = clientType; if (this.raw) { SerializerFactory factory = new SerializerFactory(rssConf); Serializer serializer = factory.getSerializer(keyClass); @@ -145,6 +181,14 @@ public int compare(K o1, K o2) { this.maxRecordsNumPerBuffer = rssConf.get(RSS_CLIENT_REMOTE_MERGE_READER_MAX_RECORDS_PER_BUFFER); this.results = new Queue(maxBufferPerPartition * maxRecordsNumPerBuffer * partitionIds.size()); + this.retryMax = + rssConf.getInteger( + RssClientConfig.RSS_CLIENT_RETRY_MAX, + RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE); + this.retryIntervalMax = + rssConf.getLong( + RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX, + RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE); LOG.info("RMRecordsReader constructed for partitions {}", partitionIds); } @@ -167,7 +211,7 @@ public void start() { } public void close() { - errorMessage = null; + error = null; stop = true; for (Queue buffer : mergeBuffers.values()) { buffer.clear(); @@ -376,8 +420,8 @@ public E take() throws InterruptedException { return e; } } - if (errorMessage != null) { - throw new RssException("RMShuffleReader fetch record failed, caused by " + errorMessage); + if (error != null) { + throw new RssException("RMShuffleReader fetch record failed, caused by " + error); } return this.queue.poll(100, TimeUnit.MILLISECONDS); } @@ -425,7 +469,8 @@ public void run() { while (!stop) { try { RssGetSortedShuffleDataRequest request = - new RssGetSortedShuffleDataRequest(appId, shuffleId, partitionId, blockId); + new RssGetSortedShuffleDataRequest( + appId, shuffleId, partitionId, blockId, retryMax, retryIntervalMax); RssGetSortedShuffleDataResponse response = client.getSortedShuffleData(request); if (response.getStatusCode() != StatusCode.SUCCESS || response.getMergeState() == MergeState.INTERNAL_ERROR.code()) { @@ -455,37 +500,53 @@ public void run() { } else if (response.getMergeState() == MergeState.DONE.code() || response.getMergeState() == MergeState.MERGING.code()) { this.sleepTime = initFetchSleepTime; - ByteBuffer byteBuffer = response.getData(); blockId = response.getNextBlockId(); - // Fetch blocks and parsing blocks are a synchronous process. If the two processes are - // split into two - // different threads, then will be asynchronous processes. Although it seems to save - // time, it actually - // consumes more memory. - RecordsReader reader = - new RecordsReader( - rssConf, - PartialInputStream.newInputStream(byteBuffer), - keyClass, - valueClass, - raw); - while (reader.next()) { - if (metrics != null) { - metrics.incRecordsRead(1); + ManagedBuffer managedBuffer = null; + ByteBuf byteBuf = null; + RecordsReader reader = null; + try { + managedBuffer = response.getData(); + byteBuf = managedBuffer.byteBuf(); + // Fetch blocks and parsing blocks are a synchronous process. If the two processes are + // split into two different threads, then will be asynchronous processes. Although it + // seems to save time, it actually consumes more memory. + reader = + new RecordsReader( + rssConf, + SerInputStream.newInputStream(byteBuf), + keyClass, + valueClass, + raw, + false); + reader.init(); + while (reader.next()) { + if (metrics != null) { + metrics.incRecordsRead(1); + } + if (recordBuffer.size() >= maxRecordsNumPerBuffer) { + nextQueue.put(recordBuffer); + recordBuffer = new RecordBuffer(partitionId); + } + recordBuffer.addRecord(reader.getCurrentKey(), reader.getCurrentValue()); + } + } finally { + if (reader != null) { + reader.close(); + } + if (byteBuf != null) { + byteBuf.release(); } - if (recordBuffer.size() >= maxRecordsNumPerBuffer) { - nextQueue.put(recordBuffer); - recordBuffer = new RecordBuffer(partitionId); + if (managedBuffer != null) { + managedBuffer.release(); } - recordBuffer.addRecord(reader.getCurrentKey(), reader.getCurrentValue()); } } else { fetchError = "Receive wrong offset from server, offset is " + response.getNextBlockId(); nextShuffleServerInfo(); break; } - } catch (Exception e) { - errorMessage = e.getMessage(); + } catch (Throwable e) { + error = e; stop = true; LOG.info("Found exception when fetch sorted record, caused by ", e); } @@ -581,28 +642,33 @@ public void run() { } } Merger.MergeQueue mergeQueue = - new Merger.MergeQueue(rssConf, segments, keyClass, valueClass, comparator, raw); - mergeQueue.init(); - mergeQueue.setPopSegmentHook( - pid -> { - try { - RecordBuffer recordBuffer = mergeBuffers.get(pid).take(); - if (recordBuffer == null) { - return null; + new Merger.MergeQueue(rssConf, segments, keyClass, valueClass, comparator, raw, false); + try { + // Here are BufferedSegment, no need to init + mergeQueue.init(); + mergeQueue.setPopSegmentHook( + pid -> { + try { + RecordBuffer recordBuffer = mergeBuffers.get(pid).take(); + if (recordBuffer == null) { + return null; + } + return new BufferedSegment(recordBuffer); + } catch (InterruptedException ex) { + throw new RssException(ex); } - return new BufferedSegment(recordBuffer); - } catch (InterruptedException ex) { - throw new RssException(ex); - } - }); - while (!stop && mergeQueue.next()) { - results.put(Record.create(mergeQueue.getCurrentKey(), mergeQueue.getCurrentValue())); + }); + while (!stop && mergeQueue.next()) { + results.put(Record.create(mergeQueue.getCurrentKey(), mergeQueue.getCurrentValue())); + } + } finally { + mergeQueue.close(); } if (!stop) { results.setProducerDone(true); } } catch (InterruptedException | IOException e) { - errorMessage = e.getMessage(); + error = e; stop = true; } } @@ -611,6 +677,6 @@ public void run() { @VisibleForTesting public ShuffleServerClient createShuffleServerClient(ShuffleServerInfo shuffleServerInfo) { return ShuffleServerClientFactory.getInstance() - .getShuffleServerClient(RSS_CLIENT_TYPE_DEFAULT_VALUE, shuffleServerInfo, rssConf); + .getShuffleServerClient(this.clientType, shuffleServerInfo, rssConf); } } diff --git a/client/src/test/java/org/apache/uniffle/client/record/reader/BufferedSegmentTest.java b/client/src/test/java/org/apache/uniffle/client/record/reader/BufferedSegmentTest.java index c4e32b5722..aee0d87875 100644 --- a/client/src/test/java/org/apache/uniffle/client/record/reader/BufferedSegmentTest.java +++ b/client/src/test/java/org/apache/uniffle/client/record/reader/BufferedSegmentTest.java @@ -65,7 +65,8 @@ public void testMergeResolvedSegmentWithHook() throws Exception { } } Merger.MergeQueue mergeQueue = - new Merger.MergeQueue(rssConf, segments, Text.class, IntWritable.class, comparator, false); + new Merger.MergeQueue( + rssConf, segments, Text.class, IntWritable.class, comparator, false, false); mergeQueue.init(); mergeQueue.setPopSegmentHook( id -> { diff --git a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleServerClient.java b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleServerClient.java index 93da7c4c9d..931f8990c5 100644 --- a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleServerClient.java +++ b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleServerClient.java @@ -18,12 +18,13 @@ package org.apache.uniffle.client.record.reader; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.apache.uniffle.client.api.ShuffleServerClient; @@ -58,16 +59,17 @@ import org.apache.uniffle.client.response.RssUnregisterShuffleResponse; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.merger.MergeState; +import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.util.RssUtils; public class MockedShuffleServerClient implements ShuffleServerClient { - private Map> shuffleData; + private Map> shuffleData; private Map indexes; private long[] blockIds; - public MockedShuffleServerClient(int[] partitionIds, ByteBuffer[][] buffers, long[] blockIds) { + public MockedShuffleServerClient(int[] partitionIds, ByteBuf[][] buffers, long[] blockIds) { if (partitionIds.length != buffers.length) { throw new RssException("partition id length is not matched"); } @@ -75,7 +77,7 @@ public MockedShuffleServerClient(int[] partitionIds, ByteBuffer[][] buffers, lon for (int i = 0; i < partitionIds.length; i++) { int partition = partitionIds[i]; shuffleData.put(partition, new ArrayList<>()); - for (ByteBuffer byteBuffer : buffers[i]) { + for (ByteBuf byteBuffer : buffers[i]) { shuffleData.get(partition).add(byteBuffer); } } @@ -100,12 +102,18 @@ public RssGetSortedShuffleDataResponse getSortedShuffleData( response = new RssGetSortedShuffleDataResponse( StatusCode.SUCCESS, - shuffleData.get(partitionId).get(index), + "", + new NettyManagedBuffer(shuffleData.get(partitionId).get(index).retain()), 10000, MergeState.DONE.code()); } else { response = - new RssGetSortedShuffleDataResponse(StatusCode.SUCCESS, null, -1, MergeState.DONE.code()); + new RssGetSortedShuffleDataResponse( + StatusCode.SUCCESS, + "", + new NettyManagedBuffer(Unpooled.buffer(0)), + -1, + MergeState.DONE.code()); } indexes.put(partitionId, index + 1); return response; diff --git a/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java b/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java index 65d513ced1..2c543ad25e 100644 --- a/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java +++ b/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java @@ -17,14 +17,14 @@ package org.apache.uniffle.client.record.reader; -import java.io.ByteArrayOutputStream; -import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.Comparator; import java.util.List; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; +import io.netty.buffer.ByteBuf; import org.apache.hadoop.io.IntWritable; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; @@ -37,11 +37,13 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.merger.Merger; import org.apache.uniffle.common.merger.Segment; +import org.apache.uniffle.common.serializer.DynBufferSerOutputStream; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.serializer.SerializerFactory; import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.SerializerUtils; -import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes; +import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doReturn; @@ -90,10 +92,9 @@ public void testNormalReadWithoutCombine(String classes) throws Exception { combiner, false, null); - byte[] buffers = genSortedRecordBytes(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, 1); + ByteBuf byteBuf = genSortedRecordBuffer(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, 1); ShuffleServerClient serverClient = - new MockedShuffleServerClient( - new int[] {partitionId}, new ByteBuffer[][] {{ByteBuffer.wrap(buffers)}}, null); + new MockedShuffleServerClient(new int[] {partitionId}, new ByteBuf[][] {{byteBuf}}, null); RMRecordsReader readerSpy = spy(reader); doReturn(serverClient).when(readerSpy).createShuffleServerClient(any()); @@ -107,6 +108,7 @@ public void testNormalReadWithoutCombine(String classes) throws Exception { index++; } assertEquals(RECORDS_NUM, index); + byteBuf.release(); } @Timeout(30) @@ -142,13 +144,13 @@ public void testNormalReadWithCombine(String classes) throws Exception { SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM)); segments.add( SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM)); - ByteArrayOutputStream output = new ByteArrayOutputStream(); + segments.forEach(segment -> segment.init()); + SerOutputStream output = new DynBufferSerOutputStream(); Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator, false); output.close(); - byte[] buffers = output.toByteArray(); + ByteBuf byteBuf = output.toByteBuf(); ShuffleServerClient serverClient = - new MockedShuffleServerClient( - new int[] {partitionId}, new ByteBuffer[][] {{ByteBuffer.wrap(buffers)}}, null); + new MockedShuffleServerClient(new int[] {partitionId}, new ByteBuf[][] {{byteBuf}}, null); RMRecordsReader reader = new RMRecordsReader( APP_ID, @@ -185,6 +187,7 @@ public void testNormalReadWithCombine(String classes) throws Exception { index++; } assertEquals(RECORDS_NUM * 2, index); + byteBuf.release(); } @Timeout(30) @@ -231,15 +234,12 @@ public void testReadMulitPartitionWithoutCombine(String classes) throws Exceptio false, null); RMRecordsReader readerSpy = spy(reader); - ByteBuffer[][] buffers = new ByteBuffer[3][2]; + ByteBuf[][] buffers = new ByteBuf[3][2]; for (int i = 0; i < 3; i++) { - buffers[i][0] = - ByteBuffer.wrap( - genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 1)); + buffers[i][0] = genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 1); buffers[i][1] = - ByteBuffer.wrap( - genSortedRecordBytes( - rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 1)); + genSortedRecordBuffer( + rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 1); } ShuffleServerClient serverClient = new MockedShuffleServerClient( @@ -256,6 +256,7 @@ public void testReadMulitPartitionWithoutCombine(String classes) throws Exceptio index++; } assertEquals(RECORDS_NUM * 6, index); + Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release())); } @Timeout(30) @@ -305,15 +306,12 @@ public void testReadMulitPartitionWithCombine(String classes) throws Exception { false, null); RMRecordsReader readerSpy = spy(reader); - ByteBuffer[][] buffers = new ByteBuffer[3][2]; + ByteBuf[][] buffers = new ByteBuf[3][2]; for (int i = 0; i < 3; i++) { - buffers[i][0] = - ByteBuffer.wrap( - genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 2)); + buffers[i][0] = genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 2); buffers[i][1] = - ByteBuffer.wrap( - genSortedRecordBytes( - rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 2)); + genSortedRecordBuffer( + rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 2); } ShuffleServerClient serverClient = new MockedShuffleServerClient( @@ -331,5 +329,6 @@ public void testReadMulitPartitionWithCombine(String classes) throws Exception { index++; } assertEquals(RECORDS_NUM * 6, index); + Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release())); } } diff --git a/client/src/test/java/org/apache/uniffle/client/record/writer/RecordCollectionTest.java b/client/src/test/java/org/apache/uniffle/client/record/writer/RecordCollectionTest.java index b5f0b5bd7b..2cc1ad8f73 100644 --- a/client/src/test/java/org/apache/uniffle/client/record/writer/RecordCollectionTest.java +++ b/client/src/test/java/org/apache/uniffle/client/record/writer/RecordCollectionTest.java @@ -17,12 +17,11 @@ package org.apache.uniffle.client.record.writer; -import java.io.ByteArrayOutputStream; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import io.netty.buffer.ByteBuf; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -31,7 +30,9 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.records.RecordsReader; import org.apache.uniffle.common.records.RecordsWriter; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.DynBufferSerOutputStream; +import org.apache.uniffle.common.serializer.SerInputStream; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.serializer.SerializerUtils; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -68,19 +69,19 @@ public void testSortAndSerializeRecords(String classes) throws Exception { // 4 serialize records RssConf rssConf = new RssConf(); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - RecordsWriter writer = new RecordsWriter(rssConf, outputStream, keyClass, valueClass, false); + SerOutputStream outputStream = new DynBufferSerOutputStream(); + RecordsWriter writer = + new RecordsWriter(rssConf, outputStream, keyClass, valueClass, false, false); + writer.init(); recordBuffer.serialize(writer); writer.close(); // 5 check the serialized data + ByteBuf byteBuf = outputStream.toByteBuf(); RecordsReader reader = new RecordsReader<>( - rssConf, - PartialInputStream.newInputStream(ByteBuffer.wrap(outputStream.toByteArray())), - keyClass, - valueClass, - false); + rssConf, SerInputStream.newInputStream(byteBuf), keyClass, valueClass, false, false); + reader.init(); int index = 0; while (reader.next()) { assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); @@ -89,6 +90,7 @@ public void testSortAndSerializeRecords(String classes) throws Exception { } assertEquals(RECORDS, index); reader.close(); + byteBuf.release(); } @ParameterizedTest @@ -126,8 +128,10 @@ public void testSortCombineAndSerializeRecords(String classes) throws Exception // 4 serialize records RssConf rssConf = new RssConf(); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - RecordsWriter writer = new RecordsWriter(rssConf, outputStream, keyClass, valueClass, false); + SerOutputStream outputStream = new DynBufferSerOutputStream(); + RecordsWriter writer = + new RecordsWriter(rssConf, outputStream, keyClass, valueClass, false, false); + writer.init(); recordBlob.serialize(writer); writer.close(); @@ -135,10 +139,12 @@ public void testSortCombineAndSerializeRecords(String classes) throws Exception RecordsReader reader = new RecordsReader<>( rssConf, - PartialInputStream.newInputStream(ByteBuffer.wrap(outputStream.toByteArray())), + SerInputStream.newInputStream(outputStream.toByteBuf()), keyClass, valueClass, + false, false); + reader.init(); int index = 0; while (reader.next()) { int aimValue = index; diff --git a/common/src/main/java/org/apache/uniffle/common/merger/Merger.java b/common/src/main/java/org/apache/uniffle/common/merger/Merger.java index 7fcbe2d9b4..c5dc21efaf 100644 --- a/common/src/main/java/org/apache/uniffle/common/merger/Merger.java +++ b/common/src/main/java/org/apache/uniffle/common/merger/Merger.java @@ -18,12 +18,12 @@ package org.apache.uniffle.common.merger; import java.io.IOException; -import java.io.OutputStream; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.function.Function; +import io.netty.buffer.ByteBuf; import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.RawComparator; import org.apache.hadoop.util.PriorityQueue; @@ -31,6 +31,7 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.records.RecordsWriter; +import org.apache.uniffle.common.serializer.SerOutputStream; public class Merger { @@ -42,6 +43,7 @@ public static class MergeQueue extends PriorityQueue implements K private final Class valueClass; private Comparator comparator; private boolean raw; + private boolean buffered; private Object currentKey; private Object currentValue; @@ -54,7 +56,8 @@ public MergeQueue( Class keyClass, Class valueClass, Comparator comparator, - boolean raw) { + boolean raw, + boolean buffered) { this.rssConf = rssConf; this.segments = segments; this.keyClass = keyClass; @@ -62,8 +65,9 @@ public MergeQueue( if (comparator == null) { throw new RssException("comparator is null!"); } - this.raw = raw; this.comparator = comparator; + this.raw = raw; + this.buffered = buffered; } public void setPopSegmentHook(Function popSegmentHook) { @@ -73,14 +77,33 @@ public void setPopSegmentHook(Function popSegmentHook) { @Override protected boolean lessThan(Object o1, Object o2) { if (raw) { - Segment s1 = (Segment) o1; - Segment s2 = (Segment) o2; - DataOutputBuffer key1 = (DataOutputBuffer) s1.getCurrentKey(); - DataOutputBuffer key2 = (DataOutputBuffer) s2.getCurrentKey(); - int c = - ((RawComparator) comparator) - .compare(key1.getData(), 0, key1.getLength(), key2.getData(), 0, key2.getLength()); - return c < 0 || ((c == 0) && s1.getId() < s2.getId()); + if (buffered) { + Segment s1 = (Segment) o1; + Segment s2 = (Segment) o2; + ByteBuf key1 = (ByteBuf) s1.getCurrentKey(); + ByteBuf key2 = (ByteBuf) s2.getCurrentKey(); + // make sure key buffer is in heap, avoid byte array copy + int c = + ((RawComparator) comparator) + .compare( + key1.array(), + key1.arrayOffset() + key1.readerIndex(), + key1.readableBytes(), + key2.array(), + key2.arrayOffset() + key2.readerIndex(), + key2.readableBytes()); + return c < 0 || ((c == 0) && s1.getId() < s2.getId()); + } else { + Segment s1 = (Segment) o1; + Segment s2 = (Segment) o2; + DataOutputBuffer key1 = (DataOutputBuffer) s1.getCurrentKey(); + DataOutputBuffer key2 = (DataOutputBuffer) s2.getCurrentKey(); + int c = + ((RawComparator) comparator) + .compare( + key1.getData(), 0, key1.getLength(), key2.getData(), 0, key2.getLength()); + return c < 0 || ((c == 0) && s1.getId() < s2.getId()); + } } else { Segment s1 = (Segment) o1; Segment s2 = (Segment) o2; @@ -153,6 +176,7 @@ private void adjustPriorityQueue(Segment segment) throws IOException { if (popSegmentHook != null) { Segment newSegment = popSegmentHook.apply((int) segment.getId()); if (newSegment != null) { + newSegment.init(); if (newSegment.next()) { put(newSegment); } else { @@ -163,46 +187,40 @@ private void adjustPriorityQueue(Segment segment) throws IOException { } } - void merge(OutputStream output) throws IOException { + public void merge(SerOutputStream output) throws IOException { RecordsWriter writer = - new RecordsWriter(rssConf, output, keyClass, valueClass, raw); - boolean recorded = true; - while (this.next()) { - writer.append(this.getCurrentKey(), this.getCurrentValue()); - if (output instanceof Recordable) { - recorded = - ((Recordable) output) - .record(writer.getTotalBytesWritten(), () -> writer.flush(), false); + new RecordsWriter(rssConf, output, keyClass, valueClass, raw, buffered); + try { + writer.init(); + while (this.next()) { + writer.append(this.getCurrentKey(), this.getCurrentValue()); } + writer.flush(); + } finally { + writer.close(); } - writer.flush(); - if (!recorded) { - ((Recordable) output).record(writer.getTotalBytesWritten(), null, true); - } - writer.close(); } @Override - public void close() throws IOException { - Segment segment; - while ((segment = pop()) != null) { - segment.close(); - } - } + public void close() throws IOException {} } public static void merge( RssConf conf, - OutputStream output, + SerOutputStream output, List segments, Class keyClass, Class valueClass, Comparator comparator, boolean raw) throws IOException { - MergeQueue mergeQueue = new MergeQueue(conf, segments, keyClass, valueClass, comparator, raw); - mergeQueue.init(); - mergeQueue.merge(output); - mergeQueue.close(); + MergeQueue mergeQueue = + new MergeQueue(conf, segments, keyClass, valueClass, comparator, raw, true); + try { + mergeQueue.init(); + mergeQueue.merge(output); + } finally { + mergeQueue.close(); + } } } diff --git a/common/src/main/java/org/apache/uniffle/common/merger/Segment.java b/common/src/main/java/org/apache/uniffle/common/merger/Segment.java index f8a7301229..9cc9c7054b 100644 --- a/common/src/main/java/org/apache/uniffle/common/merger/Segment.java +++ b/common/src/main/java/org/apache/uniffle/common/merger/Segment.java @@ -27,6 +27,8 @@ public Segment(long id) { this.id = id; } + public void init() {} + public abstract boolean next() throws IOException; public abstract Object getCurrentKey(); @@ -38,4 +40,6 @@ public long getId() { } public abstract void close() throws IOException; + + public abstract long getSize(); } diff --git a/common/src/main/java/org/apache/uniffle/common/merger/StreamedSegment.java b/common/src/main/java/org/apache/uniffle/common/merger/StreamedSegment.java index 7276b140e9..62292c7088 100644 --- a/common/src/main/java/org/apache/uniffle/common/merger/StreamedSegment.java +++ b/common/src/main/java/org/apache/uniffle/common/merger/StreamedSegment.java @@ -17,62 +17,33 @@ package org.apache.uniffle.common.merger; -import java.io.File; import java.io.IOException; -import java.nio.ByteBuffer; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.records.RecordsReader; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.SerInputStream; public class StreamedSegment extends Segment { private RecordsReader reader; + private final long size; public StreamedSegment( RssConf rssConf, - PartialInputStream inputStream, + SerInputStream inputStream, long blockId, Class keyClass, Class valueClass, + long size, boolean raw) { super(blockId); - this.reader = new RecordsReader<>(rssConf, inputStream, keyClass, valueClass, raw); + this.reader = new RecordsReader<>(rssConf, inputStream, keyClass, valueClass, raw, true); + this.size = size; } - // The buffer must be sorted by key - public StreamedSegment( - RssConf rssConf, - ByteBuffer byteBuffer, - long blockId, - Class keyClass, - Class valueClass, - boolean raw) - throws IOException { - super(blockId); - this.reader = - new RecordsReader<>( - rssConf, PartialInputStream.newInputStream(byteBuffer), keyClass, valueClass, raw); - } - - public StreamedSegment( - RssConf rssConf, - File file, - long start, - long end, - long blockId, - Class keyClass, - Class valueClass, - boolean raw) - throws IOException { - super(blockId); - this.reader = - new RecordsReader( - rssConf, - PartialInputStream.newInputStream(file, start, end), - keyClass, - valueClass, - raw); + @Override + public void init() { + this.reader.init(); } @Override @@ -97,4 +68,8 @@ public void close() throws IOException { this.reader = null; } } + + public long getSize() { + return size; + } } diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetSortedShuffleDataRequest.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetSortedShuffleDataRequest.java new file mode 100644 index 0000000000..ff5d6ce218 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetSortedShuffleDataRequest.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.netty.protocol; + +import io.netty.buffer.ByteBuf; + +import org.apache.uniffle.common.netty.DecodeException; +import org.apache.uniffle.common.netty.EncodeException; +import org.apache.uniffle.common.util.ByteBufUtils; + +public class GetSortedShuffleDataRequest extends RequestMessage { + private final String appId; + private final int shuffleId; + private final int partitionId; + private final long blockId; + private final int length; + private final long timestamp; + + public GetSortedShuffleDataRequest( + long requestId, + String appId, + int shuffleId, + int partitionId, + long blockId, + int length, + long timestamp) { + super(requestId); + this.appId = appId; + this.shuffleId = shuffleId; + this.partitionId = partitionId; + this.blockId = blockId; + this.length = length; + this.timestamp = timestamp; + } + + public String getOperationType() { + return "getSortedShuffleData"; + } + + public Type type() { + return Type.GET_SORTED_SHUFFLE_DATA_REQUEST; + } + + public int encodedLength() { + return REQUEST_ID_ENCODE_LENGTH + + ByteBufUtils.encodedLength(appId) + + 3 * Integer.BYTES + + 2 * Long.BYTES; + } + + public void encode(ByteBuf buf) throws EncodeException { + buf.writeLong(getRequestId()); + ByteBufUtils.writeLengthAndString(buf, appId); + buf.writeInt(shuffleId); + buf.writeInt(partitionId); + buf.writeLong(blockId); + buf.writeInt(length); + buf.writeLong(timestamp); + } + + public static GetSortedShuffleDataRequest decode(ByteBuf buf) throws DecodeException { + long requestId = buf.readLong(); + String appId = ByteBufUtils.readLengthAndString(buf); + int shuffleId = buf.readInt(); + int partitionId = buf.readInt(); + long blockId = buf.readLong(); + int length = buf.readInt(); + long timestamp = buf.readLong(); + return new GetSortedShuffleDataRequest( + requestId, appId, shuffleId, partitionId, blockId, length, timestamp); + } + + public String getAppId() { + return appId; + } + + public int getShuffleId() { + return shuffleId; + } + + public int getPartitionId() { + return partitionId; + } + + public long getBlockId() { + return blockId; + } + + public int getLength() { + return length; + } + + public long getTimestamp() { + return timestamp; + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetSortedShuffleDataResponse.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetSortedShuffleDataResponse.java new file mode 100644 index 0000000000..fcf42814d2 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetSortedShuffleDataResponse.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.netty.protocol; + +import io.netty.buffer.ByteBuf; + +import org.apache.uniffle.common.netty.buffer.ManagedBuffer; +import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; +import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.ByteBufUtils; + +public class GetSortedShuffleDataResponse extends RpcResponse { + private long nextBlockId; + private int mergeState; + + public GetSortedShuffleDataResponse( + long requestId, + StatusCode statusCode, + String retMessage, + long nextBlockId, + int mergeState, + ByteBuf data) { + this(requestId, statusCode, retMessage, nextBlockId, mergeState, new NettyManagedBuffer(data)); + } + + public GetSortedShuffleDataResponse( + long requestId, + StatusCode statusCode, + String retMessage, + long nextBlockId, + int mergeState, + ManagedBuffer managedBuffer) { + super(requestId, statusCode, retMessage, managedBuffer); + this.nextBlockId = nextBlockId; + this.mergeState = mergeState; + } + + public long getNextBlockId() { + return nextBlockId; + } + + public int getMergeState() { + return mergeState; + } + + @Override + public int encodedLength() { + return super.encodedLength() + 8 + 4; + } + + @Override + public void encode(ByteBuf buf) { + super.encode(buf); + buf.writeLong(nextBlockId); + buf.writeInt(mergeState); + } + + public static GetSortedShuffleDataResponse decode(ByteBuf buf, boolean decodeBody) { + long requestId = buf.readLong(); + StatusCode statusCode = StatusCode.fromCode(buf.readInt()); + String retMessage = ByteBufUtils.readLengthAndString(buf); + NettyManagedBuffer nettyManagedBuffer; + if (decodeBody) { + nettyManagedBuffer = new NettyManagedBuffer(buf); + } else { + nettyManagedBuffer = NettyManagedBuffer.EMPTY_BUFFER; + } + long nextBlockId = buf.readLong(); + int mergeState = buf.readInt(); + return new GetSortedShuffleDataResponse( + requestId, statusCode, retMessage, nextBlockId, mergeState, nettyManagedBuffer); + } + + @Override + public Type type() { + return Type.GET_SORTED_SHUFFLE_DATA_RESPONSE; + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java index 8e925fac80..ed5167e77b 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java @@ -62,7 +62,9 @@ public enum Type implements Encodable { SHUFFLE_COMMIT_RESPONSE(17), GET_SHUFFLE_RESULT_RESPONSE(18), GET_SHUFFLE_RESULT_FOR_MULTI_PART_RESPONSE(19), - REQUIRE_BUFFER_RESPONSE(20); + REQUIRE_BUFFER_RESPONSE(20), + GET_SORTED_SHUFFLE_DATA_REQUEST(21), + GET_SORTED_SHUFFLE_DATA_RESPONSE(22); private final byte id; @@ -132,6 +134,10 @@ public static Type decode(ByteBuf buf) { return GET_SHUFFLE_RESULT_FOR_MULTI_PART_RESPONSE; case 20: return REQUIRE_BUFFER_RESPONSE; + case 21: + return GET_SORTED_SHUFFLE_DATA_REQUEST; + case 22: + return GET_SORTED_SHUFFLE_DATA_RESPONSE; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: @@ -158,6 +164,10 @@ public static Message decode(Type msgType, ByteBuf in) { return GetMemoryShuffleDataRequest.decode(in); case GET_MEMORY_SHUFFLE_DATA_RESPONSE: return GetMemoryShuffleDataResponse.decode(in, true); + case GET_SORTED_SHUFFLE_DATA_REQUEST: + return GetSortedShuffleDataRequest.decode(in); + case GET_SORTED_SHUFFLE_DATA_RESPONSE: + return GetSortedShuffleDataResponse.decode(in, true); default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/src/main/java/org/apache/uniffle/common/records/RecordsReader.java b/common/src/main/java/org/apache/uniffle/common/records/RecordsReader.java index 370239c409..0594623a77 100644 --- a/common/src/main/java/org/apache/uniffle/common/records/RecordsReader.java +++ b/common/src/main/java/org/apache/uniffle/common/records/RecordsReader.java @@ -21,7 +21,7 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.serializer.DeserializationStream; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.SerInputStream; import org.apache.uniffle.common.serializer.Serializer; import org.apache.uniffle.common.serializer.SerializerFactory; import org.apache.uniffle.common.serializer.SerializerInstance; @@ -34,15 +34,20 @@ public class RecordsReader { public RecordsReader( RssConf rssConf, - PartialInputStream input, + SerInputStream input, Class keyClass, Class valueClass, - boolean raw) { + boolean raw, + boolean buffered) { SerializerFactory factory = new SerializerFactory(rssConf); Serializer serializer = factory.getSerializer(keyClass); assert factory.getSerializer(valueClass).getClass().equals(serializer.getClass()); SerializerInstance instance = serializer.newInstance(); - stream = instance.deserializeStream(input, keyClass, valueClass, raw); + stream = instance.deserializeStream(input, keyClass, valueClass, raw, buffered); + } + + public void init() { + this.stream.init(); } public boolean next() throws IOException { diff --git a/common/src/main/java/org/apache/uniffle/common/records/RecordsWriter.java b/common/src/main/java/org/apache/uniffle/common/records/RecordsWriter.java index 9ac9b2734c..d68add57a5 100644 --- a/common/src/main/java/org/apache/uniffle/common/records/RecordsWriter.java +++ b/common/src/main/java/org/apache/uniffle/common/records/RecordsWriter.java @@ -18,9 +18,9 @@ package org.apache.uniffle.common.records; import java.io.IOException; -import java.io.OutputStream; import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.serializer.SerializationStream; import org.apache.uniffle.common.serializer.Serializer; import org.apache.uniffle.common.serializer.SerializerFactory; @@ -31,12 +31,21 @@ public class RecordsWriter { private SerializationStream stream; public RecordsWriter( - RssConf rssConf, OutputStream out, Class keyClass, Class valueClass, boolean raw) { + RssConf rssConf, + SerOutputStream out, + Class keyClass, + Class valueClass, + boolean raw, + boolean buffered) { SerializerFactory factory = new SerializerFactory(rssConf); Serializer serializer = factory.getSerializer(keyClass); assert factory.getSerializer(valueClass).getClass().equals(serializer.getClass()); SerializerInstance instance = serializer.newInstance(); - stream = instance.serializeStream(out, raw); + stream = instance.serializeStream(out, raw, buffered); + } + + public void init() { + this.stream.init(); } public void append(Object key, Object value) throws IOException { diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/BufferPartialInputStreamImpl.java b/common/src/main/java/org/apache/uniffle/common/serializer/BufferSerInputStream.java similarity index 64% rename from common/src/main/java/org/apache/uniffle/common/serializer/BufferPartialInputStreamImpl.java rename to common/src/main/java/org/apache/uniffle/common/serializer/BufferSerInputStream.java index 81fb826e12..ae1919ae76 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/BufferPartialInputStreamImpl.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/BufferSerInputStream.java @@ -18,36 +18,30 @@ package org.apache.uniffle.common.serializer; import java.io.IOException; -import java.nio.ByteBuffer; -public class BufferPartialInputStreamImpl extends PartialInputStream { +import io.netty.buffer.ByteBuf; - private ByteBuffer buffer; - private final long start; // the start of source input stream - private final long end; // the end of source input stream +public class BufferSerInputStream extends SerInputStream { - public BufferPartialInputStreamImpl(ByteBuffer byteBuffer, long start, long end) - throws IOException { - if (start < 0) { - throw new IOException("Negative position for channel!"); - } - this.buffer = byteBuffer; + private ByteBuf buffer; + private final int start; + private final int end; + + private final int size; + + public BufferSerInputStream(ByteBuf byteBuf, int start, int end) { + assert start >= 0; + // TODO: the byteBuf should retain outside. + this.buffer = byteBuf; this.start = start; this.end = end; - this.buffer.position((int) start); - } - - @Override - public int read() throws IOException { - if (available() <= 0) { - return -1; - } - return this.buffer.get() & 0xff; + this.buffer.readerIndex(start); + this.size = end - start; } @Override - public int available() throws IOException { - return (int) (end - this.buffer.position()); + public int available() { + return end - this.buffer.readerIndex(); } @Override @@ -59,4 +53,17 @@ public long getStart() { public long getEnd() { return end; } + + @Override + public void transferTo(ByteBuf to, int len) throws IOException { + to.writeBytes(buffer, len); + } + + @Override + public int read() throws IOException { + if (available() <= 0) { + return -1; + } + return this.buffer.readByte() & 0xFF; + } } diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/DeserializationStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/DeserializationStream.java index cdef70ba99..138e84af9e 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/DeserializationStream.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/DeserializationStream.java @@ -21,6 +21,8 @@ public abstract class DeserializationStream { + public abstract void init(); + public abstract boolean nextRecord() throws IOException; public abstract K getCurrentKey() throws IOException; diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/DynBufferSerOutputStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/DynBufferSerOutputStream.java new file mode 100644 index 0000000000..29ae11c447 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/serializer/DynBufferSerOutputStream.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.serializer; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +public class DynBufferSerOutputStream extends SerOutputStream { + + private WrappedByteArrayOutputStream buffer; + + public DynBufferSerOutputStream() { + this.buffer = new WrappedByteArrayOutputStream(); + } + + public DynBufferSerOutputStream(int capacity) { + this.buffer = new WrappedByteArrayOutputStream(capacity); + } + + @Override + public void write(ByteBuf from) throws IOException { + // We copy the bytes, but it doesn't matter, only for test + byte[] bytes = new byte[from.readableBytes()]; + from.readBytes(bytes); + buffer.write(bytes); + } + + @Override + public void write(int b) throws IOException { + this.buffer.write(b); + } + + @Override + public ByteBuf toByteBuf() { + return Unpooled.wrappedBuffer(ByteBuffer.wrap(buffer.toByteArray())); + } + + @Override + public void flush() throws IOException { + super.flush(); + } + + @Override + public void close() throws IOException { + super.close(); + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStreamImpl.java b/common/src/main/java/org/apache/uniffle/common/serializer/FileSerInputStream.java similarity index 64% rename from common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStreamImpl.java rename to common/src/main/java/org/apache/uniffle/common/serializer/FileSerInputStream.java index b16a7c8d8d..5159b4bd4c 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStreamImpl.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/FileSerInputStream.java @@ -17,42 +17,72 @@ package org.apache.uniffle.common.serializer; -import java.io.Closeable; +import java.io.File; +import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.channels.SeekableByteChannel; +import java.nio.channels.FileChannel; -/* - * PartialInputStream is a configurable partial input stream, which - * only allows reading from start to end of the source input stream. - * */ -public class PartialInputStreamImpl extends PartialInputStream { +import io.netty.buffer.ByteBuf; + +public class FileSerInputStream extends SerInputStream { - private final SeekableByteChannel ch; // the source input channel private final long start; // the start of source input stream private final long end; // the end of source input stream - private long pos; // the read offset + private FileInputStream input; // the input stream of the source + private FileChannel fileChannel; + // In FileSerInputStream, buffer is not direct memory. This stream will read file + // content to direct memory, then copy the direct memory to heap. But it doesn't + // matter, because it is only used for testing. private ByteBuffer bb = null; private byte[] bs = null; private byte[] b1; - private Closeable closeable; + private long pos; - public PartialInputStreamImpl(SeekableByteChannel ch, long start, long end, Closeable closeable) - throws IOException { + public FileSerInputStream(File file, long start, long end) throws IOException { if (start < 0) { throw new IOException("Negative position for channel!"); } - this.ch = ch; + this.input = new FileInputStream(file); + this.fileChannel = input.getChannel(); + if (this.fileChannel == null) { + throw new IOException("channel is null!"); + } this.start = start; this.end = end; - this.closeable = closeable; this.pos = start; - ch.position(start); + this.fileChannel.position(start); + } + + @Override + public int available() { + return (int) (end - pos); + } + + @Override + public long getStart() { + return start; + } + + @Override + public long getEnd() { + return end; + } + + @Override + public void transferTo(ByteBuf to, int len) throws IOException { + // We copy the bytes, but it doesn't matter, only for test + byte[] bytes = new byte[len]; + while (len > 0) { + int c = read(bytes, 0, bytes.length); + len -= c; + } + to.writeBytes(bytes); } private int read(ByteBuffer bb) throws IOException { - return ch.read(bb); + return fileChannel.read(bb); } @Override @@ -62,7 +92,7 @@ public synchronized int read() throws IOException { } int n = read(b1); if (n == 1) { - return b1[0] & 0xff; + return b1[0] & 0xFF; } return -1; } @@ -90,25 +120,11 @@ public synchronized int read(byte[] bs, int off, int len) throws IOException { return ret; } - @Override - public int available() throws IOException { - return (int) (end - pos); - } - - @Override - public long getStart() { - return start; - } - - @Override - public long getEnd() { - return end; - } - @Override public void close() throws IOException { - if (closeable != null) { - closeable.close(); + if (this.input != null) { + this.input.close(); + this.input = null; } } } diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/FileSerOutputStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/FileSerOutputStream.java new file mode 100644 index 0000000000..053720a2b8 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/serializer/FileSerOutputStream.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.serializer; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import io.netty.buffer.ByteBuf; + +public class FileSerOutputStream extends SerOutputStream { + + private OutputStream outputStream; + + public FileSerOutputStream(File file) throws IOException { + this.outputStream = new FileOutputStream(file); + } + + @Override + public void write(ByteBuf from) throws IOException { + // We copy the bytes, but it doesn't matter, only for test + byte[] bytes = new byte[from.readableBytes()]; + from.readBytes(bytes); + outputStream.write(bytes); + } + + @Override + public void write(int b) throws IOException { + outputStream.write(b); + } + + @Override + public void flush() throws IOException { + outputStream.flush(); + } + + @Override + public void close() throws IOException { + if (outputStream != null) { + outputStream.close(); + outputStream = null; + } + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStream.java deleted file mode 100644 index f1a18595ee..0000000000 --- a/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStream.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.uniffle.common.serializer; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.nio.channels.FileChannel; - -/* - * PartialInputStream is a configurable partial input stream, which - * only allows reading from start to end of the source input stream. - * */ -public abstract class PartialInputStream extends InputStream { - - @Override - public abstract int available() throws IOException; - - public abstract long getStart(); - - public abstract long getEnd(); - - public static PartialInputStream newInputStream(File file, long start, long end) - throws IOException { - FileInputStream input = new FileInputStream(file); - FileChannel fc = input.getChannel(); - if (fc == null) { - throw new NullPointerException("channel is null!"); - } - long size = fc.size(); - return new PartialInputStreamImpl( - fc, - start, - Math.min(end, size), - () -> { - input.close(); - }); - } - - public static PartialInputStream newInputStream(File file) throws IOException { - return PartialInputStream.newInputStream(file, 0, file.length()); - } - - public static PartialInputStream newInputStream(ByteBuffer byteBuffer, long start, long end) - throws IOException { - return new BufferPartialInputStreamImpl(byteBuffer, start, Math.min(byteBuffer.limit(), end)); - } - - public static PartialInputStream newInputStream(ByteBuffer byteBuffer) throws IOException { - return new BufferPartialInputStreamImpl(byteBuffer, 0, byteBuffer.limit()); - } -} diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/SerInputStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/SerInputStream.java new file mode 100644 index 0000000000..1ea6325ac4 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/serializer/SerInputStream.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.serializer; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; + +import io.netty.buffer.ByteBuf; + +public abstract class SerInputStream extends InputStream { + + public void init() {} + + @Override + public abstract int available(); + + public abstract long getStart(); + + public abstract long getEnd(); + + public abstract void transferTo(ByteBuf to, int len) throws IOException; + + public static SerInputStream newInputStream(File file) throws IOException { + return SerInputStream.newInputStream(file, 0, file.length()); + } + + public static SerInputStream newInputStream(File file, long start) throws IOException { + return SerInputStream.newInputStream(file, start, file.length()); + } + + public static SerInputStream newInputStream(File file, long start, long end) throws IOException { + return new FileSerInputStream(file, start, Math.min(end, file.length())); + } + + public static SerInputStream newInputStream(ByteBuf byteBuf) { + return SerInputStream.newInputStream(byteBuf, 0, byteBuf.writerIndex()); + } + + public static SerInputStream newInputStream(ByteBuf byteBuf, int start) { + return SerInputStream.newInputStream(byteBuf, start, byteBuf.writerIndex()); + } + + public static SerInputStream newInputStream(ByteBuf byteBuf, int start, int end) { + return new BufferSerInputStream(byteBuf, start, Math.min(byteBuf.writerIndex(), end)); + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/SerOutputStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/SerOutputStream.java new file mode 100644 index 0000000000..9c599523c9 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/serializer/SerOutputStream.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.serializer; + +import java.io.IOException; +import java.io.OutputStream; + +import io.netty.buffer.ByteBuf; + +import org.apache.uniffle.common.exception.RssException; + +public abstract class SerOutputStream extends OutputStream { + + public abstract void write(ByteBuf from) throws IOException; + + public ByteBuf toByteBuf() { + throw new RssException("toByteBuf is not supported"); + } + + public void preAllocate(int length) throws IOException {} +} diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/SerializationStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/SerializationStream.java index 421b530318..239abbb9c5 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/SerializationStream.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/SerializationStream.java @@ -21,6 +21,8 @@ public abstract class SerializationStream { + public abstract void init(); + public abstract void writeRecord(Object key, Object value) throws IOException; public abstract void flush() throws IOException; diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/SerializerInstance.java b/common/src/main/java/org/apache/uniffle/common/serializer/SerializerInstance.java index e079fb9a91..a00dfebe56 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/SerializerInstance.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/SerializerInstance.java @@ -19,7 +19,6 @@ import java.io.DataOutputStream; import java.io.IOException; -import java.io.OutputStream; import org.apache.hadoop.io.DataInputBuffer; @@ -29,8 +28,9 @@ public abstract class SerializerInstance { public abstract T deserialize(DataInputBuffer buffer, Class vClass) throws IOException; - public abstract SerializationStream serializeStream(OutputStream output, boolean raw); + public abstract SerializationStream serializeStream( + SerOutputStream output, boolean raw, boolean buffered); public abstract DeserializationStream deserializeStream( - PartialInputStream input, Class keyClass, Class valueClass, boolean raw); + SerInputStream input, Class keyClass, Class valueClass, boolean raw, boolean buffered); } diff --git a/common/src/main/java/org/apache/uniffle/common/merger/Recordable.java b/common/src/main/java/org/apache/uniffle/common/serializer/WrappedByteArrayOutputStream.java similarity index 66% rename from common/src/main/java/org/apache/uniffle/common/merger/Recordable.java rename to common/src/main/java/org/apache/uniffle/common/serializer/WrappedByteArrayOutputStream.java index 79604b7f61..4200417ef1 100644 --- a/common/src/main/java/org/apache/uniffle/common/merger/Recordable.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/WrappedByteArrayOutputStream.java @@ -15,16 +15,22 @@ * limitations under the License. */ -package org.apache.uniffle.common.merger; +package org.apache.uniffle.common.serializer; -import java.io.IOException; +import java.io.ByteArrayOutputStream; -public interface Recordable { +/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ +public class WrappedByteArrayOutputStream extends ByteArrayOutputStream { - @FunctionalInterface - interface Flushable { - void flush() throws IOException; + public WrappedByteArrayOutputStream() { + super(); } - boolean record(long written, Flushable flush, boolean force) throws IOException; + public WrappedByteArrayOutputStream(int size) { + super(size); + } + + public byte[] getBuf() { + return buf; + } } diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/writable/BufferedRawWritableDeserializationStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/writable/BufferedRawWritableDeserializationStream.java new file mode 100644 index 0000000000..55e2d55417 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/serializer/writable/BufferedRawWritableDeserializationStream.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.serializer.writable; + +import java.io.DataInputStream; +import java.io.IOException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableUtils; + +import org.apache.uniffle.common.serializer.DeserializationStream; +import org.apache.uniffle.common.serializer.SerInputStream; +import org.apache.uniffle.common.util.NettyUtils; + +// Compare to RawWritableDeserializationStream, BufferedRawWritableDeserializationStream use shared +// buffer to store record. It means that after we use nextRecord, we store the record to shared +// buffer. It means we must use this before next nextRecord. +// Usually, BufferedRawWritableDeserializationStream is used on the server side and +// RawWritableDeserializationStream is used on the client side. Because the records obtained +// in BufferedRawWritableDeserializationStream are quickly used to form merged block, using +// shared buffer can avoid frequent memory requests. On the client side, the records obtained +// are generally used for subsequent data processing and must be independent copies, so +// RawWritableDeserializationStream is used in client side. +public class BufferedRawWritableDeserializationStream + extends DeserializationStream { + + private static final int INIT_BUFFER_SIZE = 256; + private static final int EOF_MARKER = -1; // End of File Marker + + private SerInputStream inputStream; + private DataInputStream dataIn; + + private ByteBuf currentKeyBuffer; + private ByteBuf currentValueBuffer; + + public BufferedRawWritableDeserializationStream( + WritableSerializerInstance instance, SerInputStream inputStream) { + this.inputStream = inputStream; + } + + @Override + public void init() { + this.inputStream.init(); + this.dataIn = new DataInputStream(inputStream); + // We will use key to compare, so use heap memory.Since we have a copy of the data, + // the intermediate results can be placed either off-heap or heap. Using heap memory + // here is also more secure. + UnpooledByteBufAllocator allocator = NettyUtils.getSharedUnpooledByteBufAllocator(true); + this.currentKeyBuffer = allocator.heapBuffer(INIT_BUFFER_SIZE); + this.currentValueBuffer = allocator.heapBuffer(INIT_BUFFER_SIZE); + } + + @Override + public boolean nextRecord() throws IOException { + if (inputStream.available() <= 0) { + return false; + } + int currentKeyLength = WritableUtils.readVInt(dataIn); + int currentValueLength = WritableUtils.readVInt(dataIn); + if (currentKeyLength == EOF_MARKER && currentValueLength == EOF_MARKER) { + return false; + } + currentKeyBuffer.clear(); + inputStream.transferTo(currentKeyBuffer, currentKeyLength); + currentValueBuffer.clear(); + inputStream.transferTo(currentValueBuffer, currentValueLength); + return true; + } + + @Override + public ByteBuf getCurrentKey() { + return currentKeyBuffer; + } + + @Override + public ByteBuf getCurrentValue() { + return currentValueBuffer; + } + + @Override + public void close() throws IOException { + if (currentKeyBuffer != null) { + currentKeyBuffer.release(); + currentKeyBuffer = null; + } + if (currentValueBuffer != null) { + currentValueBuffer.release(); + currentValueBuffer = null; + } + if (inputStream != null) { + inputStream.close(); + inputStream = null; + } + if (dataIn != null) { + dataIn.close(); + dataIn = null; + } + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/writable/BufferedRawWritableSerializationStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/writable/BufferedRawWritableSerializationStream.java new file mode 100644 index 0000000000..9eff8352e3 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/serializer/writable/BufferedRawWritableSerializationStream.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.serializer.writable; + +import java.io.DataOutputStream; +import java.io.IOException; + +import io.netty.buffer.ByteBuf; +import org.apache.hadoop.io.WritableUtils; + +import org.apache.uniffle.common.serializer.SerOutputStream; +import org.apache.uniffle.common.serializer.SerializationStream; + +public class BufferedRawWritableSerializationStream extends SerializationStream { + + // DataOutputStream::size return int, can not support big file which is larger than + // Integer.MAX_VALUE. + // Here introduce totalBytesWritten to record the written bytes. + private long totalBytesWritten = 0; + private SerOutputStream output; + private DataOutputStream dataOut; + + public BufferedRawWritableSerializationStream( + WritableSerializerInstance instance, SerOutputStream output) { + this.output = output; + } + + @Override + public void init() { + this.dataOut = new DataOutputStream(this.output); + } + + @Override + public void writeRecord(Object key, Object value) throws IOException { + ByteBuf keyBuffer = (ByteBuf) key; + ByteBuf valueBuffer = (ByteBuf) value; + int keyLength = keyBuffer.readableBytes(); + int valueLength = valueBuffer.readableBytes(); + int toWriteLength = + WritableUtils.getVIntSize(keyLength) + + WritableUtils.getVIntSize(valueLength) + + keyLength + + valueLength; + this.output.preAllocate(toWriteLength); + WritableUtils.writeVInt(dataOut, keyLength); + WritableUtils.writeVInt(dataOut, valueLength); + output.write(keyBuffer); + output.write(valueBuffer); + totalBytesWritten += toWriteLength; + } + + @Override + public void flush() throws IOException { + dataOut.flush(); + } + + @Override + public void close() throws IOException { + if (dataOut != null) { + dataOut.close(); + dataOut = null; + } + } + + @Override + public long getTotalBytesWritten() { + return totalBytesWritten; + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/writable/RawWritableDeserializationStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/writable/RawWritableDeserializationStream.java index 54e4b01bf5..d29cd43c69 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/writable/RawWritableDeserializationStream.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/writable/RawWritableDeserializationStream.java @@ -24,22 +24,27 @@ import org.apache.hadoop.io.WritableUtils; import org.apache.uniffle.common.serializer.DeserializationStream; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.SerInputStream; public class RawWritableDeserializationStream extends DeserializationStream { public static final int EOF_MARKER = -1; // End of File Marker - private PartialInputStream inputStream; + private SerInputStream inputStream; private DataInputStream dataIn; private ComparativeOutputBuffer currentKeyBuffer; private ComparativeOutputBuffer currentValueBuffer; public RawWritableDeserializationStream( - WritableSerializerInstance instance, PartialInputStream inputStream) { + WritableSerializerInstance instance, SerInputStream inputStream) { this.inputStream = inputStream; - this.dataIn = new DataInputStream(inputStream); + } + + @Override + public void init() { + this.inputStream.init(); + this.dataIn = new DataInputStream(this.inputStream); } @Override @@ -71,6 +76,10 @@ public ComparativeOutputBuffer getCurrentValue() { @Override public void close() throws IOException { + if (inputStream != null) { + inputStream.close(); + inputStream = null; + } if (dataIn != null) { dataIn.close(); dataIn = null; diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/writable/RawWritableSerializationStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/writable/RawWritableSerializationStream.java index e5fa2f1ef0..9e940075bd 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/writable/RawWritableSerializationStream.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/writable/RawWritableSerializationStream.java @@ -19,29 +19,30 @@ import java.io.DataOutputStream; import java.io.IOException; -import java.io.OutputStream; import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.WritableUtils; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.serializer.SerializationStream; public class RawWritableSerializationStream extends SerializationStream { - public static final int EOF_MARKER = -1; // End of File Marker - private DataOutputStream dataOut; + private SerOutputStream output; // DataOutputStream::size return int, can not support big file which is larger than // Integer.MAX_VALUE. // Here introduce totalBytesWritten to record the written bytes. private long totalBytesWritten = 0; - public RawWritableSerializationStream(WritableSerializerInstance instance, OutputStream out) { - if (out instanceof DataOutputStream) { - dataOut = (DataOutputStream) out; - } else { - dataOut = new DataOutputStream(out); - } + public RawWritableSerializationStream( + WritableSerializerInstance instance, SerOutputStream output) { + this.output = output; + } + + @Override + public void init() { + this.dataOut = new DataOutputStream(this.output); } @Override @@ -50,16 +51,18 @@ public void writeRecord(Object key, Object value) throws IOException { DataOutputBuffer valueBuffer = (DataOutputBuffer) value; int keyLength = keyBuffer.getLength(); int valueLength = valueBuffer.getLength(); + int toWriteLength = + WritableUtils.getVIntSize(keyLength) + + WritableUtils.getVIntSize(valueLength) + + keyBuffer.getLength() + + valueBuffer.getLength(); + output.preAllocate(toWriteLength); // write size and buffer to output WritableUtils.writeVInt(dataOut, keyLength); WritableUtils.writeVInt(dataOut, valueLength); keyBuffer.writeTo(dataOut); valueBuffer.writeTo(dataOut); - totalBytesWritten += - WritableUtils.getVIntSize(keyLength) - + WritableUtils.getVIntSize(valueLength) - + keyBuffer.getLength() - + valueBuffer.getLength(); + totalBytesWritten += toWriteLength; } @Override @@ -70,9 +73,6 @@ public void flush() throws IOException { @Override public void close() throws IOException { if (dataOut != null) { - WritableUtils.writeVInt(dataOut, EOF_MARKER); - WritableUtils.writeVInt(dataOut, EOF_MARKER); - totalBytesWritten += 2 * WritableUtils.getVIntSize(EOF_MARKER); dataOut.close(); dataOut = null; } diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableDeserializationStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableDeserializationStream.java index 8b016fa275..caadfbaf1f 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableDeserializationStream.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableDeserializationStream.java @@ -25,14 +25,14 @@ import org.apache.hadoop.util.ReflectionUtils; import org.apache.uniffle.common.serializer.DeserializationStream; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.SerInputStream; public class WritableDeserializationStream extends DeserializationStream { public static final int EOF_MARKER = -1; // End of File Marker - private PartialInputStream inputStream; + private SerInputStream inputStream; private DataInputStream dataIn; private Class keyClass; private Class valueClass; @@ -41,12 +41,17 @@ public class WritableDeserializationStream keyClass, Class valueClass) { this.inputStream = inputStream; this.keyClass = keyClass; this.valueClass = valueClass; + } + + @Override + public void init() { + this.inputStream.init(); this.dataIn = new DataInputStream(inputStream); } @@ -79,6 +84,10 @@ public V getCurrentValue() { @Override public void close() throws IOException { + if (inputStream != null) { + inputStream.close(); + inputStream = null; + } if (dataIn != null) { dataIn.close(); dataIn = null; diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableSerializationStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableSerializationStream.java index 07ed07a14e..dfdfcb3319 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableSerializationStream.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableSerializationStream.java @@ -19,19 +19,18 @@ import java.io.DataOutputStream; import java.io.IOException; -import java.io.OutputStream; import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableUtils; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.serializer.SerializationStream; public class WritableSerializationStream extends SerializationStream { - public static final int EOF_MARKER = -1; // End of File Marker - + private SerOutputStream output; private DataOutputStream dataOut; // DataOutputStream::size return int, can not support big file which is larger than // Integer.MAX_VALUE. @@ -40,12 +39,13 @@ public class WritableSerializationStream DataOutputBuffer buffer = new DataOutputBuffer(); DataOutputBuffer sizebuffer = new DataOutputBuffer(); - public WritableSerializationStream(WritableSerializerInstance instance, OutputStream out) { - if (out instanceof DataOutputStream) { - dataOut = (DataOutputStream) out; - } else { - dataOut = new DataOutputStream(out); - } + public WritableSerializationStream(WritableSerializerInstance instance, SerOutputStream out) { + this.output = out; + } + + @Override + public void init() { + this.dataOut = new DataOutputStream(this.output); } @Override @@ -56,6 +56,12 @@ public void writeRecord(Object key, Object value) throws IOException { int keyLength = buffer.getLength(); ((Writable) value).write(buffer); int valueLength = buffer.getLength() - keyLength; + int toWriteLength = + WritableUtils.getVIntSize(keyLength) + + WritableUtils.getVIntSize(valueLength) + + keyLength + + valueLength; + output.preAllocate(toWriteLength); // write size and buffer to output sizebuffer.reset(); @@ -63,7 +69,7 @@ public void writeRecord(Object key, Object value) throws IOException { WritableUtils.writeVInt(sizebuffer, valueLength); sizebuffer.writeTo(dataOut); buffer.writeTo(dataOut); - totalBytesWritten += sizebuffer.getLength() + buffer.getLength(); + totalBytesWritten += toWriteLength; } @Override @@ -74,8 +80,6 @@ public void flush() throws IOException { @Override public void close() throws IOException { if (dataOut != null) { - WritableUtils.writeVInt(dataOut, EOF_MARKER); - WritableUtils.writeVInt(dataOut, EOF_MARKER); dataOut.close(); dataOut = null; } diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableSerializerInstance.java b/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableSerializerInstance.java index 610a3463fc..7a4b84a584 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableSerializerInstance.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/writable/WritableSerializerInstance.java @@ -19,7 +19,6 @@ import java.io.DataOutputStream; import java.io.IOException; -import java.io.OutputStream; import org.apache.hadoop.io.DataInputBuffer; import org.apache.hadoop.io.Writable; @@ -27,7 +26,8 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.serializer.DeserializationStream; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.SerInputStream; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.serializer.SerializationStream; import org.apache.uniffle.common.serializer.SerializerInstance; @@ -47,9 +47,14 @@ public T deserialize(DataInputBuffer buffer, Class vClass) throws IOExceptio } @Override - public SerializationStream serializeStream(OutputStream output, boolean raw) { + public SerializationStream serializeStream( + SerOutputStream output, boolean raw, boolean buffered) { if (raw) { - return new RawWritableSerializationStream(this, output); + if (buffered) { + return new BufferedRawWritableSerializationStream(this, output); + } else { + return new RawWritableSerializationStream(this, output); + } } else { return new WritableSerializationStream(this, output); } @@ -57,9 +62,13 @@ public SerializationStream serializeStream(OutputStream output, boolean r @Override public DeserializationStream deserializeStream( - PartialInputStream input, Class keyClass, Class valueClass, boolean raw) { + SerInputStream input, Class keyClass, Class valueClass, boolean raw, boolean buffered) { if (raw) { - return new RawWritableDeserializationStream(this, input); + if (buffered) { + return new BufferedRawWritableDeserializationStream(this, input); + } else { + return new RawWritableDeserializationStream(this, input); + } } else { return new WritableDeserializationStream(this, input, keyClass, valueClass); } diff --git a/common/src/test/java/org/apache/uniffle/common/merger/MergerTest.java b/common/src/test/java/org/apache/uniffle/common/merger/MergerTest.java index 970b05d66c..001c1175d6 100644 --- a/common/src/test/java/org/apache/uniffle/common/merger/MergerTest.java +++ b/common/src/test/java/org/apache/uniffle/common/merger/MergerTest.java @@ -18,21 +18,27 @@ package org.apache.uniffle.common.merger; import java.io.File; -import java.io.FileOutputStream; import java.util.ArrayList; import java.util.Comparator; import java.util.List; -import org.apache.hadoop.io.RawComparator; +import io.netty.buffer.ByteBuf; +import org.apache.hadoop.io.DataInputBuffer; import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.records.RecordsReader; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.DynBufferSerOutputStream; +import org.apache.uniffle.common.serializer.SerInputStream; +import org.apache.uniffle.common.serializer.SerOutputStream; +import org.apache.uniffle.common.serializer.Serializer; +import org.apache.uniffle.common.serializer.SerializerFactory; +import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.SerializerUtils; +import static org.apache.uniffle.common.serializer.SerializerUtils.genData; import static org.junit.jupiter.api.Assertions.assertEquals; public class MergerTest { @@ -43,66 +49,70 @@ public class MergerTest { @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false,false", }) public void testMergeSegmentToFile(String classes, @TempDir File tmpDir) throws Exception { // 1 Parse arguments String[] classArray = classes.split(","); Class keyClass = SerializerUtils.getClassByName(classArray[0]); Class valueClass = SerializerUtils.getClassByName(classArray[1]); + boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + boolean direct = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; // 2 Construct segments, then merge RssConf rssConf = new RssConf(); List segments = new ArrayList<>(); Comparator comparator = SerializerUtils.getComparator(keyClass); for (int i = 0; i < SEGMENTS; i++) { - if (i % 2 == 0) { - segments.add( - SerializerUtils.genMemorySegment( - rssConf, - keyClass, - valueClass, - i, - i, - SEGMENTS, - RECORDS, - comparator instanceof RawComparator)); - } else { - segments.add( - SerializerUtils.genFileSegment( - rssConf, - keyClass, - valueClass, - i, - i, - SEGMENTS, - RECORDS, - tmpDir, - comparator instanceof RawComparator)); - } + Segment segment = + i % 2 == 0 + ? SerializerUtils.genMemorySegment( + rssConf, keyClass, valueClass, i, i, SEGMENTS, RECORDS, raw, direct) + : SerializerUtils.genFileSegment( + rssConf, keyClass, valueClass, i, i, SEGMENTS, RECORDS, tmpDir, raw); + segment.init(); + segments.add(segment); } - File mergedFile = new File(tmpDir, "data.merged"); - FileOutputStream outputStream = new FileOutputStream(mergedFile); - Merger.merge( - rssConf, - outputStream, - segments, - keyClass, - valueClass, - comparator, - comparator instanceof RawComparator); + SerOutputStream outputStream = new DynBufferSerOutputStream(); + Merger.merge(rssConf, outputStream, segments, keyClass, valueClass, comparator, raw); outputStream.close(); - // 3 Check the merged file + // 3 Check the merged + ByteBuf byteBuf = outputStream.toByteBuf(); RecordsReader reader = new RecordsReader( - rssConf, PartialInputStream.newInputStream(mergedFile), keyClass, valueClass, false); + rssConf, SerInputStream.newInputStream(byteBuf), keyClass, valueClass, raw, true); + reader.init(); + SerializerFactory factory = new SerializerFactory(rssConf); + Serializer serializer = factory.getSerializer(keyClass); + assert factory.getSerializer(valueClass).getClass().equals(serializer.getClass()); + SerializerInstance instance = serializer.newInstance(); int index = 0; while (reader.next()) { - assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); - assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); + if (raw) { + ByteBuf keyByteBuffer = (ByteBuf) reader.getCurrentKey(); + ByteBuf valueByteBuffer = (ByteBuf) reader.getCurrentValue(); + byte[] keyBytes = new byte[keyByteBuffer.readableBytes()]; + byte[] valueBytes = new byte[valueByteBuffer.readableBytes()]; + keyByteBuffer.readBytes(keyBytes); + valueByteBuffer.readBytes(valueBytes); + DataInputBuffer keyInputBuffer = new DataInputBuffer(); + keyInputBuffer.reset(keyBytes, 0, keyBytes.length); + assertEquals(genData(keyClass, index), instance.deserialize(keyInputBuffer, keyClass)); + DataInputBuffer valueInputBuffer = new DataInputBuffer(); + valueInputBuffer.reset(valueBytes, 0, valueBytes.length); + assertEquals( + genData(valueClass, index), instance.deserialize(valueInputBuffer, valueClass)); + } else { + assertEquals(genData(keyClass, index), reader.getCurrentKey()); + assertEquals(genData(valueClass, index), reader.getCurrentValue()); + } index++; } + byteBuf.release(); assertEquals(RECORDS * SEGMENTS, index); reader.close(); } diff --git a/common/src/test/java/org/apache/uniffle/common/netty/TransportFrameDecoderTest.java b/common/src/test/java/org/apache/uniffle/common/netty/TransportFrameDecoderTest.java index 1f907ebe9b..539c910a01 100644 --- a/common/src/test/java/org/apache/uniffle/common/netty/TransportFrameDecoderTest.java +++ b/common/src/test/java/org/apache/uniffle/common/netty/TransportFrameDecoderTest.java @@ -31,6 +31,7 @@ import org.apache.uniffle.common.BufferSegment; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.netty.buffer.ManagedBuffer; import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataRequest; import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataResponse; @@ -38,6 +39,8 @@ import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse; import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest; import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataResponse; +import org.apache.uniffle.common.netty.protocol.GetSortedShuffleDataRequest; +import org.apache.uniffle.common.netty.protocol.GetSortedShuffleDataResponse; import org.apache.uniffle.common.netty.protocol.Message; import org.apache.uniffle.common.netty.protocol.RpcResponse; import org.apache.uniffle.common.netty.protocol.SendShuffleDataRequest; @@ -96,6 +99,18 @@ public void testShouldRpcResponsesToBeReleased() { assertFalse(TransportFrameDecoder.shouldRelease(message4)); // after processing some business logic in the code, and finally release the body buffer message4.body().release(); + + GetSortedShuffleDataResponse rpcResponse5 = generateGetSortedShuffleDataResponse(); + int length5 = rpcResponse5.encodedLength(); + byte[] body5 = generateBody(); + ByteBuf byteBuf5 = Unpooled.buffer(length5 + body5.length); + rpcResponse5.encode(byteBuf5); + assertEquals(byteBuf5.readableBytes(), length5); + byteBuf5.writeBytes(body5); + Message message5 = Message.decode(rpcResponse5.type(), byteBuf5); + assertFalse(TransportFrameDecoder.shouldRelease(message5)); + // after processing some business logic in the code, and finally release the body buffer + message5.body().release(); } /** test if the RPC request should be released after decoding */ @@ -136,6 +151,15 @@ public void testShouldRpcRequestsToBeReleased() { Message message4 = Message.decode(rpcRequest4.type(), byteBuf4); assertTrue(TransportFrameDecoder.shouldRelease(message4)); byteBuf4.release(); + + GetSortedShuffleDataRequest rpcRequest5 = generateGetSortedShuffleDataRequest(); + int length5 = rpcRequest5.encodedLength(); + ByteBuf byteBuf5 = Unpooled.buffer(length5); + rpcRequest5.encode(byteBuf5); + assertEquals(byteBuf5.readableBytes(), length5); + Message message5 = Message.decode(rpcRequest5.type(), byteBuf5); + assertTrue(TransportFrameDecoder.shouldRelease(message5)); + byteBuf5.release(); } private byte[] generateBody() { @@ -250,4 +274,15 @@ private GetMemoryShuffleDataRequest generateGetMemoryShuffleDataRequest() { return new GetMemoryShuffleDataRequest( 1, "test_app", 1, 1, 1, 64, System.currentTimeMillis(), expectedTaskIdsBitmap); } + + private GetSortedShuffleDataRequest generateGetSortedShuffleDataRequest() { + return new GetSortedShuffleDataRequest(1, "test_app", 2, 3, 4, 100, System.currentTimeMillis()); + } + + private GetSortedShuffleDataResponse generateGetSortedShuffleDataResponse() { + byte[] data = new byte[] {1, 2, 3, 4, 5}; + ByteBuf byteBuf = Unpooled.wrappedBuffer(data); + ManagedBuffer managedBuffer = new NettyManagedBuffer(byteBuf); + return new GetSortedShuffleDataResponse(1, StatusCode.SUCCESS, "OK", 5L, 0, managedBuffer); + } } diff --git a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java index a370828c38..072985d096 100644 --- a/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java +++ b/common/src/test/java/org/apache/uniffle/common/netty/protocol/NettyProtocolTest.java @@ -31,6 +31,7 @@ import org.apache.uniffle.common.BufferSegment; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.netty.buffer.ManagedBuffer; import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; import org.apache.uniffle.common.rpc.StatusCode; @@ -295,4 +296,52 @@ public void testGetMemoryShuffleDataResponse() { getMemoryShuffleDataResponse1.getBufferSegments().get(i)); } } + + @Test + public void testGetSortedShuffleDataRequest() { + GetSortedShuffleDataRequest request = + new GetSortedShuffleDataRequest(1, "test_app", 2, 3, 4, 100, System.currentTimeMillis()); + + int encodeLength = request.encodedLength(); + ByteBuf byteBuf = Unpooled.buffer(encodeLength); + request.encode(byteBuf); + + assertEquals(byteBuf.readableBytes(), encodeLength); + + GetSortedShuffleDataRequest decodedRequest = GetSortedShuffleDataRequest.decode(byteBuf); + + assertEquals(request.getRequestId(), decodedRequest.getRequestId()); + assertEquals(request.getAppId(), decodedRequest.getAppId()); + assertEquals(request.getShuffleId(), decodedRequest.getShuffleId()); + assertEquals(request.getPartitionId(), decodedRequest.getPartitionId()); + assertEquals(request.getBlockId(), decodedRequest.getBlockId()); + assertEquals(request.getLength(), decodedRequest.getLength()); + assertEquals(request.getTimestamp(), decodedRequest.getTimestamp()); + + byteBuf.release(); + } + + @Test + public void testGetSortedShuffleDataResponse() { + byte[] data = new byte[] {1, 2, 3, 4, 5}; + ManagedBuffer managedBuffer = new NettyManagedBuffer(Unpooled.wrappedBuffer(data).retain()); + GetSortedShuffleDataResponse response = + new GetSortedShuffleDataResponse(1, StatusCode.SUCCESS, "OK", -1, 0, managedBuffer); + + int encodeLength = response.encodedLength(); + ByteBuf byteBuf = Unpooled.buffer(encodeLength); + response.encode(byteBuf); + + assertEquals(byteBuf.readableBytes(), encodeLength); + + GetSortedShuffleDataResponse decodedResponse = + GetSortedShuffleDataResponse.decode(byteBuf, true); + + assertEquals(response.getRequestId(), decodedResponse.getRequestId()); + assertEquals(response.getStatusCode(), decodedResponse.getStatusCode()); + assertEquals(response.getRetMessage(), decodedResponse.getRetMessage()); + assertEquals(response.getNextBlockId(), decodedResponse.getNextBlockId()); + + byteBuf.release(); + } } diff --git a/common/src/test/java/org/apache/uniffle/common/records/RecordsReaderWriterTest.java b/common/src/test/java/org/apache/uniffle/common/records/RecordsReaderWriterTest.java index a5f26b9a43..e3f8afc45b 100644 --- a/common/src/test/java/org/apache/uniffle/common/records/RecordsReaderWriterTest.java +++ b/common/src/test/java/org/apache/uniffle/common/records/RecordsReaderWriterTest.java @@ -17,13 +17,12 @@ package org.apache.uniffle.common.records; -import java.io.ByteArrayOutputStream; import java.io.File; -import java.io.FileOutputStream; -import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.Random; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import org.apache.hadoop.io.DataInputBuffer; import org.apache.hadoop.io.DataOutputBuffer; import org.junit.jupiter.api.io.TempDir; @@ -31,7 +30,10 @@ import org.junit.jupiter.params.provider.ValueSource; import org.apache.uniffle.common.config.RssConf; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.DynBufferSerOutputStream; +import org.apache.uniffle.common.serializer.FileSerOutputStream; +import org.apache.uniffle.common.serializer.SerInputStream; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.serializer.Serializer; import org.apache.uniffle.common.serializer.SerializerFactory; import org.apache.uniffle.common.serializer.SerializerInstance; @@ -46,107 +48,27 @@ public class RecordsReaderWriterTest { private static final int RECORDS = 1009; private static final int LOOP = 5; - // Test 1: both write and read will use common api @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem,true,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem,true,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem,false,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem,false,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file,true,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file,true,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file,false,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file,false,false", }) - public void testWriteAndReadRecordFile1(String classes, @TempDir File tmpDir) throws Exception { - RssConf rssConf = new RssConf(); - // 1 Parse arguments - String[] classArray = classes.split(","); - Class keyClass = SerializerUtils.getClassByName(classArray[0]); - Class valueClass = SerializerUtils.getClassByName(classArray[1]); - boolean isFileMode = classArray[2].equals("file"); - File tmpFile = new File(tmpDir, "tmp.data"); - - // 2 Write - long[] offsets = new long[RECORDS]; - OutputStream outputStream = - isFileMode ? new FileOutputStream(tmpFile) : new ByteArrayOutputStream(); - RecordsWriter writer = new RecordsWriter(rssConf, outputStream, keyClass, valueClass, false); - for (int i = 0; i < RECORDS; i++) { - writer.append(SerializerUtils.genData(keyClass, i), SerializerUtils.genData(valueClass, i)); - offsets[i] = writer.getTotalBytesWritten(); - } - writer.close(); - - // 3 Read - // 3.1 read from start - PartialInputStream inputStream = - isFileMode - ? PartialInputStream.newInputStream(tmpFile) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray())); - RecordsReader reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); - int index = 0; - while (reader.next()) { - assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); - assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); - index++; - } - assertEquals(RECORDS, index); - reader.close(); - - // 3.2 read from end - inputStream = - isFileMode - ? PartialInputStream.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - offsets[RECORDS - 1], - ((ByteArrayOutputStream) outputStream).size()); - reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); - assertFalse(reader.next()); - reader.close(); - - // 3.3 read from random position to end - Random random = new Random(); - long[][] indexAndOffsets = new long[LOOP + 3][2]; - indexAndOffsets[0] = new long[] {0, 0}; - indexAndOffsets[1] = new long[] {RECORDS - 1, offsets[RECORDS - 2]}; // Last record - indexAndOffsets[2] = new long[] {RECORDS, offsets[RECORDS - 1]}; // Records that don't exist - for (int i = 0; i < LOOP; i++) { - int off = random.nextInt(RECORDS - 2) + 1; - indexAndOffsets[i + 3] = new long[] {off + 1, offsets[off]}; - } - for (long[] indexAndOffset : indexAndOffsets) { - index = (int) indexAndOffset[0]; - long offset = indexAndOffset[1]; - inputStream = - isFileMode - ? PartialInputStream.newInputStream(tmpFile, offset, tmpFile.length()) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - offset, - ((ByteArrayOutputStream) outputStream).size()); - reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); - while (reader.next()) { - assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); - assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); - index++; - } - assertEquals(RECORDS, index); - } - reader.close(); - } - - // Test 2: write with common api, read with raw api - @ParameterizedTest - @ValueSource( - strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file", - }) - public void testWriteAndReadRecordFile2(String classes, @TempDir File tmpDir) throws Exception { + public void testWriteAndReadRecordFile(String classes, @TempDir File tmpDir) throws Exception { RssConf rssConf = new RssConf(); // 1 Parse arguments String[] classArray = classes.split(","); Class keyClass = SerializerUtils.getClassByName(classArray[0]); Class valueClass = SerializerUtils.getClassByName(classArray[1]); boolean isFileMode = classArray[2].equals("file"); + final boolean serRaw = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; + final boolean derRaw = classArray.length > 4 ? Boolean.parseBoolean(classArray[4]) : false; File tmpFile = new File(tmpDir, "tmp.data"); SerializerFactory factory = new SerializerFactory(rssConf); Serializer serializer = factory.getSerializer(keyClass); @@ -155,75 +77,39 @@ public void testWriteAndReadRecordFile2(String classes, @TempDir File tmpDir) th // 2 Write long[] offsets = new long[RECORDS]; - OutputStream outputStream = - isFileMode ? new FileOutputStream(tmpFile) : new ByteArrayOutputStream(); - RecordsWriter writer = new RecordsWriter(rssConf, outputStream, keyClass, valueClass, false); + SerOutputStream outputStream = + isFileMode ? new FileSerOutputStream(tmpFile) : new DynBufferSerOutputStream(); + RecordsWriter writer = + new RecordsWriter(rssConf, outputStream, keyClass, valueClass, serRaw, false); + writer.init(); for (int i = 0; i < RECORDS; i++) { - writer.append(SerializerUtils.genData(keyClass, i), SerializerUtils.genData(valueClass, i)); - offsets[i] = writer.getTotalBytesWritten(); + if (serRaw) { + DataOutputBuffer keyBuffer = new DataOutputBuffer(); + DataOutputBuffer valueBuffer = new DataOutputBuffer(); + instance.serialize(genData(keyClass, i), keyBuffer); + instance.serialize(genData(valueClass, i), valueBuffer); + writer.append(keyBuffer, valueBuffer); + offsets[i] = writer.getTotalBytesWritten(); + } else { + writer.append(SerializerUtils.genData(keyClass, i), SerializerUtils.genData(valueClass, i)); + offsets[i] = writer.getTotalBytesWritten(); + } } writer.close(); // 3 Read // 3.1 read from start - PartialInputStream inputStream = + ByteBuf byteBuf = isFileMode ? null : outputStream.toByteBuf(); + SerInputStream inputStream = isFileMode - ? PartialInputStream.newInputStream(tmpFile) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray())); - RecordsReader reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); + ? SerInputStream.newInputStream(tmpFile) + : SerInputStream.newInputStream(byteBuf); + RecordsReader reader = + new RecordsReader(rssConf, inputStream, keyClass, valueClass, derRaw, false); + reader.init(); int index = 0; while (reader.next()) { - DataOutputBuffer keyBuffer = (DataOutputBuffer) reader.getCurrentKey(); - DataInputBuffer keyInputBuffer = new DataInputBuffer(); - keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength()); - assertEquals( - SerializerUtils.genData(keyClass, index), instance.deserialize(keyInputBuffer, keyClass)); - DataOutputBuffer valueBuffer = (DataOutputBuffer) reader.getCurrentValue(); - DataInputBuffer valueInputBuffer = new DataInputBuffer(); - valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength()); - assertEquals( - SerializerUtils.genData(valueClass, index), - instance.deserialize(valueInputBuffer, valueClass)); - index++; - } - assertEquals(RECORDS, index); - reader.close(); - - // 3.2 read from end - inputStream = - isFileMode - ? PartialInputStream.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - offsets[RECORDS - 1], - ((ByteArrayOutputStream) outputStream).size()); - reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); - assertFalse(reader.next()); - reader.close(); - - // 3.3 read from random position to end - Random random = new Random(); - long[][] indexAndOffsets = new long[LOOP + 3][2]; - indexAndOffsets[0] = new long[] {0, 0}; - indexAndOffsets[1] = new long[] {RECORDS - 1, offsets[RECORDS - 2]}; // Last record - indexAndOffsets[2] = new long[] {RECORDS, offsets[RECORDS - 1]}; // Records that don't exist - for (int i = 0; i < LOOP; i++) { - int off = random.nextInt(RECORDS - 2) + 1; - indexAndOffsets[i + 3] = new long[] {off + 1, offsets[off]}; - } - for (long[] indexAndOffset : indexAndOffsets) { - index = (int) indexAndOffset[0]; - long offset = indexAndOffset[1]; - inputStream = - isFileMode - ? PartialInputStream.newInputStream(tmpFile, offset, tmpFile.length()) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - offset, - ((ByteArrayOutputStream) outputStream).size()); - reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); - while (reader.next()) { + if (derRaw) { DataOutputBuffer keyBuffer = (DataOutputBuffer) reader.getCurrentKey(); DataInputBuffer keyInputBuffer = new DataInputBuffer(); keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength()); @@ -236,60 +122,10 @@ public void testWriteAndReadRecordFile2(String classes, @TempDir File tmpDir) th assertEquals( SerializerUtils.genData(valueClass, index), instance.deserialize(valueInputBuffer, valueClass)); - index++; + } else { + assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); } - assertEquals(RECORDS, index); - } - reader.close(); - } - - // Test 3: write with raw api, read with common api - @ParameterizedTest - @ValueSource( - strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file", - }) - public void testWriteAndReadRecordFile3(String classes, @TempDir File tmpDir) throws Exception { - RssConf rssConf = new RssConf(); - // 1 Parse arguments - String[] classArray = classes.split(","); - Class keyClass = SerializerUtils.getClassByName(classArray[0]); - Class valueClass = SerializerUtils.getClassByName(classArray[1]); - boolean isFileMode = classArray[2].equals("file"); - File tmpFile = new File(tmpDir, "tmp.data"); - SerializerFactory factory = new SerializerFactory(rssConf); - Serializer serializer = factory.getSerializer(keyClass); - assert factory.getSerializer(valueClass).getClass().equals(serializer.getClass()); - SerializerInstance instance = serializer.newInstance(); - - // 2 Write - long[] offsets = new long[RECORDS]; - OutputStream outputStream = - isFileMode ? new FileOutputStream(tmpFile) : new ByteArrayOutputStream(); - RecordsWriter writer = new RecordsWriter(rssConf, outputStream, keyClass, valueClass, true); - for (int i = 0; i < RECORDS; i++) { - DataOutputBuffer keyBuffer = new DataOutputBuffer(); - DataOutputBuffer valueBuffer = new DataOutputBuffer(); - instance.serialize(genData(keyClass, i), keyBuffer); - instance.serialize(genData(valueClass, i), valueBuffer); - writer.append(keyBuffer, valueBuffer); - offsets[i] = writer.getTotalBytesWritten(); - } - writer.close(); - - // 3 Read - // 3.1 read from start - PartialInputStream inputStream = - isFileMode - ? PartialInputStream.newInputStream(tmpFile) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray())); - RecordsReader reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); - int index = 0; - while (reader.next()) { - assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); - assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); index++; } assertEquals(RECORDS, index); @@ -298,12 +134,10 @@ public void testWriteAndReadRecordFile3(String classes, @TempDir File tmpDir) th // 3.2 read from end inputStream = isFileMode - ? PartialInputStream.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - offsets[RECORDS - 1], - ((ByteArrayOutputStream) outputStream).size()); - reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); + ? SerInputStream.newInputStream(tmpFile, offsets[RECORDS - 1]) + : SerInputStream.newInputStream(byteBuf, (int) offsets[RECORDS - 1]); + reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, derRaw, false); + reader.init(); assertFalse(reader.next()); reader.close(); @@ -322,30 +156,46 @@ public void testWriteAndReadRecordFile3(String classes, @TempDir File tmpDir) th long offset = indexAndOffset[1]; inputStream = isFileMode - ? PartialInputStream.newInputStream(tmpFile, offset, tmpFile.length()) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - offset, - ((ByteArrayOutputStream) outputStream).size()); - reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); + ? SerInputStream.newInputStream(tmpFile, offset) + : SerInputStream.newInputStream(byteBuf, (int) offset); + reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, derRaw, false); + reader.init(); while (reader.next()) { - assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); - assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); + if (derRaw) { + DataOutputBuffer keyBuffer = (DataOutputBuffer) reader.getCurrentKey(); + DataInputBuffer keyInputBuffer = new DataInputBuffer(); + keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength()); + assertEquals( + SerializerUtils.genData(keyClass, index), + instance.deserialize(keyInputBuffer, keyClass)); + DataOutputBuffer valueBuffer = (DataOutputBuffer) reader.getCurrentValue(); + DataInputBuffer valueInputBuffer = new DataInputBuffer(); + valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength()); + assertEquals( + SerializerUtils.genData(valueClass, index), + instance.deserialize(valueInputBuffer, valueClass)); + } else { + assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); + } index++; } assertEquals(RECORDS, index); } + if (!isFileMode) { + byteBuf.release(); + } reader.close(); } - // Test 4: both write and read use raw api @ParameterizedTest @ValueSource( strings = { "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem", "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file", }) - public void testWriteAndReadRecordFile4(String classes, @TempDir File tmpDir) throws Exception { + public void testWriteAndReadRecordFileUseDirect(String classes, @TempDir File tmpDir) + throws Exception { RssConf rssConf = new RssConf(); // 1 Parse arguments String[] classArray = classes.split(","); @@ -360,40 +210,51 @@ public void testWriteAndReadRecordFile4(String classes, @TempDir File tmpDir) th // 2 Write long[] offsets = new long[RECORDS]; - OutputStream outputStream = - isFileMode ? new FileOutputStream(tmpFile) : new ByteArrayOutputStream(); - RecordsWriter writer = new RecordsWriter(rssConf, outputStream, keyClass, valueClass, true); + SerOutputStream outputStream = + isFileMode ? new FileSerOutputStream(tmpFile) : new DynBufferSerOutputStream(); + RecordsWriter writer = + new RecordsWriter(rssConf, outputStream, keyClass, valueClass, true, true); + writer.init(); for (int i = 0; i < RECORDS; i++) { DataOutputBuffer keyBuffer = new DataOutputBuffer(); DataOutputBuffer valueBuffer = new DataOutputBuffer(); instance.serialize(genData(keyClass, i), keyBuffer); instance.serialize(genData(valueClass, i), valueBuffer); - writer.append(keyBuffer, valueBuffer); + ByteBuf kBuffer = Unpooled.buffer(keyBuffer.getLength()); + kBuffer.writeBytes(ByteBuffer.wrap(keyBuffer.getData(), 0, keyBuffer.getLength())); + ByteBuf vBuffer = Unpooled.buffer(valueBuffer.getLength()); + vBuffer.writeBytes(ByteBuffer.wrap(valueBuffer.getData(), 0, valueBuffer.getLength())); + writer.append(kBuffer, vBuffer); + kBuffer.release(); + vBuffer.release(); offsets[i] = writer.getTotalBytesWritten(); } writer.close(); // 3 Read // 3.1 read from start - PartialInputStream inputStream = + ByteBuf byteBuf = isFileMode ? null : outputStream.toByteBuf(); + SerInputStream inputStream = isFileMode - ? PartialInputStream.newInputStream(tmpFile) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray())); - RecordsReader reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); + ? SerInputStream.newInputStream(tmpFile) + : SerInputStream.newInputStream(byteBuf); + RecordsReader reader = + new RecordsReader(rssConf, inputStream, keyClass, valueClass, true, true); + reader.init(); int index = 0; while (reader.next()) { - DataOutputBuffer keyBuffer = (DataOutputBuffer) reader.getCurrentKey(); + ByteBuf keyByteBuf = (ByteBuf) reader.getCurrentKey(); + ByteBuf valueByteBuf = (ByteBuf) reader.getCurrentValue(); + byte[] keyBytes = new byte[keyByteBuf.readableBytes()]; + byte[] valueBytes = new byte[valueByteBuf.readableBytes()]; + keyByteBuf.readBytes(keyBytes); + valueByteBuf.readBytes(valueBytes); DataInputBuffer keyInputBuffer = new DataInputBuffer(); - keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength()); - assertEquals( - SerializerUtils.genData(keyClass, index), instance.deserialize(keyInputBuffer, keyClass)); - DataOutputBuffer valueBuffer = (DataOutputBuffer) reader.getCurrentValue(); + keyInputBuffer.reset(keyBytes, 0, keyBytes.length); + assertEquals(genData(keyClass, index), instance.deserialize(keyInputBuffer, keyClass)); DataInputBuffer valueInputBuffer = new DataInputBuffer(); - valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength()); - assertEquals( - SerializerUtils.genData(valueClass, index), - instance.deserialize(valueInputBuffer, valueClass)); + valueInputBuffer.reset(valueBytes, 0, valueBytes.length); + assertEquals(genData(valueClass, index), instance.deserialize(valueInputBuffer, valueClass)); index++; } assertEquals(RECORDS, index); @@ -402,12 +263,10 @@ public void testWriteAndReadRecordFile4(String classes, @TempDir File tmpDir) th // 3.2 read from end inputStream = isFileMode - ? PartialInputStream.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - offsets[RECORDS - 1], - ((ByteArrayOutputStream) outputStream).size()); - reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); + ? SerInputStream.newInputStream(tmpFile, offsets[RECORDS - 1]) + : SerInputStream.newInputStream(byteBuf, (int) offsets[RECORDS - 1]); + reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true, true); + reader.init(); assertFalse(reader.next()); reader.close(); @@ -426,29 +285,31 @@ public void testWriteAndReadRecordFile4(String classes, @TempDir File tmpDir) th long offset = indexAndOffset[1]; inputStream = isFileMode - ? PartialInputStream.newInputStream(tmpFile, offset, tmpFile.length()) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - offset, - ((ByteArrayOutputStream) outputStream).size()); - reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); + ? SerInputStream.newInputStream(tmpFile, offset) + : SerInputStream.newInputStream(byteBuf, (int) offset); + reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true, true); + reader.init(); while (reader.next()) { - DataOutputBuffer keyBuffer = (DataOutputBuffer) reader.getCurrentKey(); + ByteBuf keyByteBuf = (ByteBuf) reader.getCurrentKey(); + ByteBuf valueByteBuf = (ByteBuf) reader.getCurrentValue(); + byte[] keyBytes = new byte[keyByteBuf.readableBytes()]; + byte[] valueBytes = new byte[valueByteBuf.readableBytes()]; + keyByteBuf.readBytes(keyBytes); + valueByteBuf.readBytes(valueBytes); DataInputBuffer keyInputBuffer = new DataInputBuffer(); - keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength()); - assertEquals( - SerializerUtils.genData(keyClass, index), - instance.deserialize(keyInputBuffer, keyClass)); - DataOutputBuffer valueBuffer = (DataOutputBuffer) reader.getCurrentValue(); + keyInputBuffer.reset(keyBytes, 0, keyBytes.length); + assertEquals(genData(keyClass, index), instance.deserialize(keyInputBuffer, keyClass)); DataInputBuffer valueInputBuffer = new DataInputBuffer(); - valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength()); + valueInputBuffer.reset(valueBytes, 0, valueBytes.length); assertEquals( - SerializerUtils.genData(valueClass, index), - instance.deserialize(valueInputBuffer, valueClass)); + genData(valueClass, index), instance.deserialize(valueInputBuffer, valueClass)); index++; } assertEquals(RECORDS, index); } + if (!isFileMode) { + byteBuf.release(); + } reader.close(); } } diff --git a/common/src/test/java/org/apache/uniffle/common/serializer/PartialInputStreamTest.java b/common/src/test/java/org/apache/uniffle/common/serializer/SerInputOutputStreamTest.java similarity index 72% rename from common/src/test/java/org/apache/uniffle/common/serializer/PartialInputStreamTest.java rename to common/src/test/java/org/apache/uniffle/common/serializer/SerInputOutputStreamTest.java index 6be871a93b..940f573c59 100644 --- a/common/src/test/java/org/apache/uniffle/common/serializer/PartialInputStreamTest.java +++ b/common/src/test/java/org/apache/uniffle/common/serializer/SerInputOutputStreamTest.java @@ -18,86 +18,86 @@ package org.apache.uniffle.common.serializer; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.Random; -import org.junit.jupiter.api.BeforeAll; +import io.netty.buffer.ByteBuf; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; -public class PartialInputStreamTest { +public class SerInputOutputStreamTest { private static final int BYTES_LEN = 10240; - private static ByteBuffer testBuffer; private static final int LOOP = 10; @TempDir private static File tempDir; - private static File tempFile; - @BeforeAll - public static void initData() throws IOException { - byte[] bytes = new byte[BYTES_LEN]; + @Test + public void testReadMemoryInputStream() throws IOException { + SerOutputStream outputStream = new DynBufferSerOutputStream(); for (int i = 0; i < BYTES_LEN; i++) { - bytes[i] = (byte) (i & 0x7F); + outputStream.write((byte) (i & 0x7F)); } - testBuffer = ByteBuffer.wrap(bytes); - tempFile = new File(tempDir, "data"); - FileOutputStream output = new FileOutputStream(tempFile); - output.write(bytes); - output.close(); - } + ByteBuf testBuf = outputStream.toByteBuf(); - @Test - public void testReadMemroyInputStream() throws IOException { - // 1 test whole file - testRandomReadMemory(testBuffer, 0, BYTES_LEN); + // 1 test whole buffer + testRandomReadMemory(testBuf, 0, BYTES_LEN); // 2 test from start to random end Random random = new Random(); for (int i = 0; i < LOOP; i++) { - testRandomReadMemory(testBuffer, 0, random.nextInt(BYTES_LEN - 1)); + testRandomReadMemory(testBuf, 0, random.nextInt(BYTES_LEN - 1)); } // 3 test from random start to end for (int i = 0; i < LOOP; i++) { - testRandomReadMemory(testBuffer, random.nextInt(BYTES_LEN - 1), BYTES_LEN); + testRandomReadMemory(testBuf, random.nextInt(BYTES_LEN - 1), BYTES_LEN); } // 4 test from random start to random end for (int i = 0; i < LOOP; i++) { int r1 = random.nextInt(BYTES_LEN - 2) + 1; int r2 = random.nextInt(BYTES_LEN - 2) + 1; - testRandomReadMemory(testBuffer, Math.min(r1, r2), Math.max(r1, r2)); + testRandomReadMemory(testBuf, Math.min(r1, r2), Math.max(r1, r2)); } // 5 Test when bytes is from start to start - testRandomReadMemory(testBuffer, 0, 0); + testRandomReadMemory(testBuf, 0, 0); // 6 Test when bytes is from end to end - testRandomReadMemory(testBuffer, BYTES_LEN, BYTES_LEN); + testRandomReadMemory(testBuf, BYTES_LEN, BYTES_LEN); // 7 Test when bytes is from random to this random for (int i = 0; i < LOOP; i++) { int r = random.nextInt(BYTES_LEN - 2) + 1; - testRandomReadMemory(testBuffer, r, r); + testRandomReadMemory(testBuf, r, r); } + testBuf.release(); } @Test public void testReadNullBytes() throws IOException { // Test when bytes is byte[0] - PartialInputStream input = PartialInputStream.newInputStream(ByteBuffer.wrap(new byte[0])); + SerOutputStream outputStream = new DynBufferSerOutputStream(); + ByteBuf testBuf = outputStream.toByteBuf(); + SerInputStream input = SerInputStream.newInputStream(testBuf, 0, 0); assertEquals(0, input.available()); assertEquals(-1, input.read()); input.close(); + testBuf.release(); } @Test public void testReadFileInputStream() throws IOException { + File tempFile = new File(tempDir, "data"); + ; + SerOutputStream outputStream = new FileSerOutputStream(tempFile); + for (int i = 0; i < BYTES_LEN; i++) { + outputStream.write((byte) (i & 0x7F)); + } + // 1 test whole file testRandomReadFile(tempFile, 0, BYTES_LEN); @@ -132,28 +132,17 @@ public void testReadFileInputStream() throws IOException { } } - private void testRandomReadMemory(ByteBuffer byteBuffer, long start, long end) - throws IOException { - PartialInputStream input = PartialInputStream.newInputStream(byteBuffer, start, end); - testRandomReadOneBytePerTime(input, start, end); - input.close(); - - input = PartialInputStream.newInputStream(byteBuffer, start, end); - testRandomReadMultiBytesPerTime(input, start, end); - input.close(); - } - - private void testRandomReadFile(File file, long start, long end) throws IOException { - PartialInputStream input = PartialInputStream.newInputStream(file, start, end); + private void testRandomReadMemory(ByteBuf byteBuf, int start, int end) throws IOException { + SerInputStream input = SerInputStream.newInputStream(byteBuf, start, end); testRandomReadOneBytePerTime(input, start, end); input.close(); - input = PartialInputStream.newInputStream(file, start, end); + input = SerInputStream.newInputStream(byteBuf, start, end); testRandomReadMultiBytesPerTime(input, start, end); input.close(); } - private void testRandomReadOneBytePerTime(PartialInputStream input, long start, long end) + private void testRandomReadOneBytePerTime(SerInputStream input, long start, long end) throws IOException { // test read one byte per time long index = start; @@ -171,7 +160,7 @@ private void testRandomReadOneBytePerTime(PartialInputStream input, long start, } } - void testRandomReadMultiBytesPerTime(PartialInputStream input, long start, long end) + private void testRandomReadMultiBytesPerTime(SerInputStream input, long start, long end) throws IOException { // test read multi bytes per times long index = start; @@ -194,4 +183,14 @@ void testRandomReadMultiBytesPerTime(PartialInputStream input, long start, long assertEquals(-1, input.read()); } } + + private void testRandomReadFile(File file, long start, long end) throws IOException { + SerInputStream input = SerInputStream.newInputStream(file, start, end); + testRandomReadOneBytePerTime(input, start, end); + input.close(); + + input = SerInputStream.newInputStream(file, start, end); + testRandomReadMultiBytesPerTime(input, start, end); + input.close(); + } } diff --git a/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java b/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java index e6679a138b..ec58b05bd7 100644 --- a/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java +++ b/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java @@ -17,15 +17,16 @@ package org.apache.uniffle.common.serializer; -import java.io.ByteArrayOutputStream; +import java.io.BufferedOutputStream; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; -import java.nio.ByteBuffer; import java.util.Comparator; import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; @@ -137,20 +138,6 @@ public int compare(Integer o1, Integer o2) { return null; } - public static byte[] genSortedRecordBytes( - RssConf rssConf, - Class keyClass, - Class valueClass, - int start, - int interval, - int length, - int replica) - throws IOException { - ByteArrayOutputStream output = new ByteArrayOutputStream(); - genSortedRecord(rssConf, keyClass, valueClass, start, interval, length, output, replica); - return output.toByteArray(); - } - public static Segment genMemorySegment( RssConf rssConf, Class keyClass, @@ -160,7 +147,8 @@ public static Segment genMemorySegment( int interval, int length) throws IOException { - return genMemorySegment(rssConf, keyClass, valueClass, blockId, start, interval, length, false); + return genMemorySegment( + rssConf, keyClass, valueClass, blockId, start, interval, length, false, false); } public static Segment genMemorySegment( @@ -171,12 +159,19 @@ public static Segment genMemorySegment( int start, int interval, int length, - boolean raw) + boolean raw, + boolean direct) throws IOException { - ByteArrayOutputStream output = new ByteArrayOutputStream(); - genSortedRecord(rssConf, keyClass, valueClass, start, interval, length, output, 1); + ByteBuf byteBuf = + genSortedRecordBuffer(rssConf, keyClass, valueClass, start, interval, length, 1, direct); return new StreamedSegment( - rssConf, ByteBuffer.wrap(output.toByteArray()), blockId, keyClass, valueClass, raw); + rssConf, + SerInputStream.newInputStream(byteBuf), + blockId, + keyClass, + valueClass, + byteBuf.readableBytes(), + raw); } public static Segment genFileSegment( @@ -187,40 +182,53 @@ public static Segment genFileSegment( int start, int interval, int length, - File tmpDir) + File tmpDir, + boolean raw) throws IOException { - return genFileSegment( - rssConf, keyClass, valueClass, blockId, start, interval, length, tmpDir, false); + File file = new File(tmpDir, "data." + start); + ByteBuf byteBuffer = + genSortedRecordBuffer(rssConf, keyClass, valueClass, start, interval, length, 1); + OutputStream outputStream = new BufferedOutputStream(new FileOutputStream(file)); + while (byteBuffer.readableBytes() > 0) { + outputStream.write(byteBuffer.readByte()); + } + outputStream.close(); + return new StreamedSegment( + rssConf, + SerInputStream.newInputStream(file), + blockId, + keyClass, + valueClass, + file.length(), + raw); } - public static Segment genFileSegment( + public static ByteBuf genSortedRecordBuffer( RssConf rssConf, Class keyClass, Class valueClass, - long blockId, int start, int interval, int length, - File tmpDir, - boolean raw) + int replica) throws IOException { - File file = new File(tmpDir, "data." + start); - genSortedRecord( - rssConf, keyClass, valueClass, start, interval, length, new FileOutputStream(file), 1); - return new StreamedSegment(rssConf, file, 0, file.length(), blockId, keyClass, valueClass, raw); + return genSortedRecordBuffer( + rssConf, keyClass, valueClass, start, interval, length, replica, false); } - private static void genSortedRecord( + public static ByteBuf genSortedRecordBuffer( RssConf rssConf, Class keyClass, Class valueClass, int start, int interval, int length, - OutputStream output, - int replica) + int replica, + boolean direct) throws IOException { - RecordsWriter writer = new RecordsWriter(rssConf, output, keyClass, valueClass, false); + SerOutputStream output = new DynBufferSerOutputStream(); + RecordsWriter writer = new RecordsWriter(rssConf, output, keyClass, valueClass, false, false); + writer.init(); for (int i = 0; i < length; i++) { for (int j = 0; j < replica; j++) { writer.append( @@ -229,5 +237,13 @@ private static void genSortedRecord( } } writer.close(); + ByteBuf heapBuf = output.toByteBuf(); + if (direct) { + ByteBuf directBuf = Unpooled.directBuffer(heapBuf.readableBytes()); + directBuf.writeBytes(heapBuf); + return directBuf; + } else { + return heapBuf; + } } } diff --git a/common/src/test/java/org/apache/uniffle/common/serializer/WritableSerializerTest.java b/common/src/test/java/org/apache/uniffle/common/serializer/WritableSerializerTest.java index 99853f3c6d..39acbd6931 100644 --- a/common/src/test/java/org/apache/uniffle/common/serializer/WritableSerializerTest.java +++ b/common/src/test/java/org/apache/uniffle/common/serializer/WritableSerializerTest.java @@ -17,13 +17,11 @@ package org.apache.uniffle.common.serializer; -import java.io.ByteArrayOutputStream; import java.io.File; -import java.io.FileOutputStream; -import java.io.OutputStream; -import java.nio.ByteBuffer; import java.util.Random; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import org.apache.hadoop.io.DataInputBuffer; import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.IntWritable; @@ -44,175 +42,94 @@ public class WritableSerializerTest { private static final int LOOP = 1009; private static RssConf rssConf = new RssConf(); - // Test 1: both write and read will use common api @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file" + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem,true,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem,true,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem,false,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem,false,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file,true,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file,true,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file,false,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file,false,false", }) - public void testSerDeKeyValues1(String classes, @TempDir File tmpDir) throws Exception { + public void testSerDeKeyValues(String classes, @TempDir File tmpDir) throws Exception { // 1 Construct serializer String[] classArray = classes.split(","); Class keyClass = SerializerUtils.getClassByName(classArray[0]); Class valueClass = SerializerUtils.getClassByName(classArray[1]); boolean isFileMode = classArray[2].equals("file"); + boolean serRaw = Boolean.parseBoolean(classArray[3]); + boolean derRaw = Boolean.parseBoolean(classArray[4]); WritableSerializer serializer = new WritableSerializer(rssConf); SerializerInstance instance = serializer.newInstance(); // 2 Write - OutputStream outputStream = + SerOutputStream outputStream = isFileMode - ? new FileOutputStream(new File(tmpDir, "tmp.data")) - : new ByteArrayOutputStream(); - SerializationStream serializationStream = instance.serializeStream(outputStream, false); + ? new FileSerOutputStream(new File(tmpDir, "tmp.data")) + : new DynBufferSerOutputStream(); + SerializationStream serializationStream = instance.serializeStream(outputStream, serRaw, false); + serializationStream.init(); long[] offsets = new long[LOOP]; for (int i = 0; i < LOOP; i++) { - serializationStream.writeRecord(genData(keyClass, i), genData(valueClass, i)); - offsets[i] = serializationStream.getTotalBytesWritten(); - } - serializationStream.close(); - - // 3 Random read - for (int i = 0; i < LOOP; i++) { - long off = offsets[i]; - PartialInputStream inputStream = - isFileMode - ? PartialInputStream.newInputStream(new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - off, - ((ByteArrayOutputStream) outputStream).size()); - DeserializationStream deserializationStream = - instance.deserializeStream(inputStream, keyClass, valueClass, false); - for (int j = i + 1; j < LOOP; j++) { - assertTrue(deserializationStream.nextRecord()); - assertEquals(genData(keyClass, j), deserializationStream.getCurrentKey()); - assertEquals(genData(valueClass, j), deserializationStream.getCurrentValue()); + if (serRaw) { + DataOutputBuffer keyBuffer = new DataOutputBuffer(); + DataOutputBuffer valueBuffer = new DataOutputBuffer(); + instance.serialize(genData(keyClass, i), keyBuffer); + instance.serialize(genData(valueClass, i), valueBuffer); + serializationStream.writeRecord(keyBuffer, valueBuffer); + offsets[i] = serializationStream.getTotalBytesWritten(); + } else { + serializationStream.writeRecord(genData(keyClass, i), genData(valueClass, i)); + offsets[i] = serializationStream.getTotalBytesWritten(); } - deserializationStream.close(); - } - } - - // Test 2: write with common api, read with raw api - @ParameterizedTest - @ValueSource( - strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file" - }) - public void testSerDeKeyValues2(String classes, @TempDir File tmpDir) throws Exception { - // 1 Construct serializer - String[] classArray = classes.split(","); - Class keyClass = SerializerUtils.getClassByName(classArray[0]); - Class valueClass = SerializerUtils.getClassByName(classArray[1]); - boolean isFileMode = classArray[2].equals("file"); - WritableSerializer serializer = new WritableSerializer(rssConf); - SerializerInstance instance = serializer.newInstance(); - - // 2 Write - OutputStream outputStream = - isFileMode - ? new FileOutputStream(new File(tmpDir, "tmp.data")) - : new ByteArrayOutputStream(); - SerializationStream serializationStream = instance.serializeStream(outputStream, false); - long[] offsets = new long[LOOP]; - for (int i = 0; i < LOOP; i++) { - serializationStream.writeRecord(genData(keyClass, i), genData(valueClass, i)); - offsets[i] = serializationStream.getTotalBytesWritten(); } serializationStream.close(); // 3 Random read + ByteBuf byteBuf = isFileMode ? null : outputStream.toByteBuf(); for (int i = 0; i < LOOP; i++) { long off = offsets[i]; - PartialInputStream inputStream = + SerInputStream inputStream = isFileMode - ? PartialInputStream.newInputStream(new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - off, - ((ByteArrayOutputStream) outputStream).size()); - + ? SerInputStream.newInputStream(new File(tmpDir, "tmp.data"), off) + : SerInputStream.newInputStream(byteBuf, (int) off); DeserializationStream deserializationStream = - instance.deserializeStream(inputStream, keyClass, valueClass, true); + instance.deserializeStream(inputStream, keyClass, valueClass, derRaw, false); + deserializationStream.init(); for (int j = i + 1; j < LOOP; j++) { - assertTrue(deserializationStream.nextRecord()); - DataOutputBuffer keyBuffer = (DataOutputBuffer) deserializationStream.getCurrentKey(); - DataInputBuffer keyInputBuffer = new DataInputBuffer(); - keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength()); - assertEquals(genData(keyClass, j), instance.deserialize(keyInputBuffer, keyClass)); - DataOutputBuffer valueBuffer = (DataOutputBuffer) deserializationStream.getCurrentValue(); - DataInputBuffer valueInputBuffer = new DataInputBuffer(); - valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength()); - assertEquals(genData(valueClass, j), instance.deserialize(valueInputBuffer, valueClass)); + if (derRaw) { + assertTrue(deserializationStream.nextRecord()); + DataOutputBuffer keyBuffer = (DataOutputBuffer) deserializationStream.getCurrentKey(); + DataInputBuffer keyInputBuffer = new DataInputBuffer(); + keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength()); + assertEquals(genData(keyClass, j), instance.deserialize(keyInputBuffer, keyClass)); + DataOutputBuffer valueBuffer = (DataOutputBuffer) deserializationStream.getCurrentValue(); + DataInputBuffer valueInputBuffer = new DataInputBuffer(); + valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength()); + assertEquals(genData(valueClass, j), instance.deserialize(valueInputBuffer, valueClass)); + } else { + assertTrue(deserializationStream.nextRecord()); + assertEquals(genData(keyClass, j), deserializationStream.getCurrentKey()); + assertEquals(genData(valueClass, j), deserializationStream.getCurrentValue()); + } } deserializationStream.close(); } - } - - // Test 3: write with raw api, read with common api - @ParameterizedTest - @ValueSource( - strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file" - }) - public void testSerDeKeyValues3(String classes, @TempDir File tmpDir) throws Exception { - // 1 Construct serializer - String[] classArray = classes.split(","); - Class keyClass = SerializerUtils.getClassByName(classArray[0]); - Class valueClass = SerializerUtils.getClassByName(classArray[1]); - boolean isFileMode = classArray[2].equals("file"); - WritableSerializer serializer = new WritableSerializer(rssConf); - SerializerInstance instance = serializer.newInstance(); - - // 2 Write - OutputStream outputStream = - isFileMode - ? new FileOutputStream(new File(tmpDir, "tmp.data")) - : new ByteArrayOutputStream(); - SerializationStream serializationStream = instance.serializeStream(outputStream, true); - long[] offsets = new long[LOOP]; - for (int i = 0; i < LOOP; i++) { - DataOutputBuffer keyBuffer = new DataOutputBuffer(); - DataOutputBuffer valueBuffer = new DataOutputBuffer(); - instance.serialize(genData(keyClass, i), keyBuffer); - instance.serialize(genData(valueClass, i), valueBuffer); - serializationStream.writeRecord(keyBuffer, valueBuffer); - offsets[i] = serializationStream.getTotalBytesWritten(); - } - serializationStream.close(); - - // 3 Random read - for (int i = 0; i < LOOP; i++) { - long off = offsets[i]; - PartialInputStream inputStream = - isFileMode - ? PartialInputStream.newInputStream(new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - off, - ((ByteArrayOutputStream) outputStream).size()); - DeserializationStream deserializationStream = - instance.deserializeStream(inputStream, keyClass, valueClass, false); - for (int j = i + 1; j < LOOP; j++) { - assertTrue(deserializationStream.nextRecord()); - assertEquals(genData(keyClass, j), deserializationStream.getCurrentKey()); - assertEquals(genData(valueClass, j), deserializationStream.getCurrentValue()); - } - deserializationStream.close(); + if (!isFileMode) { + byteBuf.release(); } } - // Test 4: both write and read use raw api @ParameterizedTest @ValueSource( strings = { "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,mem", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file" + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,file", }) - public void testSerDeKeyValues4(String classes, @TempDir File tmpDir) throws Exception { + public void testSerDeKeyValuesUseDirect(String classes, @TempDir File tmpDir) throws Exception { // 1 Construct serializer String[] classArray = classes.split(","); Class keyClass = SerializerUtils.getClassByName(classArray[0]); @@ -222,48 +139,60 @@ public void testSerDeKeyValues4(String classes, @TempDir File tmpDir) throws Exc SerializerInstance instance = serializer.newInstance(); // 2 Write - OutputStream outputStream = + SerOutputStream outputStream = isFileMode - ? new FileOutputStream(new File(tmpDir, "tmp.data")) - : new ByteArrayOutputStream(); - SerializationStream serializationStream = instance.serializeStream(outputStream, true); + ? new FileSerOutputStream(new File(tmpDir, "tmp.data")) + : new DynBufferSerOutputStream(); + SerializationStream serializationStream = instance.serializeStream(outputStream, true, true); + serializationStream.init(); long[] offsets = new long[LOOP]; for (int i = 0; i < LOOP; i++) { DataOutputBuffer keyBuffer = new DataOutputBuffer(); DataOutputBuffer valueBuffer = new DataOutputBuffer(); instance.serialize(genData(keyClass, i), keyBuffer); instance.serialize(genData(valueClass, i), valueBuffer); - serializationStream.writeRecord(keyBuffer, valueBuffer); + ByteBuf kBuffer = Unpooled.buffer(keyBuffer.getLength()); + kBuffer.writeBytes(keyBuffer.getData(), 0, keyBuffer.getLength()); + ByteBuf vBuffer = Unpooled.buffer(valueBuffer.getLength()); + vBuffer.writeBytes(valueBuffer.getData(), 0, valueBuffer.getLength()); + serializationStream.writeRecord(kBuffer, vBuffer); + kBuffer.release(); + vBuffer.release(); offsets[i] = serializationStream.getTotalBytesWritten(); } serializationStream.close(); // 3 Random read + ByteBuf byteBuf = isFileMode ? null : outputStream.toByteBuf(); for (int i = 0; i < LOOP; i++) { long off = offsets[i]; - PartialInputStream inputStream = + SerInputStream inputStream = isFileMode - ? PartialInputStream.newInputStream(new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) - : PartialInputStream.newInputStream( - ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), - off, - ((ByteArrayOutputStream) outputStream).size()); - + ? SerInputStream.newInputStream(new File(tmpDir, "tmp.data"), off) + : SerInputStream.newInputStream(byteBuf, (int) off); DeserializationStream deserializationStream = - instance.deserializeStream(inputStream, keyClass, valueClass, true); + instance.deserializeStream(inputStream, keyClass, valueClass, true, true); + deserializationStream.init(); for (int j = i + 1; j < LOOP; j++) { assertTrue(deserializationStream.nextRecord()); - DataOutputBuffer keyBuffer = (DataOutputBuffer) deserializationStream.getCurrentKey(); + ByteBuf keyByteBuf = (ByteBuf) deserializationStream.getCurrentKey(); + ByteBuf valueByteBuf = (ByteBuf) deserializationStream.getCurrentValue(); + byte[] keyBytes = new byte[keyByteBuf.readableBytes()]; + byte[] valueBytes = new byte[valueByteBuf.readableBytes()]; + keyByteBuf.readBytes(keyBytes); + valueByteBuf.readBytes(valueBytes); DataInputBuffer keyInputBuffer = new DataInputBuffer(); - keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength()); + keyInputBuffer.reset(keyBytes, 0, keyBytes.length); assertEquals(genData(keyClass, j), instance.deserialize(keyInputBuffer, keyClass)); - DataOutputBuffer valueBuffer = (DataOutputBuffer) deserializationStream.getCurrentValue(); DataInputBuffer valueInputBuffer = new DataInputBuffer(); - valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength()); + valueInputBuffer.reset(valueBytes, 0, valueBytes.length); assertEquals(genData(valueClass, j), instance.deserialize(valueInputBuffer, valueClass)); } deserializationStream.close(); } + if (!isFileMode) { + byteBuf.release(); + } } @ParameterizedTest diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java index f341175864..90591a3b87 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; import java.net.ServerSocket; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -33,10 +34,10 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import io.netty.buffer.ByteBuf; import org.apache.hadoop.io.IntWritable; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.ParameterizedTest; @@ -49,7 +50,6 @@ import org.apache.uniffle.client.record.reader.RMRecordsReader; import org.apache.uniffle.client.record.writer.Combiner; import org.apache.uniffle.client.record.writer.SumByKeyCombiner; -import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleBlockInfo; @@ -65,6 +65,7 @@ import org.apache.uniffle.common.util.ChecksumUtils; import org.apache.uniffle.coordinator.CoordinatorConf; import org.apache.uniffle.proto.RssProtos; +import org.apache.uniffle.server.ShuffleServer; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.buffer.ShuffleBufferType; import org.apache.uniffle.storage.util.StorageType; @@ -76,6 +77,7 @@ public class RemoteMergeShuffleWithRssClientTest extends ShuffleReadWriteBase { private static final int SHUFFLE_ID = 0; private static final int PARTITION_ID = 0; + private static final int RECORD_NUMBER = 1009; private static ShuffleServerInfo shuffleServerInfo; private ShuffleWriteClientImpl shuffleWriteClientImpl; @@ -85,7 +87,7 @@ public static void setupServers(@TempDir File tmpDir) throws Exception { CoordinatorConf coordinatorConf = getCoordinatorConf(); coordinatorConf.setBoolean(COORDINATOR_DYNAMIC_CLIENT_CONF_ENABLED, false); createCoordinatorServer(coordinatorConf); - ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC); + ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC_NETTY); shuffleServerConf.set(ShuffleServerConf.SERVER_MERGE_ENABLE, true); shuffleServerConf.set(ShuffleServerConf.SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE, "1k"); shuffleServerConf.set( @@ -101,8 +103,13 @@ public static void setupServers(@TempDir File tmpDir) throws Exception { shuffleServerConf.setInteger("rss.jetty.http.port", ports.get(1)); createShuffleServer(shuffleServerConf); startServers(); + ShuffleServer shuffleServer = nettyShuffleServers.get(0); shuffleServerInfo = - new ShuffleServerInfo("127.0.0.1-20001", grpcShuffleServers.get(0).getIp(), ports.get(0)); + new ShuffleServerInfo( + "127.0.0.1-20001", + shuffleServer.getIp(), + shuffleServer.getGrpcPort(), + shuffleServer.getNettyPort()); } private static List findAvailablePorts(int num) throws IOException { @@ -122,12 +129,11 @@ private static List findAvailablePorts(int num) throws IOException { return ports; } - @BeforeEach - public void createClient() { + public void createClient(String clientType) { shuffleWriteClientImpl = new ShuffleWriteClientImpl( ShuffleClientFactory.newWriteBuilder() - .clientType(ClientType.GRPC.name()) + .clientType(clientType) .retryMax(3) .retryIntervalMax(1000) .heartBeatThreadNum(1) @@ -149,8 +155,10 @@ public void closeClient() { @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,false", }) @Timeout(10) public void remoteMergeWriteReadTest(String classes) throws Exception { @@ -160,11 +168,13 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { final String valueClassName = classArray[1]; final Class keyClass = SerializerUtils.getClassByName(keyClassName); final Class valueClass = SerializerUtils.getClassByName(valueClassName); - final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final String clientType = classArray[2]; + final boolean raw = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; final Comparator comparator = SerializerUtils.getComparator(keyClass); final RssConf rssConf = new RssConf(); // 2 register shuffle + createClient(clientType); String testAppId = "remoteMergeWriteReadTest" + classes; shuffleWriteClientImpl.registerShuffle( shuffleServerInfo, @@ -198,7 +208,7 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { valueClass, 0, 5, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -211,7 +221,7 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { valueClass, 2, 5, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -224,7 +234,7 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { valueClass, 4, 5, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); // task 1 attempt 0 generate two blocks @@ -240,7 +250,7 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { valueClass, 1, 5, - 1009, + RECORD_NUMBER, 1)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -253,7 +263,7 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { valueClass, 3, 5, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); Map> partitionToServers = @@ -294,7 +304,8 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { raw, null, false, - null); + null, + clientType); reader.start(); int index = 0; KeyValueReader keyValueReader = reader.keyValueReader(); @@ -303,15 +314,17 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { assertEquals(SerializerUtils.genData(valueClass, index), keyValueReader.getCurrentValue()); index++; } - assertEquals(5 * 1009, index); + assertEquals(5 * RECORD_NUMBER, index); shuffleWriteClientImpl.unregisterShuffle(testAppId); } @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,false", }) @Timeout(10) public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception { @@ -321,7 +334,8 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception final String valueClassName = classArray[1]; final Class keyClass = SerializerUtils.getClassByName(keyClassName); final Class valueClass = SerializerUtils.getClassByName(valueClassName); - final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final String clientType = classArray[2]; + final boolean raw = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; final Comparator comparator = SerializerUtils.getComparator(keyClass); final RssConf rssConf = new RssConf(); SerializerFactory factory = new SerializerFactory(rssConf); @@ -330,6 +344,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception final Combiner combiner = new SumByKeyCombiner(raw, serializerInstance, keyClass, valueClass); // 2 register shuffle + createClient(clientType); String testAppId = "remoteMergeWriteReadTestWithCombine" + classes; shuffleWriteClientImpl.registerShuffle( shuffleServerInfo, @@ -363,7 +378,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception valueClass, 0, 3, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -376,7 +391,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception valueClass, 1, 3, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -389,7 +404,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception valueClass, 2, 3, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); // task 1 attempt 0 generate two blocks @@ -405,7 +420,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception valueClass, 0, 3, - 1009, + RECORD_NUMBER, 1)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -418,7 +433,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception valueClass, 2, 3, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); Map> partitionToServers = @@ -459,7 +474,8 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception raw, combiner, false, - null); + null, + clientType); reader.start(); int index = 0; KeyValueReader keyValueReader = reader.keyValueReader(); @@ -477,15 +493,17 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception assertEquals(newValue, keyValueReader.getCurrentValue()); index++; } - assertEquals(3 * 1009, index); + assertEquals(3 * RECORD_NUMBER, index); shuffleWriteClientImpl.unregisterShuffle(testAppId); } @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,false", }) @Timeout(10) public void remoteMergeWriteReadTestMultiPartition(String classes) throws Exception { @@ -495,11 +513,13 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except final String valueClassName = classArray[1]; final Class keyClass = SerializerUtils.getClassByName(keyClassName); final Class valueClass = SerializerUtils.getClassByName(valueClassName); - final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final String clientType = classArray[2]; + final boolean raw = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; final Comparator comparator = SerializerUtils.getComparator(keyClass); final RssConf rssConf = new RssConf(); // 2 register shuffle + createClient(clientType); String testAppId = "remoteMergeWriteReadTestMultiPartition" + classes; shuffleWriteClientImpl.registerShuffle( shuffleServerInfo, @@ -537,7 +557,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 0, 6, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -550,7 +570,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 2, 6, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -563,7 +583,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 4, 6, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); // task 1 attempt 0 generate two blocks @@ -579,7 +599,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 1, 6, - 1009, + RECORD_NUMBER, 1)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -592,7 +612,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 3, 6, - 1009, + RECORD_NUMBER, 1)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -605,7 +625,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 5, 6, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); Map> partitionToServers = @@ -672,7 +692,8 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except raw, null, false, - null); + null, + clientType); reader.start(); int index = 0; KeyValueReader keyValueReader = reader.keyValueReader(); @@ -681,15 +702,17 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except assertEquals(SerializerUtils.genData(valueClass, index), keyValueReader.getCurrentValue()); index++; } - assertEquals(6 * 1009, index); + assertEquals(6 * RECORD_NUMBER, index); shuffleWriteClientImpl.unregisterShuffle(testAppId); } @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,false", }) @Timeout(10) public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) throws Exception { @@ -699,7 +722,8 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th final String valueClassName = classArray[1]; final Class keyClass = SerializerUtils.getClassByName(keyClassName); final Class valueClass = SerializerUtils.getClassByName(valueClassName); - final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final String clientType = classArray[2]; + final boolean raw = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; final Comparator comparator = SerializerUtils.getComparator(keyClass); final RssConf rssConf = new RssConf(); SerializerFactory factory = new SerializerFactory(rssConf); @@ -708,6 +732,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th final Combiner combiner = new SumByKeyCombiner(raw, serializerInstance, keyClass, valueClass); // 2 register shuffle + createClient(clientType); String testAppId = "remoteMergeWriteReadTestMultiPartitionWithCombine" + classes; shuffleWriteClientImpl.registerShuffle( shuffleServerInfo, @@ -745,7 +770,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 0, 6, - 1009, + RECORD_NUMBER, 2)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -758,7 +783,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 2, 6, - 1009, + RECORD_NUMBER, 2)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -771,7 +796,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 4, 6, - 1009, + RECORD_NUMBER, 2)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); // task 1 attempt 0 generate two blocks @@ -787,7 +812,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 1, 6, - 1009, + RECORD_NUMBER, 2)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -800,7 +825,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 3, 6, - 1009, + RECORD_NUMBER, 2)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -813,7 +838,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 5, 6, - 1009, + RECORD_NUMBER, 2)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); Map> partitionToServers = @@ -880,7 +905,8 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th raw, combiner, false, - null); + null, + clientType); reader.start(); int index = 0; KeyValueReader keyValueReader = reader.keyValueReader(); @@ -890,7 +916,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th SerializerUtils.genData(valueClass, index * 2), keyValueReader.getCurrentValue()); index++; } - assertEquals(6 * 1009, index); + assertEquals(6 * RECORD_NUMBER, index); shuffleWriteClientImpl.unregisterShuffle(testAppId); } @@ -911,18 +937,19 @@ public static ShuffleBlockInfo createShuffleBlockForRemoteMerge( throws IOException { long blockId = blockIdLayout.getBlockId(ATOMIC_INT_SORTED.getAndIncrement(), PARTITION_ID, taskAttemptId); - byte[] buf = - SerializerUtils.genSortedRecordBytes( + ByteBuf byteBuf = + SerializerUtils.genSortedRecordBuffer( rssConf, keyClass, valueClass, start, interval, samples, duplicated); + ByteBuffer byteBuffer = byteBuf.nioBuffer(); return new ShuffleBlockInfo( SHUFFLE_ID, partitionId, blockId, - buf.length, - ChecksumUtils.getCrc32(buf), - buf, + byteBuf.readableBytes(), + ChecksumUtils.getCrc32(byteBuffer), + byteBuffer.array(), shuffleServerInfoList, - buf.length, + byteBuf.readableBytes(), 0, taskAttemptId); } diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java index d12b286c2a..85499f3b0a 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; import java.net.ServerSocket; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -33,10 +34,10 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import io.netty.buffer.ByteBuf; import org.apache.hadoop.io.IntWritable; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.ParameterizedTest; @@ -49,7 +50,6 @@ import org.apache.uniffle.client.record.reader.RMRecordsReader; import org.apache.uniffle.client.record.writer.Combiner; import org.apache.uniffle.client.record.writer.SumByKeyCombiner; -import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleBlockInfo; @@ -65,6 +65,7 @@ import org.apache.uniffle.common.util.ChecksumUtils; import org.apache.uniffle.coordinator.CoordinatorConf; import org.apache.uniffle.proto.RssProtos; +import org.apache.uniffle.server.ShuffleServer; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.buffer.ShuffleBufferType; import org.apache.uniffle.storage.util.StorageType; @@ -78,6 +79,7 @@ public class RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff private static final int SHUFFLE_ID = 0; private static final int PARTITION_ID = 0; + private static final int RECORD_NUMBER = 1009; private static ShuffleServerInfo shuffleServerInfo; private ShuffleWriteClientImpl shuffleWriteClientImpl; @@ -87,7 +89,7 @@ public static void setupServers(@TempDir File tmpDir) throws Exception { CoordinatorConf coordinatorConf = getCoordinatorConf(); coordinatorConf.setBoolean(COORDINATOR_DYNAMIC_CLIENT_CONF_ENABLED, false); createCoordinatorServer(coordinatorConf); - ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC); + ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC_NETTY); shuffleServerConf.set(ShuffleServerConf.SERVER_MERGE_ENABLE, true); shuffleServerConf.set(ShuffleServerConf.SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE, "1k"); shuffleServerConf.set( @@ -106,8 +108,13 @@ public static void setupServers(@TempDir File tmpDir) throws Exception { shuffleServerConf.setInteger("rss.jetty.http.port", ports.get(1)); createShuffleServer(shuffleServerConf); startServers(); + ShuffleServer shuffleServer = nettyShuffleServers.get(0); shuffleServerInfo = - new ShuffleServerInfo("127.0.0.1-20001", grpcShuffleServers.get(0).getIp(), ports.get(0)); + new ShuffleServerInfo( + "127.0.0.1-20001", + shuffleServer.getIp(), + shuffleServer.getGrpcPort(), + shuffleServer.getNettyPort()); } private static List findAvailablePorts(int num) throws IOException { @@ -127,12 +134,11 @@ private static List findAvailablePorts(int num) throws IOException { return ports; } - @BeforeEach - public void createClient() { + public void createClient(String clientType) { shuffleWriteClientImpl = new ShuffleWriteClientImpl( ShuffleClientFactory.newWriteBuilder() - .clientType(ClientType.GRPC.name()) + .clientType(clientType) .retryMax(3) .retryIntervalMax(1000) .heartBeatThreadNum(1) @@ -154,8 +160,10 @@ public void closeClient() { @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,false", }) @Timeout(10) public void remoteMergeWriteReadTest(String classes) throws Exception { @@ -165,11 +173,13 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { final String valueClassName = classArray[1]; final Class keyClass = SerializerUtils.getClassByName(keyClassName); final Class valueClass = SerializerUtils.getClassByName(valueClassName); - final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final String clientType = classArray[2]; + final boolean raw = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; final Comparator comparator = SerializerUtils.getComparator(keyClass); final RssConf rssConf = new RssConf(); // 2 register shuffle + createClient(clientType); String testAppId = "remoteMergeWriteReadTest" + classes; shuffleWriteClientImpl.registerShuffle( shuffleServerInfo, @@ -203,7 +213,7 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { valueClass, 0, 5, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -216,7 +226,7 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { valueClass, 2, 5, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -229,7 +239,7 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { valueClass, 4, 5, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); // task 1 attempt 0 generate two blocks @@ -245,7 +255,7 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { valueClass, 1, 5, - 1009, + RECORD_NUMBER, 1)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -258,7 +268,7 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { valueClass, 3, 5, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); Map> partitionToServers = @@ -299,7 +309,8 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { raw, null, false, - null); + null, + clientType); reader.start(); int index = 0; KeyValueReader keyValueReader = reader.keyValueReader(); @@ -308,15 +319,17 @@ public void remoteMergeWriteReadTest(String classes) throws Exception { assertEquals(SerializerUtils.genData(valueClass, index), keyValueReader.getCurrentValue()); index++; } - assertEquals(5 * 1009, index); + assertEquals(5 * RECORD_NUMBER, index); shuffleWriteClientImpl.unregisterShuffle(testAppId); } @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,false", }) @Timeout(10) public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception { @@ -326,7 +339,8 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception final String valueClassName = classArray[1]; final Class keyClass = SerializerUtils.getClassByName(keyClassName); final Class valueClass = SerializerUtils.getClassByName(valueClassName); - final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final String clientType = classArray[2]; + final boolean raw = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; final Comparator comparator = SerializerUtils.getComparator(keyClass); final RssConf rssConf = new RssConf(); SerializerFactory factory = new SerializerFactory(rssConf); @@ -335,6 +349,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception final Combiner combiner = new SumByKeyCombiner(raw, serializerInstance, keyClass, valueClass); // 2 register shuffle + createClient(clientType); String testAppId = "remoteMergeWriteReadTestWithCombine" + classes; shuffleWriteClientImpl.registerShuffle( shuffleServerInfo, @@ -369,7 +384,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception valueClass, 0, 3, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -382,7 +397,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception valueClass, 1, 3, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -395,7 +410,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception valueClass, 2, 3, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); // task 1 attempt 0 generate two blocks @@ -411,7 +426,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception valueClass, 0, 3, - 1009, + RECORD_NUMBER, 1)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -424,7 +439,7 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception valueClass, 2, 3, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); Map> partitionToServers = @@ -465,7 +480,8 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception raw, combiner, false, - null); + null, + clientType); reader.start(); int index = 0; KeyValueReader keyValueReader = reader.keyValueReader(); @@ -483,15 +499,17 @@ public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception assertEquals(newValue, keyValueReader.getCurrentValue()); index++; } - assertEquals(3 * 1009, index); + assertEquals(3 * RECORD_NUMBER, index); shuffleWriteClientImpl.unregisterShuffle(testAppId); } @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,false", }) @Timeout(10) public void remoteMergeWriteReadTestMultiPartition(String classes) throws Exception { @@ -501,11 +519,13 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except final String valueClassName = classArray[1]; final Class keyClass = SerializerUtils.getClassByName(keyClassName); final Class valueClass = SerializerUtils.getClassByName(valueClassName); - final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final String clientType = classArray[2]; + final boolean raw = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; final Comparator comparator = SerializerUtils.getComparator(keyClass); final RssConf rssConf = new RssConf(); // 2 register shuffle + createClient(clientType); String testAppId = "remoteMergeWriteReadTestMultiPartition" + classes; shuffleWriteClientImpl.registerShuffle( shuffleServerInfo, @@ -543,7 +563,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 0, 6, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -556,7 +576,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 2, 6, - 1009, + RECORD_NUMBER, 1)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -569,7 +589,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 4, 6, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); // task 1 attempt 0 generate two blocks @@ -585,7 +605,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 1, 6, - 1009, + RECORD_NUMBER, 1)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -598,7 +618,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 3, 6, - 1009, + RECORD_NUMBER, 1)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -611,7 +631,7 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except valueClass, 5, 6, - 1009, + RECORD_NUMBER, 1)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); Map> partitionToServers = @@ -678,7 +698,8 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except raw, null, false, - null); + null, + clientType); reader.start(); int index = 0; KeyValueReader keyValueReader = reader.keyValueReader(); @@ -687,15 +708,17 @@ public void remoteMergeWriteReadTestMultiPartition(String classes) throws Except assertEquals(SerializerUtils.genData(valueClass, index), keyValueReader.getCurrentValue()); index++; } - assertEquals(6 * 1009, index); + assertEquals(6 * RECORD_NUMBER, index); shuffleWriteClientImpl.unregisterShuffle(testAppId); } @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,GRPC_NETTY,false", }) @Timeout(10) public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) throws Exception { @@ -705,7 +728,8 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th final String valueClassName = classArray[1]; final Class keyClass = SerializerUtils.getClassByName(keyClassName); final Class valueClass = SerializerUtils.getClassByName(valueClassName); - final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final String clientType = classArray[2]; + final boolean raw = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; final Comparator comparator = SerializerUtils.getComparator(keyClass); final RssConf rssConf = new RssConf(); SerializerFactory factory = new SerializerFactory(rssConf); @@ -714,6 +738,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th final Combiner combiner = new SumByKeyCombiner(raw, serializerInstance, keyClass, valueClass); // 2 register shuffle + createClient(clientType); String testAppId = "remoteMergeWriteReadTestMultiPartitionWithCombine" + classes; shuffleWriteClientImpl.registerShuffle( shuffleServerInfo, @@ -752,7 +777,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 0, 6, - 1009, + RECORD_NUMBER, 2)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -765,7 +790,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 2, 6, - 1009, + RECORD_NUMBER, 2)); blocks1.add( createShuffleBlockForRemoteMerge( @@ -778,7 +803,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 4, 6, - 1009, + RECORD_NUMBER, 2)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); // task 1 attempt 0 generate two blocks @@ -794,7 +819,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 1, 6, - 1009, + RECORD_NUMBER, 2)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -807,7 +832,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 3, 6, - 1009, + RECORD_NUMBER, 2)); blocks2.add( createShuffleBlockForRemoteMerge( @@ -820,7 +845,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th valueClass, 5, 6, - 1009, + RECORD_NUMBER, 2)); shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); Map> partitionToServers = @@ -887,7 +912,8 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th raw, combiner, false, - null); + null, + clientType); reader.start(); int index = 0; KeyValueReader keyValueReader = reader.keyValueReader(); @@ -897,7 +923,7 @@ public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) th SerializerUtils.genData(valueClass, index * 2), keyValueReader.getCurrentValue()); index++; } - assertEquals(6 * 1009, index); + assertEquals(6 * RECORD_NUMBER, index); shuffleWriteClientImpl.unregisterShuffle(testAppId); } @@ -918,18 +944,19 @@ public static ShuffleBlockInfo createShuffleBlockForRemoteMerge( throws IOException { long blockId = blockIdLayout.getBlockId(ATOMIC_INT_SORTED.getAndIncrement(), PARTITION_ID, taskAttemptId); - byte[] buf = - SerializerUtils.genSortedRecordBytes( + ByteBuf byteBuf = + SerializerUtils.genSortedRecordBuffer( rssConf, keyClass, valueClass, start, interval, samples, duplicated); + ByteBuffer byteBuffer = byteBuf.nioBuffer(); return new ShuffleBlockInfo( SHUFFLE_ID, partitionId, blockId, - buf.length, - ChecksumUtils.getCrc32(buf), - buf, + byteBuf.readableBytes(), + ChecksumUtils.getCrc32(byteBuffer), + byteBuffer.array(), shuffleServerInfoList, - buf.length, + byteBuf.readableBytes(), 0, taskAttemptId); } diff --git a/integration-test/tez/src/test/java/org/apache/uniffle/test/RMTezOrderedWordCountTest.java b/integration-test/tez/src/test/java/org/apache/uniffle/test/RMTezOrderedWordCountTest.java index 925ef02ab4..47beded679 100644 --- a/integration-test/tez/src/test/java/org/apache/uniffle/test/RMTezOrderedWordCountTest.java +++ b/integration-test/tez/src/test/java/org/apache/uniffle/test/RMTezOrderedWordCountTest.java @@ -71,6 +71,7 @@ public void run() throws Exception { // Run RSS tests with different configurations runRemoteMergeRssTest(ClientType.GRPC, "rss-grpc", originPath); + runRemoteMergeRssTest(ClientType.GRPC_NETTY, "rss-netty", originPath); } private void runRemoteMergeRssTest(ClientType clientType, String testName, String originPath) diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index 20b6bf98b1..dccd9f9383 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -1170,8 +1170,6 @@ public RssGetSortedShuffleDataResponse getSortedShuffleData( .setMergedBlockId(request.getBlockId()) .setTimestamp(start) .build(); - RssProtos.GetSortedShuffleDataResponse rpcResponse = - getBlockingStub().getSortedShuffleData(rpcRequest); String requestInfo = "appId[" + request.getAppId() @@ -1182,6 +1180,17 @@ public RssGetSortedShuffleDataResponse getSortedShuffleData( + "], blockId[" + request.getBlockId() + "]"; + int retry = 0; + RssProtos.GetSortedShuffleDataResponse rpcResponse; + while (true) { + rpcResponse = getBlockingStub().getSortedShuffleData(rpcRequest); + if (rpcResponse.getStatus() != NO_BUFFER) { + break; + } + waitOrThrow( + request, retry, requestInfo, StatusCode.fromProto(rpcResponse.getStatus()), start); + retry++; + } LOG.info( "GetSortedShuffleData from {}:{} for {} cost {} ms", host, @@ -1197,7 +1206,8 @@ public RssGetSortedShuffleDataResponse getSortedShuffleData( response = new RssGetSortedShuffleDataResponse( StatusCode.SUCCESS, - ByteBuffer.wrap(rpcResponse.getData().toByteArray()), + rpcResponse.getRetMsg(), + rpcResponse.getData().asReadOnlyByteBuffer(), rpcResponse.getNextBlockId(), rpcResponse.getMState()); break; diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java index c2fde9176f..723b0ecf34 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java @@ -30,10 +30,12 @@ import org.apache.uniffle.client.request.RssGetInMemoryShuffleDataRequest; import org.apache.uniffle.client.request.RssGetShuffleDataRequest; import org.apache.uniffle.client.request.RssGetShuffleIndexRequest; +import org.apache.uniffle.client.request.RssGetSortedShuffleDataRequest; import org.apache.uniffle.client.request.RssSendShuffleDataRequest; import org.apache.uniffle.client.response.RssGetInMemoryShuffleDataResponse; import org.apache.uniffle.client.response.RssGetShuffleDataResponse; import org.apache.uniffle.client.response.RssGetShuffleIndexResponse; +import org.apache.uniffle.client.response.RssGetSortedShuffleDataResponse; import org.apache.uniffle.client.response.RssSendShuffleDataResponse; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.config.RssClientConf; @@ -51,6 +53,8 @@ import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse; import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest; import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataResponse; +import org.apache.uniffle.common.netty.protocol.GetSortedShuffleDataRequest; +import org.apache.uniffle.common.netty.protocol.GetSortedShuffleDataResponse; import org.apache.uniffle.common.netty.protocol.RpcResponse; import org.apache.uniffle.common.netty.protocol.SendShuffleDataRequest; import org.apache.uniffle.common.rpc.StatusCode; @@ -417,6 +421,67 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request } } + @Override + public RssGetSortedShuffleDataResponse getSortedShuffleData( + RssGetSortedShuffleDataRequest request) { + TransportClient transportClient = getTransportClient(); + GetSortedShuffleDataRequest getSortedShuffleDataRequest = + new GetSortedShuffleDataRequest( + requestId(), + request.getAppId(), + request.getShuffleId(), + request.getPartitionId(), + request.getBlockId(), + 0, + System.currentTimeMillis()); + + String requestInfo = + String.format( + "appId[%s], shuffleId[%d], partitionId[%d], blockId[%d]", + request.getAppId(), + request.getShuffleId(), + request.getPartitionId(), + request.getBlockId()); + + long start = System.currentTimeMillis(); + int retry = 0; + RpcResponse rpcResponse; + GetSortedShuffleDataResponse getSortedShuffleDataResponse; + + while (true) { + rpcResponse = transportClient.sendRpcSync(getSortedShuffleDataRequest, rpcTimeout); + getSortedShuffleDataResponse = (GetSortedShuffleDataResponse) rpcResponse; + if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) { + break; + } + waitOrThrow(request, retry, requestInfo, rpcResponse.getStatusCode(), start); + retry++; + } + + switch (rpcResponse.getStatusCode()) { + case SUCCESS: + LOG.info( + "GetSortedShuffleData from {}:{} for {} cost {} ms", + host, + nettyPort, + requestInfo, + System.currentTimeMillis() - start); + return new RssGetSortedShuffleDataResponse( + StatusCode.SUCCESS, + getSortedShuffleDataResponse.getRetMessage(), + getSortedShuffleDataResponse.body(), + getSortedShuffleDataResponse.getNextBlockId(), + getSortedShuffleDataResponse.getMergeState()); + default: + String msg = + String.format( + "Can't get sorted shuffle data from %s:%d for %s, errorMsg: %s", + host, nettyPort, requestInfo, getSortedShuffleDataResponse.getRetMessage()); + LOG.error(msg); + throw new RssFetchFailedException(msg); + } + } + private static final AtomicLong counter = new AtomicLong(); public static long requestId() { diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetSortedShuffleDataRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetSortedShuffleDataRequest.java index f3b4eb0789..ab02c06fd4 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetSortedShuffleDataRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetSortedShuffleDataRequest.java @@ -17,7 +17,7 @@ package org.apache.uniffle.client.request; -public class RssGetSortedShuffleDataRequest { +public class RssGetSortedShuffleDataRequest extends RetryableRequest { private final String appId; private final int shuffleId; @@ -25,11 +25,18 @@ public class RssGetSortedShuffleDataRequest { private final long blockId; public RssGetSortedShuffleDataRequest( - String appId, int shuffleId, int partitionId, long blockId) { + String appId, + int shuffleId, + int partitionId, + long blockId, + int retryMax, + long retryIntervalMax) { this.appId = appId; this.shuffleId = shuffleId; this.partitionId = partitionId; this.blockId = blockId; + this.retryMax = retryMax; + this.retryIntervalMax = retryIntervalMax; } public String getAppId() { @@ -47,4 +54,9 @@ public int getPartitionId() { public long getBlockId() { return blockId; } + + @Override + public String operationType() { + return "GetSortedShuffleData"; + } } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetSortedShuffleDataResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetSortedShuffleDataResponse.java index fa153e3d6e..d3819aa55b 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetSortedShuffleDataResponse.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetSortedShuffleDataResponse.java @@ -19,23 +19,37 @@ import java.nio.ByteBuffer; +import io.netty.buffer.Unpooled; + +import org.apache.uniffle.common.netty.buffer.ManagedBuffer; +import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; import org.apache.uniffle.common.rpc.StatusCode; public class RssGetSortedShuffleDataResponse extends ClientResponse { - private final ByteBuffer data; + private final ManagedBuffer data; private final long nextBlockId; private final int mergeState; public RssGetSortedShuffleDataResponse( - StatusCode statusCode, ByteBuffer data, long nextBlockId, int mergeState) { - super(statusCode); + StatusCode statusCode, String message, ByteBuffer data, long nextBlockId, int mergeState) { + this( + statusCode, + message, + new NettyManagedBuffer(Unpooled.wrappedBuffer(data)), + nextBlockId, + mergeState); + } + + public RssGetSortedShuffleDataResponse( + StatusCode statusCode, String message, ManagedBuffer data, long nextBlockId, int mergeState) { + super(statusCode, message); this.data = data; this.nextBlockId = nextBlockId; this.mergeState = mergeState; } - public ByteBuffer getData() { + public ManagedBuffer getData() { return data; } diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index 994a25c890..460b64d693 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -484,6 +484,9 @@ public void sendShuffleData( hasFailureOccurred = true; break; } else { + if (shuffleServer.isRemoteMergeEnable()) { + shuffleServer.getShuffleMergeManager().setDirect(appId, shuffleId, false); + } long toReleasedSize = spd.getTotalBlockEncodedLength(); // after each cacheShuffleData call, the `preAllocatedSize` is updated timely. manager.releasePreAllocatedSize(toReleasedSize); @@ -1616,7 +1619,7 @@ public void getSortedShuffleData( .setMState(mergeState.code()) .setStatus(status.toProto()) .setRetMsg(msg) - .setData(UnsafeByteOperations.unsafeWrap(sdr.getData())) + .setData(UnsafeByteOperations.unsafeWrap(sdr.getData(), 0, sdr.getDataLength())) .build(); } catch (Exception e) { status = StatusCode.INTERNAL_ERROR; @@ -1634,8 +1637,8 @@ public void getSortedShuffleData( shuffleServer.getShuffleBufferManager().releaseReadMemory(blockSize); } } else { - status = StatusCode.INTERNAL_ERROR; - msg = "Can't require memory to get shuffle data"; + status = StatusCode.NO_BUFFER; + msg = "Can't require read memory to get sorted shuffle data"; LOG.error(msg + " for " + requestInfo); reply = RssProtos.GetSortedShuffleDataResponse.newBuilder() diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index fa531ba024..5478d51c0c 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -591,6 +591,15 @@ public long requireBuffer(int requireSize) { return requireBuffer("EMPTY", requireSize); } + public boolean requireMemory(int requireSize, boolean isPreAllocated) { + return shuffleBufferManager.requireMemory(requireSize, isPreAllocated); + } + + public void releaseMemory( + int requireSize, boolean isReleaseFlushMemory, boolean isReleasePreAllocation) { + shuffleBufferManager.releaseMemory(requireSize, isReleaseFlushMemory, isReleasePreAllocation); + } + public byte[] getFinishedBlockIds( String appId, Integer shuffleId, Set partitions, BlockIdLayout blockIdLayout) throws IOException { diff --git a/server/src/main/java/org/apache/uniffle/server/merge/BlockFlushFileReader.java b/server/src/main/java/org/apache/uniffle/server/merge/BlockFlushFileReader.java index 76fa2cc1d0..97c3c06272 100644 --- a/server/src/main/java/org/apache/uniffle/server/merge/BlockFlushFileReader.java +++ b/server/src/main/java/org/apache/uniffle/server/merge/BlockFlushFileReader.java @@ -27,12 +27,17 @@ import java.util.Map; import java.util.concurrent.locks.ReentrantLock; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.util.internal.OutOfDirectMemoryError; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.netty.buffer.FileSegmentManagedBuffer; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.SerInputStream; import org.apache.uniffle.common.util.JavaUtils; +import org.apache.uniffle.common.util.NettyUtils; import org.apache.uniffle.storage.common.FileBasedShuffleSegment; /** @@ -53,6 +58,7 @@ public class BlockFlushFileReader { private static final Logger LOG = LoggerFactory.getLogger(BlockFlushFileReader.class); private static final int BUFFER_SIZE = 4096; + private final boolean direct; private String dataFile; private FileInputStream dataInput; private FileChannel dataFileChannel; @@ -74,17 +80,21 @@ public class BlockFlushFileReader { private final int ringBufferSize; private final int mask; - public BlockFlushFileReader(String dataFile, String indexFile, int ringBufferSize) - throws IOException { + public BlockFlushFileReader( + String dataFile, String indexFile, int ringBufferSize, boolean direct) { // Make sure flush file will not be updated this.ringBufferSize = ringBufferSize; + this.direct = direct; this.mask = ringBufferSize - 1; - loadShuffleIndex(indexFile); this.dataFile = dataFile; - this.dataInput = new FileInputStream(dataFile); - this.dataFileChannel = dataInput.getChannel(); + loadShuffleIndex(indexFile); // Avoid flushFileReader noop loop this.lock.lock(); + } + + void start() throws IOException { + this.dataInput = new FileInputStream(dataFile); + this.dataFileChannel = dataInput.getChannel(); this.flushFileReader = new FlushFileReader(); this.flushFileReader.start(); } @@ -109,9 +119,14 @@ public void loadShuffleIndex(String indexFileName) { } } - public void close() throws IOException, InterruptedException { - if (!this.stop) { - stop = true; + public void close() throws IOException { + stop = true; + for (BlockInputStream is : inputStreamMap.values()) { + is.close(); + } + inputStreamMap.clear(); + indexSegments.clear(); + if (flushFileReader != null) { flushFileReader.interrupt(); flushFileReader = null; } @@ -119,6 +134,7 @@ public void close() throws IOException, InterruptedException { this.dataInput.close(); this.dataInput = null; this.dataFile = null; + this.dataFileChannel = null; } } @@ -128,7 +144,7 @@ public BlockInputStream registerBlockInputStream(long blockId) { } if (!inputStreamMap.containsKey(blockId)) { inputStreamMap.put( - blockId, new BlockInputStream(blockId, this.indexSegments.get(blockId).getLength())); + blockId, new BlockInputStream(this.indexSegments.get(blockId).getLength())); } return inputStreamMap.get(blockId); } @@ -184,29 +200,44 @@ public void run() { class Buffer { - private byte[] bytes = new byte[BUFFER_SIZE]; - private int cap = BUFFER_SIZE; - private int pos = cap; + private ByteBuf buffer; + + Buffer() { + UnpooledByteBufAllocator allocator = NettyUtils.getSharedUnpooledByteBufAllocator(true); + this.buffer = + direct ? allocator.directBuffer(BUFFER_SIZE) : allocator.heapBuffer(BUFFER_SIZE); + } public int get() { - return this.bytes[pos++] & 0xFF; + return this.buffer.readByte() & 0xFF; } public int get(byte[] bs, int off, int len) { - int r = Math.min(cap - pos, len); - System.arraycopy(bytes, pos, bs, off, r); - pos += r; + int r = Math.min(this.buffer.readableBytes(), len); + this.buffer.readBytes(bs, off, r); return r; } public boolean readable() { - return pos < cap; + return this.buffer.readableBytes() > 0; } public void writeBuffer(int length) throws IOException { - dataFileChannel.read(ByteBuffer.wrap(this.bytes, 0, length)); - this.pos = 0; - this.cap = length; + ByteBuffer byteBuffer = this.buffer.nioBuffer(0, length); + dataFileChannel.read(byteBuffer); + this.buffer.readerIndex(0); + this.buffer.writerIndex(length); + } + + public ByteBuf getByteBuf() { + return this.buffer; + } + + public void release() { + if (this.buffer != null) { + this.buffer.release(); + this.buffer = null; + } } } @@ -220,9 +251,20 @@ class RingBuffer { int writeIndex = 0; RingBuffer() { - this.buffers = new Buffer[ringBufferSize]; - for (int i = 0; i < ringBufferSize; i++) { - this.buffers[i] = new Buffer(); + try { + this.buffers = new Buffer[ringBufferSize]; + for (int i = 0; i < ringBufferSize; i++) { + this.buffers[i] = new Buffer(); + } + } catch (OutOfDirectMemoryError error) { + // If out of direct memory here, previously created buffers + // cannot be released. + for (int i = 0; i < ringBufferSize; i++) { + if (this.buffers[i] != null) { + this.buffers[i].release(); + } + } + throw error; } } @@ -268,25 +310,45 @@ int read(byte[] bs, int off, int len) { } return total; } + + Buffer getReadBuffer() { + if (!empty()) { + return this.buffers[readIndex & mask]; + } + return null; + } + + void incReadIndex() { + readIndex++; + } + + void release() { + for (Buffer buffer : this.buffers) { + buffer.release(); + } + } } - public class BlockInputStream extends PartialInputStream { + public class BlockInputStream extends SerInputStream { - private long blockId; private RingBuffer ringBuffer; private boolean eof = false; private final int length; private int pos = 0; private int offsetInThisBlock = 0; - public BlockInputStream(long blockId, int length) { - this.blockId = blockId; + public BlockInputStream(int length) { this.length = length; - this.ringBuffer = new RingBuffer(); + } + + public void init() { + if (this.ringBuffer == null) { + this.ringBuffer = new RingBuffer(); + } } @Override - public int available() throws IOException { + public int available() { return length - pos; } @@ -300,20 +362,67 @@ public long getEnd() { return length; } + @Override + public void transferTo(ByteBuf to, int len) throws IOException { + while (len > 0) { + int c = internalTransferTo(to, len); + len -= c; + } + } + + private int internalTransferTo(ByteBuf out, int len) { + if (stop) { + throw new RssException("Block flush file reader is closed, caused by " + readThrowable); + } + if (len == 0) { + return 0; + } else if (eof || len < 0) { + throw new IndexOutOfBoundsException(); + } + + while (ringBuffer.empty() && !stop) { + if (lock.isHeldByCurrentThread()) { + lock.unlock(); + } + try { + lock.lockInterruptibly(); + } catch (InterruptedException e) { + throw new RssException(e); + } + } + + int c = 0; + while (len > 0) { + Buffer buffer = this.ringBuffer.getReadBuffer(); + if (buffer == null) { + break; + } + ByteBuf byteBuf = buffer.getByteBuf(); + int toRead = len; + if (len >= byteBuf.readableBytes()) { + this.ringBuffer.incReadIndex(); + toRead = byteBuf.readableBytes(); + } + len -= toRead; + out.writeBytes(byteBuf, toRead); + pos += toRead; + c += toRead; + } + if (pos >= length) { + eof = true; + } + return c; + } + public long getOffsetInThisBlock() { return this.offsetInThisBlock; } @Override - public void close() throws IOException { - try { - inputStreamMap.remove(blockId); - indexSegments.remove(blockId); - if (inputStreamMap.size() == 0) { - BlockFlushFileReader.this.close(); - } - } catch (InterruptedException e) { - throw new IOException(e); + public void close() { + if (ringBuffer != null) { + ringBuffer.release(); + ringBuffer = null; } } @@ -326,6 +435,7 @@ public void writeBuffer() throws IOException { this.offsetInThisBlock += size; } + @Override public int read(byte[] bs, int off, int len) throws IOException { if (stop) { throw new IOException("Block flush file reader is closed, caused by " + readThrowable); diff --git a/server/src/main/java/org/apache/uniffle/server/merge/DefaultMergeEventHandler.java b/server/src/main/java/org/apache/uniffle/server/merge/DefaultMergeEventHandler.java index 05c9a37238..4555143c66 100644 --- a/server/src/main/java/org/apache/uniffle/server/merge/DefaultMergeEventHandler.java +++ b/server/src/main/java/org/apache/uniffle/server/merge/DefaultMergeEventHandler.java @@ -93,9 +93,12 @@ private void handleEventAndUpdateMetrics(MergeEvent event) { } @Override - public void handle(MergeEvent event) { + public boolean handle(MergeEvent event) { if (queue.offer(event)) { ShuffleServerMetrics.gaugeMergeEventQueueSize.inc(); + return true; + } else { + return false; } } diff --git a/server/src/main/java/org/apache/uniffle/server/merge/MergeEventHandler.java b/server/src/main/java/org/apache/uniffle/server/merge/MergeEventHandler.java index c4a248e3a0..fb2540fa20 100644 --- a/server/src/main/java/org/apache/uniffle/server/merge/MergeEventHandler.java +++ b/server/src/main/java/org/apache/uniffle/server/merge/MergeEventHandler.java @@ -19,7 +19,7 @@ public interface MergeEventHandler { - void handle(MergeEvent event); + boolean handle(MergeEvent event); int getEventNumInMerge(); diff --git a/server/src/main/java/org/apache/uniffle/server/merge/MergedResult.java b/server/src/main/java/org/apache/uniffle/server/merge/MergedResult.java index 6c7ce056fb..039c231d34 100644 --- a/server/src/main/java/org/apache/uniffle/server/merge/MergedResult.java +++ b/server/src/main/java/org/apache/uniffle/server/merge/MergedResult.java @@ -17,14 +17,17 @@ package org.apache.uniffle.server.merge; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.OutputStream; import java.util.ArrayList; import java.util.List; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; + import org.apache.uniffle.common.config.RssConf; -import org.apache.uniffle.common.merger.Recordable; +import org.apache.uniffle.common.serializer.SerOutputStream; +import org.apache.uniffle.common.util.NettyUtils; import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE; @@ -35,9 +38,13 @@ public class MergedResult { // raw offset by blockId private final List offsets = new ArrayList<>(); private final CacheMergedBlockFuntion cachedMergedBlock; + private final Partition partition; public MergedResult( - RssConf rssConf, CacheMergedBlockFuntion cachedMergedBlock, int mergedBlockSize) { + RssConf rssConf, + CacheMergedBlockFuntion cachedMergedBlock, + int mergedBlockSize, + Partition partition) { this.rssConf = rssConf; this.cachedMergedBlock = cachedMergedBlock; this.mergedBlockSize = @@ -46,11 +53,12 @@ public MergedResult( : this.rssConf.getSizeAsBytes( SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE.key(), SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE.defaultValue()); + this.partition = partition; offsets.add(0L); } - public OutputStream getOutputStream() { - return new MergedSegmentOutputStream(); + public SerOutputStream getOutputStream(boolean direct, long totalBytes) { + return new MergedSegmentOutputStream(direct, totalBytes); } public boolean isOutOfBound(long blockId) { @@ -63,47 +71,95 @@ public long getBlockSize(long blockId) { @FunctionalInterface public interface CacheMergedBlockFuntion { - void cache(byte[] buffer, long blockId, int length); + boolean cache(ByteBuf byteBuf, long blockId, int length); } - class MergedSegmentOutputStream extends OutputStream implements Recordable { + public class MergedSegmentOutputStream extends SerOutputStream { + + private final boolean direct; + private final long totalBytes; + private ByteBuf byteBuf; + private long written = 0; + + MergedSegmentOutputStream(boolean direct, long totalBytes) { + this.direct = direct; + this.totalBytes = totalBytes; + } + + public void finalizeBlock() throws IOException { + // Avoid write empty block. + if (written <= offsets.get(offsets.size() - 1)) { + return; + } + int requireSize = byteBuf.readableBytes(); + // In fact, requireBuffer makes more sense before creating ByteBuf. + // However, it is not easy to catch exceptions to release buffer. + if (partition == null) { + throw new IOException("Can't find partition!"); + } + partition.requireMemory(requireSize); + boolean success = false; + try { + success = cachedMergedBlock.cache(byteBuf, offsets.size(), requireSize); + } finally { + if (!success) { + partition.releaseMemory(requireSize); + } + } + offsets.add(written); + } - ByteArrayOutputStream current; + // If some record is bigger than mergedBlockSize, we should allocate enough buffer for this. + // So when preallocate, we need to make sure the allocated buffer can write for the big record. + private void allocateNewBuffer(int preAllocateSize) { + if (this.byteBuf != null) { + byteBuf.release(); + byteBuf = null; + } + int alloc = Math.max((int) Math.min(mergedBlockSize, totalBytes - written), preAllocateSize); + UnpooledByteBufAllocator allocator = NettyUtils.getSharedUnpooledByteBufAllocator(true); + // In grpc mode, we may use array to visit the underlying buffer directly. + // We may still use array after release. But The pooled buffer may change + // the underlying buffer. So we can not use pooled buffer. + this.byteBuf = direct ? allocator.directBuffer(alloc) : Unpooled.buffer(alloc); + } - MergedSegmentOutputStream() { - current = new ByteArrayOutputStream((int) mergedBlockSize); + @Override + public void write(ByteBuf from) throws IOException { + preAllocate(from.readableBytes()); + int c = from.readableBytes(); + this.byteBuf.writeBytes(from); + written += c; } @Override public void write(int b) throws IOException { - current.write(b); + preAllocate(1); + this.byteBuf.writeByte((byte) (b & 0xFF)); + written++; } @Override - public void close() throws IOException { - if (current != null) { - current.close(); - current = null; + public void preAllocate(int length) throws IOException { + if (this.byteBuf == null || this.byteBuf.writableBytes() < length) { + finalizeBlock(); + allocateNewBuffer(length); } } + // Unlike the traditional flush, this flush can and must + // only be called once before close. @Override - public boolean record(long written, Flushable flushable, boolean force) throws IOException { - assert written >= 0; - long currentOffsetInThisBlock = written - offsets.get(offsets.size() - 1); - if (currentOffsetInThisBlock >= mergedBlockSize || (currentOffsetInThisBlock > 0 && force)) { - if (flushable != null) { - flushable.flush(); - } - cachedMergedBlock.cache( - current.toByteArray(), offsets.size(), (int) (currentOffsetInThisBlock)); - offsets.add(written); - if (!force) { - current = new ByteArrayOutputStream((int) mergedBlockSize); - } - return true; + public void flush() throws IOException { + finalizeBlock(); + } + + @Override + public void close() throws IOException { + if (this.byteBuf != null) { + this.byteBuf.release(); + this.byteBuf = null; } - return false; } } } diff --git a/server/src/main/java/org/apache/uniffle/server/merge/Partition.java b/server/src/main/java/org/apache/uniffle/server/merge/Partition.java index 667de657d5..48a9e0fb63 100644 --- a/server/src/main/java/org/apache/uniffle/server/merge/Partition.java +++ b/server/src/main/java/org/apache/uniffle/server/merge/Partition.java @@ -19,18 +19,15 @@ import java.io.File; import java.io.IOException; -import java.io.OutputStream; import java.nio.ByteBuffer; -import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Set; import com.google.common.collect.Range; -import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.ByteBuf; +import io.netty.util.IllegalReferenceCountException; import org.apache.hadoop.io.RawComparator; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; @@ -52,6 +49,8 @@ import org.apache.uniffle.common.netty.buffer.ManagedBuffer; import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.serializer.SerInputStream; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.server.ShuffleDataReadEvent; import org.apache.uniffle.server.buffer.ShuffleBuffer; import org.apache.uniffle.server.buffer.ShuffleBufferWithSkipList; @@ -92,7 +91,8 @@ public Partition(Shuffle shuffle, int partitionId) throws IOException { this.shuffle = shuffle; this.partitionId = partitionId; this.result = - new MergedResult(shuffle.serverConf, this::cachedMergedBlock, shuffle.mergedBlockSize); + new MergedResult( + shuffle.serverConf, this::cachedMergedBlock, shuffle.mergedBlockSize, this); this.initSleepTime = shuffle.serverConf.get(SERVER_MERGE_CACHE_MERGED_BLOCK_INIT_SLEEP_MS); this.maxSleepTime = shuffle.serverConf.get(SERVER_MERGE_CACHE_MERGED_BLOCK_MAX_SLEEP_MS); int tmpRingBufferSize = shuffle.serverConf.get(SERVER_MERGE_BLOCK_RING_BUFFER_SIZE); @@ -107,7 +107,7 @@ public Partition(Shuffle shuffle, int partitionId) throws IOException { } // startSortMerge is used to trigger to merger - synchronized void startSortMerge(Roaring64NavigableMap expectedBlockIdMap) throws IOException { + synchronized void startSortMerge(Roaring64NavigableMap expectedBlockIdMap) { if (getState() != INITED) { LOG.warn("Partition is already merging, so ignore duplicate reports, partition is {}", this); } else { @@ -121,7 +121,9 @@ synchronized void startSortMerge(Roaring64NavigableMap expectedBlockIdMap) throw shuffle.kClass, shuffle.vClass, expectedBlockIdMap); - shuffle.eventHandler.handle(event); + if (!shuffle.eventHandler.handle(event)) { + setState(INTERNAL_ERROR); + } } else { setState(DONE); } @@ -144,49 +146,63 @@ private ShufflePartitionedBlock getShufflePartitionedBlock(long blockId, boolean return null; } - // getSegments is used to get segments from original shuffle blocks - public List getSegments( - RssConf rssConf, Iterator blockIds, Class keyClass, Class valueClass) - throws IOException { - List segments = new ArrayList<>(); - Set blocksFlushed = new HashSet<>(); + public boolean collectBlocks(Iterator blockIds, Map cachedBlocks) { + boolean allCached = true; while (blockIds.hasNext()) { long blockId = blockIds.next(); ShufflePartitionedBlock block = getShufflePartitionedBlock(blockId, false); - if (block != null && ByteBufUtil.isAccessible(block.getData())) { - try { - StreamedSegment segment = - new StreamedSegment( - rssConf, - block.getData().nioBuffer(0, block.getDataLength()), - blockId, - keyClass, - valueClass, - (shuffle.comparator instanceof RawComparator)); - segments.add(segment); - } catch (Exception e) { - // If ByteBuf is released by flush cleanup will throw ConcurrentModificationException. - // So we need get block buffer from file - LOG.warn("construct segment failed, caused by ", e); - blocksFlushed.add(blockId); - } - } else { - blocksFlushed.add(blockId); + if (block == null) { + allCached = false; + continue; + } + try { + // If ByteBuf is released by flush cleanup will throw IllegalReferenceCountException. + // Then we need get block buffer from file + ByteBuf byteBuf = block.getData().retain().duplicate(); + cachedBlocks.put(blockId, byteBuf.slice(0, block.getDataLength())); + } catch (IllegalReferenceCountException irce) { + allCached = false; + LOG.warn("Can't read bytes from block in memory, maybe already been flushed!"); } } - if (blocksFlushed.isEmpty()) { - return segments; - } - try { - LocalFileServerReadHandler handler = getLocalFileServerReadHandler(rssConf, shuffle.appId); - this.reader = - new BlockFlushFileReader( - handler.getDataFileName(), handler.getIndexFileName(), ringBufferSize); - for (Long blockId : blocksFlushed) { + return allCached; + } + + BlockFlushFileReader createReader(RssConf rssConf) { + LocalFileServerReadHandler handler = getLocalFileServerReadHandler(rssConf, shuffle.appId); + return new BlockFlushFileReader( + handler.getDataFileName(), handler.getIndexFileName(), ringBufferSize, shuffle.direct); + } + + public boolean collectSegments( + RssConf rssConf, + Iterator blockIds, + Class keyClass, + Class valueClass, + Map cachedBlock, + List segments, + BlockFlushFileReader reader) { + while (blockIds.hasNext()) { + long blockId = blockIds.next(); + if (cachedBlock.containsKey(blockId)) { + ByteBuf byteBuf = cachedBlock.get(blockId); + SerInputStream serInputStream = SerInputStream.newInputStream(byteBuf); + StreamedSegment segment = + new StreamedSegment( + rssConf, + serInputStream, + blockId, + keyClass, + valueClass, + byteBuf.readableBytes(), + (shuffle.comparator instanceof RawComparator)); + segments.add(segment); + } else { BlockFlushFileReader.BlockInputStream inputStream = reader.registerBlockInputStream(blockId); if (inputStream == null) { - throw new IOException("Can not find any buffer or file for block " + blockId); + LOG.warn("Can not find any buffer or file for block {}", blockId); + return false; } segments.add( new StreamedSegment( @@ -195,20 +211,27 @@ public List getSegments( blockId, keyClass, valueClass, + inputStream.available(), (shuffle.comparator instanceof RawComparator))); } - return segments; - } catch (Throwable throwable) { - throw new IOException(throwable); } + return true; } - void merge(List segments) throws IOException { + SerOutputStream createSerOutputStream(long totalBytes) { + return result.getOutputStream(shuffle.direct, totalBytes); + } + + void merge(List segments, SerOutputStream output, BlockFlushFileReader reader) { try { - OutputStream outputStream = result.getOutputStream(); + segments.forEach(segment -> segment.init()); + // start reader must happen after init segment to allocate ring buffer. + if (reader != null) { + reader.start(); + } Merger.merge( shuffle.serverConf, - outputStream, + output, segments, shuffle.kClass, shuffle.vClass, @@ -216,10 +239,29 @@ void merge(List segments) throws IOException { (shuffle.comparator instanceof RawComparator)); setState(DONE); } catch (Exception e) { - // TODO: should retry!!! - LOG.error("Partition {} remote merge failed, caused by {}", this, e); + LOG.info("Found exception when merge for {}, caused by", this, e); setState(INTERNAL_ERROR); - throw new IOException(e); + } finally { + try { + if (reader != null) { + reader.close(); + } + } catch (IOException ioe) { + LOG.warn("Fail to close reader, caused by", this, ioe); + } + try { + output.close(); + } catch (IOException ioe) { + LOG.warn("Fail to close output, caused by ", ioe); + } + segments.forEach( + segment -> { + try { + segment.close(); + } catch (IOException ioe) { + LOG.warn("Fail to close segment, caused by ", ioe); + } + }); } } @@ -245,59 +287,59 @@ public MergeStatus tryGetBlock(long blockId) { return new MergeStatus(currentState, size); } + public void requireMemory(int requireSize) throws IOException { + while (!shuffle.shuffleServer.getShuffleTaskManager().requireMemory(requireSize, false)) { + try { + LOG.debug("Can not allocate enough memory for {}, then will sleep {}ms", this, sleepTime); + Thread.sleep(sleepTime); + sleepTime = Math.min(maxSleepTime, sleepTime * 2); + } catch (InterruptedException ex) { + LOG.warn("Found InterruptedException when sleep to wait require buffer {}", this); + throw new IOException(ex); + } + } + } + + public void releaseMemory(int requireSize) { + shuffle.shuffleServer.getShuffleTaskManager().releaseMemory(requireSize, false, false); + } + // When we merge data, we will divide the merge results into blocks according to the specified // block size. // The merged block in a new appId field (${appd} + MERGE_APP_SUFFIX). We will process the merged // blocks in the // original way, cache them first, and flush them to disk when necessary. - private void cachedMergedBlock(byte[] buffer, long blockId, int length) { + private boolean cachedMergedBlock(ByteBuf byteBuf, long blockId, int length) { String appId = shuffle.appId + MERGE_APP_SUFFIX; ShufflePartitionedBlock spb = - new ShufflePartitionedBlock(length, length, -1, blockId, -1, buffer); + new ShufflePartitionedBlock(length, length, -1, blockId, -1, byteBuf.retain()); ShufflePartitionedData spd = new ShufflePartitionedData(partitionId, new ShufflePartitionedBlock[] {spb}); - while (true) { - StatusCode ret = - shuffle - .shuffleServer - .getShuffleTaskManager() - .cacheShuffleData(appId, shuffle.shuffleId, false, spd); - if (ret == StatusCode.SUCCESS) { + StatusCode ret = shuffle .shuffleServer .getShuffleTaskManager() - .updateCachedBlockIds( - appId, shuffle.shuffleId, spd.getPartitionId(), spd.getBlockList()); - sleepTime = initSleepTime; - break; - } else if (ret == StatusCode.NO_BUFFER) { - try { - LOG.info( - "Can not allocate enough memory for " - + this - + ", then will sleep " - + sleepTime - + "ms"); - Thread.sleep(sleepTime); - sleepTime = Math.min(maxSleepTime, sleepTime * 2); - } catch (InterruptedException ex) { - throw new RssException(ex); - } - } else { - String shuffleDataInfo = - "appId[" - + appId - + "], shuffleId[" - + shuffle.shuffleId - + "], partitionId[" - + spd.getPartitionId() - + "]"; - throw new RssException( - "Error happened when shuffleEngine.write for " - + shuffleDataInfo - + ", statusCode=" - + ret); - } + .cacheShuffleData(appId, shuffle.shuffleId, true, spd); + if (ret == StatusCode.SUCCESS) { + shuffle + .shuffleServer + .getShuffleTaskManager() + .updateCachedBlockIds(appId, shuffle.shuffleId, spd.getPartitionId(), spd.getBlockList()); + sleepTime = initSleepTime; + return true; + } else { + String shuffleDataInfo = + "appId[" + + appId + + "], shuffleId[" + + shuffle.shuffleId + + "], partitionId[" + + spd.getPartitionId() + + "]"; + LOG.warn( + "Error happened when shuffleEngine.write for {}, statusCode={}", shuffleDataInfo, ret); + byteBuf.release(); + return false; } } @@ -321,11 +363,12 @@ private NettyManagedBuffer getMergedBlockBufferInMemory(long blockId) { try { ShufflePartitionedBlock block = this.getShufflePartitionedBlock(blockId, true); // We must make sure refCnt > 0, it means the ByteBuf is not released by flush cleanup - if (block != null && block.getData().refCnt() > 0) { - return new NettyManagedBuffer(block.getData().retain()); + if (block != null) { + ByteBuf byteBuf = block.getData().retain(); + return new NettyManagedBuffer(byteBuf.duplicate()); } return null; - } catch (Exception e) { + } catch (IllegalReferenceCountException e) { // If release that is triggered by flush cleanup before we retain, may throw // IllegalReferenceCountException. // It means ByteBuf is not available, we must get the block buffer from file. @@ -411,9 +454,6 @@ private LocalFileServerReadHandler getLocalFileServerReadHandler(RssConf rssConf void cleanup() { try { - if (reader != null) { - reader.close(); - } shuffleMeta.clear(); } catch (Exception e) { LOG.warn("Partition {} clean up failed, caused by {}", this, e); diff --git a/server/src/main/java/org/apache/uniffle/server/merge/Shuffle.java b/server/src/main/java/org/apache/uniffle/server/merge/Shuffle.java index 3aa7df6f82..9f0099702f 100644 --- a/server/src/main/java/org/apache/uniffle/server/merge/Shuffle.java +++ b/server/src/main/java/org/apache/uniffle/server/merge/Shuffle.java @@ -42,6 +42,7 @@ public class Shuffle { private final Map> partitions = JavaUtils.newConcurrentMap(); final int mergedBlockSize; final ClassLoader classLoader; + boolean direct = false; public Shuffle( RssConf rssConf, @@ -79,6 +80,10 @@ void cleanup() { this.partitions.clear(); } + public void setDirect(boolean direct) throws IOException { + this.direct = direct; + } + public ClassLoader getClassLoader() { return classLoader; } diff --git a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java index f8a6c1bfd2..027b63d9df 100644 --- a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java +++ b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java @@ -32,6 +32,7 @@ import java.util.Map; import com.google.common.annotations.VisibleForTesting; +import io.netty.buffer.ByteBuf; import org.apache.commons.lang3.ClassUtils; import org.apache.commons.lang3.StringUtils; import org.roaringbitmap.longlong.Roaring64NavigableMap; @@ -39,14 +40,15 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.common.ShuffleDataResult; -import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.merger.Segment; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.proto.RssProtos.MergeContext; import org.apache.uniffle.server.ShuffleServer; import org.apache.uniffle.server.ShuffleServerConf; +import static org.apache.uniffle.common.merger.MergeState.INTERNAL_ERROR; import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_CLASS_LOADER_JARS_PATH; public class ShuffleMergeManager { @@ -230,24 +232,60 @@ public void startSortMerge( } public void processEvent(MergeEvent event) { + boolean success = false; + Partition partition = null; + Map cachedBlocks = new HashMap<>(); try { - ClassLoader original = Thread.currentThread().getContextClassLoader(); Thread.currentThread() .setContextClassLoader( this.getShuffle(event.getAppId(), event.getShuffleId()).getClassLoader()); - List segments = - this.getPartition(event.getAppId(), event.getShuffleId(), event.getPartitionId()) - .getSegments( - serverConf, - event.getExpectedBlockIdMap().iterator(), - event.getKeyClass(), - event.getValueClass()); - this.getPartition(event.getAppId(), event.getShuffleId(), event.getPartitionId()) - .merge(segments); - Thread.currentThread().setContextClassLoader(original); - } catch (Exception e) { - LOG.info("Found exception when merge, caused by ", e); - throw new RssException(e); + partition = this.getPartition(event.getAppId(), event.getShuffleId(), event.getPartitionId()); + if (partition == null) { + LOG.info("Can not find partition for event: {}", event); + return; + } + + // 1 collect blocks, retain block from bufferPool, need to release. + boolean allCached = + partition.collectBlocks(event.getExpectedBlockIdMap().iterator(), cachedBlocks); + + // 2 If the size of cacheBlock is less than total block, we will read from file, so construct + // reader + BlockFlushFileReader reader = null; + if (!allCached) { + // create reader do not allocate resource. + reader = partition.createReader(serverConf); + } + + // 3 collect input segments, but not init. So do not allocate any resource. + List segments = new ArrayList<>(); + boolean allFound = + partition.collectSegments( + serverConf, + event.getExpectedBlockIdMap().iterator(), + event.getKeyClass(), + event.getValueClass(), + cachedBlocks, + segments, + reader); + if (!allFound) { + return; + } + + // 4 create output, but not init. So do not allocate any resource. + // Because of the presence of EOF, the totalBytes are generally slightly larger than the + // required space. + long totalBytes = segments.stream().mapToLong(segment -> segment.getSize()).sum(); + SerOutputStream output = partition.createSerOutputStream(totalBytes); + + // 5 merge segments to output + partition.merge(segments, output, reader); + success = true; + } finally { + if (!success && partition != null) { + partition.setState(INTERNAL_ERROR); + } + cachedBlocks.values().forEach(byteBuf -> byteBuf.release()); } } @@ -256,6 +294,12 @@ public ShuffleDataResult getShuffleData( return this.getPartition(appId, shuffleId, partitionId).getShuffleData(blockId); } + public void setDirect(String appId, int shuffleId, boolean direct) throws IOException { + if (this.shuffles.containsKey(appId) && this.shuffles.get(appId).containsKey(shuffleId)) { + this.getShuffle(appId, shuffleId).setDirect(direct); + } + } + public MergeStatus tryGetBlock(String appId, int shuffleId, int partitionId, long blockId) { return this.getPartition(appId, shuffleId, partitionId).tryGetBlock(blockId); } @@ -271,7 +315,12 @@ Shuffle getShuffle(String appId, int shuffleId) { @VisibleForTesting Partition getPartition(String appId, int shuffleId, int partitionId) { - return this.shuffles.get(appId).get(shuffleId).getPartition(partitionId); + if (this.shuffles.containsKey(appId)) { + if (this.shuffles.get(appId).containsKey(shuffleId)) { + return this.shuffles.get(appId).get(shuffleId).getPartition(partitionId); + } + } + return null; } public void refreshAppId(String appId) { diff --git a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java index b352fdae9c..eb1951a5f3 100644 --- a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java +++ b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java @@ -45,6 +45,7 @@ import org.apache.uniffle.common.exception.ExceedHugePartitionHardLimitException; import org.apache.uniffle.common.exception.FileNotFoundException; import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.merger.MergeState; import org.apache.uniffle.common.netty.buffer.ManagedBuffer; import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; import org.apache.uniffle.common.netty.client.TransportClient; @@ -55,6 +56,8 @@ import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse; import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest; import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataResponse; +import org.apache.uniffle.common.netty.protocol.GetSortedShuffleDataRequest; +import org.apache.uniffle.common.netty.protocol.GetSortedShuffleDataResponse; import org.apache.uniffle.common.netty.protocol.RequestMessage; import org.apache.uniffle.common.netty.protocol.RpcResponse; import org.apache.uniffle.common.netty.protocol.SendShuffleDataRequest; @@ -68,6 +71,7 @@ import org.apache.uniffle.server.audit.ServerRpcAuditContext; import org.apache.uniffle.server.buffer.PreAllocatedBufferInfo; import org.apache.uniffle.server.buffer.ShuffleBufferManager; +import org.apache.uniffle.server.merge.MergeStatus; import org.apache.uniffle.storage.common.Storage; import org.apache.uniffle.storage.common.StorageReadMetrics; import org.apache.uniffle.storage.util.ShuffleStorageUtils; @@ -124,6 +128,8 @@ public void receive(TransportClient client, RequestMessage msg) { handleGetLocalShuffleIndexRequest(client, (GetLocalShuffleIndexRequest) msg); } else if (msg instanceof GetMemoryShuffleDataRequest) { handleGetMemoryShuffleDataRequest(client, (GetMemoryShuffleDataRequest) msg); + } else if (msg instanceof GetSortedShuffleDataRequest) { + handleGetSortedShuffleDataRequest(client, (GetSortedShuffleDataRequest) msg); } else { throw new RssException("Can not handle message " + msg.type()); } @@ -302,6 +308,9 @@ public void handleSendShuffleDataRequest(TransportClient client, SendShuffleData responseMessage = errorMsg; hasFailureOccurred = true; } else { + if (shuffleServer.isRemoteMergeEnable()) { + shuffleServer.getShuffleMergeManager().setDirect(appId, shuffleId, true); + } long toReleasedSize = spd.getTotalBlockEncodedLength(); // after each cacheShuffleData call, the `preAllocatedSize` is updated timely. shuffleTaskManager.releasePreAllocatedSize(toReleasedSize); @@ -751,6 +760,109 @@ public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDat } } + public void handleGetSortedShuffleDataRequest( + TransportClient client, GetSortedShuffleDataRequest req) { + final long start = System.currentTimeMillis(); + long requestId = req.getRequestId(); + String appId = req.getAppId(); + int shuffleId = req.getShuffleId(); + int partitionId = req.getPartitionId(); + long blockId = req.getBlockId(); + long timestamp = req.getTimestamp(); + + if (timestamp > 0) { + long transportTime = start - timestamp; + if (transportTime > 0) { + shuffleServer + .getNettyMetrics() + .recordTransportTime(GetSortedShuffleDataRequest.class.getName(), transportTime); + } + } + StatusCode status = StatusCode.SUCCESS; + String msg = "OK"; + GetSortedShuffleDataResponse response; + String requestInfo = + "appId[" + + appId + + "], shuffleId[" + + shuffleId + + "], partitionId[" + + partitionId + + "], blockId[" + + blockId + + "]"; + if (!shuffleServer.isRemoteMergeEnable()) { + msg = "Remote merge is disabled"; + status = StatusCode.INTERNAL_ERROR; + response = + new GetSortedShuffleDataResponse( + requestId, status, msg, -1, MergeState.INTERNAL_ERROR.code(), Unpooled.EMPTY_BUFFER); + client.getChannel().writeAndFlush(response); + return; + } + MergeStatus mergeStatus = + shuffleServer.getShuffleMergeManager().tryGetBlock(appId, shuffleId, partitionId, blockId); + MergeState mergeState = mergeStatus.getState(); + long readBlockSize = mergeStatus.getSize(); + + if (mergeState == MergeState.INITED + || (mergeState == MergeState.MERGING && readBlockSize == -1) + || (mergeState == MergeState.DONE && readBlockSize == -1) + || mergeState == MergeState.INTERNAL_ERROR) { + msg = mergeState.name(); + response = + new GetSortedShuffleDataResponse( + requestId, status, msg, -1, mergeState.code(), Unpooled.EMPTY_BUFFER); + client.getChannel().writeAndFlush(response); + return; + } + + if (shuffleServer.getShuffleBufferManager().requireReadMemory(readBlockSize)) { + ShuffleDataResult sdr = null; + try { + sdr = + shuffleServer + .getShuffleMergeManager() + .getShuffleData(appId, shuffleId, partitionId, blockId); + + response = + new GetSortedShuffleDataResponse( + requestId, status, msg, blockId + 1, mergeState.code(), sdr.getManagedBuffer()); + + ReleaseMemoryAndRecordReadTimeListener listener = + new ReleaseMemoryAndRecordReadTimeListener( + start, readBlockSize, sdr.getDataLength(), requestInfo, req, response, client); + + client.getChannel().writeAndFlush(response).addListener(listener); + } catch (Exception e) { + shuffleServer.getShuffleBufferManager().releaseReadMemory(readBlockSize); + if (sdr != null) { + sdr.release(); + } + status = StatusCode.INTERNAL_ERROR; + msg = "Error happened when get shuffle data for " + requestInfo + ", " + e.getMessage(); + LOG.error(msg, e); + response = + new GetSortedShuffleDataResponse( + requestId, + status, + msg, + -1, + MergeState.INTERNAL_ERROR.code(), + Unpooled.EMPTY_BUFFER); + client.getChannel().writeAndFlush(response); + } + } else { + status = StatusCode.NO_BUFFER; + msg = "Can't require read memory to get sorted shuffle data"; + LOG.error(msg + " for " + requestInfo); + response = + new GetSortedShuffleDataResponse( + requestId, status, msg, -1, mergeState.code(), Unpooled.EMPTY_BUFFER); + client.getChannel().writeAndFlush(response); + } + } + private List toPartitionedDataList(SendShuffleDataRequest req) { List ret = Lists.newArrayList(); @@ -872,6 +984,15 @@ public void operationComplete(ChannelFuture future) { errorMsg, Lists.newArrayList(), Unpooled.EMPTY_BUFFER); + } else if (request instanceof GetSortedShuffleDataRequest) { + errorResponse = + new GetSortedShuffleDataResponse( + request.getRequestId(), + StatusCode.INTERNAL_ERROR, + errorMsg, + -1L, + MergeState.INTERNAL_ERROR.code(), + Unpooled.EMPTY_BUFFER); } else { LOG.error("Cannot handle request {}", request.type(), cause); return; diff --git a/server/src/test/java/org/apache/uniffle/server/merge/BlockFlushFileReaderTest.java b/server/src/test/java/org/apache/uniffle/server/merge/BlockFlushFileReaderTest.java index 6b6c870511..ca3bc65ab3 100644 --- a/server/src/test/java/org/apache/uniffle/server/merge/BlockFlushFileReaderTest.java +++ b/server/src/test/java/org/apache/uniffle/server/merge/BlockFlushFileReaderTest.java @@ -18,7 +18,6 @@ package org.apache.uniffle.server.merge; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; import java.util.ArrayList; import java.util.Comparator; @@ -28,7 +27,7 @@ import java.util.concurrent.atomic.AtomicInteger; import com.google.common.collect.Lists; -import org.apache.hadoop.io.RawComparator; +import io.netty.buffer.ByteBuf; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.ParameterizedTest; @@ -41,7 +40,9 @@ import org.apache.uniffle.common.merger.Segment; import org.apache.uniffle.common.merger.StreamedSegment; import org.apache.uniffle.common.records.RecordsReader; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.DynBufferSerOutputStream; +import org.apache.uniffle.common.serializer.SerInputStream; +import org.apache.uniffle.common.serializer.SerOutputStream; import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler; @@ -58,9 +59,14 @@ public class BlockFlushFileReaderTest { @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,4", - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,32", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2,true,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2,true,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2,false,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2,false,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,8,true,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,8,true,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,8,false,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,8,false,false", }) public void writeTestWithMerge(String classes, @TempDir File tmpDir) throws Exception { final String[] classArray = classes.split(","); @@ -68,8 +74,9 @@ public void writeTestWithMerge(String classes, @TempDir File tmpDir) throws Exce final Class valueClass = SerializerUtils.getClassByName(classArray[1]); final Comparator comparator = SerializerUtils.getComparator(keyClass); final int ringBufferSize = Integer.parseInt(classArray[2]); + boolean raw = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; + boolean direct = classArray.length > 4 ? Boolean.parseBoolean(classArray[4]) : false; - final File dataOutput = new File(tmpDir, "dataOutput"); final File dataDir = new File(tmpDir, "data"); final String[] basePaths = new String[] {dataDir.getAbsolutePath()}; final LocalFileWriteHandler writeHandler1 = @@ -91,42 +98,40 @@ public void writeTestWithMerge(String classes, @TempDir File tmpDir) throws Exce String indexFileName = readHandler.getIndexFileName(); BlockFlushFileReader blockFlushFileReader = - new BlockFlushFileReader(dataFileName, indexFileName, ringBufferSize); + new BlockFlushFileReader(dataFileName, indexFileName, ringBufferSize, direct); List segments = new ArrayList<>(); for (Long blockId : expectedBlockIds) { - PartialInputStream partialInputStream = - blockFlushFileReader.registerBlockInputStream(blockId); - segments.add( + SerInputStream inputStream = blockFlushFileReader.registerBlockInputStream(blockId); + Segment segment = new StreamedSegment( - conf, - partialInputStream, - blockId, - keyClass, - valueClass, - comparator instanceof RawComparator)); + conf, inputStream, blockId, keyClass, valueClass, inputStream.available(), raw); + segments.add(segment); } - FileOutputStream outputStream = new FileOutputStream(dataOutput); - Merger.merge( - conf, - outputStream, - segments, - keyClass, - valueClass, - comparator, - comparator instanceof RawComparator); + SerOutputStream outputStream = new DynBufferSerOutputStream(); + segments.forEach(segment -> segment.init()); + blockFlushFileReader.start(); + Merger.merge(conf, outputStream, segments, keyClass, valueClass, comparator, raw); + blockFlushFileReader.close(); outputStream.close(); + for (Segment segment : segments) { + segment.close(); + } int index = 0; + ByteBuf byteBuf = outputStream.toByteBuf(); RecordsReader reader = new RecordsReader( - conf, PartialInputStream.newInputStream(dataOutput), keyClass, valueClass, false); + conf, SerInputStream.newInputStream(byteBuf), keyClass, valueClass, false, false); + reader.init(); while (reader.next()) { assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); index++; } assertEquals(100900, index); + blockFlushFileReader.close(); + byteBuf.release(); } public static void writeTestData( @@ -143,11 +148,13 @@ public static List generateBlocks( throws IOException { BlockIdLayout layout = BlockIdLayout.DEFAULT; List blocks = Lists.newArrayList(); - byte[] bytes = - SerializerUtils.genSortedRecordBytes( + ByteBuf byteBuf = + SerializerUtils.genSortedRecordBuffer( rssConf, keyClass, valueClass, start, interval, length, 1); long blockId = layout.getBlockId(ATOMIC_INT.incrementAndGet(), 0, 100); - blocks.add(new ShufflePartitionedBlock(bytes.length, bytes.length, 0, blockId, 100, bytes)); + blocks.add( + new ShufflePartitionedBlock( + byteBuf.readableBytes(), byteBuf.readableBytes(), 0, blockId, 100, byteBuf)); return blocks; } @@ -155,15 +162,20 @@ public static List generateBlocks( @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2,true,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2,true,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2,false,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,2,false,false", }) public void writeTestWithMergeWhenInterrupted(String classes, @TempDir File tmpDir) throws Exception { String[] classArray = classes.split(","); - Class keyClass = SerializerUtils.getClassByName(classArray[0]); - Class valueClass = SerializerUtils.getClassByName(classArray[1]); - Comparator comparator = SerializerUtils.getComparator(keyClass); + final Class keyClass = SerializerUtils.getClassByName(classArray[0]); + final Class valueClass = SerializerUtils.getClassByName(classArray[1]); + final Comparator comparator = SerializerUtils.getComparator(keyClass); int ringBufferSize = Integer.parseInt(classArray[2]); + boolean raw = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; + boolean direct = classArray.length > 4 ? Boolean.parseBoolean(classArray[4]) : false; File dataDir = new File(tmpDir, "data"); String[] basePaths = new String[] {dataDir.getAbsolutePath()}; @@ -187,37 +199,37 @@ public void writeTestWithMergeWhenInterrupted(String classes, @TempDir File tmpD String indexFileName = readHandler.getIndexFileName(); BlockFlushFileReader blockFlushFileReader = - new BlockFlushFileReader(dataFileName, indexFileName, ringBufferSize); + new BlockFlushFileReader(dataFileName, indexFileName, ringBufferSize, direct); + blockFlushFileReader.start(); List segments = new ArrayList<>(); for (Long blockId : expectedBlockIds) { - PartialInputStream partialInputStream = - blockFlushFileReader.registerBlockInputStream(blockId); - segments.add( + SerInputStream inputStream = blockFlushFileReader.registerBlockInputStream(blockId); + Segment segment = new MockedStreamedSegment( conf, - partialInputStream, + inputStream, blockId, keyClass, valueClass, - comparator instanceof RawComparator, - blockFlushFileReader)); + inputStream.available(), + raw, + blockFlushFileReader); + segments.add(segment); } - FileOutputStream outputStream = new FileOutputStream(dataOutput); + SerOutputStream outputStream = new DynBufferSerOutputStream(); + segments.forEach(segment -> segment.init()); + blockFlushFileReader.start(); assertThrows( Exception.class, - () -> { - Merger.merge( - conf, - outputStream, - segments, - keyClass, - valueClass, - comparator, - comparator instanceof RawComparator); - }); + () -> Merger.merge(conf, outputStream, segments, keyClass, valueClass, comparator, raw)); + outputStream.close(); + blockFlushFileReader.close(); outputStream.close(); + for (Segment segment : segments) { + segment.close(); + } } class MockedStreamedSegment extends StreamedSegment { @@ -227,24 +239,21 @@ class MockedStreamedSegment extends StreamedSegment { MockedStreamedSegment( RssConf rssConf, - PartialInputStream inputStream, + SerInputStream inputStream, long blockId, Class keyClass, Class valueClass, + long size, boolean raw, BlockFlushFileReader reader) { - super(rssConf, inputStream, blockId, keyClass, valueClass, raw); + super(rssConf, inputStream, blockId, keyClass, valueClass, size, raw); this.reader = reader; } public boolean next() throws IOException { boolean ret = super.next(); if (this.count++ > 200) { - try { - this.reader.close(); - } catch (InterruptedException e) { - throw new IOException(e); - } + this.reader.close(); } return ret; } diff --git a/server/src/test/java/org/apache/uniffle/server/merge/MergedResultTest.java b/server/src/test/java/org/apache/uniffle/server/merge/MergedResultTest.java index e8dcda7abf..c9cb292a5b 100644 --- a/server/src/test/java/org/apache/uniffle/server/merge/MergedResultTest.java +++ b/server/src/test/java/org/apache/uniffle/server/merge/MergedResultTest.java @@ -19,14 +19,13 @@ import java.io.File; import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Comparator; import java.util.List; +import io.netty.buffer.ByteBuf; import org.apache.commons.lang3.tuple.Pair; -import org.apache.hadoop.io.RawComparator; +import org.apache.hadoop.io.DataInputBuffer; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.ParameterizedTest; @@ -34,15 +33,20 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.merger.Merger; -import org.apache.uniffle.common.merger.Recordable; import org.apache.uniffle.common.merger.Segment; import org.apache.uniffle.common.records.RecordsReader; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.SerInputStream; +import org.apache.uniffle.common.serializer.SerOutputStream; +import org.apache.uniffle.common.serializer.Serializer; +import org.apache.uniffle.common.serializer.SerializerFactory; +import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.SerializerUtils; +import static org.apache.uniffle.common.serializer.SerializerUtils.genData; import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; public class MergedResultTest { @@ -53,61 +57,70 @@ public class MergedResultTest { @Test public void testMergedResult() throws IOException { // 1 Construct cache - List> blocks = new ArrayList<>(); + List> blocks = new ArrayList<>(); MergedResult.CacheMergedBlockFuntion cache = - (byte[] buffer, long blockId, int length) -> { + (ByteBuf byteBuf, long blockId, int length) -> { + byteBuf.retain(); assertEquals(blockId - 1, blocks.size()); - blocks.add(Pair.of(length, buffer)); + blocks.add(Pair.of(length, byteBuf)); + return true; }; // 2 Write to merged result RssConf rssConf = new RssConf(); rssConf.set(SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE, String.valueOf(BYTES_LEN / 10)); - MergedResult result = new MergedResult(rssConf, cache, -1); - OutputStream output = result.getOutputStream(); + Partition partition = mock(Partition.class); + MergedResult result = new MergedResult(rssConf, cache, -1, partition); + SerOutputStream output = result.getOutputStream(false, BYTES_LEN); for (int i = 0; i < BYTES_LEN; i++) { output.write((byte) (i & 0x7F)); - if (output instanceof Recordable) { - ((Recordable) output).record(i + 1, null, false); - } } + output.flush(); output.close(); // 3 check blocks number - // Max merged block is 1024, every record have 2 bytes, so will result to 10 block + // Max merged block is 1024, every record have one byte, so will result to 10 block assertEquals(10, blocks.size()); // 4 check the blocks int index = 0; for (int i = 0; i < blocks.size(); i++) { int length = blocks.get(i).getLeft(); - byte[] buffer = blocks.get(i).getRight(); - assertTrue(buffer.length >= length); + ByteBuf byteBuf = blocks.get(i).getRight(); + assertTrue(byteBuf.readableBytes() >= length); for (int j = 0; j < length; j++) { - assertEquals(index & 0x7F, buffer[j]); + assertEquals(index & 0x7F, byteBuf.readByte()); index++; } } assertEquals(BYTES_LEN, index); + blocks.forEach(block -> block.getRight().release()); } @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false,false", }) public void testMergeSegmentToMergeResult(String classes, @TempDir File tmpDir) throws Exception { // 1 Parse arguments String[] classArray = classes.split(","); Class keyClass = SerializerUtils.getClassByName(classArray[0]); Class valueClass = SerializerUtils.getClassByName(classArray[1]); + boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + boolean direct = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; // 2 Construct cache - List> blocks = new ArrayList<>(); + List> blocks = new ArrayList<>(); MergedResult.CacheMergedBlockFuntion cache = - (byte[] buffer, long blockId, int length) -> { + (ByteBuf byteBuf, long blockId, int length) -> { assertEquals(blockId - 1, blocks.size()); - blocks.add(Pair.of(length, buffer)); + byteBuf.retain(); + blocks.add(Pair.of(length, byteBuf)); + return true; }; // 3 Construct segments, then merge @@ -115,63 +128,62 @@ public void testMergeSegmentToMergeResult(String classes, @TempDir File tmpDir) List segments = new ArrayList<>(); Comparator comparator = SerializerUtils.getComparator(keyClass); for (int i = 0; i < SEGMENTS; i++) { - if (i % 2 == 0) { - segments.add( - SerializerUtils.genMemorySegment( - rssConf, - keyClass, - valueClass, - i, - i, - SEGMENTS, - RECORDS, - comparator instanceof RawComparator)); - } else { - segments.add( - SerializerUtils.genFileSegment( - rssConf, - keyClass, - valueClass, - i, - i, - SEGMENTS, - RECORDS, - tmpDir, - comparator instanceof RawComparator)); - } + Segment segment = + i % 2 == 0 + ? SerializerUtils.genMemorySegment( + rssConf, keyClass, valueClass, i, i, SEGMENTS, RECORDS, raw, direct) + : SerializerUtils.genFileSegment( + rssConf, keyClass, valueClass, i, i, SEGMENTS, RECORDS, tmpDir, raw); + segment.init(); + segments.add(segment); } - MergedResult result = new MergedResult(rssConf, cache, -1); - OutputStream mergedOutputStream = result.getOutputStream(); - Merger.merge( - rssConf, - mergedOutputStream, - segments, - keyClass, - valueClass, - comparator, - comparator instanceof RawComparator); - mergedOutputStream.flush(); + Partition partition = mock(Partition.class); + MergedResult result = new MergedResult(rssConf, cache, -1, partition); + long totalBytes = segments.stream().mapToLong(segment -> segment.getSize()).sum(); + SerOutputStream mergedOutputStream = result.getOutputStream(direct, totalBytes); + Merger.merge(rssConf, mergedOutputStream, segments, keyClass, valueClass, comparator, raw); mergedOutputStream.close(); // 4 check merged blocks int index = 0; + SerializerFactory factory = new SerializerFactory(rssConf); + Serializer serializer = factory.getSerializer(keyClass); + assert factory.getSerializer(valueClass).getClass().equals(serializer.getClass()); + SerializerInstance instance = serializer.newInstance(); for (int i = 0; i < blocks.size(); i++) { - int length = blocks.get(i).getLeft(); - byte[] buffer = blocks.get(i).getRight(); RecordsReader reader = new RecordsReader( rssConf, - PartialInputStream.newInputStream(ByteBuffer.wrap(buffer)), + SerInputStream.newInputStream(blocks.get(i).getRight()), keyClass, valueClass, - false); + raw, + true); + reader.init(); while (reader.next()) { - assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); - assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); + if (raw) { + ByteBuf keyByteBuf = (ByteBuf) reader.getCurrentKey(); + ByteBuf valueByteBuf = (ByteBuf) reader.getCurrentValue(); + byte[] keyBytes = new byte[keyByteBuf.readableBytes()]; + byte[] valueBytes = new byte[valueByteBuf.readableBytes()]; + keyByteBuf.readBytes(keyBytes); + valueByteBuf.readBytes(valueBytes); + DataInputBuffer keyInputBuffer = new DataInputBuffer(); + keyInputBuffer.reset(keyBytes, 0, keyBytes.length); + assertEquals(genData(keyClass, index), instance.deserialize(keyInputBuffer, keyClass)); + DataInputBuffer valueInputBuffer = new DataInputBuffer(); + valueInputBuffer.reset(valueBytes, 0, valueBytes.length); + assertEquals( + genData(valueClass, index), instance.deserialize(valueInputBuffer, valueClass)); + } else { + assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); + } index++; } reader.close(); } assertEquals(RECORDS * SEGMENTS, index); + blocks.forEach(block -> block.getRight().release()); } } diff --git a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java index 4ea82750c7..112f7f5e43 100644 --- a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java +++ b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java @@ -24,6 +24,8 @@ import java.util.concurrent.TimeUnit; import com.google.common.collect.ImmutableMap; +import io.netty.buffer.ByteBuf; +import org.apache.hadoop.io.DataInputBuffer; import org.awaitility.Awaitility; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -41,7 +43,8 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.merger.MergeState; import org.apache.uniffle.common.records.RecordsReader; -import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.SerInputStream; +import org.apache.uniffle.common.serializer.SerializerInstance; import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.serializer.writable.WritableSerializer; import org.apache.uniffle.common.util.BlockIdLayout; @@ -53,6 +56,7 @@ import org.apache.uniffle.server.buffer.ShuffleBufferType; import org.apache.uniffle.storage.util.StorageType; +import static org.apache.uniffle.common.serializer.SerializerUtils.genData; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -83,6 +87,7 @@ public void beforeEach() { serverConf.setLong(ShuffleServerConf.SERVER_APP_EXPIRED_WITHOUT_HEARTBEAT, 60L * 1000L * 60L); serverConf.set(ShuffleServerConf.SERVER_MERGE_ENABLE, true); serverConf.set(ShuffleServerConf.SERVER_SHUFFLE_BUFFER_TYPE, ShuffleBufferType.SKIP_LIST); + serverConf.set(ShuffleServerConf.SERVER_BUFFER_CAPACITY, 100 * 1024 * 1024L); ShuffleServerMetrics.clear(); ShuffleServerMetrics.register(); assertTrue(this.tempDir1.isDirectory()); @@ -102,18 +107,24 @@ public void afterEach() throws Exception { @ParameterizedTest @ValueSource( strings = { - "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true,false", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false,false", }) - public void testMergerManager(String classes, @TempDir File tmpDir) throws Exception { + public void testMergerManager(String classes) throws Exception { // 1 Construct serializer and comparator final String[] classArray = classes.split(","); final String keyClassName = classArray[0]; final String valueClassName = classArray[1]; final Class keyClass = SerializerUtils.getClassByName(keyClassName); final Class valueClass = SerializerUtils.getClassByName(valueClassName); + boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + boolean direct = classArray.length > 3 ? Boolean.parseBoolean(classArray[3]) : false; final Comparator comparator = SerializerUtils.getComparator(keyClass); final String comparatorClassName = comparator.getClass().getName(); final WritableSerializer serializer = new WritableSerializer(new RssConf()); + final SerializerInstance instance = serializer.newInstance(); // 2 Construct shuffle task manager and merge manager shuffleServer = new ShuffleServer(serverConf); @@ -153,20 +164,21 @@ public void testMergerManager(String classes, @TempDir File tmpDir) throws Excep blocks[3] = blockIdLayout.getBlockId(1, PARTITION_ID, 1); ShufflePartitionedBlock[] shufflePartitionedBlocks = new ShufflePartitionedBlock[4]; for (int i = 0; i < 4; i++) { - byte[] buffer = - SerializerUtils.genSortedRecordBytes( - serverConf, keyClass, valueClass, i, 4, RECORDS_NUMBER, 1); + ByteBuf byteBuf = + SerializerUtils.genSortedRecordBuffer( + serverConf, keyClass, valueClass, i, 4, RECORDS_NUMBER, 1, direct); shufflePartitionedBlocks[i] = new ShufflePartitionedBlock( - buffer.length, - buffer.length, + byteBuf.readableBytes(), + byteBuf.readableBytes(), 0, blocks[i], blockIdLayout.getTaskAttemptId(blocks[i]), - buffer); + byteBuf); } ShufflePartitionedData spd = new ShufflePartitionedData(PARTITION_ID, shufflePartitionedBlocks); shuffleTaskManager.cacheShuffleData(APP_ID, SHUFFLE_ID, false, spd); + mergeManager.setDirect(APP_ID, SHUFFLE_ID, direct); // 4.2 report shuffle result shuffleTaskManager.addFinishedBlockIds( APP_ID, SHUFFLE_ID, ImmutableMap.of(PARTITION_ID, blocks), 1); @@ -204,13 +216,31 @@ public void testMergerManager(String classes, @TempDir File tmpDir) throws Excep if (blockSize != -1) { ShuffleDataResult shuffleDataResult = mergeManager.getShuffleData(APP_ID, SHUFFLE_ID, PARTITION_ID, blockId); - PartialInputStream inputStream = - PartialInputStream.newInputStream(shuffleDataResult.getDataBuffer()); + SerInputStream inputStream = + SerInputStream.newInputStream(shuffleDataResult.getDataBuf()); RecordsReader reader = - new RecordsReader(serverConf, inputStream, keyClass, valueClass, false); + new RecordsReader(serverConf, inputStream, keyClass, valueClass, raw, true); + reader.init(); while (reader.next()) { - assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); - assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); + if (raw) { + ByteBuf keyByteBuf = (ByteBuf) reader.getCurrentKey(); + ByteBuf valueByteBuf = (ByteBuf) reader.getCurrentValue(); + byte[] keyBytes = new byte[keyByteBuf.readableBytes()]; + byte[] valueBytes = new byte[valueByteBuf.readableBytes()]; + keyByteBuf.readBytes(keyBytes); + valueByteBuf.readBytes(valueBytes); + DataInputBuffer keyInputBuffer = new DataInputBuffer(); + keyInputBuffer.reset(keyBytes, 0, keyBytes.length); + assertEquals( + genData(keyClass, index), instance.deserialize(keyInputBuffer, keyClass)); + DataInputBuffer valueInputBuffer = new DataInputBuffer(); + valueInputBuffer.reset(valueBytes, 0, valueBytes.length); + assertEquals( + genData(valueClass, index), instance.deserialize(valueInputBuffer, valueClass)); + } else { + assertEquals(genData(keyClass, index), reader.getCurrentKey()); + assertEquals(genData(valueClass, index), reader.getCurrentValue()); + } index++; } shuffleDataResult.release();