diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs
index 727bbb398..8abb7e732 100644
--- a/quinn-proto/src/connection/mod.rs
+++ b/quinn-proto/src/connection/mod.rs
@@ -53,7 +53,10 @@ pub use stats::ConnectionStats;
mod streams;
pub use streams::Streams;
-pub use streams::{FinishError, ReadError, ShouldTransmit, StreamEvent, UnknownStream, WriteError};
+pub use streams::{
+ Chunks, FinishError, ReadError, ReadableError, ShouldTransmit, StreamEvent, UnknownStream,
+ WriteError,
+};
mod timer;
use timer::{Timer, TimerTable};
@@ -1088,40 +1091,9 @@ where
/// control window is filled. On any given stream, you can switch from ordered to unordered
/// reads, but ordered reads on streams that have seen previous unordered reads will return
/// `ReadError::IllegalOrderedRead`.
- pub fn read(
- &mut self,
- id: StreamId,
- max_length: usize,
- ordered: bool,
- ) -> Result, ReadError> {
- let result = self.streams.read(id, max_length, ordered);
- self.post_read(id, &result);
- Ok(result?.map(|x| x.result))
- }
-
- /// Read the next ordered chunks from the given recv stream
- pub fn read_chunks(
- &mut self,
- id: StreamId,
- bufs: &mut [Bytes],
- ) -> Result , ReadError> {
- let result = self.streams.read_chunks(id, bufs);
- self.post_read(id, &result);
- Ok(result?.map(|x| x.result.bufs))
- }
-
- fn post_read(&mut self, id: StreamId, result: &streams::ReadResult) {
- let (did_read, max_data, max_stream_data) = match result {
- Ok(Some(did)) => (true, did.max_data, did.max_stream_data),
- _ => (false, ShouldTransmit::default(), ShouldTransmit::default()),
- };
-
- let max_dirty = self.streams.take_max_streams_dirty(id.dir());
- if did_read || max_dirty {
- self.spaces[SpaceId::Data]
- .pending
- .post_read(id, max_data, max_stream_data, max_dirty);
- }
+ pub fn read(&mut self, id: StreamId, ordered: bool) -> Result {
+ self.streams
+ .read(id, ordered, &mut self.spaces[SpaceId::Data].pending)
}
/// Send data on the given stream
diff --git a/quinn-proto/src/connection/streams.rs b/quinn-proto/src/connection/streams.rs
index 4cedd7325..4c1f0e72e 100644
--- a/quinn-proto/src/connection/streams.rs
+++ b/quinn-proto/src/connection/streams.rs
@@ -5,11 +5,10 @@ use std::{
mem,
};
-use bytes::{BufMut, Bytes};
+use bytes::BufMut;
use thiserror::Error;
use tracing::{debug, trace};
-use super::assembler::Chunk;
use super::spaces::{Retransmits, ThinRetransmits};
use crate::{
coding::BufMutExt,
@@ -20,9 +19,8 @@ use crate::{
};
mod recv;
-pub use recv::ReadError;
-use recv::{BytesRead, ReadChunks, Recv, StreamReadResult};
-pub(super) use recv::{DidRead, ReadResult};
+use recv::Recv;
+pub use recv::{Chunks, ReadError, ReadableError};
mod send;
pub use send::{FinishError, WriteError};
@@ -202,55 +200,13 @@ impl Streams {
self.connection_blocked.clear();
}
- pub(crate) fn read(
- &mut self,
+ pub(crate) fn read<'a>(
+ &'a mut self,
id: StreamId,
- max_length: usize,
ordered: bool,
- ) -> ReadResult {
- self.try_read(id, |rs| rs.read(max_length, ordered))
- }
-
- pub(crate) fn read_chunks(
- &mut self,
- id: StreamId,
- bufs: &mut [Bytes],
- ) -> ReadResult {
- self.try_read(id, |rs| rs.read_chunks(bufs))
- }
-
- fn try_read(&mut self, id: StreamId, mut read: F) -> ReadResult
- where
- F: FnMut(&mut Recv) -> StreamReadResult,
- O: BytesRead,
- {
- let mut entry = match self.recv.entry(id) {
- hash_map::Entry::Vacant(_) => return Err(ReadError::UnknownStream),
- hash_map::Entry::Occupied(e) => e,
- };
- let rs = entry.get_mut();
- match read(rs) {
- Ok(Some(out)) => {
- let (_, max_stream_data) = rs.max_stream_data(self.stream_receive_window);
- let max_data = self.add_read_credits(out.bytes_read());
- Ok(Some(DidRead {
- result: out,
- max_stream_data,
- max_data,
- }))
- }
- Ok(None) => {
- entry.remove_entry();
- self.stream_freed(id, StreamHalf::Recv);
- Ok(None)
- }
- Err(e @ ReadError::Reset { .. }) => {
- entry.remove_entry();
- self.stream_freed(id, StreamHalf::Recv);
- Err(e)
- }
- Err(e) => Err(e),
- }
+ pending: &'a mut Retransmits,
+ ) -> Result, ReadableError> {
+ Chunks::new(id, ordered, self, pending)
}
/// Queue `data` to be written for `stream`
@@ -1077,6 +1033,7 @@ enum StreamHalf {
mod tests {
use super::*;
use crate::TransportErrorCode;
+ use bytes::Bytes;
fn make(side: Side) -> Streams {
Streams::new(
@@ -1094,6 +1051,7 @@ mod tests {
let mut client = make(Side::Client);
let id = StreamId::new(Side::Server, Dir::Uni, 0);
let initial_max = client.local_max_data;
+ let mut pending = Retransmits::default();
assert_eq!(
client
.received(
@@ -1110,7 +1068,9 @@ mod tests {
);
assert_eq!(client.data_recvd, 2048);
assert_eq!(client.local_max_data - initial_max, 0);
- client.read(id, 1024, true).unwrap();
+ let mut chunks = client.read(id, true, &mut pending).unwrap();
+ chunks.next(1024).unwrap();
+ let _ = chunks.finalize();
assert_eq!(client.local_max_data - initial_max, 1024);
assert_eq!(
client
@@ -1193,6 +1153,7 @@ mod tests {
fn recv_stopped() {
let mut client = make(Side::Client);
let id = StreamId::new(Side::Server, Dir::Uni, 0);
+ let mut pending = Retransmits::default();
let initial_max = client.local_max_data;
assert_eq!(
client
@@ -1217,10 +1178,13 @@ mod tests {
}
);
assert!(client.stop(id).is_err());
- assert_eq!(client.read(id, 0, true), Err(ReadError::UnknownStream));
assert_eq!(
- client.read(id, usize::MAX, false),
- Err(ReadError::UnknownStream)
+ client.read(id, true, &mut pending).err(),
+ Some(ReadableError::UnknownStream)
+ );
+ assert_eq!(
+ client.read(id, false, &mut pending).err(),
+ Some(ReadableError::UnknownStream)
);
assert_eq!(client.local_max_data - initial_max, 32);
assert_eq!(
diff --git a/quinn-proto/src/connection/streams/recv.rs b/quinn-proto/src/connection/streams/recv.rs
index 56a8f658b..dd72f80b0 100644
--- a/quinn-proto/src/connection/streams/recv.rs
+++ b/quinn-proto/src/connection/streams/recv.rs
@@ -1,8 +1,10 @@
-use bytes::Bytes;
+use std::collections::hash_map::Entry;
+use std::mem;
+
use thiserror::Error;
use tracing::debug;
-use super::{ShouldTransmit, UnknownStream};
+use super::{Retransmits, ShouldTransmit, StreamHalf, StreamId, Streams, UnknownStream};
use crate::connection::assembler::{Assembler, Chunk, IllegalOrderedRead};
use crate::{frame, TransportError, VarInt};
@@ -70,49 +72,6 @@ impl Recv {
Ok((new_bytes, frame.fin && self.stopped))
}
- pub(super) fn read(&mut self, max_length: usize, ordered: bool) -> StreamReadResult {
- if self.stopped {
- return Err(ReadError::UnknownStream);
- }
-
- self.assembler.ensure_ordering(ordered)?;
- match self.assembler.read(max_length, ordered) {
- Some(chunk) => Ok(Some(chunk)),
- None => self.read_blocked().map(|()| None),
- }
- }
-
- pub(super) fn read_chunks(
- &mut self,
- chunks: &mut [Bytes],
- ) -> Result, ReadError> {
- if self.stopped {
- return Err(ReadError::UnknownStream);
- }
-
- let mut out = ReadChunks { bufs: 0, read: 0 };
- if chunks.is_empty() {
- return Ok(Some(out));
- }
-
- self.assembler.ensure_ordering(true)?;
- while let Some(chunk) = self.assembler.read(usize::MAX, true) {
- chunks[out.bufs] = chunk.bytes;
- out.read += chunks[out.bufs].len();
- out.bufs += 1;
-
- if out.bufs >= chunks.len() {
- return Ok(Some(out));
- }
- }
-
- if out.bufs > 0 {
- return Ok(Some(out));
- }
-
- self.read_blocked().map(|()| None)
- }
-
pub(super) fn stop(&mut self) -> Result<(u64, ShouldTransmit), UnknownStream> {
if self.stopped {
return Err(UnknownStream { _private: () });
@@ -129,21 +88,6 @@ impl Recv {
Ok((read_credits, ShouldTransmit(self.is_receiving())))
}
- fn read_blocked(&mut self) -> Result<(), ReadError> {
- match self.state {
- RecvState::ResetRecvd { error_code, .. } => {
- Err(ReadError::Reset(error_code))
- }
- RecvState::Recv { size } => {
- if size == Some(self.end) && self.assembler.bytes_read() == self.end {
- Ok(())
- } else {
- Err(ReadError::Blocked)
- }
- }
- }
- }
-
/// Returns the window that should be advertised in a `MAX_STREAM_DATA` frame
///
/// The method returns a tuple which consists of the window that should be
@@ -254,40 +198,149 @@ impl Recv {
}
}
-pub(crate) type ReadResult = Result>, ReadError>;
-
-/// Result of a `Streams::read` call in case the stream had not ended yet
-#[derive(Debug, Eq, PartialEq, Copy, Clone)]
-#[must_use = "A frame might need to be enqueued"]
-pub(crate) struct DidRead {
- pub result: T,
- pub max_stream_data: ShouldTransmit,
- pub max_data: ShouldTransmit,
+/// Chunks
+pub struct Chunks<'a> {
+ id: StreamId,
+ ordered: bool,
+ streams: &'a mut Streams,
+ pending: &'a mut Retransmits,
+ state: ChunksState,
+ read: u64,
}
-pub(super) type StreamReadResult = Result, ReadError>;
+impl<'a> Chunks<'a> {
+ pub(super) fn new(
+ id: StreamId,
+ ordered: bool,
+ streams: &'a mut Streams,
+ pending: &'a mut Retransmits,
+ ) -> Result {
+ let entry = match streams.recv.entry(id) {
+ Entry::Occupied(entry) => entry,
+ Entry::Vacant(_) => return Err(ReadableError::UnknownStream),
+ };
+
+ let mut recv = match entry.get().stopped {
+ true => return Err(ReadableError::UnknownStream),
+ false => entry.remove(),
+ };
-pub(crate) trait BytesRead {
- fn bytes_read(&self) -> u64;
-}
+ recv.assembler.ensure_ordering(ordered)?;
+ Ok(Self {
+ id,
+ ordered,
+ streams,
+ pending,
+ state: ChunksState::Readable(recv),
+ read: 0,
+ })
+ }
+
+ /// Next
+ ///
+ /// Should call finalize() when done calling this.
+ pub fn next(&mut self, max_length: usize) -> Result, ReadError> {
+ let mut rs = match mem::replace(&mut self.state, ChunksState::Finalized) {
+ ChunksState::Readable(rs) => rs,
+ ChunksState::Error(e, st) => {
+ self.state = ChunksState::Error(e.clone(), st);
+ return Err(e);
+ }
+ ChunksState::Finished(st) => {
+ self.state = ChunksState::Finished(st);
+ return Ok(None);
+ }
+ ChunksState::Finalized => panic!("must not call next() after finalize()"),
+ };
+
+ if let Some(chunk) = rs.assembler.read(max_length, self.ordered) {
+ self.read += chunk.bytes.len() as u64;
+ self.state = ChunksState::Readable(rs);
+ return Ok(Some(chunk));
+ }
+
+ match rs.state {
+ RecvState::ResetRecvd { error_code, .. } => {
+ self.streams.stream_freed(self.id, StreamHalf::Recv);
+ self.pending
+ .post_read(self.id, ShouldTransmit(false), ShouldTransmit(false), true);
+
+ let err = ReadError::Reset(error_code);
+ self.state = ChunksState::Error(err.clone(), ShouldTransmit(true));
+ Err(err)
+ }
+ RecvState::Recv { size } => {
+ if size == Some(rs.end) && rs.assembler.bytes_read() == rs.end {
+ self.streams.stream_freed(self.id, StreamHalf::Recv);
+ self.pending.post_read(
+ self.id,
+ ShouldTransmit(false),
+ ShouldTransmit(false),
+ true,
+ );
+ self.state = ChunksState::Finished(ShouldTransmit(true));
+ Ok(None)
+ } else {
+ let should_transmit =
+ Self::done(&mut rs, self.read, self.id, self.streams, self.pending);
+ self.state = ChunksState::Error(ReadError::Blocked, should_transmit);
+ self.streams.recv.insert(self.id, rs);
+ Err(ReadError::Blocked)
+ }
+ }
+ }
+ }
-impl BytesRead for Chunk {
- fn bytes_read(&self) -> u64 {
- self.bytes.len() as u64
+ /// Finalize
+ pub fn finalize(mut self) -> ShouldTransmit {
+ self.finalize_inner(false)
}
-}
-pub(crate) struct ReadChunks {
- pub bufs: usize,
- pub read: usize,
+ fn finalize_inner(&mut self, drop: bool) -> ShouldTransmit {
+ let state = mem::replace(&mut self.state, ChunksState::Finalized);
+ match state {
+ ChunksState::Readable(mut rs) => {
+ debug_assert!(!drop);
+ let should_transmit =
+ Self::done(&mut rs, self.read, self.id, self.streams, self.pending);
+ self.streams.recv.insert(self.id, rs);
+ should_transmit
+ }
+ ChunksState::Finished(should_transmit) | ChunksState::Error(_, should_transmit) => {
+ debug_assert!(!drop);
+ should_transmit
+ }
+ ChunksState::Finalized => ShouldTransmit(false),
+ }
+ }
+
+ fn done(
+ rs: &mut Recv,
+ read: u64,
+ id: StreamId,
+ streams: &mut Streams,
+ pending: &mut Retransmits,
+ ) -> ShouldTransmit {
+ let (_, max_stream_data) = rs.max_stream_data(streams.stream_receive_window);
+ let max_data = streams.add_read_credits(read);
+ pending.post_read(id, max_data, max_stream_data, false);
+ ShouldTransmit(max_stream_data.0 | max_data.0)
+ }
}
-impl BytesRead for ReadChunks {
- fn bytes_read(&self) -> u64 {
- self.read as u64
+impl<'a> Drop for Chunks<'a> {
+ fn drop(&mut self) {
+ let _ = self.finalize_inner(true);
}
}
+enum ChunksState {
+ Readable(Recv),
+ Error(ReadError, ShouldTransmit),
+ Finished(ShouldTransmit),
+ Finalized,
+}
+
/// Errors triggered when reading from a recv stream
#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum ReadError {
@@ -302,6 +355,11 @@ pub enum ReadError {
/// Carries an application-defined error code.
#[error("reset by peer: code {0}")]
Reset(VarInt),
+}
+
+/// Errors triggered when opening a recv stream for reading
+#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
+pub enum ReadableError {
/// The stream has not been opened or was already stopped, finished, or reset
#[error("unknown stream")]
UnknownStream,
@@ -313,9 +371,9 @@ pub enum ReadError {
IllegalOrderedRead,
}
-impl From for ReadError {
+impl From for ReadableError {
fn from(_: IllegalOrderedRead) -> Self {
- ReadError::IllegalOrderedRead
+ ReadableError::IllegalOrderedRead
}
}
diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs
index c75fc115b..733975b3e 100644
--- a/quinn-proto/src/lib.rs
+++ b/quinn-proto/src/lib.rs
@@ -40,8 +40,10 @@ mod varint;
pub use varint::{VarInt, VarIntBoundsExceeded};
mod connection;
-pub use crate::connection::{Chunk, ConnectionError, ConnectionStats, Event, SendDatagramError};
-pub use crate::connection::{FinishError, ReadError, StreamEvent, UnknownStream, WriteError};
+pub use crate::connection::{
+ Chunk, Chunks, ConnectionError, ConnectionStats, Event, FinishError, ReadError, ReadableError,
+ SendDatagramError, StreamEvent, UnknownStream, WriteError,
+};
mod config;
pub use config::{ConfigError, TransportConfig};
diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs
index 4a9b548e9..9d3e666bd 100644
--- a/quinn-proto/src/tests/mod.rs
+++ b/quinn-proto/src/tests/mod.rs
@@ -238,14 +238,14 @@ fn finish_stream_simple() {
assert_eq!(pair.server_conn_mut(client_ch).send_streams(), 0);
assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s);
assert_matches!(pair.server_conn_mut(server_ch).poll(), None);
+
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
+ chunks.next(usize::MAX),
Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG
);
- assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
- Ok(None)
- );
+ assert_matches!(chunks.next(usize::MAX), Ok(None));
+ let _ = chunks.finalize();
}
#[test]
@@ -270,10 +270,9 @@ fn reset_stream() {
Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni }))
);
assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s);
- assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
- Err(ReadError::Reset(ERROR))
- );
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
+ assert_matches!(chunks.next(usize::MAX), Err(ReadError::Reset(ERROR)));
+ let _ = chunks.finalize();
assert_matches!(pair.client_conn_mut(client_ch).poll(), None);
}
@@ -432,10 +431,13 @@ fn zero_rtt_happypath() {
pair.drive();
assert!(pair.client_conn_mut(client_ch).accepted_0rtt());
let server_ch = pair.server.assert_accept();
+
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
+ chunks.next(usize::MAX),
Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG
);
+ let _ = chunks.finalize();
assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0);
}
@@ -503,11 +505,10 @@ fn zero_rtt_rejection() {
assert_matches!(pair.server_conn_mut(server_conn).poll(), None);
let s2 = pair.client_conn_mut(client_ch).open(Dir::Uni).unwrap();
assert_eq!(s, s2);
- assert_eq!(
- pair.server_conn_mut(server_conn)
- .read(s2, usize::MAX, false),
- Err(ReadError::Blocked)
- );
+
+ let mut chunks = pair.server_conn_mut(server_conn).read(s2, false).unwrap();
+ assert_eq!(chunks.next(usize::MAX), Err(ReadError::Blocked));
+ let _ = chunks.finalize();
assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0);
}
@@ -643,14 +644,15 @@ fn stream_id_limit() {
Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni }))
);
assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s);
+
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
+ chunks.next(usize::MAX),
Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG
);
- assert_eq!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
- Ok(None)
- );
+ assert_eq!(chunks.next(usize::MAX), Ok(None));
+ let _ = chunks.finalize();
+
// Server will only send MAX_STREAM_ID now that the application's been notified
pair.drive();
assert_matches!(
@@ -676,10 +678,10 @@ fn stream_id_limit() {
);
assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s);
assert_matches!(pair.server_conn_mut(server_ch).poll(), None);
- assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
- Ok(None)
- );
+
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
+ assert_matches!(chunks.next(usize::MAX), Ok(None));
+ let _ = chunks.finalize();
}
#[test]
@@ -705,10 +707,12 @@ fn key_update_simple() {
);
assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Bi), Some(stream) if stream == s);
assert_matches!(pair.server_conn_mut(server_ch).poll(), None);
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
+ chunks.next(usize::MAX),
Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG1
);
+ let _ = chunks.finalize();
info!("initiating key update");
pair.client_conn_mut(client_ch).initiate_key_update();
@@ -719,10 +723,12 @@ fn key_update_simple() {
assert_matches!(pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Readable { id })) if id == s);
assert_matches!(pair.server_conn_mut(server_ch).poll(), None);
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
+ chunks.next(usize::MAX),
Ok(Some(chunk)) if chunk.offset == 6 && chunk.bytes == MSG2
);
+ let _ = chunks.finalize();
assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0);
assert_eq!(pair.server_conn_mut(server_ch).lost_packets(), 0);
@@ -763,18 +769,12 @@ fn key_update_reordered() {
);
assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Bi), Some(stream) if stream == s);
- let buf1 = pair
- .server_conn_mut(server_ch)
- .read(s, usize::MAX, true)
- .unwrap()
- .unwrap();
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, true).unwrap();
+ let buf1 = chunks.next(usize::MAX).unwrap().unwrap();
assert_matches!(&*buf1.bytes, MSG1);
- let buf2 = pair
- .server_conn_mut(server_ch)
- .read(s, usize::MAX, true)
- .unwrap()
- .unwrap();
+ let buf2 = chunks.next(usize::MAX).unwrap().unwrap();
assert_eq!(buf2.bytes, MSG2);
+ let _ = chunks.finalize();
assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0);
assert_eq!(pair.server_conn_mut(server_ch).lost_packets(), 0);
@@ -987,10 +987,13 @@ fn test_flow_control(config: TransportConfig, window_size: usize) {
.reset(s, VarInt(42))
.unwrap();
pair.drive();
+
+ let mut chunks = pair.server_conn_mut(server_conn).read(s, true).unwrap();
assert_eq!(
- pair.server_conn_mut(server_conn).read(s, usize::MAX, true),
- Err(ReadError::Reset(VarInt(42)))
+ chunks.next(usize::MAX).err(),
+ Some(ReadError::Reset(VarInt(42)))
);
+ let _ = chunks.finalize();
// Happy path
let s = pair.client_conn_mut(client_conn).open(Dir::Uni).unwrap();
@@ -1006,8 +1009,9 @@ fn test_flow_control(config: TransportConfig, window_size: usize) {
pair.drive();
let mut cursor = 0;
+ let mut chunks = pair.server_conn_mut(server_conn).read(s, true).unwrap();
loop {
- match pair.server_conn_mut(server_conn).read(s, usize::MAX, true) {
+ match chunks.next(usize::MAX) {
Ok(Some(chunk)) => {
cursor += chunk.bytes.len();
}
@@ -1022,6 +1026,7 @@ fn test_flow_control(config: TransportConfig, window_size: usize) {
}
}
}
+ let _ = chunks.finalize();
assert_eq!(cursor, window_size);
pair.drive();
@@ -1037,8 +1042,9 @@ fn test_flow_control(config: TransportConfig, window_size: usize) {
pair.drive();
let mut cursor = 0;
+ let mut chunks = pair.server_conn_mut(server_conn).read(s, true).unwrap();
loop {
- match pair.server_conn_mut(server_conn).read(s, usize::MAX, true) {
+ match chunks.next(usize::MAX) {
Ok(Some(chunk)) => {
cursor += chunk.bytes.len();
}
@@ -1054,6 +1060,7 @@ fn test_flow_control(config: TransportConfig, window_size: usize) {
}
}
assert_eq!(cursor, window_size);
+ let _ = chunks.finalize();
}
#[test]
@@ -1102,10 +1109,11 @@ fn stop_opens_bidi() {
assert_eq!(pair.server_conn_mut(client_conn).send_streams(), 0);
assert_matches!(pair.server_conn_mut(server_conn).accept(Dir::Bi), Some(stream) if stream == s);
assert_eq!(pair.server_conn_mut(client_conn).send_streams(), 1);
- assert_matches!(
- pair.server_conn_mut(server_conn).read(s, usize::MAX, false),
- Err(ReadError::Blocked)
- );
+
+ let mut chunks = pair.server_conn_mut(server_conn).read(s, false).unwrap();
+ assert_matches!(chunks.next(usize::MAX), Err(ReadError::Blocked));
+ let _ = chunks.finalize();
+
assert_matches!(
pair.server_conn_mut(server_conn).write(s, b"foo"),
Err(WriteError::Stopped(ERROR))
@@ -1302,10 +1310,13 @@ fn finish_stream_flow_control_reordered() {
pair.server.drive(pair.time, pair.client.addr); // Receive
// Issue flow control credit
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
+ chunks.next(usize::MAX),
Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG
);
+ let _ = chunks.finalize();
+
pair.server.drive(pair.time, pair.client.addr);
pair.server.delay_outbound(); // Delay it
@@ -1325,10 +1336,10 @@ fn finish_stream_flow_control_reordered() {
Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni }))
);
assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s);
- assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
- Ok(None)
- );
+
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
+ assert_matches!(chunks.next(usize::MAX), Ok(None));
+ let _ = chunks.finalize();
}
#[test]
@@ -1357,10 +1368,12 @@ fn handshake_1rtt_handling() {
pair.drive();
assert!(pair.client_conn_mut(client_ch).lost_packets() != 0);
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
+ chunks.next(usize::MAX),
Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG
);
+ let _ = chunks.finalize();
}
#[test]
@@ -1603,14 +1616,14 @@ fn finish_acked() {
assert_matches!(pair.server_conn_mut(server_ch).poll(), None);
assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s);
+
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
+ chunks.next(usize::MAX),
Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG
);
- assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
- Err(ReadError::Blocked)
- );
+ assert_matches!(chunks.next(usize::MAX), Err(ReadError::Blocked));
+ let _ = chunks.finalize();
// Finish before receiving data ack
pair.client_conn_mut(client_ch).finish(s).unwrap();
@@ -1626,10 +1639,10 @@ fn finish_acked() {
pair.client_conn_mut(client_ch).poll(),
Some(Event::Stream(StreamEvent::Finished { id })) if id == s
);
- assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
- Ok(None)
- );
+
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
+ assert_matches!(chunks.next(usize::MAX), Ok(None));
+ let _ = chunks.finalize();
}
#[test]
@@ -1668,14 +1681,14 @@ fn finish_retransmit() {
);
assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s);
+
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
+ chunks.next(usize::MAX),
Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG
);
- assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
- Ok(None)
- );
+ assert_matches!(chunks.next(usize::MAX), Ok(None));
+ let _ = chunks.finalize();
}
/// Ensures that exchanging data on a client-initiated bidirectional stream works past the initial
@@ -1703,103 +1716,25 @@ fn repeated_request_response() {
pair.drive();
assert_eq!(pair.server_conn_mut(server_ch).accept(Dir::Bi), Some(s));
+ let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap();
assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
+ chunks.next(usize::MAX),
Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == REQUEST
);
- assert_matches!(
- pair.server_conn_mut(server_ch).read(s, usize::MAX, false),
- Ok(None)
- );
+ assert_matches!(chunks.next(usize::MAX), Ok(None));
+ let _ = chunks.finalize();
+
pair.server_conn_mut(server_ch).write(s, RESPONSE).unwrap();
pair.server_conn_mut(server_ch).finish(s).unwrap();
pair.drive();
+ let mut chunks = pair.client_conn_mut(client_ch).read(s, false).unwrap();
assert_matches!(
- pair.client_conn_mut(client_ch).read(s, usize::MAX, false),
+ chunks.next(usize::MAX),
Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == RESPONSE
);
- assert_matches!(
- pair.client_conn_mut(client_ch).read(s, usize::MAX, false),
- Ok(None)
- );
- }
-}
-
-#[test]
-fn read_chunks() {
- let _guard = subscribe();
- let server = ServerConfig {
- transport: Arc::new(TransportConfig {
- max_concurrent_bidi_streams: 3u32.into(),
- ..TransportConfig::default()
- }),
- ..server_config()
- };
- let mut pair = Pair::new(Default::default(), server);
- let (client_ch, server_ch) = pair.connect();
- let mut empty = vec![];
- let mut chunks = vec![Bytes::new(), Bytes::new()];
- const ONE: &[u8] = b"ONE";
- const TWO: &[u8] = b"TWO";
- const THREE: &[u8] = b"THREE";
- for _ in 0..3 {
- let s = pair.client_conn_mut(client_ch).open(Dir::Bi).unwrap();
-
- pair.client_conn_mut(client_ch).write(s, ONE).unwrap();
- pair.drive();
- pair.client_conn_mut(client_ch).write(s, TWO).unwrap();
- pair.drive();
- pair.client_conn_mut(client_ch).write(s, THREE).unwrap();
-
- pair.drive();
-
- assert_eq!(pair.server_conn_mut(server_ch).accept(Dir::Bi), Some(s));
-
- // Read into an empty slice can't do much you, but doesn't crash
- assert_eq!(
- pair.server_conn_mut(server_ch).read_chunks(s, &mut empty),
- Ok(Some(0))
- );
-
- // Read until `chunks` is filled
- assert_eq!(
- pair.server_conn_mut(server_ch).read_chunks(s, &mut chunks),
- Ok(Some(2))
- );
- assert_eq!(&chunks, &[ONE, TWO]);
-
- // Read the rest
- assert_eq!(
- pair.server_conn_mut(server_ch).read_chunks(s, &mut chunks),
- Ok(Some(1))
- );
- assert_eq!(&chunks[..1], &[THREE]);
-
- // We've read everything, stream is now blocked
- assert_eq!(
- pair.server_conn_mut(server_ch).read_chunks(s, &mut chunks),
- Err(ReadError::Blocked)
- );
-
- // Read a new chunk after we've been blocked
- pair.client_conn_mut(client_ch).write(s, ONE).unwrap();
- pair.drive();
- assert_eq!(
- pair.server_conn_mut(server_ch).read_chunks(s, &mut chunks),
- Ok(Some(1))
- );
- assert_eq!(&chunks[..1], &[ONE]);
-
- // Stream finishes by yeilding `Ok(None)`
- pair.client_conn_mut(client_ch).finish(s).unwrap();
- pair.drive();
- assert_matches!(
- pair.server_conn_mut(server_ch).read_chunks(s, &mut chunks),
- Ok(None)
- );
-
- pair.drive();
+ assert_matches!(chunks.next(usize::MAX), Ok(None));
+ let _ = chunks.finalize();
}
}
diff --git a/quinn/src/streams.rs b/quinn/src/streams.rs
index 717780937..e3d74580d 100644
--- a/quinn/src/streams.rs
+++ b/quinn/src/streams.rs
@@ -11,7 +11,7 @@ use futures::{
io::{AsyncRead, AsyncWrite},
ready, FutureExt,
};
-use proto::{Chunk, ConnectionError, FinishError, StreamId};
+use proto::{Chunk, Chunks, ConnectionError, FinishError, ReadableError, StreamId};
use thiserror::Error;
use tokio::io::ReadBuf;
@@ -319,6 +319,7 @@ where
stream: StreamId,
is_0rtt: bool,
all_data_read: bool,
+ reset: Option,
}
impl RecvStream
@@ -331,6 +332,7 @@ where
stream,
is_0rtt,
all_data_read: false,
+ reset: None,
}
}
@@ -371,10 +373,26 @@ where
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll> {
- self.poll_read_generic(cx, |conn, stream| {
- conn.inner
- .read(stream, buf.remaining(), true)
- .map(|val| val.map(|chunk| buf.put_slice(&chunk.bytes)))
+ if buf.remaining() == 0 {
+ return Poll::Ready(Ok(()));
+ }
+
+ self.poll_read_generic(cx, true, |chunks| {
+ let mut read = false;
+ loop {
+ if buf.remaining() == 0 {
+ // We know `read` is `true` because `buf.remaining()` was not 0 before
+ return ReadStatus::Readable(());
+ }
+
+ match chunks.next(buf.remaining()) {
+ Ok(Some(chunk)) => {
+ buf.put_slice(&chunk.bytes);
+ read = true;
+ }
+ res => return (if read { Some(()) } else { None }, res.err()).into(),
+ }
+ }
})
.map(|res| res.map(|_| ()))
}
@@ -405,8 +423,9 @@ where
max_length: usize,
ordered: bool,
) -> Poll, ReadError>> {
- self.poll_read_generic(cx, |conn, stream| {
- conn.inner.read(stream, max_length, ordered)
+ self.poll_read_generic(cx, ordered, |chunks| match chunks.next(max_length) {
+ Ok(Some(chunk)) => ReadStatus::Readable(chunk),
+ res => (None, res.err()).into(),
})
}
@@ -428,7 +447,27 @@ where
cx: &mut Context,
bufs: &mut [Bytes],
) -> Poll, ReadError>> {
- self.poll_read_generic(cx, |conn, stream| conn.inner.read_chunks(stream, bufs))
+ if bufs.is_empty() {
+ return Poll::Ready(Ok(Some(0)));
+ }
+
+ self.poll_read_generic(cx, true, |chunks| {
+ let mut read = 0;
+ loop {
+ if read >= bufs.len() {
+ // We know `read > 0` because `bufs` cannot be empty here
+ return ReadStatus::Readable(read);
+ }
+
+ match chunks.next(usize::MAX) {
+ Ok(Some(chunk)) => {
+ bufs[read] = chunk.bytes;
+ read += 1;
+ }
+ res => return (if read == 0 { None } else { Some(read) }, res.err()).into(),
+ }
+ }
+ })
}
/// Convenience method to read all remaining data into a buffer
@@ -479,46 +518,86 @@ where
self.stream
}
+ /// Handle common logic related to reading out of a receive stream
+ ///
+ /// This takes an `FnMut` closure that takes care of the actual reading process, matching
+ /// the detailed read semantics for the calling function with a particular return type.
+ /// The closure can read from the passed `&mut Chunks` and has to return the status after
+ /// reading: the amount of data read, and the status after the final read call.
fn poll_read_generic(
&mut self,
cx: &mut Context,
+ ordered: bool,
mut read_fn: T,
) -> Poll, ReadError>>
where
- T: FnMut(
- &mut crate::connection::ConnectionInner,
- StreamId,
- ) -> Result, proto::ReadError>,
+ T: FnMut(&mut Chunks) -> ReadStatus,
{
use proto::ReadError::*;
+ if self.all_data_read {
+ return Poll::Ready(Ok(None));
+ }
+
let mut conn = self.conn.lock("RecvStream::poll_read");
if self.is_0rtt {
conn.check_0rtt().map_err(|()| ReadError::ZeroRttRejected)?;
}
- match read_fn(&mut conn, self.stream) {
- Ok(Some(u)) => {
- if conn.inner.has_pending_retransmits() {
- conn.wake()
+
+ // If we stored an error during a previous call, return it now. This can happen if a
+ // `read_fn` both wants to return data and also returns an error in its final stream status.
+ let status = match self.reset.take() {
+ Some(code) => ReadStatus::Failed(None, Reset(code)),
+ None => {
+ let mut chunks = conn.inner.read(self.stream, ordered)?;
+ let status = read_fn(&mut chunks);
+ if chunks.finalize().should_transmit() {
+ conn.wake();
}
- Poll::Ready(Ok(Some(u)))
+ status
}
- Ok(None) => {
+ };
+
+ match status {
+ ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))),
+ ReadStatus::Finished(read) => {
self.all_data_read = true;
- Poll::Ready(Ok(None))
+ Poll::Ready(Ok(read))
}
- Err(Blocked) => {
- if let Some(ref x) = conn.error {
- return Poll::Ready(Err(ReadError::ConnectionClosed(x.clone())));
+ ReadStatus::Failed(read, Blocked) => match read {
+ Some(val) => Poll::Ready(Ok(Some(val))),
+ None => {
+ if let Some(ref x) = conn.error {
+ return Poll::Ready(Err(ReadError::ConnectionClosed(x.clone())));
+ }
+ conn.blocked_readers.insert(self.stream, cx.waker().clone());
+ Poll::Pending
}
- conn.blocked_readers.insert(self.stream, cx.waker().clone());
- Poll::Pending
- }
- Err(Reset(error_code)) => {
- self.all_data_read = true;
- Poll::Ready(Err(ReadError::Reset(error_code)))
- }
- Err(UnknownStream) => Poll::Ready(Err(ReadError::UnknownStream)),
- Err(IllegalOrderedRead) => Poll::Ready(Err(ReadError::IllegalOrderedRead)),
+ },
+ ReadStatus::Failed(read, Reset(error_code)) => match read {
+ None => {
+ self.all_data_read = true;
+ Poll::Ready(Err(ReadError::Reset(error_code)))
+ }
+ done => {
+ self.reset = Some(error_code);
+ Poll::Ready(Ok(done))
+ }
+ },
+ }
+ }
+}
+
+enum ReadStatus {
+ Readable(T),
+ Finished(Option),
+ Failed(Option, proto::ReadError),
+}
+
+impl From<(Option, Option)> for ReadStatus {
+ fn from(status: (Option, Option)) -> Self {
+ match status {
+ (read, None) => ReadStatus::Finished(read),
+ (read, Some(e)) => ReadStatus::Failed(read, e),
}
}
}
@@ -661,6 +740,15 @@ pub enum ReadError {
ZeroRttRejected,
}
+impl From for ReadError {
+ fn from(e: ReadableError) -> Self {
+ match e {
+ ReadableError::UnknownStream => ReadError::UnknownStream,
+ ReadableError::IllegalOrderedRead => ReadError::IllegalOrderedRead,
+ }
+ }
+}
+
impl From for io::Error {
fn from(x: ReadError) -> Self {
use self::ReadError::*;