Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize BufWriter #79930

Merged
merged 7 commits into from
May 6, 2021
Merged
217 changes: 178 additions & 39 deletions library/std/src/io/buffered/bufwriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::io::{
self, Error, ErrorKind, IntoInnerError, IoSlice, Seek, SeekFrom, Write, DEFAULT_BUF_SIZE,
};
use crate::mem;
use crate::ptr;

/// Wraps a writer and buffers its output.
///
Expand Down Expand Up @@ -68,6 +69,10 @@ use crate::mem;
#[stable(feature = "rust1", since = "1.0.0")]
pub struct BufWriter<W: Write> {
inner: Option<W>,
// The buffer. Avoid using this like a normal `Vec` in common code paths.
// That is, don't use `buf.push`, `buf.extend_from_slice`, or any other
// methods that require bounds checking or the like. This makes an enormous
// difference to performance (we may want to stop using a `Vec` entirely).
buf: Vec<u8>,
// #30888: If the inner writer panics in a call to write, we don't want to
// write the buffered data a second time in BufWriter's destructor. This
Expand Down Expand Up @@ -181,9 +186,14 @@ impl<W: Write> BufWriter<W> {
/// data. Writes as much as possible without exceeding capacity. Returns
/// the number of bytes written.
pub(super) fn write_to_buf(&mut self, buf: &[u8]) -> usize {
let available = self.buf.capacity() - self.buf.len();
let available = self.spare_capacity();
let amt_to_buffer = available.min(buf.len());
self.buf.extend_from_slice(&buf[..amt_to_buffer]);

// SAFETY: `amt_to_buffer` is <= buffer's spare capacity by construction.
unsafe {
self.write_to_buffer_unchecked(&buf[..amt_to_buffer]);
}

amt_to_buffer
}

Expand Down Expand Up @@ -331,6 +341,103 @@ impl<W: Write> BufWriter<W> {
let buf = if !self.panicked { Ok(buf) } else { Err(WriterPanicked { buf }) };
(self.inner.take().unwrap(), buf)
}

// Ensure this function does not get inlined into `write`, so that it
// remains inlineable and its common path remains as short as possible.
// If this function ends up being called frequently relative to `write`,
// it's likely a sign that the client is using an improperly sized buffer
// or their write patterns are somewhat pathological.
#[cold]
#[inline(never)]
fn write_cold(&mut self, buf: &[u8]) -> io::Result<usize> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it's assumed this function is only called for buf.len() >= self.spare_capacity(). It'd make it easier to read and verify if that assumption was written down. Perhaps as a debug_assert!(). (Same below.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only called in that case, but it still works if that assumption doesn't hold, if I'm not mistaken, so I didn't add a debug_assert!. Let me know if you still think I should add it and I'll go ahead.

if buf.len() > self.spare_capacity() {
self.flush_buf()?;
}

// Why not len > capacity? To avoid a needless trip through the buffer when the input
// exactly fills it. We'd just need to flush it to the underlying writer anyway.
if buf.len() >= self.buf.capacity() {
self.panicked = true;
let r = self.get_mut().write(buf);
self.panicked = false;
r
} else {
// Write to the buffer. In this case, we write to the buffer even if it fills it
// exactly. Doing otherwise would mean flushing the buffer, then writing this
// input to the inner writer, which in many cases would be a worse strategy.

// SAFETY: There was either enough spare capacity already, or there wasn't and we
// flushed the buffer to ensure that there is. In the latter case, we know that there
// is because flushing ensured that our entire buffer is spare capacity, and we entered
// this block because the input buffer length is less than that capacity. In either
// case, it's safe to write the input buffer to our buffer.
unsafe {
self.write_to_buffer_unchecked(buf);
}

Ok(buf.len())
}
}

// Ensure this function does not get inlined into `write_all`, so that it
// remains inlineable and its common path remains as short as possible.
// If this function ends up being called frequently relative to `write_all`,
// it's likely a sign that the client is using an improperly sized buffer
// or their write patterns are somewhat pathological.
#[cold]
#[inline(never)]
fn write_all_cold(&mut self, buf: &[u8]) -> io::Result<()> {
// Normally, `write_all` just calls `write` in a loop. We can do better
// by calling `self.get_mut().write_all()` directly, which avoids
// round trips through the buffer in the event of a series of partial
// writes in some circumstances.

if buf.len() > self.spare_capacity() {
self.flush_buf()?;
}

// Why not len > capacity? To avoid a needless trip through the buffer when the input
// exactly fills it. We'd just need to flush it to the underlying writer anyway.
if buf.len() >= self.buf.capacity() {
self.panicked = true;
let r = self.get_mut().write_all(buf);
self.panicked = false;
r
} else {
// Write to the buffer. In this case, we write to the buffer even if it fills it
// exactly. Doing otherwise would mean flushing the buffer, then writing this
// input to the inner writer, which in many cases would be a worse strategy.

// SAFETY: There was either enough spare capacity already, or there wasn't and we
// flushed the buffer to ensure that there is. In the latter case, we know that there
// is because flushing ensured that our entire buffer is spare capacity, and we entered
// this block because the input buffer length is less than that capacity. In either
// case, it's safe to write the input buffer to our buffer.
unsafe {
self.write_to_buffer_unchecked(buf);
}

Ok(())
}
}

// SAFETY: Requires `buf.len() <= self.buf.capacity() - self.buf.len()`,
// i.e., that input buffer length is less than or equal to spare capacity.
#[inline]
unsafe fn write_to_buffer_unchecked(&mut self, buf: &[u8]) {
m-ou-se marked this conversation as resolved.
Show resolved Hide resolved
debug_assert!(buf.len() <= self.spare_capacity());
let old_len = self.buf.len();
let buf_len = buf.len();
let src = buf.as_ptr();
let dst = self.buf.as_mut_ptr().add(old_len);
ptr::copy_nonoverlapping(src, dst, buf_len);
self.buf.set_len(old_len + buf_len);
Comment on lines +429 to +434
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Note that you could also implement this with self.buf.spare_capacity_mut() and MaybeUninit::write_slice, which are both still unstable.)

}

#[inline]
fn spare_capacity(&self) -> usize {
self.buf.capacity() - self.buf.len()
}
}

