Skip to content

Commit

Permalink
Async writer tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Mar 28, 2023
1 parent b05522f commit 8660dd4
Showing 1 changed file with 32 additions and 28 deletions.
60 changes: 32 additions & 28 deletions parquet/src/arrow/async_writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@
//! # }
//! ```

use std::{
io::Write,
sync::{Arc, Mutex},
};
use std::{io::Write, sync::Arc};

use crate::{
arrow::ArrowWriter,
Expand Down Expand Up @@ -88,22 +85,25 @@ pub struct AsyncArrowWriter<W> {
impl<W: AsyncWrite + Unpin + Send> AsyncArrowWriter<W> {
/// Try to create a new Async Arrow Writer.
///
/// `buffer_flush_threshold` will be used to trigger flush of the inner buffer.
/// `buffer_size` determines the size of the intermediate buffer
///
/// Flush will automatically be called by [`Self::write`] if
/// the buffer is at least half full
pub fn try_new(
writer: W,
arrow_schema: SchemaRef,
buffer_flush_threshold: usize,
buffer_size: usize,
props: Option<WriterProperties>,
) -> Result<Self> {
let shared_buffer = SharedBuffer::default();
let shared_buffer = SharedBuffer::new(buffer_size);
let sync_writer =
ArrowWriter::try_new(shared_buffer.clone(), arrow_schema, props)?;

Ok(Self {
sync_writer,
async_writer: writer,
shared_buffer,
buffer_flush_threshold,
buffer_flush_threshold: buffer_size / 2,
})
}

Expand All @@ -114,7 +114,7 @@ impl<W: AsyncWrite + Unpin + Send> AsyncArrowWriter<W> {
pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> {
self.sync_writer.write(batch)?;
Self::try_flush(
&self.shared_buffer,
&mut self.shared_buffer,
&mut self.async_writer,
self.buffer_flush_threshold,
)
Expand All @@ -135,64 +135,68 @@ impl<W: AsyncWrite + Unpin + Send> AsyncArrowWriter<W> {
let metadata = self.sync_writer.close()?;

// Force to flush the remaining data.
Self::try_flush(&self.shared_buffer, &mut self.async_writer, 0).await?;
Self::try_flush(&mut self.shared_buffer, &mut self.async_writer, 0).await?;

Ok(metadata)
}

/// Flush the data in the [`SharedBuffer`] into the `async_writer` if its size
/// exceeds the threshold.
async fn try_flush(
shared_buffer: &SharedBuffer,
shared_buffer: &mut SharedBuffer,
async_writer: &mut W,
threshold: usize,
) -> Result<()> {
let mut buffer = {
let mut buffer = shared_buffer.buffer.lock().unwrap();

if buffer.is_empty() || buffer.len() < threshold {
// no need to flush
return Ok(());
}
std::mem::take(&mut *buffer)
};
let mut buffer = shared_buffer.buffer.try_lock().unwrap();
if buffer.is_empty() || buffer.len() < threshold {
// no need to flush
return Ok(());
}

async_writer
.write(&buffer)
.write(buffer.as_slice())
.await
.map_err(|e| ParquetError::External(Box::new(e)))?;

async_writer
.flush()
.await
.map_err(|e| ParquetError::External(Box::new(e)))?;

// reuse the buffer.
buffer.clear();
*shared_buffer.buffer.lock().unwrap() = buffer;

Ok(())
}
}

/// A buffer with interior mutability shared by the [`ArrowWriter`] and
/// [`AsyncArrowWriter`].
#[derive(Clone, Default)]
#[derive(Clone)]
struct SharedBuffer {
/// The inner buffer for reading and writing
///
/// The lock is used to obtain internal mutability, so no worry about the
/// lock contention.
buffer: Arc<Mutex<Vec<u8>>>,
buffer: Arc<futures::lock::Mutex<Vec<u8>>>,
}

impl SharedBuffer {
pub fn new(capacity: usize) -> Self {
Self {
buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))),
}
}
}

impl Write for SharedBuffer {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut buffer = self.buffer.lock().unwrap();
let mut buffer = self.buffer.try_lock().unwrap();
Write::write(&mut *buffer, buf)
}

fn flush(&mut self) -> std::io::Result<()> {
let mut buffer = self.buffer.lock().unwrap();
let mut buffer = self.buffer.try_lock().unwrap();
Write::flush(&mut *buffer)
}
}
Expand Down Expand Up @@ -342,7 +346,7 @@ mod tests {
};

let test_buffer_flush_thresholds =
vec![0, 1024, 40 * 1024, 50 * 1024, 100 * 1024, usize::MAX];
vec![0, 1024, 40 * 1024, 50 * 1024, 100 * 1024];

for buffer_flush_threshold in test_buffer_flush_thresholds {
let reader = get_test_reader();
Expand All @@ -354,7 +358,7 @@ mod tests {
let mut async_writer = AsyncArrowWriter::try_new(
&mut test_async_sink,
reader.schema(),
buffer_flush_threshold,
buffer_flush_threshold * 2,
Some(write_props.clone()),
)
.unwrap();
Expand Down

0 comments on commit 8660dd4

Please sign in to comment.