Skip to content

Commit

Permalink
fix: address reviews v2
Browse files Browse the repository at this point in the history
  • Loading branch information
vibhatha committed Feb 1, 2024
1 parent 907195a commit 8f482da
Showing 1 changed file with 88 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
Expand All @@ -46,6 +48,7 @@
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.junit.After;
import org.junit.Assert;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -96,6 +99,24 @@ private void createAndWriteArrowFile(DictionaryProvider provider,
}
}

private void createAndWriteArrowStream(DictionaryProvider provider,
CompressionUtil.CodecType codecType) throws IOException {
List<Field> fields = new ArrayList<>();
fields.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()), new ArrayList<>()));
root = VectorSchemaRoot.create(new Schema(fields), allocator);

final int rowCount = 10;
GenerateSampleData.generateTestData(root.getVector(0), rowCount);
root.setRowCount(rowCount);

try (final ArrowStreamWriter writer = new ArrowStreamWriter(root, provider, Channels.newChannel(out),
IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, codecType, Optional.of(7))) {
writer.start();
writer.writeBatch();
writer.end();
}
}

private Dictionary createDictionary(VarCharVector dictionaryVector) {
setVector(dictionaryVector,
"foo".getBytes(StandardCharsets.UTF_8),
Expand All @@ -113,29 +134,51 @@ public void testArrowFileZstdRoundTrip() throws Exception {
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
CommonsCompressionFactory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
Assert.assertTrue(reader.loadNextBatch());
Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assert.assertFalse(reader.loadNextBatch());
Assertions.assertEquals(1, reader.getRecordBlocks().size());
Assertions.assertTrue(reader.loadNextBatch());
Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assertions.assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
Assertions.assertEquals(1, reader.getRecordBlocks().size());
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
reader::loadNextBatch);
Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage());
}
}

@Test
public void testArrowStreamZstdRoundTrip() throws Exception {
createAndWriteArrowStream(null, CompressionUtil.CodecType.ZSTD);
// with compression
try (ArrowStreamReader reader =
new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
CommonsCompressionFactory.INSTANCE)) {
Assert.assertTrue(reader.loadNextBatch());
Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assert.assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowStreamReader reader =
new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
reader::loadNextBatch);
Assert.assertEquals(
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage()
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage()
);
}
}

@Test
public void testArrowFileZstdRoundTripWithDictionary() throws Exception {
VarCharVector dictionaryVector = (VarCharVector)
FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("f1", allocator, null);
FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("f1_file", allocator, null);
Dictionary dictionary = createDictionary(dictionaryVector);
DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider();
provider.put(dictionary);
Expand All @@ -146,22 +189,50 @@ public void testArrowFileZstdRoundTripWithDictionary() throws Exception {
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
CommonsCompressionFactory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
Assert.assertTrue(reader.loadNextBatch());
Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assert.assertFalse(reader.loadNextBatch());
Assertions.assertEquals(1, reader.getRecordBlocks().size());
Assertions.assertTrue(reader.loadNextBatch());
Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assertions.assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
Assertions.assertEquals(1, reader.getRecordBlocks().size());
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
reader::loadNextBatch);
Assert.assertEquals(
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage()
);
Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage());
}
dictionaryVector.close();
}

@Test
public void testArrowStreamZstdRoundTripWithDictionary() throws Exception {
VarCharVector dictionaryVector = (VarCharVector)
FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("f1_stream", allocator, null);
Dictionary dictionary = createDictionary(dictionaryVector);
DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider();
provider.put(dictionary);

createAndWriteArrowStream(provider, CompressionUtil.CodecType.ZSTD);

// with compression
try (ArrowStreamReader reader =
new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
CommonsCompressionFactory.INSTANCE)) {
Assertions.assertTrue(reader.loadNextBatch());
Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assertions.assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowStreamReader reader =
new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
reader::loadNextBatch);
Assertions.assertEquals("Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage());
}
dictionaryVector.close();
}
Expand Down

0 comments on commit 8f482da

Please sign in to comment.