diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 973416670bdfa..4cb5f6cccc4c8 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -78,6 +78,12 @@ class StreamReader::StreamReaderImpl { int32_t message_length = *reinterpret_cast(buffer->data()); + if (message_length == 0) { + // Optional 0 EOS control message + *message = nullptr; + return Status::OK(); + } + RETURN_NOT_OK(stream_->Read(message_length, &buffer)); if (buffer->size() != message_length) { return Status::IOError("Unexpected end of stream trying to read message"); diff --git a/integration/integration_test.py b/integration/integration_test.py index 049436a751f38..5cd63c502bd20 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -680,12 +680,16 @@ def stream_to_file(self, stream_path, file_path): cmd = ['java', '-cp', self.ARROW_TOOLS_JAR, 'org.apache.arrow.tools.StreamToFile', stream_path, file_path] + if self.debug: + print(' '.join(cmd)) run_cmd(cmd) def file_to_stream(self, file_path, stream_path): cmd = ['java', '-cp', self.ARROW_TOOLS_JAR, 'org.apache.arrow.tools.FileToStream', file_path, stream_path] + if self.debug: + print(' '.join(cmd)) run_cmd(cmd) diff --git a/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java b/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java index c00620e44b064..7c0cadd9d77dd 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java @@ -18,23 +18,19 @@ package org.apache.arrow.tools; import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; import java.net.ServerSocket; import java.net.Socket; -import java.util.ArrayList; -import java.util.List; + +import com.google.common.base.Preconditions; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.stream.ArrowStreamReader; import org.apache.arrow.vector.stream.ArrowStreamWriter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.google.common.base.Preconditions; - public class EchoServer { private static final Logger LOGGER = LoggerFactory.getLogger(EchoServer.class); @@ -57,30 +53,28 @@ public ClientConnection(Socket socket) { public void run() throws IOException { BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - List batches = new ArrayList(); - try ( - InputStream in = socket.getInputStream(); - OutputStream out = socket.getOutputStream(); - ArrowStreamReader reader = new ArrowStreamReader(in, allocator); - ) { - // Read the entire input stream. - reader.init(); - while (true) { - ArrowRecordBatch batch = reader.nextRecordBatch(); - if (batch == null) break; - batches.add(batch); - } - LOGGER.info(String.format("Received %d batches", batches.size())); - - // Write it back - try (ArrowStreamWriter writer = new ArrowStreamWriter(out, reader.getSchema())) { - for (ArrowRecordBatch batch: batches) { - writer.writeRecordBatch(batch); + // Read the entire input stream and write it back + try (ArrowStreamReader reader = new ArrowStreamReader(socket.getInputStream(), allocator)) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + // load the first batch before instantiating the writer so that we have any dictionaries + reader.loadNextBatch(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, reader, socket.getOutputStream())) { + writer.start(); + int echoed = 0; + while (true) { + int rowCount = reader.getVectorSchemaRoot().getRowCount(); + if (rowCount == 0) { + break; + } else { + writer.writeBatch(); + echoed += rowCount; + reader.loadNextBatch(); + } } writer.end(); Preconditions.checkState(reader.bytesRead() == writer.bytesWritten()); + LOGGER.info(String.format("Echoed %d records", echoed)); } - LOGGER.info("Done writing stream back."); } } diff --git a/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java b/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java index db7a1c23f9ca6..9fa7b761a5772 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java @@ -23,18 +23,12 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.PrintStream; -import java.util.List; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.file.ArrowBlock; -import org.apache.arrow.vector.file.ArrowFooter; -import org.apache.arrow.vector.file.ArrowReader; -import org.apache.arrow.vector.file.ArrowWriter; -import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.file.ArrowFileReader; +import org.apache.arrow.vector.file.ArrowFileWriter; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; @@ -86,35 +80,27 @@ int run(String[] args) { File inFile = validateFile("input", inFileName); File outFile = validateFile("output", outFileName); BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); // TODO: close - try( - FileInputStream fileInputStream = new FileInputStream(inFile); - ArrowReader arrowReader = new ArrowReader(fileInputStream.getChannel(), allocator);) { + try (FileInputStream fileInputStream = new FileInputStream(inFile); + ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), allocator)) { - ArrowFooter footer = arrowReader.readFooter(); - Schema schema = footer.getSchema(); + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); LOGGER.debug("Input file size: " + inFile.length()); LOGGER.debug("Found schema: " + schema); - try ( - FileOutputStream fileOutputStream = new FileOutputStream(outFile); - ArrowWriter arrowWriter = new ArrowWriter(fileOutputStream.getChannel(), schema); - ) { - - // initialize vectors - - List recordBatches = footer.getRecordBatches(); - for (ArrowBlock rbBlock : recordBatches) { - try (ArrowRecordBatch inRecordBatch = arrowReader.readRecordBatch(rbBlock); - VectorSchemaRoot root = new VectorSchemaRoot(schema, allocator);) { - - VectorLoader vectorLoader = new VectorLoader(root); - vectorLoader.load(inRecordBatch); - - VectorUnloader vectorUnloader = new VectorUnloader(root); - ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch(); - arrowWriter.writeRecordBatch(recordBatch); + try (FileOutputStream fileOutputStream = new FileOutputStream(outFile); + ArrowFileWriter arrowWriter = new ArrowFileWriter(root, arrowReader, fileOutputStream.getChannel())) { + arrowWriter.start(); + while (true) { + arrowReader.loadNextBatch(); + int loaded = root.getRowCount(); + if (loaded == 0) { + break; + } else { + arrowWriter.writeBatch(); } } + arrowWriter.end(); } LOGGER.debug("Output file size: " + outFile.length()); } diff --git a/java/tools/src/main/java/org/apache/arrow/tools/FileToStream.java b/java/tools/src/main/java/org/apache/arrow/tools/FileToStream.java index ba6505cb48d08..d5345535d19dc 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/FileToStream.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/FileToStream.java @@ -25,10 +25,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.file.ArrowBlock; -import org.apache.arrow.vector.file.ArrowFooter; -import org.apache.arrow.vector.file.ArrowReader; -import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.file.ArrowFileReader; import org.apache.arrow.vector.stream.ArrowStreamWriter; /** @@ -36,19 +34,20 @@ * first argument and the output is written to standard out. */ public class FileToStream { + public static void convert(FileInputStream in, OutputStream out) throws IOException { BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - try( - ArrowReader reader = new ArrowReader(in.getChannel(), allocator);) { - ArrowFooter footer = reader.readFooter(); - try ( - ArrowStreamWriter writer = new ArrowStreamWriter(out, footer.getSchema()); - ) { - for (ArrowBlock block: footer.getRecordBatches()) { - try (ArrowRecordBatch batch = reader.readRecordBatch(block)) { - writer.writeRecordBatch(batch); - } + try (ArrowFileReader reader = new ArrowFileReader(in.getChannel(), allocator)) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + // load the first batch before instantiating the writer so that we have any dictionaries + reader.loadNextBatch(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, reader, out)) { + writer.start(); + while (root.getRowCount() > 0) { + writer.writeBatch(); + reader.loadNextBatch(); } + writer.end(); } } } diff --git a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java index 36d4ee5485470..5d4849c234383 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java @@ -28,16 +28,12 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; import org.apache.arrow.vector.file.ArrowBlock; -import org.apache.arrow.vector.file.ArrowFooter; -import org.apache.arrow.vector.file.ArrowReader; -import org.apache.arrow.vector.file.ArrowWriter; +import org.apache.arrow.vector.file.ArrowFileReader; +import org.apache.arrow.vector.file.ArrowFileWriter; import org.apache.arrow.vector.file.json.JsonFileReader; import org.apache.arrow.vector.file.json.JsonFileWriter; -import org.apache.arrow.vector.schema.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.Validator; import org.apache.commons.cli.CommandLine; @@ -69,24 +65,18 @@ enum Command { ARROW_TO_JSON(true, false) { @Override public void execute(File arrowFile, File jsonFile) throws IOException { - try( - BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + try(BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); FileInputStream fileInputStream = new FileInputStream(arrowFile); - ArrowReader arrowReader = new ArrowReader(fileInputStream.getChannel(), allocator);) { - ArrowFooter footer = arrowReader.readFooter(); - Schema schema = footer.getSchema(); + ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), allocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); LOGGER.debug("Input file size: " + arrowFile.length()); LOGGER.debug("Found schema: " + schema); - try (JsonFileWriter writer = new JsonFileWriter(jsonFile, JsonFileWriter.config().pretty(true));) { + try (JsonFileWriter writer = new JsonFileWriter(jsonFile, JsonFileWriter.config().pretty(true))) { writer.start(schema); - List recordBatches = footer.getRecordBatches(); - for (ArrowBlock rbBlock : recordBatches) { - try (ArrowRecordBatch inRecordBatch = arrowReader.readRecordBatch(rbBlock); - VectorSchemaRoot root = new VectorSchemaRoot(schema, allocator);) { - VectorLoader vectorLoader = new VectorLoader(root); - vectorLoader.load(inRecordBatch); - writer.write(root); - } + for (ArrowBlock rbBlock : arrowReader.getRecordBlocks()) { + arrowReader.loadRecordBatch(rbBlock); + writer.write(root); } } LOGGER.debug("Output file size: " + jsonFile.length()); @@ -96,27 +86,22 @@ public void execute(File arrowFile, File jsonFile) throws IOException { JSON_TO_ARROW(false, true) { @Override public void execute(File arrowFile, File jsonFile) throws IOException { - try ( - BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - JsonFileReader reader = new JsonFileReader(jsonFile, allocator); - ) { + try (BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + JsonFileReader reader = new JsonFileReader(jsonFile, allocator)) { Schema schema = reader.start(); LOGGER.debug("Input file size: " + jsonFile.length()); LOGGER.debug("Found schema: " + schema); - try ( - FileOutputStream fileOutputStream = new FileOutputStream(arrowFile); - ArrowWriter arrowWriter = new ArrowWriter(fileOutputStream.getChannel(), schema); - ) { - - // initialize vectors - VectorSchemaRoot root; - while ((root = reader.read()) != null) { - VectorUnloader vectorUnloader = new VectorUnloader(root); - try (ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch();) { - arrowWriter.writeRecordBatch(recordBatch); - } - root.close(); + try (FileOutputStream fileOutputStream = new FileOutputStream(arrowFile); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + // TODO json dictionaries + ArrowFileWriter arrowWriter = new ArrowFileWriter(root, null, fileOutputStream.getChannel())) { + arrowWriter.start(); + reader.read(root); + while (root.getRowCount() != 0) { + arrowWriter.writeBatch(); + reader.read(root); } + arrowWriter.end(); } LOGGER.debug("Output file size: " + arrowFile.length()); } @@ -125,32 +110,26 @@ public void execute(File arrowFile, File jsonFile) throws IOException { VALIDATE(true, true) { @Override public void execute(File arrowFile, File jsonFile) throws IOException { - try ( - BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - JsonFileReader jsonReader = new JsonFileReader(jsonFile, allocator); - FileInputStream fileInputStream = new FileInputStream(arrowFile); - ArrowReader arrowReader = new ArrowReader(fileInputStream.getChannel(), allocator); - ) { + try (BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + JsonFileReader jsonReader = new JsonFileReader(jsonFile, allocator); + FileInputStream fileInputStream = new FileInputStream(arrowFile); + ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), allocator)) { Schema jsonSchema = jsonReader.start(); - ArrowFooter footer = arrowReader.readFooter(); - Schema arrowSchema = footer.getSchema(); + VectorSchemaRoot arrowRoot = arrowReader.getVectorSchemaRoot(); + Schema arrowSchema = arrowRoot.getSchema(); LOGGER.debug("Arrow Input file size: " + arrowFile.length()); LOGGER.debug("ARROW schema: " + arrowSchema); LOGGER.debug("JSON Input file size: " + jsonFile.length()); LOGGER.debug("JSON schema: " + jsonSchema); Validator.compareSchemas(jsonSchema, arrowSchema); - List recordBatches = footer.getRecordBatches(); + List recordBatches = arrowReader.getRecordBlocks(); Iterator iterator = recordBatches.iterator(); VectorSchemaRoot jsonRoot; while ((jsonRoot = jsonReader.read()) != null && iterator.hasNext()) { ArrowBlock rbBlock = iterator.next(); - try (ArrowRecordBatch inRecordBatch = arrowReader.readRecordBatch(rbBlock); - VectorSchemaRoot arrowRoot = new VectorSchemaRoot(arrowSchema, allocator);) { - VectorLoader vectorLoader = new VectorLoader(arrowRoot); - vectorLoader.load(inRecordBatch); - Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot); - } + arrowReader.loadRecordBatch(rbBlock); + Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot); jsonRoot.close(); } boolean hasMoreJSON = jsonRoot != null; diff --git a/java/tools/src/main/java/org/apache/arrow/tools/StreamToFile.java b/java/tools/src/main/java/org/apache/arrow/tools/StreamToFile.java index c8a5c8914afcc..3b79d5b05e116 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/StreamToFile.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/StreamToFile.java @@ -27,8 +27,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.file.ArrowWriter; -import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.file.ArrowFileWriter; import org.apache.arrow.vector.stream.ArrowStreamReader; /** @@ -38,13 +38,16 @@ public class StreamToFile { public static void convert(InputStream in, OutputStream out) throws IOException { BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator)) { - reader.init(); - try (ArrowWriter writer = new ArrowWriter(Channels.newChannel(out), reader.getSchema());) { - while (true) { - ArrowRecordBatch batch = reader.nextRecordBatch(); - if (batch == null) break; - writer.writeRecordBatch(batch); + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + // load the first batch before instantiating the writer so that we have any dictionaries + reader.loadNextBatch(); + try (ArrowFileWriter writer = new ArrowFileWriter(root, reader, Channels.newChannel(out))) { + writer.start(); + while (root.getRowCount() > 0) { + writer.writeBatch(); + reader.loadNextBatch(); } + writer.end(); } } } diff --git a/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java b/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java index 4cfc52fe08631..f752f7eaa74b9 100644 --- a/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java +++ b/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java @@ -23,13 +23,10 @@ import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; -import java.util.List; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.impl.ComplexWriterImpl; import org.apache.arrow.vector.complex.writer.BaseWriter.ComplexWriter; @@ -37,10 +34,8 @@ import org.apache.arrow.vector.complex.writer.BigIntWriter; import org.apache.arrow.vector.complex.writer.IntWriter; import org.apache.arrow.vector.file.ArrowBlock; -import org.apache.arrow.vector.file.ArrowFooter; -import org.apache.arrow.vector.file.ArrowReader; -import org.apache.arrow.vector.file.ArrowWriter; -import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.file.ArrowFileReader; +import org.apache.arrow.vector.file.ArrowFileWriter; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; @@ -63,26 +58,14 @@ static void writeData(int count, MapVector parent) { static void validateOutput(File testOutFile, BufferAllocator allocator) throws Exception { // read - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(testOutFile); - ArrowReader arrowReader = new ArrowReader(fileInputStream.getChannel(), readerAllocator); - BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - ) { - ArrowFooter footer = arrowReader.readFooter(); - Schema schema = footer.getSchema(); - - // initialize vectors - try (VectorSchemaRoot root = new VectorSchemaRoot(schema, readerAllocator)) { - VectorLoader vectorLoader = new VectorLoader(root); - - List recordBatches = footer.getRecordBatches(); - for (ArrowBlock rbBlock : recordBatches) { - try (ArrowRecordBatch recordBatch = arrowReader.readRecordBatch(rbBlock)) { - vectorLoader.load(recordBatch); - } - validateContent(COUNT, root); - } + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + FileInputStream fileInputStream = new FileInputStream(testOutFile); + ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); + for (ArrowBlock rbBlock : arrowReader.getRecordBlocks()) { + arrowReader.loadRecordBatch(rbBlock); + validateContent(COUNT, root); } } } @@ -96,16 +79,10 @@ static void validateContent(int count, VectorSchemaRoot root) { } static void write(FieldVector parent, File file) throws FileNotFoundException, IOException { - Schema schema = new Schema(parent.getField().getChildren()); - int valueCount = parent.getAccessor().getValueCount(); - List fields = parent.getChildrenFromFields(); - VectorUnloader vectorUnloader = new VectorUnloader(schema, valueCount, fields); - try ( - FileOutputStream fileOutputStream = new FileOutputStream(file); - ArrowWriter arrowWriter = new ArrowWriter(fileOutputStream.getChannel(), schema); - ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch(); - ) { - arrowWriter.writeRecordBatch(recordBatch); + VectorSchemaRoot root = new VectorSchemaRoot(parent); + try (FileOutputStream fileOutputStream = new FileOutputStream(file); + ArrowFileWriter arrowWriter = new ArrowFileWriter(root, null, fileOutputStream.getChannel())) { + arrowWriter.writeBatch(); } } diff --git a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java index 48d6162f423a3..706f8e2ca4d36 100644 --- a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java +++ b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java @@ -24,106 +24,268 @@ import java.io.IOException; import java.net.Socket; import java.net.UnknownHostException; -import java.util.ArrayList; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.Collections; import java.util.List; +import com.google.common.collect.ImmutableList; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.schema.ArrowFieldNode; -import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.NullableIntVector; +import org.apache.arrow.vector.NullableTinyIntVector; +import org.apache.arrow.vector.NullableVarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; import org.apache.arrow.vector.stream.ArrowStreamReader; import org.apache.arrow.vector.stream.ArrowStreamWriter; +import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.Int; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; import org.junit.Test; -import io.netty.buffer.ArrowBuf; - public class EchoServerTest { - public static ArrowBuf buf(BufferAllocator alloc, byte[] bytes) { - ArrowBuf buffer = alloc.buffer(bytes.length); - buffer.writeBytes(bytes); - return buffer; + + private static EchoServer server; + private static int serverPort; + private static Thread serverThread; + + @BeforeClass + public static void startEchoServer() throws IOException { + server = new EchoServer(0); + serverPort = server.port(); + serverThread = new Thread() { + @Override + public void run() { + try { + server.run(); + } catch (IOException e) { + e.printStackTrace(); + } + } + }; + serverThread.start(); } - public static byte[] array(ArrowBuf buf) { - byte[] bytes = new byte[buf.readableBytes()]; - buf.readBytes(bytes); - return bytes; + @AfterClass + public static void stopEchoServer() throws IOException, InterruptedException { + server.close(); + serverThread.join(); } - private void testEchoServer(int serverPort, Schema schema, List batches) + private void testEchoServer(int serverPort, + Field field, + NullableTinyIntVector vector, + int batches) throws UnknownHostException, IOException { BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot root = new VectorSchemaRoot(asList(field), asList((FieldVector) vector), 0); try (Socket socket = new Socket("localhost", serverPort); - ArrowStreamWriter writer = new ArrowStreamWriter(socket.getOutputStream(), schema); + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, socket.getOutputStream()); ArrowStreamReader reader = new ArrowStreamReader(socket.getInputStream(), alloc)) { - for (ArrowRecordBatch batch: batches) { - writer.writeRecordBatch(batch); + writer.start(); + for (int i = 0; i < batches; i++) { + vector.allocateNew(16); + for (int j = 0; j < 8; j++) { + vector.getMutator().set(j, j + i); + vector.getMutator().set(j + 8, 0, (byte) (j + i)); + } + vector.getMutator().setValueCount(16); + root.setRowCount(16); + writer.writeBatch(); } writer.end(); - reader.init(); - assertEquals(schema, reader.getSchema()); - for (int i = 0; i < batches.size(); i++) { - ArrowRecordBatch result = reader.nextRecordBatch(); - ArrowRecordBatch expected = batches.get(i); - assertTrue(result != null); - assertEquals(expected.getBuffers().size(), result.getBuffers().size()); - for (int j = 0; j < expected.getBuffers().size(); j++) { - assertTrue(expected.getBuffers().get(j).compareTo(result.getBuffers().get(j)) == 0); + assertEquals(new Schema(asList(field)), reader.getVectorSchemaRoot().getSchema()); + + NullableTinyIntVector readVector = (NullableTinyIntVector) reader.getVectorSchemaRoot().getFieldVectors().get(0); + for (int i = 0; i < batches; i++) { + reader.loadNextBatch(); + assertEquals(16, reader.getVectorSchemaRoot().getRowCount()); + assertEquals(16, readVector.getAccessor().getValueCount()); + for (int j = 0; j < 8; j++) { + assertEquals(j + i, readVector.getAccessor().get(j)); + assertTrue(readVector.getAccessor().isNull(j + 8)); } } - ArrowRecordBatch result = reader.nextRecordBatch(); - assertTrue(result == null); + reader.loadNextBatch(); + assertEquals(0, reader.getVectorSchemaRoot().getRowCount()); assertEquals(reader.bytesRead(), writer.bytesWritten()); } } @Test public void basicTest() throws InterruptedException, IOException { - final EchoServer server = new EchoServer(0); - int serverPort = server.port(); - Thread serverThread = new Thread() { - @Override - public void run() { - try { - server.run(); - } catch (IOException e) { - e.printStackTrace(); - } - } - }; - serverThread.start(); - BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); - byte[] validity = new byte[] { (byte)255, 0}; - byte[] values = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - ArrowBuf validityb = buf(alloc, validity); - ArrowBuf valuesb = buf(alloc, values); - ArrowRecordBatch batch = new ArrowRecordBatch( - 16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb)); - Schema schema = new Schema(asList(new Field( - "testField", true, new ArrowType.Int(8, true), Collections.emptyList()))); + Field field = new Field("testField", true, new ArrowType.Int(8, true), Collections.emptyList()); + NullableTinyIntVector vector = new NullableTinyIntVector("testField", alloc, null); + Schema schema = new Schema(asList(field)); // Try an empty stream, just the header. - testEchoServer(serverPort, schema, new ArrayList()); + testEchoServer(serverPort, field, vector, 0); // Try with one batch. - List batches = new ArrayList<>(); - batches.add(batch); - testEchoServer(serverPort, schema, batches); + testEchoServer(serverPort, field, vector, 1); // Try with a few - for (int i = 0; i < 10; i++) { - batches.add(batch); + testEchoServer(serverPort, field, vector, 10); + } + + @Test + public void testFlatDictionary() throws IOException { + DictionaryEncoding writeEncoding = new DictionaryEncoding(1L, false, null); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + NullableIntVector writeVector = new NullableIntVector("varchar", allocator, writeEncoding); + NullableVarCharVector writeDictionaryVector = new NullableVarCharVector("dict", allocator, null)) { + writeVector.allocateNewSafe(); + NullableIntVector.Mutator mutator = writeVector.getMutator(); + mutator.set(0, 0); + mutator.set(1, 1); + mutator.set(3, 2); + mutator.set(4, 1); + mutator.set(5, 2); + mutator.setValueCount(6); + + writeDictionaryVector.allocateNewSafe(); + NullableVarCharVector.Mutator dictionaryMutator = writeDictionaryVector.getMutator(); + dictionaryMutator.set(0, "foo".getBytes(StandardCharsets.UTF_8)); + dictionaryMutator.set(1, "bar".getBytes(StandardCharsets.UTF_8)); + dictionaryMutator.set(2, "baz".getBytes(StandardCharsets.UTF_8)); + dictionaryMutator.setValueCount(3); + + List fields = ImmutableList.of(writeVector.getField()); + List vectors = ImmutableList.of((FieldVector) writeVector); + VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors, 6); + + DictionaryProvider writeProvider = new MapDictionaryProvider(new Dictionary(writeDictionaryVector, writeEncoding)); + + try (Socket socket = new Socket("localhost", serverPort); + ArrowStreamWriter writer = new ArrowStreamWriter(root, writeProvider, socket.getOutputStream()); + ArrowStreamReader reader = new ArrowStreamReader(socket.getInputStream(), allocator)) { + writer.start(); + writer.writeBatch(); + writer.end(); + + reader.loadNextBatch(); + VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); + Assert.assertEquals(6, readerRoot.getRowCount()); + + FieldVector readVector = readerRoot.getFieldVectors().get(0); + Assert.assertNotNull(readVector); + + DictionaryEncoding readEncoding = readVector.getField().getDictionary(); + Assert.assertNotNull(readEncoding); + Assert.assertEquals(1L, readEncoding.getId()); + + FieldVector.Accessor accessor = readVector.getAccessor(); + Assert.assertEquals(6, accessor.getValueCount()); + Assert.assertEquals(0, accessor.getObject(0)); + Assert.assertEquals(1, accessor.getObject(1)); + Assert.assertEquals(null, accessor.getObject(2)); + Assert.assertEquals(2, accessor.getObject(3)); + Assert.assertEquals(1, accessor.getObject(4)); + Assert.assertEquals(2, accessor.getObject(5)); + + Dictionary dictionary = reader.lookup(1L); + Assert.assertNotNull(dictionary); + NullableVarCharVector.Accessor dictionaryAccessor = ((NullableVarCharVector) dictionary.getVector()).getAccessor(); + Assert.assertEquals(3, dictionaryAccessor.getValueCount()); + Assert.assertEquals(new Text("foo"), dictionaryAccessor.getObject(0)); + Assert.assertEquals(new Text("bar"), dictionaryAccessor.getObject(1)); + Assert.assertEquals(new Text("baz"), dictionaryAccessor.getObject(2)); + } } - testEchoServer(serverPort, schema, batches); + } - server.close(); - serverThread.join(); + @Test + public void testNestedDictionary() throws IOException { + DictionaryEncoding writeEncoding = new DictionaryEncoding(2L, false, null); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + NullableVarCharVector writeDictionaryVector = new NullableVarCharVector("dictionary", allocator, null); + ListVector writeVector = new ListVector("list", allocator, null, null)) { + + // data being written: + // [['foo', 'bar'], ['foo'], ['bar']] -> [[0, 1], [0], [1]] + + writeDictionaryVector.allocateNew(); + writeDictionaryVector.getMutator().set(0, "foo".getBytes(StandardCharsets.UTF_8)); + writeDictionaryVector.getMutator().set(1, "bar".getBytes(StandardCharsets.UTF_8)); + writeDictionaryVector.getMutator().setValueCount(2); + + writeVector.addOrGetVector(MinorType.INT, writeEncoding); + writeVector.allocateNew(); + UnionListWriter listWriter = new UnionListWriter(writeVector); + listWriter.startList(); + listWriter.writeInt(0); + listWriter.writeInt(1); + listWriter.endList(); + listWriter.startList(); + listWriter.writeInt(0); + listWriter.endList(); + listWriter.startList(); + listWriter.writeInt(1); + listWriter.endList(); + listWriter.setValueCount(3); + + List fields = ImmutableList.of(writeVector.getField()); + List vectors = ImmutableList.of((FieldVector) writeVector); + VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors, 3); + + DictionaryProvider writeProvider = new MapDictionaryProvider(new Dictionary(writeDictionaryVector, writeEncoding)); + + try (Socket socket = new Socket("localhost", serverPort); + ArrowStreamWriter writer = new ArrowStreamWriter(root, writeProvider, socket.getOutputStream()); + ArrowStreamReader reader = new ArrowStreamReader(socket.getInputStream(), allocator)) { + writer.start(); + writer.writeBatch(); + writer.end(); + + reader.loadNextBatch(); + VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); + Assert.assertEquals(3, readerRoot.getRowCount()); + + ListVector readVector = (ListVector) readerRoot.getFieldVectors().get(0); + Assert.assertNotNull(readVector); + + Assert.assertNull(readVector.getField().getDictionary()); + DictionaryEncoding readEncoding = readVector.getField().getChildren().get(0).getDictionary(); + Assert.assertNotNull(readEncoding); + Assert.assertEquals(2L, readEncoding.getId()); + + Field nestedField = readVector.getField().getChildren().get(0); + + DictionaryEncoding encoding = nestedField.getDictionary(); + Assert.assertNotNull(encoding); + Assert.assertEquals(2L, encoding.getId()); + Assert.assertEquals(new Int(32, true), encoding.getIndexType()); + + ListVector.Accessor accessor = readVector.getAccessor(); + Assert.assertEquals(3, accessor.getValueCount()); + Assert.assertEquals(Arrays.asList(0, 1), accessor.getObject(0)); + Assert.assertEquals(Arrays.asList(0), accessor.getObject(1)); + Assert.assertEquals(Arrays.asList(1), accessor.getObject(2)); + + Dictionary readDictionary = reader.lookup(2L); + Assert.assertNotNull(readDictionary); + NullableVarCharVector.Accessor dictionaryAccessor = ((NullableVarCharVector) readDictionary.getVector()).getAccessor(); + Assert.assertEquals(2, dictionaryAccessor.getValueCount()); + Assert.assertEquals(new Text("foo"), dictionaryAccessor.getObject(0)); + Assert.assertEquals(new Text("bar"), dictionaryAccessor.getObject(1)); + } + } } } diff --git a/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java b/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java index 0ae32bebe0b30..9d4ef5c26505b 100644 --- a/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java +++ b/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java @@ -33,6 +33,11 @@ import java.io.StringReader; import java.util.Map; +import com.fasterxml.jackson.core.util.DefaultPrettyPrinter; +import com.fasterxml.jackson.core.util.DefaultPrettyPrinter.NopIndenter; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.tools.Integration.Command; @@ -49,11 +54,6 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; -import com.fasterxml.jackson.core.util.DefaultPrettyPrinter; -import com.fasterxml.jackson.core.util.DefaultPrettyPrinter.NopIndenter; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.SerializationFeature; - public class TestIntegration { @Rule @@ -128,6 +128,34 @@ public void testJSONRoundTripWithVariableWidth() throws Exception { } } + @Test + public void testJSONRoundTripWithStruct() throws Exception { + File testJSONFile = new File("../../integration/data/struct_example.json"); + File testOutFile = testFolder.newFile("testOutStruct.arrow"); + File testRoundTripJSONFile = testFolder.newFile("testOutStruct.json"); + testOutFile.delete(); + testRoundTripJSONFile.delete(); + + Integration integration = new Integration(); + + // convert to arrow + String[] args1 = { "-arrow", testOutFile.getAbsolutePath(), "-json", testJSONFile.getAbsolutePath(), "-command", Command.JSON_TO_ARROW.name()}; + integration.run(args1); + + // convert back to json + String[] args2 = { "-arrow", testOutFile.getAbsolutePath(), "-json", testRoundTripJSONFile.getAbsolutePath(), "-command", Command.ARROW_TO_JSON.name()}; + integration.run(args2); + + BufferedReader orig = readNormalized(testJSONFile); + BufferedReader rt = readNormalized(testRoundTripJSONFile); + String i, o; + int j = 0; + while ((i = orig.readLine()) != null && (o = rt.readLine()) != null) { + assertEquals("line: " + j, i, o); + ++j; + } + } + private ObjectMapper om = new ObjectMapper(); { DefaultPrettyPrinter prettyPrinter = new DefaultPrettyPrinter(); diff --git a/java/tools/tmptestfilesio b/java/tools/tmptestfilesio new file mode 100644 index 0000000000000..d1b6b6cdb9387 Binary files /dev/null and b/java/tools/tmptestfilesio differ diff --git a/java/vector/src/main/codegen/templates/MapWriters.java b/java/vector/src/main/codegen/templates/MapWriters.java index 4af6eee91b6de..428ce0427d4b8 100644 --- a/java/vector/src/main/codegen/templates/MapWriters.java +++ b/java/vector/src/main/codegen/templates/MapWriters.java @@ -64,7 +64,7 @@ public class ${mode}MapWriter extends AbstractFieldWriter { list(child.getName()); break; case UNION: - UnionWriter writer = new UnionWriter(container.addOrGet(child.getName(), MinorType.UNION, UnionVector.class), getNullableMapWriterFactory()); + UnionWriter writer = new UnionWriter(container.addOrGet(child.getName(), MinorType.UNION, UnionVector.class, null), getNullableMapWriterFactory()); fields.put(handleCase(child.getName()), writer); break; <#list vv.types as type><#list type.minor as minor> @@ -113,7 +113,7 @@ public MapWriter map(String name) { FieldWriter writer = fields.get(finalName); if(writer == null){ int vectorCount=container.size(); - NullableMapVector vector = container.addOrGet(name, MinorType.MAP, NullableMapVector.class); + NullableMapVector vector = container.addOrGet(name, MinorType.MAP, NullableMapVector.class, null); writer = new PromotableWriter(vector, container, getNullableMapWriterFactory()); if(vectorCount != container.size()) { writer.allocate(); @@ -157,7 +157,7 @@ public ListWriter list(String name) { FieldWriter writer = fields.get(finalName); int vectorCount = container.size(); if(writer == null) { - writer = new PromotableWriter(container.addOrGet(name, MinorType.LIST, ListVector.class), container, getNullableMapWriterFactory()); + writer = new PromotableWriter(container.addOrGet(name, MinorType.LIST, ListVector.class, null), container, getNullableMapWriterFactory()); if (container.size() > vectorCount) { writer.allocate(); } @@ -222,7 +222,7 @@ public void end() { if(writer == null) { ValueVector vector; ValueVector currentVector = container.getChild(name); - ${vectName}Vector v = container.addOrGet(name, MinorType.${upperName}, ${vectName}Vector.class<#if minor.class == "Decimal"> , new int[] {precision, scale}); + ${vectName}Vector v = container.addOrGet(name, MinorType.${upperName}, ${vectName}Vector.class, null<#if minor.class == "Decimal"> , new int[] {precision, scale}); writer = new PromotableWriter(v, container, getNullableMapWriterFactory()); vector = v; if (currentVector == null || currentVector != vector) { diff --git a/java/vector/src/main/codegen/templates/NullableValueVectors.java b/java/vector/src/main/codegen/templates/NullableValueVectors.java index 6b25fb36b40c0..b3e10e3fa87a2 100644 --- a/java/vector/src/main/codegen/templates/NullableValueVectors.java +++ b/java/vector/src/main/codegen/templates/NullableValueVectors.java @@ -65,21 +65,21 @@ public final class ${className} extends BaseDataValueVector implements <#if type private final int precision; private final int scale; - public ${className}(String name, BufferAllocator allocator, int precision, int scale) { + public ${className}(String name, BufferAllocator allocator, DictionaryEncoding dictionary, int precision, int scale) { super(name, allocator); values = new ${valuesName}(valuesField, allocator, precision, scale); this.precision = precision; this.scale = scale; mutator = new Mutator(); accessor = new Accessor(); - field = new Field(name, true, new Decimal(precision, scale), null); + field = new Field(name, true, new Decimal(precision, scale), dictionary, null); innerVectors = Collections.unmodifiableList(Arrays.asList( bits, values )); } <#else> - public ${className}(String name, BufferAllocator allocator) { + public ${className}(String name, BufferAllocator allocator, DictionaryEncoding dictionary) { super(name, allocator); values = new ${valuesName}(valuesField, allocator); mutator = new Mutator(); @@ -88,38 +88,38 @@ public final class ${className} extends BaseDataValueVector implements <#if type minor.class == "SmallInt" || minor.class == "Int" || minor.class == "BigInt"> - field = new Field(name, true, new Int(${type.width} * 8, true), null); + field = new Field(name, true, new Int(${type.width} * 8, true), dictionary, null); <#elseif minor.class == "UInt1" || minor.class == "UInt2" || minor.class == "UInt4" || minor.class == "UInt8"> - field = new Field(name, true, new Int(${type.width} * 8, false), null); + field = new Field(name, true, new Int(${type.width} * 8, false), dictionary, null); <#elseif minor.class == "Date"> - field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Date(), null); + field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Date(), dictionary, null); <#elseif minor.class == "Time"> - field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Time(), null); + field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Time(), dictionary, null); <#elseif minor.class == "Float4"> - field = new Field(name, true, new FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.SINGLE), null); + field = new Field(name, true, new FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.SINGLE), dictionary, null); <#elseif minor.class == "Float8"> - field = new Field(name, true, new FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE), null); + field = new Field(name, true, new FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE), dictionary, null); <#elseif minor.class == "TimeStampSec"> - field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Timestamp(org.apache.arrow.vector.types.TimeUnit.SECOND), null); + field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Timestamp(org.apache.arrow.vector.types.TimeUnit.SECOND), dictionary, null); <#elseif minor.class == "TimeStampMilli"> - field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Timestamp(org.apache.arrow.vector.types.TimeUnit.MILLISECOND), null); + field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Timestamp(org.apache.arrow.vector.types.TimeUnit.MILLISECOND), dictionary, null); <#elseif minor.class == "TimeStampMicro"> - field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Timestamp(org.apache.arrow.vector.types.TimeUnit.MICROSECOND), null); + field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Timestamp(org.apache.arrow.vector.types.TimeUnit.MICROSECOND), dictionary, null); <#elseif minor.class == "TimeStampNano"> - field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Timestamp(org.apache.arrow.vector.types.TimeUnit.NANOSECOND), null); + field = new Field(name, true, new org.apache.arrow.vector.types.pojo.ArrowType.Timestamp(org.apache.arrow.vector.types.TimeUnit.NANOSECOND), dictionary, null); <#elseif minor.class == "IntervalDay"> - field = new Field(name, true, new Interval(org.apache.arrow.vector.types.IntervalUnit.DAY_TIME), null); + field = new Field(name, true, new Interval(org.apache.arrow.vector.types.IntervalUnit.DAY_TIME), dictionary, null); <#elseif minor.class == "IntervalYear"> - field = new Field(name, true, new Interval(org.apache.arrow.vector.types.IntervalUnit.YEAR_MONTH), null); + field = new Field(name, true, new Interval(org.apache.arrow.vector.types.IntervalUnit.YEAR_MONTH), dictionary, null); <#elseif minor.class == "VarChar"> - field = new Field(name, true, new Utf8(), null); + field = new Field(name, true, new Utf8(), dictionary, null); <#elseif minor.class == "VarBinary"> - field = new Field(name, true, new Binary(), null); + field = new Field(name, true, new Binary(), dictionary, null); <#elseif minor.class == "Bit"> - field = new Field(name, true, new Bool(), null); + field = new Field(name, true, new Bool(), dictionary, null); innerVectors = Collections.unmodifiableList(Arrays.asList( bits, @@ -378,9 +378,9 @@ private class TransferImpl implements TransferPair { public TransferImpl(String name, BufferAllocator allocator){ <#if minor.class == "Decimal"> - to = new ${className}(name, allocator, precision, scale); + to = new ${className}(name, allocator, field.getDictionary(), precision, scale); <#else> - to = new ${className}(name, allocator); + to = new ${className}(name, allocator, field.getDictionary()); } diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index 1a6908df2c40d..076ed93999623 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -118,11 +118,11 @@ public List getFieldBuffers() { public List getFieldInnerVectors() { return this.innerVectors; } - + public NullableMapVector getMap() { if (mapVector == null) { int vectorCount = internalMap.size(); - mapVector = internalMap.addOrGet("map", MinorType.MAP, NullableMapVector.class); + mapVector = internalMap.addOrGet("map", MinorType.MAP, NullableMapVector.class, null); if (internalMap.size() > vectorCount) { mapVector.allocateNew(); if (callBack != null) { @@ -144,7 +144,7 @@ public NullableMapVector getMap() { public Nullable${name}Vector get${name}Vector() { if (${uncappedName}Vector == null) { int vectorCount = internalMap.size(); - ${uncappedName}Vector = internalMap.addOrGet("${lowerCaseName}", MinorType.${name?upper_case}, Nullable${name}Vector.class); + ${uncappedName}Vector = internalMap.addOrGet("${lowerCaseName}", MinorType.${name?upper_case}, Nullable${name}Vector.class, null); if (internalMap.size() > vectorCount) { ${uncappedName}Vector.allocateNew(); if (callBack != null) { @@ -162,7 +162,7 @@ public NullableMapVector getMap() { public ListVector getList() { if (listVector == null) { int vectorCount = internalMap.size(); - listVector = internalMap.addOrGet("list", MinorType.LIST, ListVector.class); + listVector = internalMap.addOrGet("list", MinorType.LIST, ListVector.class, null); if (internalMap.size() > vectorCount) { listVector.allocateNew(); if (callBack != null) { @@ -262,7 +262,7 @@ public void copyFromSafe(int inIndex, int outIndex, UnionVector from) { public FieldVector addVector(FieldVector v) { String name = v.getMinorType().name().toLowerCase(); Preconditions.checkState(internalMap.getChild(name) == null, String.format("%s vector already exists", name)); - final FieldVector newVector = internalMap.addOrGet(name, v.getMinorType(), v.getClass()); + final FieldVector newVector = internalMap.addOrGet(name, v.getMinorType(), v.getClass(), v.getField().getDictionary()); v.makeTransferPair(newVector).transfer(); internalMap.putChild(name, newVector); if (callBack != null) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java index d1e9abe5dd111..179f2ee879f43 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BitVector.java @@ -81,6 +81,7 @@ public void load(ArrowFieldNode fieldNode, ArrowBuf data) { } else { super.load(fieldNode, data); } + this.valueCount = fieldNode.getLength(); } @Override @@ -451,7 +452,6 @@ public final void setToOne(int index) { /** * set count bits to 1 in data starting at firstBitIndex - * @param data the buffer to set * @param firstBitIndex the index of the first bit to set * @param count the number of bits to set */ diff --git a/java/vector/src/main/java/org/apache/arrow/vector/FieldVector.java b/java/vector/src/main/java/org/apache/arrow/vector/FieldVector.java index b28433cfd0d94..0fdbc48552aaa 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/FieldVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/FieldVector.java @@ -19,11 +19,10 @@ import java.util.List; +import io.netty.buffer.ArrowBuf; import org.apache.arrow.vector.schema.ArrowFieldNode; import org.apache.arrow.vector.types.pojo.Field; -import io.netty.buffer.ArrowBuf; - /** * A vector corresponding to a Field in the schema * It has inner vectors backed by buffers (validity, offsets, data, ...) @@ -61,5 +60,4 @@ public interface FieldVector extends ValueVector { * @return the inner vectors for this field as defined by the TypeLayout */ List getFieldInnerVectors(); - } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java b/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java index 5c1176cf95d26..76de250e0e972 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java @@ -36,15 +36,14 @@ * Loads buffers into vectors */ public class VectorLoader { + private final VectorSchemaRoot root; /** * will create children in root based on schema - * @param schema the expected schema * @param root the root to add vectors to based on schema */ public VectorLoader(VectorSchemaRoot root) { - super(); this.root = root; } @@ -57,18 +56,16 @@ public void load(ArrowRecordBatch recordBatch) { Iterator buffers = recordBatch.getBuffers().iterator(); Iterator nodes = recordBatch.getNodes().iterator(); List fields = root.getSchema().getFields(); - for (int i = 0; i < fields.size(); ++i) { - Field field = fields.get(i); + for (Field field: fields) { FieldVector fieldVector = root.getVector(field.getName()); loadBuffers(fieldVector, field, buffers, nodes); } root.setRowCount(recordBatch.getLength()); if (nodes.hasNext() || buffers.hasNext()) { - throw new IllegalArgumentException("not all nodes and buffers where consumed. nodes: " + Iterators.toString(nodes) + " buffers: " + Iterators.toString(buffers)); + throw new IllegalArgumentException("not all nodes and buffers were consumed. nodes: " + Iterators.toString(nodes) + " buffers: " + Iterators.toString(buffers)); } } - private void loadBuffers(FieldVector vector, Field field, Iterator buffers, Iterator nodes) { checkArgument(nodes.hasNext(), "no more field nodes for for field " + field + " and vector " + vector); @@ -82,7 +79,7 @@ private void loadBuffers(FieldVector vector, Field field, Iterator buf vector.loadFieldBuffers(fieldNode, ownBuffers); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load buffers for field " + - field + ". error message: " + e.getMessage(), e); + field + ". error message: " + e.getMessage(), e); } List children = field.getChildren(); if (children.size() > 0) { @@ -96,4 +93,4 @@ private void loadBuffers(FieldVector vector, Field field, Iterator buf } } -} +} \ No newline at end of file diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VectorSchemaRoot.java b/java/vector/src/main/java/org/apache/arrow/vector/VectorSchemaRoot.java index 1cbe18787ef45..7e626fb14305e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VectorSchemaRoot.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VectorSchemaRoot.java @@ -18,7 +18,6 @@ package org.apache.arrow.vector; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -29,6 +28,9 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +/** + * Holder for a set of vectors to be loaded/unloaded + */ public class VectorSchemaRoot implements AutoCloseable { private final Schema schema; @@ -37,9 +39,17 @@ public class VectorSchemaRoot implements AutoCloseable { private final Map fieldVectorsMap = new HashMap<>(); public VectorSchemaRoot(FieldVector parent) { - this.schema = new Schema(parent.getField().getChildren()); - this.rowCount = parent.getAccessor().getValueCount(); - this.fieldVectors = parent.getChildrenFromFields(); + this(parent.getField().getChildren(), parent.getChildrenFromFields(), parent.getAccessor().getValueCount()); + } + + public VectorSchemaRoot(List fields, List fieldVectors, int rowCount) { + if (fields.size() != fieldVectors.size()) { + throw new IllegalArgumentException("Fields must match field vectors. Found " + + fieldVectors.size() + " vectors and " + fields.size() + " fields"); + } + this.schema = new Schema(fields); + this.rowCount = rowCount; + this.fieldVectors = fieldVectors; for (int i = 0; i < schema.getFields().size(); ++i) { Field field = schema.getFields().get(i); FieldVector vector = fieldVectors.get(i); @@ -47,21 +57,19 @@ public VectorSchemaRoot(FieldVector parent) { } } - public VectorSchemaRoot(Schema schema, BufferAllocator allocator) { - super(); - this.schema = schema; + public static VectorSchemaRoot create(Schema schema, BufferAllocator allocator) { List fieldVectors = new ArrayList<>(); for (Field field : schema.getFields()) { MinorType minorType = Types.getMinorTypeForArrowType(field.getType()); - FieldVector vector = minorType.getNewVector(field.getName(), allocator, null); + FieldVector vector = minorType.getNewVector(field.getName(), allocator, field.getDictionary(), null); vector.initializeChildrenFromFields(field.getChildren()); fieldVectors.add(vector); - fieldVectorsMap.put(field.getName(), vector); } - this.fieldVectors = Collections.unmodifiableList(fieldVectors); - if (this.fieldVectors.size() != schema.getFields().size()) { - throw new IllegalArgumentException("The root vector did not create the right number of children. found " + fieldVectors.size() + " expected " + schema.getFields().size()); + if (fieldVectors.size() != schema.getFields().size()) { + throw new IllegalArgumentException("The root vector did not create the right number of children. found " + + fieldVectors.size() + " expected " + schema.getFields().size()); } + return new VectorSchemaRoot(schema.getFields(), fieldVectors, 0); } public List getFieldVectors() { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java b/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java index 92d8cb045ae31..8e9ff6d462c5c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java @@ -20,42 +20,27 @@ import java.util.ArrayList; import java.util.List; +import io.netty.buffer.ArrowBuf; import org.apache.arrow.vector.ValueVector.Accessor; import org.apache.arrow.vector.schema.ArrowFieldNode; import org.apache.arrow.vector.schema.ArrowRecordBatch; import org.apache.arrow.vector.schema.ArrowVectorType; -import org.apache.arrow.vector.types.pojo.Schema; - -import io.netty.buffer.ArrowBuf; public class VectorUnloader { - private final Schema schema; - private final int valueCount; - private final List vectors; - - public VectorUnloader(Schema schema, int valueCount, List vectors) { - super(); - this.schema = schema; - this.valueCount = valueCount; - this.vectors = vectors; - } + private final VectorSchemaRoot root; public VectorUnloader(VectorSchemaRoot root) { - this(root.getSchema(), root.getRowCount(), root.getFieldVectors()); - } - - public Schema getSchema() { - return schema; + this.root = root; } public ArrowRecordBatch getRecordBatch() { List nodes = new ArrayList<>(); List buffers = new ArrayList<>(); - for (FieldVector vector : vectors) { + for (FieldVector vector : root.getFieldVectors()) { appendNodes(vector, nodes, buffers); } - return new ArrowRecordBatch(valueCount, nodes, buffers); + return new ArrowRecordBatch(root.getRowCount(), nodes, buffers); } private void appendNodes(FieldVector vector, List nodes, List buffers) { @@ -74,4 +59,4 @@ private void appendNodes(FieldVector vector, List nodes, List T addOrGet(String name, MinorType minorType, Class clazz, int... precisionScale); + public abstract T addOrGet(String name, MinorType minorType, Class clazz, DictionaryEncoding dictionary, int... precisionScale); // return the child vector with the input name public abstract T getChild(String name, Class clazz); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractMapVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractMapVector.java index f030d166ade8d..baeeb07873714 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractMapVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractMapVector.java @@ -26,6 +26,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.util.CallBack; import org.apache.arrow.vector.util.MapWithOrdinal; @@ -110,7 +111,7 @@ public boolean allocateNewSafe() { * @return resultant {@link org.apache.arrow.vector.ValueVector} */ @Override - public T addOrGet(String name, MinorType minorType, Class clazz, int... precisionScale) { + public T addOrGet(String name, MinorType minorType, Class clazz, DictionaryEncoding dictionary, int... precisionScale) { final ValueVector existing = getChild(name); boolean create = false; if (existing == null) { @@ -122,7 +123,7 @@ public T addOrGet(String name, MinorType minorType, Clas create = true; } if (create) { - final T vector = clazz.cast(minorType.getNewVector(name, allocator, callBack, precisionScale)); + final T vector = clazz.cast(minorType.getNewVector(name, allocator, dictionary, callBack, precisionScale)); putChild(name, vector); if (callBack!=null) { callBack.doWork(); @@ -162,12 +163,12 @@ public T getChild(String name, Class clazz) { return typeify(v, clazz); } - protected ValueVector add(String name, MinorType minorType, int... precisionScale) { + protected ValueVector add(String name, MinorType minorType, DictionaryEncoding dictionary, int... precisionScale) { final ValueVector existing = getChild(name); if (existing != null) { throw new IllegalStateException(String.format("Vector already exists: Existing[%s], Requested[%s] ", existing.getClass().getSimpleName(), minorType)); } - FieldVector vector = minorType.getNewVector(name, allocator, callBack, precisionScale); + FieldVector vector = minorType.getNewVector(name, allocator, dictionary, callBack, precisionScale); putChild(name, vector); if (callBack!=null) { callBack.doWork(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java index 7424df474ae89..eeb8f5830f404 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java @@ -28,6 +28,7 @@ import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.util.SchemaChangeRuntimeException; import com.google.common.base.Preconditions; @@ -150,10 +151,10 @@ public int size() { return vector == DEFAULT_DATA_VECTOR ? 0:1; } - public AddOrGetResult addOrGetVector(MinorType minorType) { + public AddOrGetResult addOrGetVector(MinorType minorType, DictionaryEncoding dictionary) { boolean created = false; if (vector instanceof ZeroVector) { - vector = minorType.getNewVector(DATA_VECTOR_NAME, allocator, null); + vector = minorType.getNewVector(DATA_VECTOR_NAME, allocator, dictionary, null); // returned vector must have the same field created = true; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/DictionaryVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/DictionaryVector.java deleted file mode 100644 index 84760eadf2253..0000000000000 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/DictionaryVector.java +++ /dev/null @@ -1,229 +0,0 @@ -/******************************************************************************* - - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ******************************************************************************/ -package org.apache.arrow.vector.complex; - -import io.netty.buffer.ArrowBuf; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.OutOfMemoryException; -import org.apache.arrow.vector.NullableIntVector; -import org.apache.arrow.vector.ValueVector; -import org.apache.arrow.vector.complex.reader.FieldReader; -import org.apache.arrow.vector.types.Dictionary; -import org.apache.arrow.vector.types.Types.MinorType; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.util.TransferPair; - -import java.util.HashMap; -import java.util.Iterator; -import java.util.Map; - -public class DictionaryVector implements ValueVector { - - private ValueVector indices; - private Dictionary dictionary; - - public DictionaryVector(ValueVector indices, Dictionary dictionary) { - this.indices = indices; - this.dictionary = dictionary; - } - - /** - * Dictionary encodes a vector. The dictionary will be built using the values from the vector. - * - * @param vector vector to encode - * @return dictionary encoded vector - */ - public static DictionaryVector encode(ValueVector vector) { - validateType(vector.getMinorType()); - Map lookUps = new HashMap<>(); - Map transfers = new HashMap<>(); - - ValueVector.Accessor accessor = vector.getAccessor(); - int count = accessor.getValueCount(); - - NullableIntVector indices = new NullableIntVector(vector.getField().getName(), vector.getAllocator()); - indices.allocateNew(count); - NullableIntVector.Mutator mutator = indices.getMutator(); - - int nextIndex = 0; - for (int i = 0; i < count; i++) { - Object value = accessor.getObject(i); - if (value != null) { // if it's null leave it null - Integer index = lookUps.get(value); - if (index == null) { - index = nextIndex++; - lookUps.put(value, index); - transfers.put(i, index); - } - mutator.set(i, index); - } - } - mutator.setValueCount(count); - - // copy the dictionary values into the dictionary vector - TransferPair dictionaryTransfer = vector.getTransferPair(vector.getAllocator()); - ValueVector dictionaryVector = dictionaryTransfer.getTo(); - dictionaryVector.allocateNewSafe(); - for (Map.Entry entry: transfers.entrySet()) { - dictionaryTransfer.copyValueSafe(entry.getKey(), entry.getValue()); - } - dictionaryVector.getMutator().setValueCount(transfers.size()); - Dictionary dictionary = new Dictionary(dictionaryVector, false); - - return new DictionaryVector(indices, dictionary); - } - - /** - * Dictionary encodes a vector with a provided dictionary. The dictionary must contain all values in the vector. - * - * @param vector vector to encode - * @param dictionary dictionary used for encoding - * @return dictionary encoded vector - */ - public static DictionaryVector encode(ValueVector vector, Dictionary dictionary) { - validateType(vector.getMinorType()); - // load dictionary values into a hashmap for lookup - ValueVector.Accessor dictionaryAccessor = dictionary.getDictionary().getAccessor(); - Map lookUps = new HashMap<>(dictionaryAccessor.getValueCount()); - for (int i = 0; i < dictionaryAccessor.getValueCount(); i++) { - // for primitive array types we need a wrapper that implements equals and hashcode appropriately - lookUps.put(dictionaryAccessor.getObject(i), i); - } - - // vector to hold our indices (dictionary encoded values) - NullableIntVector indices = new NullableIntVector(vector.getField().getName(), vector.getAllocator()); - NullableIntVector.Mutator mutator = indices.getMutator(); - - ValueVector.Accessor accessor = vector.getAccessor(); - int count = accessor.getValueCount(); - - indices.allocateNew(count); - - for (int i = 0; i < count; i++) { - Object value = accessor.getObject(i); - if (value != null) { // if it's null leave it null - // note: this may fail if value was not included in the dictionary - mutator.set(i, lookUps.get(value)); - } - } - mutator.setValueCount(count); - - return new DictionaryVector(indices, dictionary); - } - - /** - * Decodes a dictionary encoded array using the provided dictionary. - * - * @param indices dictionary encoded values, must be int type - * @param dictionary dictionary used to decode the values - * @return vector with values restored from dictionary - */ - public static ValueVector decode(ValueVector indices, Dictionary dictionary) { - ValueVector.Accessor accessor = indices.getAccessor(); - int count = accessor.getValueCount(); - ValueVector dictionaryVector = dictionary.getDictionary(); - // copy the dictionary values into the decoded vector - TransferPair transfer = dictionaryVector.getTransferPair(indices.getAllocator()); - transfer.getTo().allocateNewSafe(); - for (int i = 0; i < count; i++) { - Object index = accessor.getObject(i); - if (index != null) { - transfer.copyValueSafe(((Number) index).intValue(), i); - } - } - - ValueVector decoded = transfer.getTo(); - decoded.getMutator().setValueCount(count); - return decoded; - } - - private static void validateType(MinorType type) { - // byte arrays don't work as keys in our dictionary map - we could wrap them with something to - // implement equals and hashcode if we want that functionality - if (type == MinorType.VARBINARY || type == MinorType.LIST || type == MinorType.MAP || type == MinorType.UNION) { - throw new IllegalArgumentException("Dictionary encoding for complex types not implemented"); - } - } - - public ValueVector getIndexVector() { return indices; } - - public ValueVector getDictionaryVector() { return dictionary.getDictionary(); } - - public Dictionary getDictionary() { return dictionary; } - - @Override - public MinorType getMinorType() { return indices.getMinorType(); } - - @Override - public Field getField() { return indices.getField(); } - - // note: dictionary vector is not closed, as it may be shared - @Override - public void close() { indices.close(); } - - @Override - public void allocateNew() throws OutOfMemoryException { indices.allocateNew(); } - - @Override - public boolean allocateNewSafe() { return indices.allocateNewSafe(); } - - @Override - public BufferAllocator getAllocator() { return indices.getAllocator(); } - - @Override - public void setInitialCapacity(int numRecords) { indices.setInitialCapacity(numRecords); } - - @Override - public int getValueCapacity() { return indices.getValueCapacity(); } - - @Override - public int getBufferSize() { return indices.getBufferSize(); } - - @Override - public int getBufferSizeFor(int valueCount) { return indices.getBufferSizeFor(valueCount); } - - @Override - public Iterator iterator() { - return indices.iterator(); - } - - @Override - public void clear() { indices.clear(); } - - @Override - public TransferPair getTransferPair(BufferAllocator allocator) { return indices.getTransferPair(allocator); } - - @Override - public TransferPair getTransferPair(String ref, BufferAllocator allocator) { return indices.getTransferPair(ref, allocator); } - - @Override - public TransferPair makeTransferPair(ValueVector target) { return indices.makeTransferPair(target); } - - @Override - public Accessor getAccessor() { return indices.getAccessor(); } - - @Override - public Mutator getMutator() { return indices.getMutator(); } - - @Override - public FieldReader getReader() { return indices.getReader(); } - - @Override - public ArrowBuf[] getBuffers(boolean clear) { return indices.getBuffers(clear); } -} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index 074b0aa7e58fa..a12440e39e8fe 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -24,6 +24,10 @@ import java.util.Collections; import java.util.List; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ObjectArrays; + +import io.netty.buffer.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.OutOfMemoryException; import org.apache.arrow.vector.AddOrGetResult; @@ -42,16 +46,12 @@ import org.apache.arrow.vector.schema.ArrowFieldNode; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.util.CallBack; import org.apache.arrow.vector.util.JsonStringArrayList; import org.apache.arrow.vector.util.TransferPair; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ObjectArrays; - -import io.netty.buffer.ArrowBuf; - public class ListVector extends BaseRepeatedValueVector implements FieldVector { final UInt4Vector offsets; @@ -62,14 +62,16 @@ public class ListVector extends BaseRepeatedValueVector implements FieldVector { private UnionListWriter writer; private UnionListReader reader; private CallBack callBack; + private final DictionaryEncoding dictionary; - public ListVector(String name, BufferAllocator allocator, CallBack callBack) { + public ListVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack) { super(name, allocator); this.bits = new BitVector("$bits$", allocator); this.offsets = getOffsetVector(); this.innerVectors = Collections.unmodifiableList(Arrays.asList(bits, offsets)); this.writer = new UnionListWriter(this); this.reader = new UnionListReader(this); + this.dictionary = dictionary; this.callBack = callBack; } @@ -80,7 +82,7 @@ public void initializeChildrenFromFields(List children) { } Field field = children.get(0); MinorType minorType = Types.getMinorTypeForArrowType(field.getType()); - AddOrGetResult addOrGetVector = addOrGetVector(minorType); + AddOrGetResult addOrGetVector = addOrGetVector(minorType, field.getDictionary()); if (!addOrGetVector.isCreated()) { throw new IllegalArgumentException("Child vector already existed: " + addOrGetVector.getVector()); } @@ -151,16 +153,16 @@ private class TransferImpl implements TransferPair { TransferPair pairs[] = new TransferPair[3]; public TransferImpl(String name, BufferAllocator allocator) { - this(new ListVector(name, allocator, null)); + this(new ListVector(name, allocator, dictionary, null)); } public TransferImpl(ListVector to) { this.to = to; - to.addOrGetVector(vector.getMinorType()); + to.addOrGetVector(vector.getMinorType(), vector.getField().getDictionary()); pairs[0] = offsets.makeTransferPair(to.offsets); pairs[1] = bits.makeTransferPair(to.bits); if (to.getDataVector() instanceof ZeroVector) { - to.addOrGetVector(vector.getMinorType()); + to.addOrGetVector(vector.getMinorType(), vector.getField().getDictionary()); } pairs[2] = getDataVector().makeTransferPair(to.getDataVector()); } @@ -232,8 +234,8 @@ public boolean allocateNewSafe() { return success; } - public AddOrGetResult addOrGetVector(MinorType minorType) { - AddOrGetResult result = super.addOrGetVector(minorType); + public AddOrGetResult addOrGetVector(MinorType minorType, DictionaryEncoding dictionary) { + AddOrGetResult result = super.addOrGetVector(minorType, dictionary); reader = new UnionListReader(this); return result; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java index 31a1bb74b8e98..4d750cad264db 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java @@ -160,7 +160,7 @@ protected MapTransferPair(MapVector from, MapVector to, boolean allocate) { // (This is similar to what happens in ScanBatch where the children cannot be added till they are // read). To take care of this, we ensure that the hashCode of the MaterializedField does not // include the hashCode of the children but is based only on MaterializedField$key. - final FieldVector newVector = to.addOrGet(child, vector.getMinorType(), vector.getClass()); + final FieldVector newVector = to.addOrGet(child, vector.getMinorType(), vector.getClass(), vector.getField().getDictionary()); if (allocate && to.size() != preSize) { newVector.allocateNew(); } @@ -314,12 +314,11 @@ public void close() { public void initializeChildrenFromFields(List children) { for (Field field : children) { MinorType minorType = Types.getMinorTypeForArrowType(field.getType()); - FieldVector vector = (FieldVector)this.add(field.getName(), minorType); + FieldVector vector = (FieldVector)this.add(field.getName(), minorType, field.getDictionary()); vector.initializeChildrenFromFields(field.getChildren()); } } - public List getChildrenFromFields() { return getChildren(); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java index 5fa35307ab683..bb1fdf841a305 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java @@ -34,6 +34,7 @@ import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holders.ComplexHolder; import org.apache.arrow.vector.schema.ArrowFieldNode; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.util.CallBack; import org.apache.arrow.vector.util.TransferPair; @@ -48,14 +49,16 @@ public class NullableMapVector extends MapVector implements FieldVector { protected final BitVector bits; private final List innerVectors; + private final DictionaryEncoding dictionary; private final Accessor accessor; private final Mutator mutator; - public NullableMapVector(String name, BufferAllocator allocator, CallBack callBack) { + public NullableMapVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack) { super(name, checkNotNull(allocator), callBack); this.bits = new BitVector("$bits$", allocator); this.innerVectors = Collections.unmodifiableList(Arrays.asList(bits)); + this.dictionary = dictionary; this.accessor = new Accessor(); this.mutator = new Mutator(); } @@ -83,7 +86,7 @@ public FieldReader getReader() { @Override public TransferPair getTransferPair(BufferAllocator allocator) { - return new NullableMapTransferPair(this, new NullableMapVector(name, allocator, callBack), false); + return new NullableMapTransferPair(this, new NullableMapVector(name, allocator, dictionary, callBack), false); } @Override @@ -93,7 +96,7 @@ public TransferPair makeTransferPair(ValueVector to) { @Override public TransferPair getTransferPair(String ref, BufferAllocator allocator) { - return new NullableMapTransferPair(this, new NullableMapVector(ref, allocator, callBack), false); + return new NullableMapTransferPair(this, new NullableMapVector(ref, allocator, dictionary, callBack), false); } protected class NullableMapTransferPair extends MapTransferPair { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java index dbdd2050d13ed..6d0531678488a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java @@ -149,7 +149,8 @@ public MapWriter rootAsMap() { switch(mode){ case INIT: - NullableMapVector map = container.addOrGet(name, MinorType.MAP, NullableMapVector.class); + // TODO allow dictionaries in complex types + NullableMapVector map = container.addOrGet(name, MinorType.MAP, NullableMapVector.class, null); mapRoot = nullableMapWriterFactory.build(map); mapRoot.setPosition(idx()); mode = Mode.MAP; @@ -180,7 +181,8 @@ public ListWriter rootAsList() { case INIT: int vectorCount = container.size(); - ListVector listVector = container.addOrGet(name, MinorType.LIST, ListVector.class); + // TODO allow dictionaries in complex types + ListVector listVector = container.addOrGet(name, MinorType.LIST, ListVector.class, null); if (container.size() > vectorCount) { listVector.allocateNew(); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java index 1f7253bca93c8..e33319a2270b1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java @@ -125,7 +125,7 @@ protected FieldWriter getWriter(MinorType type) { // ??? return null; } - ValueVector v = listVector.addOrGetVector(type).getVector(); + ValueVector v = listVector.addOrGetVector(type, null).getVector(); v.allocateNew(); setWriter(v); writer.setPosition(position); @@ -150,7 +150,8 @@ private FieldWriter promoteToUnion() { TransferPair tp = vector.getTransferPair(vector.getMinorType().name().toLowerCase(), vector.getAllocator()); tp.transfer(); if (parentContainer != null) { - unionVector = parentContainer.addOrGet(name, MinorType.UNION, UnionVector.class); + // TODO allow dictionaries in complex types + unionVector = parentContainer.addOrGet(name, MinorType.UNION, UnionVector.class, null); unionVector.allocateNew(); } else if (listVector != null) { unionVector = listVector.promoteToUnion(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java new file mode 100644 index 0000000000000..0c1cadfdafdbf --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java @@ -0,0 +1,66 @@ +/******************************************************************************* + + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.apache.arrow.vector.dictionary; + +import java.util.Objects; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; + +public class Dictionary { + + private final DictionaryEncoding encoding; + private final FieldVector dictionary; + + public Dictionary(FieldVector dictionary, DictionaryEncoding encoding) { + this.dictionary = dictionary; + this.encoding = encoding; + } + + public FieldVector getVector() { + return dictionary; + } + + public DictionaryEncoding getEncoding() { + return encoding; + } + + public ArrowType getVectorType() { + return dictionary.getField().getType(); + } + + @Override + public String toString() { + return "Dictionary " + encoding + " " + dictionary; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Dictionary that = (Dictionary) o; + return Objects.equals(encoding, that.encoding) && Objects.equals(dictionary, that.dictionary); + } + + @Override + public int hashCode() { + return Objects.hash(encoding, dictionary); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java new file mode 100644 index 0000000000000..0666bc4137a9d --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java @@ -0,0 +1,144 @@ +/******************************************************************************* + + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.apache.arrow.vector.dictionary; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.Map; + +import com.google.common.collect.ImmutableList; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.util.TransferPair; + +public class DictionaryEncoder { + + // TODO recursively examine fields? + + /** + * Dictionary encodes a vector with a provided dictionary. The dictionary must contain all values in the vector. + * + * @param vector vector to encode + * @param dictionary dictionary used for encoding + * @return dictionary encoded vector + */ + public static ValueVector encode(ValueVector vector, Dictionary dictionary) { + validateType(vector.getMinorType()); + // load dictionary values into a hashmap for lookup + ValueVector.Accessor dictionaryAccessor = dictionary.getVector().getAccessor(); + Map lookUps = new HashMap<>(dictionaryAccessor.getValueCount()); + for (int i = 0; i < dictionaryAccessor.getValueCount(); i++) { + // for primitive array types we need a wrapper that implements equals and hashcode appropriately + lookUps.put(dictionaryAccessor.getObject(i), i); + } + + Field valueField = vector.getField(); + Field indexField = new Field(valueField.getName(), valueField.isNullable(), + dictionary.getEncoding().getIndexType(), dictionary.getEncoding(), null); + + // vector to hold our indices (dictionary encoded values) + FieldVector indices = indexField.createVector(vector.getAllocator()); + ValueVector.Mutator mutator = indices.getMutator(); + + // use reflection to pull out the set method + // TODO implement a common interface for int vectors + Method setter = null; + for (Class c: ImmutableList.of(int.class, long.class)) { + try { + setter = mutator.getClass().getMethod("set", int.class, c); + break; + } catch(NoSuchMethodException e) { + // ignore + } + } + if (setter == null) { + throw new IllegalArgumentException("Dictionary encoding does not have a valid int type:" + indices.getClass()); + } + + ValueVector.Accessor accessor = vector.getAccessor(); + int count = accessor.getValueCount(); + + indices.allocateNew(); + + try { + for (int i = 0; i < count; i++) { + Object value = accessor.getObject(i); + if (value != null) { // if it's null leave it null + // note: this may fail if value was not included in the dictionary + Object encoded = lookUps.get(value); + if (encoded == null) { + throw new IllegalArgumentException("Dictionary encoding not defined for value:" + value); + } + setter.invoke(mutator, i, encoded); + } + } + } catch (IllegalAccessException e) { + throw new RuntimeException("IllegalAccessException invoking vector mutator set():", e); + } catch (InvocationTargetException e) { + throw new RuntimeException("InvocationTargetException invoking vector mutator set():", e.getCause()); + } + + mutator.setValueCount(count); + + return indices; + } + + /** + * Decodes a dictionary encoded array using the provided dictionary. + * + * @param indices dictionary encoded values, must be int type + * @param dictionary dictionary used to decode the values + * @return vector with values restored from dictionary + */ + public static ValueVector decode(ValueVector indices, Dictionary dictionary) { + ValueVector.Accessor accessor = indices.getAccessor(); + int count = accessor.getValueCount(); + ValueVector dictionaryVector = dictionary.getVector(); + int dictionaryCount = dictionaryVector.getAccessor().getValueCount(); + // copy the dictionary values into the decoded vector + TransferPair transfer = dictionaryVector.getTransferPair(indices.getAllocator()); + transfer.getTo().allocateNewSafe(); + for (int i = 0; i < count; i++) { + Object index = accessor.getObject(i); + if (index != null) { + int indexAsInt = ((Number) index).intValue(); + if (indexAsInt > dictionaryCount) { + throw new IllegalArgumentException("Provided dictionary does not contain value for index " + indexAsInt); + } + transfer.copyValueSafe(indexAsInt, i); + } + } + // TODO do we need to worry about the field? + ValueVector decoded = transfer.getTo(); + decoded.getMutator().setValueCount(count); + return decoded; + } + + private static void validateType(MinorType type) { + // byte arrays don't work as keys in our dictionary map - we could wrap them with something to + // implement equals and hashcode if we want that functionality + if (type == MinorType.VARBINARY || type == MinorType.LIST || type == MinorType.MAP || type == MinorType.UNION) { + throw new IllegalArgumentException("Dictionary encoding for complex types not implemented: type " + type); + } + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java new file mode 100644 index 0000000000000..63fde2536da8b --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.dictionary; + +import java.util.HashMap; +import java.util.Map; + +public interface DictionaryProvider { + + public Dictionary lookup(long id); + + public static class MapDictionaryProvider implements DictionaryProvider { + + private final Map map; + + public MapDictionaryProvider(Dictionary... dictionaries) { + this.map = new HashMap<>(); + for (Dictionary dictionary: dictionaries) { + put(dictionary); + } + } + + public void put(Dictionary dictionary) { + map.put(dictionary.getEncoding().getId(), dictionary); + } + + @Override + public Dictionary lookup(long id) { + return map.get(id); + } + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java new file mode 100644 index 0000000000000..28440a190ad43 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java @@ -0,0 +1,142 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.file; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SeekableByteChannel; +import java.util.Arrays; +import java.util.List; + +import org.apache.arrow.flatbuf.Footer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.schema.ArrowDictionaryBatch; +import org.apache.arrow.vector.schema.ArrowMessage; +import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.stream.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ArrowFileReader extends ArrowReader { + + private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFileReader.class); + + private ArrowFooter footer; + private int currentDictionaryBatch = 0; + private int currentRecordBatch = 0; + + public ArrowFileReader(SeekableByteChannel in, BufferAllocator allocator) { + super(new SeekableReadChannel(in), allocator); + } + + public ArrowFileReader(SeekableReadChannel in, BufferAllocator allocator) { + super(in, allocator); + } + + @Override + protected Schema readSchema(SeekableReadChannel in) throws IOException { + if (footer == null) { + if (in.size() <= (ArrowMagic.MAGIC_LENGTH * 2 + 4)) { + throw new InvalidArrowFileException("file too small: " + in.size()); + } + ByteBuffer buffer = ByteBuffer.allocate(4 + ArrowMagic.MAGIC_LENGTH); + long footerLengthOffset = in.size() - buffer.remaining(); + in.setPosition(footerLengthOffset); + in.readFully(buffer); + buffer.flip(); + byte[] array = buffer.array(); + if (!ArrowMagic.validateMagic(Arrays.copyOfRange(array, 4, array.length))) { + throw new InvalidArrowFileException("missing Magic number " + Arrays.toString(buffer.array())); + } + int footerLength = MessageSerializer.bytesToInt(array); + if (footerLength <= 0 || footerLength + ArrowMagic.MAGIC_LENGTH * 2 + 4 > in.size()) { + throw new InvalidArrowFileException("invalid footer length: " + footerLength); + } + long footerOffset = footerLengthOffset - footerLength; + LOGGER.debug(String.format("Footer starts at %d, length: %d", footerOffset, footerLength)); + ByteBuffer footerBuffer = ByteBuffer.allocate(footerLength); + in.setPosition(footerOffset); + in.readFully(footerBuffer); + footerBuffer.flip(); + Footer footerFB = Footer.getRootAsFooter(footerBuffer); + this.footer = new ArrowFooter(footerFB); + } + return footer.getSchema(); + } + + @Override + protected ArrowMessage readMessage(SeekableReadChannel in, BufferAllocator allocator) throws IOException { + if (currentDictionaryBatch < footer.getDictionaries().size()) { + ArrowBlock block = footer.getDictionaries().get(currentDictionaryBatch++); + return readDictionaryBatch(in, block, allocator); + } else if (currentRecordBatch < footer.getRecordBatches().size()) { + ArrowBlock block = footer.getRecordBatches().get(currentRecordBatch++); + return readRecordBatch(in, block, allocator); + } else { + return null; + } + } + + public List getDictionaryBlocks() throws IOException { + ensureInitialized(); + return footer.getDictionaries(); + } + + public List getRecordBlocks() throws IOException { + ensureInitialized(); + return footer.getRecordBatches(); + } + + public void loadRecordBatch(ArrowBlock block) throws IOException { + ensureInitialized(); + int blockIndex = footer.getRecordBatches().indexOf(block); + if (blockIndex == -1) { + throw new IllegalArgumentException("Arrow bock does not exist in record batches: " + block); + } + currentRecordBatch = blockIndex; + loadNextBatch(); + } + + private ArrowDictionaryBatch readDictionaryBatch(SeekableReadChannel in, + ArrowBlock block, + BufferAllocator allocator) throws IOException { + LOGGER.debug(String.format("DictionaryRecordBatch at %d, metadata: %d, body: %d", + block.getOffset(), block.getMetadataLength(), block.getBodyLength())); + in.setPosition(block.getOffset()); + ArrowDictionaryBatch batch = MessageSerializer.deserializeDictionaryBatch(in, block, allocator); + if (batch == null) { + throw new IOException("Invalid file. No batch at offset: " + block.getOffset()); + } + return batch; + } + + private ArrowRecordBatch readRecordBatch(SeekableReadChannel in, + ArrowBlock block, + BufferAllocator allocator) throws IOException { + LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d", + block.getOffset(), block.getMetadataLength(), + block.getBodyLength())); + in.setPosition(block.getOffset()); + ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(in, block, allocator); + if (batch == null) { + throw new IOException("Invalid file. No batch at offset: " + block.getOffset()); + } + return batch; + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileWriter.java new file mode 100644 index 0000000000000..23d210a3ee73b --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileWriter.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.file; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; +import java.util.List; + +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.Schema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ArrowFileWriter extends ArrowWriter { + + private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFileWriter.class); + + public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + super(root, provider, out); + } + + @Override + protected void startInternal(WriteChannel out) throws IOException { + ArrowMagic.writeMagic(out); + } + + @Override + protected void endInternal(WriteChannel out, + Schema schema, + List dictionaries, + List records) throws IOException { + long footerStart = out.getCurrentPosition(); + out.write(new ArrowFooter(schema, dictionaries, records), false); + int footerLength = (int)(out.getCurrentPosition() - footerStart); + if (footerLength <= 0) { + throw new InvalidArrowFileException("invalid footer"); + } + out.writeIntLittleEndian(footerLength); + LOGGER.debug(String.format("Footer starts at %d, length: %d", footerStart, footerLength)); + ArrowMagic.writeMagic(out); + LOGGER.debug(String.format("magic written, now at %d", out.getCurrentPosition())); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java index 38903068570c7..1c0008a9184a0 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java @@ -38,7 +38,6 @@ public class ArrowFooter implements FBSerializable { private final List recordBatches; public ArrowFooter(Schema schema, List dictionaries, List recordBatches) { - super(); this.schema = schema; this.dictionaries = dictionaries; this.recordBatches = recordBatches; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowMagic.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowMagic.java new file mode 100644 index 0000000000000..99ea96b3856d5 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowMagic.java @@ -0,0 +1,37 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.file; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; + +public class ArrowMagic { + + private static final byte[] MAGIC = "ARROW1".getBytes(StandardCharsets.UTF_8); + + public static final int MAGIC_LENGTH = MAGIC.length; + + public static void writeMagic(WriteChannel out) throws IOException { + out.write(MAGIC); + } + + public static boolean validateMagic(byte[] array) { + return Arrays.equals(MAGIC, array); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java index 8f4f4978d66cf..1646fbe803687 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java @@ -18,90 +18,188 @@ package org.apache.arrow.vector.file; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.SeekableByteChannel; -import java.util.Arrays; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.google.common.collect.ImmutableList; -import org.apache.arrow.flatbuf.Footer; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.schema.ArrowDictionaryBatch; +import org.apache.arrow.vector.schema.ArrowMessage; +import org.apache.arrow.vector.schema.ArrowMessage.ArrowMessageVisitor; import org.apache.arrow.vector.schema.ArrowRecordBatch; -import org.apache.arrow.vector.stream.MessageSerializer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class ArrowReader implements AutoCloseable { - private static final Logger LOGGER = LoggerFactory.getLogger(ArrowReader.class); - - public static final byte[] MAGIC = "ARROW1".getBytes(); +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.Int; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; - private final SeekableByteChannel in; +public abstract class ArrowReader implements DictionaryProvider, AutoCloseable { + private final T in; private final BufferAllocator allocator; - private ArrowFooter footer; + private VectorLoader loader; + private VectorSchemaRoot root; + private Map dictionaries; - public ArrowReader(SeekableByteChannel in, BufferAllocator allocator) { - super(); + private boolean initialized = false; + + protected ArrowReader(T in, BufferAllocator allocator) { this.in = in; this.allocator = allocator; } - private int readFully(ByteBuffer buffer) throws IOException { - int total = 0; - int n; - do { - n = in.read(buffer); - total += n; - } while (n >= 0 && buffer.remaining() > 0); - buffer.flip(); - return total; + /** + * Returns the vector schema root. This will be loaded with new values on every call to loadNextBatch + * + * @return the vector schema root + * @throws IOException if reading of schema fails + */ + public VectorSchemaRoot getVectorSchemaRoot() throws IOException { + ensureInitialized(); + return root; } - public ArrowFooter readFooter() throws IOException { - if (footer == null) { - if (in.size() <= (MAGIC.length * 2 + 4)) { - throw new InvalidArrowFileException("file too small: " + in.size()); - } - ByteBuffer buffer = ByteBuffer.allocate(4 + MAGIC.length); - long footerLengthOffset = in.size() - buffer.remaining(); - in.position(footerLengthOffset); - readFully(buffer); - byte[] array = buffer.array(); - if (!Arrays.equals(MAGIC, Arrays.copyOfRange(array, 4, array.length))) { - throw new InvalidArrowFileException("missing Magic number " + Arrays.toString(buffer.array())); - } - int footerLength = MessageSerializer.bytesToInt(array); - if (footerLength <= 0 || footerLength + MAGIC.length * 2 + 4 > in.size()) { - throw new InvalidArrowFileException("invalid footer length: " + footerLength); - } - long footerOffset = footerLengthOffset - footerLength; - LOGGER.debug(String.format("Footer starts at %d, length: %d", footerOffset, footerLength)); - ByteBuffer footerBuffer = ByteBuffer.allocate(footerLength); - in.position(footerOffset); - readFully(footerBuffer); - Footer footerFB = Footer.getRootAsFooter(footerBuffer); - this.footer = new ArrowFooter(footerFB); + /** + * Returns any dictionaries + * + * @return dictionaries, if any + * @throws IOException if reading of schema fails + */ + public Map getDictionaryVectors() throws IOException { + ensureInitialized(); + return dictionaries; + } + + @Override + public Dictionary lookup(long id) { + if (initialized) { + return dictionaries.get(id); + } else { + return null; } - return footer; } - // TODO: read dictionaries - - public ArrowRecordBatch readRecordBatch(ArrowBlock block) throws IOException { - LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d", - block.getOffset(), block.getMetadataLength(), - block.getBodyLength())); - in.position(block.getOffset()); - ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch( - new ReadChannel(in, block.getOffset()), block, allocator); - if (batch == null) { - throw new IOException("Invalid file. No batch at offset: " + block.getOffset()); + public void loadNextBatch() throws IOException { + ensureInitialized(); + // read in all dictionary batches, then stop after our first record batch + ArrowMessageVisitor visitor = new ArrowMessageVisitor() { + @Override + public Boolean visit(ArrowDictionaryBatch message) { + try { load(message); } finally { message.close(); } + return true; + } + @Override + public Boolean visit(ArrowRecordBatch message) { + try { loader.load(message); } finally { message.close(); } + return false; + } + }; + root.setRowCount(0); + ArrowMessage message = readMessage(in, allocator); + while (message != null && message.accepts(visitor)) { + message = readMessage(in, allocator); } - return batch; } + public long bytesRead() { return in.bytesRead(); } + @Override public void close() throws IOException { + if (initialized) { + root.close(); + for (Dictionary dictionary: dictionaries.values()) { + dictionary.getVector().close(); + } + } in.close(); } + + protected abstract Schema readSchema(T in) throws IOException; + + protected abstract ArrowMessage readMessage(T in, BufferAllocator allocator) throws IOException; + + protected void ensureInitialized() throws IOException { + if (!initialized) { + initialize(); + initialized = true; + } + } + + /** + * Reads the schema and initializes the vectors + */ + private void initialize() throws IOException { + Schema schema = readSchema(in); + List fields = new ArrayList<>(); + List vectors = new ArrayList<>(); + Map dictionaries = new HashMap<>(); + + for (Field field: schema.getFields()) { + Field updated = toMemoryFormat(field, dictionaries); + fields.add(updated); + vectors.add(updated.createVector(allocator)); + } + + this.root = new VectorSchemaRoot(fields, vectors, 0); + this.loader = new VectorLoader(root); + this.dictionaries = Collections.unmodifiableMap(dictionaries); + } + + // in the message format, fields have the dictionary type + // in the memory format, they have the index type + private Field toMemoryFormat(Field field, Map dictionaries) { + DictionaryEncoding encoding = field.getDictionary(); + List children = field.getChildren(); + + if (encoding == null && children.isEmpty()) { + return field; + } + + List updatedChildren = new ArrayList<>(children.size()); + for (Field child: children) { + updatedChildren.add(toMemoryFormat(child, dictionaries)); + } + + ArrowType type; + if (encoding == null) { + type = field.getType(); + } else { + // re-type the field for in-memory format + type = encoding.getIndexType(); + if (type == null) { + type = new Int(32, true); + } + // get existing or create dictionary vector + if (!dictionaries.containsKey(encoding.getId())) { + // create a new dictionary vector for the values + Field dictionaryField = new Field(field.getName(), field.isNullable(), field.getType(), null, children); + FieldVector dictionaryVector = dictionaryField.createVector(allocator); + dictionaries.put(encoding.getId(), new Dictionary(dictionaryVector, encoding)); + } + } + + return new Field(field.getName(), field.isNullable(), type, encoding, updatedChildren); + } + + private void load(ArrowDictionaryBatch dictionaryBatch) { + long id = dictionaryBatch.getDictionaryId(); + Dictionary dictionary = dictionaries.get(id); + if (dictionary == null) { + throw new IllegalArgumentException("Dictionary ID " + id + " not defined in schema"); + } + FieldVector vector = dictionary.getVector(); + VectorSchemaRoot root = new VectorSchemaRoot(ImmutableList.of(vector.getField()), ImmutableList.of(vector), 0); + VectorLoader loader = new VectorLoader(root); + loader.load(dictionaryBatch.getDictionary()); + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java index 24c667e67d98d..60a6afb565318 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -21,77 +21,172 @@ import java.nio.channels.WritableByteChannel; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import com.google.common.collect.ImmutableList; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.schema.ArrowDictionaryBatch; import org.apache.arrow.vector.schema.ArrowRecordBatch; import org.apache.arrow.vector.stream.MessageSerializer; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ArrowWriter implements AutoCloseable { +public abstract class ArrowWriter implements AutoCloseable { + private static final Logger LOGGER = LoggerFactory.getLogger(ArrowWriter.class); + // schema with fields in message format, not memory format + private final Schema schema; private final WriteChannel out; - private final Schema schema; + private final VectorUnloader unloader; + private final List dictionaries; + + private final List dictionaryBlocks = new ArrayList<>(); + private final List recordBlocks = new ArrayList<>(); - private final List recordBatches = new ArrayList<>(); private boolean started = false; + private boolean ended = false; - public ArrowWriter(WritableByteChannel out, Schema schema) { + /** + * Note: fields are not closed when the writer is closed + * + * @param root + * @param provider + * @param out + */ + protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + this.unloader = new VectorUnloader(root); this.out = new WriteChannel(out); - this.schema = schema; + + List fields = new ArrayList<>(root.getSchema().getFields().size()); + Map dictionaryBatches = new HashMap<>(); + + for (Field field: root.getSchema().getFields()) { + fields.add(toMessageFormat(field, provider, dictionaryBatches)); + } + + this.schema = new Schema(fields); + this.dictionaries = Collections.unmodifiableList(new ArrayList<>(dictionaryBatches.values())); + } + + // in the message format, fields have the dictionary type + // in the memory format, they have the index type + private Field toMessageFormat(Field field, DictionaryProvider provider, Map batches) { + DictionaryEncoding encoding = field.getDictionary(); + List children = field.getChildren(); + + if (encoding == null && children.isEmpty()) { + return field; + } + + List updatedChildren = new ArrayList<>(children.size()); + for (Field child: children) { + updatedChildren.add(toMessageFormat(child, provider, batches)); + } + + ArrowType type; + if (encoding == null) { + type = field.getType(); + } else { + long id = encoding.getId(); + Dictionary dictionary = provider.lookup(id); + if (dictionary == null) { + throw new IllegalArgumentException("Could not find dictionary with ID " + id); + } + type = dictionary.getVectorType(); + + if (!batches.containsKey(id)) { + FieldVector vector = dictionary.getVector(); + int count = vector.getAccessor().getValueCount(); + VectorSchemaRoot root = new VectorSchemaRoot(ImmutableList.of(field), ImmutableList.of(vector), count); + VectorUnloader unloader = new VectorUnloader(root); + ArrowRecordBatch batch = unloader.getRecordBatch(); + batches.put(id, new ArrowDictionaryBatch(id, batch)); + } + } + + return new Field(field.getName(), field.isNullable(), type, encoding, updatedChildren); } - private void start() throws IOException { - writeMagic(); - MessageSerializer.serialize(out, schema); + public void start() throws IOException { + ensureStarted(); } - // TODO: write dictionaries + public void writeBatch() throws IOException { + ensureStarted(); + try (ArrowRecordBatch batch = unloader.getRecordBatch()) { + writeRecordBatch(batch); + } + } - public void writeRecordBatch(ArrowRecordBatch recordBatch) throws IOException { - checkStarted(); - ArrowBlock batchDesc = MessageSerializer.serialize(out, recordBatch); + protected void writeRecordBatch(ArrowRecordBatch batch) throws IOException { + ArrowBlock block = MessageSerializer.serialize(out, batch); LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d", - batchDesc.getOffset(), batchDesc.getMetadataLength(), batchDesc.getBodyLength())); + block.getOffset(), block.getMetadataLength(), block.getBodyLength())); + recordBlocks.add(block); + } - // add metadata to footer - recordBatches.add(batchDesc); + public void end() throws IOException { + ensureStarted(); + ensureEnded(); } - private void checkStarted() throws IOException { + public long bytesWritten() { return out.getCurrentPosition(); } + + private void ensureStarted() throws IOException { if (!started) { started = true; - start(); + startInternal(out); + // write the schema - for file formats this is duplicated in the footer, but matches + // the streaming format + MessageSerializer.serialize(out, schema); + // write out any dictionaries + for (ArrowDictionaryBatch batch : dictionaries) { + try { + ArrowBlock block = MessageSerializer.serialize(out, batch); + LOGGER.debug(String.format("DictionaryRecordBatch at %d, metadata: %d, body: %d", + block.getOffset(), block.getMetadataLength(), block.getBodyLength())); + dictionaryBlocks.add(block); + } finally { + batch.close(); + } + } } } - @Override - public void close() throws IOException { - try { - long footerStart = out.getCurrentPosition(); - writeFooter(); - int footerLength = (int)(out.getCurrentPosition() - footerStart); - if (footerLength <= 0 ) { - throw new InvalidArrowFileException("invalid footer"); - } - out.writeIntLittleEndian(footerLength); - LOGGER.debug(String.format("Footer starts at %d, length: %d", footerStart, footerLength)); - writeMagic(); - } finally { - out.close(); + private void ensureEnded() throws IOException { + if (!ended) { + ended = true; + endInternal(out, schema, dictionaryBlocks, recordBlocks); } } - private void writeMagic() throws IOException { - out.write(ArrowReader.MAGIC); - LOGGER.debug(String.format("magic written, now at %d", out.getCurrentPosition())); - } + protected abstract void startInternal(WriteChannel out) throws IOException; + + protected abstract void endInternal(WriteChannel out, + Schema schema, + List dictionaries, + List records) throws IOException; - private void writeFooter() throws IOException { - // TODO: dictionaries - out.write(new ArrowFooter(schema, Collections.emptyList(), recordBatches), false); + @Override + public void close() { + try { + end(); + out.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java index a9dc1293b8193..b062f3826eab3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java @@ -32,16 +32,9 @@ public class ReadChannel implements AutoCloseable { private ReadableByteChannel in; private long bytesRead = 0; - // The starting byte offset into 'in'. - private final long startByteOffset; - - public ReadChannel(ReadableByteChannel in, long startByteOffset) { - this.in = in; - this.startByteOffset = startByteOffset; - } public ReadChannel(ReadableByteChannel in) { - this(in, 0); + this.in = in; } public long bytesRead() { return bytesRead; } @@ -72,8 +65,6 @@ public int readFully(ArrowBuf buffer, int l) throws IOException { return n; } - public long getCurrentPositiion() { return startByteOffset + bytesRead; } - @Override public void close() throws IOException { if (this.in != null) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/file/SeekableReadChannel.java similarity index 57% rename from java/vector/src/main/java/org/apache/arrow/vector/types/Dictionary.java rename to java/vector/src/main/java/org/apache/arrow/vector/file/SeekableReadChannel.java index fbe1345f96aa3..914c3cb4b33a9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/Dictionary.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/SeekableReadChannel.java @@ -1,5 +1,4 @@ -/******************************************************************************* - +/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -15,26 +14,26 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - ******************************************************************************/ -package org.apache.arrow.vector.types; + */ +package org.apache.arrow.vector.file; -import org.apache.arrow.vector.ValueVector; +import java.io.IOException; +import java.nio.channels.SeekableByteChannel; -public class Dictionary { +public class SeekableReadChannel extends ReadChannel { - private ValueVector dictionary; - private boolean ordered; + private final SeekableByteChannel in; - public Dictionary(ValueVector dictionary, boolean ordered) { - this.dictionary = dictionary; - this.ordered = ordered; + public SeekableReadChannel(SeekableByteChannel in) { + super(in); + this.in = in; } - public ValueVector getDictionary() { - return dictionary; + public void setPosition(long position) throws IOException { + in.position(position); } - public boolean isOrdered() { - return ordered; + public long size() throws IOException { + return in.size(); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java index d99c9a6c99958..42104d181a2d0 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java @@ -21,13 +21,12 @@ import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; -import org.apache.arrow.vector.schema.FBSerializable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.google.flatbuffers.FlatBufferBuilder; import io.netty.buffer.ArrowBuf; +import org.apache.arrow.vector.schema.FBSerializable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Wrapper around a WritableByteChannel that maintains the position as well adding diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java index 24fdc184523b3..bdb63b92cb105 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java @@ -88,10 +88,34 @@ public Schema start() throws JsonParseException, IOException { } } + public void read(VectorSchemaRoot root) throws IOException { + JsonToken t = parser.nextToken(); + if (t == START_OBJECT) { + { + int count = readNextField("count", Integer.class); + root.setRowCount(count); + nextFieldIs("columns"); + readToken(START_ARRAY); + { + for (Field field : schema.getFields()) { + FieldVector vector = root.getVector(field.getName()); + readVector(field, vector); + } + } + readToken(END_ARRAY); + } + readToken(END_OBJECT); + } else if (t == END_ARRAY) { + root.setRowCount(0); + } else { + throw new IllegalArgumentException("Invalid token: " + t); + } + } + public VectorSchemaRoot read() throws IOException { JsonToken t = parser.nextToken(); if (t == START_OBJECT) { - VectorSchemaRoot recordBatch = new VectorSchemaRoot(schema, allocator); + VectorSchemaRoot recordBatch = VectorSchemaRoot.create(schema, allocator); { int count = readNextField("count", Integer.class); recordBatch.setRowCount(count); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowDictionaryBatch.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowDictionaryBatch.java new file mode 100644 index 0000000000000..901877b7058cd --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowDictionaryBatch.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.schema; + +import com.google.flatbuffers.FlatBufferBuilder; +import org.apache.arrow.flatbuf.DictionaryBatch; + +public class ArrowDictionaryBatch implements ArrowMessage { + + private final long dictionaryId; + private final ArrowRecordBatch dictionary; + + public ArrowDictionaryBatch(long dictionaryId, ArrowRecordBatch dictionary) { + this.dictionaryId = dictionaryId; + this.dictionary = dictionary; + } + + public long getDictionaryId() { return dictionaryId; } + public ArrowRecordBatch getDictionary() { return dictionary; } + + @Override + public int writeTo(FlatBufferBuilder builder) { + int dataOffset = dictionary.writeTo(builder); + DictionaryBatch.startDictionaryBatch(builder); + DictionaryBatch.addId(builder, dictionaryId); + DictionaryBatch.addData(builder, dataOffset); + return DictionaryBatch.endDictionaryBatch(builder); + } + + @Override + public int computeBodyLength() { return dictionary.computeBodyLength(); } + + @Override + public T accepts(ArrowMessageVisitor visitor) { return visitor.visit(this); } + + @Override + public String toString() { + return "ArrowDictionaryBatch [dictionaryId=" + dictionaryId + ", dictionary=" + dictionary + "]"; + } + + @Override + public void close() { + dictionary.close(); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowMessage.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowMessage.java new file mode 100644 index 0000000000000..d307428889b0f --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowMessage.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.schema; + +public interface ArrowMessage extends FBSerializable, AutoCloseable { + + public int computeBodyLength(); + + public T accepts(ArrowMessageVisitor visitor); + + public static interface ArrowMessageVisitor { + public T visit(ArrowDictionaryBatch message); + public T visit(ArrowRecordBatch message); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java index 40c2fbfd984f8..6ef514e568d2d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java @@ -32,7 +32,8 @@ import io.netty.buffer.ArrowBuf; -public class ArrowRecordBatch implements FBSerializable, AutoCloseable { +public class ArrowRecordBatch implements ArrowMessage { + private static final Logger LOGGER = LoggerFactory.getLogger(ArrowRecordBatch.class); /** number of records */ @@ -113,9 +114,13 @@ public int writeTo(FlatBufferBuilder builder) { return RecordBatch.endRecordBatch(builder); } + @Override + public T accepts(ArrowMessageVisitor visitor) { return visitor.visit(this); } + /** * releases the buffers */ + @Override public void close() { if (!closed) { closed = true; @@ -134,6 +139,7 @@ public String toString() { /** * Computes the size of the serialized body for this recordBatch. */ + @Override public int computeBodyLength() { int size = 0; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java index f32966c5d5217..2deef37cd4e56 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java @@ -17,79 +17,43 @@ */ package org.apache.arrow.vector.stream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.channels.Channels; -import java.nio.channels.ReadableByteChannel; - import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.file.ArrowReader; import org.apache.arrow.vector.file.ReadChannel; -import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.schema.ArrowMessage; import org.apache.arrow.vector.types.pojo.Schema; -import com.google.common.base.Preconditions; +import java.io.IOException; +import java.io.InputStream; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; /** * This classes reads from an input stream and produces ArrowRecordBatches. */ -public class ArrowStreamReader implements AutoCloseable { - private ReadChannel in; - private final BufferAllocator allocator; - private Schema schema; - - /** - * Constructs a streaming read, reading bytes from 'in'. Non-blocking. - */ - public ArrowStreamReader(ReadableByteChannel in, BufferAllocator allocator) { - super(); - this.in = new ReadChannel(in); - this.allocator = allocator; - } - - public ArrowStreamReader(InputStream in, BufferAllocator allocator) { - this(Channels.newChannel(in), allocator); - } - - /** - * Initializes the reader. Must be called before the other APIs. This is blocking. - */ - public void init() throws IOException { - Preconditions.checkState(this.schema == null, "Cannot call init() more than once."); - this.schema = readSchema(); - } +public class ArrowStreamReader extends ArrowReader { - /** - * Returns the schema for all records in this stream. - */ - public Schema getSchema () { - Preconditions.checkState(this.schema != null, "Must call init() first."); - return schema; - } - - public long bytesRead() { return in.bytesRead(); } + /** + * Constructs a streaming read, reading bytes from 'in'. Non-blocking. + */ + public ArrowStreamReader(ReadableByteChannel in, BufferAllocator allocator) { + super(new ReadChannel(in), allocator); + } - /** - * Reads and returns the next ArrowRecordBatch. Returns null if this is the end - * of stream. - */ - public ArrowRecordBatch nextRecordBatch() throws IOException { - Preconditions.checkState(this.in != null, "Cannot call after close()"); - Preconditions.checkState(this.schema != null, "Must call init() first."); - return MessageSerializer.deserializeRecordBatch(in, allocator); - } + public ArrowStreamReader(InputStream in, BufferAllocator allocator) { + this(Channels.newChannel(in), allocator); + } - @Override - public void close() throws IOException { - if (this.in != null) { - in.close(); - in = null; + /** + * Reads the schema message from the beginning of the stream. + */ + @Override + protected Schema readSchema(ReadChannel in) throws IOException { + return MessageSerializer.deserializeSchema(in); } - } - /** - * Reads the schema message from the beginning of the stream. - */ - private Schema readSchema() throws IOException { - return MessageSerializer.deserializeSchema(in); - } + @Override + protected ArrowMessage readMessage(ReadChannel in, BufferAllocator allocator) throws IOException { + return MessageSerializer.deserializeMessageBatch(in, allocator); + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java index 60dc5861c9242..ea29cd99804c8 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java @@ -17,63 +17,40 @@ */ package org.apache.arrow.vector.stream; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.file.ArrowBlock; +import org.apache.arrow.vector.file.ArrowWriter; +import org.apache.arrow.vector.file.WriteChannel; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + import java.io.IOException; import java.io.OutputStream; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; +import java.util.List; -import org.apache.arrow.vector.file.WriteChannel; -import org.apache.arrow.vector.schema.ArrowRecordBatch; -import org.apache.arrow.vector.types.pojo.Schema; - -public class ArrowStreamWriter implements AutoCloseable { - private final WriteChannel out; - private final Schema schema; - private boolean headerSent = false; +public class ArrowStreamWriter extends ArrowWriter { - /** - * Creates the stream writer. non-blocking. - * totalBatches can be set if the writer knows beforehand. Can be -1 if unknown. - */ - public ArrowStreamWriter(WritableByteChannel out, Schema schema) { - this.out = new WriteChannel(out); - this.schema = schema; - } - - public ArrowStreamWriter(OutputStream out, Schema schema) - throws IOException { - this(Channels.newChannel(out), schema); - } - - public long bytesWritten() { return out.getCurrentPosition(); } - - public void writeRecordBatch(ArrowRecordBatch batch) throws IOException { - // Send the header if we have not yet. - checkAndSendHeader(); - MessageSerializer.serialize(out, batch); - } + public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, OutputStream out) { + this(root, provider, Channels.newChannel(out)); + } - /** - * End the stream. This is not required and this object can simply be closed. - */ - public void end() throws IOException { - checkAndSendHeader(); - out.writeIntLittleEndian(0); - } + public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + super(root, provider, out); + } - @Override - public void close() throws IOException { - // The header might not have been sent if this is an empty stream. Send it even in - // this case so readers see a valid empty stream. - checkAndSendHeader(); - out.close(); - } + @Override + protected void startInternal(WriteChannel out) throws IOException {} - private void checkAndSendHeader() throws IOException { - if (!headerSent) { - MessageSerializer.serialize(out, schema); - headerSent = true; + @Override + protected void endInternal(WriteChannel out, + Schema schema, + List dictionaries, + List records) throws IOException { + out.writeIntLittleEndian(0); } - } } - diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java index 92df2504bcb23..92a6c0c26ba6e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java @@ -22,7 +22,11 @@ import java.util.ArrayList; import java.util.List; +import com.google.flatbuffers.FlatBufferBuilder; + +import io.netty.buffer.ArrowBuf; import org.apache.arrow.flatbuf.Buffer; +import org.apache.arrow.flatbuf.DictionaryBatch; import org.apache.arrow.flatbuf.FieldNode; import org.apache.arrow.flatbuf.Message; import org.apache.arrow.flatbuf.MessageHeader; @@ -33,14 +37,12 @@ import org.apache.arrow.vector.file.ReadChannel; import org.apache.arrow.vector.file.WriteChannel; import org.apache.arrow.vector.schema.ArrowBuffer; +import org.apache.arrow.vector.schema.ArrowDictionaryBatch; import org.apache.arrow.vector.schema.ArrowFieldNode; +import org.apache.arrow.vector.schema.ArrowMessage; import org.apache.arrow.vector.schema.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; -import com.google.flatbuffers.FlatBufferBuilder; - -import io.netty.buffer.ArrowBuf; - /** * Utility class for serializing Messages. Messages are all serialized a similar way. * 1. 4 byte little endian message header prefix @@ -81,35 +83,39 @@ public static long serialize(WriteChannel out, Schema schema) throws IOException * Deserializes a schema object. Format is from serialize(). */ public static Schema deserializeSchema(ReadChannel in) throws IOException { - Message message = deserializeMessage(in, MessageHeader.Schema); + Message message = deserializeMessage(in); if (message == null) { throw new IOException("Unexpected end of input. Missing schema."); } + if (message.headerType() != MessageHeader.Schema) { + throw new IOException("Expected schema but header was " + message.headerType()); + } return Schema.convertSchema((org.apache.arrow.flatbuf.Schema) message.header(new org.apache.arrow.flatbuf.Schema())); } + /** * Serializes an ArrowRecordBatch. Returns the offset and length of the written batch. */ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) - throws IOException { + throws IOException { + long start = out.getCurrentPosition(); int bodyLength = batch.computeBodyLength(); FlatBufferBuilder builder = new FlatBufferBuilder(); int batchOffset = batch.writeTo(builder); - ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch, - batchOffset, bodyLength); + ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch, batchOffset, bodyLength); int metadataLength = serializedMessage.remaining(); - // Add extra padding bytes so that length prefix + metadata is a multiple - // of 8 after alignment - if ((start + metadataLength + 4) % 8 != 0) { - metadataLength += 8 - (start + metadataLength + 4) % 8; + // calculate alignment bytes so that metadata length points to the correct location after alignment + int padding = (int)((start + metadataLength + 4) % 8); + if (padding != 0) { + metadataLength += (8 - padding); } out.writeIntLittleEndian(metadataLength); @@ -118,6 +124,13 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) // Align the output to 8 byte boundary. out.align(); + long bufferLength = writeBatchBuffers(out, batch); + + // Metadata size in the Block account for the size prefix + return new ArrowBlock(start, metadataLength + 4, bufferLength); + } + + private static long writeBatchBuffers(WriteChannel out, ArrowRecordBatch batch) throws IOException { long bufferStart = out.getCurrentPosition(); List buffers = batch.getBuffers(); List buffersLayout = batch.getBuffersLayout(); @@ -135,22 +148,14 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) " != " + startPosition + layout.getSize()); } } - // Metadata size in the Block account for the size prefix - return new ArrowBlock(start, metadataLength + 4, out.getCurrentPosition() - bufferStart); + return out.getCurrentPosition() - bufferStart; } /** * Deserializes a RecordBatch */ - public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, - BufferAllocator alloc) throws IOException { - Message message = deserializeMessage(in, MessageHeader.RecordBatch); - if (message == null) return null; - - if (message.bodyLength() > Integer.MAX_VALUE) { - throw new IOException("Cannot currently deserialize record batches over 2GB"); - } - + private static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, Message message, BufferAllocator alloc) + throws IOException { RecordBatch recordBatchFB = (RecordBatch) message.header(new RecordBatch()); int bodyLength = (int) message.bodyLength(); @@ -191,9 +196,7 @@ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, ArrowBlock // Now read the body final ArrowBuf body = buffer.slice(block.getMetadataLength(), (int) totalLen - block.getMetadataLength()); - ArrowRecordBatch result = deserializeRecordBatch(recordBatchFB, body); - - return result; + return deserializeRecordBatch(recordBatchFB, body); } // Deserializes a record batch given the Flatbuffer metadata and in-memory body @@ -218,6 +221,106 @@ private static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB return arrowRecordBatch; } + /** + * Serializes a dictionary ArrowRecordBatch. Returns the offset and length of the written batch. + */ + public static ArrowBlock serialize(WriteChannel out, ArrowDictionaryBatch batch) throws IOException { + long start = out.getCurrentPosition(); + int bodyLength = batch.computeBodyLength(); + + FlatBufferBuilder builder = new FlatBufferBuilder(); + int batchOffset = batch.writeTo(builder); + + ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.DictionaryBatch, batchOffset, bodyLength); + + int metadataLength = serializedMessage.remaining(); + + // Add extra padding bytes so that length prefix + metadata is a multiple + // of 8 after alignment + if ((start + metadataLength + 4) % 8 != 0) { + metadataLength += 8 - (start + metadataLength + 4) % 8; + } + + out.writeIntLittleEndian(metadataLength); + out.write(serializedMessage); + + // Align the output to 8 byte boundary. + out.align(); + + // write the embedded record batch + long bufferLength = writeBatchBuffers(out, batch.getDictionary()); + + // Metadata size in the Block account for the size prefix + return new ArrowBlock(start, metadataLength + 4, bufferLength + 8); + } + + /** + * Deserializes a DictionaryBatch + */ + private static ArrowDictionaryBatch deserializeDictionaryBatch(ReadChannel in, + Message message, + BufferAllocator alloc) throws IOException { + DictionaryBatch dictionaryBatchFB = (DictionaryBatch) message.header(new DictionaryBatch()); + + int bodyLength = (int) message.bodyLength(); + + // Now read the record batch body + ArrowBuf body = alloc.buffer(bodyLength); + if (in.readFully(body, bodyLength) != bodyLength) { + throw new IOException("Unexpected end of input trying to read batch."); + } + ArrowRecordBatch recordBatch = deserializeRecordBatch(dictionaryBatchFB.data(), body); + return new ArrowDictionaryBatch(dictionaryBatchFB.id(), recordBatch); + } + + /** + * Deserializes a DictionaryBatch knowing the size of the entire message up front. This + * minimizes the number of reads to the underlying stream. + */ + public static ArrowDictionaryBatch deserializeDictionaryBatch(ReadChannel in, + ArrowBlock block, + BufferAllocator alloc) throws IOException { + // Metadata length contains integer prefix plus byte padding + long totalLen = block.getMetadataLength() + block.getBodyLength(); + + if (totalLen > Integer.MAX_VALUE) { + throw new IOException("Cannot currently deserialize record batches over 2GB"); + } + + ArrowBuf buffer = alloc.buffer((int) totalLen); + if (in.readFully(buffer, (int) totalLen) != totalLen) { + throw new IOException("Unexpected end of input trying to read batch."); + } + + ArrowBuf metadataBuffer = buffer.slice(4, block.getMetadataLength() - 4); + + Message messageFB = + Message.getRootAsMessage(metadataBuffer.nioBuffer().asReadOnlyBuffer()); + + DictionaryBatch dictionaryBatchFB = (DictionaryBatch) messageFB.header(new DictionaryBatch()); + + // Now read the body + final ArrowBuf body = buffer.slice(block.getMetadataLength(), + (int) totalLen - block.getMetadataLength()); + ArrowRecordBatch recordBatch = deserializeRecordBatch(dictionaryBatchFB.data(), body); + return new ArrowDictionaryBatch(dictionaryBatchFB.id(), recordBatch); + } + + public static ArrowMessage deserializeMessageBatch(ReadChannel in, BufferAllocator alloc) throws IOException { + Message message = deserializeMessage(in); + if (message == null) { + return null; + } else if (message.bodyLength() > Integer.MAX_VALUE) { + throw new IOException("Cannot currently deserialize record batches over 2GB"); + } + + switch (message.headerType()) { + case MessageHeader.RecordBatch: return deserializeRecordBatch(in, message, alloc); + case MessageHeader.DictionaryBatch: return deserializeDictionaryBatch(in, message, alloc); + default: throw new IOException("Unexpected message header type " + message.headerType()); + } + } + /** * Serializes a message header. */ @@ -232,7 +335,7 @@ private static ByteBuffer serializeMessage(FlatBufferBuilder builder, byte heade return builder.dataBuffer(); } - private static Message deserializeMessage(ReadChannel in, byte headerType) throws IOException { + private static Message deserializeMessage(ReadChannel in) throws IOException { // Read the message size. There is an i32 little endian prefix. ByteBuffer buffer = ByteBuffer.allocate(4); if (in.readFully(buffer) != 4) return null; @@ -246,11 +349,6 @@ private static Message deserializeMessage(ReadChannel in, byte headerType) throw } buffer.rewind(); - Message message = Message.getRootAsMessage(buffer); - if (message.headerType() != headerType) { - throw new IOException("Invalid message: expecting " + headerType + - ". Message contained: " + message.headerType()); - } - return message; + return Message.getRootAsMessage(buffer); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java index ab539d5dc3b6e..8f2d04224c0fd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java @@ -33,10 +33,10 @@ import org.apache.arrow.vector.NullableIntervalDayVector; import org.apache.arrow.vector.NullableIntervalYearVector; import org.apache.arrow.vector.NullableSmallIntVector; -import org.apache.arrow.vector.NullableTimeStampSecVector; -import org.apache.arrow.vector.NullableTimeStampMilliVector; import org.apache.arrow.vector.NullableTimeStampMicroVector; +import org.apache.arrow.vector.NullableTimeStampMilliVector; import org.apache.arrow.vector.NullableTimeStampNanoVector; +import org.apache.arrow.vector.NullableTimeStampSecVector; import org.apache.arrow.vector.NullableTimeVector; import org.apache.arrow.vector.NullableTinyIntVector; import org.apache.arrow.vector.NullableUInt1Vector; @@ -61,10 +61,10 @@ import org.apache.arrow.vector.complex.impl.IntervalYearWriterImpl; import org.apache.arrow.vector.complex.impl.NullableMapWriter; import org.apache.arrow.vector.complex.impl.SmallIntWriterImpl; -import org.apache.arrow.vector.complex.impl.TimeStampSecWriterImpl; -import org.apache.arrow.vector.complex.impl.TimeStampMilliWriterImpl; import org.apache.arrow.vector.complex.impl.TimeStampMicroWriterImpl; +import org.apache.arrow.vector.complex.impl.TimeStampMilliWriterImpl; import org.apache.arrow.vector.complex.impl.TimeStampNanoWriterImpl; +import org.apache.arrow.vector.complex.impl.TimeStampSecWriterImpl; import org.apache.arrow.vector.complex.impl.TimeWriterImpl; import org.apache.arrow.vector.complex.impl.TinyIntWriterImpl; import org.apache.arrow.vector.complex.impl.UInt1WriterImpl; @@ -92,6 +92,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType.Timestamp; import org.apache.arrow.vector.types.pojo.ArrowType.Union; import org.apache.arrow.vector.types.pojo.ArrowType.Utf8; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.util.CallBack; @@ -129,7 +130,7 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { return ZeroVector.INSTANCE; } @@ -145,8 +146,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableMapVector(name, allocator, callBack); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableMapVector(name, allocator, dictionary, callBack); } @Override @@ -161,8 +162,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableTinyIntVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableTinyIntVector(name, allocator, dictionary); } @Override @@ -177,8 +178,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableSmallIntVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableSmallIntVector(name, allocator, dictionary); } @Override @@ -193,8 +194,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableIntVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableIntVector(name, allocator, dictionary); } @Override @@ -209,8 +210,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableBigIntVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableBigIntVector(name, allocator, dictionary); } @Override @@ -225,8 +226,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableDateVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableDateVector(name, allocator, dictionary); } @Override @@ -241,8 +242,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableTimeVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableTimeVector(name, allocator, dictionary); } @Override @@ -258,8 +259,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableTimeStampSecVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableTimeStampSecVector(name, allocator, dictionary); } @Override @@ -275,8 +276,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableTimeStampMilliVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableTimeStampMilliVector(name, allocator, dictionary); } @Override @@ -292,8 +293,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableTimeStampMicroVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableTimeStampMicroVector(name, allocator, dictionary); } @Override @@ -309,8 +310,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableTimeStampNanoVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableTimeStampNanoVector(name, allocator, dictionary); } @Override @@ -325,8 +326,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableIntervalDayVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableIntervalDayVector(name, allocator, dictionary); } @Override @@ -341,8 +342,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableIntervalDayVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableIntervalDayVector(name, allocator, dictionary); } @Override @@ -358,8 +359,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableFloat4Vector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableFloat4Vector(name, allocator, dictionary); } @Override @@ -375,8 +376,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableFloat8Vector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableFloat8Vector(name, allocator, dictionary); } @Override @@ -391,8 +392,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableBitVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableBitVector(name, allocator, dictionary); } @Override @@ -407,8 +408,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableVarCharVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableVarCharVector(name, allocator, dictionary); } @Override @@ -423,8 +424,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableVarBinaryVector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableVarBinaryVector(name, allocator, dictionary); } @Override @@ -443,8 +444,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableDecimalVector(name, allocator, precisionScale[0], precisionScale[1]); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableDecimalVector(name, allocator, dictionary, precisionScale[0], precisionScale[1]); } @Override @@ -459,8 +460,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableUInt1Vector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableUInt1Vector(name, allocator, dictionary); } @Override @@ -475,8 +476,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableUInt2Vector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableUInt2Vector(name, allocator, dictionary); } @Override @@ -491,8 +492,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableUInt4Vector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableUInt4Vector(name, allocator, dictionary); } @Override @@ -507,8 +508,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new NullableUInt8Vector(name, allocator); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new NullableUInt8Vector(name, allocator, dictionary); } @Override @@ -523,8 +524,8 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { - return new ListVector(name, allocator, callBack); + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + return new ListVector(name, allocator, dictionary, callBack); } @Override @@ -539,7 +540,10 @@ public Field getField() { } @Override - public FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale) { + public FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale) { + if (dictionary != null) { + throw new UnsupportedOperationException("Dictionary encoding not supported for complex types"); + } return new UnionVector(name, allocator, callBack); } @@ -561,7 +565,7 @@ public ArrowType getType() { public abstract Field getField(); - public abstract FieldVector getNewVector(String name, BufferAllocator allocator, CallBack callBack, int... precisionScale); + public abstract FieldVector getNewVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack, int... precisionScale); public abstract FieldWriter getNewFieldWriter(ValueVector vector); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java new file mode 100644 index 0000000000000..6d35cdef832f9 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/DictionaryEncoding.java @@ -0,0 +1,51 @@ +/******************************************************************************* + + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.apache.arrow.vector.types.pojo; + +import org.apache.arrow.vector.types.pojo.ArrowType.Int; + +public class DictionaryEncoding { + + private final long id; + private final boolean ordered; + private final Int indexType; + + public DictionaryEncoding(long id, boolean ordered, Int indexType) { + this.id = id; + this.ordered = ordered; + this.indexType = indexType == null ? new Int(32, true) : indexType; + } + + public long getId() { + return id; + } + + public boolean isOrdered() { + return ordered; + } + + public Int getIndexType() { + return indexType; + } + + @Override + public String toString() { + return "DictionaryEncoding[id=" + id + ",ordered=" + ordered + ",indexType=" + indexType + "]"; + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java index 2d528e4141907..b8db46f346ad1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/Field.java @@ -24,23 +24,27 @@ import java.util.List; import java.util.Objects; +import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; -import org.apache.arrow.flatbuf.DictionaryEncoding; -import org.apache.arrow.vector.schema.TypeLayout; -import org.apache.arrow.vector.schema.VectorLayout; - -import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.flatbuffers.FlatBufferBuilder; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.schema.TypeLayout; +import org.apache.arrow.vector.schema.VectorLayout; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType.Int; + public class Field { private final String name; private final boolean nullable; private final ArrowType type; - private final Long dictionary; + private final DictionaryEncoding dictionary; private final List children; private final TypeLayout typeLayout; @@ -49,7 +53,7 @@ private Field( @JsonProperty("name") String name, @JsonProperty("nullable") boolean nullable, @JsonProperty("type") ArrowType type, - @JsonProperty("dictionary") Long dictionary, + @JsonProperty("dictionary") DictionaryEncoding dictionary, @JsonProperty("children") List children, @JsonProperty("typeLayout") TypeLayout typeLayout) { this.name = name; @@ -68,18 +72,30 @@ public Field(String name, boolean nullable, ArrowType type, List children this(name, nullable, type, null, children, TypeLayout.getTypeLayout(checkNotNull(type))); } - public Field(String name, boolean nullable, ArrowType type, Long dictionary, List children) { + public Field(String name, boolean nullable, ArrowType type, DictionaryEncoding dictionary, List children) { this(name, nullable, type, dictionary, children, TypeLayout.getTypeLayout(checkNotNull(type))); } + public FieldVector createVector(BufferAllocator allocator) { + MinorType minorType = Types.getMinorTypeForArrowType(type); + FieldVector vector = minorType.getNewVector(name, allocator, dictionary, null); + vector.initializeChildrenFromFields(children); + return vector; + } + public static Field convertField(org.apache.arrow.flatbuf.Field field) { String name = field.name(); boolean nullable = field.nullable(); ArrowType type = getTypeForField(field); - DictionaryEncoding dictionaryEncoding = field.dictionary(); - Long dictionary = null; - if (dictionaryEncoding != null) { - dictionary = dictionaryEncoding.id(); + DictionaryEncoding dictionary = null; + org.apache.arrow.flatbuf.DictionaryEncoding dictionaryFB = field.dictionary(); + if (dictionaryFB != null) { + Int indexType = null; + org.apache.arrow.flatbuf.Int indexTypeFB = dictionaryFB.indexType(); + if (indexTypeFB != null) { + indexType = new Int(indexTypeFB.bitWidth(), indexTypeFB.isSigned()); + } + dictionary = new DictionaryEncoding(dictionaryFB.id(), dictionaryFB.isOrdered(), indexType); } ImmutableList.Builder layout = ImmutableList.builder(); for (int i = 0; i < field.layoutLength(); ++i) { @@ -105,8 +121,11 @@ public int getField(FlatBufferBuilder builder) { int typeOffset = type.getType(builder); int dictionaryOffset = -1; if (dictionary != null) { - builder.addLong(dictionary); - dictionaryOffset = builder.offset(); + // TODO encode dictionary type - currently type is only signed 32 bit int (default null) + org.apache.arrow.flatbuf.DictionaryEncoding.startDictionaryEncoding(builder); + org.apache.arrow.flatbuf.DictionaryEncoding.addId(builder, dictionary.getId()); + org.apache.arrow.flatbuf.DictionaryEncoding.addIsOrdered(builder, dictionary.isOrdered()); + dictionaryOffset = org.apache.arrow.flatbuf.DictionaryEncoding.endDictionaryEncoding(builder); } int[] childrenData = new int[children.size()]; for (int i = 0; i < children.size(); i++) { @@ -126,11 +145,11 @@ public int getField(FlatBufferBuilder builder) { org.apache.arrow.flatbuf.Field.addNullable(builder, nullable); org.apache.arrow.flatbuf.Field.addTypeType(builder, type.getTypeID().getFlatbufID()); org.apache.arrow.flatbuf.Field.addType(builder, typeOffset); + org.apache.arrow.flatbuf.Field.addChildren(builder, childrenOffset); + org.apache.arrow.flatbuf.Field.addLayout(builder, layoutOffset); if (dictionary != null) { org.apache.arrow.flatbuf.Field.addDictionary(builder, dictionaryOffset); } - org.apache.arrow.flatbuf.Field.addChildren(builder, childrenOffset); - org.apache.arrow.flatbuf.Field.addLayout(builder, layoutOffset); return org.apache.arrow.flatbuf.Field.endField(builder); } @@ -147,7 +166,7 @@ public ArrowType getType() { } @JsonInclude(Include.NON_NULL) - public Long getDictionary() { return dictionary; } + public DictionaryEncoding getDictionary() { return dictionary; } public List getChildren() { return children; @@ -168,8 +187,8 @@ public boolean equals(Object obj) { Objects.equals(this.type, that.type) && Objects.equals(this.dictionary, that.dictionary) && (Objects.equals(this.children, that.children) || - (this.children == null && that.children.size() == 0) || - (this.children.size() == 0 && that.children == null)); + (this.children == null || this.children.size() == 0) && + (that.children == null || that.children.size() == 0)); } @Override @@ -180,7 +199,7 @@ public String toString() { } sb.append(type); if (dictionary != null) { - sb.append("[dictionary: ").append(dictionary).append("]"); + sb.append("[dictionary: ").append(dictionary.getId()).append("]"); } if (!children.isEmpty()) { sb.append("<").append(Joiner.on(", ").join(children)).append(">"); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java index cca35e44a215d..20f4aa8cf643d 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java @@ -44,7 +44,7 @@ public class TestDecimalVector { @Test public void test() { BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - NullableDecimalVector decimalVector = new NullableDecimalVector("decimal", allocator, 10, scale); + NullableDecimalVector decimalVector = new NullableDecimalVector("decimal", allocator, null, 10, scale); decimalVector.allocateNew(); BigDecimal[] values = new BigDecimal[intValues.length]; for (int i = 0; i < intValues.length; i++) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java index 962950abec87a..e3087ef8c95cc 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java @@ -18,16 +18,16 @@ package org.apache.arrow.vector; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.complex.DictionaryVector; -import org.apache.arrow.vector.types.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.junit.After; import org.junit.Before; import org.junit.Test; import java.nio.charset.StandardCharsets; -import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; public class TestDictionaryVector { @@ -49,65 +49,10 @@ public void terminate() throws Exception { } @Test - public void testEncodeStringsWithGeneratedDictionary() { + public void testEncodeStrings() { // Create a new value vector - try (final NullableVarCharVector vector = (NullableVarCharVector) MinorType.VARCHAR.getNewVector("foo", allocator, null)) { - final NullableVarCharVector.Mutator m = vector.getMutator(); - vector.allocateNew(512, 5); - - // set some values - m.setSafe(0, zero, 0, zero.length); - m.setSafe(1, one, 0, one.length); - m.setSafe(2, one, 0, one.length); - m.setSafe(3, two, 0, two.length); - m.setSafe(4, zero, 0, zero.length); - m.setValueCount(5); - - DictionaryVector encoded = DictionaryVector.encode(vector); - - try { - // verify values in the dictionary - ValueVector dictionary = encoded.getDictionaryVector(); - assertEquals(vector.getClass(), dictionary.getClass()); - - NullableVarCharVector.Accessor dictionaryAccessor = ((NullableVarCharVector) dictionary).getAccessor(); - assertEquals(3, dictionaryAccessor.getValueCount()); - assertArrayEquals(zero, dictionaryAccessor.get(0)); - assertArrayEquals(one, dictionaryAccessor.get(1)); - assertArrayEquals(two, dictionaryAccessor.get(2)); - - // verify indices - ValueVector indices = encoded.getIndexVector(); - assertEquals(NullableIntVector.class, indices.getClass()); - - NullableIntVector.Accessor indexAccessor = ((NullableIntVector) indices).getAccessor(); - assertEquals(5, indexAccessor.getValueCount()); - assertEquals(0, indexAccessor.get(0)); - assertEquals(1, indexAccessor.get(1)); - assertEquals(1, indexAccessor.get(2)); - assertEquals(2, indexAccessor.get(3)); - assertEquals(0, indexAccessor.get(4)); - - // now run through the decoder and verify we get the original back - try (ValueVector decoded = DictionaryVector.decode(indices, encoded.getDictionary())) { - assertEquals(vector.getClass(), decoded.getClass()); - assertEquals(vector.getAccessor().getValueCount(), decoded.getAccessor().getValueCount()); - for (int i = 0; i < 5; i++) { - assertEquals(vector.getAccessor().getObject(i), decoded.getAccessor().getObject(i)); - } - } - } finally { - encoded.getDictionaryVector().close(); - encoded.getIndexVector().close(); - } - } - } - - @Test - public void testEncodeStringsWithProvidedDictionary() { - // Create a new value vector - try (final NullableVarCharVector vector = (NullableVarCharVector) MinorType.VARCHAR.getNewVector("foo", allocator, null); - final NullableVarCharVector dictionary = (NullableVarCharVector) MinorType.VARCHAR.getNewVector("dict", allocator, null)) { + try (final NullableVarCharVector vector = (NullableVarCharVector) MinorType.VARCHAR.getNewVector("foo", allocator, null, null); + final NullableVarCharVector dictionaryVector = (NullableVarCharVector) MinorType.VARCHAR.getNewVector("dict", allocator, null, null)) { final NullableVarCharVector.Mutator m = vector.getMutator(); vector.allocateNew(512, 5); @@ -120,19 +65,20 @@ public void testEncodeStringsWithProvidedDictionary() { m.setValueCount(5); // set some dictionary values - final NullableVarCharVector.Mutator m2 = dictionary.getMutator(); - dictionary.allocateNew(512, 3); + final NullableVarCharVector.Mutator m2 = dictionaryVector.getMutator(); + dictionaryVector.allocateNew(512, 3); m2.setSafe(0, zero, 0, zero.length); m2.setSafe(1, one, 0, one.length); m2.setSafe(2, two, 0, two.length); m2.setValueCount(3); - try(final DictionaryVector encoded = DictionaryVector.encode(vector, new Dictionary(dictionary, false))) { + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + + try(final ValueVector encoded = (FieldVector) DictionaryEncoder.encode(vector, dictionary)) { // verify indices - ValueVector indices = encoded.getIndexVector(); - assertEquals(NullableIntVector.class, indices.getClass()); + assertEquals(NullableIntVector.class, encoded.getClass()); - NullableIntVector.Accessor indexAccessor = ((NullableIntVector) indices).getAccessor(); + NullableIntVector.Accessor indexAccessor = ((NullableIntVector) encoded).getAccessor(); assertEquals(5, indexAccessor.getValueCount()); assertEquals(0, indexAccessor.get(0)); assertEquals(1, indexAccessor.get(1)); @@ -141,7 +87,7 @@ public void testEncodeStringsWithProvidedDictionary() { assertEquals(0, indexAccessor.get(4)); // now run through the decoder and verify we get the original back - try (ValueVector decoded = DictionaryVector.decode(indices, encoded.getDictionary())) { + try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { assertEquals(vector.getClass(), decoded.getClass()); assertEquals(vector.getAccessor().getValueCount(), decoded.getAccessor().getValueCount()); for (int i = 0; i < 5; i++) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java index 1f0baaed776a1..18d93b6401e39 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java @@ -42,8 +42,8 @@ public void terminate() throws Exception { @Test public void testCopyFrom() throws Exception { - try (ListVector inVector = new ListVector("input", allocator, null); - ListVector outVector = new ListVector("output", allocator, null)) { + try (ListVector inVector = new ListVector("input", allocator, null, null); + ListVector outVector = new ListVector("output", allocator, null, null)) { UnionListWriter writer = inVector.getWriter(); writer.allocate(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java index 774b59e3683e3..6917638d74e4d 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java @@ -86,7 +86,7 @@ public void testFixedType() { public void testNullableVarLen2() { // Create a new value vector for 1024 integers. - try (final NullableVarCharVector vector = new NullableVarCharVector(EMPTY_SCHEMA_PATH, allocator)) { + try (final NullableVarCharVector vector = new NullableVarCharVector(EMPTY_SCHEMA_PATH, allocator, null)) { final NullableVarCharVector.Mutator m = vector.getMutator(); vector.allocateNew(1024 * 10, 1024); @@ -116,7 +116,7 @@ public void testNullableVarLen2() { public void testNullableFixedType() { // Create a new value vector for 1024 integers. - try (final NullableUInt4Vector vector = new NullableUInt4Vector(EMPTY_SCHEMA_PATH, allocator)) { + try (final NullableUInt4Vector vector = new NullableUInt4Vector(EMPTY_SCHEMA_PATH, allocator, null)) { final NullableUInt4Vector.Mutator m = vector.getMutator(); vector.allocateNew(1024); @@ -186,7 +186,7 @@ public void testNullableFixedType() { @Test public void testNullableFloat() { // Create a new value vector for 1024 integers - try (final NullableFloat4Vector vector = (NullableFloat4Vector) MinorType.FLOAT4.getNewVector(EMPTY_SCHEMA_PATH, allocator, null)) { + try (final NullableFloat4Vector vector = (NullableFloat4Vector) MinorType.FLOAT4.getNewVector(EMPTY_SCHEMA_PATH, allocator, null, null)) { final NullableFloat4Vector.Mutator m = vector.getMutator(); vector.allocateNew(1024); @@ -233,7 +233,7 @@ public void testNullableFloat() { @Test public void testNullableInt() { // Create a new value vector for 1024 integers - try (final NullableIntVector vector = (NullableIntVector) MinorType.INT.getNewVector(EMPTY_SCHEMA_PATH, allocator, null)) { + try (final NullableIntVector vector = (NullableIntVector) MinorType.INT.getNewVector(EMPTY_SCHEMA_PATH, allocator, null, null)) { final NullableIntVector.Mutator m = vector.getMutator(); vector.allocateNew(1024); @@ -403,7 +403,7 @@ private void validateRange(int length, int start, int count) { @Test public void testReAllocNullableFixedWidthVector() { // Create a new value vector for 1024 integers - try (final NullableFloat4Vector vector = (NullableFloat4Vector) MinorType.FLOAT4.getNewVector(EMPTY_SCHEMA_PATH, allocator, null)) { + try (final NullableFloat4Vector vector = (NullableFloat4Vector) MinorType.FLOAT4.getNewVector(EMPTY_SCHEMA_PATH, allocator, null, null)) { final NullableFloat4Vector.Mutator m = vector.getMutator(); vector.allocateNew(1024); @@ -436,7 +436,7 @@ public void testReAllocNullableFixedWidthVector() { @Test public void testReAllocNullableVariableWidthVector() { // Create a new value vector for 1024 integers - try (final NullableVarCharVector vector = (NullableVarCharVector) MinorType.VARCHAR.getNewVector(EMPTY_SCHEMA_PATH, allocator, null)) { + try (final NullableVarCharVector vector = (NullableVarCharVector) MinorType.VARCHAR.getNewVector(EMPTY_SCHEMA_PATH, allocator, null, null)) { final NullableVarCharVector.Mutator m = vector.getMutator(); vector.allocateNew(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorUnloadLoad.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorUnloadLoad.java index 79c9d5046acd6..372bcf0da6e9a 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorUnloadLoad.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorUnloadLoad.java @@ -27,6 +27,7 @@ import java.util.Collections; import java.util.List; +import io.netty.buffer.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.complex.MapVector; @@ -46,8 +47,6 @@ import org.junit.Assert; import org.junit.Test; -import io.netty.buffer.ArrowBuf; - public class TestVectorUnloadLoad { static final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); @@ -81,8 +80,8 @@ public void testUnloadLoad() throws IOException { try ( ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch(); BufferAllocator finalVectorsAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - VectorSchemaRoot newRoot = new VectorSchemaRoot(schema, finalVectorsAllocator); - ) { + VectorSchemaRoot newRoot = VectorSchemaRoot.create(schema, finalVectorsAllocator); + ) { // load it VectorLoader vectorLoader = new VectorLoader(newRoot); @@ -131,8 +130,8 @@ public void testUnloadLoadAddPadding() throws IOException { try ( ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch(); BufferAllocator finalVectorsAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - VectorSchemaRoot newRoot = new VectorSchemaRoot(schema, finalVectorsAllocator); - ) { + VectorSchemaRoot newRoot = VectorSchemaRoot.create(schema, finalVectorsAllocator); + ) { List oldBuffers = recordBatch.getBuffers(); List newBuffers = new ArrayList<>(); for (ArrowBuf oldBuffer : oldBuffers) { @@ -185,7 +184,7 @@ public void testLoadEmptyValidityBuffer() throws IOException { Schema schema = new Schema(asList( new Field("intDefined", true, new ArrowType.Int(32, true), Collections.emptyList()), new Field("intNull", true, new ArrowType.Int(32, true), Collections.emptyList()) - )); + )); int count = 10; ArrowBuf validity = allocator.buffer(10).slice(0, 0); ArrowBuf[] values = new ArrowBuf[2]; @@ -200,8 +199,8 @@ public void testLoadEmptyValidityBuffer() throws IOException { try ( ArrowRecordBatch recordBatch = new ArrowRecordBatch(count, asList(new ArrowFieldNode(count, 0), new ArrowFieldNode(count, count)), asList(validity, values[0], validity, values[1])); BufferAllocator finalVectorsAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - VectorSchemaRoot newRoot = new VectorSchemaRoot(schema, finalVectorsAllocator); - ) { + VectorSchemaRoot newRoot = VectorSchemaRoot.create(schema, finalVectorsAllocator); + ) { // load it VectorLoader vectorLoader = new VectorLoader(newRoot); @@ -244,11 +243,12 @@ public static VectorUnloader newVectorUnloader(FieldVector root) { Schema schema = new Schema(root.getField().getChildren()); int valueCount = root.getAccessor().getValueCount(); List fields = root.getChildrenFromFields(); - return new VectorUnloader(schema, valueCount, fields); + VectorSchemaRoot vsr = new VectorSchemaRoot(schema.getFields(), fields, valueCount); + return new VectorUnloader(vsr); } @AfterClass public static void afterClass() { allocator.close(); } -} +} \ No newline at end of file diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index 58312b3f9ff9c..2b49d8ed4b582 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -53,7 +53,7 @@ public void terminate() throws Exception { public void testPromoteToUnion() throws Exception { try (final MapVector container = new MapVector(EMPTY_SCHEMA_PATH, allocator, null); - final NullableMapVector v = container.addOrGet("test", MinorType.MAP, NullableMapVector.class); + final NullableMapVector v = container.addOrGet("test", MinorType.MAP, NullableMapVector.class, null); final PromotableWriter writer = new PromotableWriter(v, container)) { container.allocateNew(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java index 7a2d416241b78..a8a2d512c09ec 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java @@ -181,7 +181,7 @@ public void testList() { @Test public void listScalarType() { - ListVector listVector = new ListVector("list", allocator, null); + ListVector listVector = new ListVector("list", allocator, null, null); listVector.allocateNew(); UnionListWriter listWriter = new UnionListWriter(listVector); for (int i = 0; i < COUNT; i++) { @@ -204,7 +204,7 @@ public void listScalarType() { @Test public void listScalarTypeNullable() { - ListVector listVector = new ListVector("list", allocator, null); + ListVector listVector = new ListVector("list", allocator, null, null); listVector.allocateNew(); UnionListWriter listWriter = new UnionListWriter(listVector); for (int i = 0; i < COUNT; i++) { @@ -233,7 +233,7 @@ public void listScalarTypeNullable() { @Test public void listMapType() { - ListVector listVector = new ListVector("list", allocator, null); + ListVector listVector = new ListVector("list", allocator, null, null); listVector.allocateNew(); UnionListWriter listWriter = new UnionListWriter(listVector); MapWriter mapWriter = listWriter.map(); @@ -261,7 +261,7 @@ public void listMapType() { @Test public void listListType() { - try (ListVector listVector = new ListVector("list", allocator, null)) { + try (ListVector listVector = new ListVector("list", allocator, null, null)) { listVector.allocateNew(); UnionListWriter listWriter = new UnionListWriter(listVector); for (int i = 0; i < COUNT; i++) { @@ -286,7 +286,7 @@ public void listListType() { */ @Test public void listListType2() { - try (ListVector listVector = new ListVector("list", allocator, null)) { + try (ListVector listVector = new ListVector("list", allocator, null, null)) { listVector.allocateNew(); UnionListWriter listWriter = new UnionListWriter(listVector); ListWriter innerListWriter = listWriter.list(); @@ -324,7 +324,7 @@ private void checkListOfLists(final ListVector listVector) { @Test public void unionListListType() { - try (ListVector listVector = new ListVector("list", allocator, null)) { + try (ListVector listVector = new ListVector("list", allocator, null, null)) { listVector.allocateNew(); UnionListWriter listWriter = new UnionListWriter(listVector); for (int i = 0; i < COUNT; i++) { @@ -353,7 +353,7 @@ public void unionListListType() { */ @Test public void unionListListType2() { - try (ListVector listVector = new ListVector("list", allocator, null)) { + try (ListVector listVector = new ListVector("list", allocator, null, null)) { listVector.allocateNew(); UnionListWriter listWriter = new UnionListWriter(listVector); ListWriter innerListWriter = listWriter.list(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java index a83a2833c88bf..75e5d2d6e5c98 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java @@ -17,31 +17,44 @@ */ package org.apache.arrow.vector.file; -import static org.apache.arrow.vector.TestVectorUnloadLoad.newVectorUnloader; -import static org.junit.Assert.assertTrue; - import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileInputStream; -import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.List; +import com.google.common.collect.ImmutableList; + import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.NullableTinyIntVector; +import org.apache.arrow.vector.NullableVarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.NullableMapVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.schema.ArrowBuffer; +import org.apache.arrow.vector.schema.ArrowMessage; import org.apache.arrow.vector.schema.ArrowRecordBatch; import org.apache.arrow.vector.stream.ArrowStreamReader; import org.apache.arrow.vector.stream.ArrowStreamWriter; +import org.apache.arrow.vector.stream.MessageSerializerTest; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType.Int; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; import org.junit.Assert; import org.junit.Test; import org.slf4j.Logger; @@ -68,7 +81,7 @@ public void testWriteComplex() throws IOException { int count = COUNT; try ( BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - NullableMapVector parent = new NullableMapVector("parent", vectorAllocator, null)) { + NullableMapVector parent = new NullableMapVector("parent", vectorAllocator, null, null)) { writeComplexData(count, parent); FieldVector root = parent.getChild("root"); validateComplexContent(count, new VectorSchemaRoot(root)); @@ -83,71 +96,63 @@ public void testWriteRead() throws IOException { int count = COUNT; // write - try ( - BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - MapVector parent = new MapVector("parent", originalVectorAllocator, null)) { + try (BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); + MapVector parent = new MapVector("parent", originalVectorAllocator, null)) { writeData(count, parent); write(parent.getChild("root"), file, stream); } // read - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowReader arrowReader = new ArrowReader(fileInputStream.getChannel(), readerAllocator); - BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - MapVector parent = new MapVector("parent", vectorAllocator, null) - ) { - ArrowFooter footer = arrowReader.readFooter(); - Schema schema = footer.getSchema(); - LOGGER.debug("reading schema: " + schema); - - // initialize vectors - - try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator)) { - VectorLoader vectorLoader = new VectorLoader(root); - - List recordBatches = footer.getRecordBatches(); - for (ArrowBlock rbBlock : recordBatches) { - try (ArrowRecordBatch recordBatch = arrowReader.readRecordBatch(rbBlock)) { - List buffersLayout = recordBatch.getBuffersLayout(); - for (ArrowBuffer arrowBuffer : buffersLayout) { - Assert.assertEquals(0, arrowBuffer.getOffset() % 8); + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator){ + @Override + protected ArrowMessage readMessage(SeekableReadChannel in, BufferAllocator allocator) throws IOException { + ArrowMessage message = super.readMessage(in, allocator); + if (message != null) { + ArrowRecordBatch batch = (ArrowRecordBatch) message; + List buffersLayout = batch.getBuffersLayout(); + for (ArrowBuffer arrowBuffer : buffersLayout) { + Assert.assertEquals(0, arrowBuffer.getOffset() % 8); + } + } + return message; } - vectorLoader.load(recordBatch); - } - - validateContent(count, root); - } + }) { + Schema schema = arrowReader.getVectorSchemaRoot().getSchema(); + LOGGER.debug("reading schema: " + schema); + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + for (ArrowBlock rbBlock : arrowReader.getRecordBlocks()) { + arrowReader.loadRecordBatch(rbBlock); + Assert.assertEquals(count, root.getRowCount()); + validateContent(count, root); } } // Read from stream. - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator); - BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - MapVector parent = new MapVector("parent", vectorAllocator, null) - ) { - arrowReader.init(); - Schema schema = arrowReader.getSchema(); + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); + ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator){ + @Override + protected ArrowMessage readMessage(ReadChannel in, BufferAllocator allocator) throws IOException { + ArrowMessage message = super.readMessage(in, allocator); + if (message != null) { + ArrowRecordBatch batch = (ArrowRecordBatch) message; + List buffersLayout = batch.getBuffersLayout(); + for (ArrowBuffer arrowBuffer : buffersLayout) { + Assert.assertEquals(0, arrowBuffer.getOffset() % 8); + } + } + return message; + } + }) { + + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - - try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator)) { - VectorLoader vectorLoader = new VectorLoader(root); - while (true) { - try (ArrowRecordBatch recordBatch = arrowReader.nextRecordBatch()) { - if (recordBatch == null) break; - List buffersLayout = recordBatch.getBuffersLayout(); - for (ArrowBuffer arrowBuffer : buffersLayout) { - Assert.assertEquals(0, arrowBuffer.getOffset() % 8); - } - vectorLoader.load(recordBatch); - } - } - validateContent(count, root); - } + arrowReader.loadNextBatch(); + Assert.assertEquals(count, root.getRowCount()); + validateContent(count, root); } } @@ -158,61 +163,37 @@ public void testWriteReadComplex() throws IOException { int count = COUNT; // write - try ( - BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - MapVector parent = new MapVector("parent", originalVectorAllocator, null)) { + try (BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); + MapVector parent = new MapVector("parent", originalVectorAllocator, null)) { writeComplexData(count, parent); write(parent.getChild("root"), file, stream); } // read - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowReader arrowReader = new ArrowReader(fileInputStream.getChannel(), readerAllocator); - BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - NullableMapVector parent = new NullableMapVector("parent", vectorAllocator, null) - ) { - ArrowFooter footer = arrowReader.readFooter(); - Schema schema = footer.getSchema(); + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - // initialize vectors - - try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator)) { - VectorLoader vectorLoader = new VectorLoader(root); - List recordBatches = footer.getRecordBatches(); - for (ArrowBlock rbBlock : recordBatches) { - try (ArrowRecordBatch recordBatch = arrowReader.readRecordBatch(rbBlock)) { - vectorLoader.load(recordBatch); - } - validateComplexContent(count, root); - } + for (ArrowBlock rbBlock : arrowReader.getRecordBlocks()) { + arrowReader.loadRecordBatch(rbBlock); + Assert.assertEquals(count, root.getRowCount()); + validateComplexContent(count, root); } } // Read from stream. - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator); - BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - MapVector parent = new MapVector("parent", vectorAllocator, null) - ) { - arrowReader.init(); - Schema schema = arrowReader.getSchema(); + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); + ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - - try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator)) { - VectorLoader vectorLoader = new VectorLoader(root); - while (true) { - try (ArrowRecordBatch recordBatch = arrowReader.nextRecordBatch()) { - if (recordBatch == null) break; - vectorLoader.load(recordBatch); - } - } - validateComplexContent(count, root); - } + arrowReader.loadNextBatch(); + Assert.assertEquals(count, root.getRowCount()); + validateComplexContent(count, root); } } @@ -223,94 +204,70 @@ public void testWriteReadMultipleRBs() throws IOException { int[] counts = { 10, 5 }; // write - try ( - BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - MapVector parent = new MapVector("parent", originalVectorAllocator, null); - FileOutputStream fileOutputStream = new FileOutputStream(file);) { + try (BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); + MapVector parent = new MapVector("parent", originalVectorAllocator, null); + FileOutputStream fileOutputStream = new FileOutputStream(file)){ writeData(counts[0], parent); - VectorUnloader vectorUnloader0 = newVectorUnloader(parent.getChild("root")); - Schema schema = vectorUnloader0.getSchema(); - Assert.assertEquals(2, schema.getFields().size()); - try (ArrowWriter arrowWriter = new ArrowWriter(fileOutputStream.getChannel(), schema); - ArrowStreamWriter streamWriter = new ArrowStreamWriter(stream, schema)) { - try (ArrowRecordBatch recordBatch = vectorUnloader0.getRecordBatch()) { - Assert.assertEquals("RB #0", counts[0], recordBatch.getLength()); - arrowWriter.writeRecordBatch(recordBatch); - streamWriter.writeRecordBatch(recordBatch); - } + VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root")); + + try(ArrowFileWriter fileWriter = new ArrowFileWriter(root, null, fileOutputStream.getChannel()); + ArrowStreamWriter streamWriter = new ArrowStreamWriter(root, null, stream)) { + fileWriter.start(); + streamWriter.start(); + + fileWriter.writeBatch(); + streamWriter.writeBatch(); + parent.allocateNew(); writeData(counts[1], parent); // if we write the same data we don't catch that the metadata is stored in the wrong order. - VectorUnloader vectorUnloader1 = newVectorUnloader(parent.getChild("root")); - try (ArrowRecordBatch recordBatch = vectorUnloader1.getRecordBatch()) { - Assert.assertEquals("RB #1", counts[1], recordBatch.getLength()); - arrowWriter.writeRecordBatch(recordBatch); - streamWriter.writeRecordBatch(recordBatch); - } + root.setRowCount(counts[1]); + + fileWriter.writeBatch(); + streamWriter.writeBatch(); + + fileWriter.end(); + streamWriter.end(); } } - // read - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowReader arrowReader = new ArrowReader(fileInputStream.getChannel(), readerAllocator); - BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - MapVector parent = new MapVector("parent", vectorAllocator, null); - ) { - ArrowFooter footer = arrowReader.readFooter(); - Schema schema = footer.getSchema(); + // read file + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); int i = 0; - try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator);) { - VectorLoader vectorLoader = new VectorLoader(root); - List recordBatches = footer.getRecordBatches(); - Assert.assertEquals(2, recordBatches.size()); - long previousOffset = 0; - for (ArrowBlock rbBlock : recordBatches) { - Assert.assertTrue(rbBlock.getOffset() + " > " + previousOffset, rbBlock.getOffset() > previousOffset); - previousOffset = rbBlock.getOffset(); - try (ArrowRecordBatch recordBatch = arrowReader.readRecordBatch(rbBlock)) { - Assert.assertEquals("RB #" + i, counts[i], recordBatch.getLength()); - List buffersLayout = recordBatch.getBuffersLayout(); - for (ArrowBuffer arrowBuffer : buffersLayout) { - Assert.assertEquals(0, arrowBuffer.getOffset() % 8); - } - vectorLoader.load(recordBatch); - validateContent(counts[i], root); - } - ++i; - } + List recordBatches = arrowReader.getRecordBlocks(); + Assert.assertEquals(2, recordBatches.size()); + long previousOffset = 0; + for (ArrowBlock rbBlock : recordBatches) { + Assert.assertTrue(rbBlock.getOffset() + " > " + previousOffset, rbBlock.getOffset() > previousOffset); + previousOffset = rbBlock.getOffset(); + arrowReader.loadRecordBatch(rbBlock); + Assert.assertEquals("RB #" + i, counts[i], root.getRowCount()); + validateContent(counts[i], root); + ++i; } } // read stream - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator); - BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - MapVector parent = new MapVector("parent", vectorAllocator, null) - ) { - arrowReader.init(); - Schema schema = arrowReader.getSchema(); + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); + ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); int i = 0; - try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator);) { - VectorLoader vectorLoader = new VectorLoader(root); - for (int n = 0; n < 2; n++) { - try (ArrowRecordBatch recordBatch = arrowReader.nextRecordBatch()) { - assertTrue(recordBatch != null); - Assert.assertEquals("RB #" + i, counts[i], recordBatch.getLength()); - List buffersLayout = recordBatch.getBuffersLayout(); - for (ArrowBuffer arrowBuffer : buffersLayout) { - Assert.assertEquals(0, arrowBuffer.getOffset() % 8); - } - vectorLoader.load(recordBatch); - validateContent(counts[i], root); - } - ++i; - } + + for (int n = 0; n < 2; n++) { + arrowReader.loadNextBatch(); + Assert.assertEquals("RB #" + i, counts[i], root.getRowCount()); + validateContent(counts[i], root); + ++i; } + arrowReader.loadNextBatch(); + Assert.assertEquals(0, root.getRowCount()); } } @@ -319,90 +276,326 @@ public void testWriteReadUnion() throws IOException { File file = new File("target/mytest_write_union.arrow"); ByteArrayOutputStream stream = new ByteArrayOutputStream(); int count = COUNT; - try ( - BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - NullableMapVector parent = new NullableMapVector("parent", vectorAllocator, null)) { + // write + try (BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); + NullableMapVector parent = new NullableMapVector("parent", vectorAllocator, null, null)) { writeUnionData(count, parent); - - printVectors(parent.getChildrenFromFields()); - validateUnionData(count, new VectorSchemaRoot(parent.getChild("root"))); - write(parent.getChild("root"), file, stream); } - // read - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - FileInputStream fileInputStream = new FileInputStream(file); - ArrowReader arrowReader = new ArrowReader(fileInputStream.getChannel(), readerAllocator); - BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - ) { - ArrowFooter footer = arrowReader.readFooter(); - Schema schema = footer.getSchema(); + + // read file + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); + arrowReader.loadNextBatch(); + validateUnionData(count, root); + } + + // Read from stream. + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); + ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); + LOGGER.debug("reading schema: " + schema); + arrowReader.loadNextBatch(); + validateUnionData(count, root); + } + } - // initialize vectors - try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator);) { - VectorLoader vectorLoader = new VectorLoader(root); - List recordBatches = footer.getRecordBatches(); - for (ArrowBlock rbBlock : recordBatches) { - try (ArrowRecordBatch recordBatch = arrowReader.readRecordBatch(rbBlock)) { - vectorLoader.load(recordBatch); - } - validateUnionData(count, root); - } + @Test + public void testWriteReadTiny() throws IOException { + File file = new File("target/mytest_write_tiny.arrow"); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(MessageSerializerTest.testSchema(), allocator)) { + root.getFieldVectors().get(0).allocateNew(); + NullableTinyIntVector.Mutator mutator = (NullableTinyIntVector.Mutator) root.getFieldVectors().get(0).getMutator(); + for (int i = 0; i < 16; i++) { + mutator.set(i, i < 8 ? 1 : 0, (byte)(i + 1)); + } + mutator.setValueCount(16); + root.setRowCount(16); + + // write file + try (FileOutputStream fileOutputStream = new FileOutputStream(file); + ArrowFileWriter arrowWriter = new ArrowFileWriter(root, null, fileOutputStream.getChannel())) { + LOGGER.debug("writing schema: " + root.getSchema()); + arrowWriter.start(); + arrowWriter.writeBatch(); + arrowWriter.end(); + } + // write stream + try (ArrowStreamWriter arrowWriter = new ArrowStreamWriter(root, null, stream)) { + arrowWriter.start(); + arrowWriter.writeBatch(); + arrowWriter.end(); } } + // read file + try (BufferAllocator readerAllocator = allocator.newChildAllocator("fileReader", 0, Integer.MAX_VALUE); + FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); + LOGGER.debug("reading schema: " + schema); + arrowReader.loadNextBatch(); + validateTinyData(root); + } + // Read from stream. - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); - ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator); - BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - MapVector parent = new MapVector("parent", vectorAllocator, null) - ) { - arrowReader.init(); - Schema schema = arrowReader.getSchema(); + try (BufferAllocator readerAllocator = allocator.newChildAllocator("streamReader", 0, Integer.MAX_VALUE); + ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); + ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); + LOGGER.debug("reading schema: " + schema); + arrowReader.loadNextBatch(); + validateTinyData(root); + } + } + + private void validateTinyData(VectorSchemaRoot root) { + Assert.assertEquals(16, root.getRowCount()); + NullableTinyIntVector vector = (NullableTinyIntVector) root.getFieldVectors().get(0); + for (int i = 0; i < 16; i++) { + if (i < 8) { + Assert.assertEquals((byte)(i + 1), vector.getAccessor().get(i)); + } else { + Assert.assertTrue(vector.getAccessor().isNull(i)); + } + } + } + + @Test + public void testWriteReadDictionary() throws IOException { + File file = new File("target/mytest_dict.arrow"); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + + // write + try (BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); + NullableVarCharVector vector = new NullableVarCharVector("varchar", originalVectorAllocator, null); + NullableVarCharVector dictionaryVector = new NullableVarCharVector("dict", originalVectorAllocator, null)) { + vector.allocateNewSafe(); + NullableVarCharVector.Mutator mutator = vector.getMutator(); + mutator.set(0, "foo".getBytes(StandardCharsets.UTF_8)); + mutator.set(1, "bar".getBytes(StandardCharsets.UTF_8)); + mutator.set(3, "baz".getBytes(StandardCharsets.UTF_8)); + mutator.set(4, "bar".getBytes(StandardCharsets.UTF_8)); + mutator.set(5, "baz".getBytes(StandardCharsets.UTF_8)); + mutator.setValueCount(6); + + dictionaryVector.allocateNewSafe(); + mutator = dictionaryVector.getMutator(); + mutator.set(0, "foo".getBytes(StandardCharsets.UTF_8)); + mutator.set(1, "bar".getBytes(StandardCharsets.UTF_8)); + mutator.set(2, "baz".getBytes(StandardCharsets.UTF_8)); + mutator.setValueCount(3); + + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + MapDictionaryProvider provider = new MapDictionaryProvider(); + provider.put(dictionary); + + FieldVector encodedVector = (FieldVector) DictionaryEncoder.encode(vector, dictionary); + + List fields = ImmutableList.of(encodedVector.getField()); + List vectors = ImmutableList.of(encodedVector); + VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors, 6); + + try (FileOutputStream fileOutputStream = new FileOutputStream(file); + ArrowFileWriter fileWriter = new ArrowFileWriter(root, provider, fileOutputStream.getChannel()); + ArrowStreamWriter streamWriter = new ArrowStreamWriter(root, provider, stream)) { + LOGGER.debug("writing schema: " + root.getSchema()); + fileWriter.start(); + streamWriter.start(); + fileWriter.writeBatch(); + streamWriter.writeBatch(); + fileWriter.end(); + streamWriter.end(); + } + + dictionaryVector.close(); + encodedVector.close(); + } + + // read from file + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); + arrowReader.loadNextBatch(); + validateFlatDictionary(root.getFieldVectors().get(0), arrowReader); + } + + // Read from stream + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); + ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); + LOGGER.debug("reading schema: " + schema); + arrowReader.loadNextBatch(); + validateFlatDictionary(root.getFieldVectors().get(0), arrowReader); + } + } + + private void validateFlatDictionary(FieldVector vector, DictionaryProvider provider) { + Assert.assertNotNull(vector); + + DictionaryEncoding encoding = vector.getField().getDictionary(); + Assert.assertNotNull(encoding); + Assert.assertEquals(1L, encoding.getId()); + + FieldVector.Accessor accessor = vector.getAccessor(); + Assert.assertEquals(6, accessor.getValueCount()); + Assert.assertEquals(0, accessor.getObject(0)); + Assert.assertEquals(1, accessor.getObject(1)); + Assert.assertEquals(null, accessor.getObject(2)); + Assert.assertEquals(2, accessor.getObject(3)); + Assert.assertEquals(1, accessor.getObject(4)); + Assert.assertEquals(2, accessor.getObject(5)); + + Dictionary dictionary = provider.lookup(1L); + Assert.assertNotNull(dictionary); + NullableVarCharVector.Accessor dictionaryAccessor = ((NullableVarCharVector) dictionary.getVector()).getAccessor(); + Assert.assertEquals(3, dictionaryAccessor.getValueCount()); + Assert.assertEquals(new Text("foo"), dictionaryAccessor.getObject(0)); + Assert.assertEquals(new Text("bar"), dictionaryAccessor.getObject(1)); + Assert.assertEquals(new Text("baz"), dictionaryAccessor.getObject(2)); + } - try (VectorSchemaRoot root = new VectorSchemaRoot(schema, vectorAllocator)) { - VectorLoader vectorLoader = new VectorLoader(root); - while (true) { - try (ArrowRecordBatch recordBatch = arrowReader.nextRecordBatch()) { - if (recordBatch == null) break; - vectorLoader.load(recordBatch); - } - } - validateUnionData(count, root); + @Test + public void testWriteReadNestedDictionary() throws IOException { + File file = new File("target/mytest_dict_nested.arrow"); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + + DictionaryEncoding encoding = new DictionaryEncoding(2L, false, null); + + // data being written: + // [['foo', 'bar'], ['foo'], ['bar']] -> [[0, 1], [0], [1]] + + // write + try (NullableVarCharVector dictionaryVector = new NullableVarCharVector("dictionary", allocator, null); + ListVector listVector = new ListVector("list", allocator, null, null)) { + + Dictionary dictionary = new Dictionary(dictionaryVector, encoding); + MapDictionaryProvider provider = new MapDictionaryProvider(); + provider.put(dictionary); + + dictionaryVector.allocateNew(); + dictionaryVector.getMutator().set(0, "foo".getBytes(StandardCharsets.UTF_8)); + dictionaryVector.getMutator().set(1, "bar".getBytes(StandardCharsets.UTF_8)); + dictionaryVector.getMutator().setValueCount(2); + + listVector.addOrGetVector(MinorType.INT, encoding); + listVector.allocateNew(); + UnionListWriter listWriter = new UnionListWriter(listVector); + listWriter.startList(); + listWriter.writeInt(0); + listWriter.writeInt(1); + listWriter.endList(); + listWriter.startList(); + listWriter.writeInt(0); + listWriter.endList(); + listWriter.startList(); + listWriter.writeInt(1); + listWriter.endList(); + listWriter.setValueCount(3); + + List fields = ImmutableList.of(listVector.getField()); + List vectors = ImmutableList.of((FieldVector) listVector); + VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors, 3); + + try ( + FileOutputStream fileOutputStream = new FileOutputStream(file); + ArrowFileWriter fileWriter = new ArrowFileWriter(root, provider, fileOutputStream.getChannel()); + ArrowStreamWriter streamWriter = new ArrowStreamWriter(root, provider, stream)) { + LOGGER.debug("writing schema: " + root.getSchema()); + fileWriter.start(); + streamWriter.start(); + fileWriter.writeBatch(); + streamWriter.writeBatch(); + fileWriter.end(); + streamWriter.end(); } } + + // read from file + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + FileInputStream fileInputStream = new FileInputStream(file); + ArrowFileReader arrowReader = new ArrowFileReader(fileInputStream.getChannel(), readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); + LOGGER.debug("reading schema: " + schema); + arrowReader.loadNextBatch(); + validateNestedDictionary((ListVector) root.getFieldVectors().get(0), arrowReader); + } + + // Read from stream + try (BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ByteArrayInputStream input = new ByteArrayInputStream(stream.toByteArray()); + ArrowStreamReader arrowReader = new ArrowStreamReader(input, readerAllocator)) { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + Schema schema = root.getSchema(); + LOGGER.debug("reading schema: " + schema); + arrowReader.loadNextBatch(); + validateNestedDictionary((ListVector) root.getFieldVectors().get(0), arrowReader); + } + } + + private void validateNestedDictionary(ListVector vector, DictionaryProvider provider) { + Assert.assertNotNull(vector); + Assert.assertNull(vector.getField().getDictionary()); + Field nestedField = vector.getField().getChildren().get(0); + + DictionaryEncoding encoding = nestedField.getDictionary(); + Assert.assertNotNull(encoding); + Assert.assertEquals(2L, encoding.getId()); + Assert.assertEquals(new Int(32, true), encoding.getIndexType()); + + ListVector.Accessor accessor = vector.getAccessor(); + Assert.assertEquals(3, accessor.getValueCount()); + Assert.assertEquals(Arrays.asList(0, 1), accessor.getObject(0)); + Assert.assertEquals(Arrays.asList(0), accessor.getObject(1)); + Assert.assertEquals(Arrays.asList(1), accessor.getObject(2)); + + Dictionary dictionary = provider.lookup(2L); + Assert.assertNotNull(dictionary); + NullableVarCharVector.Accessor dictionaryAccessor = ((NullableVarCharVector) dictionary.getVector()).getAccessor(); + Assert.assertEquals(2, dictionaryAccessor.getValueCount()); + Assert.assertEquals(new Text("foo"), dictionaryAccessor.getObject(0)); + Assert.assertEquals(new Text("bar"), dictionaryAccessor.getObject(1)); } /** * Writes the contents of parents to file. If outStream is non-null, also writes it * to outStream in the streaming serialized format. */ - private void write(FieldVector parent, File file, OutputStream outStream) throws FileNotFoundException, IOException { - VectorUnloader vectorUnloader = newVectorUnloader(parent); - Schema schema = vectorUnloader.getSchema(); - LOGGER.debug("writing schema: " + schema); - try ( - FileOutputStream fileOutputStream = new FileOutputStream(file); - ArrowWriter arrowWriter = new ArrowWriter(fileOutputStream.getChannel(), schema); - ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch(); - ) { - arrowWriter.writeRecordBatch(recordBatch); + private void write(FieldVector parent, File file, OutputStream outStream) throws IOException { + VectorSchemaRoot root = new VectorSchemaRoot(parent); + + try (FileOutputStream fileOutputStream = new FileOutputStream(file); + ArrowFileWriter arrowWriter = new ArrowFileWriter(root, null, fileOutputStream.getChannel());) { + LOGGER.debug("writing schema: " + root.getSchema()); + arrowWriter.start(); + arrowWriter.writeBatch(); + arrowWriter.end(); } // Also try serializing to the stream writer. if (outStream != null) { - try ( - ArrowStreamWriter arrowWriter = new ArrowStreamWriter(outStream, schema); - ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch(); - ) { - arrowWriter.writeRecordBatch(recordBatch); + try (ArrowStreamWriter arrowWriter = new ArrowStreamWriter(root, null, outStream)) { + arrowWriter.start(); + arrowWriter.writeBatch(); + arrowWriter.end(); } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowReaderWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowReaderWriter.java index 96bcbb1dae71c..27b0f6bcc502b 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowReaderWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowReaderWriter.java @@ -17,12 +17,15 @@ */ package org.apache.arrow.vector.file; +import static java.nio.channels.Channels.newChannel; import static java.util.Arrays.asList; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.Channels; @@ -34,8 +37,14 @@ import org.apache.arrow.flatbuf.RecordBatch; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.NullableIntVector; +import org.apache.arrow.vector.NullableTinyIntVector; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.schema.ArrowFieldNode; import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -68,12 +77,17 @@ byte[] array(ArrowBuf buf) { @Test public void test() throws IOException { Schema schema = new Schema(asList(new Field("testField", true, new ArrowType.Int(8, true), Collections.emptyList()))); - byte[] validity = new byte[] { (byte)255, 0}; + MinorType minorType = Types.getMinorTypeForArrowType(schema.getFields().get(0).getType()); + FieldVector vector = minorType.getNewVector("testField", allocator, null,null); + vector.initializeChildrenFromFields(schema.getFields().get(0).getChildren()); + + byte[] validity = new byte[] { (byte) 255, 0}; // second half is "undefined" byte[] values = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; ByteArrayOutputStream out = new ByteArrayOutputStream(); - try (ArrowWriter writer = new ArrowWriter(Channels.newChannel(out), schema)) { + try (VectorSchemaRoot root = new VectorSchemaRoot(schema.getFields(), asList(vector), 16); + ArrowFileWriter writer = new ArrowFileWriter(root, null, newChannel(out))) { ArrowBuf validityb = buf(validity); ArrowBuf valuesb = buf(values); writer.writeRecordBatch(new ArrowRecordBatch(16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb))); @@ -81,15 +95,15 @@ public void test() throws IOException { byte[] byteArray = out.toByteArray(); - try (ArrowReader reader = new ArrowReader(new ByteArrayReadableSeekableByteChannel(byteArray), allocator)) { - ArrowFooter footer = reader.readFooter(); - Schema readSchema = footer.getSchema(); + SeekableReadChannel channel = new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(byteArray)); + try (ArrowFileReader reader = new ArrowFileReader(channel, allocator)) { + Schema readSchema = reader.getVectorSchemaRoot().getSchema(); assertEquals(schema, readSchema); assertTrue(readSchema.getFields().get(0).getTypeLayout().getVectorTypes().toString(), readSchema.getFields().get(0).getTypeLayout().getVectors().size() > 0); // TODO: dictionaries - List recordBatches = footer.getRecordBatches(); + List recordBatches = reader.getRecordBlocks(); assertEquals(1, recordBatches.size()); - ArrowRecordBatch recordBatch = reader.readRecordBatch(recordBatches.get(0)); + ArrowRecordBatch recordBatch = (ArrowRecordBatch) reader.readMessage(channel, allocator); List nodes = recordBatch.getNodes(); assertEquals(1, nodes.size()); ArrowFieldNode node = nodes.get(0); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStream.java b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStream.java new file mode 100644 index 0000000000000..e7cdf3fea4b8b --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStream.java @@ -0,0 +1,102 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.file; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +import io.netty.buffer.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.NullableTinyIntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.schema.ArrowFieldNode; +import org.apache.arrow.vector.schema.ArrowMessage; +import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.stream.ArrowStreamReader; +import org.apache.arrow.vector.stream.ArrowStreamWriter; +import org.apache.arrow.vector.stream.MessageSerializerTest; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Test; + +public class TestArrowStream extends BaseFileTest { + @Test + public void testEmptyStream() throws IOException { + Schema schema = MessageSerializerTest.testSchema(); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + + // Write the stream. + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { + } + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator)) { + assertEquals(schema, reader.getVectorSchemaRoot().getSchema()); + // Empty should return nothing. Can be called repeatedly. + reader.loadNextBatch(); + assertEquals(0, reader.getVectorSchemaRoot().getRowCount()); + reader.loadNextBatch(); + assertEquals(0, reader.getVectorSchemaRoot().getRowCount()); + } + } + + @Test + public void testReadWrite() throws IOException { + Schema schema = MessageSerializerTest.testSchema(); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + int numBatches = 1; + + root.getFieldVectors().get(0).allocateNew(); + NullableTinyIntVector.Mutator mutator = (NullableTinyIntVector.Mutator) root.getFieldVectors().get(0).getMutator(); + for (int i = 0; i < 16; i++) { + mutator.set(i, i < 8 ? 1 : 0, (byte)(i + 1)); + } + mutator.setValueCount(16); + root.setRowCount(16); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + long bytesWritten = 0; + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { + writer.start(); + for (int i = 0; i < numBatches; i++) { + writer.writeBatch(); + } + writer.end(); + bytesWritten = writer.bytesWritten(); + } + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator)) { + Schema readSchema = reader.getVectorSchemaRoot().getSchema(); + assertEquals(schema, readSchema); + for (int i = 0; i < numBatches; i++) { + reader.loadNextBatch(); + } + // TODO figure out why reader isn't getting padding bytes + assertEquals(bytesWritten, reader.bytesRead() + 4); + reader.loadNextBatch(); + assertEquals(0, reader.getVectorSchemaRoot().getRowCount()); + } + } + } +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStreamPipe.java b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStreamPipe.java new file mode 100644 index 0000000000000..46d46794bbefa --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStreamPipe.java @@ -0,0 +1,163 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.file; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.nio.channels.Pipe; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.NullableTinyIntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.schema.ArrowMessage; +import org.apache.arrow.vector.stream.ArrowStreamReader; +import org.apache.arrow.vector.stream.ArrowStreamWriter; +import org.apache.arrow.vector.stream.MessageSerializerTest; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Test; + +public class TestArrowStreamPipe { + Schema schema = MessageSerializerTest.testSchema(); + BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); + + private final class WriterThread extends Thread { + + private final int numBatches; + private final ArrowStreamWriter writer; + private final VectorSchemaRoot root; + + public WriterThread(int numBatches, WritableByteChannel sinkChannel) + throws IOException { + this.numBatches = numBatches; + BufferAllocator allocator = alloc.newChildAllocator("writer thread", 0, Integer.MAX_VALUE); + root = VectorSchemaRoot.create(schema, allocator); + writer = new ArrowStreamWriter(root, null, sinkChannel); + } + + @Override + public void run() { + try { + writer.start(); + for (int j = 0; j < numBatches; j++) { + root.getFieldVectors().get(0).allocateNew(); + NullableTinyIntVector.Mutator mutator = (NullableTinyIntVector.Mutator) root.getFieldVectors().get(0).getMutator(); + // Send a changing batch id first + mutator.set(0, j); + for (int i = 1; i < 16; i++) { + mutator.set(i, i < 8 ? 1 : 0, (byte)(i + 1)); + } + mutator.setValueCount(16); + root.setRowCount(16); + + writer.writeBatch(); + } + writer.close(); + root.close(); + } catch (IOException e) { + e.printStackTrace(); + Assert.fail(e.toString()); // have to explicitly fail since we're in a separate thread + } + } + + public long bytesWritten() { return writer.bytesWritten(); } + } + + private final class ReaderThread extends Thread { + private int batchesRead = 0; + private final ArrowStreamReader reader; + private final BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); + private boolean done = false; + + public ReaderThread(ReadableByteChannel sourceChannel) + throws IOException { + reader = new ArrowStreamReader(sourceChannel, alloc) { + @Override + protected ArrowMessage readMessage(ReadChannel in, BufferAllocator allocator) throws IOException { + // Read all the batches. Each batch contains an incrementing id and then some + // constant data. Verify both. + ArrowMessage message = super.readMessage(in, allocator); + if (message == null) { + done = true; + } else { + batchesRead++; + } + return message; + } + @Override + public void loadNextBatch() throws IOException { + super.loadNextBatch(); + if (!done) { + VectorSchemaRoot root = getVectorSchemaRoot(); + Assert.assertEquals(16, root.getRowCount()); + NullableTinyIntVector vector = (NullableTinyIntVector) root.getFieldVectors().get(0); + Assert.assertEquals((byte)(batchesRead - 1), vector.getAccessor().get(0)); + for (int i = 1; i < 16; i++) { + if (i < 8) { + Assert.assertEquals((byte)(i + 1), vector.getAccessor().get(i)); + } else { + Assert.assertTrue(vector.getAccessor().isNull(i)); + } + } + } + } + }; + } + + @Override + public void run() { + try { + assertEquals(schema, reader.getVectorSchemaRoot().getSchema()); + assertTrue( + reader.getVectorSchemaRoot().getSchema().getFields().get(0).getTypeLayout().getVectorTypes().toString(), + reader.getVectorSchemaRoot().getSchema().getFields().get(0).getTypeLayout().getVectors().size() > 0); + while (!done) { + reader.loadNextBatch(); + } + } catch (IOException e) { + e.printStackTrace(); + Assert.fail(e.toString()); // have to explicitly fail since we're in a separate thread + } + } + + public int getBatchesRead() { return batchesRead; } + public long bytesRead() { return reader.bytesRead(); } + } + + // Starts up a producer and consumer thread to read/write batches. + @Test + public void pipeTest() throws IOException, InterruptedException { + int NUM_BATCHES = 10; + Pipe pipe = Pipe.open(); + WriterThread writer = new WriterThread(NUM_BATCHES, pipe.sink()); + ReaderThread reader = new ReaderThread(pipe.source()); + + writer.start(); + reader.start(); + reader.join(); + writer.join(); + + assertEquals(NUM_BATCHES, reader.getBatchesRead()); + assertEquals(writer.bytesWritten(), reader.bytesRead()); + } +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java b/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java index 3720a13b0fce5..c88958cbf2c9c 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java @@ -70,7 +70,7 @@ public void testWriteComplexJSON() throws IOException { int count = COUNT; try ( BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - NullableMapVector parent = new NullableMapVector("parent", vectorAllocator, null)) { + NullableMapVector parent = new NullableMapVector("parent", vectorAllocator, null, null)) { writeComplexData(count, parent); VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root")); validateComplexContent(root.getRowCount(), root); @@ -92,7 +92,7 @@ public void testWriteReadUnionJSON() throws IOException { int count = COUNT; try ( BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - NullableMapVector parent = new NullableMapVector("parent", vectorAllocator, null)) { + NullableMapVector parent = new NullableMapVector("parent", vectorAllocator, null, null)) { writeUnionData(count, parent); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/stream/MessageSerializerTest.java b/java/vector/src/test/java/org/apache/arrow/vector/stream/MessageSerializerTest.java index 7b4de80ee03ea..bb2ccf8cbb5f6 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/stream/MessageSerializerTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/stream/MessageSerializerTest.java @@ -34,6 +34,7 @@ import org.apache.arrow.vector.file.ReadChannel; import org.apache.arrow.vector.file.WriteChannel; import org.apache.arrow.vector.schema.ArrowFieldNode; +import org.apache.arrow.vector.schema.ArrowMessage; import org.apache.arrow.vector.schema.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -88,9 +89,10 @@ public void testSerializeRecordBatch() throws IOException { MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), batch); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); - ArrowRecordBatch deserialized = MessageSerializer.deserializeRecordBatch( - new ReadChannel(Channels.newChannel(in)), alloc); - verifyBatch(deserialized, validity, values); + ReadChannel channel = new ReadChannel(Channels.newChannel(in)); + ArrowMessage deserialized = MessageSerializer.deserializeMessageBatch(channel, alloc); + assertEquals(ArrowRecordBatch.class, deserialized.getClass()); + verifyBatch((ArrowRecordBatch) deserialized, validity, values); } public static Schema testSchema() { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStream.java b/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStream.java deleted file mode 100644 index 725272a0f072e..0000000000000 --- a/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStream.java +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.vector.stream; - -import static java.util.Arrays.asList; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; - -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.file.BaseFileTest; -import org.apache.arrow.vector.schema.ArrowFieldNode; -import org.apache.arrow.vector.schema.ArrowRecordBatch; -import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Test; - -import io.netty.buffer.ArrowBuf; - -public class TestArrowStream extends BaseFileTest { - @Test - public void testEmptyStream() throws IOException { - Schema schema = MessageSerializerTest.testSchema(); - - // Write the stream. - ByteArrayOutputStream out = new ByteArrayOutputStream(); - try (ArrowStreamWriter writer = new ArrowStreamWriter(out, schema)) { - } - - ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); - try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator)) { - reader.init(); - assertEquals(schema, reader.getSchema()); - // Empty should return null. Can be called repeatedly. - assertTrue(reader.nextRecordBatch() == null); - assertTrue(reader.nextRecordBatch() == null); - } - } - - @Test - public void testReadWrite() throws IOException { - Schema schema = MessageSerializerTest.testSchema(); - byte[] validity = new byte[] { (byte)255, 0}; - // second half is "undefined" - byte[] values = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - - int numBatches = 5; - BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - long bytesWritten = 0; - try (ArrowStreamWriter writer = new ArrowStreamWriter(out, schema)) { - ArrowBuf validityb = MessageSerializerTest.buf(alloc, validity); - ArrowBuf valuesb = MessageSerializerTest.buf(alloc, values); - for (int i = 0; i < numBatches; i++) { - writer.writeRecordBatch(new ArrowRecordBatch( - 16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb))); - } - bytesWritten = writer.bytesWritten(); - } - - ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); - try (ArrowStreamReader reader = new ArrowStreamReader(in, alloc)) { - reader.init(); - Schema readSchema = reader.getSchema(); - for (int i = 0; i < numBatches; i++) { - assertEquals(schema, readSchema); - assertTrue( - readSchema.getFields().get(0).getTypeLayout().getVectorTypes().toString(), - readSchema.getFields().get(0).getTypeLayout().getVectors().size() > 0); - ArrowRecordBatch recordBatch = reader.nextRecordBatch(); - MessageSerializerTest.verifyBatch(recordBatch, validity, values); - assertTrue(recordBatch != null); - } - assertTrue(reader.nextRecordBatch() == null); - assertEquals(bytesWritten, reader.bytesRead()); - } - } -} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStreamPipe.java b/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStreamPipe.java deleted file mode 100644 index aa0b77e46a392..0000000000000 --- a/java/vector/src/test/java/org/apache/arrow/vector/stream/TestArrowStreamPipe.java +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.vector.stream; - -import static java.util.Arrays.asList; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import java.io.IOException; -import java.nio.channels.Pipe; -import java.nio.channels.ReadableByteChannel; -import java.nio.channels.WritableByteChannel; - -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.schema.ArrowFieldNode; -import org.apache.arrow.vector.schema.ArrowRecordBatch; -import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Test; - -import io.netty.buffer.ArrowBuf; - -public class TestArrowStreamPipe { - Schema schema = MessageSerializerTest.testSchema(); - // second half is "undefined" - byte[] values = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - - private final class WriterThread extends Thread { - private final int numBatches; - private final ArrowStreamWriter writer; - - public WriterThread(int numBatches, WritableByteChannel sinkChannel) - throws IOException { - this.numBatches = numBatches; - writer = new ArrowStreamWriter(sinkChannel, schema); - } - - @Override - public void run() { - BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); - try { - ArrowBuf valuesb = MessageSerializerTest.buf(alloc, values); - for (int i = 0; i < numBatches; i++) { - // Send a changing byte id first. - byte[] validity = new byte[] { (byte)i, 0}; - ArrowBuf validityb = MessageSerializerTest.buf(alloc, validity); - writer.writeRecordBatch(new ArrowRecordBatch( - 16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb))); - } - writer.close(); - } catch (IOException e) { - e.printStackTrace(); - assertTrue(false); - } - } - - public long bytesWritten() { return writer.bytesWritten(); } - } - - private final class ReaderThread extends Thread { - private int batchesRead = 0; - private final ArrowStreamReader reader; - private final BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE); - - public ReaderThread(ReadableByteChannel sourceChannel) - throws IOException { - reader = new ArrowStreamReader(sourceChannel, alloc); - } - - @Override - public void run() { - try { - reader.init(); - assertEquals(schema, reader.getSchema()); - assertTrue( - reader.getSchema().getFields().get(0).getTypeLayout().getVectorTypes().toString(), - reader.getSchema().getFields().get(0).getTypeLayout().getVectors().size() > 0); - - // Read all the batches. Each batch contains an incrementing id and then some - // constant data. Verify both. - while (true) { - ArrowRecordBatch batch = reader.nextRecordBatch(); - if (batch == null) break; - byte[] validity = new byte[] { (byte)batchesRead, 0}; - MessageSerializerTest.verifyBatch(batch, validity, values); - batchesRead++; - } - } catch (IOException e) { - e.printStackTrace(); - assertTrue(false); - } - } - - public int getBatchesRead() { return batchesRead; } - public long bytesRead() { return reader.bytesRead(); } - } - - // Starts up a producer and consumer thread to read/write batches. - @Test - public void pipeTest() throws IOException, InterruptedException { - int NUM_BATCHES = 10; - Pipe pipe = Pipe.open(); - WriterThread writer = new WriterThread(NUM_BATCHES, pipe.sink()); - ReaderThread reader = new ReaderThread(pipe.source()); - - writer.start(); - reader.start(); - reader.join(); - writer.join(); - - assertEquals(NUM_BATCHES, reader.getBatchesRead()); - assertEquals(writer.bytesWritten(), reader.bytesRead()); - } -}