Skip to content

Commit

Permalink
io: add AsyncReadExt::{chain, take} (#1484)
Browse files Browse the repository at this point in the history
  • Loading branch information
taiki-e authored and carllerche committed Aug 21, 2019
1 parent a791f4a commit 24fb33e
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 0 deletions.
24 changes: 24 additions & 0 deletions tokio-io/src/io/async_read_ext.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
use crate::io::chain::{chain, Chain};
use crate::io::copy::{copy, Copy};
use crate::io::read::{read, Read};
use crate::io::read_exact::{read_exact, ReadExact};
use crate::io::read_to_end::{read_to_end, ReadToEnd};
use crate::io::read_to_string::{read_to_string, ReadToString};
use crate::io::take::{take, Take};
use crate::{AsyncRead, AsyncWrite};

/// An extension trait which adds utility methods to `AsyncRead` types.
pub trait AsyncReadExt: AsyncRead {
/// Creates an adaptor which will chain this stream with another.
///
/// The returned `AsyncRead` instance will first read all bytes from this object
/// until EOF is encountered. Afterwards the output is equivalent to the
/// output of `next`.
fn chain<R>(self, next: R) -> Chain<Self, R>
where
Self: Sized,
R: AsyncRead,
{
chain(self, next)
}

/// Copy all data from `self` into the provided `AsyncWrite`.
///
/// The returned future will copy all the bytes read from `reader` into the
Expand Down Expand Up @@ -63,6 +78,15 @@ pub trait AsyncReadExt: AsyncRead {
{
read_to_string(self, dst)
}

/// Creates an AsyncRead adapter which will read at most `limit` bytes
/// from the underlying reader.
fn take(self, limit: u64) -> Take<Self>
where
Self: Sized,
{
take(self, limit)
}
}

impl<R: AsyncRead + ?Sized> AsyncReadExt for R {}
142 changes: 142 additions & 0 deletions tokio-io/src/io/chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
use crate::{AsyncBufRead, AsyncRead};
use futures_core::ready;
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::fmt;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};

/// Stream for the [`chain`](super::AsyncReadExt::chain) method.
#[must_use = "streams do nothing unless polled"]
pub struct Chain<T, U> {
first: T,
second: U,
done_first: bool,
}

impl<T, U> Unpin for Chain<T, U>
where
T: Unpin,
U: Unpin,
{
}

pub(super) fn chain<T, U>(first: T, second: U) -> Chain<T, U>
where
T: AsyncRead,
U: AsyncRead,
{
Chain {
first,
second,
done_first: false,
}
}

impl<T, U> Chain<T, U>
where
T: AsyncRead,
U: AsyncRead,
{
unsafe_pinned!(first: T);
unsafe_pinned!(second: U);
unsafe_unpinned!(done_first: bool);

/// Gets references to the underlying readers in this `Chain`.
pub fn get_ref(&self) -> (&T, &U) {
(&self.first, &self.second)
}

/// Gets mutable references to the underlying readers in this `Chain`.
///
/// Care should be taken to avoid modifying the internal I/O state of the
/// underlying readers as doing so may corrupt the internal state of this
/// `Chain`.
pub fn get_mut(&mut self) -> (&mut T, &mut U) {
(&mut self.first, &mut self.second)
}

/// Gets pinned mutable references to the underlying readers in this `Chain`.
///
/// Care should be taken to avoid modifying the internal I/O state of the
/// underlying readers as doing so may corrupt the internal state of this
/// `Chain`.
pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
unsafe {
let Self { first, second, .. } = self.get_unchecked_mut();
(Pin::new_unchecked(first), Pin::new_unchecked(second))
}
}

/// Consumes the `Chain`, returning the wrapped readers.
pub fn into_inner(self) -> (T, U) {
(self.first, self.second)
}
}

impl<T, U> fmt::Debug for Chain<T, U>
where
T: fmt::Debug,
U: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Chain")
.field("t", &self.first)
.field("u", &self.second)
.finish()
}
}

impl<T, U> AsyncRead for Chain<T, U>
where
T: AsyncRead,
U: AsyncRead,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if !self.done_first {
match ready!(self.as_mut().first().poll_read(cx, buf)?) {
0 if !buf.is_empty() => *self.as_mut().done_first() = true,
n => return Poll::Ready(Ok(n)),
}
}
self.second().poll_read(cx, buf)
}
}

impl<T, U> AsyncBufRead for Chain<T, U>
where
T: AsyncBufRead,
U: AsyncBufRead,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let Self {
first,
second,
done_first,
} = unsafe { self.get_unchecked_mut() };
let first = unsafe { Pin::new_unchecked(first) };
let second = unsafe { Pin::new_unchecked(second) };

if !*done_first {
match ready!(first.poll_fill_buf(cx)?) {
buf if buf.is_empty() => {
*done_first = true;
}
buf => return Poll::Ready(Ok(buf)),
}
}
second.poll_fill_buf(cx)
}

