Skip to content

Commit

Permalink
[Improve] improve avro format convert
Browse files Browse the repository at this point in the history
  • Loading branch information
liunaijie committed Dec 27, 2023
1 parent a9ec9d5 commit 4474b70
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.container.TestContainerId;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;
import org.apache.seatunnel.format.avro.AvroDeserializationSchema;
import org.apache.seatunnel.format.text.TextSerializationSchema;

import org.apache.kafka.clients.consumer.ConsumerConfig;
Expand All @@ -47,6 +48,7 @@
import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.serialization.ByteArrayDeserializer;
import org.apache.kafka.common.serialization.ByteArraySerializer;
import org.apache.kafka.common.serialization.StringDeserializer;

Expand Down Expand Up @@ -303,8 +305,16 @@ public void testFakeSourceToKafkaAvroFormat(TestContainer container)
}

@TestTemplate
@DisabledOnContainer(TestContainerId.SPARK_2_4)
public void testKafkaAvroToConsole(TestContainer container)
@DisabledOnContainer(
value = {
TestContainerId.SPARK_2_4,
TestContainerId.SPARK_3_3,
TestContainerId.FLINK_1_13,
TestContainerId.FLINK_1_14,
TestContainerId.FLINK_1_15,
TestContainerId.FLINK_1_16
})
public void testKafkaAvroToAssert(TestContainer container)
throws IOException, InterruptedException {
DefaultSeaTunnelRowSerializer serializer =
DefaultSeaTunnelRowSerializer.create(
Expand All @@ -313,8 +323,25 @@ public void testKafkaAvroToConsole(TestContainer container)
MessageFormat.AVRO,
DEFAULT_FIELD_DELIMITER);
generateTestData(row -> serializer.serializeRow(row), 0, 100);
Container.ExecResult execResult = container.executeJob("/avro/kafka_avro_to_console.conf");
Container.ExecResult execResult = container.executeJob("/avro/kafka_avro_to_assert.conf");
Assertions.assertEquals(0, execResult.getExitCode(), execResult.getStderr());

AvroDeserializationSchema avroDeserializationSchema =
new AvroDeserializationSchema(SEATUNNEL_ROW_TYPE);
List<SeaTunnelRow> kafkaSTRow =
getKafkaSTRow(
"test_avro_topic",
value -> {
try {
return avroDeserializationSchema.deserialize(value);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
Assertions.assertEquals(100, kafkaSTRow.size());
kafkaSTRow.forEach(row -> Assertions.assertEquals("string", row.getField(3).toString()));
kafkaSTRow.forEach(row -> Assertions.assertEquals(false, row.getField(4)));
kafkaSTRow.forEach(row -> Assertions.assertEquals(Byte.parseByte("1"), row.getField(5)));
}

public void testKafkaLatestToConsole(TestContainer container)
Expand Down Expand Up @@ -373,6 +400,22 @@ private Properties kafkaConsumerConfig() {
return props;
}

private Properties kafkaByteConsumerConfig() {
Properties props = new Properties();
props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafkaContainer.getBootstrapServers());
props.put(ConsumerConfig.GROUP_ID_CONFIG, "seatunnel-kafka-sink-group");
props.put(
ConsumerConfig.AUTO_OFFSET_RESET_CONFIG,
OffsetResetStrategy.EARLIEST.toString().toLowerCase());
props.setProperty(
ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG,
ByteArrayDeserializer.class.getName());
props.setProperty(
ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG,
ByteArrayDeserializer.class.getName());
return props;
}

private void generateTestData(ProducerRecordConverter converter, int start, int end) {
for (int i = start; i < end; i++) {
SeaTunnelRow row =
Expand Down Expand Up @@ -480,7 +523,34 @@ private List<String> getKafkaConsumerListData(String topicName) {
return data;
}

private List<SeaTunnelRow> getKafkaSTRow(String topicName, ConsumerRecordConverter converter) {
List<SeaTunnelRow> data = new ArrayList<>();
try (KafkaConsumer<byte[], byte[]> consumer =
new KafkaConsumer<>(kafkaByteConsumerConfig())) {
consumer.subscribe(Arrays.asList(topicName));
Map<TopicPartition, Long> offsets =
consumer.endOffsets(Arrays.asList(new TopicPartition(topicName, 0)));
Long endOffset = offsets.entrySet().iterator().next().getValue();
Long lastProcessedOffset = -1L;

do {
ConsumerRecords<byte[], byte[]> records = consumer.poll(Duration.ofMillis(100));
for (ConsumerRecord<byte[], byte[]> record : records) {
if (lastProcessedOffset < record.offset()) {
data.add(converter.convert(record.value()));
}
lastProcessedOffset = record.offset();
}
} while (lastProcessedOffset < endOffset - 1);
}
return data;
}

interface ProducerRecordConverter {
ProducerRecord<byte[], byte[]> convert(SeaTunnelRow row);
}

interface ConsumerRecordConverter {
SeaTunnelRow convert(byte[] value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,21 @@ sink {
rule_value = 99
}
]
}
},
{
field_name = c_string
field_type = string
field_value = [
{
rule_type = MIN_LENGTH
rule_value = 6
},
{
rule_type = MAX_LENGTH
rule_value = 6
}
]
}
]
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ public AvroSerializationSchema(SeaTunnelRowType rowType) {
public byte[] serialize(SeaTunnelRow element) {
GenericRecord record = converter.convertRowToGenericRecord(element);
try {
out.reset();
writer.write(record, encoder);
encoder.flush();
return out.toByteArray();
} catch (IOException e) {
throw new SeaTunnelAvroFormatException(
AvroFormatErrorCode.SERIALIZATION_ERROR,
"Serialization error on record : " + element);
} finally {
out.reset();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.avro.io.DatumReader;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.List;

public class AvroToRowConverter implements Serializable {
Expand Down Expand Up @@ -72,19 +73,16 @@ public SeaTunnelRow converter(GenericRecord record, SeaTunnelRowType rowType) {
values[i] = null;
continue;
}
values[i] =
convertField(
rowType.getFieldType(i),
record.getSchema().getField(fieldNames[i]),
record.get(fieldNames[i]));
values[i] = convertField(rowType.getFieldType(i), record.get(fieldNames[i]));
}
return new SeaTunnelRow(values);
}

private Object convertField(SeaTunnelDataType<?> dataType, Schema.Field field, Object val) {
private Object convertField(SeaTunnelDataType<?> dataType, Object val) {
switch (dataType.getSqlType()) {
case MAP:
case STRING:
return val.toString();
case MAP:
case BOOLEAN:
case SMALLINT:
case INT:
Expand Down Expand Up @@ -121,67 +119,15 @@ private Object convertField(SeaTunnelDataType<?> dataType, Schema.Field field, O
}
}

protected static Object convertArray(List<Object> val, SeaTunnelDataType<?> dataType) {
protected Object convertArray(List<Object> val, SeaTunnelDataType<?> dataType) {
if (val == null) {
return null;
}
int length = val.size();
switch (dataType.getSqlType()) {
case STRING:
String[] strings = new String[length];
for (int i = 0; i < strings.length; i++) {
strings[i] = val.get(i).toString();
}
return strings;
case BOOLEAN:
Boolean[] booleans = new Boolean[length];
for (int i = 0; i < booleans.length; i++) {
booleans[i] = (Boolean) val.get(i);
}
return booleans;
case BYTES:
Byte[] bytes = new Byte[length];
for (int i = 0; i < bytes.length; i++) {
bytes[i] = (Byte) val.get(i);
}
return bytes;
case SMALLINT:
Short[] shorts = new Short[length];
for (int i = 0; i < shorts.length; i++) {
shorts[i] = (Short) val.get(i);
}
return shorts;
case INT:
Integer[] integers = new Integer[length];
for (int i = 0; i < integers.length; i++) {
integers[i] = (Integer) val.get(i);
}
return integers;
case BIGINT:
Long[] longs = new Long[length];
for (int i = 0; i < longs.length; i++) {
longs[i] = (Long) val.get(i);
}
return longs;
case FLOAT:
Float[] floats = new Float[length];
for (int i = 0; i < floats.length; i++) {
floats[i] = (Float) val.get(i);
}
return floats;
case DOUBLE:
Double[] doubles = new Double[length];
for (int i = 0; i < doubles.length; i++) {
doubles[i] = (Double) val.get(i);
}
return doubles;
default:
String errorMsg =
String.format(
"SeaTunnel avro array format is not supported for this data type [%s]",
dataType.getSqlType());
throw new SeaTunnelAvroFormatException(
AvroFormatErrorCode.UNSUPPORTED_DATA_TYPE, errorMsg);
Object instance = Array.newInstance(dataType.getTypeClass(), length);
for (int i = 0; i < val.size(); i++) {
Array.set(instance, i, convertField(dataType, val.get(i)));
}
return instance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import org.apache.avro.io.DatumWriter;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;

public class RowToAvroConverter implements Serializable {

Expand Down Expand Up @@ -111,14 +111,11 @@ private Object resolveObject(Object data, SeaTunnelDataType<?> seaTunnelDataType
case BYTES:
return ByteBuffer.wrap((byte[]) data);
case ARRAY:
// BasicType<?> basicType = ((ArrayType<?, ?>)
// seaTunnelDataType).getElementType();
// return Util.convertArray((Object[]) data, basicType);
BasicType<?> basicType = ((ArrayType<?, ?>) seaTunnelDataType).getElementType();
List<Object> records = new ArrayList<>(((Object[]) data).length);
for (Object object : (Object[]) data) {
Object resolvedObject = resolveObject(object, basicType);
records.add(resolvedObject);
int length = Array.getLength(data);
ArrayList<Object> records = new ArrayList<>(length);
for (int i = 0; i < length; i++) {
records.add(resolveObject(Array.get(data, i), basicType));
}
return records;
case ROW:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ public void testSerialization() throws IOException {
SeaTunnelRowType rowType = buildSeaTunnelRowType();
SeaTunnelRow seaTunnelRow = buildSeaTunnelRow();
AvroSerializationSchema serializationSchema = new AvroSerializationSchema(rowType);
byte[] serialize = serializationSchema.serialize(seaTunnelRow);
byte[] bytes = serializationSchema.serialize(seaTunnelRow);
AvroDeserializationSchema deserializationSchema = new AvroDeserializationSchema(rowType);
SeaTunnelRow deserialize = deserializationSchema.deserialize(serialize);
SeaTunnelRow deserialize = deserializationSchema.deserialize(bytes);
String[] strArray1 = (String[]) seaTunnelRow.getField(1);
String[] strArray2 = (String[]) deserialize.getField(1);
Assertions.assertArrayEquals(strArray1, strArray2);
Expand Down

0 comments on commit 4474b70

Please sign in to comment.