#[unstable(feature = "bufwriter_into_raw_parts", issue = "80690")]
Expand Down Expand Up @@ -402,63 +509,82 @@ impl fmt::Debug for WriterPanicked {

#[stable(feature = "rust1", since = "1.0.0")]
impl<W: Write> Write for BufWriter<W> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.buf.len() + buf.len() > self.buf.capacity() {
self.flush_buf()?;
}
// FIXME: Why no len > capacity? Why not buffer len == capacity? #72919
if buf.len() >= self.buf.capacity() {
self.panicked = true;
let r = self.get_mut().write(buf);
self.panicked = false;
r
} else {
self.buf.extend_from_slice(buf);
// Use < instead of <= to avoid a needless trip through the buffer in some cases.
// See `write_cold` for details.
if buf.len() < self.spare_capacity() {
// SAFETY: safe by above conditional.
unsafe {
self.write_to_buffer_unchecked(buf);
}

Ok(buf.len())
} else {
self.write_cold(buf)
}
}

#[inline]
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
// Normally, `write_all` just calls `write` in a loop. We can do better
// by calling `self.get_mut().write_all()` directly, which avoids
// round trips through the buffer in the event of a series of partial
// writes in some circumstances.
if self.buf.len() + buf.len() > self.buf.capacity() {
self.flush_buf()?;
}
// FIXME: Why no len > capacity? Why not buffer len == capacity? #72919
if buf.len() >= self.buf.capacity() {
self.panicked = true;
let r = self.get_mut().write_all(buf);
self.panicked = false;
r
} else {
self.buf.extend_from_slice(buf);
// Use < instead of <= to avoid a needless trip through the buffer in some cases.
tgnottingham marked this conversation as resolved.
Show resolved Hide resolved
// See `write_all_cold` for details.
if buf.len() < self.spare_capacity() {
// SAFETY: safe by above conditional.
unsafe {
self.write_to_buffer_unchecked(buf);
}

Ok(())
} else {
self.write_all_cold(buf)
}
}

fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
// FIXME: Consider applying `#[inline]` / `#[inline(never)]` optimizations already applied
// to `write` and `write_all`. The performance benefits can be significant. See #79930.
if self.get_ref().is_write_vectored() {
let total_len = bufs.iter().map(|b| b.len()).sum::<usize>();
if self.buf.len() + total_len > self.buf.capacity() {
// We have to handle the possibility that the total length of the buffers overflows
// `usize` (even though this can only happen if multiple `IoSlice`s reference the
// same underlying buffer, as otherwise the buffers wouldn't fit in memory). If the
// computation overflows, then surely the input cannot fit in our buffer, so we forward
// to the inner writer's `write_vectored` method to let it handle it appropriately.
let saturated_total_len =
bufs.iter().fold(0usize, |acc, b| acc.saturating_add(b.len()));

if saturated_total_len > self.spare_capacity() {
// Flush if the total length of the input exceeds our buffer's spare capacity.
// If we would have overflowed, this condition also holds, and we need to flush.
self.flush_buf()?;
}
if total_len >= self.buf.capacity() {

if saturated_total_len >= self.buf.capacity() {
// Forward to our inner writer if the total length of the input is greater than or
// equal to our buffer capacity. If we would have overflowed, this condition also
// holds, and we punt to the inner writer.
self.panicked = true;
let r = self.get_mut().write_vectored(bufs);
self.panicked = false;
r
} else {
bufs.iter().for_each(|b| self.buf.extend_from_slice(b));
Ok(total_len)
// `saturated_total_len < self.buf.capacity()` implies that we did not saturate.

// SAFETY: We checked whether or not the spare capacity was large enough above. If
// it was, then we're safe already. If it wasn't, we flushed, making sufficient
// room for any input <= the buffer size, which includes this input.
unsafe {
bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b));
};
tgnottingham marked this conversation as resolved.
Show resolved Hide resolved

Ok(saturated_total_len)
}
} else {
let mut iter = bufs.iter();
let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) {
// This is the first non-empty slice to write, so if it does
// not fit in the buffer, we still get to flush and proceed.
if self.buf.len() + buf.len() > self.buf.capacity() {
if buf.len() > self.spare_capacity() {
self.flush_buf()?;
}
if buf.len() >= self.buf.capacity() {
Expand All @@ -469,19 +595,32 @@ impl<W: Write> Write for BufWriter<W> {
self.panicked = false;
return r;
} else {
self.buf.extend_from_slice(buf);
// SAFETY: We checked whether or not the spare capacity was large enough above.
// If it was, then we're safe already. If it wasn't, we flushed, making
// sufficient room for any input <= the buffer size, which includes this input.
unsafe {
self.write_to_buffer_unchecked(buf);
}

buf.len()
}
} else {
return Ok(0);
};
debug_assert!(total_written != 0);
for buf in iter {
if self.buf.len() + buf.len() > self.buf.capacity() {
break;
} else {
self.buf.extend_from_slice(buf);
if buf.len() <= self.spare_capacity() {
// SAFETY: safe by above conditional.
unsafe {
self.write_to_buffer_unchecked(buf);
}

// This cannot overflow `usize`. If we are here, we've written all of the bytes
// so far to our buffer, and we've ensured that we never exceed the buffer's
// capacity. Therefore, `total_written` <= `self.buf.capacity()` <= `usize::MAX`.
total_written += buf.len();
} else {
break;
}
}
Ok(total_written)
Expand Down