From f0ddc666d075a017f538e0c8dccfa2da2f19cd8d Mon Sep 17 00:00:00 2001 From: Aaron Todd Date: Fri, 17 May 2024 16:17:22 -0400 Subject: [PATCH] disable stalled stream protection on empty bodies and after read complete (#3644) ## Motivation and Context * https://github.com/awslabs/aws-sdk-rust/issues/1141 * https://github.com/awslabs/aws-sdk-rust/issues/1146 * https://github.com/awslabs/aws-sdk-rust/issues/1148 ## Description * Disables stalled stream upload protection for requests with an empty/zero length body. * Disables stalled stream upload throughput checking once the request body has been read and handed off to the HTTP layer. ## Testing Additional integration tests added covering empty bodies and completed uploads. Tested SQS issue against latest runtime and can see it works now. The S3 `CopyObject` issue is related to downloads and will need a different solution. ## Checklist - [x] I have updated `CHANGELOG.next.toml` if I made changes to the smithy-rs codegen or runtime crates - [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS SDK, generated SDK code, or SDK runtime crates ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --------- Co-authored-by: Zelda Hessler Co-authored-by: ysaito1001 --- CHANGELOG.next.toml | 12 ++ .../aws-smithy-runtime-api/Cargo.toml | 2 +- .../src/client/stalled_stream_protection.rs | 6 +- rust-runtime/aws-smithy-runtime/Cargo.toml | 2 +- .../client/http/body/minimum_throughput.rs | 6 + .../minimum_throughput/http_body_0_4_x.rs | 20 ++- .../http/body/minimum_throughput/options.rs | 21 +-- .../body/minimum_throughput/throughput.rs | 20 +++ .../src/client/stalled_stream_protection.rs | 6 + .../tests/stalled_stream_download.rs | 12 +- .../tests/stalled_stream_upload.rs | 161 ++++++++++++++++-- rust-runtime/aws-smithy-types/Cargo.toml | 3 +- rust-runtime/aws-smithy-types/src/body.rs | 2 + .../aws-smithy-types/src/byte_stream.rs | 48 +++++- .../src/byte_stream/bytestream_util.rs | 67 ++++++-- 15 files changed, 339 insertions(+), 49 deletions(-) diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 22094149df..5a9334b9af 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -36,6 +36,18 @@ references = ["aws-sdk-rust#1079"] meta = { "breaking" = false, "bug" = true, "tada" = false } author = "rcoh" +[[aws-sdk-rust]] +message = "Fixes stalled upload stream protection to not apply to empty request bodies and to stop checking for violations once the request body has been read." +references = ["aws-sdk-rust#1141", "aws-sdk-rust#1146", "aws-sdk-rust#1148"] +meta = { "breaking" = false, "tada" = false, "bug" = true } +authors = ["aajtodd", "Velfi"] + +[[smithy-rs]] +message = "Fixes stalled upload stream protection to not apply to empty request bodies and to stop checking for violations once the request body has been read." +references = ["aws-sdk-rust#1141", "aws-sdk-rust#1146", "aws-sdk-rust#1148"] +meta = { "breaking" = false, "tada" = false, "bug" = true } +authors = ["aajtodd", "Velfi"] + [[aws-sdk-rust]] message = "Updating the documentation for the `app_name` method on `ConfigLoader` to indicate the order of precedence for the sources of the `AppName`." references = ["smithy-rs#3645"] diff --git a/rust-runtime/aws-smithy-runtime-api/Cargo.toml b/rust-runtime/aws-smithy-runtime-api/Cargo.toml index 4db33bd48b..f4c5937f0e 100644 --- a/rust-runtime/aws-smithy-runtime-api/Cargo.toml +++ b/rust-runtime/aws-smithy-runtime-api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-runtime-api" -version = "1.6.0" +version = "1.6.1" authors = ["AWS Rust SDK Team ", "Zelda Hessler "] description = "Smithy runtime types." edition = "2021" diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs b/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs index f90f886592..d2af8a4342 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs @@ -13,7 +13,11 @@ use aws_smithy_types::config_bag::{Storable, StoreReplace}; use std::time::Duration; -const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(5); +/// The default grace period for stalled stream protection. +/// +/// When a stream stalls for longer than this grace period, the stream will +/// return an error. +pub const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(20); /// Configuration for stalled stream protection. /// diff --git a/rust-runtime/aws-smithy-runtime/Cargo.toml b/rust-runtime/aws-smithy-runtime/Cargo.toml index a11496cf3d..38836b6f8a 100644 --- a/rust-runtime/aws-smithy-runtime/Cargo.toml +++ b/rust-runtime/aws-smithy-runtime/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-runtime" -version = "1.5.2" +version = "1.5.3" authors = ["AWS Rust SDK Team ", "Zelda Hessler "] description = "The new smithy runtime crate" edition = "2021" diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs index 59c8a3c64c..5f2a5f6e6a 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs @@ -136,6 +136,10 @@ impl UploadThroughput { self.logs.lock().unwrap().push_bytes_transferred(now, bytes); } + pub(crate) fn mark_complete(&self) -> bool { + self.logs.lock().unwrap().mark_complete() + } + pub(crate) fn report(&self, now: SystemTime) -> ThroughputReport { self.logs.lock().unwrap().report(now) } @@ -177,6 +181,8 @@ trait UploadReport { impl UploadReport for ThroughputReport { fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) { let throughput = match self { + // stream has been exhausted, stop tracking violations + ThroughputReport::Complete => return (false, ZERO_THROUGHPUT), // If the report is incomplete, then we don't have enough data yet to // decide if minimum throughput was violated. ThroughputReport::Incomplete => { diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs index a8f2fe9c4b..fae7d1c53c 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs @@ -22,6 +22,7 @@ trait DownloadReport { impl DownloadReport for ThroughputReport { fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) { let throughput = match self { + ThroughputReport::Complete => return (false, ZERO_THROUGHPUT), // If the report is incomplete, then we don't have enough data yet to // decide if minimum throughput was violated. ThroughputReport::Incomplete => { @@ -175,6 +176,18 @@ where tracing::trace!("received data: {}", bytes.len()); this.throughput .push_bytes_transferred(now, bytes.len() as u64); + + // hyper will optimistically stop polling when end of stream is reported + // (e.g. when content-length amount of data has been consumed) which means + // we may never get to `Poll:Ready(None)`. Check for same condition and + // attempt to stop checking throughput violations _now_ as we may never + // get polled again. The caveat here is that it depends on `Body` implementations + // implementing `is_end_stream()` correctly. Users can also disable SSP as an + // alternative for such fringe use cases. + if self.is_end_stream() { + tracing::trace!("stream reported end of stream before Poll::Ready(None) reached; marking stream complete"); + self.throughput.mark_complete(); + } Poll::Ready(Some(Ok(bytes))) } Poll::Pending => { @@ -183,7 +196,12 @@ where Poll::Pending } // If we've read all the data or an error occurred, then return that result. - res => res, + res => { + if this.throughput.mark_complete() { + tracing::trace!("stream completed: {:?}", res); + } + res + } } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs index 113461a31e..565b7e187a 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs @@ -4,10 +4,12 @@ */ use super::Throughput; -use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig; +use aws_smithy_runtime_api::client::stalled_stream_protection::{ + StalledStreamProtectionConfig, DEFAULT_GRACE_PERIOD, +}; use std::time::Duration; -/// A collection of options for configuring a [`MinimumThroughputBody`](super::MinimumThroughputBody). +/// A collection of options for configuring a [`MinimumThroughputBody`](super::MinimumThroughputDownloadBody). #[derive(Debug, Clone)] pub struct MinimumThroughputBodyOptions { /// The minimum throughput that is acceptable. @@ -69,6 +71,13 @@ impl MinimumThroughputBodyOptions { } } +const DEFAULT_MINIMUM_THROUGHPUT: Throughput = Throughput { + bytes_read: 1, + per_time_elapsed: Duration::from_secs(1), +}; + +const DEFAULT_CHECK_WINDOW: Duration = Duration::from_secs(1); + impl Default for MinimumThroughputBodyOptions { fn default() -> Self { Self { @@ -87,14 +96,6 @@ pub struct MinimumThroughputBodyOptionsBuilder { grace_period: Option, } -const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(0); -const DEFAULT_MINIMUM_THROUGHPUT: Throughput = Throughput { - bytes_read: 1, - per_time_elapsed: Duration::from_secs(1), -}; - -const DEFAULT_CHECK_WINDOW: Duration = Duration::from_secs(1); - impl MinimumThroughputBodyOptionsBuilder { /// Create a new `MinimumThroughputBodyOptionsBuilder`. pub fn new() -> Self { diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs index 57ea3318e7..83a2e4ca77 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs @@ -260,6 +260,8 @@ pub(crate) enum ThroughputReport { Pending, /// The stream transferred this amount of throughput during the time window. Transferred(Throughput), + /// The stream has completed, no more data is expected. + Complete, } const BIN_COUNT: usize = 10; @@ -285,6 +287,7 @@ pub(super) struct ThroughputLogs { resolution: Duration, current_tail: SystemTime, buffer: LogBuffer, + stream_complete: bool, } impl ThroughputLogs { @@ -302,6 +305,7 @@ impl ThroughputLogs { resolution, current_tail: now, buffer: LogBuffer::new(), + stream_complete: false, } } @@ -343,8 +347,24 @@ impl ThroughputLogs { assert!(self.current_tail >= now); } + /// Mark the stream complete indicating no more data is expected. This is an + /// idempotent operation -- subsequent invocations of this function have no effect + /// and return false. + /// + /// After marking a stream complete [report](#method.report) will forever more return + /// [ThroughputReport::Complete] + pub(super) fn mark_complete(&mut self) -> bool { + let prev = self.stream_complete; + self.stream_complete = true; + !prev + } + /// Generates an overall report of the time window. pub(super) fn report(&mut self, now: SystemTime) -> ThroughputReport { + if self.stream_complete { + return ThroughputReport::Complete; + } + self.catch_up(now); self.buffer.fill_gaps(); diff --git a/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs index 83cfb64752..071ca32c69 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs @@ -65,6 +65,12 @@ impl Intercept for StalledStreamProtectionInterceptor { ) -> Result<(), BoxError> { if let Some(sspcfg) = cfg.load::().cloned() { if sspcfg.upload_enabled() { + if let Some(0) = context.request().body().content_length() { + tracing::trace!( + "skipping stalled stream protection for zero length request body" + ); + return Ok(()); + } let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?; let now = time_source.now(); diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs index 54e953322c..d8ab3327a4 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs @@ -105,9 +105,13 @@ async fn download_stalls() { let (time, sleep) = tick_advance_time_and_sleep(); let (server, response_sender) = channel_server(); let op = operation(server, time.clone(), sleep); + let barrier = Arc::new(Barrier::new(2)); + let c = barrier.clone(); let server = tokio::spawn(async move { - for _ in 1..10 { + c.wait().await; + for i in 1..10 { + tracing::debug!("send {i}"); response_sender.send(NEAT_DATA).await.unwrap(); tick!(time, Duration::from_secs(1)); } @@ -115,7 +119,10 @@ async fn download_stalls() { }); let response_body = op.invoke(()).await.expect("initial success"); - let result = tokio::spawn(eagerly_consume(response_body)); + let result = tokio::spawn(async move { + barrier.wait().await; + eagerly_consume(response_body).await + }); server.await.unwrap(); let err = result @@ -188,6 +195,7 @@ async fn user_downloads_data_too_slowly() { } use download_test_tools::*; +use tokio::sync::Barrier; mod download_test_tools { use crate::stalled_stream_common::*; diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs index f64fa321b2..47cb47edfb 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs @@ -7,6 +7,8 @@ #[macro_use] mod stalled_stream_common; + +use aws_smithy_runtime_api::client::stalled_stream_protection::DEFAULT_GRACE_PERIOD; use stalled_stream_common::*; /// Scenario: Successful upload at a rate above the minimum throughput. @@ -88,7 +90,7 @@ async fn upload_too_slow() { async fn upload_stalls() { let _logs = show_test_logs(); - let (server, time, sleep) = stalling_server(); + let (server, time, sleep) = stalling_server(None); let op = operation(server, time.clone(), sleep); let (body, body_sender) = channel_body(); @@ -107,27 +109,84 @@ async fn upload_stalls() { expect_timeout(result.await.expect("no panics")); } +/// Scenario: Request does not have a body. Server response doesn't start coming through +/// until after the grace period. +/// Expected: MUST NOT timeout. +#[tokio::test] +async fn empty_request_body_delayed_response() { + let _logs = show_test_logs(); + + let (server, time, sleep) = stalling_server(Some(Duration::from_secs(6))); + let op = operation(server, time.clone(), sleep); + + let result = tokio::spawn(async move { op.invoke(SdkBody::empty()).await }); + + let _advance = tokio::spawn(async move { + for _ in 0..6 { + tick!(time, Duration::from_secs(1)); + } + }); + + assert_eq!(200, result.await.unwrap().expect("success").as_u16()); +} + /// Scenario: All the request data is either uploaded to the server or buffered in the /// HTTP client, but the response doesn't start coming through within the grace period. -/// Expected: MUST timeout after the grace period completes. +/// Expected: MUST NOT timeout, upload throughput should only apply up until the request body has +/// been read completely and handed off to the HTTP client. #[tokio::test] -async fn complete_upload_no_response() { +async fn complete_upload_delayed_response() { let _logs = show_test_logs(); - let (server, time, sleep) = stalling_server(); + let (server, time, sleep) = stalling_server(Some(Duration::from_secs(6))); let op = operation(server, time.clone(), sleep); let (body, body_sender) = channel_body(); let result = tokio::spawn(async move { op.invoke(body).await }); let _streamer = tokio::spawn(async move { + info!("send data"); body_sender.send(NEAT_DATA).await.unwrap(); tick!(time, Duration::from_secs(1)); + info!("body send complete; dropping"); drop(body_sender); - time.tick(Duration::from_secs(6)).await; + tick!(time, DEFAULT_GRACE_PERIOD); + info!("body stream task complete"); + // advance to unblock the stalled server + tick!(time, Duration::from_secs(2)); }); - expect_timeout(result.await.expect("no panics")); + assert_eq!(200, result.await.unwrap().expect("success").as_u16()); +} + +/// Scenario: Upload all request data and never poll again once content-length has +/// been reached. Hyper will stop polling once it detects end of stream so we can't rely +/// on reaching `Poll:Ready(None)` to detect end of stream. +/// +/// ref: https://github.com/hyperium/hyper/issues/1545 +/// ref: https://github.com/hyperium/hyper/issues/1521 +/// +/// Expected: MUST NOT timeout, upload throughput should only apply up until the request body has +/// been read completely. Once no more data is expected we should stop checking for throughput +/// violations. +#[tokio::test] +async fn complete_upload_stop_polling() { + let _logs = show_test_logs(); + + let (server, time, sleep) = limited_read_server(NEAT_DATA.len(), Some(Duration::from_secs(7))); + let op = operation(server, time.clone(), sleep.clone()); + + let body = SdkBody::from(NEAT_DATA); + let result = tokio::spawn(async move { op.invoke(body).await }); + + tokio::spawn(async move { + // advance past the grace period + tick!(time, DEFAULT_GRACE_PERIOD + Duration::from_secs(1)); + // unblock server + tick!(time, Duration::from_secs(2)); + }); + + assert_eq!(200, result.await.unwrap().expect("success").as_u16()); } // Scenario: The server stops asking for data, the client maxes out its send buffer, @@ -189,6 +248,8 @@ async fn user_provides_data_too_slowly() { use upload_test_tools::*; mod upload_test_tools { + use aws_smithy_async::rt::sleep::AsyncSleep; + use crate::stalled_stream_common::*; pub fn successful_response() -> HttpResponse { @@ -285,24 +346,43 @@ mod upload_test_tools { fake_server!(FakeServerConnector, fake_server, bool, advance_time) } - /// Fake server/connector that reads some data, and then stalls. - pub fn stalling_server() -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) { + /// Fake server/connector that reads some data, and then stalls for the given time before + /// returning a response. If `None` is given the server will stall indefinitely. + pub fn stalling_server( + respond_after: Option, + ) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) { async fn fake_server( mut body: Pin<&mut SdkBody>, _time: TickAdvanceTime, - _sleep: TickAdvanceSleep, - _: (), + sleep: TickAdvanceSleep, + respond_after: Option, ) -> HttpResponse { let mut times = 5; while times > 0 && poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() { times -= 1; } - // never awake after this - tracing::info!("stalling indefinitely"); - std::future::pending::<()>().await; - unreachable!() + + match respond_after { + Some(delay) => { + tracing::info!("stalling for {} seconds", delay.as_secs()); + sleep.sleep(delay).await; + tracing::info!("returning delayed response"); + successful_response() + } + None => { + // never awake after this + tracing::info!("stalling indefinitely"); + std::future::pending::<()>().await; + unreachable!() + } + } } - fake_server!(FakeServerConnector, fake_server) + fake_server!( + FakeServerConnector, + fake_server, + Option, + respond_after + ) } /// Fake server/connector that polls data after each period of time in the given @@ -332,6 +412,57 @@ mod upload_test_tools { ) } + /// Fake server/connector that polls data only up to the content-length. Optionally delays + /// sending the response by the given duration. + pub fn limited_read_server( + content_len: usize, + respond_after: Option, + ) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) { + async fn fake_server( + mut body: Pin<&mut SdkBody>, + _time: TickAdvanceTime, + sleep: TickAdvanceSleep, + params: (usize, Option), + ) -> HttpResponse { + let mut remaining = params.0; + loop { + match poll_fn(|cx| body.as_mut().poll_data(cx)).await { + Some(res) => { + let rc = res.unwrap().len(); + remaining -= rc; + tracing::info!("read {rc} bytes; remaining: {remaining}"); + if remaining == 0 { + tracing::info!("read reported content-length data, stopping polling"); + break; + }; + } + None => { + tracing::info!( + "read until poll_data() returned None, no data left, stopping polling" + ); + break; + } + } + } + + let respond_after = params.1; + if let Some(delay) = respond_after { + tracing::info!("stalling for {} seconds", delay.as_secs()); + sleep.sleep(delay).await; + tracing::info!("returning delayed response"); + } + + successful_response() + } + + fake_server!( + FakeServerConnector, + fake_server, + (usize, Option), + (content_len, respond_after) + ) + } + pub fn expect_timeout(result: Result>>) { let err = result.expect_err("should have timed out"); assert_str_contains!( diff --git a/rust-runtime/aws-smithy-types/Cargo.toml b/rust-runtime/aws-smithy-types/Cargo.toml index c16bf6615a..e1590718a8 100644 --- a/rust-runtime/aws-smithy-types/Cargo.toml +++ b/rust-runtime/aws-smithy-types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-types" -version = "1.1.9" +version = "1.1.10" authors = [ "AWS Rust SDK Team ", "Russell Cohen ", @@ -67,6 +67,7 @@ tokio = { version = "1.23.1", features = [ "fs", "io-util", ] } +# This is used in a doctest, don't listen to udeps. tokio-stream = "0.1.5" tempfile = "3.2.0" diff --git a/rust-runtime/aws-smithy-types/src/body.rs b/rust-runtime/aws-smithy-types/src/body.rs index 4fe75dcf39..cde0d9d840 100644 --- a/rust-runtime/aws-smithy-types/src/body.rs +++ b/rust-runtime/aws-smithy-types/src/body.rs @@ -376,10 +376,12 @@ mod test { async fn http_body_consumes_data() { let mut body = SdkBody::from("hello!"); let mut body = Pin::new(&mut body); + assert!(!body.is_end_stream()); let data = body.next().await; assert!(data.is_some()); let data = body.next().await; assert!(data.is_none()); + assert!(body.is_end_stream()); } #[tokio::test] diff --git a/rust-runtime/aws-smithy-types/src/byte_stream.rs b/rust-runtime/aws-smithy-types/src/byte_stream.rs index 108a8bc90c..82d0bddf0e 100644 --- a/rust-runtime/aws-smithy-types/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-types/src/byte_stream.rs @@ -579,11 +579,13 @@ impl Inner { } } -#[cfg(test)] +#[cfg(all(test, feature = "rt-tokio"))] mod tests { + use super::{ByteStream, Inner}; use crate::body::SdkBody; - use crate::byte_stream::Inner; use bytes::Bytes; + use std::io::Write; + use tempfile::NamedTempFile; #[tokio::test] async fn read_from_string_body() { @@ -598,10 +600,8 @@ mod tests { ); } - #[cfg(feature = "rt-tokio")] #[tokio::test] async fn bytestream_into_async_read() { - use super::ByteStream; use tokio::io::AsyncBufReadExt; let byte_stream = ByteStream::from_static(b"data 1\ndata 2\ndata 3"); @@ -614,4 +614,44 @@ mod tests { assert_eq!(lines.next_line().await.unwrap(), Some("data 3".to_owned())); assert_eq!(lines.next_line().await.unwrap(), None); } + + #[tokio::test] + async fn valid_size_hint() { + assert_eq!(ByteStream::from_static(b"hello").size_hint().1, Some(5)); + assert_eq!(ByteStream::from_static(b"").size_hint().1, Some(0)); + + let mut f = NamedTempFile::new().unwrap(); + f.write_all(b"hello").unwrap(); + let body = ByteStream::from_path(f.path()).await.unwrap(); + assert_eq!(body.inner.size_hint().1, Some(5)); + + let mut f = NamedTempFile::new().unwrap(); + f.write_all(b"").unwrap(); + let body = ByteStream::from_path(f.path()).await.unwrap(); + assert_eq!(body.inner.size_hint().1, Some(0)); + } + + #[allow(clippy::bool_assert_comparison)] + #[tokio::test] + async fn valid_eos() { + assert_eq!( + ByteStream::from_static(b"hello").inner.body.is_end_stream(), + false + ); + let mut f = NamedTempFile::new().unwrap(); + f.write_all(b"hello").unwrap(); + let body = ByteStream::from_path(f.path()).await.unwrap(); + assert_eq!(body.inner.body.content_length(), Some(5)); + assert!(!body.inner.body.is_end_stream()); + + assert_eq!( + ByteStream::from_static(b"").inner.body.is_end_stream(), + true + ); + let mut f = NamedTempFile::new().unwrap(); + f.write_all(b"").unwrap(); + let body = ByteStream::from_path(f.path()).await.unwrap(); + assert_eq!(body.inner.body.content_length(), Some(0)); + assert!(body.inner.body.is_end_stream()); + } } diff --git a/rust-runtime/aws-smithy-types/src/byte_stream/bytestream_util.rs b/rust-runtime/aws-smithy-types/src/byte_stream/bytestream_util.rs index 8a8feda0c2..4fd8a08d00 100644 --- a/rust-runtime/aws-smithy-types/src/byte_stream/bytestream_util.rs +++ b/rust-runtime/aws-smithy-types/src/byte_stream/bytestream_util.rs @@ -44,7 +44,10 @@ impl PathBody { fn from_file(file: File, length: u64, buffer_size: usize) -> Self { PathBody { - state: State::Loaded(ReaderStream::with_capacity(file.take(length), buffer_size)), + state: State::Loaded { + stream: ReaderStream::with_capacity(file.take(length), buffer_size), + bytes_left: length, + }, length, buffer_size, // The file used to create this `PathBody` should have already had an offset applied @@ -230,7 +233,10 @@ impl FsBuilder { enum State { Unloaded(PathBuf), Loading(Pin> + Send + Sync + 'static>>), - Loaded(ReaderStream>), + Loaded { + stream: ReaderStream>, + bytes_left: u64, + }, } impl http_body_0_4::Body for PathBody { @@ -238,7 +244,7 @@ impl http_body_0_4::Body for PathBody { type Error = Box; fn poll_data( - mut self: std::pin::Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll>> { use std::task::Poll; @@ -260,18 +266,27 @@ impl http_body_0_4::Body for PathBody { State::Loading(ref mut future) => { match futures_core::ready!(Pin::new(future).poll(cx)) { Ok(file) => { - self.state = State::Loaded(ReaderStream::with_capacity( - file.take(self.length), - self.buffer_size, - )); + self.state = State::Loaded { + stream: ReaderStream::with_capacity( + file.take(self.length), + self.buffer_size, + ), + bytes_left: self.length, + }; } Err(e) => return Poll::Ready(Some(Err(e.into()))), }; } - State::Loaded(ref mut stream) => { + State::Loaded { + ref mut stream, + ref mut bytes_left, + } => { use futures_core::Stream; - return match futures_core::ready!(std::pin::Pin::new(stream).poll_next(cx)) { - Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))), + return match futures_core::ready!(Pin::new(stream).poll_next(cx)) { + Some(Ok(bytes)) => { + *bytes_left -= bytes.len() as u64; + Poll::Ready(Some(Ok(bytes))) + } None => Poll::Ready(None), Some(Err(e)) => Poll::Ready(Some(Err(e.into()))), }; @@ -281,15 +296,17 @@ impl http_body_0_4::Body for PathBody { } fn poll_trailers( - self: std::pin::Pin<&mut Self>, + self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, ) -> std::task::Poll, Self::Error>> { std::task::Poll::Ready(Ok(None)) } fn is_end_stream(&self) -> bool { - // fast path end-stream for empty streams - self.length == 0 + match self.state { + State::Unloaded(_) | State::Loading(_) => self.length == 0, + State::Loaded { bytes_left, .. } => bytes_left == 0, + } } fn size_hint(&self) -> http_body_0_4::SizeHint { @@ -303,6 +320,7 @@ mod test { use super::FsBuilder; use crate::byte_stream::{ByteStream, Length}; use bytes::Buf; + use http_body_0_4::Body; use std::io::Write; use tempfile::NamedTempFile; @@ -370,6 +388,29 @@ mod test { assert_eq!(body.content_length(), Some(1)); } + #[tokio::test] + async fn fsbuilder_is_end_stream() { + let sentence = "A very long sentence that's clearly longer than a single byte."; + let mut file = NamedTempFile::new().unwrap(); + file.write_all(sentence.as_bytes()).unwrap(); + // Ensure that the file was written to + file.flush().expect("flushing is OK"); + + let mut body = FsBuilder::new() + .path(&file) + .build() + .await + .unwrap() + .into_inner(); + + assert!(!body.is_end_stream()); + assert_eq!(body.content_length(), Some(sentence.len() as u64)); + + let data = body.data().await.unwrap().unwrap(); + assert_eq!(data.len(), sentence.len()); + assert!(body.is_end_stream()); + } + #[tokio::test] async fn fsbuilder_respects_length() { let mut file = NamedTempFile::new().unwrap();