From 24cf82ef83269979d7c459d424a8198aa072f14e Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Thu, 4 Feb 2021 22:13:10 +0100 Subject: [PATCH] Expose iterator-like read API --- quinn-proto/src/connection/mod.rs | 42 +--- quinn-proto/src/connection/streams.rs | 76 ++----- quinn-proto/src/connection/streams/recv.rs | 228 ++++++++++++-------- quinn-proto/src/lib.rs | 6 +- quinn-proto/src/tests/mod.rs | 237 ++++++++------------- quinn/src/streams.rs | 150 ++++++++++--- 6 files changed, 379 insertions(+), 360 deletions(-) diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 727bbb398..8abb7e732 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -53,7 +53,10 @@ pub use stats::ConnectionStats; mod streams; pub use streams::Streams; -pub use streams::{FinishError, ReadError, ShouldTransmit, StreamEvent, UnknownStream, WriteError}; +pub use streams::{ + Chunks, FinishError, ReadError, ReadableError, ShouldTransmit, StreamEvent, UnknownStream, + WriteError, +}; mod timer; use timer::{Timer, TimerTable}; @@ -1088,40 +1091,9 @@ where /// control window is filled. On any given stream, you can switch from ordered to unordered /// reads, but ordered reads on streams that have seen previous unordered reads will return /// `ReadError::IllegalOrderedRead`. - pub fn read( - &mut self, - id: StreamId, - max_length: usize, - ordered: bool, - ) -> Result, ReadError> { - let result = self.streams.read(id, max_length, ordered); - self.post_read(id, &result); - Ok(result?.map(|x| x.result)) - } - - /// Read the next ordered chunks from the given recv stream - pub fn read_chunks( - &mut self, - id: StreamId, - bufs: &mut [Bytes], - ) -> Result, ReadError> { - let result = self.streams.read_chunks(id, bufs); - self.post_read(id, &result); - Ok(result?.map(|x| x.result.bufs)) - } - - fn post_read(&mut self, id: StreamId, result: &streams::ReadResult) { - let (did_read, max_data, max_stream_data) = match result { - Ok(Some(did)) => (true, did.max_data, did.max_stream_data), - _ => (false, ShouldTransmit::default(), ShouldTransmit::default()), - }; - - let max_dirty = self.streams.take_max_streams_dirty(id.dir()); - if did_read || max_dirty { - self.spaces[SpaceId::Data] - .pending - .post_read(id, max_data, max_stream_data, max_dirty); - } + pub fn read(&mut self, id: StreamId, ordered: bool) -> Result { + self.streams + .read(id, ordered, &mut self.spaces[SpaceId::Data].pending) } /// Send data on the given stream diff --git a/quinn-proto/src/connection/streams.rs b/quinn-proto/src/connection/streams.rs index 4cedd7325..4c1f0e72e 100644 --- a/quinn-proto/src/connection/streams.rs +++ b/quinn-proto/src/connection/streams.rs @@ -5,11 +5,10 @@ use std::{ mem, }; -use bytes::{BufMut, Bytes}; +use bytes::BufMut; use thiserror::Error; use tracing::{debug, trace}; -use super::assembler::Chunk; use super::spaces::{Retransmits, ThinRetransmits}; use crate::{ coding::BufMutExt, @@ -20,9 +19,8 @@ use crate::{ }; mod recv; -pub use recv::ReadError; -use recv::{BytesRead, ReadChunks, Recv, StreamReadResult}; -pub(super) use recv::{DidRead, ReadResult}; +use recv::Recv; +pub use recv::{Chunks, ReadError, ReadableError}; mod send; pub use send::{FinishError, WriteError}; @@ -202,55 +200,13 @@ impl Streams { self.connection_blocked.clear(); } - pub(crate) fn read( - &mut self, + pub(crate) fn read<'a>( + &'a mut self, id: StreamId, - max_length: usize, ordered: bool, - ) -> ReadResult { - self.try_read(id, |rs| rs.read(max_length, ordered)) - } - - pub(crate) fn read_chunks( - &mut self, - id: StreamId, - bufs: &mut [Bytes], - ) -> ReadResult { - self.try_read(id, |rs| rs.read_chunks(bufs)) - } - - fn try_read(&mut self, id: StreamId, mut read: F) -> ReadResult - where - F: FnMut(&mut Recv) -> StreamReadResult, - O: BytesRead, - { - let mut entry = match self.recv.entry(id) { - hash_map::Entry::Vacant(_) => return Err(ReadError::UnknownStream), - hash_map::Entry::Occupied(e) => e, - }; - let rs = entry.get_mut(); - match read(rs) { - Ok(Some(out)) => { - let (_, max_stream_data) = rs.max_stream_data(self.stream_receive_window); - let max_data = self.add_read_credits(out.bytes_read()); - Ok(Some(DidRead { - result: out, - max_stream_data, - max_data, - })) - } - Ok(None) => { - entry.remove_entry(); - self.stream_freed(id, StreamHalf::Recv); - Ok(None) - } - Err(e @ ReadError::Reset { .. }) => { - entry.remove_entry(); - self.stream_freed(id, StreamHalf::Recv); - Err(e) - } - Err(e) => Err(e), - } + pending: &'a mut Retransmits, + ) -> Result, ReadableError> { + Chunks::new(id, ordered, self, pending) } /// Queue `data` to be written for `stream` @@ -1077,6 +1033,7 @@ enum StreamHalf { mod tests { use super::*; use crate::TransportErrorCode; + use bytes::Bytes; fn make(side: Side) -> Streams { Streams::new( @@ -1094,6 +1051,7 @@ mod tests { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); let initial_max = client.local_max_data; + let mut pending = Retransmits::default(); assert_eq!( client .received( @@ -1110,7 +1068,9 @@ mod tests { ); assert_eq!(client.data_recvd, 2048); assert_eq!(client.local_max_data - initial_max, 0); - client.read(id, 1024, true).unwrap(); + let mut chunks = client.read(id, true, &mut pending).unwrap(); + chunks.next(1024).unwrap(); + let _ = chunks.finalize(); assert_eq!(client.local_max_data - initial_max, 1024); assert_eq!( client @@ -1193,6 +1153,7 @@ mod tests { fn recv_stopped() { let mut client = make(Side::Client); let id = StreamId::new(Side::Server, Dir::Uni, 0); + let mut pending = Retransmits::default(); let initial_max = client.local_max_data; assert_eq!( client @@ -1217,10 +1178,13 @@ mod tests { } ); assert!(client.stop(id).is_err()); - assert_eq!(client.read(id, 0, true), Err(ReadError::UnknownStream)); assert_eq!( - client.read(id, usize::MAX, false), - Err(ReadError::UnknownStream) + client.read(id, true, &mut pending).err(), + Some(ReadableError::UnknownStream) + ); + assert_eq!( + client.read(id, false, &mut pending).err(), + Some(ReadableError::UnknownStream) ); assert_eq!(client.local_max_data - initial_max, 32); assert_eq!( diff --git a/quinn-proto/src/connection/streams/recv.rs b/quinn-proto/src/connection/streams/recv.rs index 56a8f658b..dd72f80b0 100644 --- a/quinn-proto/src/connection/streams/recv.rs +++ b/quinn-proto/src/connection/streams/recv.rs @@ -1,8 +1,10 @@ -use bytes::Bytes; +use std::collections::hash_map::Entry; +use std::mem; + use thiserror::Error; use tracing::debug; -use super::{ShouldTransmit, UnknownStream}; +use super::{Retransmits, ShouldTransmit, StreamHalf, StreamId, Streams, UnknownStream}; use crate::connection::assembler::{Assembler, Chunk, IllegalOrderedRead}; use crate::{frame, TransportError, VarInt}; @@ -70,49 +72,6 @@ impl Recv { Ok((new_bytes, frame.fin && self.stopped)) } - pub(super) fn read(&mut self, max_length: usize, ordered: bool) -> StreamReadResult { - if self.stopped { - return Err(ReadError::UnknownStream); - } - - self.assembler.ensure_ordering(ordered)?; - match self.assembler.read(max_length, ordered) { - Some(chunk) => Ok(Some(chunk)), - None => self.read_blocked().map(|()| None), - } - } - - pub(super) fn read_chunks( - &mut self, - chunks: &mut [Bytes], - ) -> Result, ReadError> { - if self.stopped { - return Err(ReadError::UnknownStream); - } - - let mut out = ReadChunks { bufs: 0, read: 0 }; - if chunks.is_empty() { - return Ok(Some(out)); - } - - self.assembler.ensure_ordering(true)?; - while let Some(chunk) = self.assembler.read(usize::MAX, true) { - chunks[out.bufs] = chunk.bytes; - out.read += chunks[out.bufs].len(); - out.bufs += 1; - - if out.bufs >= chunks.len() { - return Ok(Some(out)); - } - } - - if out.bufs > 0 { - return Ok(Some(out)); - } - - self.read_blocked().map(|()| None) - } - pub(super) fn stop(&mut self) -> Result<(u64, ShouldTransmit), UnknownStream> { if self.stopped { return Err(UnknownStream { _private: () }); @@ -129,21 +88,6 @@ impl Recv { Ok((read_credits, ShouldTransmit(self.is_receiving()))) } - fn read_blocked(&mut self) -> Result<(), ReadError> { - match self.state { - RecvState::ResetRecvd { error_code, .. } => { - Err(ReadError::Reset(error_code)) - } - RecvState::Recv { size } => { - if size == Some(self.end) && self.assembler.bytes_read() == self.end { - Ok(()) - } else { - Err(ReadError::Blocked) - } - } - } - } - /// Returns the window that should be advertised in a `MAX_STREAM_DATA` frame /// /// The method returns a tuple which consists of the window that should be @@ -254,40 +198,149 @@ impl Recv { } } -pub(crate) type ReadResult = Result>, ReadError>; - -/// Result of a `Streams::read` call in case the stream had not ended yet -#[derive(Debug, Eq, PartialEq, Copy, Clone)] -#[must_use = "A frame might need to be enqueued"] -pub(crate) struct DidRead { - pub result: T, - pub max_stream_data: ShouldTransmit, - pub max_data: ShouldTransmit, +/// Chunks +pub struct Chunks<'a> { + id: StreamId, + ordered: bool, + streams: &'a mut Streams, + pending: &'a mut Retransmits, + state: ChunksState, + read: u64, } -pub(super) type StreamReadResult = Result, ReadError>; +impl<'a> Chunks<'a> { + pub(super) fn new( + id: StreamId, + ordered: bool, + streams: &'a mut Streams, + pending: &'a mut Retransmits, + ) -> Result { + let entry = match streams.recv.entry(id) { + Entry::Occupied(entry) => entry, + Entry::Vacant(_) => return Err(ReadableError::UnknownStream), + }; + + let mut recv = match entry.get().stopped { + true => return Err(ReadableError::UnknownStream), + false => entry.remove(), + }; -pub(crate) trait BytesRead { - fn bytes_read(&self) -> u64; -} + recv.assembler.ensure_ordering(ordered)?; + Ok(Self { + id, + ordered, + streams, + pending, + state: ChunksState::Readable(recv), + read: 0, + }) + } + + /// Next + /// + /// Should call finalize() when done calling this. + pub fn next(&mut self, max_length: usize) -> Result, ReadError> { + let mut rs = match mem::replace(&mut self.state, ChunksState::Finalized) { + ChunksState::Readable(rs) => rs, + ChunksState::Error(e, st) => { + self.state = ChunksState::Error(e.clone(), st); + return Err(e); + } + ChunksState::Finished(st) => { + self.state = ChunksState::Finished(st); + return Ok(None); + } + ChunksState::Finalized => panic!("must not call next() after finalize()"), + }; + + if let Some(chunk) = rs.assembler.read(max_length, self.ordered) { + self.read += chunk.bytes.len() as u64; + self.state = ChunksState::Readable(rs); + return Ok(Some(chunk)); + } + + match rs.state { + RecvState::ResetRecvd { error_code, .. } => { + self.streams.stream_freed(self.id, StreamHalf::Recv); + self.pending + .post_read(self.id, ShouldTransmit(false), ShouldTransmit(false), true); + + let err = ReadError::Reset(error_code); + self.state = ChunksState::Error(err.clone(), ShouldTransmit(true)); + Err(err) + } + RecvState::Recv { size } => { + if size == Some(rs.end) && rs.assembler.bytes_read() == rs.end { + self.streams.stream_freed(self.id, StreamHalf::Recv); + self.pending.post_read( + self.id, + ShouldTransmit(false), + ShouldTransmit(false), + true, + ); + self.state = ChunksState::Finished(ShouldTransmit(true)); + Ok(None) + } else { + let should_transmit = + Self::done(&mut rs, self.read, self.id, self.streams, self.pending); + self.state = ChunksState::Error(ReadError::Blocked, should_transmit); + self.streams.recv.insert(self.id, rs); + Err(ReadError::Blocked) + } + } + } + } -impl BytesRead for Chunk { - fn bytes_read(&self) -> u64 { - self.bytes.len() as u64 + /// Finalize + pub fn finalize(mut self) -> ShouldTransmit { + self.finalize_inner(false) } -} -pub(crate) struct ReadChunks { - pub bufs: usize, - pub read: usize, + fn finalize_inner(&mut self, drop: bool) -> ShouldTransmit { + let state = mem::replace(&mut self.state, ChunksState::Finalized); + match state { + ChunksState::Readable(mut rs) => { + debug_assert!(!drop); + let should_transmit = + Self::done(&mut rs, self.read, self.id, self.streams, self.pending); + self.streams.recv.insert(self.id, rs); + should_transmit + } + ChunksState::Finished(should_transmit) | ChunksState::Error(_, should_transmit) => { + debug_assert!(!drop); + should_transmit + } + ChunksState::Finalized => ShouldTransmit(false), + } + } + + fn done( + rs: &mut Recv, + read: u64, + id: StreamId, + streams: &mut Streams, + pending: &mut Retransmits, + ) -> ShouldTransmit { + let (_, max_stream_data) = rs.max_stream_data(streams.stream_receive_window); + let max_data = streams.add_read_credits(read); + pending.post_read(id, max_data, max_stream_data, false); + ShouldTransmit(max_stream_data.0 | max_data.0) + } } -impl BytesRead for ReadChunks { - fn bytes_read(&self) -> u64 { - self.read as u64 +impl<'a> Drop for Chunks<'a> { + fn drop(&mut self) { + let _ = self.finalize_inner(true); } } +enum ChunksState { + Readable(Recv), + Error(ReadError, ShouldTransmit), + Finished(ShouldTransmit), + Finalized, +} + /// Errors triggered when reading from a recv stream #[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum ReadError { @@ -302,6 +355,11 @@ pub enum ReadError { /// Carries an application-defined error code. #[error("reset by peer: code {0}")] Reset(VarInt), +} + +/// Errors triggered when opening a recv stream for reading +#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub enum ReadableError { /// The stream has not been opened or was already stopped, finished, or reset #[error("unknown stream")] UnknownStream, @@ -313,9 +371,9 @@ pub enum ReadError { IllegalOrderedRead, } -impl From for ReadError { +impl From for ReadableError { fn from(_: IllegalOrderedRead) -> Self { - ReadError::IllegalOrderedRead + ReadableError::IllegalOrderedRead } } diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index c75fc115b..733975b3e 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -40,8 +40,10 @@ mod varint; pub use varint::{VarInt, VarIntBoundsExceeded}; mod connection; -pub use crate::connection::{Chunk, ConnectionError, ConnectionStats, Event, SendDatagramError}; -pub use crate::connection::{FinishError, ReadError, StreamEvent, UnknownStream, WriteError}; +pub use crate::connection::{ + Chunk, Chunks, ConnectionError, ConnectionStats, Event, FinishError, ReadError, ReadableError, + SendDatagramError, StreamEvent, UnknownStream, WriteError, +}; mod config; pub use config::{ConfigError, TransportConfig}; diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index 4a9b548e9..9d3e666bd 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -238,14 +238,14 @@ fn finish_stream_simple() { assert_eq!(pair.server_conn_mut(client_ch).send_streams(), 0); assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); + + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), + chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); - assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), - Ok(None) - ); + assert_matches!(chunks.next(usize::MAX), Ok(None)); + let _ = chunks.finalize(); } #[test] @@ -270,10 +270,9 @@ fn reset_stream() { Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s); - assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), - Err(ReadError::Reset(ERROR)) - ); + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); + assert_matches!(chunks.next(usize::MAX), Err(ReadError::Reset(ERROR))); + let _ = chunks.finalize(); assert_matches!(pair.client_conn_mut(client_ch).poll(), None); } @@ -432,10 +431,13 @@ fn zero_rtt_happypath() { pair.drive(); assert!(pair.client_conn_mut(client_ch).accepted_0rtt()); let server_ch = pair.server.assert_accept(); + + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), + chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); + let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); } @@ -503,11 +505,10 @@ fn zero_rtt_rejection() { assert_matches!(pair.server_conn_mut(server_conn).poll(), None); let s2 = pair.client_conn_mut(client_ch).open(Dir::Uni).unwrap(); assert_eq!(s, s2); - assert_eq!( - pair.server_conn_mut(server_conn) - .read(s2, usize::MAX, false), - Err(ReadError::Blocked) - ); + + let mut chunks = pair.server_conn_mut(server_conn).read(s2, false).unwrap(); + assert_eq!(chunks.next(usize::MAX), Err(ReadError::Blocked)); + let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); } @@ -643,14 +644,15 @@ fn stream_id_limit() { Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s); + + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), + chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); - assert_eq!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), - Ok(None) - ); + assert_eq!(chunks.next(usize::MAX), Ok(None)); + let _ = chunks.finalize(); + // Server will only send MAX_STREAM_ID now that the application's been notified pair.drive(); assert_matches!( @@ -676,10 +678,10 @@ fn stream_id_limit() { ); assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); - assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), - Ok(None) - ); + + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); + assert_matches!(chunks.next(usize::MAX), Ok(None)); + let _ = chunks.finalize(); } #[test] @@ -705,10 +707,12 @@ fn key_update_simple() { ); assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Bi), Some(stream) if stream == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), + chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG1 ); + let _ = chunks.finalize(); info!("initiating key update"); pair.client_conn_mut(client_ch).initiate_key_update(); @@ -719,10 +723,12 @@ fn key_update_simple() { assert_matches!(pair.server_conn_mut(server_ch).poll(), Some(Event::Stream(StreamEvent::Readable { id })) if id == s); assert_matches!(pair.server_conn_mut(server_ch).poll(), None); + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), + chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 6 && chunk.bytes == MSG2 ); + let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); assert_eq!(pair.server_conn_mut(server_ch).lost_packets(), 0); @@ -763,18 +769,12 @@ fn key_update_reordered() { ); assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Bi), Some(stream) if stream == s); - let buf1 = pair - .server_conn_mut(server_ch) - .read(s, usize::MAX, true) - .unwrap() - .unwrap(); + let mut chunks = pair.server_conn_mut(server_ch).read(s, true).unwrap(); + let buf1 = chunks.next(usize::MAX).unwrap().unwrap(); assert_matches!(&*buf1.bytes, MSG1); - let buf2 = pair - .server_conn_mut(server_ch) - .read(s, usize::MAX, true) - .unwrap() - .unwrap(); + let buf2 = chunks.next(usize::MAX).unwrap().unwrap(); assert_eq!(buf2.bytes, MSG2); + let _ = chunks.finalize(); assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); assert_eq!(pair.server_conn_mut(server_ch).lost_packets(), 0); @@ -987,10 +987,13 @@ fn test_flow_control(config: TransportConfig, window_size: usize) { .reset(s, VarInt(42)) .unwrap(); pair.drive(); + + let mut chunks = pair.server_conn_mut(server_conn).read(s, true).unwrap(); assert_eq!( - pair.server_conn_mut(server_conn).read(s, usize::MAX, true), - Err(ReadError::Reset(VarInt(42))) + chunks.next(usize::MAX).err(), + Some(ReadError::Reset(VarInt(42))) ); + let _ = chunks.finalize(); // Happy path let s = pair.client_conn_mut(client_conn).open(Dir::Uni).unwrap(); @@ -1006,8 +1009,9 @@ fn test_flow_control(config: TransportConfig, window_size: usize) { pair.drive(); let mut cursor = 0; + let mut chunks = pair.server_conn_mut(server_conn).read(s, true).unwrap(); loop { - match pair.server_conn_mut(server_conn).read(s, usize::MAX, true) { + match chunks.next(usize::MAX) { Ok(Some(chunk)) => { cursor += chunk.bytes.len(); } @@ -1022,6 +1026,7 @@ fn test_flow_control(config: TransportConfig, window_size: usize) { } } } + let _ = chunks.finalize(); assert_eq!(cursor, window_size); pair.drive(); @@ -1037,8 +1042,9 @@ fn test_flow_control(config: TransportConfig, window_size: usize) { pair.drive(); let mut cursor = 0; + let mut chunks = pair.server_conn_mut(server_conn).read(s, true).unwrap(); loop { - match pair.server_conn_mut(server_conn).read(s, usize::MAX, true) { + match chunks.next(usize::MAX) { Ok(Some(chunk)) => { cursor += chunk.bytes.len(); } @@ -1054,6 +1060,7 @@ fn test_flow_control(config: TransportConfig, window_size: usize) { } } assert_eq!(cursor, window_size); + let _ = chunks.finalize(); } #[test] @@ -1102,10 +1109,11 @@ fn stop_opens_bidi() { assert_eq!(pair.server_conn_mut(client_conn).send_streams(), 0); assert_matches!(pair.server_conn_mut(server_conn).accept(Dir::Bi), Some(stream) if stream == s); assert_eq!(pair.server_conn_mut(client_conn).send_streams(), 1); - assert_matches!( - pair.server_conn_mut(server_conn).read(s, usize::MAX, false), - Err(ReadError::Blocked) - ); + + let mut chunks = pair.server_conn_mut(server_conn).read(s, false).unwrap(); + assert_matches!(chunks.next(usize::MAX), Err(ReadError::Blocked)); + let _ = chunks.finalize(); + assert_matches!( pair.server_conn_mut(server_conn).write(s, b"foo"), Err(WriteError::Stopped(ERROR)) @@ -1302,10 +1310,13 @@ fn finish_stream_flow_control_reordered() { pair.server.drive(pair.time, pair.client.addr); // Receive // Issue flow control credit + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), + chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); + let _ = chunks.finalize(); + pair.server.drive(pair.time, pair.client.addr); pair.server.delay_outbound(); // Delay it @@ -1325,10 +1336,10 @@ fn finish_stream_flow_control_reordered() { Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) ); assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s); - assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), - Ok(None) - ); + + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); + assert_matches!(chunks.next(usize::MAX), Ok(None)); + let _ = chunks.finalize(); } #[test] @@ -1357,10 +1368,12 @@ fn handshake_1rtt_handling() { pair.drive(); assert!(pair.client_conn_mut(client_ch).lost_packets() != 0); + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), + chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); + let _ = chunks.finalize(); } #[test] @@ -1603,14 +1616,14 @@ fn finish_acked() { assert_matches!(pair.server_conn_mut(server_ch).poll(), None); assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s); + + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), + chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); - assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), - Err(ReadError::Blocked) - ); + assert_matches!(chunks.next(usize::MAX), Err(ReadError::Blocked)); + let _ = chunks.finalize(); // Finish before receiving data ack pair.client_conn_mut(client_ch).finish(s).unwrap(); @@ -1626,10 +1639,10 @@ fn finish_acked() { pair.client_conn_mut(client_ch).poll(), Some(Event::Stream(StreamEvent::Finished { id })) if id == s ); - assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), - Ok(None) - ); + + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); + assert_matches!(chunks.next(usize::MAX), Ok(None)); + let _ = chunks.finalize(); } #[test] @@ -1668,14 +1681,14 @@ fn finish_retransmit() { ); assert_matches!(pair.server_conn_mut(server_ch).accept(Dir::Uni), Some(stream) if stream == s); + + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), + chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG ); - assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), - Ok(None) - ); + assert_matches!(chunks.next(usize::MAX), Ok(None)); + let _ = chunks.finalize(); } /// Ensures that exchanging data on a client-initiated bidirectional stream works past the initial @@ -1703,103 +1716,25 @@ fn repeated_request_response() { pair.drive(); assert_eq!(pair.server_conn_mut(server_ch).accept(Dir::Bi), Some(s)); + let mut chunks = pair.server_conn_mut(server_ch).read(s, false).unwrap(); assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), + chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == REQUEST ); - assert_matches!( - pair.server_conn_mut(server_ch).read(s, usize::MAX, false), - Ok(None) - ); + assert_matches!(chunks.next(usize::MAX), Ok(None)); + let _ = chunks.finalize(); + pair.server_conn_mut(server_ch).write(s, RESPONSE).unwrap(); pair.server_conn_mut(server_ch).finish(s).unwrap(); pair.drive(); + let mut chunks = pair.client_conn_mut(client_ch).read(s, false).unwrap(); assert_matches!( - pair.client_conn_mut(client_ch).read(s, usize::MAX, false), + chunks.next(usize::MAX), Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == RESPONSE ); - assert_matches!( - pair.client_conn_mut(client_ch).read(s, usize::MAX, false), - Ok(None) - ); - } -} - -#[test] -fn read_chunks() { - let _guard = subscribe(); - let server = ServerConfig { - transport: Arc::new(TransportConfig { - max_concurrent_bidi_streams: 3u32.into(), - ..TransportConfig::default() - }), - ..server_config() - }; - let mut pair = Pair::new(Default::default(), server); - let (client_ch, server_ch) = pair.connect(); - let mut empty = vec![]; - let mut chunks = vec![Bytes::new(), Bytes::new()]; - const ONE: &[u8] = b"ONE"; - const TWO: &[u8] = b"TWO"; - const THREE: &[u8] = b"THREE"; - for _ in 0..3 { - let s = pair.client_conn_mut(client_ch).open(Dir::Bi).unwrap(); - - pair.client_conn_mut(client_ch).write(s, ONE).unwrap(); - pair.drive(); - pair.client_conn_mut(client_ch).write(s, TWO).unwrap(); - pair.drive(); - pair.client_conn_mut(client_ch).write(s, THREE).unwrap(); - - pair.drive(); - - assert_eq!(pair.server_conn_mut(server_ch).accept(Dir::Bi), Some(s)); - - // Read into an empty slice can't do much you, but doesn't crash - assert_eq!( - pair.server_conn_mut(server_ch).read_chunks(s, &mut empty), - Ok(Some(0)) - ); - - // Read until `chunks` is filled - assert_eq!( - pair.server_conn_mut(server_ch).read_chunks(s, &mut chunks), - Ok(Some(2)) - ); - assert_eq!(&chunks, &[ONE, TWO]); - - // Read the rest - assert_eq!( - pair.server_conn_mut(server_ch).read_chunks(s, &mut chunks), - Ok(Some(1)) - ); - assert_eq!(&chunks[..1], &[THREE]); - - // We've read everything, stream is now blocked - assert_eq!( - pair.server_conn_mut(server_ch).read_chunks(s, &mut chunks), - Err(ReadError::Blocked) - ); - - // Read a new chunk after we've been blocked - pair.client_conn_mut(client_ch).write(s, ONE).unwrap(); - pair.drive(); - assert_eq!( - pair.server_conn_mut(server_ch).read_chunks(s, &mut chunks), - Ok(Some(1)) - ); - assert_eq!(&chunks[..1], &[ONE]); - - // Stream finishes by yeilding `Ok(None)` - pair.client_conn_mut(client_ch).finish(s).unwrap(); - pair.drive(); - assert_matches!( - pair.server_conn_mut(server_ch).read_chunks(s, &mut chunks), - Ok(None) - ); - - pair.drive(); + assert_matches!(chunks.next(usize::MAX), Ok(None)); + let _ = chunks.finalize(); } } diff --git a/quinn/src/streams.rs b/quinn/src/streams.rs index 717780937..e3d74580d 100644 --- a/quinn/src/streams.rs +++ b/quinn/src/streams.rs @@ -11,7 +11,7 @@ use futures::{ io::{AsyncRead, AsyncWrite}, ready, FutureExt, }; -use proto::{Chunk, ConnectionError, FinishError, StreamId}; +use proto::{Chunk, Chunks, ConnectionError, FinishError, ReadableError, StreamId}; use thiserror::Error; use tokio::io::ReadBuf; @@ -319,6 +319,7 @@ where stream: StreamId, is_0rtt: bool, all_data_read: bool, + reset: Option, } impl RecvStream @@ -331,6 +332,7 @@ where stream, is_0rtt, all_data_read: false, + reset: None, } } @@ -371,10 +373,26 @@ where cx: &mut Context, buf: &mut ReadBuf<'_>, ) -> Poll> { - self.poll_read_generic(cx, |conn, stream| { - conn.inner - .read(stream, buf.remaining(), true) - .map(|val| val.map(|chunk| buf.put_slice(&chunk.bytes))) + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + self.poll_read_generic(cx, true, |chunks| { + let mut read = false; + loop { + if buf.remaining() == 0 { + // We know `read` is `true` because `buf.remaining()` was not 0 before + return ReadStatus::Readable(()); + } + + match chunks.next(buf.remaining()) { + Ok(Some(chunk)) => { + buf.put_slice(&chunk.bytes); + read = true; + } + res => return (if read { Some(()) } else { None }, res.err()).into(), + } + } }) .map(|res| res.map(|_| ())) } @@ -405,8 +423,9 @@ where max_length: usize, ordered: bool, ) -> Poll, ReadError>> { - self.poll_read_generic(cx, |conn, stream| { - conn.inner.read(stream, max_length, ordered) + self.poll_read_generic(cx, ordered, |chunks| match chunks.next(max_length) { + Ok(Some(chunk)) => ReadStatus::Readable(chunk), + res => (None, res.err()).into(), }) } @@ -428,7 +447,27 @@ where cx: &mut Context, bufs: &mut [Bytes], ) -> Poll, ReadError>> { - self.poll_read_generic(cx, |conn, stream| conn.inner.read_chunks(stream, bufs)) + if bufs.is_empty() { + return Poll::Ready(Ok(Some(0))); + } + + self.poll_read_generic(cx, true, |chunks| { + let mut read = 0; + loop { + if read >= bufs.len() { + // We know `read > 0` because `bufs` cannot be empty here + return ReadStatus::Readable(read); + } + + match chunks.next(usize::MAX) { + Ok(Some(chunk)) => { + bufs[read] = chunk.bytes; + read += 1; + } + res => return (if read == 0 { None } else { Some(read) }, res.err()).into(), + } + } + }) } /// Convenience method to read all remaining data into a buffer @@ -479,46 +518,86 @@ where self.stream } + /// Handle common logic related to reading out of a receive stream + /// + /// This takes an `FnMut` closure that takes care of the actual reading process, matching + /// the detailed read semantics for the calling function with a particular return type. + /// The closure can read from the passed `&mut Chunks` and has to return the status after + /// reading: the amount of data read, and the status after the final read call. fn poll_read_generic( &mut self, cx: &mut Context, + ordered: bool, mut read_fn: T, ) -> Poll, ReadError>> where - T: FnMut( - &mut crate::connection::ConnectionInner, - StreamId, - ) -> Result, proto::ReadError>, + T: FnMut(&mut Chunks) -> ReadStatus, { use proto::ReadError::*; + if self.all_data_read { + return Poll::Ready(Ok(None)); + } + let mut conn = self.conn.lock("RecvStream::poll_read"); if self.is_0rtt { conn.check_0rtt().map_err(|()| ReadError::ZeroRttRejected)?; } - match read_fn(&mut conn, self.stream) { - Ok(Some(u)) => { - if conn.inner.has_pending_retransmits() { - conn.wake() + + // If we stored an error during a previous call, return it now. This can happen if a + // `read_fn` both wants to return data and also returns an error in its final stream status. + let status = match self.reset.take() { + Some(code) => ReadStatus::Failed(None, Reset(code)), + None => { + let mut chunks = conn.inner.read(self.stream, ordered)?; + let status = read_fn(&mut chunks); + if chunks.finalize().should_transmit() { + conn.wake(); } - Poll::Ready(Ok(Some(u))) + status } - Ok(None) => { + }; + + match status { + ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))), + ReadStatus::Finished(read) => { self.all_data_read = true; - Poll::Ready(Ok(None)) + Poll::Ready(Ok(read)) } - Err(Blocked) => { - if let Some(ref x) = conn.error { - return Poll::Ready(Err(ReadError::ConnectionClosed(x.clone()))); + ReadStatus::Failed(read, Blocked) => match read { + Some(val) => Poll::Ready(Ok(Some(val))), + None => { + if let Some(ref x) = conn.error { + return Poll::Ready(Err(ReadError::ConnectionClosed(x.clone()))); + } + conn.blocked_readers.insert(self.stream, cx.waker().clone()); + Poll::Pending } - conn.blocked_readers.insert(self.stream, cx.waker().clone()); - Poll::Pending - } - Err(Reset(error_code)) => { - self.all_data_read = true; - Poll::Ready(Err(ReadError::Reset(error_code))) - } - Err(UnknownStream) => Poll::Ready(Err(ReadError::UnknownStream)), - Err(IllegalOrderedRead) => Poll::Ready(Err(ReadError::IllegalOrderedRead)), + }, + ReadStatus::Failed(read, Reset(error_code)) => match read { + None => { + self.all_data_read = true; + Poll::Ready(Err(ReadError::Reset(error_code))) + } + done => { + self.reset = Some(error_code); + Poll::Ready(Ok(done)) + } + }, + } + } +} + +enum ReadStatus { + Readable(T), + Finished(Option), + Failed(Option, proto::ReadError), +} + +impl From<(Option, Option)> for ReadStatus { + fn from(status: (Option, Option)) -> Self { + match status { + (read, None) => ReadStatus::Finished(read), + (read, Some(e)) => ReadStatus::Failed(read, e), } } } @@ -661,6 +740,15 @@ pub enum ReadError { ZeroRttRejected, } +impl From for ReadError { + fn from(e: ReadableError) -> Self { + match e { + ReadableError::UnknownStream => ReadError::UnknownStream, + ReadableError::IllegalOrderedRead => ReadError::IllegalOrderedRead, + } + } +} + impl From for io::Error { fn from(x: ReadError) -> Self { use self::ReadError::*;