Skip to content

Commit

Permalink
Change parquet writers to use standard std:io::Write rather custom …
Browse files Browse the repository at this point in the history
…`ParquetWriter` trait (#1717) (#1163) (#1719)

* Rustify parquet writer (#1717) (#1163)

* Fix parquet_derive

* Fix benches

* Fix parquet_derive tests

* Use raw vec instead of Cursor

* Review feedback

* Fix unnecessary unwrap
  • Loading branch information
tustvold authored May 25, 2022
1 parent 5cf06bf commit 722fcfc
Show file tree
Hide file tree
Showing 18 changed files with 448 additions and 604 deletions.
8 changes: 3 additions & 5 deletions parquet/benches/arrow_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ use std::sync::Arc;

use arrow::datatypes::*;
use arrow::{record_batch::RecordBatch, util::data_gen::*};
use parquet::{
arrow::ArrowWriter, errors::Result, file::writer::InMemoryWriteableCursor,
};
use parquet::{arrow::ArrowWriter, errors::Result};

fn create_primitive_bench_batch(
size: usize,
Expand Down Expand Up @@ -278,8 +276,8 @@ fn _create_nested_bench_batch(
#[inline]
fn write_batch(batch: &RecordBatch) -> Result<()> {
// Write batch to an in-memory writer
let cursor = InMemoryWriteableCursor::default();
let mut writer = ArrowWriter::try_new(cursor, batch.schema(), None)?;
let buffer = vec![];
let mut writer = ArrowWriter::try_new(buffer, batch.schema(), None)?;

writer.write(batch)?;
writer.close()?;
Expand Down
2 changes: 1 addition & 1 deletion parquet/src/arrow/array_reader/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ mod tests {
.set_max_row_group_size(200)
.build();

let mut writer = ArrowWriter::try_new(
let writer = ArrowWriter::try_new(
file.try_clone().unwrap(),
Arc::new(arrow_schema),
Some(props),
Expand Down
39 changes: 22 additions & 17 deletions parquet/src/arrow/arrow_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,14 @@ mod tests {
use crate::arrow::schema::add_encoded_arrow_schema_to_metadata;
use crate::arrow::{ArrowWriter, ProjectionMask};
use crate::basic::{ConvertedType, Encoding, Repetition, Type as PhysicalType};
use crate::column::writer::get_typed_column_writer_mut;
use crate::data_type::{
BoolType, ByteArray, ByteArrayType, DataType, FixedLenByteArray,
FixedLenByteArrayType, Int32Type, Int64Type,
};
use crate::errors::Result;
use crate::file::properties::{WriterProperties, WriterVersion};
use crate::file::reader::{FileReader, SerializedFileReader};
use crate::file::writer::{FileWriter, SerializedFileWriter};
use crate::file::writer::SerializedFileWriter;
use crate::schema::parser::parse_message_type;
use crate::schema::types::{Type, TypePtr};
use crate::util::cursor::SliceableCursor;
Expand Down Expand Up @@ -936,21 +935,24 @@ mod tests {
for (idx, v) in values.iter().enumerate() {
let def_levels = def_levels.map(|d| d[idx].as_slice());
let mut row_group_writer = writer.next_row_group()?;
let mut column_writer = row_group_writer
.next_column()?
.expect("Column writer is none!");
{
let mut column_writer = row_group_writer
.next_column()?
.expect("Column writer is none!");

get_typed_column_writer_mut::<T>(&mut column_writer)
.write_batch(v, def_levels, None)?;
column_writer
.typed::<T>()
.write_batch(v, def_levels, None)?;

row_group_writer.close_column(column_writer)?;
writer.close_row_group(row_group_writer)?
column_writer.close()?;
}
row_group_writer.close()?;
}

writer.close()
}

fn get_test_reader(file_name: &str) -> Arc<dyn FileReader> {
fn get_test_reader(file_name: &str) -> Arc<SerializedFileReader<File>> {
let file = get_test_file(file_name);

let reader =
Expand Down Expand Up @@ -1094,15 +1096,18 @@ mod tests {
)
.unwrap();

let mut row_group_writer = writer.next_row_group().unwrap();
let mut column_writer = row_group_writer.next_column().unwrap().unwrap();
{
let mut row_group_writer = writer.next_row_group().unwrap();
let mut column_writer = row_group_writer.next_column().unwrap().unwrap();

get_typed_column_writer_mut::<Int32Type>(&mut column_writer)
.write_batch(&[34, 76], Some(&[0, 1, 0, 1]), None)
.unwrap();
column_writer
.typed::<Int32Type>()
.write_batch(&[34, 76], Some(&[0, 1, 0, 1]), None)
.unwrap();

row_group_writer.close_column(column_writer).unwrap();
writer.close_row_group(row_group_writer).unwrap();
column_writer.close().unwrap();
row_group_writer.close().unwrap();
}

writer.close().unwrap();
}
Expand Down
44 changes: 21 additions & 23 deletions parquet/src/arrow/arrow_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Contains writer which writes arrow data into parquet data.

use std::collections::VecDeque;
use std::io::Write;
use std::sync::Arc;

use arrow::array as arrow_array;
Expand All @@ -35,18 +36,16 @@ use super::schema::{
use crate::column::writer::ColumnWriter;
use crate::errors::{ParquetError, Result};
use crate::file::properties::WriterProperties;
use crate::{
data_type::*,
file::writer::{FileWriter, ParquetWriter, RowGroupWriter, SerializedFileWriter},
};
use crate::file::writer::{SerializedColumnWriter, SerializedRowGroupWriter};
use crate::{data_type::*, file::writer::SerializedFileWriter};

/// Arrow writer
///
/// Writes Arrow `RecordBatch`es to a Parquet writer, buffering up `RecordBatch` in order
/// to produce row groups with `max_row_group_size` rows. Any remaining rows will be
/// flushed on close, leading the final row group in the output file to potentially
/// contain fewer than `max_row_group_size` rows
pub struct ArrowWriter<W: ParquetWriter> {
pub struct ArrowWriter<W: Write> {
/// Underlying Parquet writer
writer: SerializedFileWriter<W>,

Expand All @@ -65,7 +64,7 @@ pub struct ArrowWriter<W: ParquetWriter> {
max_row_group_size: usize,
}

impl<W: 'static + ParquetWriter> ArrowWriter<W> {
impl<W: Write> ArrowWriter<W> {
/// Try to create a new Arrow writer
///
/// The writer will fail if:
Expand Down Expand Up @@ -185,33 +184,35 @@ impl<W: 'static + ParquetWriter> ArrowWriter<W> {
})
.collect();

write_leaves(row_group_writer.as_mut(), &arrays, &mut levels)?;
write_leaves(&mut row_group_writer, &arrays, &mut levels)?;
}

self.writer.close_row_group(row_group_writer)?;
row_group_writer.close()?;
self.buffered_rows -= num_rows;

Ok(())
}

/// Close and finalize the underlying Parquet writer
pub fn close(&mut self) -> Result<parquet_format::FileMetaData> {
pub fn close(mut self) -> Result<parquet_format::FileMetaData> {
self.flush()?;
self.writer.close()
}
}

/// Convenience method to get the next ColumnWriter from the RowGroupWriter
#[inline]
fn get_col_writer(row_group_writer: &mut dyn RowGroupWriter) -> Result<ColumnWriter> {
fn get_col_writer<'a, W: Write>(
row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
) -> Result<SerializedColumnWriter<'a>> {
let col_writer = row_group_writer
.next_column()?
.expect("Unable to get column writer");
Ok(col_writer)
}

fn write_leaves(
row_group_writer: &mut dyn RowGroupWriter,
fn write_leaves<W: Write>(
row_group_writer: &mut SerializedRowGroupWriter<'_, W>,
arrays: &[ArrayRef],
levels: &mut [Vec<LevelInfo>],
) -> Result<()> {
Expand Down Expand Up @@ -250,12 +251,12 @@ fn write_leaves(
let mut col_writer = get_col_writer(row_group_writer)?;
for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
write_leaf(
&mut col_writer,
col_writer.untyped(),
array,
levels.pop().expect("Levels exhausted"),
)?;
}
row_group_writer.close_column(col_writer)?;
col_writer.close()?;
Ok(())
}
ArrowDataType::List(_) | ArrowDataType::LargeList(_) => {
Expand Down Expand Up @@ -313,12 +314,12 @@ fn write_leaves(
// cast dictionary to a primitive
let array = arrow::compute::cast(array, value_type)?;
write_leaf(
&mut col_writer,
col_writer.untyped(),
&array,
levels.pop().expect("Levels exhausted"),
)?;
}
row_group_writer.close_column(col_writer)?;
col_writer.close()?;
Ok(())
}
ArrowDataType::Float16 => Err(ParquetError::ArrowError(
Expand All @@ -336,8 +337,8 @@ fn write_leaves(
}

fn write_leaf(
writer: &mut ColumnWriter,
column: &arrow_array::ArrayRef,
writer: &mut ColumnWriter<'_>,
column: &ArrayRef,
levels: LevelInfo,
) -> Result<i64> {
let indices = levels.filter_array_indices();
Expand Down Expand Up @@ -705,7 +706,6 @@ mod tests {
use crate::file::{
reader::{FileReader, SerializedFileReader},
statistics::Statistics,
writer::InMemoryWriteableCursor,
};

#[test]
Expand Down Expand Up @@ -744,16 +744,14 @@ mod tests {
let expected_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(a), Arc::new(b)]).unwrap();

let cursor = InMemoryWriteableCursor::default();
let mut buffer = vec![];

{
let mut writer = ArrowWriter::try_new(cursor.clone(), schema, None).unwrap();
let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap();
writer.write(&expected_batch).unwrap();
writer.close().unwrap();
}

let buffer = cursor.into_inner().unwrap();

let cursor = crate::file::serialized_reader::SliceableCursor::new(buffer);
let reader = SerializedFileReader::new(cursor).unwrap();
let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(reader));
Expand Down
4 changes: 2 additions & 2 deletions parquet/src/arrow/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,7 @@ mod tests {

// write to an empty parquet file so that schema is serialized
let file = tempfile::tempfile().unwrap();
let mut writer = ArrowWriter::try_new(
let writer = ArrowWriter::try_new(
file.try_clone().unwrap(),
Arc::new(schema.clone()),
None,
Expand Down Expand Up @@ -1660,7 +1660,7 @@ mod tests {

// write to an empty parquet file so that schema is serialized
let file = tempfile::tempfile().unwrap();
let mut writer = ArrowWriter::try_new(
let writer = ArrowWriter::try_new(
file.try_clone().unwrap(),
Arc::new(schema.clone()),
None,
Expand Down
22 changes: 10 additions & 12 deletions parquet/src/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@
//!
//! use parquet::{
//! column::{reader::ColumnReader, writer::ColumnWriter},
//! data_type::Int32Type,
//! file::{
//! properties::WriterProperties,
//! reader::{FileReader, SerializedFileReader},
//! writer::{FileWriter, SerializedFileWriter},
//! writer::SerializedFileWriter,
//! },
//! schema::parser::parse_message_type,
//! };
Expand All @@ -65,20 +66,17 @@
//! let props = Arc::new(WriterProperties::builder().build());
//! let file = fs::File::create(path).unwrap();
//! let mut writer = SerializedFileWriter::new(file, schema, props).unwrap();
//!
//! let mut row_group_writer = writer.next_row_group().unwrap();
//! while let Some(mut col_writer) = row_group_writer.next_column().unwrap() {
//! match col_writer {
//! // You can also use `get_typed_column_writer` method to extract typed writer.
//! ColumnWriter::Int32ColumnWriter(ref mut typed_writer) => {
//! typed_writer
//! .write_batch(&[1, 2, 3], Some(&[3, 3, 3, 2, 2]), Some(&[0, 1, 0, 1, 1]))
//! .unwrap();
//! }
//! _ => {}
//! }
//! row_group_writer.close_column(col_writer).unwrap();
//! col_writer
//! .typed::<Int32Type>()
//! .write_batch(&[1, 2, 3], Some(&[3, 3, 3, 2, 2]), Some(&[0, 1, 0, 1, 1]))
//! .unwrap();
//! col_writer.close().unwrap();
//! }
//! writer.close_row_group(row_group_writer).unwrap();
//! row_group_writer.close().unwrap();
//!
//! writer.close().unwrap();
//!
//! // Reading data using column reader API.
Expand Down
Loading

0 comments on commit 722fcfc

Please sign in to comment.