Skip to content

Commit

Permalink
[Improve] improve avro format convert
Browse files Browse the repository at this point in the history
  • Loading branch information
liunaijie committed Dec 25, 2023
1 parent 93ebc39 commit 9b50e15
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ public void testFakeSourceToKafkaAvroFormat(TestContainer container)
Container.ExecResult execResult =
container.executeJob("/avro/fake_source_to_kafka_avro_format.conf");
Assertions.assertEquals(0, execResult.getExitCode(), execResult.getStderr());
List<String> dataList = getKafkaConsumerListData("test_avro_topic");
dataList.forEach(System.out::println);
}

@TestTemplate
Expand All @@ -313,7 +315,7 @@ public void testKafkaAvroToConsole(TestContainer container)
MessageFormat.AVRO,
DEFAULT_FIELD_DELIMITER);
generateTestData(row -> serializer.serializeRow(row), 0, 100);
Container.ExecResult execResult = container.executeJob("/avro/kafka_avro_to_console.conf");
Container.ExecResult execResult = container.executeJob("/avro/kafka_avro_to_assert.conf");
Assertions.assertEquals(0, execResult.getExitCode(), execResult.getStderr());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.avro.io.DatumReader;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.List;

public class AvroToRowConverter implements Serializable {
Expand Down Expand Up @@ -72,19 +73,16 @@ public SeaTunnelRow converter(GenericRecord record, SeaTunnelRowType rowType) {
values[i] = null;
continue;
}
values[i] =
convertField(
rowType.getFieldType(i),
record.getSchema().getField(fieldNames[i]),
record.get(fieldNames[i]));
values[i] = convertField(rowType.getFieldType(i), record.get(fieldNames[i]));
}
return new SeaTunnelRow(values);
}

private Object convertField(SeaTunnelDataType<?> dataType, Schema.Field field, Object val) {
private Object convertField(SeaTunnelDataType<?> dataType, Object val) {
switch (dataType.getSqlType()) {
case MAP:
case STRING:
return val.toString();
case MAP:
case BOOLEAN:
case SMALLINT:
case INT:
Expand Down Expand Up @@ -121,67 +119,15 @@ private Object convertField(SeaTunnelDataType<?> dataType, Schema.Field field, O
}
}

protected static Object convertArray(List<Object> val, SeaTunnelDataType<?> dataType) {
protected Object convertArray(List<Object> val, SeaTunnelDataType<?> dataType) {
if (val == null) {
return null;
}
int length = val.size();
switch (dataType.getSqlType()) {
case STRING:
String[] strings = new String[length];
for (int i = 0; i < strings.length; i++) {
strings[i] = val.get(i).toString();
}
return strings;
case BOOLEAN:
Boolean[] booleans = new Boolean[length];
for (int i = 0; i < booleans.length; i++) {
booleans[i] = (Boolean) val.get(i);
}
return booleans;
case BYTES:
Byte[] bytes = new Byte[length];
for (int i = 0; i < bytes.length; i++) {
bytes[i] = (Byte) val.get(i);
}
return bytes;
case SMALLINT:
Short[] shorts = new Short[length];
for (int i = 0; i < shorts.length; i++) {
shorts[i] = (Short) val.get(i);
}
return shorts;
case INT:
Integer[] integers = new Integer[length];
for (int i = 0; i < integers.length; i++) {
integers[i] = (Integer) val.get(i);
}
return integers;
case BIGINT:
Long[] longs = new Long[length];
for (int i = 0; i < longs.length; i++) {
longs[i] = (Long) val.get(i);
}
return longs;
case FLOAT:
Float[] floats = new Float[length];
for (int i = 0; i < floats.length; i++) {
floats[i] = (Float) val.get(i);
}
return floats;
case DOUBLE:
Double[] doubles = new Double[length];
for (int i = 0; i < doubles.length; i++) {
doubles[i] = (Double) val.get(i);
}
return doubles;
default:
String errorMsg =
String.format(
"SeaTunnel avro array format is not supported for this data type [%s]",
dataType.getSqlType());
throw new SeaTunnelAvroFormatException(
AvroFormatErrorCode.UNSUPPORTED_DATA_TYPE, errorMsg);
Object instance = Array.newInstance(dataType.getTypeClass(), length);
for (int i = 0; i < val.size(); i++) {
Array.set(instance, i, convertField(dataType, val.get(i)));
}
return instance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import org.apache.avro.io.DatumWriter;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;

public class RowToAvroConverter implements Serializable {

Expand Down Expand Up @@ -111,14 +111,11 @@ private Object resolveObject(Object data, SeaTunnelDataType<?> seaTunnelDataType
case BYTES:
return ByteBuffer.wrap((byte[]) data);
case ARRAY:
// BasicType<?> basicType = ((ArrayType<?, ?>)
// seaTunnelDataType).getElementType();
// return Util.convertArray((Object[]) data, basicType);
BasicType<?> basicType = ((ArrayType<?, ?>) seaTunnelDataType).getElementType();
List<Object> records = new ArrayList<>(((Object[]) data).length);
for (Object object : (Object[]) data) {
Object resolvedObject = resolveObject(object, basicType);
records.add(resolvedObject);
int length = Array.getLength(data);
ArrayList<Object> records = new ArrayList<>(length);
for (int i = 0; i < length; i++) {
records.add(resolveObject(Array.get(data, i), basicType));
}
return records;
case ROW:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ public void testSerialization() throws IOException {
SeaTunnelRowType rowType = buildSeaTunnelRowType();
SeaTunnelRow seaTunnelRow = buildSeaTunnelRow();
AvroSerializationSchema serializationSchema = new AvroSerializationSchema(rowType);
byte[] serialize = serializationSchema.serialize(seaTunnelRow);
byte[] bytes = serializationSchema.serialize(seaTunnelRow);
AvroDeserializationSchema deserializationSchema = new AvroDeserializationSchema(rowType);
SeaTunnelRow deserialize = deserializationSchema.deserialize(serialize);
SeaTunnelRow deserialize = deserializationSchema.deserialize(bytes);
String[] strArray1 = (String[]) seaTunnelRow.getField(1);
String[] strArray2 = (String[]) deserialize.getField(1);
Assertions.assertArrayEquals(strArray1, strArray2);
Expand Down

0 comments on commit 9b50e15

Please sign in to comment.