Skip to content

Commit

Permalink
[#2173] feat(remote merge): support netty for remote merge.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengchenyu committed Oct 21, 2024
1 parent 43323bb commit 3a68303
Show file tree
Hide file tree
Showing 67 changed files with 2,628 additions and 1,390 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ public void init(Context<K, V> context) {
}
Map<Integer, List<ShuffleServerInfo>> serverInfoMap = new HashMap<>();
serverInfoMap.put(partitionId, new ArrayList<>(serverInfoSet));
String clientType =
rssJobConf.get(RssMRConfig.RSS_CLIENT_TYPE, RssMRConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE);
this.reader =
new RMRecordsReader(
appId,
Expand All @@ -134,7 +136,8 @@ public void init(Context<K, V> context) {
true,
combiner,
combiner != null,
new MRMetricsReporter(context.getReporter()));
new MRMetricsReporter(context.getReporter()),
clientType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.records.RecordsReader;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
import org.apache.uniffle.common.serializer.SerInputStream;
import org.apache.uniffle.common.serializer.SerializerFactory;
import org.apache.uniffle.common.serializer.SerializerInstance;
import org.apache.uniffle.common.serializer.SerializerUtils;
Expand Down Expand Up @@ -522,11 +522,8 @@ public void testWriteNormalWithRemoteMerge() throws Exception {
ByteBuf byteBuf = blockInfos.get(0).getData();
RecordsReader<Text, Text> reader =
new RecordsReader<>(
rssConf,
PartialInputStreamImpl.newInputStream(byteBuf.nioBuffer()),
Text.class,
Text.class,
false);
rssConf, SerInputStream.newInputStream(byteBuf), Text.class, Text.class, false, false);
reader.init();
int index = 0;
while (reader.next()) {
assertEquals(SerializerUtils.genData(Text.class, index), reader.getCurrentKey());
Expand Down Expand Up @@ -609,10 +606,12 @@ public void testWriteNormalWithRemoteMergeAndCombine() throws Exception {
RecordsReader<Text, IntWritable> reader =
new RecordsReader<>(
rssConf,
PartialInputStreamImpl.newInputStream(byteBuf.nioBuffer()),
SerInputStream.newInputStream(byteBuf),
Text.class,
IntWritable.class,
false,
false);
reader.init();
int index = 0;
while (reader.next()) {
int aimValue = index;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.hadoop.mapreduce.task.reduce;

import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -29,6 +27,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import io.netty.buffer.ByteBuf;
import org.apache.hadoop.io.DataInputBuffer;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.RawComparator;
Expand Down Expand Up @@ -58,13 +57,15 @@
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.merger.Merger;
import org.apache.uniffle.common.merger.Segment;
import org.apache.uniffle.common.serializer.DynBufferSerOutputStream;
import org.apache.uniffle.common.serializer.SerOutputStream;
import org.apache.uniffle.common.serializer.Serializer;
import org.apache.uniffle.common.serializer.SerializerFactory;
import org.apache.uniffle.common.serializer.SerializerInstance;
import org.apache.uniffle.common.serializer.SerializerUtils;

import static org.apache.uniffle.common.serializer.SerializerUtils.genData;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
Expand Down Expand Up @@ -136,12 +137,10 @@ public void testReadShuffleWithoutCombine() throws Exception {
combiner,
false,
null);
ByteBuffer byteBuffer =
ByteBuffer.wrap(
genSortedRecordBytes(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, 1));
ByteBuf byteBuf = genSortedRecordBuffer(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, 1);
ShuffleServerClient serverClient =
new MockedShuffleServerClient(
new int[] {PARTITION_ID}, new ByteBuffer[][] {{byteBuffer}}, blockIds);
new int[] {PARTITION_ID}, new ByteBuf[][] {{byteBuf}}, blockIds);
RMRecordsReader readerSpy = spy(reader);
doReturn(serverClient).when(readerSpy).createShuffleServerClient(any());

Expand Down Expand Up @@ -170,6 +169,7 @@ public void testReadShuffleWithoutCombine() throws Exception {
index++;
}
assertEquals(RECORDS_NUM, index);
byteBuf.release();
}
}

Expand Down Expand Up @@ -219,20 +219,21 @@ public void testReadShuffleWithCombine() throws Exception {
List<Segment> segments = new ArrayList<>();
segments.add(
SerializerUtils.genMemorySegment(
rssConf, keyClass, valueClass, 0L, 0, 2, RECORDS_NUM, true));
rssConf, keyClass, valueClass, 0L, 0, 2, RECORDS_NUM, true, false));
segments.add(
SerializerUtils.genMemorySegment(
rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM, true));
rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM, true, false));
segments.add(
SerializerUtils.genMemorySegment(
rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM, true));
ByteArrayOutputStream output = new ByteArrayOutputStream();
rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM, true, false));
segments.forEach(segment -> segment.init());
SerOutputStream output = new DynBufferSerOutputStream();
Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator, true);
output.close();
ByteBuffer byteBuffer = ByteBuffer.wrap(output.toByteArray());
ByteBuf byteBuf = output.toByteBuf();
ShuffleServerClient serverClient =
new MockedShuffleServerClient(
new int[] {PARTITION_ID}, new ByteBuffer[][] {{byteBuffer}}, blockIds);
new int[] {PARTITION_ID}, new ByteBuf[][] {{byteBuf}}, blockIds);
RMRecordsReader reader =
new RMRecordsReader(
APP_ID,
Expand Down Expand Up @@ -280,6 +281,7 @@ public void testReadShuffleWithCombine() throws Exception {
index++;
}
assertEquals(RECORDS_NUM * 2, index);
byteBuf.release();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public class RMRssShuffle implements ExceptionReporter {
private ShuffleInputEventHandlerOrderedGrouped eventHandler;
private final TezTaskAttemptID tezTaskAttemptID;
private final String srcNameTrimmed;
private final String clientType;
private Map<Integer, List<ShuffleServerInfo>> partitionToServers;

private AtomicBoolean isShutDown = new AtomicBoolean(false);
Expand Down Expand Up @@ -101,6 +102,8 @@ public RMRssShuffle(
this.numInputs = numInputs;
this.shuffleId = shuffleId;
this.applicationAttemptId = applicationAttemptId;
this.clientType =
conf.get(RssTezConfig.RSS_CLIENT_TYPE, RssTezConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE);
this.appId = this.applicationAttemptId.toString();
this.srcNameTrimmed = TezUtilsInternal.cleanVertexName(inputContext.getSourceVertexName());
LOG.info(srcNameTrimmed + ": Shuffle assigned with " + numInputs + " inputs.");
Expand Down Expand Up @@ -254,7 +257,8 @@ public RMRecordsReader createRMRecordsReader(Set partitionIds) {
false,
(inc) -> {
inputRecordCounter.increment(inc);
});
},
this.clientType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
Expand All @@ -28,6 +29,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import io.netty.buffer.ByteBuf;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IntWritable;
Expand Down Expand Up @@ -74,7 +76,7 @@
import static org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS;
import static org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_KEY_COMPARATOR_CLASS;
import static org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
Expand Down Expand Up @@ -158,12 +160,9 @@ public void testReadShuffleData() throws Exception {
false,
null);
RMRecordsReader recordsReaderSpy = spy(recordsReader);
ByteBuffer[][] buffers =
new ByteBuffer[][] {
{
ByteBuffer.wrap(
genSortedRecordBytes(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, duplicated))
}
ByteBuf[][] buffers =
new ByteBuf[][] {
{genSortedRecordBuffer(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, duplicated)}
};
ShuffleServerClient serverClient =
new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, blockIds);
Expand Down Expand Up @@ -228,6 +227,7 @@ public void testReadShuffleData() throws Exception {
index++;
}
assertEquals(RECORDS_NUM, index);
Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release()));
}

