diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index 552e3f1615c7..f26cde05b7ab 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -171,23 +171,33 @@ impl Buffer { /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`. /// Doing so allows the same memory region to be shared between buffers. + /// /// # Panics + /// /// Panics iff `offset` is larger than `len`. pub fn slice(&self, offset: usize) -> Self { + let mut s = self.clone(); + s.advance(offset); + s + } + + /// Increases the offset of this buffer by `offset` + /// + /// # Panics + /// + /// Panics iff `offset` is larger than `len`. + #[inline] + pub fn advance(&mut self, offset: usize) { assert!( offset <= self.length, "the offset of the new Buffer cannot exceed the existing length" ); + self.length -= offset; // Safety: // This cannot overflow as // `self.offset + self.length < self.data.len()` // `offset < self.length` - let ptr = unsafe { self.ptr.add(offset) }; - Self { - data: self.data.clone(), - length: self.length - offset, - ptr, - } + self.ptr = unsafe { self.ptr.add(offset) }; } /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`, diff --git a/arrow-integration-testing/tests/ipc_reader.rs b/arrow-integration-testing/tests/ipc_reader.rs index 88cdad64f92f..a683075990c7 100644 --- a/arrow-integration-testing/tests/ipc_reader.rs +++ b/arrow-integration-testing/tests/ipc_reader.rs @@ -19,10 +19,12 @@ //! in `testing/arrow-ipc-stream/integration/...` use arrow::error::ArrowError; -use arrow::ipc::reader::{FileReader, StreamReader}; +use arrow::ipc::reader::{FileReader, StreamDecoder, StreamReader}; use arrow::util::test_util::arrow_test_data; +use arrow_buffer::Buffer; use arrow_integration_testing::read_gzip_json; use std::fs::File; +use std::io::Read; #[test] fn read_0_1_4() { @@ -182,18 +184,45 @@ fn verify_arrow_stream(testdata: &str, version: &str, path: &str) { let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.stream"); println!("Verifying {filename}"); + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + // Compare contents to the expected output format in JSON { println!(" verifying content"); let file = File::open(&filename).unwrap(); let mut reader = StreamReader::try_new(file, None).unwrap(); - // read expected JSON output - let arrow_json = read_gzip_json(version, path); assert!(arrow_json.equals_reader(&mut reader).unwrap()); // the next batch must be empty assert!(reader.next().is_none()); // the stream must indicate that it's finished assert!(reader.is_finished()); } + + // Test stream decoder + let expected = arrow_json.get_record_batches().unwrap(); + for chunk_sizes in [1, 2, 8, 123] { + let mut decoder = StreamDecoder::new(); + let stream = chunked_file(&filename, chunk_sizes); + let mut actual = Vec::with_capacity(expected.len()); + for mut x in stream { + while !x.is_empty() { + if let Some(x) = decoder.decode(&mut x).unwrap() { + actual.push(x); + } + } + } + decoder.finish().unwrap(); + assert_eq!(expected, actual); + } +} + +fn chunked_file(filename: &str, chunk_size: u64) -> impl Iterator { + let mut file = File::open(filename).unwrap(); + std::iter::from_fn(move || { + let mut buf = vec![]; + let read = (&mut file).take(chunk_size).read_to_end(&mut buf).unwrap(); + (read != 0).then(|| Buffer::from_vec(buf)) + }) } diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index b2e580241adc..51e54215ea7f 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -17,12 +17,17 @@ //! Utilities for converting between IPC types and native Arrow types +use arrow_buffer::Buffer; use arrow_schema::*; -use flatbuffers::{FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, WIPOffset}; +use flatbuffers::{ + FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, Verifiable, Verifier, + VerifierOptions, WIPOffset, +}; use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use crate::{size_prefixed_root_as_message, KeyValue, CONTINUATION_MARKER}; +use crate::{size_prefixed_root_as_message, KeyValue, Message, CONTINUATION_MARKER}; use DataType::*; /// Serialize a schema in IPC format @@ -806,6 +811,45 @@ pub(crate) fn get_fb_dictionary<'a>( builder.finish() } +/// An owned container for a validated [`Message`] +/// +/// Safely decoding a flatbuffer requires validating the various embedded offsets, +/// see [`Verifier`]. This is a potentially expensive operation, and it is therefore desirable +/// to only do this once. [`crate::root_as_message`] performs this validation on construction, +/// however, it returns a [`Message`] borrowing the provided byte slice. This prevents +/// storing this [`Message`] in the same data structure that owns the buffer, as this +/// would require self-referential borrows. +/// +/// [`MessageBuffer`] solves this problem by providing a safe API for a [`Message`] +/// without a lifetime bound. +#[derive(Clone)] +pub struct MessageBuffer(Buffer); + +impl Debug for MessageBuffer { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.as_ref().fmt(f) + } +} + +impl MessageBuffer { + /// Try to create a [`MessageBuffer`] from the provided [`Buffer`] + pub fn try_new(buf: Buffer) -> Result { + let opts = VerifierOptions::default(); + let mut v = Verifier::new(&opts, &buf); + >::run_verifier(&mut v, 0).map_err(|err| { + ArrowError::ParseError(format!("Unable to get root as message: {err:?}")) + })?; + Ok(Self(buf)) + } + + /// Return the [`Message`] + #[inline] + pub fn as_ref(&self) -> Message<'_> { + // SAFETY: Run verifier on construction + unsafe { crate::root_as_message_unchecked(&self.0) } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index f015674d6813..dd0365da4bc7 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -20,6 +20,10 @@ //! The `FileReader` and `StreamReader` have similar interfaces, //! however the `FileReader` expects a reader that supports `Seek`ing +mod stream; + +pub use stream::*; + use flatbuffers::{VectorIter, VerifierOptions}; use std::collections::HashMap; use std::fmt; diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs new file mode 100644 index 000000000000..7807228175ac --- /dev/null +++ b/arrow-ipc/src/reader/stream.rs @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::Arc; + +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_schema::{ArrowError, SchemaRef}; + +use crate::convert::MessageBuffer; +use crate::reader::{read_dictionary, read_record_batch}; +use crate::{MessageHeader, CONTINUATION_MARKER}; + +/// A low-level interface for reading [`RecordBatch`] data from a stream of bytes +/// +/// See [StreamReader](crate::reader::StreamReader) for a higher-level interface +#[derive(Debug, Default)] +pub struct StreamDecoder { + /// The schema of this decoder, if read + schema: Option, + /// Lookup table for dictionaries by ID + dictionaries: HashMap, + /// The decoder state + state: DecoderState, + /// A scratch buffer when a read is split across multiple `Buffer` + buf: MutableBuffer, +} + +#[derive(Debug)] +enum DecoderState { + /// Decoding the message header + Header { + /// Temporary buffer + buf: [u8; 4], + /// Number of bytes read into buf + read: u8, + /// If we have read a continuation token + continuation: bool, + }, + /// Decoding the message flatbuffer + Message { + /// The size of the message flatbuffer + size: u32, + }, + /// Decoding the message body + Body { + /// The message flatbuffer + message: MessageBuffer, + }, + /// Reached the end of the stream + Finished, +} + +impl Default for DecoderState { + fn default() -> Self { + Self::Header { + buf: [0; 4], + read: 0, + continuation: false, + } + } +} + +impl StreamDecoder { + /// Create a new [`StreamDecoder`] + pub fn new() -> Self { + Self::default() + } + + /// Try to read the next [`RecordBatch`] from the provided [`Buffer`] + /// + /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes. + /// + /// The push-based interface facilitates integration with sources that yield arbitrarily + /// delimited bytes ranges, such as a chunked byte stream received from object storage + /// + /// ``` + /// # use arrow_array::RecordBatch; + /// # use arrow_buffer::Buffer; + /// # use arrow_ipc::reader::StreamDecoder; + /// # use arrow_schema::ArrowError; + /// # + /// fn print_stream(src: impl Iterator) -> Result<(), ArrowError> { + /// let mut decoder = StreamDecoder::new(); + /// for mut x in src { + /// while !x.is_empty() { + /// if let Some(x) = decoder.decode(&mut x)? { + /// println!("{x:?}"); + /// } + /// } + /// } + /// decoder.finish().unwrap(); + /// Ok(()) + /// } + /// ``` + pub fn decode(&mut self, buffer: &mut Buffer) -> Result, ArrowError> { + while !buffer.is_empty() { + match &mut self.state { + DecoderState::Header { + buf, + read, + continuation, + } => { + let offset_buf = &mut buf[*read as usize..]; + let to_read = buffer.len().min(offset_buf.len()); + offset_buf[..to_read].copy_from_slice(&buffer[..to_read]); + *read += to_read as u8; + buffer.advance(to_read); + if *read == 4 { + if !*continuation && buf == &CONTINUATION_MARKER { + *continuation = true; + *read = 0; + continue; + } + let size = u32::from_le_bytes(*buf); + + if size == 0 { + self.state = DecoderState::Finished; + continue; + } + self.state = DecoderState::Message { size }; + } + } + DecoderState::Message { size } => { + let len = *size as usize; + if self.buf.is_empty() && buffer.len() > len { + let message = MessageBuffer::try_new(buffer.slice_with_length(0, len))?; + self.state = DecoderState::Body { message }; + buffer.advance(len); + continue; + } + + let to_read = buffer.len().min(len - self.buf.len()); + self.buf.extend_from_slice(&buffer[..to_read]); + buffer.advance(to_read); + if self.buf.len() == len { + let message = MessageBuffer::try_new(std::mem::take(&mut self.buf).into())?; + self.state = DecoderState::Body { message }; + } + } + DecoderState::Body { message } => { + let message = message.as_ref(); + let body_length = message.bodyLength() as usize; + + let body = if self.buf.is_empty() && buffer.len() >= body_length { + let body = buffer.slice_with_length(0, body_length); + buffer.advance(body_length); + body + } else { + let to_read = buffer.len().min(body_length - self.buf.len()); + self.buf.extend_from_slice(&buffer[..to_read]); + buffer.advance(to_read); + + if self.buf.len() != body_length { + continue; + } + std::mem::take(&mut self.buf).into() + }; + + let version = message.version(); + match message.header_type() { + MessageHeader::Schema => { + if self.schema.is_some() { + return Err(ArrowError::IpcError( + "Not expecting a schema when messages are read".to_string(), + )); + } + + let ipc_schema = message.header_as_schema().unwrap(); + let schema = crate::convert::fb_to_schema(ipc_schema); + self.state = DecoderState::default(); + self.schema = Some(Arc::new(schema)); + } + MessageHeader::RecordBatch => { + let batch = message.header_as_record_batch().unwrap(); + let schema = self.schema.clone().ok_or_else(|| { + ArrowError::IpcError("Missing schema".to_string()) + })?; + let batch = read_record_batch( + &body, + batch, + schema, + &self.dictionaries, + None, + &version, + )?; + self.state = DecoderState::default(); + return Ok(Some(batch)); + } + MessageHeader::DictionaryBatch => { + let dictionary = message.header_as_dictionary_batch().unwrap(); + let schema = self.schema.as_deref().ok_or_else(|| { + ArrowError::IpcError("Missing schema".to_string()) + })?; + read_dictionary( + &body, + dictionary, + schema, + &mut self.dictionaries, + &version, + )?; + self.state = DecoderState::default(); + } + MessageHeader::NONE => { + self.state = DecoderState::default(); + } + t => { + return Err(ArrowError::IpcError(format!( + "Message type unsupported by StreamDecoder: {t:?}" + ))) + } + } + } + DecoderState::Finished => { + return Err(ArrowError::IpcError("Unexpected EOS".to_string())) + } + } + } + Ok(None) + } + + /// Signal the end of stream + /// + /// Returns an error if any partial data remains in the stream + pub fn finish(&mut self) -> Result<(), ArrowError> { + match self.state { + DecoderState::Finished + | DecoderState::Header { + read: 0, + continuation: false, + .. + } => Ok(()), + _ => Err(ArrowError::IpcError("Unexpected End of Stream".to_string())), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::writer::StreamWriter; + use arrow_array::{Int32Array, Int64Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema}; + + // Further tests in arrow-integration-testing/tests/ipc_reader.rs + + #[test] + fn test_eos() { + let schema = Arc::new(Schema::new(vec![ + Field::new("int32", DataType::Int32, false), + Field::new("int64", DataType::Int64, false), + ])); + + let input = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])) as _, + Arc::new(Int64Array::from(vec![1, 2, 3])) as _, + ], + ) + .unwrap(); + + let mut buf = Vec::with_capacity(1024); + let mut s = StreamWriter::try_new(&mut buf, &schema).unwrap(); + s.write(&input).unwrap(); + s.finish().unwrap(); + drop(s); + + let buffer = Buffer::from_vec(buf); + + let mut b = buffer.slice_with_length(0, buffer.len() - 1); + let mut decoder = StreamDecoder::new(); + let output = decoder.decode(&mut b).unwrap().unwrap(); + assert_eq!(output, input); + assert_eq!(b.len(), 7); // 8 byte EOS truncated by 1 byte + assert!(decoder.decode(&mut b).unwrap().is_none()); + + let err = decoder.finish().unwrap_err().to_string(); + assert_eq!(err, "Ipc error: Unexpected End of Stream"); + } +} diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 4e32b04b0fba..22edfbc2454d 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -677,6 +677,7 @@ impl DictionaryTracker { } } +/// Writer for an IPC file pub struct FileWriter { /// The object to write to writer: BufWriter, @@ -701,13 +702,13 @@ pub struct FileWriter { } impl FileWriter { - /// Try create a new writer, with the schema written as part of the header + /// Try to create a new writer, with the schema written as part of the header pub fn try_new(writer: W, schema: &Schema) -> Result { let write_options = IpcWriteOptions::default(); Self::try_new_with_options(writer, schema, write_options) } - /// Try create a new writer with IpcWriteOptions + /// Try to create a new writer with IpcWriteOptions pub fn try_new_with_options( writer: W, schema: &Schema, @@ -857,6 +858,7 @@ impl RecordBatchWriter for FileWriter { } } +/// Writer for an IPC stream pub struct StreamWriter { /// The object to write to writer: BufWriter, @@ -871,7 +873,7 @@ pub struct StreamWriter { } impl StreamWriter { - /// Try create a new writer, with the schema written as part of the header + /// Try to create a new writer, with the schema written as part of the header pub fn try_new(writer: W, schema: &Schema) -> Result { let write_options = IpcWriteOptions::default(); Self::try_new_with_options(writer, schema, write_options) diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 99055573345a..628e5c96693d 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -416,7 +416,7 @@ impl Decoder { /// should be included in the next call to [`Self::decode`] /// /// There is no requirement that `buf` contains a whole number of records, facilitating - /// integration with arbitrary byte streams, such as that yielded by [`BufRead`] + /// integration with arbitrary byte streams, such as those yielded by [`BufRead`] pub fn decode(&mut self, buf: &[u8]) -> Result { self.tape_decoder.decode(buf) }