fn consume(self: Pin<&mut Self>, amt: usize) {
if !self.done_first {
self.first().consume(amt)
} else {
self.second().consume(amt)
}
}
}
2 changes: 2 additions & 0 deletions tokio-io/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod async_read_ext;
mod async_write_ext;
mod buf_reader;
mod buf_writer;
mod chain;
mod copy;
mod flush;
mod lines;
Expand All @@ -13,6 +14,7 @@ mod read_to_end;
mod read_to_string;
mod read_until;
mod shutdown;
mod take;
mod write;
mod write_all;

Expand Down
120 changes: 120 additions & 0 deletions tokio-io/src/io/take.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use crate::{AsyncBufRead, AsyncRead};
use futures_core::ready;
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{cmp, io};

/// Stream for the [`take`](super::AsyncReadExt::take) method.
#[derive(Debug)]
#[must_use = "streams do nothing unless you `.await` or poll them"]
pub struct Take<R> {
inner: R,
// Add '_' to avoid conflicts with `limit` method.
limit_: u64,
}

impl<R: Unpin> Unpin for Take<R> {}

pub(super) fn take<R: AsyncRead>(inner: R, limit: u64) -> Take<R> {
Take {
inner,
limit_: limit,
}
}

impl<R: AsyncRead> Take<R> {
unsafe_pinned!(inner: R);
unsafe_unpinned!(limit_: u64);

/// Returns the remaining number of bytes that can be
/// read before this instance will return EOF.
///
/// # Note
///
/// This instance may reach `EOF` after reading fewer bytes than indicated by
/// this method if the underlying [`AsyncRead`] instance reaches EOF.
pub fn limit(&self) -> u64 {
self.limit_
}

/// Sets the number of bytes that can be read before this instance will
/// return EOF. This is the same as constructing a new `Take` instance, so
/// the amount of bytes read and the previous limit value don't matter when
/// calling this method.
pub fn set_limit(&mut self, limit: u64) {
self.limit_ = limit
}

/// Gets a reference to the underlying reader.
pub fn get_ref(&self) -> &R {
&self.inner
}

/// Gets a mutable reference to the underlying reader.
///
/// Care should be taken to avoid modifying the internal I/O state of the
/// underlying reader as doing so may corrupt the internal limit of this
/// `Take`.
pub fn get_mut(&mut self) -> &mut R {
&mut self.inner
}

/// Gets a pinned mutable reference to the underlying reader.
///
/// Care should be taken to avoid modifying the internal I/O state of the
/// underlying reader as doing so may corrupt the internal limit of this
/// `Take`.
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
self.inner()
}

/// Consumes the `Take`, returning the wrapped reader.
pub fn into_inner(self) -> R {
self.inner
}
}

impl<R: AsyncRead> AsyncRead for Take<R> {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}

fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
if self.limit_ == 0 {
return Poll::Ready(Ok(0));
}

let max = std::cmp::min(buf.len() as u64, self.limit_) as usize;
let n = ready!(self.as_mut().inner().poll_read(cx, &mut buf[..max]))?;
*self.as_mut().limit_() -= n as u64;
Poll::Ready(Ok(n))
}
}

impl<R: AsyncBufRead> AsyncBufRead for Take<R> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let Self { inner, limit_ } = unsafe { self.get_unchecked_mut() };
let inner = unsafe { Pin::new_unchecked(inner) };

// Don't call into inner reader at all at EOF because it may still block
if *limit_ == 0 {
return Poll::Ready(Ok(&[]));
}

let buf = ready!(inner.poll_fill_buf(cx)?);
let cap = cmp::min(buf.len() as u64, *limit_) as usize;
Poll::Ready(Ok(&buf[..cap]))
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
// Don't let callers reset the limit by passing an overlarge value
let amt = cmp::min(amt as u64, self.limit_) as usize;
*self.as_mut().limit_() -= amt as u64;
self.inner().consume(amt);
}
}
16 changes: 16 additions & 0 deletions tokio-io/tests/chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#![warn(rust_2018_idioms)]
#![feature(async_await)]

use tokio_io::AsyncReadExt;
use tokio_test::assert_ok;

#[tokio::test]
async fn chain() {
let mut buf = Vec::new();
let rd1: &[u8] = b"hello ";
let rd2: &[u8] = b"world";

let mut rd = rd1.chain(rd2);
assert_ok!(rd.read_to_end(&mut buf).await);
assert_eq!(buf, b"hello world");
}
16 changes: 16 additions & 0 deletions tokio-io/tests/take.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#![warn(rust_2018_idioms)]
#![feature(async_await)]

use tokio_io::AsyncReadExt;
use tokio_test::assert_ok;

#[tokio::test]
async fn take() {
let mut buf = [0; 6];
let rd: &[u8] = b"hello world";

let mut rd = rd.take(4);
let n = assert_ok!(rd.read(&mut buf).await);
assert_eq!(n, 4);
assert_eq!(&buf, &b"hell\0\0"[..]);
}

0 comments on commit 24fb33e

Please sign in to comment.