@Test
Expand Down Expand Up @@ -309,15 +309,13 @@ public void testReadMultiPartitionShuffleData() throws Exception {
false,
null);
RMRecordsReader recordsReaderSpy = spy(recordsReader);
ByteBuffer[][] buffers = new ByteBuffer[3][2];
ByteBuf[][] buffers = new ByteBuf[3][2];
for (int i = 0; i < 3; i++) {
buffers[i][0] =
ByteBuffer.wrap(
genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, duplicated));
genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, duplicated);
buffers[i][1] =
ByteBuffer.wrap(
genSortedRecordBytes(
rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, duplicated));
genSortedRecordBuffer(
rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, duplicated);
}
ShuffleServerClient serverClient =
new MockedShuffleServerClient(
Expand Down Expand Up @@ -396,6 +394,7 @@ public void testReadMultiPartitionShuffleData() throws Exception {
index++;
}
assertEquals(RECORDS_NUM * 6, index);
Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release()));
}

public static DataMovementEvent createDataMovementEvent(int partition, String path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -73,7 +72,7 @@
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.records.RecordsReader;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
import org.apache.uniffle.common.serializer.SerInputStream;
import org.apache.uniffle.common.serializer.SerializerFactory;
import org.apache.uniffle.common.serializer.SerializerInstance;
import org.apache.uniffle.common.serializer.SerializerUtils;
Expand Down Expand Up @@ -616,11 +615,8 @@ public void testWriteWithRemoteMerge() throws Exception {
buf.readBytes(bytes);
RecordsReader<Text, Text> reader =
new RecordsReader<>(
rssConf,
PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes)),
Text.class,
Text.class,
false);
rssConf, SerInputStream.newInputStream(buf), Text.class, Text.class, false, false);
reader.init();
int index = 0;
while (reader.next()) {
assertEquals(SerializerUtils.genData(Text.class, index), reader.getCurrentKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Map;
import java.util.Random;

import io.netty.buffer.Unpooled;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.IntWritable;
Expand All @@ -41,8 +42,7 @@

import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.serializer.DeserializationStream;
import org.apache.uniffle.common.serializer.PartialInputStream;
import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
import org.apache.uniffle.common.serializer.SerInputStream;
import org.apache.uniffle.common.serializer.SerializerFactory;
import org.apache.uniffle.common.serializer.SerializerInstance;

Expand Down Expand Up @@ -203,9 +203,11 @@ public void testReadWriteWithRemoteMergeAndNoSort() throws IOException {
buffer.addRecord(genData(Text.class, i), genData(IntWritable.class, i));
}
byte[] bytes = buffer.getData();
PartialInputStream inputStream = PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes));
SerInputStream inputStream =
SerInputStream.newInputStream(Unpooled.wrappedBuffer(ByteBuffer.wrap(bytes)));
DeserializationStream dStream =
instance.deserializeStream(inputStream, Text.class, IntWritable.class, false);
instance.deserializeStream(inputStream, Text.class, IntWritable.class, false, false);
dStream.init();
for (int i = 0; i < RECORDS_NUM; i++) {
assertTrue(dStream.nextRecord());
assertEquals(genData(Text.class, i), dStream.getCurrentKey());
Expand Down Expand Up @@ -240,9 +242,11 @@ public void testReadWriteWithRemoteMergeAndSort() throws IOException {
buffer.addRecord(genData(Text.class, i), genData(IntWritable.class, i));
}
byte[] bytes = buffer.getData();
PartialInputStream inputStream = PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes));
SerInputStream inputStream =
SerInputStream.newInputStream(Unpooled.wrappedBuffer(ByteBuffer.wrap(bytes)));
DeserializationStream dStream =
instance.deserializeStream(inputStream, Text.class, IntWritable.class, false);
instance.deserializeStream(inputStream, Text.class, IntWritable.class, false, false);
dStream.init();
for (int i = 0; i < RECORDS_NUM; i++) {
assertTrue(dStream.nextRecord());
assertEquals(genData(Text.class, i), dStream.getCurrentKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.tez.runtime.library.input;

import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Comparator;
Expand All @@ -29,6 +28,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import io.netty.buffer.ByteBuf;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IntWritable;
Expand Down Expand Up @@ -72,12 +72,14 @@
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.merger.Merger;
import org.apache.uniffle.common.merger.Segment;
import org.apache.uniffle.common.serializer.DynBufferSerOutputStream;
import org.apache.uniffle.common.serializer.SerOutputStream;
import org.apache.uniffle.common.serializer.SerializerUtils;
import org.apache.uniffle.common.util.BlockIdLayout;

import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
Expand Down Expand Up @@ -166,10 +168,11 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM));
segments.add(
SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM));
ByteArrayOutputStream output = new ByteArrayOutputStream();
segments.forEach(segment -> segment.init());
SerOutputStream output = new DynBufferSerOutputStream();
Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator, false);
output.close();
ByteBuffer[][] buffers = new ByteBuffer[][] {{ByteBuffer.wrap(output.toByteArray())}};
ByteBuf[][] buffers = new ByteBuf[][] {{output.toByteBuf()}};
ShuffleServerClient serverClient =
new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, blockIds);
RMRecordsReader recordsReader =
Expand Down Expand Up @@ -362,15 +365,12 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
false,
null);
RMRecordsReader recordsReaderSpy = spy(recordsReader);
ByteBuffer[][] buffers = new ByteBuffer[3][2];
ByteBuf[][] buffers = new ByteBuf[3][2];
for (int i = 0; i < 3; i++) {
buffers[i][0] =
ByteBuffer.wrap(
genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 1));
buffers[i][0] = genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 1);
buffers[i][1] =
ByteBuffer.wrap(
genSortedRecordBytes(
rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 1));
genSortedRecordBuffer(
rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 1);
}
ShuffleServerClient serverClient =
new MockedShuffleServerClient(
Expand Down
Loading

0 comments on commit 3a68303

Please sign in to comment.