Skip to content

Commit

Permalink
roundtrip tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nevi-me committed Jul 3, 2021
1 parent 4885c32 commit 9f7d2f9
Showing 1 changed file with 35 additions and 87 deletions.
122 changes: 35 additions & 87 deletions parquet/src/arrow/arrow_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ fn get_fsb_array_slice(
mod tests {
use super::*;

use std::fs::File;
use std::sync::Arc;
use std::{fs::File, io::Seek};

use arrow::datatypes::ToByteSlice;
use arrow::datatypes::{DataType, Field, Schema, UInt32Type, UInt8Type};
Expand Down Expand Up @@ -592,16 +592,11 @@ mod tests {
let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]);

// build a record batch
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(a), Arc::new(b)],
)
.unwrap();
let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
.unwrap();

let file = get_temp_file("test_arrow_writer.parquet", &[]);
let mut writer = ArrowWriter::try_new(file, Arc::new(schema), None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
roundtrip("test_arrow_write.parquet", batch, Some(SMALL_SIZE / 2));
}

#[test]
Expand Down Expand Up @@ -660,22 +655,22 @@ mod tests {
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);

// build a record batch
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)]).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();

let file = get_temp_file("test_arrow_writer_non_null.parquet", &[]);
let mut writer = ArrowWriter::try_new(file, Arc::new(schema), None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
roundtrip(
"test_arrow_writer_non_null.parquet",
batch,
Some(SMALL_SIZE / 2),
);
}

#[test]
fn arrow_writer_list() {
// define schema
let schema = Schema::new(vec![Field::new(
"a",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
false,
DataType::List(Box::new(Field::new("item", DataType::Int32, false))),
true,
)]);

// create some data
Expand All @@ -690,7 +685,7 @@ mod tests {
let a_list_data = ArrayData::builder(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
true,
false,
))))
.len(5)
.add_buffer(a_value_offsets)
Expand All @@ -700,15 +695,13 @@ mod tests {
let a = ListArray::from(a_list_data);

// build a record batch
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)]).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();

assert_eq!(batch.column(0).data().null_count(), 1);

let file = get_temp_file("test_arrow_writer_list.parquet", &[]);
let mut writer = ArrowWriter::try_new(file, Arc::new(schema), None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
// This test fails if the max row group size is less than the batch's length
// see https://github.com/apache/arrow-rs/issues/518
roundtrip("test_arrow_writer_list.parquet", batch, None);
}

#[test]
Expand Down Expand Up @@ -741,15 +734,13 @@ mod tests {
let a = ListArray::from(a_list_data);

// build a record batch
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)]).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();

// This test fails if the max row group size is less than the batch's length
// see https://github.com/apache/arrow-rs/issues/518
assert_eq!(batch.column(0).data().null_count(), 0);

let file = get_temp_file("test_arrow_writer_list_non_null.parquet", &[]);
let mut writer = ArrowWriter::try_new(file, Arc::new(schema), None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
roundtrip("test_arrow_writer_list_non_null.parquet", batch, None);
}

#[test]
Expand All @@ -773,39 +764,16 @@ mod tests {
let string_values = StringArray::from(raw_string_values.clone());
let binary_values = BinaryArray::from(raw_binary_value_refs);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
Arc::new(schema),
vec![Arc::new(string_values), Arc::new(binary_values)],
)
.unwrap();

let mut file = get_temp_file("test_arrow_writer_binary.parquet", &[]);
let mut writer =
ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema), None)
.unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();

file.seek(std::io::SeekFrom::Start(0)).unwrap();
let file_reader = SerializedFileReader::new(file).unwrap();
let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(file_reader));
let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap();

let batch = record_batch_reader.next().unwrap().unwrap();
let string_col = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let binary_col = batch
.column(1)
.as_any()
.downcast_ref::<BinaryArray>()
.unwrap();

for i in 0..batch.num_rows() {
assert_eq!(string_col.value(i), raw_string_values[i]);
assert_eq!(binary_col.value(i), raw_binary_values[i].as_slice());
}
roundtrip(
"test_arrow_writer_binary.parquet",
batch,
Some(SMALL_SIZE / 2),
);
}

#[test]
Expand All @@ -819,36 +787,16 @@ mod tests {
dec_builder.append_value(0).unwrap();
dec_builder.append_value(-100).unwrap();

let raw_decimal_i128_values: Vec<i128> = vec![10_000, 50_000, 0, -100];
let decimal_values = dec_builder.finish();
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(decimal_values)],
)
.unwrap();

let mut file = get_temp_file("test_arrow_writer_decimal.parquet", &[]);
let mut writer =
ArrowWriter::try_new(file.try_clone().unwrap(), Arc::new(schema), None)
let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(decimal_values)])
.unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();

file.seek(std::io::SeekFrom::Start(0)).unwrap();
let file_reader = SerializedFileReader::new(file).unwrap();
let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(file_reader));
let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap();

let batch = record_batch_reader.next().unwrap().unwrap();
let decimal_col = batch
.column(0)
.as_any()
.downcast_ref::<DecimalArray>()
.unwrap();

for i in 0..batch.num_rows() {
assert_eq!(decimal_col.value(i), raw_decimal_i128_values[i]);
}
roundtrip(
"test_arrow_writer_decimal.parquet",
batch,
Some(SMALL_SIZE / 2),
);
}

#[test]
Expand Down

0 comments on commit 9f7d2f9

Please sign in to comment.