Skip to content

Commit

Permalink
fix: address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
vibhatha committed Nov 22, 2023
1 parent ab13149 commit 82484f4
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
import org.apache.arrow.vector.GenerateSampleData;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionCodec.Factory;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.Dictionary;
Expand All @@ -58,6 +56,7 @@
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;

public class TestArrowReaderWriterWithCompression {

Expand All @@ -69,6 +68,7 @@ public class TestArrowReaderWriterWithCompression {
public void setup() {
allocator = new RootAllocator(Integer.MAX_VALUE);
out = new ByteArrayOutputStream();
root = null;
}

@After
Expand All @@ -79,7 +79,10 @@ public void tearDown() {
if (allocator != null) {
allocator.close();
}
out.reset();
if (out != null) {
out.reset();
}

}

private void createAndWriteArrowFile(DictionaryProvider provider,
Expand All @@ -100,22 +103,6 @@ private void createAndWriteArrowFile(DictionaryProvider provider,
}
}

private void readArrowFile(Factory factory, boolean expectSuccess, String expectedErrorMessage)
throws IOException {
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator, factory)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
if (expectSuccess) {
Assert.assertTrue(reader.loadNextBatch());
Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assert.assertFalse(reader.loadNextBatch());
} else {
Exception exception = Assert.assertThrows(IllegalArgumentException.class, reader::loadNextBatch);
Assert.assertEquals(expectedErrorMessage, exception.getMessage());
}
}
}

private Dictionary createDictionary(VarCharVector dictionaryVector) {
setVector(dictionaryVector,
"foo".getBytes(StandardCharsets.UTF_8),
Expand Down Expand Up @@ -161,30 +148,32 @@ private File writeArrowStream(VectorSchemaRoot root, DictionaryProvider provider
return tempFile;
}

private void readArrowStream(File tempFile, BufferAllocator allocator,
CompressionCodec.Factory compressionFactory,
boolean shouldSucceed, String expectedExceptionMessage) throws IOException {
try (SeekableByteChannel channel = FileChannel.open(tempFile.toPath());
ArrowStreamReader reader = new ArrowStreamReader(channel, allocator, compressionFactory)) {
if (shouldSucceed) {
Assert.assertTrue(reader.loadNextBatch());
Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assert.assertFalse(reader.loadNextBatch());
} else {
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
() -> reader.loadNextBatch());
Assert.assertEquals(expectedExceptionMessage, exception.getMessage());
}
}
}


@Test
@Disabled
public void testArrowFileZstdRoundTrip() throws Exception {
createAndWriteArrowFile(null, CompressionUtil.CodecType.ZSTD);
readArrowFile(CommonsCompressionFactory.INSTANCE, true, null);
readArrowFile(NoCompressionCodec.Factory.INSTANCE, false,
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD");
// with compression
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
CommonsCompressionFactory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
Assert.assertTrue(reader.loadNextBatch());
Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assert.assertFalse(reader.loadNextBatch());

}
// without compression
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
reader::loadNextBatch);
Assert.assertEquals(
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage()
);
}
}

@Test
Expand All @@ -196,9 +185,28 @@ public void testArrowFileZstdRoundTripWithDictionary() throws Exception {
provider.put(dictionary);

createAndWriteArrowFile(provider, CompressionUtil.CodecType.ZSTD);
readArrowFile(CommonsCompressionFactory.INSTANCE, true, null);
readArrowFile(NoCompressionCodec.Factory.INSTANCE, false,
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD");

// with compression
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
CommonsCompressionFactory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
Assert.assertTrue(reader.loadNextBatch());
Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assert.assertFalse(reader.loadNextBatch());
}
// without compression
try (ArrowFileReader reader =
new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Assert.assertEquals(1, reader.getRecordBlocks().size());
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
reader::loadNextBatch);
Assert.assertEquals(
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage()
);
}
dictionaryVector.close();
}

Expand All @@ -219,11 +227,24 @@ public void testArrowStreamZstdRoundTrip() throws Exception {

File tempFile = writeArrowStream(root, provider, CompressionUtil.CodecType.ZSTD);
// Read the on-disk compressed arrow file with CommonsCompressionFactory provided
readArrowStream(tempFile, allocator, CommonsCompressionFactory.INSTANCE, true, null);
try (SeekableByteChannel channel = FileChannel.open(tempFile.toPath());
ArrowStreamReader reader = new ArrowStreamReader(channel, allocator,
CommonsCompressionFactory.INSTANCE)) {
Assert.assertTrue(reader.loadNextBatch());
Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assert.assertFalse(reader.loadNextBatch());
}
// Read the on-disk compressed arrow file without CompressionFactory provided
readArrowStream(tempFile, allocator,
NoCompressionCodec.Factory.INSTANCE, false,
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD");
try (SeekableByteChannel channel = FileChannel.open(tempFile.toPath());
ArrowStreamReader reader = new ArrowStreamReader(channel, allocator,
NoCompressionCodec.Factory.INSTANCE)) {
Exception exception = Assert.assertThrows(IllegalArgumentException.class,
() -> reader.loadNextBatch());
Assert.assertEquals(
"Please add arrow-compression module to use CommonsCompressionFactory for ZSTD",
exception.getMessage()
);
}
dictionaryVector.close();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,7 @@ protected void writeDictionaryBatch(Dictionary dictionary) throws IOException {
Collections.singletonList(vector.getField()),
Collections.singletonList(vector),
count);
VectorUnloader unloader = new VectorUnloader(dictRoot, /*includeNullCount*/ true,
this.compressionLevel.isPresent() ?
this.compressionFactory.createCodec(this.codecType, this.compressionLevel.get()) :
this.compressionFactory.createCodec(this.codecType),
VectorUnloader unloader = new VectorUnloader(dictRoot, /*includeNullCount*/ true, getCodec(),
/*alignBuffers*/ true);
ArrowRecordBatch batch = unloader.getRecordBatch();
ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false);
Expand Down

0 comments on commit 82484f4

Please sign in to comment.