From 892a1bcebb943d62e21918091c3575a4046829ef Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Tue, 5 Mar 2024 15:23:20 -0800 Subject: [PATCH 01/19] Take initial stab at upload minimum throughput --- aws/sdk/integration-tests/s3/Cargo.toml | 3 + .../s3/tests/stalled-stream-protection.rs | 114 ++++-- ...lledStreamProtectionConfigCustomization.kt | 2 +- .../client/http/body/minimum_throughput.rs | 373 +++++++++++++++++- .../minimum_throughput/http_body_0_4_x.rs | 43 ++ .../body/minimum_throughput/throughput.rs | 91 +++-- .../src/client/orchestrator.rs | 12 +- .../src/client/stalled_stream_protection.rs | 65 ++- .../src/test_util/capture_test_logs.rs | 19 +- 9 files changed, 620 insertions(+), 102 deletions(-) diff --git a/aws/sdk/integration-tests/s3/Cargo.toml b/aws/sdk/integration-tests/s3/Cargo.toml index 50ce1ae5c0..0d8ec0a9bc 100644 --- a/aws/sdk/integration-tests/s3/Cargo.toml +++ b/aws/sdk/integration-tests/s3/Cargo.toml @@ -48,3 +48,6 @@ tracing-subscriber = { version = "0.3.15", features = ["env-filter", "json"] } # If you're writing a test with this, take heed! `no-env-filter` means you'll be capturing # logs from everything that speaks, so be specific with your asserts. tracing-test = { version = "0.2.4", features = ["no-env-filter"] } + +[dependencies] +pin-project-lite = "0.2.13" diff --git a/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs b/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs index 25008a415e..d70c424d77 100644 --- a/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs +++ b/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs @@ -4,27 +4,96 @@ */ use aws_credential_types::Credentials; -use aws_sdk_s3::config::{Region, StalledStreamProtectionConfig}; -use aws_sdk_s3::primitives::ByteStream; +use aws_sdk_s3::{ + config::{Region, StalledStreamProtectionConfig}, + error::BoxError, +}; +use aws_sdk_s3::{error::DisplayErrorContext, primitives::ByteStream}; use aws_sdk_s3::{Client, Config}; -use bytes::BytesMut; +use aws_smithy_runtime::{assert_str_contains, test_util::capture_test_logs::capture_test_logs}; +use aws_smithy_types::body::SdkBody; +use bytes::{Bytes, BytesMut}; +use http_body::Body; use std::error::Error; -use std::future::Future; -use std::net::SocketAddr; use std::time::Duration; +use std::{future::Future, task::Poll}; +use std::{net::SocketAddr, pin::Pin, task::Context}; +use tokio::{ + net::{TcpListener, TcpStream}, + time::sleep, +}; use tracing::debug; +enum SlowBodyState { + Wait(Pin + Send + Sync + 'static>>), + Send, + Taken, +} + +struct SlowBody { + state: SlowBodyState, +} + +impl SlowBody { + fn new() -> Self { + Self { + state: SlowBodyState::Send, + } + } +} + +impl Body for SlowBody { + type Data = Bytes; + type Error = BoxError; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + let mut state = SlowBodyState::Taken; + std::mem::swap(&mut state, &mut self.state); + match state { + SlowBodyState::Wait(mut fut) => match fut.as_mut().poll(cx) { + Poll::Ready(_) => self.state = SlowBodyState::Send, + Poll::Pending => { + self.state = SlowBodyState::Wait(fut); + return Poll::Pending; + } + }, + SlowBodyState::Send => { + self.state = SlowBodyState::Wait(Box::pin(sleep(Duration::from_micros(100)))); + return Poll::Ready(Some(Ok(Bytes::from_static( + b"data_data_data_data_data_data_data_data_data_data_data_data_\ + data_data_data_data_data_data_data_data_data_data_data_data_\ + data_data_data_data_data_data_data_data_data_data_data_data_\ + data_data_data_data_data_data_data_data_data_data_data_data_", + )))); + } + SlowBodyState::Taken => unreachable!(), + } + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Poll::Ready(Ok(None)) + } +} + // This test doesn't work because we can't count on `hyper` to poll the body, // regardless of whether we schedule a wake. To make this functionality work, // we'd have to integrate more closely with the orchestrator. // // I'll leave this test here because we do eventually want to support stalled // stream protection for uploads. -#[ignore] #[tokio::test] async fn test_stalled_stream_protection_defaults_for_upload() { - // We spawn a faulty server that will close the connection after - // writing half of the response body. + let _logs = capture_test_logs(); + + // We spawn a faulty server that will stop all request processing after reading half of the request body. let (server, server_addr) = start_faulty_upload_server().await; let _ = tokio::spawn(server); @@ -32,7 +101,7 @@ async fn test_stalled_stream_protection_defaults_for_upload() { .credentials_provider(Credentials::for_tests()) .region(Region::new("us-east-1")) .endpoint_url(format!("http://{server_addr}")) - // .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build()) + .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build()) .build(); let client = Client::from_conf(conf); @@ -40,22 +109,19 @@ async fn test_stalled_stream_protection_defaults_for_upload() { .put_object() .bucket("a-test-bucket") .key("stalled-stream-test.txt") - .body(ByteStream::from_static(b"Hello")) + .body(ByteStream::new(SdkBody::from_body_0_4(SlowBody::new()))) .send() .await .expect_err("upload stream stalled out"); - let err = err.source().expect("inner error exists"); - assert_eq!( - err.to_string(), + let err_msg = DisplayErrorContext(&err).to_string(); + assert_str_contains!( + err_msg, "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" ); } async fn start_faulty_upload_server() -> (impl Future, SocketAddr) { - use tokio::net::{TcpListener, TcpStream}; - use tokio::time::sleep; - let listener = TcpListener::bind("0.0.0.0:0") .await .expect("socket is free"); @@ -65,12 +131,7 @@ async fn start_faulty_upload_server() -> (impl Future, SocketAddr) let mut buf = BytesMut::new(); let mut time_to_stall = false; - loop { - if time_to_stall { - debug!("faulty server has read partial request, now getting stuck"); - break; - } - + while !time_to_stall { match socket.try_read_buf(&mut buf) { Ok(0) => { unreachable!( @@ -79,12 +140,7 @@ async fn start_faulty_upload_server() -> (impl Future, SocketAddr) } Ok(n) => { debug!("read {n} bytes from the socket"); - - // Check to see if we've received some headers if buf.len() >= 128 { - let s = String::from_utf8_lossy(&buf); - debug!("{s}"); - time_to_stall = true; } } @@ -98,6 +154,7 @@ async fn start_faulty_upload_server() -> (impl Future, SocketAddr) } } + debug!("faulty server has read partial request, now getting stuck"); loop { tokio::task::yield_now().await } @@ -123,6 +180,7 @@ async fn start_faulty_upload_server() -> (impl Future, SocketAddr) } #[tokio::test] +#[ignore] async fn test_explicitly_configured_stalled_stream_protection_for_downloads() { // We spawn a faulty server that will close the connection after // writing half of the response body. @@ -163,6 +221,7 @@ async fn test_explicitly_configured_stalled_stream_protection_for_downloads() { } #[tokio::test] +#[ignore] async fn test_stalled_stream_protection_for_downloads_can_be_disabled() { // We spawn a faulty server that will close the connection after // writing half of the response body. @@ -195,6 +254,7 @@ async fn test_stalled_stream_protection_for_downloads_can_be_disabled() { // This test will always take as long as whatever grace period is set by default. #[tokio::test] +#[ignore] async fn test_stalled_stream_protection_for_downloads_is_enabled_by_default() { // We spawn a faulty server that will close the connection after // writing half of the response body. diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt index 83c3b6dd6b..3faeccff93 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt @@ -124,7 +124,7 @@ class StalledStreamProtectionOperationCustomization( // we can't count on hyper to poll a request body on wake. rustTemplate( """ - #{StalledStreamProtectionInterceptor}::new(#{Kind}::ResponseBody) + #{StalledStreamProtectionInterceptor}::new(#{Kind}::RequestAndResponseBody) """, *preludeScope, "StalledStreamProtectionInterceptor" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptor"), diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs index c576a34afa..96aa9517bf 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs @@ -18,11 +18,26 @@ mod throughput; use aws_smithy_async::rt::sleep::Sleep; use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep}; use aws_smithy_async::time::{SharedTimeSource, TimeSource}; -use aws_smithy_runtime_api::box_error::BoxError; -use aws_smithy_runtime_api::shared::IntoShared; +use aws_smithy_runtime_api::{ + box_error::BoxError, + client::{ + http::HttpConnectorFuture, result::ConnectorError, runtime_components::RuntimeComponents, + stalled_stream_protection::StalledStreamProtectionConfig, + }, +}; +use aws_smithy_runtime_api::{client::orchestrator::HttpResponse, shared::IntoShared}; +use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace}; use options::MinimumThroughputBodyOptions; -use std::fmt; -use std::time::SystemTime; +use std::{ + fmt, + sync::{Arc, Mutex}, + task::Poll, +}; +use std::{future::Future, pin::Pin}; +use std::{ + task::Context, + time::{Duration, SystemTime}, +}; use throughput::ThroughputLogs; pin_project_lite::pin_project! { @@ -93,4 +108,352 @@ impl fmt::Display for Error { impl std::error::Error for Error {} -// Tests are implemented per HTTP body type. +/// Used to store the upload throughput in the interceptor context. +#[derive(Clone, Debug)] +pub(crate) struct UploadThroughput { + log: Arc>, +} + +impl UploadThroughput { + pub(crate) fn new() -> Self { + Self { + log: Arc::new(Mutex::new(ThroughputLogs::new( + // Never keep more than 10KB of logs in memory. This currently + // equates to 426 logs. + (NUMBER_OF_LOGS_IN_ONE_KB * 10.0) as usize, + ))), + } + } + + pub(crate) fn push(&self, now: SystemTime, bytes: u64) { + self.log.lock().unwrap().push((now, bytes)); + } + + pub(crate) fn calculate_throughput( + &self, + now: SystemTime, + time_window: Duration, + ) -> Option { + self.log + .lock() + .unwrap() + .calculate_throughput(now, time_window) + } +} + +impl Storable for UploadThroughput { + type Storer = StoreReplace; +} + +pin_project_lite::pin_project! { + pub(crate) struct ThroughputReadingBody { + time_source: SharedTimeSource, + throughput: UploadThroughput, + #[pin] + inner: B, + } +} + +impl ThroughputReadingBody { + pub(crate) fn new( + time_source: SharedTimeSource, + throughput: UploadThroughput, + body: B, + ) -> Self { + Self { + time_source, + throughput, + inner: body, + } + } +} + +pin_project_lite::pin_project! { + struct ThroughputCheckFuture { + #[pin] + response: HttpConnectorFuture, + #[pin] + check_interval: Option, + #[pin] + grace_period: Option, + + time_source: SharedTimeSource, + sleep_impl: SharedAsyncSleep, + upload_throughput: UploadThroughput, + minimum_throughput: Throughput, + time_window: Duration, + grace_time: Duration, + + failing_throughput: Option, + } +} + +impl ThroughputCheckFuture { + fn new( + response: HttpConnectorFuture, + time_source: SharedTimeSource, + sleep_impl: SharedAsyncSleep, + upload_throughput: UploadThroughput, + minimum_throughput: Throughput, + time_window: Duration, + grace_time: Duration, + ) -> Self { + Self { + response, + check_interval: Some(sleep_impl.sleep(time_window)), + grace_period: None, + time_source, + sleep_impl, + upload_throughput, + minimum_throughput, + time_window, + grace_time, + failing_throughput: None, + } + } +} + +impl Future for ThroughputCheckFuture { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + if let Poll::Ready(output) = this.response.poll(cx) { + return Poll::Ready(output); + } else { + let mut below_minimum_throughput = false; + let check_interval_expired = this + .check_interval + .as_mut() + .as_pin_mut() + .expect("always set") + .poll(cx) + .is_ready(); + if check_interval_expired { + // Set up the next check interval + *this.check_interval = Some(this.sleep_impl.sleep(*this.time_window)); + + // Wake so that the check interval future gets polled + // next time this poll method is called. If it never gets polled, + // then this task won't be woken to check again. + cx.waker().wake_by_ref(); + } + + let should_check = check_interval_expired || this.grace_period.is_some(); + if should_check { + let now = this.time_source.now(); + let current_throughput = this + .upload_throughput + .calculate_throughput(now, *this.time_window); + below_minimum_throughput = current_throughput + .as_ref() + .map(|tp| tp < this.minimum_throughput) + .unwrap_or_default(); + tracing::debug!("current throughput: {current_throughput:?}, below minimum: {below_minimum_throughput}"); + if below_minimum_throughput && !this.failing_throughput.is_some() { + *this.failing_throughput = current_throughput; + } else if !below_minimum_throughput { + *this.failing_throughput = None; + } + } + + // If we kicked off a grace period and are now satisfied, clear out the grace period + if !below_minimum_throughput && this.grace_period.is_some() { + tracing::debug!("upload minimum throughput recovered during grace period"); + *this.grace_period = None; + } + if below_minimum_throughput { + // Start a grace period if below minimum throughput + if this.grace_period.is_none() { + tracing::debug!( + grace_period=?*this.grace_time, + "upload minimum throughput below configured minimum; starting grace period" + ); + *this.grace_period = Some(this.sleep_impl.sleep(*this.grace_time)); + } + // Check the grace period if one is already set and we're not satisfied + if let Some(grace_period) = this.grace_period.as_pin_mut() { + if grace_period.poll(cx).is_ready() { + tracing::debug!("grace period ended; timing out request"); + return Poll::Ready(Err(ConnectorError::timeout( + Error::ThroughputBelowMinimum { + expected: *this.minimum_throughput, + actual: this + .failing_throughput + .expect("always set if there's a grace period"), + } + .into(), + ))); + } + } + } + } + Poll::Pending + } +} + +pin_project_lite::pin_project! { + #[project = EnumProj] + pub(crate) enum MaybeThroughputCheckFuture { + Direct { #[pin] future: HttpConnectorFuture }, + Checked { #[pin] future: ThroughputCheckFuture }, + } +} + +impl MaybeThroughputCheckFuture { + pub(crate) fn new( + cfg: &mut ConfigBag, + components: &RuntimeComponents, + connector_future: HttpConnectorFuture, + ) -> Self { + if let Some(sspcfg) = cfg.load::().cloned() { + if sspcfg.is_enabled() { + let options = MinimumThroughputBodyOptions::from(sspcfg); + return Self::new_inner( + connector_future, + components.time_source(), + components.sleep_impl(), + cfg.interceptor_state().load::().cloned(), + Some(options), + ); + } + } + tracing::debug!("no minimum upload throughput checks"); + Self::new_inner(connector_future, None, None, None, None) + } + + fn new_inner( + response: HttpConnectorFuture, + time_source: Option, + sleep_impl: Option, + upload_throughput: Option, + options: Option, + ) -> Self { + match (time_source, sleep_impl, upload_throughput, options) { + (Some(time_source), Some(sleep_impl), Some(upload_throughput), Some(options)) => { + tracing::debug!(options=?options, "applying minimum upload throughput check future"); + Self::Checked { + future: ThroughputCheckFuture::new( + response, + time_source, + sleep_impl, + upload_throughput, + options.minimum_throughput(), + options.check_window(), + options.grace_period(), + ), + } + } + _ => Self::Direct { future: response }, + } + } +} + +impl Future for MaybeThroughputCheckFuture { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project() { + EnumProj::Direct { future } => future.poll(cx), + EnumProj::Checked { future } => future.poll(cx), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{assert_str_contains, test_util::capture_test_logs::capture_test_logs}; + use aws_smithy_async::test_util::tick_advance_sleep::tick_advance_time_and_sleep; + use aws_smithy_types::{body::SdkBody, error::display::DisplayErrorContext}; + use std::future::IntoFuture; + + const TEST_TIME_WINDOW: Duration = Duration::from_secs(1); + + #[tokio::test] + async fn throughput_check_constant_rate_success() { + let minimum_throughput = Throughput::new_bytes_per_second(990); + let grace_time = Duration::from_secs(1); + let actual_throughput_bps = 1000.0; + let transfer_fn = |_| actual_throughput_bps * TEST_TIME_WINDOW.as_secs_f64(); + let result = throughput_check_test(minimum_throughput, grace_time, transfer_fn).await; + let response = result.expect("no timeout"); + assert_eq!(200, response.status().as_u16()); + } + + #[tokio::test] + async fn throughput_check_constant_rate_timeout() { + let minimum_throughput = Throughput::new_bytes_per_second(1100); + let grace_time = Duration::from_secs(1); + let actual_throughput_bps = 1000.0; + let transfer_fn = |_| actual_throughput_bps * TEST_TIME_WINDOW.as_secs_f64(); + let result = throughput_check_test(minimum_throughput, grace_time, transfer_fn).await; + let error = result.err().expect("times out"); + assert_str_contains!( + DisplayErrorContext(&error).to_string(), + "minimum throughput was specified at 1100 B/s, but throughput of 1000 B/s was observed" + ); + } + + #[tokio::test] + async fn throughput_check_grace_time_recovery() { + let minimum_throughput = Throughput::new_bytes_per_second(1000); + let grace_time = Duration::from_secs(3); + let actual_throughput_bps = 1000.0; + let transfer_fn = |window| { + if window <= 5 || window > 7 { + actual_throughput_bps * TEST_TIME_WINDOW.as_secs_f64() + } else { + 0.0 + } + }; + let result = throughput_check_test(minimum_throughput, grace_time, transfer_fn).await; + let response = result.expect("no timeout"); + assert_eq!(200, response.status().as_u16()); + } + + async fn throughput_check_test( + minimum_throughput: Throughput, + grace_time: Duration, + transfer_fn: F, + ) -> Result + where + F: Fn(u64) -> f64, + { + let _logs = capture_test_logs(); + let (time_source, sleep_impl) = tick_advance_time_and_sleep(); + + let response = HttpResponse::try_from( + http::Response::builder() + .status(200) + .body(SdkBody::empty()) + .unwrap(), + ) + .unwrap(); + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + + let upload_throughput = UploadThroughput::new(); + let check_task = tokio::spawn(ThroughputCheckFuture::new( + HttpConnectorFuture::new(async move { Ok(response_rx.into_future().await.unwrap()) }), + time_source.clone().into_shared(), + sleep_impl.into_shared(), + upload_throughput.clone(), + minimum_throughput, + TEST_TIME_WINDOW, + grace_time, + )); + + // simulate 20 check time windows at `actual_throughput` bytes/sec + for window in 0..20 { + let bytes = (transfer_fn)(window); + upload_throughput.push(time_source.now(), (bytes + 0.5) as u64); + time_source.tick(TEST_TIME_WINDOW).await; + println!("window {window}"); + } + let _ = response_tx.send(response); + println!("upload finished"); + + check_task.await.expect("no panic") + } +} diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs index 075ef39d63..85b73d062d 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs @@ -4,6 +4,7 @@ */ use super::{BoxError, Error, MinimumThroughputBody}; +use crate::client::http::body::minimum_throughput::ThroughputReadingBody; use aws_smithy_async::rt::sleep::AsyncSleep; use http_body_0_4::Body; use std::future::Future; @@ -114,6 +115,48 @@ where } } +impl Body for ThroughputReadingBody +where + B: Body, +{ + type Data = bytes::Bytes; + type Error = BoxError; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + // this code is called quite frequently in production—one every millisecond or so when downloading + // a stream. However, SystemTime::now is on the order of nanoseconds + let now = self.time_source.now(); + // Attempt to read the data from the inner body, then update the + // throughput logs. + let this = self.as_mut().project(); + match this.inner.poll_data(cx) { + Poll::Ready(Some(Ok(bytes))) => { + tracing::trace!("received data: {}", bytes.len()); + this.throughput.push(now, bytes.len() as u64); + Poll::Ready(Some(Ok(bytes))) + } + Poll::Pending => { + tracing::trace!("received poll pending"); + this.throughput.push(now, 0); + Poll::Pending + } + // If we've read all the data or an error occurred, then return that result. + res => res, + } + } + + fn poll_trailers( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + let this = self.as_mut().project(); + this.inner.poll_trailers(cx) + } +} + // These tests use `hyper::body::Body::wrap_stream` #[cfg(all(test, feature = "connector-hyper-0-14-x", feature = "test-util"))] mod test { diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs index e2a9b294e6..f5c759caa4 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs @@ -97,7 +97,7 @@ impl From<(u64, Duration)> for Throughput { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub(super) struct ThroughputLogs { max_length: usize, inner: VecDeque<(SystemTime, u64)>, @@ -118,45 +118,38 @@ impl ThroughputLogs { if self.inner.len() == self.max_length { self.bytes_processed -= self.inner.pop_front().map(|(_, sz)| sz).unwrap_or_default(); } - debug_assert!(self.inner.capacity() > self.inner.len()); self.bytes_processed += throughput.1; self.inner.push_back(throughput); } - fn buffer_full(&self) -> bool { - self.inner.len() == self.max_length - } - pub(super) fn calculate_throughput( &self, now: SystemTime, time_window: Duration, ) -> Option { - // There are a lot of pathological cases that are 0 throughput. These cases largely shouldn't - // happen, because the check interval MUST be less than the check window - let total_length = self - .inner - .iter() - .last()? - .0 - .duration_since(self.inner.front()?.0) - .ok()?; - // during a "healthy" request we'll only have a few milliseconds of logs (shorter than the check window) - if total_length < time_window { - // if we haven't hit our requested time window & the buffer still isn't full, then - // return `None` — this is the "startup grace period" - return if !self.buffer_full() { - None - } else { - // Otherwise, if the entire buffer fits in the timewindow, we can the shortcut to - // avoid recomputing all the data - Some(Throughput { + if self.inner.is_empty() { + return None; + } + if let Some(first_time) = self.inner.front().map(|e| e.0) { + if first_time + time_window >= now && first_time < now { + // If the first logged time fits within the time window, then short-circuit + return Some(Throughput { bytes_read: self.bytes_processed, - per_time_elapsed: total_length, - }) - }; + per_time_elapsed: now.duration_since(first_time).expect("checked above"), + }); + } } + if let Some(last_time) = self.inner.back().map(|e| e.0) { + if last_time + time_window < now { + // If we have no log entries at all within the time window, then short-circuit to 0 + return Some(Throughput { + bytes_read: 0, + per_time_elapsed: time_window, + }); + } + } + let minimum_ts = now - time_window; let first_item = self.inner.iter().find(|(ts, _)| *ts >= minimum_ts)?.0; @@ -166,7 +159,7 @@ impl ThroughputLogs { .inner .iter() .rev() - .take_while(|(ts, _)| *ts > minimum_ts) + .take_while(|(ts, _)| *ts >= minimum_ts) .map(|t| t.1) .sum::(); @@ -280,4 +273,44 @@ mod test { .unwrap(); assert_eq!(108.0, throughput.bytes_per_second()); } + + // If the time since the last log entry is greater than the window, then the throughput should be zero + #[test] + fn test_throughput_log_calculate_throughput_long_after_last_log() { + let (throughput_logs, now) = build_throughput_log(1000, Duration::from_millis(100), 12); + + let throughput = throughput_logs + .calculate_throughput(now + Duration::from_secs(5), Duration::from_secs(1)) + .unwrap(); + let expected_throughput = 0.0; + + assert_eq!(expected_throughput, throughput.bytes_per_second()); + } + + // If the throughput log is empty, it should return None for the calculated throughput + #[test] + fn test_throughput_log_calculate_throughput_empty_log() { + let throughput_logs = ThroughputLogs::new(1000); + assert!(throughput_logs + .calculate_throughput(UNIX_EPOCH, Duration::from_secs(1)) + .is_none()); + } + + // Verify things work as expected when everything occurs exactly on the time window boundary + #[test] + fn test_boundary_conditions() { + let mut logs = ThroughputLogs::new(1000); + logs.bytes_processed = 2000; + logs.inner.push_back((SystemTime::UNIX_EPOCH, 1000)); + logs.inner + .push_back((SystemTime::UNIX_EPOCH + Duration::from_secs(1), 1000)); + + let throughput = logs + .calculate_throughput( + SystemTime::UNIX_EPOCH + Duration::from_secs(2), + Duration::from_secs(1), + ) + .unwrap(); + assert_eq!(Throughput::new_bytes_per_second(1000), throughput); + } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs index f8bbc2c05c..0a8088a01a 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs @@ -5,9 +5,12 @@ use self::auth::orchestrate_auth; use crate::client::interceptors::Interceptors; -use crate::client::orchestrator::endpoints::orchestrate_endpoint; use crate::client::orchestrator::http::{log_response_body, read_body}; use crate::client::timeout::{MaybeTimeout, MaybeTimeoutConfig, TimeoutKind}; +use crate::client::{ + http::body::minimum_throughput::MaybeThroughputCheckFuture, + orchestrator::endpoints::orchestrate_endpoint, +}; use aws_smithy_async::rt::sleep::AsyncSleep; use aws_smithy_runtime_api::box_error::BoxError; use aws_smithy_runtime_api::client::http::{HttpClient, HttpConnector, HttpConnectorSettings}; @@ -385,7 +388,12 @@ async fn try_attempt( builder.build() }; let connector = http_client.http_connector(&settings, runtime_components); - connector.call(request).await.map_err(OrchestratorError::connector) + let response_future = MaybeThroughputCheckFuture::new( + cfg, + runtime_components, + connector.call(request), + ); + response_future.await.map_err(OrchestratorError::connector) }); trace!(response = ?response, "received response from service"); ctx.set_response(response); diff --git a/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs index 3e07b3f0b8..89f08c8bde 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs @@ -3,7 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -use crate::client::http::body::minimum_throughput::MinimumThroughputBody; +use crate::client::http::body::minimum_throughput::{ + MinimumThroughputBody, ThroughputReadingBody, UploadThroughput, +}; use aws_smithy_async::rt::sleep::SharedAsyncSleep; use aws_smithy_async::time::SharedTimeSource; use aws_smithy_runtime_api::box_error::BoxError; @@ -64,17 +66,25 @@ impl Intercept for StalledStreamProtectionInterceptor { cfg: &mut ConfigBag, ) -> Result<(), BoxError> { if self.enable_for_request_body { - if let Some(cfg) = cfg.load::() { - if cfg.is_enabled() { - let (async_sleep, time_source) = + if let Some(sspcfg) = cfg.load::().cloned() { + if sspcfg.is_enabled() { + let throughput = UploadThroughput::new(); + cfg.interceptor_state().store_put(throughput.clone()); + + let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?; tracing::trace!("adding stalled stream protection to request body"); - add_stalled_stream_protection_to_body( - context.request_mut().body_mut(), - cfg, - async_sleep, - time_source, - ); + let sspcfg = sspcfg.clone(); // TODO XXX use the config? + let it = mem::replace(context.request_mut().body_mut(), SdkBody::taken()); + let it = it.map_preserve_contents(move |body| { + let time_source = time_source.clone(); + SdkBody::from_body_0_4(ThroughputReadingBody::new( + time_source, + throughput.clone(), + body, + )) + }); + let _ = mem::replace(context.request_mut().body_mut(), it); } } } @@ -94,12 +104,17 @@ impl Intercept for StalledStreamProtectionInterceptor { let (async_sleep, time_source) = get_runtime_component_deps(runtime_components)?; tracing::trace!("adding stalled stream protection to response body"); - add_stalled_stream_protection_to_body( - context.response_mut().body_mut(), - cfg, - async_sleep, - time_source, - ); + let cfg = cfg.clone(); + let it = mem::replace(context.response_mut().body_mut(), SdkBody::taken()); + let it = it.map_preserve_contents(move |body| { + let cfg = cfg.clone(); + let async_sleep = async_sleep.clone(); + let time_source = time_source.clone(); + let mtb = + MinimumThroughputBody::new(time_source, async_sleep, body, cfg.into()); + SdkBody::from_body_0_4(mtb) + }); + let _ = mem::replace(context.response_mut().body_mut(), it); } } } @@ -118,21 +133,3 @@ fn get_runtime_component_deps( .ok_or("A time source is required when stalled stream protection is enabled")?; Ok((async_sleep, time_source)) } - -fn add_stalled_stream_protection_to_body( - body: &mut SdkBody, - cfg: &StalledStreamProtectionConfig, - async_sleep: SharedAsyncSleep, - time_source: SharedTimeSource, -) { - let cfg = cfg.clone(); - let it = mem::replace(body, SdkBody::taken()); - let it = it.map_preserve_contents(move |body| { - let cfg = cfg.clone(); - let async_sleep = async_sleep.clone(); - let time_source = time_source.clone(); - let mtb = MinimumThroughputBody::new(time_source, async_sleep, body, cfg.into()); - SdkBody::from_body_0_4(mtb) - }); - let _ = mem::replace(body, it); -} diff --git a/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs b/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs index 92b450c115..b55c0c68ac 100644 --- a/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs +++ b/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs @@ -24,14 +24,25 @@ pub struct LogCaptureGuard(#[allow(dead_code)] DefaultGuard); pub fn capture_test_logs() -> (LogCaptureGuard, Rx) { // it may be helpful to upstream this at some point let (mut writer, rx) = Tee::stdout(); - if env::var("VERBOSE_TEST_LOGS").is_ok() { - eprintln!("Enabled verbose test logging."); + let (enabled, level) = match env::var("VERBOSE_TEST_LOGS").ok().as_deref() { + Some("debug") => (true, Level::DEBUG), + Some("error") => (true, Level::ERROR), + Some("info") => (true, Level::INFO), + Some("warn") => (true, Level::WARN), + Some("trace") | Some(_) => (true, Level::TRACE), + None => (false, Level::TRACE), + }; + if enabled { + eprintln!("Enabled verbose test logging at {level:?}."); writer.loud(); } else { - eprintln!("To see full logs from this test set VERBOSE_TEST_LOGS=true"); + eprintln!( + "To see full logs from this test set VERBOSE_TEST_LOGS=true \ + (or to a log level, e.g., trace, debug, info, etc)" + ); } let subscriber = tracing_subscriber::fmt() - .with_max_level(Level::TRACE) + .with_max_level(level) .with_writer(Mutex::new(writer)) .finish(); let guard = tracing::subscriber::set_default(subscriber); From 38aa76eeb6623feb24e67059a5a4876250ed4ceb Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Fri, 8 Mar 2024 14:12:59 -0800 Subject: [PATCH 02/19] Add comprehensive upload stream protection integration tests --- .../src/client/orchestrator/operation.rs | 13 +- .../tests/stalled_stream_upload.rs | 424 ++++++++++++++++++ 2 files changed, 436 insertions(+), 1 deletion(-) create mode 100644 rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs index bd875c72e6..7ad48ec585 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs @@ -12,7 +12,6 @@ use crate::client::orchestrator::endpoints::StaticUriEndpointResolver; use crate::client::retries::strategy::{NeverRetryStrategy, StandardRetryStrategy}; use aws_smithy_async::rt::sleep::AsyncSleep; use aws_smithy_async::time::TimeSource; -use aws_smithy_runtime_api::box_error::BoxError; use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver; use aws_smithy_runtime_api::client::auth::{ AuthSchemeOptionResolverParams, SharedAuthScheme, SharedAuthSchemeOptionResolver, @@ -35,6 +34,9 @@ use aws_smithy_runtime_api::client::ser_de::{ DeserializeResponse, SerializeRequest, SharedRequestSerializer, SharedResponseDeserializer, }; use aws_smithy_runtime_api::shared::IntoShared; +use aws_smithy_runtime_api::{ + box_error::BoxError, client::stalled_stream_protection::StalledStreamProtectionConfig, +}; use aws_smithy_types::config_bag::{ConfigBag, Layer}; use aws_smithy_types::retry::RetryConfig; use aws_smithy_types::timeout::TimeoutConfig; @@ -293,6 +295,15 @@ impl OperationBuilder { self } + /// Configures stalled stream protection with the given config. + pub fn stalled_stream_protection( + mut self, + stalled_stream_protection: StalledStreamProtectionConfig, + ) -> Self { + self.config.store_put(stalled_stream_protection); + self + } + /// Configures the serializer for the builder. pub fn serializer( mut self, diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs new file mode 100644 index 0000000000..0033a8828f --- /dev/null +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs @@ -0,0 +1,424 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_async::time::TimeSource; +use aws_smithy_runtime::{assert_str_contains, test_util::capture_test_logs::capture_test_logs}; +use aws_smithy_types::error::display::DisplayErrorContext; +use bytes::Bytes; +use std::time::Duration; +use tracing::info; + +/// No really, it's 42 bytes long... super neat +const NEAT_DATA: Bytes = Bytes::from_static(b"some really neat data"); + +/// Ticks time forward by the given duration, and logs the current time for debugging. +macro_rules! tick { + ($ticker:ident, $duration:expr) => { + $ticker.tick($duration).await; + let now = $ticker + .now() + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap(); + tracing::info!("ticked {:?}, now at {:?}", $duration, now); + }; +} + +/// Scenario: Successful upload at a rate above the minimum throughput. +/// Expected: MUST NOT timeout. +#[tokio::test] +async fn upload_success() { + let _logs = capture_test_logs(); + + let (server, time, sleep) = eager_server(true); + let op = operation(server, time, sleep); + + let (body, body_sender) = channel_body(); + let result = tokio::spawn(async move { op.invoke(body).await }); + + for _ in 0..100 { + body_sender.send(NEAT_DATA).await.unwrap(); + } + drop(body_sender); + + assert_eq!(200, result.await.unwrap().expect("success").as_u16()); +} + +/// Scenario: Upload takes some time to start, but then goes normally. +/// Expected: MUST NOT timeout. +#[tokio::test] +async fn upload_slow_start() { + let _logs = capture_test_logs(); + + let (server, time, sleep) = eager_server(false); + let op = operation(server, time.clone(), sleep); + + let (body, body_sender) = channel_body(); + let result = tokio::spawn(async move { op.invoke(body).await }); + + let _streamer = tokio::spawn(async move { + // Advance longer than the grace period. This shouldn't fail since + // it is the customer's side that hasn't produced data yet, not a server issue. + time.tick(Duration::from_secs(10)).await; + + for _ in 0..100 { + body_sender.send(NEAT_DATA).await.unwrap(); + time.tick(Duration::from_secs(1)).await; + } + drop(body_sender); + time.tick(Duration::from_secs(1)).await; + }); + + assert_eq!(200, result.await.unwrap().expect("success").as_u16()); +} + +/// Scenario: The upload is going fine, but falls below the minimum throughput. +/// Expected: MUST timeout. +#[tokio::test] +async fn upload_too_slow() { + let _logs = capture_test_logs(); + + // Server that starts off fast enough, but gets slower over time until it should timeout. + let (server, time, sleep) = time_sequence_server([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let op = operation(server, time, sleep); + + let (body, body_sender) = channel_body(); + let result = tokio::spawn(async move { op.invoke(body).await }); + + let _streamer = tokio::spawn(async move { + for send in 0..100 { + info!("send {send}"); + body_sender.send(NEAT_DATA).await.unwrap(); + } + drop(body_sender); + }); + + let err = result + .await + .expect("no panics") + .expect_err("should have timed out"); + assert_str_contains!( + DisplayErrorContext(&err).to_string(), + "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" + ); +} + +/// Scenario: The server stops asking for data, the client maxes out its send buffer, +/// and the request stream stops being polled. +/// Expected: MUST timeout after the grace period completes. +#[tokio::test] +async fn upload_stalls() { + let _logs = capture_test_logs(); + + let (server, time, sleep) = stalling_server(); + let op = operation(server, time.clone(), sleep); + + let (body, body_sender) = channel_body(); + let result = tokio::spawn(async move { op.invoke(body).await }); + + let _streamer = tokio::spawn(async move { + for send in 1..=100 { + info!("send {send}"); + body_sender.send(NEAT_DATA).await.unwrap(); + tick!(time, Duration::from_secs(1)); + } + drop(body_sender); + time.tick(Duration::from_secs(1)).await; + }); + + let err = result + .await + .expect("no panics") + .expect_err("should have timed out"); + assert_str_contains!( + DisplayErrorContext(&err).to_string(), + "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" + ); +} + +// Scenario: The server stops asking for data, the client maxes out its send buffer, +// and the request stream stops being polled. However, before the grace period +// is over, the server recovers and starts asking for data again. +// Expected: MUST NOT timeout. +#[tokio::test] +async fn upload_stall_recovery_in_grace_period() { + let _logs = capture_test_logs(); + + // Server starts off fast enough, but then slows down almost up to + // the grace period, and then recovers. + let (server, time, sleep) = time_sequence_server([1, 4, 1]); + let op = operation(server, time, sleep); + + let (body, body_sender) = channel_body(); + let result = tokio::spawn(async move { op.invoke(body).await }); + + let _streamer = tokio::spawn(async move { + for send in 0..100 { + info!("send {send}"); + body_sender.send(NEAT_DATA).await.unwrap(); + } + drop(body_sender); + }); + + assert_eq!(200, result.await.unwrap().expect("success").as_u16()); +} + +// Scenario: The customer isn't providing data on the stream fast enough to satisfy +// the minimum throughput. This shouldn't be considered a stall since the +// server is asking for more data and could handle it if it were available. +// Expected: MUST NOT timeout. +#[tokio::test] +async fn user_provides_data_too_slowly() { + let _logs = capture_test_logs(); + + let (server, time, sleep) = eager_server(false); + let op = operation(server, time.clone(), sleep.clone()); + + let (body, body_sender) = channel_body(); + let result = tokio::spawn(async move { op.invoke(body).await }); + + let _streamer = tokio::spawn(async move { + body_sender.send(NEAT_DATA).await.unwrap(); + tick!(time, Duration::from_secs(1)); + body_sender.send(NEAT_DATA).await.unwrap(); + + // Now advance 10 seconds before sending more data, simulating a + // customer taking time to produce more data to stream. + tick!(time, Duration::from_secs(10)); + body_sender.send(NEAT_DATA).await.unwrap(); + drop(body_sender); + tick!(time, Duration::from_secs(1)); + }); + + assert_eq!(200, result.await.unwrap().expect("success").as_u16()); +} + +use test_tools::*; +mod test_tools { + use aws_smithy_async::test_util::tick_advance_sleep::{ + tick_advance_time_and_sleep, TickAdvanceSleep, TickAdvanceTime, + }; + use aws_smithy_async::time::TimeSource; + use aws_smithy_runtime::client::{ + orchestrator::operation::Operation, + stalled_stream_protection::{ + StalledStreamProtectionInterceptor, StalledStreamProtectionInterceptorKind, + }, + }; + use aws_smithy_runtime_api::{ + client::{ + http::{ + HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, + SharedHttpConnector, + }, + orchestrator::{HttpRequest, HttpResponse}, + runtime_components::RuntimeComponents, + stalled_stream_protection::StalledStreamProtectionConfig, + }, + http::StatusCode, + shared::IntoShared, + }; + use aws_smithy_types::{body::SdkBody, timeout::TimeoutConfig}; + use bytes::Bytes; + use http_body_0_4::Body; + use pin_utils::pin_mut; + use std::{ + collections::VecDeque, + convert::Infallible, + future::poll_fn, + mem, + pin::Pin, + task::{Context, Poll}, + time::Duration, + }; + use tracing::instrument::Instrument; + + #[derive(Debug)] + struct FakeServer(SharedHttpConnector); + + impl HttpClient for FakeServer { + fn http_connector( + &self, + _settings: &HttpConnectorSettings, + _components: &RuntimeComponents, + ) -> SharedHttpConnector { + self.0.clone() + } + } + + struct ChannelBody { + receiver: tokio::sync::mpsc::Receiver, + } + impl http_body_0_4::Body for ChannelBody { + type Data = Bytes; + type Error = Infallible; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.receiver.poll_recv(cx) { + Poll::Ready(value) => Poll::Ready(value.map(|v| Ok(v))), + Poll::Pending => Poll::Pending, + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + unreachable!() + } + } + + pub fn channel_body() -> (SdkBody, tokio::sync::mpsc::Sender) { + let (sender, receiver) = tokio::sync::mpsc::channel(1000); + (SdkBody::from_body_0_4(ChannelBody { receiver }), sender) + } + + pub fn successful_response() -> HttpResponse { + HttpResponse::try_from( + http::Response::builder() + .status(200) + .body(SdkBody::empty()) + .unwrap(), + ) + .unwrap() + } + + pub fn operation( + http_connector: impl HttpConnector + 'static, + time: TickAdvanceTime, + sleep: TickAdvanceSleep, + ) -> Operation { + let operation = Operation::builder() + .service_name("test") + .operation_name("test") + .http_client(FakeServer(http_connector.into_shared())) + .endpoint_url("http://localhost:1234/doesntmatter") + .no_auth() + .no_retry() + .timeout_config(TimeoutConfig::disabled()) + .serializer(|body: SdkBody| Ok(HttpRequest::new(body))) + .deserializer::<_, Infallible>(|response| Ok(response.status())) + .stalled_stream_protection( + StalledStreamProtectionConfig::enabled() + .grace_period(Duration::from_secs(5)) + .build(), + ) + .interceptor(StalledStreamProtectionInterceptor::new( + StalledStreamProtectionInterceptorKind::RequestAndResponseBody, + )) + .sleep_impl(sleep) + .time_source(time) + .build(); + operation + } + + /// Creates a fake HttpConnector implementation that calls the given async $body_fn + /// to get the response body. This $body_fn is given a request body, time, and sleep. + macro_rules! fake_server { + ($name:ident, $body_fn:expr) => { + fake_server!($name, $body_fn, (), ()) + }; + ($name:ident, $body_fn:expr, $params_ty:ty, $params:expr) => {{ + #[derive(Debug)] + struct $name(TickAdvanceTime, TickAdvanceSleep, $params_ty); + impl HttpConnector for $name { + fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture { + let time = self.0.clone(); + let sleep = self.1.clone(); + let params = self.2.clone(); + let span = tracing::span!(tracing::Level::INFO, "FAKE SERVER"); + HttpConnectorFuture::new( + async move { + let mut body = SdkBody::taken(); + mem::swap(request.body_mut(), &mut body); + pin_mut!(body); + + Ok($body_fn(body, time, sleep, params).await) + } + .instrument(span), + ) + } + } + let (time, sleep) = tick_advance_time_and_sleep(); + ( + $name(time.clone(), sleep.clone(), $params).into_shared(), + time, + sleep, + ) + }}; + } + + /// Fake server/connector that immediately reads all incoming data with an + /// optional 1 second gap in between polls. + pub fn eager_server( + advance_time: bool, + ) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) { + async fn fake_server( + mut body: Pin<&mut SdkBody>, + time: TickAdvanceTime, + _: TickAdvanceSleep, + advance_time: bool, + ) -> HttpResponse { + while poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() { + if advance_time { + tick!(time, Duration::from_secs(1)); + } + } + successful_response() + } + fake_server!(FakeServerConnector, fake_server, bool, advance_time) + } + + /// Fake server/connector that reads some data, and then stalls. + pub fn stalling_server() -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) { + async fn fake_server( + mut body: Pin<&mut SdkBody>, + _time: TickAdvanceTime, + _sleep: TickAdvanceSleep, + _: (), + ) -> HttpResponse { + let mut times = 5; + while poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() { + times -= 1; + if times <= 0 { + // never awake after this + tracing::info!("stalling indefinitely"); + std::future::pending().await + } + } + unreachable!() + } + fake_server!(FakeServerConnector, fake_server) + } + + /// Fake server/connector that polls data after each period of time in the given + /// sequence. Once the sequence completes, it will delay 1 second after each poll. + pub fn time_sequence_server( + time_sequence: impl IntoIterator, + ) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) { + async fn fake_server( + mut body: Pin<&mut SdkBody>, + time: TickAdvanceTime, + _sleep: TickAdvanceSleep, + time_sequence: Vec, + ) -> HttpResponse { + let mut time_sequence: VecDeque = + time_sequence.into_iter().map(Duration::from_secs).collect(); + while poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() { + let next_time = time_sequence.pop_front().unwrap_or(Duration::from_secs(1)); + tick!(time, next_time); + } + successful_response() + } + fake_server!( + FakeServerConnector, + fake_server, + Vec, + time_sequence.into_iter().collect() + ) + } +} From 186c51b81bbd9051f6d7ad64818146402fd3b194 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Fri, 8 Mar 2024 15:10:36 -0800 Subject: [PATCH 03/19] Add comprehensive download stream protection integration tests --- .../src/client/orchestrator/operation.rs | 22 + .../tests/stalled_stream_download.rs | 388 ++++++++++++++++++ 2 files changed, 410 insertions(+) create mode 100644 rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs index 7ad48ec585..e761cf601f 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs @@ -350,6 +350,28 @@ impl OperationBuilder { } } + /// Configures the a deserializer implementation for the builder. + pub fn deserializer_impl( + mut self, + deserializer: impl DeserializeResponse + Send + Sync + 'static, + ) -> OperationBuilder + where + O2: fmt::Debug + Send + Sync + 'static, + E2: std::error::Error + fmt::Debug + Send + Sync + 'static, + { + let deserializer: SharedResponseDeserializer = deserializer.into_shared(); + self.config.store_put(deserializer); + + OperationBuilder { + service_name: self.service_name, + operation_name: self.operation_name, + config: self.config, + runtime_components: self.runtime_components, + runtime_plugins: self.runtime_plugins, + _phantom: Default::default(), + } + } + /// Creates an `Operation` from the builder. pub fn build(self) -> Operation { let service_name = self.service_name.expect("service_name required"); diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs new file mode 100644 index 0000000000..8cc160f011 --- /dev/null +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs @@ -0,0 +1,388 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_async::test_util::tick_advance_sleep::tick_advance_time_and_sleep; +use aws_smithy_async::time::TimeSource; +use aws_smithy_runtime::{assert_str_contains, test_util::capture_test_logs::capture_test_logs}; +use aws_smithy_types::error::display::DisplayErrorContext; +use bytes::Bytes; +use std::time::Duration; + +/// No really, it's 42 bytes long... super neat +const NEAT_DATA: Bytes = Bytes::from_static(b"some really neat data"); + +/// Ticks time forward by the given duration, and logs the current time for debugging. +macro_rules! tick { + ($ticker:ident, $duration:expr) => { + $ticker.tick($duration).await; + let now = $ticker + .now() + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap(); + tracing::info!("ticked {:?}, now at {:?}", $duration, now); + }; +} + +/// Scenario: Successfully download at a rate above the minimum throughput. +/// Expected: MUST NOT timeout. +#[tokio::test] +async fn download_success() { + let _logs = capture_test_logs(); + + let (time, sleep) = tick_advance_time_and_sleep(); + let (server, response_sender) = channel_server(); + let op = operation(server, time.clone(), sleep); + + let server = tokio::spawn(async move { + for _ in 1..100 { + response_sender.send(NEAT_DATA).await.unwrap(); + tick!(time, Duration::from_secs(1)); + } + drop(response_sender); + tick!(time, Duration::from_secs(1)); + }); + + let response_body = op.invoke(()).await.expect("initial success"); + let result = eagerly_consume(response_body).await; + server.await.unwrap(); + + result.ok().expect("response MUST NOT timeout"); +} + +/// Scenario: Download takes a some time to start, but then goes normally. +/// Expected: MUT NOT timeout. +#[tokio::test] +async fn download_slow_start() { + let _logs = capture_test_logs(); + + let (time, sleep) = tick_advance_time_and_sleep(); + let (server, response_sender) = channel_server(); + let op = operation(server, time.clone(), sleep); + + let server = tokio::spawn(async move { + // Delay almost to the end of the grace period before sending anything + tick!(time, Duration::from_secs(4)); + for _ in 1..100 { + response_sender.send(NEAT_DATA).await.unwrap(); + tick!(time, Duration::from_secs(1)); + } + drop(response_sender); + tick!(time, Duration::from_secs(1)); + }); + + let response_body = op.invoke(()).await.expect("initial success"); + let result = eagerly_consume(response_body).await; + server.await.unwrap(); + + result.ok().expect("response MUST NOT timeout"); +} + +/// Scenario: Download starts fine, and then slowly falls below minimum throughput. +/// Expected: MUST timeout. +#[tokio::test] +async fn download_too_slow() { + let _logs = capture_test_logs(); + + let (time, sleep) = tick_advance_time_and_sleep(); + let (server, response_sender) = channel_server(); + let op = operation(server, time.clone(), sleep); + + let server = tokio::spawn(async move { + // Get slower with every poll + for delay in 1..100 { + let _ = response_sender.send(NEAT_DATA).await; + tick!(time, Duration::from_secs(delay)); + } + drop(response_sender); + tick!(time, Duration::from_secs(1)); + }); + + let response_body = op.invoke(()).await.expect("initial success"); + let result = eagerly_consume(response_body).await; + server.await.unwrap(); + + let err = result.expect_err("should have timed out"); + assert_str_contains!( + DisplayErrorContext(err.as_ref()).to_string(), + "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" + ); +} + +/// Scenario: Download starts fine, and then the server stalls and stops sending data. +/// Expected: MUST timeout. +#[tokio::test] +async fn download_stalls() { + let _logs = capture_test_logs(); + + let (time, sleep) = tick_advance_time_and_sleep(); + let (server, response_sender) = channel_server(); + let op = operation(server, time.clone(), sleep); + + let server = tokio::spawn(async move { + for _ in 1..10 { + response_sender.send(NEAT_DATA).await.unwrap(); + tick!(time, Duration::from_secs(1)); + } + tick!(time, Duration::from_secs(10)); + }); + + let response_body = op.invoke(()).await.expect("initial success"); + let result = eagerly_consume(response_body).await; + server.await.unwrap(); + + let err = result.expect_err("should have timed out"); + assert_str_contains!( + DisplayErrorContext(err.as_ref()).to_string(), + "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" + ); +} + +/// Scenario: Download starts fine, but then the server stalls for a time within the +/// grace period. Following that, it starts sending data again. +/// Expected: MUST NOT timeout. +#[tokio::test] +async fn download_stall_recovery_in_grace_period() { + let _logs = capture_test_logs(); + + let (time, sleep) = tick_advance_time_and_sleep(); + let (server, response_sender) = channel_server(); + let op = operation(server, time.clone(), sleep); + + let server = tokio::spawn(async move { + for _ in 1..10 { + response_sender.send(NEAT_DATA).await.unwrap(); + tick!(time, Duration::from_secs(1)); + } + // Delay almost to the end of the grace period + tick!(time, Duration::from_secs(4)); + // And now recover + for _ in 1..10 { + response_sender.send(NEAT_DATA).await.unwrap(); + tick!(time, Duration::from_secs(1)); + } + drop(response_sender); + tick!(time, Duration::from_secs(1)); + }); + + let response_body = op.invoke(()).await.expect("initial success"); + let result = eagerly_consume(response_body).await; + server.await.unwrap(); + + result.ok().expect("response MUST NOT timeout"); +} + +/// Scenario: The server sends data fast enough, but the customer doesn't consume the +/// data fast enough. +/// Expected: MUST NOT timeout. +#[tokio::test] +async fn user_downloads_data_too_slowly() { + let _logs = capture_test_logs(); + + let (time, sleep) = tick_advance_time_and_sleep(); + let (server, response_sender) = channel_server(); + let op = operation(server, time.clone(), sleep); + + let server = tokio::spawn(async move { + for _ in 1..100 { + response_sender.send(NEAT_DATA).await.unwrap(); + } + drop(response_sender); + }); + + let response_body = op.invoke(()).await.expect("initial success"); + let result = slowly_consume(time, response_body).await; + server.await.unwrap(); + + result.ok().expect("response MUST NOT timeout"); +} + +use test_tools::*; +mod test_tools { + use aws_smithy_async::test_util::tick_advance_sleep::{TickAdvanceSleep, TickAdvanceTime}; + use aws_smithy_async::time::TimeSource; + use aws_smithy_runtime::client::{ + orchestrator::operation::Operation, + stalled_stream_protection::{ + StalledStreamProtectionInterceptor, StalledStreamProtectionInterceptorKind, + }, + }; + use aws_smithy_runtime_api::{ + box_error::BoxError, + client::{ + http::{ + HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, + SharedHttpConnector, + }, + interceptors::context::{Error, Output}, + orchestrator::{HttpRequest, HttpResponse, OrchestratorError}, + runtime_components::RuntimeComponents, + ser_de::DeserializeResponse, + stalled_stream_protection::StalledStreamProtectionConfig, + }, + shared::IntoShared, + }; + use aws_smithy_types::{body::SdkBody, timeout::TimeoutConfig}; + use bytes::Bytes; + use http_body_0_4::Body; + use pin_utils::pin_mut; + use std::{ + convert::Infallible, + future::poll_fn, + mem, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + time::Duration, + }; + + #[derive(Debug)] + struct FakeServer(SharedHttpConnector); + + impl HttpClient for FakeServer { + fn http_connector( + &self, + _settings: &HttpConnectorSettings, + _components: &RuntimeComponents, + ) -> SharedHttpConnector { + self.0.clone() + } + } + + struct ChannelBody { + receiver: tokio::sync::mpsc::Receiver, + } + impl http_body_0_4::Body for ChannelBody { + type Data = Bytes; + type Error = Infallible; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.receiver.poll_recv(cx) { + Poll::Ready(value) => Poll::Ready(value.map(|v| Ok(v))), + Poll::Pending => Poll::Pending, + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + unreachable!() + } + } + + pub fn channel_body() -> (SdkBody, tokio::sync::mpsc::Sender) { + let (sender, receiver) = tokio::sync::mpsc::channel(1000); + (SdkBody::from_body_0_4(ChannelBody { receiver }), sender) + } + + fn response(body: SdkBody) -> HttpResponse { + HttpResponse::try_from(http::Response::builder().status(200).body(body).unwrap()).unwrap() + } + + pub fn operation( + http_connector: impl HttpConnector + 'static, + time: TickAdvanceTime, + sleep: TickAdvanceSleep, + ) -> Operation<(), SdkBody, Infallible> { + #[derive(Debug)] + struct Deserializer; + impl DeserializeResponse for Deserializer { + fn deserialize_streaming( + &self, + response: &mut HttpResponse, + ) -> Option>> { + let mut body = SdkBody::taken(); + mem::swap(response.body_mut(), &mut body); + Some(Ok(Output::erase(body))) + } + + fn deserialize_nonstreaming( + &self, + _: &HttpResponse, + ) -> Result> { + unreachable!() + } + } + + let operation = Operation::builder() + .service_name("test") + .operation_name("test") + .http_client(FakeServer(http_connector.into_shared())) + .endpoint_url("http://localhost:1234/doesntmatter") + .no_auth() + .no_retry() + .timeout_config(TimeoutConfig::disabled()) + .serializer(|_body: ()| Ok(HttpRequest::new(SdkBody::empty()))) + .deserializer_impl(Deserializer) + .stalled_stream_protection( + StalledStreamProtectionConfig::enabled() + .grace_period(Duration::from_secs(5)) + .build(), + ) + .interceptor(StalledStreamProtectionInterceptor::new( + StalledStreamProtectionInterceptorKind::RequestAndResponseBody, + )) + .sleep_impl(sleep) + .time_source(time) + .build(); + operation + } + + /// Fake server/connector that responds with a channel body. + pub fn channel_server() -> (SharedHttpConnector, tokio::sync::mpsc::Sender) { + #[derive(Debug)] + struct FakeServerConnector { + body: Arc>>, + } + impl HttpConnector for FakeServerConnector { + fn call(&self, _request: HttpRequest) -> HttpConnectorFuture { + let body = self.body.lock().unwrap().take().unwrap(); + HttpConnectorFuture::new(async move { Ok(response(body)) }) + } + } + + let (body, body_sender) = channel_body(); + ( + FakeServerConnector { + body: Arc::new(Mutex::new(Some(body))), + } + .into_shared(), + body_sender, + ) + } + + /// Simulate a client eagerly consuming all the data sent to it from the server. + pub async fn eagerly_consume(body: SdkBody) -> Result<(), BoxError> { + pin_mut!(body); + while let Some(result) = poll_fn(|cx| body.as_mut().poll_data(cx)).await { + if let Err(err) = result { + return Err(err); + } else { + tracing::info!("consumed bytes from the response body"); + } + } + Ok(()) + } + + /// Simulate a client very slowly consuming data with an eager server. + /// + /// This implementation will take longer than the grace period to consume + /// the next piece of data. + pub async fn slowly_consume(time: TickAdvanceTime, body: SdkBody) -> Result<(), BoxError> { + pin_mut!(body); + while let Some(result) = poll_fn(|cx| body.as_mut().poll_data(cx)).await { + if let Err(err) = result { + return Err(err); + } else { + tracing::info!("consumed bytes from the response body"); + tick!(time, Duration::from_secs(10)); + } + } + Ok(()) + } +} From dd8d351d35cfd12ecc22a8aa761e48055e1f2863 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Fri, 8 Mar 2024 15:39:51 -0800 Subject: [PATCH 04/19] Commonize code between upload/download tests --- .../tests/stalled_stream_common.rs | 113 +++++++++++++++ .../tests/stalled_stream_download.rs | 106 +------------- .../tests/stalled_stream_upload.rs | 135 +++--------------- 3 files changed, 135 insertions(+), 219 deletions(-) create mode 100644 rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs new file mode 100644 index 0000000000..99d0850b52 --- /dev/null +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs @@ -0,0 +1,113 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +pub use aws_smithy_async::{ + test_util::tick_advance_sleep::{ + tick_advance_time_and_sleep, TickAdvanceSleep, TickAdvanceTime, + }, + time::TimeSource, +}; +pub use aws_smithy_runtime::{ + assert_str_contains, + client::{ + orchestrator::operation::Operation, + stalled_stream_protection::{ + StalledStreamProtectionInterceptor, StalledStreamProtectionInterceptorKind, + }, + }, + test_util::capture_test_logs::capture_test_logs, +}; +pub use aws_smithy_runtime_api::{ + box_error::BoxError, + client::{ + http::{ + HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, + SharedHttpConnector, + }, + interceptors::context::{Error, Output}, + orchestrator::{HttpRequest, HttpResponse, OrchestratorError}, + result::SdkError, + runtime_components::RuntimeComponents, + ser_de::DeserializeResponse, + stalled_stream_protection::StalledStreamProtectionConfig, + }, + http::{Response, StatusCode}, + shared::IntoShared, +}; +pub use aws_smithy_types::{ + body::SdkBody, error::display::DisplayErrorContext, timeout::TimeoutConfig, +}; +pub use bytes::Bytes; +pub use http_body_0_4::Body; +pub use pin_utils::pin_mut; +pub use std::{ + collections::VecDeque, + convert::Infallible, + future::poll_fn, + mem, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + time::Duration, +}; +pub use tracing::{info, Instrument as _}; + +/// No really, it's 42 bytes long... super neat +pub const NEAT_DATA: Bytes = Bytes::from_static(b"some really neat data"); + +/// Ticks time forward by the given duration, and logs the current time for debugging. +#[macro_export] +macro_rules! tick { + ($ticker:ident, $duration:expr) => { + $ticker.tick($duration).await; + let now = $ticker + .now() + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap(); + tracing::info!("ticked {:?}, now at {:?}", $duration, now); + }; +} + +#[derive(Debug)] +pub struct FakeServer(pub SharedHttpConnector); +impl HttpClient for FakeServer { + fn http_connector( + &self, + _settings: &HttpConnectorSettings, + _components: &RuntimeComponents, + ) -> SharedHttpConnector { + self.0.clone() + } +} + +struct ChannelBody { + receiver: tokio::sync::mpsc::Receiver, +} +impl http_body_0_4::Body for ChannelBody { + type Data = Bytes; + type Error = Infallible; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.receiver.poll_recv(cx) { + Poll::Ready(value) => Poll::Ready(value.map(|v| Ok(v))), + Poll::Pending => Poll::Pending, + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + unreachable!() + } +} + +pub fn channel_body() -> (SdkBody, tokio::sync::mpsc::Sender) { + let (sender, receiver) = tokio::sync::mpsc::channel(1000); + (SdkBody::from_body_0_4(ChannelBody { receiver }), sender) +} diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs index 8cc160f011..cd07ed586f 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs @@ -3,27 +3,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_async::test_util::tick_advance_sleep::tick_advance_time_and_sleep; -use aws_smithy_async::time::TimeSource; -use aws_smithy_runtime::{assert_str_contains, test_util::capture_test_logs::capture_test_logs}; -use aws_smithy_types::error::display::DisplayErrorContext; -use bytes::Bytes; use std::time::Duration; -/// No really, it's 42 bytes long... super neat -const NEAT_DATA: Bytes = Bytes::from_static(b"some really neat data"); - -/// Ticks time forward by the given duration, and logs the current time for debugging. -macro_rules! tick { - ($ticker:ident, $duration:expr) => { - $ticker.tick($duration).await; - let now = $ticker - .now() - .duration_since(std::time::SystemTime::UNIX_EPOCH) - .unwrap(); - tracing::info!("ticked {:?}, now at {:?}", $duration, now); - }; -} +#[macro_use] +mod stalled_stream_common; +use stalled_stream_common::*; /// Scenario: Successfully download at a rate above the minimum throughput. /// Expected: MUST NOT timeout. @@ -198,87 +182,9 @@ async fn user_downloads_data_too_slowly() { result.ok().expect("response MUST NOT timeout"); } -use test_tools::*; -mod test_tools { - use aws_smithy_async::test_util::tick_advance_sleep::{TickAdvanceSleep, TickAdvanceTime}; - use aws_smithy_async::time::TimeSource; - use aws_smithy_runtime::client::{ - orchestrator::operation::Operation, - stalled_stream_protection::{ - StalledStreamProtectionInterceptor, StalledStreamProtectionInterceptorKind, - }, - }; - use aws_smithy_runtime_api::{ - box_error::BoxError, - client::{ - http::{ - HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, - SharedHttpConnector, - }, - interceptors::context::{Error, Output}, - orchestrator::{HttpRequest, HttpResponse, OrchestratorError}, - runtime_components::RuntimeComponents, - ser_de::DeserializeResponse, - stalled_stream_protection::StalledStreamProtectionConfig, - }, - shared::IntoShared, - }; - use aws_smithy_types::{body::SdkBody, timeout::TimeoutConfig}; - use bytes::Bytes; - use http_body_0_4::Body; - use pin_utils::pin_mut; - use std::{ - convert::Infallible, - future::poll_fn, - mem, - pin::Pin, - sync::{Arc, Mutex}, - task::{Context, Poll}, - time::Duration, - }; - - #[derive(Debug)] - struct FakeServer(SharedHttpConnector); - - impl HttpClient for FakeServer { - fn http_connector( - &self, - _settings: &HttpConnectorSettings, - _components: &RuntimeComponents, - ) -> SharedHttpConnector { - self.0.clone() - } - } - - struct ChannelBody { - receiver: tokio::sync::mpsc::Receiver, - } - impl http_body_0_4::Body for ChannelBody { - type Data = Bytes; - type Error = Infallible; - - fn poll_data( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - match self.receiver.poll_recv(cx) { - Poll::Ready(value) => Poll::Ready(value.map(|v| Ok(v))), - Poll::Pending => Poll::Pending, - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - unreachable!() - } - } - - pub fn channel_body() -> (SdkBody, tokio::sync::mpsc::Sender) { - let (sender, receiver) = tokio::sync::mpsc::channel(1000); - (SdkBody::from_body_0_4(ChannelBody { receiver }), sender) - } +use download_test_tools::*; +mod download_test_tools { + use crate::stalled_stream_common::*; fn response(body: SdkBody) -> HttpResponse { HttpResponse::try_from(http::Response::builder().status(200).body(body).unwrap()).unwrap() diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs index 0033a8828f..44309517fd 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs @@ -3,27 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_async::time::TimeSource; -use aws_smithy_runtime::{assert_str_contains, test_util::capture_test_logs::capture_test_logs}; -use aws_smithy_types::error::display::DisplayErrorContext; -use bytes::Bytes; -use std::time::Duration; -use tracing::info; - -/// No really, it's 42 bytes long... super neat -const NEAT_DATA: Bytes = Bytes::from_static(b"some really neat data"); - -/// Ticks time forward by the given duration, and logs the current time for debugging. -macro_rules! tick { - ($ticker:ident, $duration:expr) => { - $ticker.tick($duration).await; - let now = $ticker - .now() - .duration_since(std::time::SystemTime::UNIX_EPOCH) - .unwrap(); - tracing::info!("ticked {:?}, now at {:?}", $duration, now); - }; -} +#[macro_use] +mod stalled_stream_common; +use stalled_stream_common::*; /// Scenario: Successful upload at a rate above the minimum throughput. /// Expected: MUST NOT timeout. @@ -94,14 +76,7 @@ async fn upload_too_slow() { drop(body_sender); }); - let err = result - .await - .expect("no panics") - .expect_err("should have timed out"); - assert_str_contains!( - DisplayErrorContext(&err).to_string(), - "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" - ); + expect_timeout(result.await.expect("no panics")); } /// Scenario: The server stops asking for data, the client maxes out its send buffer, @@ -127,14 +102,7 @@ async fn upload_stalls() { time.tick(Duration::from_secs(1)).await; }); - let err = result - .await - .expect("no panics") - .expect_err("should have timed out"); - assert_str_contains!( - DisplayErrorContext(&err).to_string(), - "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" - ); + expect_timeout(result.await.expect("no panics")); } // Scenario: The server stops asking for data, the client maxes out its send buffer, @@ -194,88 +162,9 @@ async fn user_provides_data_too_slowly() { assert_eq!(200, result.await.unwrap().expect("success").as_u16()); } -use test_tools::*; -mod test_tools { - use aws_smithy_async::test_util::tick_advance_sleep::{ - tick_advance_time_and_sleep, TickAdvanceSleep, TickAdvanceTime, - }; - use aws_smithy_async::time::TimeSource; - use aws_smithy_runtime::client::{ - orchestrator::operation::Operation, - stalled_stream_protection::{ - StalledStreamProtectionInterceptor, StalledStreamProtectionInterceptorKind, - }, - }; - use aws_smithy_runtime_api::{ - client::{ - http::{ - HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, - SharedHttpConnector, - }, - orchestrator::{HttpRequest, HttpResponse}, - runtime_components::RuntimeComponents, - stalled_stream_protection::StalledStreamProtectionConfig, - }, - http::StatusCode, - shared::IntoShared, - }; - use aws_smithy_types::{body::SdkBody, timeout::TimeoutConfig}; - use bytes::Bytes; - use http_body_0_4::Body; - use pin_utils::pin_mut; - use std::{ - collections::VecDeque, - convert::Infallible, - future::poll_fn, - mem, - pin::Pin, - task::{Context, Poll}, - time::Duration, - }; - use tracing::instrument::Instrument; - - #[derive(Debug)] - struct FakeServer(SharedHttpConnector); - - impl HttpClient for FakeServer { - fn http_connector( - &self, - _settings: &HttpConnectorSettings, - _components: &RuntimeComponents, - ) -> SharedHttpConnector { - self.0.clone() - } - } - - struct ChannelBody { - receiver: tokio::sync::mpsc::Receiver, - } - impl http_body_0_4::Body for ChannelBody { - type Data = Bytes; - type Error = Infallible; - - fn poll_data( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - match self.receiver.poll_recv(cx) { - Poll::Ready(value) => Poll::Ready(value.map(|v| Ok(v))), - Poll::Pending => Poll::Pending, - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - unreachable!() - } - } - - pub fn channel_body() -> (SdkBody, tokio::sync::mpsc::Sender) { - let (sender, receiver) = tokio::sync::mpsc::channel(1000); - (SdkBody::from_body_0_4(ChannelBody { receiver }), sender) - } +use upload_test_tools::*; +mod upload_test_tools { + use crate::stalled_stream_common::*; pub fn successful_response() -> HttpResponse { HttpResponse::try_from( @@ -421,4 +310,12 @@ mod test_tools { time_sequence.into_iter().collect() ) } + + pub fn expect_timeout(result: Result>>) { + let err = result.expect_err("should have timed out"); + assert_str_contains!( + DisplayErrorContext(&err).to_string(), + "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" + ); + } } From d57f7df4d9bc52578ffdff3ac8371b8af8afd702 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Mon, 11 Mar 2024 16:50:34 -0700 Subject: [PATCH 05/19] Rework upload/download stalled stream protection --- .../client/http/body/minimum_throughput.rs | 242 +++------ .../minimum_throughput/http_body_0_4_x.rs | 385 +++---------- .../http/body/minimum_throughput/options.rs | 48 +- .../body/minimum_throughput/throughput.rs | 514 ++++++++++++------ .../src/client/stalled_stream_protection.rs | 29 +- .../tests/stalled_stream_download.rs | 7 +- .../tests/stalled_stream_performance.rs | 4 +- .../tests/stalled_stream_upload.rs | 33 +- 8 files changed, 583 insertions(+), 679 deletions(-) diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs index 96aa9517bf..05973c336b 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs @@ -15,6 +15,7 @@ pub mod options; pub use throughput::Throughput; mod throughput; +use crate::client::http::body::minimum_throughput::throughput::ThroughputReport; use aws_smithy_async::rt::sleep::Sleep; use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep}; use aws_smithy_async::time::{SharedTimeSource, TimeSource}; @@ -40,15 +41,20 @@ use std::{ }; use throughput::ThroughputLogs; +/// Use [`MinimumThroughputDownloadBody`] instead. +#[deprecated(note = "Renamed to MinimumThroughputDownloadBody since it doesn't work for uploads")] +pub type MinimumThroughputBody = MinimumThroughputDownloadBody; + pin_project_lite::pin_project! { /// A body-wrapping type that ensures data is being streamed faster than some lower limit. /// /// If data is being streamed too slowly, this body type will emit an error next time it's polled. - pub struct MinimumThroughputBody { + pub struct MinimumThroughputDownloadBody { async_sleep: SharedAsyncSleep, time_source: SharedTimeSource, options: MinimumThroughputBodyOptions, throughput_logs: ThroughputLogs, + resolution: Duration, #[pin] sleep_fut: Option, #[pin] @@ -58,10 +64,7 @@ pin_project_lite::pin_project! { } } -const SIZE_OF_ONE_LOG: usize = std::mem::size_of::<(SystemTime, u64)>(); // 24 bytes per log -const NUMBER_OF_LOGS_IN_ONE_KB: f64 = 1024.0 / SIZE_OF_ONE_LOG as f64; - -impl MinimumThroughputBody { +impl MinimumThroughputDownloadBody { /// Create a new minimum throughput body. pub fn new( time_source: impl TimeSource + 'static, @@ -69,14 +72,15 @@ impl MinimumThroughputBody { body: B, options: MinimumThroughputBodyOptions, ) -> Self { + let time_source: SharedTimeSource = time_source.into_shared(); + let now = time_source.now(); + let throughput_logs = ThroughputLogs::new(options.check_window(), now); + let resolution = throughput_logs.resolution(); Self { - throughput_logs: ThroughputLogs::new( - // Never keep more than 10KB of logs in memory. This currently - // equates to 426 logs. - (NUMBER_OF_LOGS_IN_ONE_KB * 10.0) as usize, - ), + throughput_logs, + resolution, async_sleep: async_sleep.into_shared(), - time_source: time_source.into_shared(), + time_source, inner: body, sleep_fut: None, grace_period_fut: None, @@ -111,33 +115,29 @@ impl std::error::Error for Error {} /// Used to store the upload throughput in the interceptor context. #[derive(Clone, Debug)] pub(crate) struct UploadThroughput { - log: Arc>, + logs: Arc>, } impl UploadThroughput { - pub(crate) fn new() -> Self { + pub(crate) fn new(time_window: Duration, now: SystemTime) -> Self { Self { - log: Arc::new(Mutex::new(ThroughputLogs::new( - // Never keep more than 10KB of logs in memory. This currently - // equates to 426 logs. - (NUMBER_OF_LOGS_IN_ONE_KB * 10.0) as usize, - ))), + logs: Arc::new(Mutex::new(ThroughputLogs::new(time_window, now))), } } - pub(crate) fn push(&self, now: SystemTime, bytes: u64) { - self.log.lock().unwrap().push((now, bytes)); + pub(crate) fn resolution(&self) -> Duration { + self.logs.lock().unwrap().resolution() } - pub(crate) fn calculate_throughput( - &self, - now: SystemTime, - time_window: Duration, - ) -> Option { - self.log - .lock() - .unwrap() - .calculate_throughput(now, time_window) + pub(crate) fn push_pending(&self, now: SystemTime) { + self.logs.lock().unwrap().push_pending(now); + } + pub(crate) fn push_bytes_transferred(&self, now: SystemTime, bytes: u64) { + self.logs.lock().unwrap().push_bytes_transferred(now, bytes); + } + + pub(crate) fn report(&self, now: SystemTime) -> ThroughputReport { + self.logs.lock().unwrap().report(now) } } @@ -168,7 +168,51 @@ impl ThroughputReadingBody { } } +const ZERO_THROUGHPUT: Throughput = Throughput::new_bytes_per_second(0); + +// Helper trait for interpretting the throughput report. +trait UploadReport { + fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput); +} +impl UploadReport for ThroughputReport { + fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) { + let throughput = match self { + // If the report is incomplete, then we don't have enough data yet to + // decide if minimum throughput was violated. + ThroughputReport::Incomplete => { + tracing::trace!( + "not enough data to decide if minimum throughput has been violated" + ); + return (false, ZERO_THROUGHPUT); + } + // If most of the datapoints are Poll::Pending, then the user has stalled. + // In this case, we don't want to say minimum throughput was violated. + ThroughputReport::Pending => { + tracing::debug!( + "the user has stalled; this will not become a minimum throughput violation" + ); + return (false, ZERO_THROUGHPUT); + } + // If there has been no polling, then the server has stalled. Alternatively, + // if we're transferring data, but it's too slow, then we also want to say + // that the minimum throughput has been violated. + ThroughputReport::NoPolling => ZERO_THROUGHPUT, + ThroughputReport::Transferred(tp) => tp, + }; + if throughput < minimum_throughput { + tracing::debug!( + "current throughput: {throughput} is below minimum: {minimum_throughput}" + ); + (true, throughput) + } else { + (false, throughput) + } + } +} + pin_project_lite::pin_project! { + /// Future that pairs with [`UploadThroughput`] to add a minimum throughput + /// requirement to a request upload stream. struct ThroughputCheckFuture { #[pin] response: HttpConnectorFuture, @@ -180,9 +224,8 @@ pin_project_lite::pin_project! { time_source: SharedTimeSource, sleep_impl: SharedAsyncSleep, upload_throughput: UploadThroughput, - minimum_throughput: Throughput, - time_window: Duration, - grace_time: Duration, + resolution: Duration, + options: MinimumThroughputBodyOptions, failing_throughput: Option, } @@ -194,20 +237,18 @@ impl ThroughputCheckFuture { time_source: SharedTimeSource, sleep_impl: SharedAsyncSleep, upload_throughput: UploadThroughput, - minimum_throughput: Throughput, - time_window: Duration, - grace_time: Duration, + options: MinimumThroughputBodyOptions, ) -> Self { + let resolution = upload_throughput.resolution(); Self { response, - check_interval: Some(sleep_impl.sleep(time_window)), + check_interval: Some(sleep_impl.sleep(resolution)), grace_period: None, time_source, sleep_impl, upload_throughput, - minimum_throughput, - time_window, - grace_time, + resolution, + options, failing_throughput: None, } } @@ -232,7 +273,7 @@ impl Future for ThroughputCheckFuture { .is_ready(); if check_interval_expired { // Set up the next check interval - *this.check_interval = Some(this.sleep_impl.sleep(*this.time_window)); + *this.check_interval = Some(this.sleep_impl.sleep(*this.resolution)); // Wake so that the check interval future gets polled // next time this poll method is called. If it never gets polled, @@ -243,16 +284,12 @@ impl Future for ThroughputCheckFuture { let should_check = check_interval_expired || this.grace_period.is_some(); if should_check { let now = this.time_source.now(); - let current_throughput = this - .upload_throughput - .calculate_throughput(now, *this.time_window); - below_minimum_throughput = current_throughput - .as_ref() - .map(|tp| tp < this.minimum_throughput) - .unwrap_or_default(); - tracing::debug!("current throughput: {current_throughput:?}, below minimum: {below_minimum_throughput}"); + let report = this.upload_throughput.report(now); + let (violated, current_throughput) = + report.minimum_throughput_violated(this.options.minimum_throughput()); + below_minimum_throughput = violated; if below_minimum_throughput && !this.failing_throughput.is_some() { - *this.failing_throughput = current_throughput; + *this.failing_throughput = Some(current_throughput); } else if !below_minimum_throughput { *this.failing_throughput = None; } @@ -267,10 +304,10 @@ impl Future for ThroughputCheckFuture { // Start a grace period if below minimum throughput if this.grace_period.is_none() { tracing::debug!( - grace_period=?*this.grace_time, + grace_period=?this.options.grace_period(), "upload minimum throughput below configured minimum; starting grace period" ); - *this.grace_period = Some(this.sleep_impl.sleep(*this.grace_time)); + *this.grace_period = Some(this.sleep_impl.sleep(this.options.grace_period())); } // Check the grace period if one is already set and we're not satisfied if let Some(grace_period) = this.grace_period.as_pin_mut() { @@ -278,7 +315,7 @@ impl Future for ThroughputCheckFuture { tracing::debug!("grace period ended; timing out request"); return Poll::Ready(Err(ConnectorError::timeout( Error::ThroughputBelowMinimum { - expected: *this.minimum_throughput, + expected: this.options.minimum_throughput(), actual: this .failing_throughput .expect("always set if there's a grace period"), @@ -339,9 +376,7 @@ impl MaybeThroughputCheckFuture { time_source, sleep_impl, upload_throughput, - options.minimum_throughput(), - options.check_window(), - options.grace_period(), + options, ), } } @@ -360,100 +395,3 @@ impl Future for MaybeThroughputCheckFuture { } } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::{assert_str_contains, test_util::capture_test_logs::capture_test_logs}; - use aws_smithy_async::test_util::tick_advance_sleep::tick_advance_time_and_sleep; - use aws_smithy_types::{body::SdkBody, error::display::DisplayErrorContext}; - use std::future::IntoFuture; - - const TEST_TIME_WINDOW: Duration = Duration::from_secs(1); - - #[tokio::test] - async fn throughput_check_constant_rate_success() { - let minimum_throughput = Throughput::new_bytes_per_second(990); - let grace_time = Duration::from_secs(1); - let actual_throughput_bps = 1000.0; - let transfer_fn = |_| actual_throughput_bps * TEST_TIME_WINDOW.as_secs_f64(); - let result = throughput_check_test(minimum_throughput, grace_time, transfer_fn).await; - let response = result.expect("no timeout"); - assert_eq!(200, response.status().as_u16()); - } - - #[tokio::test] - async fn throughput_check_constant_rate_timeout() { - let minimum_throughput = Throughput::new_bytes_per_second(1100); - let grace_time = Duration::from_secs(1); - let actual_throughput_bps = 1000.0; - let transfer_fn = |_| actual_throughput_bps * TEST_TIME_WINDOW.as_secs_f64(); - let result = throughput_check_test(minimum_throughput, grace_time, transfer_fn).await; - let error = result.err().expect("times out"); - assert_str_contains!( - DisplayErrorContext(&error).to_string(), - "minimum throughput was specified at 1100 B/s, but throughput of 1000 B/s was observed" - ); - } - - #[tokio::test] - async fn throughput_check_grace_time_recovery() { - let minimum_throughput = Throughput::new_bytes_per_second(1000); - let grace_time = Duration::from_secs(3); - let actual_throughput_bps = 1000.0; - let transfer_fn = |window| { - if window <= 5 || window > 7 { - actual_throughput_bps * TEST_TIME_WINDOW.as_secs_f64() - } else { - 0.0 - } - }; - let result = throughput_check_test(minimum_throughput, grace_time, transfer_fn).await; - let response = result.expect("no timeout"); - assert_eq!(200, response.status().as_u16()); - } - - async fn throughput_check_test( - minimum_throughput: Throughput, - grace_time: Duration, - transfer_fn: F, - ) -> Result - where - F: Fn(u64) -> f64, - { - let _logs = capture_test_logs(); - let (time_source, sleep_impl) = tick_advance_time_and_sleep(); - - let response = HttpResponse::try_from( - http::Response::builder() - .status(200) - .body(SdkBody::empty()) - .unwrap(), - ) - .unwrap(); - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - - let upload_throughput = UploadThroughput::new(); - let check_task = tokio::spawn(ThroughputCheckFuture::new( - HttpConnectorFuture::new(async move { Ok(response_rx.into_future().await.unwrap()) }), - time_source.clone().into_shared(), - sleep_impl.into_shared(), - upload_throughput.clone(), - minimum_throughput, - TEST_TIME_WINDOW, - grace_time, - )); - - // simulate 20 check time windows at `actual_throughput` bytes/sec - for window in 0..20 { - let bytes = (transfer_fn)(window); - upload_throughput.push(time_source.now(), (bytes + 0.5) as u64); - time_source.tick(TEST_TIME_WINDOW).await; - println!("window {window}"); - } - let _ = response_tx.send(response); - println!("upload finished"); - - check_task.await.expect("no panic") - } -} diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs index 85b73d062d..f89b009eb8 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs @@ -3,15 +3,58 @@ * SPDX-License-Identifier: Apache-2.0 */ -use super::{BoxError, Error, MinimumThroughputBody}; -use crate::client::http::body::minimum_throughput::ThroughputReadingBody; +use super::{BoxError, Error, MinimumThroughputDownloadBody}; +use crate::client::http::body::minimum_throughput::{ + throughput::ThroughputReport, Throughput, ThroughputReadingBody, +}; use aws_smithy_async::rt::sleep::AsyncSleep; use http_body_0_4::Body; use std::future::Future; use std::pin::{pin, Pin}; use std::task::{Context, Poll}; -impl Body for MinimumThroughputBody +const ZERO_THROUGHPUT: Throughput = Throughput::new_bytes_per_second(0); + +// Helper trait for interpretting the throughput report. +trait DownloadReport { + fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput); +} +impl DownloadReport for ThroughputReport { + fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) { + let throughput = match self { + // If the report is incomplete, then we don't have enough data yet to + // decide if minimum throughput was violated. + ThroughputReport::Incomplete => { + tracing::trace!( + "not enough data to decide if minimum throughput has been violated" + ); + return (false, ZERO_THROUGHPUT); + } + // If no polling is taking place, then the user has stalled. + // In this case, we don't want to say minimum throughput was violated. + ThroughputReport::NoPolling => { + tracing::debug!( + "the user has stalled; this will not become a minimum throughput violation" + ); + return (false, ZERO_THROUGHPUT); + } + // If we're stuck in Poll::Pending, then the server has stalled. Alternatively, + // if we're transferring data, but it's too slow, then we also want to say + // that the minimum throughput has been violated. + ThroughputReport::Pending => ZERO_THROUGHPUT, + ThroughputReport::Transferred(tp) => tp, + }; + let violated = throughput < minimum_throughput; + if violated { + tracing::debug!( + "current throughput: {throughput} is below minimum: {minimum_throughput}" + ); + } + (violated, throughput) + } +} + +impl Body for MinimumThroughputDownloadBody where B: Body, { @@ -31,12 +74,13 @@ where let poll_res = match this.inner.poll_data(cx) { Poll::Ready(Some(Ok(bytes))) => { tracing::trace!("received data: {}", bytes.len()); - this.throughput_logs.push((now, bytes.len() as u64)); + this.throughput_logs + .push_bytes_transferred(now, bytes.len() as u64); Poll::Ready(Some(Ok(bytes))) } Poll::Pending => { tracing::trace!("received poll pending"); - this.throughput_logs.push((now, 0)); + this.throughput_logs.push_pending(now); Poll::Pending } // If we've read all the data or an error occurred, then return that result. @@ -47,44 +91,27 @@ where let mut sleep_fut = this .sleep_fut .take() - .unwrap_or_else(|| this.async_sleep.sleep(this.options.check_interval())); + .unwrap_or_else(|| this.async_sleep.sleep(*this.resolution)); if let Poll::Ready(()) = pin!(&mut sleep_fut).poll(cx) { tracing::trace!("sleep future triggered—triggering a wakeup"); // Whenever the sleep future expires, we replace it. - sleep_fut = this.async_sleep.sleep(this.options.check_interval()); + sleep_fut = this.async_sleep.sleep(*this.resolution); // We also schedule a wake up for current task to ensure that // it gets polled at least one more time. cx.waker().wake_by_ref(); }; this.sleep_fut.replace(sleep_fut); - let calculated_tpt = match this - .throughput_logs - .calculate_throughput(now, this.options.check_window()) - { - Some(tpt) => tpt, - None => { - tracing::trace!("calculated throughput is None!"); - return poll_res; - } - }; - tracing::trace!( - "calculated throughput {:?} (window: {:?})", - calculated_tpt, - this.options.check_window() - ); // Calculate the current throughput and emit an error if it's too low and // the grace period has elapsed. - let is_below_minimum_throughput = calculated_tpt <= this.options.minimum_throughput(); - if is_below_minimum_throughput { - // Check the grace period future to see if it needs creating. - tracing::trace!( - in_grace_period = this.grace_period_fut.is_some(), - observed_throughput = ?calculated_tpt, - minimum_throughput = ?this.options.minimum_throughput(), - "below minimum throughput" - ); + let report = this.throughput_logs.report(now); + let (violated, current_throughput) = + report.minimum_throughput_violated(this.options.minimum_throughput()); + if violated { + if this.grace_period_fut.is_none() { + tracing::debug!("entering minimum throughput grace period"); + } let mut grace_period_fut = this .grace_period_fut .take() @@ -93,13 +120,16 @@ where // The grace period has ended! return Poll::Ready(Some(Err(Box::new(Error::ThroughputBelowMinimum { expected: self.options.minimum_throughput(), - actual: calculated_tpt, + actual: current_throughput, })))); }; this.grace_period_fut.replace(grace_period_fut); } else { // Ensure we don't have an active grace period future if we're not // currently below the minimum throughput. + if this.grace_period_fut.is_some() { + tracing::debug!("throughput recovered; exiting grace period"); + } let _ = this.grace_period_fut.take(); } @@ -135,12 +165,13 @@ where match this.inner.poll_data(cx) { Poll::Ready(Some(Ok(bytes))) => { tracing::trace!("received data: {}", bytes.len()); - this.throughput.push(now, bytes.len() as u64); + this.throughput + .push_bytes_transferred(now, bytes.len() as u64); Poll::Ready(Some(Ok(bytes))) } Poll::Pending => { tracing::trace!("received poll pending"); - this.throughput.push(now, 0); + this.throughput.push_pending(now); Poll::Pending } // If we've read all the data or an error occurred, then return that result. @@ -156,289 +187,3 @@ where this.inner.poll_trailers(cx) } } - -// These tests use `hyper::body::Body::wrap_stream` -#[cfg(all(test, feature = "connector-hyper-0-14-x", feature = "test-util"))] -mod test { - use super::{super::Throughput, Error, MinimumThroughputBody}; - use crate::client::http::body::minimum_throughput::options::MinimumThroughputBodyOptions; - use crate::test_util::capture_test_logs::capture_test_logs; - use aws_smithy_async::rt::sleep::AsyncSleep; - use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep, ManualTimeSource}; - use aws_smithy_types::body::SdkBody; - use aws_smithy_types::byte_stream::{AggregatedBytes, ByteStream}; - use aws_smithy_types::error::display::DisplayErrorContext; - use bytes::{BufMut, Bytes, BytesMut}; - use http::HeaderMap; - use http_body_0_4::Body; - use once_cell::sync::Lazy; - use pretty_assertions::assert_eq; - use std::convert::Infallible; - use std::error::Error as StdError; - use std::future::{poll_fn, Future}; - use std::pin::{pin, Pin}; - use std::task::{Context, Poll}; - use std::time::{Duration, UNIX_EPOCH}; - - struct NeverBody; - - impl Body for NeverBody { - type Data = Bytes; - type Error = Box<(dyn StdError + Send + Sync + 'static)>; - - fn poll_data( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll>> { - Poll::Pending - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - unreachable!("body can't be read, so this won't be called") - } - } - - #[tokio::test()] - async fn test_self_waking() { - let (time_source, async_sleep) = instant_time_and_sleep(UNIX_EPOCH); - let mut body = MinimumThroughputBody::new( - time_source.clone(), - async_sleep.clone(), - NeverBody, - Default::default(), - ); - time_source.advance(Duration::from_secs(1)); - let actual_err = body.data().await.expect("next chunk exists").unwrap_err(); - let expected_err = Error::ThroughputBelowMinimum { - expected: (1, Duration::from_secs(1)).into(), - actual: (0, Duration::from_secs(1)).into(), - }; - - assert_eq!(expected_err.to_string(), actual_err.to_string()); - } - - fn create_test_stream( - async_sleep: impl AsyncSleep + Clone, - ) -> impl futures_util::Stream> { - futures_util::stream::unfold(1, move |state| { - let async_sleep = async_sleep.clone(); - async move { - if state > 255 { - None - } else { - async_sleep.sleep(Duration::from_secs(1)).await; - Some(( - Result::<_, Infallible>::Ok(Bytes::from_static(b"00000000")), - state + 1, - )) - } - } - }) - } - - static EXPECTED_BYTES: Lazy> = - Lazy::new(|| (1..=255).flat_map(|_| b"00000000").copied().collect()); - - fn eight_byte_per_second_stream_with_minimum_throughput_timeout( - minimum_throughput: Throughput, - ) -> ( - impl Future>, - ManualTimeSource, - InstantSleep, - ) { - let (time_source, async_sleep) = instant_time_and_sleep(UNIX_EPOCH); - let time_clone = time_source.clone(); - - // Will send ~8 bytes per second. - let stream = create_test_stream(async_sleep.clone()); - let body = ByteStream::new(SdkBody::from_body_0_4(hyper_0_14::body::Body::wrap_stream( - stream, - ))); - let body = body.map(move |body| { - let time_source = time_clone.clone(); - // We don't want to log these sleeps because it would duplicate - // the `sleep` calls being logged by the MTB - let async_sleep = InstantSleep::unlogged(); - SdkBody::from_body_0_4(MinimumThroughputBody::new( - time_source, - async_sleep, - body, - MinimumThroughputBodyOptions::builder() - .minimum_throughput(minimum_throughput) - .build(), - )) - }); - - (body.collect(), time_source, async_sleep) - } - - async fn expect_error(minimum_throughput: Throughput) { - let (res, ..) = - eight_byte_per_second_stream_with_minimum_throughput_timeout(minimum_throughput); - let expected_err = Error::ThroughputBelowMinimum { - expected: minimum_throughput, - actual: Throughput::new(8, Duration::from_secs(1)), - }; - match res.await { - Ok(_) => { - panic!( - "response succeeded instead of returning the expected error '{expected_err}'" - ) - } - Err(actual_err) => { - assert_eq!( - expected_err.to_string(), - // We need to source this so that we don't get the streaming error it's wrapped in. - actual_err.source().unwrap().to_string() - ); - } - } - } - - #[tokio::test] - async fn test_throughput_timeout_less_than() { - let minimum_throughput = Throughput::new_bytes_per_second(9); - expect_error(minimum_throughput).await; - } - - async fn expect_success(minimum_throughput: Throughput) { - let (res, time_source, async_sleep) = - eight_byte_per_second_stream_with_minimum_throughput_timeout(minimum_throughput); - match res.await { - Ok(res) => { - assert_eq!(255.0, time_source.seconds_since_unix_epoch()); - assert_eq!(Duration::from_secs(255), async_sleep.total_duration()); - assert_eq!(*EXPECTED_BYTES, res.to_vec()); - } - Err(err) => panic!("{}", DisplayErrorContext(err.source().unwrap())), - } - } - - #[tokio::test] - async fn test_throughput_timeout_equal_to() { - let (_guard, _) = capture_test_logs(); - // a tiny bit less. To capture 0-throughput properly, we need to allow 0 to be 0 - let minimum_throughput = Throughput::new(31, Duration::from_secs(4)); - expect_success(minimum_throughput).await; - } - - #[tokio::test] - async fn test_throughput_timeout_greater_than() { - let minimum_throughput = Throughput::new(20, Duration::from_secs(3)); - expect_success(minimum_throughput).await; - } - - // A multiplier for the sine wave amplitude; Chosen arbitrarily. - const BYTE_COUNT_UPPER_LIMIT: u64 = 1000; - - /// emits 1000B/S for 5 seconds then suddenly stops - fn sudden_stop( - async_sleep: impl AsyncSleep + Clone, - ) -> impl futures_util::Stream> { - let sleep_dur = Duration::from_millis(50); - fastrand::seed(0); - futures_util::stream::unfold(1, move |i| { - let async_sleep = async_sleep.clone(); - async move { - let number_seconds = (i * sleep_dur).as_secs_f64(); - async_sleep.sleep(sleep_dur).await; - if number_seconds > 5.0 { - Some((Result::::Ok(Bytes::new()), i + 1)) - } else { - let mut bytes = BytesMut::new(); - let bytes_per_segment = - (BYTE_COUNT_UPPER_LIMIT as f64) * sleep_dur.as_secs_f64(); - for _ in 0..bytes_per_segment as usize { - bytes.put_u8(0) - } - - Some((Result::::Ok(bytes.into()), i + 1)) - } - } - }) - } - - #[tokio::test] - async fn test_stalled_stream_detection() { - test_suddenly_stopping_stream(0, Duration::from_secs(6)).await - } - - #[tokio::test] - async fn test_slow_stream_detection() { - test_suddenly_stopping_stream(BYTE_COUNT_UPPER_LIMIT / 2, Duration::from_secs_f64(5.50)) - .await - } - - #[tokio::test] - async fn test_check_interval() { - let (_guard, _) = capture_test_logs(); - let (ts, sleep) = instant_time_and_sleep(UNIX_EPOCH); - let mut body = MinimumThroughputBody::new( - ts, - sleep.clone(), - NeverBody, - MinimumThroughputBodyOptions::builder() - .check_interval(Duration::from_millis(1234)) - .grace_period(Duration::from_millis(456)) - .build(), - ); - let mut body = pin!(body); - let _ = poll_fn(|cx| body.as_mut().poll_data(cx)).await; - assert_eq!( - sleep.logs(), - vec![ - // sleep, by second sleep we know we have no data, then the grace period - Duration::from_millis(1234), - Duration::from_millis(1234), - Duration::from_millis(456) - ] - ); - } - - async fn test_suddenly_stopping_stream(throughput_limit: u64, time_until_timeout: Duration) { - let (_guard, _) = capture_test_logs(); - let options = MinimumThroughputBodyOptions::builder() - // Minimum throughput per second will be approx. half of the BYTE_COUNT_UPPER_LIMIT. - .minimum_throughput(Throughput::new_bytes_per_second(throughput_limit)) - .build(); - let (time_source, async_sleep) = instant_time_and_sleep(UNIX_EPOCH); - let time_clone = time_source.clone(); - - let stream = sudden_stop(async_sleep.clone()); - let body = ByteStream::new(SdkBody::from_body_0_4(hyper_0_14::body::Body::wrap_stream( - stream, - ))); - let res = body - .map(move |body| { - let time_source = time_clone.clone(); - // We don't want to log these sleeps because it would duplicate - // the `sleep` calls being logged by the MTB - let async_sleep = InstantSleep::unlogged(); - SdkBody::from_body_0_4(MinimumThroughputBody::new( - time_source, - async_sleep, - body, - options.clone(), - )) - }) - .collect(); - - match res.await { - Ok(_res) => { - panic!("stream should have timed out"); - } - Err(err) => { - dbg!(err); - assert_eq!( - async_sleep.total_duration(), - time_until_timeout, - "With throughput limit {:?} expected timeout after {:?} (stream starts sending 0's at 5 seconds.", - throughput_limit, time_until_timeout - ); - } - } - } -} diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs index 4c8fc1177b..79d2ec5063 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs @@ -12,6 +12,7 @@ use std::time::Duration; pub struct MinimumThroughputBodyOptions { /// The minimum throughput that is acceptable. minimum_throughput: Throughput, + /// The 'grace period' after which the minimum throughput will be enforced. /// /// If this is set to 0, the minimum throughput will be enforced immediately. @@ -24,9 +25,6 @@ pub struct MinimumThroughputBodyOptions { /// stream-startup. grace_period: Duration, - /// The interval at which the throughput is checked. - check_interval: Duration, - /// The period of time to consider when computing the throughput /// /// This SHOULD be longer than the check interval, or stuck-streams may evade detection. @@ -44,7 +42,6 @@ impl MinimumThroughputBodyOptions { MinimumThroughputBodyOptionsBuilder::new() .minimum_throughput(self.minimum_throughput) .grace_period(self.grace_period) - .check_interval(self.check_interval) } /// The throughput check grace period. @@ -65,12 +62,10 @@ impl MinimumThroughputBodyOptions { self.check_window } - /// The rate at which the throughput is checked. - /// - /// The actual rate throughput is checked may be higher than this value, - /// but it will never be lower. + /// Not used. Always returns `Duration::from_millis(500)`. + #[deprecated(note = "No longer used. Always returns Duration::from_millis(500)")] pub fn check_interval(&self) -> Duration { - self.check_interval + Duration::from_millis(500) } } @@ -79,7 +74,6 @@ impl Default for MinimumThroughputBodyOptions { Self { minimum_throughput: DEFAULT_MINIMUM_THROUGHPUT, grace_period: DEFAULT_GRACE_PERIOD, - check_interval: DEFAULT_CHECK_INTERVAL, check_window: DEFAULT_CHECK_WINDOW, } } @@ -89,11 +83,10 @@ impl Default for MinimumThroughputBodyOptions { #[derive(Debug, Default, Clone)] pub struct MinimumThroughputBodyOptionsBuilder { minimum_throughput: Option, - check_interval: Option, + check_window: Option, grace_period: Option, } -const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_millis(500); const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(0); const DEFAULT_MINIMUM_THROUGHPUT: Throughput = Throughput { bytes_read: 1, @@ -136,19 +129,26 @@ impl MinimumThroughputBodyOptionsBuilder { self } - /// Set the rate at which throughput is checked. - /// - /// Defaults to 1 second. - pub fn check_interval(mut self, check_interval: Duration) -> Self { - self.set_check_interval(Some(check_interval)); + /// No longer used. + #[deprecated(note = "No longer used.")] + pub fn check_interval(self, _check_interval: Duration) -> Self { self } - /// Set the rate at which throughput is checked. - /// - /// Defaults to 1 second. - pub fn set_check_interval(&mut self, check_interval: Option) -> &mut Self { - self.check_interval = check_interval; + /// No longer used. + #[deprecated(note = "No longer used.")] + pub fn set_check_interval(&mut self, _check_interval: Option) -> &mut Self { + self + } + + #[allow(unused)] + pub(crate) fn check_window(mut self, check_window: Duration) -> Self { + self.set_check_window(Some(check_window)); + self + } + #[allow(unused)] + pub(crate) fn set_check_window(&mut self, check_window: Option) -> &mut Self { + self.check_window = check_window; self } @@ -161,8 +161,7 @@ impl MinimumThroughputBodyOptionsBuilder { minimum_throughput: self .minimum_throughput .unwrap_or(DEFAULT_MINIMUM_THROUGHPUT), - check_interval: self.check_interval.unwrap_or(DEFAULT_CHECK_INTERVAL), - check_window: DEFAULT_CHECK_WINDOW, + check_window: self.check_window.unwrap_or(DEFAULT_CHECK_WINDOW), } } } @@ -172,7 +171,6 @@ impl From for MinimumThroughputBodyOptions { MinimumThroughputBodyOptions { grace_period: value.grace_period(), minimum_throughput: DEFAULT_MINIMUM_THROUGHPUT, - check_interval: DEFAULT_CHECK_INTERVAL, check_window: DEFAULT_CHECK_WINDOW, } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs index f5c759caa4..cd4f5f5a61 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs @@ -3,12 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -use std::collections::VecDeque; use std::fmt; use std::time::{Duration, SystemTime}; /// Throughput representation for use when configuring [`super::MinimumThroughputBody`] #[derive(Debug, Clone, Copy)] +#[cfg_attr(test, derive(Eq))] pub struct Throughput { pub(super) bytes_read: u64, pub(super) per_time_elapsed: Duration, @@ -29,7 +29,7 @@ impl Throughput { } /// Create a new throughput in bytes per second. - pub fn new_bytes_per_second(bytes: u64) -> Self { + pub const fn new_bytes_per_second(bytes: u64) -> Self { Self { bytes_read: bytes, per_time_elapsed: Duration::from_secs(1), @@ -37,7 +37,7 @@ impl Throughput { } /// Create a new throughput in kilobytes per second. - pub fn new_kilobytes_per_second(kilobytes: u64) -> Self { + pub const fn new_kilobytes_per_second(kilobytes: u64) -> Self { Self { bytes_read: kilobytes * 1000, per_time_elapsed: Duration::from_secs(1), @@ -45,7 +45,7 @@ impl Throughput { } /// Create a new throughput in megabytes per second. - pub fn new_megabytes_per_second(megabytes: u64) -> Self { + pub const fn new_megabytes_per_second(megabytes: u64) -> Self { Self { bytes_read: megabytes * 1000 * 1000, per_time_elapsed: Duration::from_secs(1), @@ -97,83 +97,264 @@ impl From<(u64, Duration)> for Throughput { } } -#[derive(Clone, Debug)] -pub(super) struct ThroughputLogs { - max_length: usize, - inner: VecDeque<(SystemTime, u64)>, - bytes_processed: u64, +/// Cell in a linear grid that represents a small chunk of time. +#[derive(Copy, Clone, Debug)] +enum Cell { + /// There is no data in this cell. + Empty, + + /// No polling took place during this cell. + NoPolling, + + /// This many bytes were transferred during this cell. + TransferredBytes(u64), + + /// The user/remote was not providing/consuming data fast enough during this cell. + /// + /// The number is the number of bytes transferred, if this replaced TransferredBytes. + Pending(u64), } +impl Cell { + fn merge(&mut self, other: Cell) { + use Cell::*; + // Assign values based on this priority order (highest priority higher up): + // 1. Pending + // 2. TransferredBytes + // 3. NoPolling + // 4. Empty + *self = match (other, *self) { + (Pending(other), this) => Pending(other + this.bytes()), + (TransferredBytes(other), this) => TransferredBytes(other + this.bytes()), + (other, NoPolling) => other, + (NoPolling, _) => panic!("can't merge NoPolling into cell"), + (other, Empty) => other, + (Empty, _) => panic!("can't merge Empty into cell"), + }; + } -impl ThroughputLogs { - pub(super) fn new(max_length: usize) -> Self { - Self { - inner: VecDeque::with_capacity(max_length), - max_length, - bytes_processed: 0, + /// Number of bytes transferred during this cell + fn bytes(&self) -> u64 { + match *self { + Cell::Empty | Cell::NoPolling => 0, + Cell::TransferredBytes(bytes) | Cell::Pending(bytes) => bytes, } } +} + +#[derive(Copy, Clone, Debug, Default)] +struct CellCounts { + /// Number of cells with no data. + empty: usize, + /// Number of "no polling" cells. + no_polling: usize, + /// Number of "bytes transferred" cells. + transferred: usize, + /// Number of "pending" cells. + pending: usize, +} - pub(super) fn push(&mut self, throughput: (SystemTime, u64)) { - // When the number of logs exceeds the max length, toss the oldest log. - if self.inner.len() == self.max_length { - self.bytes_processed -= self.inner.pop_front().map(|(_, sz)| sz).unwrap_or_default(); +/// Underlying stack-allocated linear grid buffer for tracking +/// throughput events for [`ThroughputLogs`]. +#[derive(Copy, Clone, Debug)] +struct LogBuffer { + entries: [Cell; N], + // The length only needs to exist so that the `fill_gaps` function + // can differentiate between `Empty` due to there not having been enough + // time to establish a full buffer worth of data vs. `Empty` due to a + // polling gap. Once the length reaches N, it will never change again. + length: usize, +} +impl LogBuffer { + fn new() -> Self { + Self { + entries: [Cell::Empty; N], + length: 0, } - debug_assert!(self.inner.capacity() > self.inner.len()); - self.bytes_processed += throughput.1; - self.inner.push_back(throughput); - } - - pub(super) fn calculate_throughput( - &self, - now: SystemTime, - time_window: Duration, - ) -> Option { - if self.inner.is_empty() { - return None; + } + + /// Mutably returns the tail of the buffer. + /// + /// The buffer MUST have at least one cell it it before this is called. + fn tail_mut(&mut self) -> &mut Cell { + debug_assert!(self.length > 0); + &mut self.entries[self.length - 1] + } + + /// Pushes a cell into the buffer. If the buffer is already full, + /// then this will rotate the entire buffer to the left. + fn push(&mut self, cell: Cell) { + if self.filled() { + self.entries.rotate_left(1); + self.entries[N - 1] = cell; + } else { + self.entries[self.length] = cell; + self.length += 1; } - if let Some(first_time) = self.inner.front().map(|e| e.0) { - if first_time + time_window >= now && first_time < now { - // If the first logged time fits within the time window, then short-circuit - return Some(Throughput { - bytes_read: self.bytes_processed, - per_time_elapsed: now.duration_since(first_time).expect("checked above"), - }); + } + + /// Returns the total number of bytes transferred within the time window. + fn bytes_transferred(&self) -> u64 { + self.entries.iter().take(self.length).map(Cell::bytes).sum() + } + + #[inline] + fn filled(&self) -> bool { + self.length == N + } + + /// Fills in missing NoData entries. + /// + /// We want NoData entries to represent when a future hasn't been polled. + /// Since the future is in charge of logging in the first place, the only + /// way we can know about these is by examining gaps in time. + fn fill_gaps(&mut self) { + for entry in self.entries.iter_mut().take(self.length) { + if matches!(entry, Cell::Empty) { + *entry = Cell::NoPolling; } } - if let Some(last_time) = self.inner.back().map(|e| e.0) { - if last_time + time_window < now { - // If we have no log entries at all within the time window, then short-circuit to 0 - return Some(Throughput { - bytes_read: 0, - per_time_elapsed: time_window, - }); + } + + /// Returns the counts of each cell type in the buffer. + fn counts(&self) -> CellCounts { + let mut counts = CellCounts::default(); + for entry in &self.entries { + match entry { + Cell::Empty => counts.empty += 1, + Cell::NoPolling => counts.no_polling += 1, + Cell::TransferredBytes(_) => counts.transferred += 1, + Cell::Pending(_) => counts.pending += 1, } } + counts + } +} + +/// Report/summary of all the events in a time window. +#[cfg_attr(test, derive(Debug, Eq, PartialEq))] +pub(crate) enum ThroughputReport { + /// Not enough data to draw any conclusions. This happens early in a request/response. + Incomplete, + /// The stream hasn't been polled for most of this time window. + NoPolling, + /// The stream has been waiting for most of the time window. + Pending, + /// The stream transferred this amount of throughput during the time window. + Transferred(Throughput), +} + +const BUFFER_SIZE: usize = 10; + +/// Log of throughput in a request or response stream. +/// +/// Used to determine if a configured minimum throughput is being met or not +/// so that a request or response stream can be timed out in the event of a +/// stall. +/// +/// Request/response streams push data transfer or pending events to this log +/// based on what's going on in their poll functions. The log tracks three kinds +/// of events despite only receiving two: the third is "no polling". The poll +/// functions cannot know when they're not being polled, so the log examines gaps +/// in the event history to know when no polling took place. +/// +/// The event logging is simplified down to a linear 10-cell grid, which each +/// cell representing 1/10th the total time window. When an event is pushed, +/// it is either merged into the current tail cell, or all the cells are rotated +/// left to create a new empty tail cell, and then it is merged into that one. +#[derive(Clone, Debug)] +pub(super) struct ThroughputLogs { + resolution: Duration, + current_tail: SystemTime, + buffer: LogBuffer, +} + +impl ThroughputLogs { + /// Creates a new log starting at `now` with the given `time_window`. + /// + /// Note: the `time_window` gets divided by 10 to create smaller sub-windows + /// to track throughput. The time window should configured to be large enough + /// so that these sub-windows aren't too small for network-based events. + /// A time window of 10ms probably won't work, but 500ms might. The default + /// is one second. + pub(super) fn new(time_window: Duration, now: SystemTime) -> Self { + assert!(!time_window.is_zero()); + let resolution = time_window.div_f64(BUFFER_SIZE as f64); + Self { + resolution, + current_tail: now, + buffer: LogBuffer::new(), + } + } - let minimum_ts = now - time_window; - let first_item = self.inner.iter().find(|(ts, _)| *ts >= minimum_ts)?.0; + /// Returns the resolution at which events are logged at. + pub(super) fn resolution(&self) -> Duration { + self.resolution + } - let time_elapsed = now.duration_since(first_item).unwrap_or_default(); + /// Pushes a "pending" event. + /// + /// Pending indicates the streaming future is waiting for something. + /// In an upload, it is waiting for data from the user, and in a download, + /// it is waiting for data from the server. + pub(super) fn push_pending(&mut self, time: SystemTime) { + self.push(time, Cell::Pending(0)); + } - let total_bytes_logged = self - .inner - .iter() - .rev() - .take_while(|(ts, _)| *ts >= minimum_ts) - .map(|t| t.1) - .sum::(); + /// Pushes a data transferred event. + /// + /// Indicates that this number of bytes were transferred at this time. + pub(super) fn push_bytes_transferred(&mut self, time: SystemTime, bytes: u64) { + self.push(time, Cell::TransferredBytes(bytes)); + } - Some(Throughput { - bytes_read: total_bytes_logged, - per_time_elapsed: time_elapsed, - }) + fn push(&mut self, now: SystemTime, value: Cell) { + self.catch_up(now); + self.buffer.tail_mut().merge(value); + self.buffer.fill_gaps(); + } + + /// Pushes empty cells until `current_tail` is caught up to `now`. + fn catch_up(&mut self, now: SystemTime) { + while now >= self.current_tail { + self.current_tail += self.resolution; + self.buffer.push(Cell::Empty); + } + assert!(self.current_tail >= now); + } + + /// Generates an overall report of the time window. + pub(super) fn report(&mut self, now: SystemTime) -> ThroughputReport { + self.catch_up(now); + self.buffer.fill_gaps(); + + let CellCounts { + empty, + no_polling, + transferred, + pending, + } = self.buffer.counts(); + + let bytes = self.buffer.bytes_transferred(); + let time = self.resolution * (BUFFER_SIZE - empty) as u32; + let throughput = Throughput::new(bytes, time); + + let half = BUFFER_SIZE / 2; + if empty >= half { + return ThroughputReport::Incomplete; + } + match (transferred > 0, no_polling >= half, pending >= half) { + (true, _, _) => ThroughputReport::Transferred(throughput), + (_, true, _) => ThroughputReport::NoPolling, + (_, _, true) => ThroughputReport::Pending, + _ => ThroughputReport::Incomplete, + } } } #[cfg(test)] mod test { - use super::{Throughput, ThroughputLogs}; - use std::time::{Duration, SystemTime, UNIX_EPOCH}; + use super::*; + use std::time::Duration; #[test] fn test_throughput_eq() { @@ -185,132 +366,143 @@ mod test { assert_eq!(t2, t3); } - fn build_throughput_log( - length: u32, - tick_duration: Duration, - rate: u64, - ) -> (ThroughputLogs, SystemTime) { - let mut throughput_logs = ThroughputLogs::new(length as usize); - for i in 1..=length { - throughput_logs.push((UNIX_EPOCH + (tick_duration * i), rate)); - } - - assert_eq!(length as usize, throughput_logs.inner.len()); - (throughput_logs, UNIX_EPOCH + (tick_duration * length)) + #[test] + fn incomplete_no_entries() { + let start = SystemTime::UNIX_EPOCH; + let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); + let report = logs.report(start); + assert_eq!(ThroughputReport::Incomplete, report); } - const EPSILON: f64 = 0.001; - macro_rules! assert_delta { - ($x:expr, $y:expr, $d:expr) => { - if !(($x as f64) - $y < $d || $y - ($x as f64) < $d) { - panic!(); - } - }; + #[test] + fn incomplete_with_entries() { + let start = SystemTime::UNIX_EPOCH; + let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); + logs.push_pending(start); + + let report = logs.report(start + Duration::from_millis(300)); + assert_eq!(ThroughputReport::Incomplete, report); } #[test] - fn test_throughput_log_calculate_throughput_1() { - let (throughput_logs, now) = build_throughput_log(1000, Duration::from_secs(1), 1); - - for dur in [10, 100, 100] { - let throughput = throughput_logs - .calculate_throughput(now, Duration::from_secs(dur)) - .unwrap(); - assert_eq!(1.0, throughput.bytes_per_second()); - } - let throughput = throughput_logs - .calculate_throughput(now, Duration::from_secs_f64(101.5)) - .unwrap(); - assert_delta!(1, throughput.bytes_per_second(), EPSILON); + fn incomplete_with_transferred() { + let start = SystemTime::UNIX_EPOCH; + let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); + logs.push_pending(start); + logs.push_bytes_transferred(start + Duration::from_millis(100), 10); + + let report = logs.report(start + Duration::from_millis(300)); + assert_eq!(ThroughputReport::Incomplete, report); } #[test] - fn test_throughput_log_calculate_throughput_2() { - let (throughput_logs, now) = build_throughput_log(1000, Duration::from_secs(5), 5); + fn push_pending_at_the_beginning_of_each_tick() { + let start = SystemTime::UNIX_EPOCH; + let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); + + let mut now = start; + for i in 1..=BUFFER_SIZE { + logs.push_pending(now); + now += logs.resolution(); + + assert_eq!(i, logs.buffer.counts().pending); + } - let throughput = throughput_logs - .calculate_throughput(now, Duration::from_secs(1000)) - .unwrap(); - assert_eq!(1.0, throughput.bytes_per_second()); + let report = dbg!(&mut logs).report(now); + assert_eq!(ThroughputReport::Pending, report); } #[test] - fn test_throughput_log_calculate_throughput_3() { - let (throughput_logs, now) = build_throughput_log(1000, Duration::from_millis(200), 1024); + fn push_pending_at_the_end_of_each_tick() { + let start = SystemTime::UNIX_EPOCH; + let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); + + let mut now = start; + for i in 1..BUFFER_SIZE { + now += logs.resolution(); + logs.push_pending(now); + + assert_eq!(i, dbg!(&logs).buffer.counts().pending); + assert_eq!(0, logs.buffer.counts().transferred); + assert_eq!(1, logs.buffer.counts().no_polling); + } + // This should replace the initial "no polling" cell + now += logs.resolution(); + logs.push_pending(now); + assert_eq!(0, logs.buffer.counts().no_polling); - let throughput = throughput_logs - .calculate_throughput(now, Duration::from_secs(5)) - .unwrap(); - let expected_throughput = 1024.0 * 5.0; - assert_eq!(expected_throughput, throughput.bytes_per_second()); + let report = dbg!(&mut logs).report(now); + assert_eq!(ThroughputReport::Pending, report); } #[test] - fn test_throughput_log_calculate_throughput_4() { - let (throughput_logs, now) = build_throughput_log(1000, Duration::from_millis(100), 12); + fn push_transferred_at_the_beginning_of_each_tick() { + let start = SystemTime::UNIX_EPOCH; + let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); + + let mut now = start; + for i in 1..=BUFFER_SIZE { + logs.push_bytes_transferred(now, 10); + if i != BUFFER_SIZE { + now += logs.resolution(); + } - let throughput = throughput_logs - .calculate_throughput(now, Duration::from_secs(1)) - .unwrap(); - let expected_throughput = 12.0 * 10.0; + assert_eq!(i, logs.buffer.counts().transferred); + assert_eq!(0, logs.buffer.counts().pending); + assert_eq!(0, logs.buffer.counts().no_polling); + } - assert_eq!(expected_throughput, throughput.bytes_per_second()); + let report = dbg!(&mut logs).report(now); + assert_eq!( + ThroughputReport::Transferred(Throughput::new(100, Duration::from_secs(1))), + report + ); } #[test] - fn test_throughput_followed_by_0() { - let tick = Duration::from_millis(100); - let (mut throughput_logs, now) = build_throughput_log(1000, tick, 12); - let throughput = throughput_logs - .calculate_throughput(now, Duration::from_secs(1)) - .unwrap(); - let expected_throughput = 12.0 * 10.0; - - assert_eq!(expected_throughput, throughput.bytes_per_second()); - throughput_logs.push((now + tick, 0)); - let throughput = throughput_logs - .calculate_throughput(now + tick, Duration::from_secs(1)) - .unwrap(); - assert_eq!(108.0, throughput.bytes_per_second()); - } - - // If the time since the last log entry is greater than the window, then the throughput should be zero - #[test] - fn test_throughput_log_calculate_throughput_long_after_last_log() { - let (throughput_logs, now) = build_throughput_log(1000, Duration::from_millis(100), 12); - - let throughput = throughput_logs - .calculate_throughput(now + Duration::from_secs(5), Duration::from_secs(1)) - .unwrap(); - let expected_throughput = 0.0; - - assert_eq!(expected_throughput, throughput.bytes_per_second()); + fn no_polling() { + let start = SystemTime::UNIX_EPOCH; + let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); + let report = logs.report(start + Duration::from_secs(2)); + assert_eq!(ThroughputReport::NoPolling, report); } - // If the throughput log is empty, it should return None for the calculated throughput + // Transferred bytes MUST take priority over pending #[test] - fn test_throughput_log_calculate_throughput_empty_log() { - let throughput_logs = ThroughputLogs::new(1000); - assert!(throughput_logs - .calculate_throughput(UNIX_EPOCH, Duration::from_secs(1)) - .is_none()); + fn mixed_bag_mostly_pending() { + let start = SystemTime::UNIX_EPOCH; + let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); + + logs.push_bytes_transferred(start + Duration::from_millis(50), 10); + logs.push_pending(start + Duration::from_millis(150)); + logs.push_pending(start + Duration::from_millis(250)); + logs.push_bytes_transferred(start + Duration::from_millis(350), 10); + logs.push_pending(start + Duration::from_millis(450)); + logs.push_pending(start + Duration::from_millis(650)); + logs.push_pending(start + Duration::from_millis(750)); + logs.push_pending(start + Duration::from_millis(850)); + + let report = logs.report(start + Duration::from_millis(999)); + assert_eq!( + ThroughputReport::Transferred(Throughput::new_bytes_per_second(20)), + report + ); } - // Verify things work as expected when everything occurs exactly on the time window boundary #[test] - fn test_boundary_conditions() { - let mut logs = ThroughputLogs::new(1000); - logs.bytes_processed = 2000; - logs.inner.push_back((SystemTime::UNIX_EPOCH, 1000)); - logs.inner - .push_back((SystemTime::UNIX_EPOCH + Duration::from_secs(1), 1000)); - - let throughput = logs - .calculate_throughput( - SystemTime::UNIX_EPOCH + Duration::from_secs(2), - Duration::from_secs(1), - ) - .unwrap(); - assert_eq!(Throughput::new_bytes_per_second(1000), throughput); + fn mixed_bag_mostly_pending_no_transferred() { + let start = SystemTime::UNIX_EPOCH; + let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); + + logs.push_pending(start + Duration::from_millis(50)); + logs.push_pending(start + Duration::from_millis(150)); + logs.push_pending(start + Duration::from_millis(250)); + logs.push_pending(start + Duration::from_millis(450)); + logs.push_pending(start + Duration::from_millis(650)); + logs.push_pending(start + Duration::from_millis(750)); + logs.push_pending(start + Duration::from_millis(850)); + + let report = logs.report(start + Duration::from_millis(999)); + assert_eq!(ThroughputReport::Pending, report); } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs index 89f08c8bde..97a023c55f 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs @@ -4,7 +4,8 @@ */ use crate::client::http::body::minimum_throughput::{ - MinimumThroughputBody, ThroughputReadingBody, UploadThroughput, + options::MinimumThroughputBodyOptions, MinimumThroughputDownloadBody, ThroughputReadingBody, + UploadThroughput, }; use aws_smithy_async::rt::sleep::SharedAsyncSleep; use aws_smithy_async::time::SharedTimeSource; @@ -68,13 +69,15 @@ impl Intercept for StalledStreamProtectionInterceptor { if self.enable_for_request_body { if let Some(sspcfg) = cfg.load::().cloned() { if sspcfg.is_enabled() { - let throughput = UploadThroughput::new(); - cfg.interceptor_state().store_put(throughput.clone()); - let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?; + let now = time_source.now(); + + let options: MinimumThroughputBodyOptions = sspcfg.into(); + let throughput = UploadThroughput::new(options.check_window(), now); + cfg.interceptor_state().store_put(throughput.clone()); + tracing::trace!("adding stalled stream protection to request body"); - let sspcfg = sspcfg.clone(); // TODO XXX use the config? let it = mem::replace(context.request_mut().body_mut(), SdkBody::taken()); let it = it.map_preserve_contents(move |body| { let time_source = time_source.clone(); @@ -99,19 +102,23 @@ impl Intercept for StalledStreamProtectionInterceptor { cfg: &mut ConfigBag, ) -> Result<(), BoxError> { if self.enable_for_response_body { - if let Some(cfg) = cfg.load::() { - if cfg.is_enabled() { + if let Some(sspcfg) = cfg.load::() { + if sspcfg.is_enabled() { let (async_sleep, time_source) = get_runtime_component_deps(runtime_components)?; tracing::trace!("adding stalled stream protection to response body"); - let cfg = cfg.clone(); + let sspcfg = sspcfg.clone(); let it = mem::replace(context.response_mut().body_mut(), SdkBody::taken()); let it = it.map_preserve_contents(move |body| { - let cfg = cfg.clone(); + let sspcfg = sspcfg.clone(); let async_sleep = async_sleep.clone(); let time_source = time_source.clone(); - let mtb = - MinimumThroughputBody::new(time_source, async_sleep, body, cfg.into()); + let mtb = MinimumThroughputDownloadBody::new( + time_source, + async_sleep, + body, + sspcfg.into(), + ); SdkBody::from_body_0_4(mtb) }); let _ = mem::replace(context.response_mut().body_mut(), it); diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs index cd07ed586f..98b77cdd37 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs @@ -113,10 +113,13 @@ async fn download_stalls() { }); let response_body = op.invoke(()).await.expect("initial success"); - let result = eagerly_consume(response_body).await; + let result = tokio::spawn(eagerly_consume(response_body)); server.await.unwrap(); - let err = result.expect_err("should have timed out"); + let err = result + .await + .expect("no panics") + .expect_err("should have timed out"); assert_str_contains!( DisplayErrorContext(err.as_ref()).to_string(), "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_performance.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_performance.rs index 70211cfe52..f1ed0f779a 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_performance.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_performance.rs @@ -7,7 +7,7 @@ use aws_smithy_async::rt::sleep::TokioSleep; use aws_smithy_async::time::{SystemTimeSource, TimeSource}; -use aws_smithy_runtime::client::http::body::minimum_throughput::MinimumThroughputBody; +use aws_smithy_runtime::client::http::body::minimum_throughput::MinimumThroughputDownloadBody; use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig; use aws_smithy_types::body::SdkBody; use aws_smithy_types::byte_stream::ByteStream; @@ -92,7 +92,7 @@ async fn make_request(address: &str, wrap_body: bool) -> Duration { let time_source = SystemTimeSource::new(); let sleep = TokioSleep::new(); let opts = StalledStreamProtectionConfig::enabled().build(); - let mtb = MinimumThroughputBody::new(time_source, sleep, body, opts.into()); + let mtb = MinimumThroughputDownloadBody::new(time_source, sleep, body, opts.into()); SdkBody::from_body_0_4(mtb) }); } diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs index 44309517fd..887fbd21e7 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs @@ -105,6 +105,29 @@ async fn upload_stalls() { expect_timeout(result.await.expect("no panics")); } +/// Scenario: All the request data is either uploaded to the server or buffered in the +/// HTTP client, but the response doesn't start coming through within the grace period. +/// Expected: MUST timeout after the grace period completes. +#[tokio::test] +async fn complete_upload_no_response() { + let _logs = capture_test_logs(); + + let (server, time, sleep) = stalling_server(); + let op = operation(server, time.clone(), sleep); + + let (body, body_sender) = channel_body(); + let result = tokio::spawn(async move { op.invoke(body).await }); + + let _streamer = tokio::spawn(async move { + body_sender.send(NEAT_DATA).await.unwrap(); + tick!(time, Duration::from_secs(1)); + drop(body_sender); + time.tick(Duration::from_secs(6)).await; + }); + + expect_timeout(result.await.expect("no panics")); +} + // Scenario: The server stops asking for data, the client maxes out its send buffer, // and the request stream stops being polled. However, before the grace period // is over, the server recovers and starts asking for data again. @@ -271,14 +294,12 @@ mod upload_test_tools { _: (), ) -> HttpResponse { let mut times = 5; - while poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() { + while times > 0 && poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() { times -= 1; - if times <= 0 { - // never awake after this - tracing::info!("stalling indefinitely"); - std::future::pending().await - } } + // never awake after this + tracing::info!("stalling indefinitely"); + std::future::pending::<()>().await; unreachable!() } fake_server!(FakeServerConnector, fake_server) From 38cd2a3699bf380122c8d040b0655946a256d0ee Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Tue, 12 Mar 2024 18:10:53 -0700 Subject: [PATCH 06/19] CI fixes --- .../s3/tests/stalled-stream-protection.rs | 16 ++-------------- rust-runtime/aws-smithy-runtime/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs b/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs index d70c424d77..6a8d0804d4 100644 --- a/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs +++ b/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs @@ -83,12 +83,6 @@ impl Body for SlowBody { } } -// This test doesn't work because we can't count on `hyper` to poll the body, -// regardless of whether we schedule a wake. To make this functionality work, -// we'd have to integrate more closely with the orchestrator. -// -// I'll leave this test here because we do eventually want to support stalled -// stream protection for uploads. #[tokio::test] async fn test_stalled_stream_protection_defaults_for_upload() { let _logs = capture_test_logs(); @@ -180,7 +174,6 @@ async fn start_faulty_upload_server() -> (impl Future, SocketAddr) } #[tokio::test] -#[ignore] async fn test_explicitly_configured_stalled_stream_protection_for_downloads() { // We spawn a faulty server that will close the connection after // writing half of the response body. @@ -221,7 +214,6 @@ async fn test_explicitly_configured_stalled_stream_protection_for_downloads() { } #[tokio::test] -#[ignore] async fn test_stalled_stream_protection_for_downloads_can_be_disabled() { // We spawn a faulty server that will close the connection after // writing half of the response body. @@ -254,7 +246,6 @@ async fn test_stalled_stream_protection_for_downloads_can_be_disabled() { // This test will always take as long as whatever grace period is set by default. #[tokio::test] -#[ignore] async fn test_stalled_stream_protection_for_downloads_is_enabled_by_default() { // We spawn a faulty server that will close the connection after // writing half of the response body. @@ -289,14 +280,11 @@ async fn test_stalled_stream_protection_for_downloads_is_enabled_by_default() { err.to_string(), "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" ); - // 1s check interval + 5s grace period - assert_eq!(start.elapsed().as_secs(), 6); + // the 1s check interval is included in the 5s grace period + assert_eq!(start.elapsed().as_secs(), 5); } async fn start_faulty_download_server() -> (impl Future, SocketAddr) { - use tokio::net::{TcpListener, TcpStream}; - use tokio::time::sleep; - let listener = TcpListener::bind("0.0.0.0:0") .await .expect("socket is free"); diff --git a/rust-runtime/aws-smithy-runtime/Cargo.toml b/rust-runtime/aws-smithy-runtime/Cargo.toml index c8224745e1..6b4bedc356 100644 --- a/rust-runtime/aws-smithy-runtime/Cargo.toml +++ b/rust-runtime/aws-smithy-runtime/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-runtime" -version = "1.1.8" +version = "1.2.0" authors = ["AWS Rust SDK Team ", "Zelda Hessler "] description = "The new smithy runtime crate" edition = "2021" From 6fd19ec02377f0180157e1757723bdd1afad57ce Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Thu, 14 Mar 2024 13:28:33 -0700 Subject: [PATCH 07/19] Rename "cell" to "bin" --- .../body/minimum_throughput/throughput.rs | 96 +++++++++---------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs index cd4f5f5a61..dc6bb46cac 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs @@ -97,26 +97,26 @@ impl From<(u64, Duration)> for Throughput { } } -/// Cell in a linear grid that represents a small chunk of time. +/// Represents a bin (or a cell) in a linear grid that represents a small chunk of time. #[derive(Copy, Clone, Debug)] -enum Cell { - /// There is no data in this cell. +enum Bin { + /// There is no data in this bin. Empty, - /// No polling took place during this cell. + /// No polling took place during this bin. NoPolling, - /// This many bytes were transferred during this cell. + /// This many bytes were transferred during this bin. TransferredBytes(u64), - /// The user/remote was not providing/consuming data fast enough during this cell. + /// The user/remote was not providing/consuming data fast enough during this bin. /// /// The number is the number of bytes transferred, if this replaced TransferredBytes. Pending(u64), } -impl Cell { - fn merge(&mut self, other: Cell) { - use Cell::*; +impl Bin { + fn merge(&mut self, other: Bin) { + use Bin::*; // Assign values based on this priority order (highest priority higher up): // 1. Pending // 2. TransferredBytes @@ -126,30 +126,30 @@ impl Cell { (Pending(other), this) => Pending(other + this.bytes()), (TransferredBytes(other), this) => TransferredBytes(other + this.bytes()), (other, NoPolling) => other, - (NoPolling, _) => panic!("can't merge NoPolling into cell"), + (NoPolling, _) => panic!("can't merge NoPolling into bin"), (other, Empty) => other, - (Empty, _) => panic!("can't merge Empty into cell"), + (Empty, _) => panic!("can't merge Empty into bin"), }; } - /// Number of bytes transferred during this cell + /// Number of bytes transferred during this bin fn bytes(&self) -> u64 { match *self { - Cell::Empty | Cell::NoPolling => 0, - Cell::TransferredBytes(bytes) | Cell::Pending(bytes) => bytes, + Bin::Empty | Bin::NoPolling => 0, + Bin::TransferredBytes(bytes) | Bin::Pending(bytes) => bytes, } } } #[derive(Copy, Clone, Debug, Default)] -struct CellCounts { - /// Number of cells with no data. +struct BinCounts { + /// Number of bins with no data. empty: usize, - /// Number of "no polling" cells. + /// Number of "no polling" bins. no_polling: usize, - /// Number of "bytes transferred" cells. + /// Number of "bytes transferred" bins. transferred: usize, - /// Number of "pending" cells. + /// Number of "pending" bins. pending: usize, } @@ -157,7 +157,7 @@ struct CellCounts { /// throughput events for [`ThroughputLogs`]. #[derive(Copy, Clone, Debug)] struct LogBuffer { - entries: [Cell; N], + entries: [Bin; N], // The length only needs to exist so that the `fill_gaps` function // can differentiate between `Empty` due to there not having been enough // time to establish a full buffer worth of data vs. `Empty` due to a @@ -167,34 +167,34 @@ struct LogBuffer { impl LogBuffer { fn new() -> Self { Self { - entries: [Cell::Empty; N], + entries: [Bin::Empty; N], length: 0, } } /// Mutably returns the tail of the buffer. /// - /// The buffer MUST have at least one cell it it before this is called. - fn tail_mut(&mut self) -> &mut Cell { + /// The buffer MUST have at least one bin it it before this is called. + fn tail_mut(&mut self) -> &mut Bin { debug_assert!(self.length > 0); &mut self.entries[self.length - 1] } - /// Pushes a cell into the buffer. If the buffer is already full, + /// Pushes a bin into the buffer. If the buffer is already full, /// then this will rotate the entire buffer to the left. - fn push(&mut self, cell: Cell) { + fn push(&mut self, bin: Bin) { if self.filled() { self.entries.rotate_left(1); - self.entries[N - 1] = cell; + self.entries[N - 1] = bin; } else { - self.entries[self.length] = cell; + self.entries[self.length] = bin; self.length += 1; } } /// Returns the total number of bytes transferred within the time window. fn bytes_transferred(&self) -> u64 { - self.entries.iter().take(self.length).map(Cell::bytes).sum() + self.entries.iter().take(self.length).map(Bin::bytes).sum() } #[inline] @@ -209,21 +209,21 @@ impl LogBuffer { /// way we can know about these is by examining gaps in time. fn fill_gaps(&mut self) { for entry in self.entries.iter_mut().take(self.length) { - if matches!(entry, Cell::Empty) { - *entry = Cell::NoPolling; + if matches!(entry, Bin::Empty) { + *entry = Bin::NoPolling; } } } - /// Returns the counts of each cell type in the buffer. - fn counts(&self) -> CellCounts { - let mut counts = CellCounts::default(); + /// Returns the counts of each bin type in the buffer. + fn counts(&self) -> BinCounts { + let mut counts = BinCounts::default(); for entry in &self.entries { match entry { - Cell::Empty => counts.empty += 1, - Cell::NoPolling => counts.no_polling += 1, - Cell::TransferredBytes(_) => counts.transferred += 1, - Cell::Pending(_) => counts.pending += 1, + Bin::Empty => counts.empty += 1, + Bin::NoPolling => counts.no_polling += 1, + Bin::TransferredBytes(_) => counts.transferred += 1, + Bin::Pending(_) => counts.pending += 1, } } counts @@ -257,10 +257,10 @@ const BUFFER_SIZE: usize = 10; /// functions cannot know when they're not being polled, so the log examines gaps /// in the event history to know when no polling took place. /// -/// The event logging is simplified down to a linear 10-cell grid, which each -/// cell representing 1/10th the total time window. When an event is pushed, -/// it is either merged into the current tail cell, or all the cells are rotated -/// left to create a new empty tail cell, and then it is merged into that one. +/// The event logging is simplified down to a linear grid consisting of 10 "bins", +/// with each bin representing 1/10th the total time window. When an event is pushed, +/// it is either merged into the current tail bin, or all the bins are rotated +/// left to create a new empty tail bin, and then it is merged into that one. #[derive(Clone, Debug)] pub(super) struct ThroughputLogs { resolution: Duration, @@ -297,27 +297,27 @@ impl ThroughputLogs { /// In an upload, it is waiting for data from the user, and in a download, /// it is waiting for data from the server. pub(super) fn push_pending(&mut self, time: SystemTime) { - self.push(time, Cell::Pending(0)); + self.push(time, Bin::Pending(0)); } /// Pushes a data transferred event. /// /// Indicates that this number of bytes were transferred at this time. pub(super) fn push_bytes_transferred(&mut self, time: SystemTime, bytes: u64) { - self.push(time, Cell::TransferredBytes(bytes)); + self.push(time, Bin::TransferredBytes(bytes)); } - fn push(&mut self, now: SystemTime, value: Cell) { + fn push(&mut self, now: SystemTime, value: Bin) { self.catch_up(now); self.buffer.tail_mut().merge(value); self.buffer.fill_gaps(); } - /// Pushes empty cells until `current_tail` is caught up to `now`. + /// Pushes empty bins until `current_tail` is caught up to `now`. fn catch_up(&mut self, now: SystemTime) { while now >= self.current_tail { self.current_tail += self.resolution; - self.buffer.push(Cell::Empty); + self.buffer.push(Bin::Empty); } assert!(self.current_tail >= now); } @@ -327,7 +327,7 @@ impl ThroughputLogs { self.catch_up(now); self.buffer.fill_gaps(); - let CellCounts { + let BinCounts { empty, no_polling, transferred, @@ -426,7 +426,7 @@ mod test { assert_eq!(0, logs.buffer.counts().transferred); assert_eq!(1, logs.buffer.counts().no_polling); } - // This should replace the initial "no polling" cell + // This should replace the initial "no polling" bin now += logs.resolution(); logs.push_pending(now); assert_eq!(0, logs.buffer.counts().no_polling); From 8ef35d6636789974f2ecb932aaeb77623bb6ad09 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Thu, 14 Mar 2024 13:48:45 -0700 Subject: [PATCH 08/19] Fix empty cell check --- .../http/body/minimum_throughput/throughput.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs index dc6bb46cac..0c31a079cc 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs @@ -174,6 +174,8 @@ impl LogBuffer { /// Mutably returns the tail of the buffer. /// + /// ## Panics + /// /// The buffer MUST have at least one bin it it before this is called. fn tail_mut(&mut self) -> &mut Bin { debug_assert!(self.length > 0); @@ -334,14 +336,17 @@ impl ThroughputLogs { pending, } = self.buffer.counts(); + // If there are any empty cells at all, then we haven't been tracking + // long enough to make any judgements about the stream's progress. + if empty > 0 { + return ThroughputReport::Incomplete; + } + let bytes = self.buffer.bytes_transferred(); let time = self.resolution * (BUFFER_SIZE - empty) as u32; let throughput = Throughput::new(bytes, time); let half = BUFFER_SIZE / 2; - if empty >= half { - return ThroughputReport::Incomplete; - } match (transferred > 0, no_polling >= half, pending >= half) { (true, _, _) => ThroughputReport::Transferred(throughput), (_, true, _) => ThroughputReport::NoPolling, @@ -478,6 +483,7 @@ mod test { logs.push_pending(start + Duration::from_millis(250)); logs.push_bytes_transferred(start + Duration::from_millis(350), 10); logs.push_pending(start + Duration::from_millis(450)); + // skip 550 logs.push_pending(start + Duration::from_millis(650)); logs.push_pending(start + Duration::from_millis(750)); logs.push_pending(start + Duration::from_millis(850)); @@ -497,7 +503,9 @@ mod test { logs.push_pending(start + Duration::from_millis(50)); logs.push_pending(start + Duration::from_millis(150)); logs.push_pending(start + Duration::from_millis(250)); + // skip 350 logs.push_pending(start + Duration::from_millis(450)); + // skip 550 logs.push_pending(start + Duration::from_millis(650)); logs.push_pending(start + Duration::from_millis(750)); logs.push_pending(start + Duration::from_millis(850)); From ac15324f2f6cf96a0df2619442dab7dccc0c4aa2 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Thu, 14 Mar 2024 13:57:05 -0700 Subject: [PATCH 09/19] Update changelog --- CHANGELOG.next.toml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 1a69497fa7..a7b1f6c6a1 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -22,3 +22,15 @@ message = "`DefaultS3ExpressIdentityProvider` now uses `BehaviorVersion` threade references = ["smithy-rs#3478"] meta = { "breaking" = false, "bug" = true, "tada" = false } author = "ysaito1001" + +[[smithy-rs]] +message = "Stalled stream protection now supports request upload streams, and existing download stream protection no longer triggers when user-code is doing something other than downloading in the middle of the stream (for example, making a HTTP request to another service for each row in CSV file download from S3)." +references = ["smithy-rs#3485"] +meta = { "breaking" = false, "tada" = true, "bug" = true } +authors = ["jdisanti"] + +[[aws-sdk-rust]] +message = "Stalled stream protection now supports request upload streams, and existing download stream protection no longer triggers when user-code is doing something other than downloading in the middle of the stream (for example, making a HTTP request to another service for each row in CSV file download from S3)." +references = ["smithy-rs#3485"] +meta = { "breaking" = false, "tada" = true, "bug" = true } +author = "jdisanti" From 510bedbed641e096689f8e6ff20cfca801e072ff Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Thu, 14 Mar 2024 17:29:46 -0700 Subject: [PATCH 10/19] Fix aws-smithy-runtime no-default-features build --- rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs | 2 ++ .../aws-smithy-runtime/tests/stalled_stream_download.rs | 2 ++ rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs | 2 ++ 3 files changed, 6 insertions(+) diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs index 99d0850b52..69c201eb0c 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![cfg(all(feature = "client", feature = "test-util"))] + pub use aws_smithy_async::{ test_util::tick_advance_sleep::{ tick_advance_time_and_sleep, TickAdvanceSleep, TickAdvanceTime, diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs index 98b77cdd37..d538054e6c 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![cfg(all(feature = "client", feature = "test-util"))] + use std::time::Duration; #[macro_use] diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs index 887fbd21e7..29eb12c62f 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![cfg(all(feature = "client", feature = "test-util"))] + #[macro_use] mod stalled_stream_common; use stalled_stream_common::*; From 73635049dc5bc5fabb603292fec92d5d803e9706 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Mon, 25 Mar 2024 13:56:28 -0700 Subject: [PATCH 11/19] Fix size hints --- .../s3/tests/body_size_hint.rs | 122 ++++++++++++++++++ .../minimum_throughput/http_body_0_4_x.rs | 16 +++ 2 files changed, 138 insertions(+) create mode 100644 aws/sdk/integration-tests/s3/tests/body_size_hint.rs diff --git a/aws/sdk/integration-tests/s3/tests/body_size_hint.rs b/aws/sdk/integration-tests/s3/tests/body_size_hint.rs new file mode 100644 index 0000000000..270a87f902 --- /dev/null +++ b/aws/sdk/integration-tests/s3/tests/body_size_hint.rs @@ -0,0 +1,122 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Body wrappers must pass through size_hint + +use aws_config::SdkConfig; +use aws_sdk_s3::{ + config::{Credentials, HttpClient, Region, RuntimeComponents, SharedCredentialsProvider}, + primitives::{ByteStream, SdkBody}, + Client, +}; +use aws_smithy_runtime_api::{ + client::{ + http::{HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector}, + orchestrator::HttpRequest, + }, + http::{Response, StatusCode}, +}; +use http_body::Body; +use std::sync::{Arc, Mutex}; + +#[derive(Clone, Debug, Default)] +struct TestClient { + response_body: Arc>>, + captured_body: Arc>>, +} +impl HttpConnector for TestClient { + fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture { + *self.captured_body.lock().unwrap() = Some(request.take_body()); + let body = self + .response_body + .lock() + .unwrap() + .take() + .unwrap_or_else(SdkBody::empty); + HttpConnectorFuture::ready(Ok(Response::new(StatusCode::try_from(200).unwrap(), body))) + } +} +impl HttpClient for TestClient { + fn http_connector( + &self, + _settings: &HttpConnectorSettings, + _components: &RuntimeComponents, + ) -> SharedHttpConnector { + SharedHttpConnector::new(self.clone()) + } +} + +#[tokio::test] +async fn download_body_size_hint_check() { + let test_body_content = b"hello"; + let test_body = SdkBody::from(&test_body_content[..]); + assert_eq!( + Some(test_body_content.len() as u64), + test_body.size_hint().exact(), + "pre-condition check" + ); + + let http_client = TestClient { + response_body: Arc::new(Mutex::new(Some(test_body))), + ..Default::default() + }; + let sdk_config = SdkConfig::builder() + .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) + .region(Region::new("us-east-1")) + .http_client(http_client) + .build(); + let client = Client::new(&sdk_config); + let response = client + .get_object() + .bucket("foo") + .key("foo") + .send() + .await + .unwrap(); + assert_eq!( + ( + test_body_content.len() as u64, + Some(test_body_content.len() as u64), + ), + response.body.size_hint(), + "the size hint should be passed through all the default body wrappers" + ); +} + +#[tokio::test] +async fn upload_body_size_hint_check() { + let test_body_content = b"hello"; + + let http_client = TestClient::default(); + let sdk_config = SdkConfig::builder() + .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) + .region(Region::new("us-east-1")) + .http_client(http_client.clone()) + .build(); + let client = Client::new(&sdk_config); + let body = ByteStream::from_static(test_body_content); + assert_eq!( + ( + test_body_content.len() as u64, + Some(test_body_content.len() as u64), + ), + body.size_hint(), + "pre-condition check" + ); + let _response = client + .put_object() + .bucket("foo") + .key("foo") + .body(body) + .send() + .await + .unwrap(); + let captured_body = http_client.captured_body.lock().unwrap().take().unwrap(); + assert_eq!( + Some(test_body_content.len() as u64), + captured_body.size_hint().exact(), + "the size hint should be passed through all the default body wrappers" + ); +} diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs index f89b009eb8..a8f2fe9c4b 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs @@ -143,6 +143,14 @@ where let this = self.as_mut().project(); this.inner.poll_trailers(cx) } + + fn size_hint(&self) -> http_body_0_4::SizeHint { + self.inner.size_hint() + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } } impl Body for ThroughputReadingBody @@ -186,4 +194,12 @@ where let this = self.as_mut().project(); this.inner.poll_trailers(cx) } + + fn size_hint(&self) -> http_body_0_4::SizeHint { + self.inner.size_hint() + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } } From 79b6010f2ebf9e0e2370a4279f84c664369f6a65 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Mon, 25 Mar 2024 14:05:41 -0700 Subject: [PATCH 12/19] Disable upload protection in current BMV --- .../s3/tests/stalled-stream-protection.rs | 1 + ...lledStreamProtectionConfigCustomization.kt | 5 +- .../src/client/stalled_stream_protection.rs | 71 +++++++++--- .../aws-smithy-runtime/src/client/defaults.rs | 17 ++- .../src/client/stalled_stream_protection.rs | 109 ++++++++---------- .../tests/stalled_stream_common.rs | 4 +- .../tests/stalled_stream_download.rs | 4 +- .../tests/stalled_stream_upload.rs | 4 +- 8 files changed, 127 insertions(+), 88 deletions(-) diff --git a/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs b/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs index 6a8d0804d4..21a224adfa 100644 --- a/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs +++ b/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs @@ -95,6 +95,7 @@ async fn test_stalled_stream_protection_defaults_for_upload() { .credentials_provider(Credentials::for_tests()) .region(Region::new("us-east-1")) .endpoint_url(format!("http://{server_addr}")) + // TODO(https://github.com/smithy-lang/smithy-rs/issues/3510): make stalled stream protection enabled by default with BMV and remove this line .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build()) .build(); let client = Client::from_conf(conf); diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt index 3faeccff93..8304efc2c4 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt @@ -120,15 +120,12 @@ class StalledStreamProtectionOperationCustomization( is OperationSection.AdditionalInterceptors -> { val stalledStreamProtectionModule = RuntimeType.smithyRuntime(rc).resolve("client::stalled_stream_protection") section.registerInterceptor(rc, this) { - // Currently, only response bodies are protected/supported because - // we can't count on hyper to poll a request body on wake. rustTemplate( """ - #{StalledStreamProtectionInterceptor}::new(#{Kind}::RequestAndResponseBody) + #{StalledStreamProtectionInterceptor}::default() """, *preludeScope, "StalledStreamProtectionInterceptor" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptor"), - "Kind" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptorKind"), ) } } diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs b/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs index 25c9c5c67d..f90f886592 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs @@ -20,15 +20,17 @@ const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(5); /// When enabled, download streams that stall out will be cancelled. #[derive(Clone, Debug)] pub struct StalledStreamProtectionConfig { - is_enabled: bool, + upload_enabled: bool, + download_enabled: bool, grace_period: Duration, } impl StalledStreamProtectionConfig { - /// Create a new config that enables stalled stream protection. + /// Create a new config that enables stalled stream protection for both uploads and downloads. pub fn enabled() -> Builder { Builder { - is_enabled: Some(true), + upload_enabled: Some(true), + download_enabled: Some(true), grace_period: None, } } @@ -36,14 +38,25 @@ impl StalledStreamProtectionConfig { /// Create a new config that disables stalled stream protection. pub fn disabled() -> Self { Self { - is_enabled: false, + upload_enabled: false, + download_enabled: false, grace_period: DEFAULT_GRACE_PERIOD, } } - /// Return whether stalled stream protection is enabled. + /// Return whether stalled stream protection is enabled for either uploads or downloads. pub fn is_enabled(&self) -> bool { - self.is_enabled + self.upload_enabled || self.download_enabled + } + + /// True if stalled stream protection is enabled for upload streams. + pub fn upload_enabled(&self) -> bool { + self.upload_enabled + } + + /// True if stalled stream protection is enabled for download streams. + pub fn download_enabled(&self) -> bool { + self.download_enabled } /// Return the grace period for stalled stream protection. @@ -57,7 +70,8 @@ impl StalledStreamProtectionConfig { #[derive(Clone, Debug)] pub struct Builder { - is_enabled: Option, + upload_enabled: Option, + download_enabled: Option, grace_period: Option, } @@ -74,22 +88,48 @@ impl Builder { self } - /// Set whether stalled stream protection is enabled. - pub fn is_enabled(mut self, is_enabled: bool) -> Self { - self.is_enabled = Some(is_enabled); + /// Set whether stalled stream protection is enabled for both uploads and downloads. + pub fn is_enabled(mut self, enabled: bool) -> Self { + self.set_is_enabled(Some(enabled)); + self + } + + /// Set whether stalled stream protection is enabled for both uploads and downloads. + pub fn set_is_enabled(&mut self, enabled: Option) -> &mut Self { + self.set_upload_enabled(enabled); + self.set_download_enabled(enabled); + self + } + + /// Set whether stalled stream protection is enabled for upload streams. + pub fn upload_enabled(mut self, enabled: bool) -> Self { + self.set_upload_enabled(Some(enabled)); + self + } + + /// Set whether stalled stream protection is enabled for upload streams. + pub fn set_upload_enabled(&mut self, enabled: Option) -> &mut Self { + self.upload_enabled = enabled; + self + } + + /// Set whether stalled stream protection is enabled for download streams. + pub fn download_enabled(mut self, enabled: bool) -> Self { + self.set_download_enabled(Some(enabled)); self } - /// Set whether stalled stream protection is enabled. - pub fn set_is_enabled(&mut self, is_enabled: Option) -> &mut Self { - self.is_enabled = is_enabled; + /// Set whether stalled stream protection is enabled for download streams. + pub fn set_download_enabled(&mut self, enabled: Option) -> &mut Self { + self.download_enabled = enabled; self } /// Build the config. pub fn build(self) -> StalledStreamProtectionConfig { StalledStreamProtectionConfig { - is_enabled: self.is_enabled.unwrap_or_default(), + upload_enabled: self.upload_enabled.unwrap_or_default(), + download_enabled: self.download_enabled.unwrap_or_default(), grace_period: self.grace_period.unwrap_or(DEFAULT_GRACE_PERIOD), } } @@ -98,7 +138,8 @@ impl Builder { impl From for Builder { fn from(config: StalledStreamProtectionConfig) -> Self { Builder { - is_enabled: Some(config.is_enabled), + upload_enabled: Some(config.upload_enabled), + download_enabled: Some(config.download_enabled), grace_period: Some(config.grace_period), } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/defaults.rs b/rust-runtime/aws-smithy-runtime/src/client/defaults.rs index ca2d06a387..f13433ac0b 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/defaults.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/defaults.rs @@ -170,7 +170,16 @@ pub fn default_identity_cache_plugin() -> Option { /// /// By default, when throughput falls below 1/Bs for more than 5 seconds, the /// stream is cancelled. +#[deprecated( + since = "1.2.0", + note = "This function wasn't intended to be public, and didn't take the behavior major version as an argument, so it couldn't be evolved over time." +)] pub fn default_stalled_stream_protection_config_plugin() -> Option { + default_stalled_stream_protection_config_plugin_v2(BehaviorVersion::v2023_11_09()) +} +fn default_stalled_stream_protection_config_plugin_v2( + _behavior_version: BehaviorVersion, +) -> Option { Some( default_plugin( "default_stalled_stream_protection_config_plugin", @@ -183,6 +192,8 @@ pub fn default_stalled_stream_protection_config_plugin() -> Option impl IntoIterator { + let behavior_version = params + .behavior_version + .unwrap_or_else(BehaviorVersion::latest); + [ default_http_client_plugin(), default_identity_cache_plugin(), @@ -263,7 +278,7 @@ pub fn default_plugins( default_sleep_impl_plugin(), default_time_source_plugin(), default_timeout_config_plugin(), - default_stalled_stream_protection_config_plugin(), + default_stalled_stream_protection_config_plugin_v2(behavior_version), ] .into_iter() .flatten() diff --git a/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs index 97a023c55f..83cfb64752 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs @@ -21,14 +21,16 @@ use aws_smithy_types::config_bag::ConfigBag; use std::mem; /// Adds stalled stream protection when sending requests and/or receiving responses. -#[derive(Debug)] -pub struct StalledStreamProtectionInterceptor { - enable_for_request_body: bool, - enable_for_response_body: bool, -} +#[derive(Debug, Default)] +#[non_exhaustive] +pub struct StalledStreamProtectionInterceptor; /// Stalled stream protection can be enable for request bodies, response bodies, /// or both. +#[deprecated( + since = "1.2.0", + note = "This kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag." +)] pub enum StalledStreamProtectionInterceptorKind { /// Enable stalled stream protection for request bodies. RequestBody, @@ -40,18 +42,13 @@ pub enum StalledStreamProtectionInterceptorKind { impl StalledStreamProtectionInterceptor { /// Create a new stalled stream protection interceptor. - pub fn new(kind: StalledStreamProtectionInterceptorKind) -> Self { - use StalledStreamProtectionInterceptorKind::*; - let (enable_for_request_body, enable_for_response_body) = match kind { - RequestBody => (true, false), - ResponseBody => (false, true), - RequestAndResponseBody => (true, true), - }; - - Self { - enable_for_request_body, - enable_for_response_body, - } + #[deprecated( + since = "1.2.0", + note = "The kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag. Construct the interceptor using Default." + )] + #[allow(deprecated)] + pub fn new(_kind: StalledStreamProtectionInterceptorKind) -> Self { + Default::default() } } @@ -66,29 +63,26 @@ impl Intercept for StalledStreamProtectionInterceptor { runtime_components: &RuntimeComponents, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { - if self.enable_for_request_body { - if let Some(sspcfg) = cfg.load::().cloned() { - if sspcfg.is_enabled() { - let (_async_sleep, time_source) = - get_runtime_component_deps(runtime_components)?; - let now = time_source.now(); + if let Some(sspcfg) = cfg.load::().cloned() { + if sspcfg.upload_enabled() { + let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?; + let now = time_source.now(); - let options: MinimumThroughputBodyOptions = sspcfg.into(); - let throughput = UploadThroughput::new(options.check_window(), now); - cfg.interceptor_state().store_put(throughput.clone()); + let options: MinimumThroughputBodyOptions = sspcfg.into(); + let throughput = UploadThroughput::new(options.check_window(), now); + cfg.interceptor_state().store_put(throughput.clone()); - tracing::trace!("adding stalled stream protection to request body"); - let it = mem::replace(context.request_mut().body_mut(), SdkBody::taken()); - let it = it.map_preserve_contents(move |body| { - let time_source = time_source.clone(); - SdkBody::from_body_0_4(ThroughputReadingBody::new( - time_source, - throughput.clone(), - body, - )) - }); - let _ = mem::replace(context.request_mut().body_mut(), it); - } + tracing::trace!("adding stalled stream protection to request body"); + let it = mem::replace(context.request_mut().body_mut(), SdkBody::taken()); + let it = it.map_preserve_contents(move |body| { + let time_source = time_source.clone(); + SdkBody::from_body_0_4(ThroughputReadingBody::new( + time_source, + throughput.clone(), + body, + )) + }); + let _ = mem::replace(context.request_mut().body_mut(), it); } } @@ -101,28 +95,25 @@ impl Intercept for StalledStreamProtectionInterceptor { runtime_components: &RuntimeComponents, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { - if self.enable_for_response_body { - if let Some(sspcfg) = cfg.load::() { - if sspcfg.is_enabled() { - let (async_sleep, time_source) = - get_runtime_component_deps(runtime_components)?; - tracing::trace!("adding stalled stream protection to response body"); + if let Some(sspcfg) = cfg.load::() { + if sspcfg.download_enabled() { + let (async_sleep, time_source) = get_runtime_component_deps(runtime_components)?; + tracing::trace!("adding stalled stream protection to response body"); + let sspcfg = sspcfg.clone(); + let it = mem::replace(context.response_mut().body_mut(), SdkBody::taken()); + let it = it.map_preserve_contents(move |body| { let sspcfg = sspcfg.clone(); - let it = mem::replace(context.response_mut().body_mut(), SdkBody::taken()); - let it = it.map_preserve_contents(move |body| { - let sspcfg = sspcfg.clone(); - let async_sleep = async_sleep.clone(); - let time_source = time_source.clone(); - let mtb = MinimumThroughputDownloadBody::new( - time_source, - async_sleep, - body, - sspcfg.into(), - ); - SdkBody::from_body_0_4(mtb) - }); - let _ = mem::replace(context.response_mut().body_mut(), it); - } + let async_sleep = async_sleep.clone(); + let time_source = time_source.clone(); + let mtb = MinimumThroughputDownloadBody::new( + time_source, + async_sleep, + body, + sspcfg.into(), + ); + SdkBody::from_body_0_4(mtb) + }); + let _ = mem::replace(context.response_mut().body_mut(), it); } } Ok(()) diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs index 69c201eb0c..5559d382cf 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs @@ -15,9 +15,7 @@ pub use aws_smithy_runtime::{ assert_str_contains, client::{ orchestrator::operation::Operation, - stalled_stream_protection::{ - StalledStreamProtectionInterceptor, StalledStreamProtectionInterceptorKind, - }, + stalled_stream_protection::StalledStreamProtectionInterceptor, }, test_util::capture_test_logs::capture_test_logs, }; diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs index d538054e6c..01af3b51ce 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs @@ -235,9 +235,7 @@ mod download_test_tools { .grace_period(Duration::from_secs(5)) .build(), ) - .interceptor(StalledStreamProtectionInterceptor::new( - StalledStreamProtectionInterceptorKind::RequestAndResponseBody, - )) + .interceptor(StalledStreamProtectionInterceptor::default()) .sleep_impl(sleep) .time_source(time) .build(); diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs index 29eb12c62f..0fb4533e78 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs @@ -221,9 +221,7 @@ mod upload_test_tools { .grace_period(Duration::from_secs(5)) .build(), ) - .interceptor(StalledStreamProtectionInterceptor::new( - StalledStreamProtectionInterceptorKind::RequestAndResponseBody, - )) + .interceptor(StalledStreamProtectionInterceptor::default()) .sleep_impl(sleep) .time_source(time) .build(); From 4c670205c9ff711724cc53bc5de3125d070546ff Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Mon, 25 Mar 2024 15:04:57 -0700 Subject: [PATCH 13/19] Version bump runtime-api --- rust-runtime/aws-smithy-runtime-api/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-runtime/aws-smithy-runtime-api/Cargo.toml b/rust-runtime/aws-smithy-runtime-api/Cargo.toml index b0ee8e1e6b..ea9b6e970e 100644 --- a/rust-runtime/aws-smithy-runtime-api/Cargo.toml +++ b/rust-runtime/aws-smithy-runtime-api/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-runtime-api" -version = "1.2.0" +version = "1.3.0" authors = ["AWS Rust SDK Team ", "Zelda Hessler "] description = "Smithy runtime types." edition = "2021" From e25e049f9f139ca2eafdf2c868d2bff057138682 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Mon, 25 Mar 2024 15:12:24 -0700 Subject: [PATCH 14/19] Rename upload check futures --- .../src/client/http/body/minimum_throughput.rs | 16 ++++++++-------- .../src/client/orchestrator.rs | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs index 05973c336b..59c8a3c64c 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs @@ -213,7 +213,7 @@ impl UploadReport for ThroughputReport { pin_project_lite::pin_project! { /// Future that pairs with [`UploadThroughput`] to add a minimum throughput /// requirement to a request upload stream. - struct ThroughputCheckFuture { + struct UploadThroughputCheckFuture { #[pin] response: HttpConnectorFuture, #[pin] @@ -231,7 +231,7 @@ pin_project_lite::pin_project! { } } -impl ThroughputCheckFuture { +impl UploadThroughputCheckFuture { fn new( response: HttpConnectorFuture, time_source: SharedTimeSource, @@ -254,7 +254,7 @@ impl ThroughputCheckFuture { } } -impl Future for ThroughputCheckFuture { +impl Future for UploadThroughputCheckFuture { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -332,13 +332,13 @@ impl Future for ThroughputCheckFuture { pin_project_lite::pin_project! { #[project = EnumProj] - pub(crate) enum MaybeThroughputCheckFuture { + pub(crate) enum MaybeUploadThroughputCheckFuture { Direct { #[pin] future: HttpConnectorFuture }, - Checked { #[pin] future: ThroughputCheckFuture }, + Checked { #[pin] future: UploadThroughputCheckFuture }, } } -impl MaybeThroughputCheckFuture { +impl MaybeUploadThroughputCheckFuture { pub(crate) fn new( cfg: &mut ConfigBag, components: &RuntimeComponents, @@ -371,7 +371,7 @@ impl MaybeThroughputCheckFuture { (Some(time_source), Some(sleep_impl), Some(upload_throughput), Some(options)) => { tracing::debug!(options=?options, "applying minimum upload throughput check future"); Self::Checked { - future: ThroughputCheckFuture::new( + future: UploadThroughputCheckFuture::new( response, time_source, sleep_impl, @@ -385,7 +385,7 @@ impl MaybeThroughputCheckFuture { } } -impl Future for MaybeThroughputCheckFuture { +impl Future for MaybeUploadThroughputCheckFuture { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs index 0a8088a01a..112fbe85ba 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs @@ -8,7 +8,7 @@ use crate::client::interceptors::Interceptors; use crate::client::orchestrator::http::{log_response_body, read_body}; use crate::client::timeout::{MaybeTimeout, MaybeTimeoutConfig, TimeoutKind}; use crate::client::{ - http::body::minimum_throughput::MaybeThroughputCheckFuture, + http::body::minimum_throughput::MaybeUploadThroughputCheckFuture, orchestrator::endpoints::orchestrate_endpoint, }; use aws_smithy_async::rt::sleep::AsyncSleep; @@ -388,7 +388,7 @@ async fn try_attempt( builder.build() }; let connector = http_client.http_connector(&settings, runtime_components); - let response_future = MaybeThroughputCheckFuture::new( + let response_future = MaybeUploadThroughputCheckFuture::new( cfg, runtime_components, connector.call(request), From 3020312a54699c62c5d2c2b3b02b248d86e52376 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Mon, 25 Mar 2024 16:29:12 -0700 Subject: [PATCH 15/19] Simplify bin merging logic --- .../http/body/minimum_throughput/options.rs | 12 ++- .../body/minimum_throughput/throughput.rs | 99 +++++++++++-------- 2 files changed, 67 insertions(+), 44 deletions(-) diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs index 79d2ec5063..113461a31e 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs @@ -129,14 +129,18 @@ impl MinimumThroughputBodyOptionsBuilder { self } - /// No longer used. - #[deprecated(note = "No longer used.")] + /// No longer used. The check interval is now based on the check window (not currently configurable). + #[deprecated( + note = "No longer used. The check interval is now based on the check window (not currently configurable). Open an issue if you need to configure the check window." + )] pub fn check_interval(self, _check_interval: Duration) -> Self { self } - /// No longer used. - #[deprecated(note = "No longer used.")] + /// No longer used. The check interval is now based on the check window (not currently configurable). + #[deprecated( + note = "No longer used. The check interval is now based on the check window (not currently configurable). Open an issue if you need to configure the check window." + )] pub fn set_check_interval(&mut self, _check_interval: Option) -> &mut Self { self } diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs index 0c31a079cc..57ea3318e7 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs @@ -97,9 +97,12 @@ impl From<(u64, Duration)> for Throughput { } } -/// Represents a bin (or a cell) in a linear grid that represents a small chunk of time. -#[derive(Copy, Clone, Debug)] -enum Bin { +/// Overall label for a given bin. +#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)] +enum BinLabel { + // IMPORTANT: The order of these enums matters since it represents their priority: + // Pending > TransferredBytes > NoPolling > Empty + // /// There is no data in this bin. Empty, @@ -107,37 +110,51 @@ enum Bin { NoPolling, /// This many bytes were transferred during this bin. - TransferredBytes(u64), + TransferredBytes, /// The user/remote was not providing/consuming data fast enough during this bin. /// /// The number is the number of bytes transferred, if this replaced TransferredBytes. - Pending(u64), + Pending, +} + +/// Represents a bin (or a cell) in a linear grid that represents a small chunk of time. +#[derive(Copy, Clone, Debug)] +struct Bin { + label: BinLabel, + bytes: u64, } + impl Bin { - fn merge(&mut self, other: Bin) { - use Bin::*; + const fn new(label: BinLabel, bytes: u64) -> Self { + Self { label, bytes } + } + const fn empty() -> Self { + Self::new(BinLabel::Empty, 0) + } + + fn is_empty(&self) -> bool { + matches!(self.label, BinLabel::Empty) + } + + fn merge(&mut self, other: Bin) -> &mut Self { // Assign values based on this priority order (highest priority higher up): // 1. Pending // 2. TransferredBytes // 3. NoPolling // 4. Empty - *self = match (other, *self) { - (Pending(other), this) => Pending(other + this.bytes()), - (TransferredBytes(other), this) => TransferredBytes(other + this.bytes()), - (other, NoPolling) => other, - (NoPolling, _) => panic!("can't merge NoPolling into bin"), - (other, Empty) => other, - (Empty, _) => panic!("can't merge Empty into bin"), + self.label = if other.label > self.label { + other.label + } else { + self.label }; + self.bytes += other.bytes; + self } /// Number of bytes transferred during this bin fn bytes(&self) -> u64 { - match *self { - Bin::Empty | Bin::NoPolling => 0, - Bin::TransferredBytes(bytes) | Bin::Pending(bytes) => bytes, - } + self.bytes } } @@ -167,7 +184,7 @@ struct LogBuffer { impl LogBuffer { fn new() -> Self { Self { - entries: [Bin::Empty; N], + entries: [Bin::empty(); N], length: 0, } } @@ -176,7 +193,7 @@ impl LogBuffer { /// /// ## Panics /// - /// The buffer MUST have at least one bin it it before this is called. + /// The buffer MUST have at least one bin in it before this is called. fn tail_mut(&mut self) -> &mut Bin { debug_assert!(self.length > 0); &mut self.entries[self.length - 1] @@ -211,8 +228,8 @@ impl LogBuffer { /// way we can know about these is by examining gaps in time. fn fill_gaps(&mut self) { for entry in self.entries.iter_mut().take(self.length) { - if matches!(entry, Bin::Empty) { - *entry = Bin::NoPolling; + if entry.is_empty() { + *entry = Bin::new(BinLabel::NoPolling, 0); } } } @@ -221,11 +238,11 @@ impl LogBuffer { fn counts(&self) -> BinCounts { let mut counts = BinCounts::default(); for entry in &self.entries { - match entry { - Bin::Empty => counts.empty += 1, - Bin::NoPolling => counts.no_polling += 1, - Bin::TransferredBytes(_) => counts.transferred += 1, - Bin::Pending(_) => counts.pending += 1, + match entry.label { + BinLabel::Empty => counts.empty += 1, + BinLabel::NoPolling => counts.no_polling += 1, + BinLabel::TransferredBytes => counts.transferred += 1, + BinLabel::Pending => counts.pending += 1, } } counts @@ -245,7 +262,7 @@ pub(crate) enum ThroughputReport { Transferred(Throughput), } -const BUFFER_SIZE: usize = 10; +const BIN_COUNT: usize = 10; /// Log of throughput in a request or response stream. /// @@ -267,20 +284,20 @@ const BUFFER_SIZE: usize = 10; pub(super) struct ThroughputLogs { resolution: Duration, current_tail: SystemTime, - buffer: LogBuffer, + buffer: LogBuffer, } impl ThroughputLogs { /// Creates a new log starting at `now` with the given `time_window`. /// /// Note: the `time_window` gets divided by 10 to create smaller sub-windows - /// to track throughput. The time window should configured to be large enough + /// to track throughput. The time window should be configured to be large enough /// so that these sub-windows aren't too small for network-based events. /// A time window of 10ms probably won't work, but 500ms might. The default /// is one second. pub(super) fn new(time_window: Duration, now: SystemTime) -> Self { assert!(!time_window.is_zero()); - let resolution = time_window.div_f64(BUFFER_SIZE as f64); + let resolution = time_window.div_f64(BIN_COUNT as f64); Self { resolution, current_tail: now, @@ -289,6 +306,8 @@ impl ThroughputLogs { } /// Returns the resolution at which events are logged at. + /// + /// The resolution is the number of bins in the time window. pub(super) fn resolution(&self) -> Duration { self.resolution } @@ -299,14 +318,14 @@ impl ThroughputLogs { /// In an upload, it is waiting for data from the user, and in a download, /// it is waiting for data from the server. pub(super) fn push_pending(&mut self, time: SystemTime) { - self.push(time, Bin::Pending(0)); + self.push(time, Bin::new(BinLabel::Pending, 0)); } /// Pushes a data transferred event. /// /// Indicates that this number of bytes were transferred at this time. pub(super) fn push_bytes_transferred(&mut self, time: SystemTime, bytes: u64) { - self.push(time, Bin::TransferredBytes(bytes)); + self.push(time, Bin::new(BinLabel::TransferredBytes, bytes)); } fn push(&mut self, now: SystemTime, value: Bin) { @@ -319,7 +338,7 @@ impl ThroughputLogs { fn catch_up(&mut self, now: SystemTime) { while now >= self.current_tail { self.current_tail += self.resolution; - self.buffer.push(Bin::Empty); + self.buffer.push(Bin::empty()); } assert!(self.current_tail >= now); } @@ -343,10 +362,10 @@ impl ThroughputLogs { } let bytes = self.buffer.bytes_transferred(); - let time = self.resolution * (BUFFER_SIZE - empty) as u32; + let time = self.resolution * (BIN_COUNT - empty) as u32; let throughput = Throughput::new(bytes, time); - let half = BUFFER_SIZE / 2; + let half = BIN_COUNT / 2; match (transferred > 0, no_polling >= half, pending >= half) { (true, _, _) => ThroughputReport::Transferred(throughput), (_, true, _) => ThroughputReport::NoPolling, @@ -406,7 +425,7 @@ mod test { let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); let mut now = start; - for i in 1..=BUFFER_SIZE { + for i in 1..=BIN_COUNT { logs.push_pending(now); now += logs.resolution(); @@ -423,7 +442,7 @@ mod test { let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); let mut now = start; - for i in 1..BUFFER_SIZE { + for i in 1..BIN_COUNT { now += logs.resolution(); logs.push_pending(now); @@ -446,9 +465,9 @@ mod test { let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); let mut now = start; - for i in 1..=BUFFER_SIZE { + for i in 1..=BIN_COUNT { logs.push_bytes_transferred(now, 10); - if i != BUFFER_SIZE { + if i != BIN_COUNT { now += logs.resolution(); } From adcb87f5aefc93aae25b5c82d863df4e3251bcac Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Mon, 25 Mar 2024 16:48:17 -0700 Subject: [PATCH 16/19] Create a new `show_test_logs` fn --- .../src/test_util/capture_test_logs.rs | 53 +++++++++++++------ .../tests/stalled_stream_common.rs | 2 +- .../tests/stalled_stream_download.rs | 12 ++--- .../tests/stalled_stream_upload.rs | 14 ++--- 4 files changed, 50 insertions(+), 31 deletions(-) diff --git a/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs b/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs index b55c0c68ac..2046b8c3a7 100644 --- a/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs +++ b/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs @@ -12,7 +12,37 @@ use tracing_subscriber::fmt::TestWriter; /// A guard that resets log capturing upon being dropped. #[derive(Debug)] -pub struct LogCaptureGuard(#[allow(dead_code)] DefaultGuard); +pub struct LogCaptureGuard(#[allow(dead_code)] Option); + +/// Enables output of test logs to stdout. +/// +/// The `VERBOSE_TEST_LOGS` environment variable acts as a +/// tracing_subscriber fmt env filter. You can give it full env filter +/// expressions, or just simply give it a log level (e.g., tracing, debug, info, etc). +/// Setting it to "1" or "true" will enable trace logging. +#[must_use] +pub fn show_test_logs() -> LogCaptureGuard { + let (mut writer, _rx) = Tee::stdout(); + let env_var = env::var("VERBOSE_TEST_LOGS").ok(); + let env_filter = match env_var.as_deref() { + Some("true") | Some("1") => Some("trace"), + Some(filter) => Some(filter), + None => None, + }; + if let Some(env_filter) = env_filter { + eprintln!("Enabled verbose test logging with env filter {env_filter:?}."); + writer.loud(); + + let subscriber = tracing_subscriber::fmt() + .with_env_filter(env_filter) + .with_writer(Mutex::new(writer)) + .finish(); + let guard = tracing::subscriber::set_default(subscriber); + LogCaptureGuard(Some(guard)) + } else { + LogCaptureGuard(None) + } +} /// Capture logs from this test. /// @@ -24,29 +54,18 @@ pub struct LogCaptureGuard(#[allow(dead_code)] DefaultGuard); pub fn capture_test_logs() -> (LogCaptureGuard, Rx) { // it may be helpful to upstream this at some point let (mut writer, rx) = Tee::stdout(); - let (enabled, level) = match env::var("VERBOSE_TEST_LOGS").ok().as_deref() { - Some("debug") => (true, Level::DEBUG), - Some("error") => (true, Level::ERROR), - Some("info") => (true, Level::INFO), - Some("warn") => (true, Level::WARN), - Some("trace") | Some(_) => (true, Level::TRACE), - None => (false, Level::TRACE), - }; - if enabled { - eprintln!("Enabled verbose test logging at {level:?}."); + if env::var("VERBOSE_TEST_LOGS").is_ok() { + eprintln!("Enabled verbose test logging."); writer.loud(); } else { - eprintln!( - "To see full logs from this test set VERBOSE_TEST_LOGS=true \ - (or to a log level, e.g., trace, debug, info, etc)" - ); + eprintln!("To see full logs from this test set VERBOSE_TEST_LOGS=true"); } let subscriber = tracing_subscriber::fmt() - .with_max_level(level) + .with_max_level(Level::TRACE) .with_writer(Mutex::new(writer)) .finish(); let guard = tracing::subscriber::set_default(subscriber); - (LogCaptureGuard(guard), rx) + (LogCaptureGuard(Some(guard)), rx) } /// Receiver for the captured logs. diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs index 5559d382cf..3596fa2e38 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs @@ -17,7 +17,7 @@ pub use aws_smithy_runtime::{ orchestrator::operation::Operation, stalled_stream_protection::StalledStreamProtectionInterceptor, }, - test_util::capture_test_logs::capture_test_logs, + test_util::capture_test_logs::show_test_logs, }; pub use aws_smithy_runtime_api::{ box_error::BoxError, diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs index 01af3b51ce..54e953322c 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs @@ -15,7 +15,7 @@ use stalled_stream_common::*; /// Expected: MUST NOT timeout. #[tokio::test] async fn download_success() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); let (time, sleep) = tick_advance_time_and_sleep(); let (server, response_sender) = channel_server(); @@ -41,7 +41,7 @@ async fn download_success() { /// Expected: MUT NOT timeout. #[tokio::test] async fn download_slow_start() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); let (time, sleep) = tick_advance_time_and_sleep(); let (server, response_sender) = channel_server(); @@ -69,7 +69,7 @@ async fn download_slow_start() { /// Expected: MUST timeout. #[tokio::test] async fn download_too_slow() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); let (time, sleep) = tick_advance_time_and_sleep(); let (server, response_sender) = channel_server(); @@ -100,7 +100,7 @@ async fn download_too_slow() { /// Expected: MUST timeout. #[tokio::test] async fn download_stalls() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); let (time, sleep) = tick_advance_time_and_sleep(); let (server, response_sender) = channel_server(); @@ -133,7 +133,7 @@ async fn download_stalls() { /// Expected: MUST NOT timeout. #[tokio::test] async fn download_stall_recovery_in_grace_period() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); let (time, sleep) = tick_advance_time_and_sleep(); let (server, response_sender) = channel_server(); @@ -167,7 +167,7 @@ async fn download_stall_recovery_in_grace_period() { /// Expected: MUST NOT timeout. #[tokio::test] async fn user_downloads_data_too_slowly() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); let (time, sleep) = tick_advance_time_and_sleep(); let (server, response_sender) = channel_server(); diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs index 0fb4533e78..f64fa321b2 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs @@ -13,7 +13,7 @@ use stalled_stream_common::*; /// Expected: MUST NOT timeout. #[tokio::test] async fn upload_success() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); let (server, time, sleep) = eager_server(true); let op = operation(server, time, sleep); @@ -33,7 +33,7 @@ async fn upload_success() { /// Expected: MUST NOT timeout. #[tokio::test] async fn upload_slow_start() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); let (server, time, sleep) = eager_server(false); let op = operation(server, time.clone(), sleep); @@ -61,7 +61,7 @@ async fn upload_slow_start() { /// Expected: MUST timeout. #[tokio::test] async fn upload_too_slow() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); // Server that starts off fast enough, but gets slower over time until it should timeout. let (server, time, sleep) = time_sequence_server([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); @@ -86,7 +86,7 @@ async fn upload_too_slow() { /// Expected: MUST timeout after the grace period completes. #[tokio::test] async fn upload_stalls() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); let (server, time, sleep) = stalling_server(); let op = operation(server, time.clone(), sleep); @@ -112,7 +112,7 @@ async fn upload_stalls() { /// Expected: MUST timeout after the grace period completes. #[tokio::test] async fn complete_upload_no_response() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); let (server, time, sleep) = stalling_server(); let op = operation(server, time.clone(), sleep); @@ -136,7 +136,7 @@ async fn complete_upload_no_response() { // Expected: MUST NOT timeout. #[tokio::test] async fn upload_stall_recovery_in_grace_period() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); // Server starts off fast enough, but then slows down almost up to // the grace period, and then recovers. @@ -163,7 +163,7 @@ async fn upload_stall_recovery_in_grace_period() { // Expected: MUST NOT timeout. #[tokio::test] async fn user_provides_data_too_slowly() { - let _logs = capture_test_logs(); + let _logs = show_test_logs(); let (server, time, sleep) = eager_server(false); let op = operation(server, time.clone(), sleep.clone()); From d960f504f45323e7713e255851764571d28311b2 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Mon, 25 Mar 2024 16:54:21 -0700 Subject: [PATCH 17/19] Improve the changelog entries --- CHANGELOG.next.toml | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 560279ffd9..468aed73c5 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -12,13 +12,25 @@ # author = "rcoh" [[smithy-rs]] -message = "Stalled stream protection now supports request upload streams, and existing download stream protection no longer triggers when user-code is doing something other than downloading in the middle of the stream (for example, making a HTTP request to another service for each row in CSV file download from S3)." +message = "Stalled stream protection now supports request upload streams." references = ["smithy-rs#3485"] -meta = { "breaking" = false, "tada" = true, "bug" = true } +meta = { "breaking" = false, "tada" = true, "bug" = false } authors = ["jdisanti"] [[aws-sdk-rust]] -message = "Stalled stream protection now supports request upload streams, and existing download stream protection no longer triggers when user-code is doing something other than downloading in the middle of the stream (for example, making a HTTP request to another service for each row in CSV file download from S3)." +message = "Stalled stream protection now supports request upload streams." references = ["smithy-rs#3485"] -meta = { "breaking" = false, "tada" = true, "bug" = true } +meta = { "breaking" = false, "tada" = true, "bug" = false } +author = "jdisanti" + +[[smithy-rs]] +message = "Stalled stream protection for download streams no longer triggers when user-code is taking a significant amount of time doing something other than downloading in the middle of the stream. As an example, previously, making a HTTP request to another service for each row in CSV file download from S3 would result in a stalled stream timeout." +references = ["smithy-rs#3485"] +meta = { "breaking" = false, "tada" = false, "bug" = true } +authors = ["jdisanti"] + +[[aws-sdk-rust]] +message = "Stalled stream protection for download streams no longer triggers when user-code is taking a significant amount of time doing something other than downloading in the middle of the stream." +references = ["smithy-rs#3485"] +meta = { "breaking" = false, "tada" = false, "bug" = true } author = "jdisanti" From 3ca765afc82e63e4142bb3675157bc3adaf571d4 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Wed, 27 Mar 2024 11:04:31 -0700 Subject: [PATCH 18/19] Incorporate feedback --- CHANGELOG.next.toml | 26 ++++++-- .../s3/tests/body_size_hint.rs | 63 +++++-------------- .../src/test_util/capture_test_logs.rs | 43 ++++++------- 3 files changed, 55 insertions(+), 77 deletions(-) diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 669e64e715..545660b853 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -30,25 +30,43 @@ meta = { "breaking" = true, "tada" = false, "bug" = true, "target" = "client" } author = "Ten0" [[smithy-rs]] -message = "Stalled stream protection now supports request upload streams." +message = """ +Stalled stream protection now supports request upload streams. It is currently off by default, but will be enabled by default in a future release. To enable it now, you can do the following: + +```rust +let config = my_service::Config::builder() + .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build()) + // ... + .build(); +``` +""" references = ["smithy-rs#3485"] meta = { "breaking" = false, "tada" = true, "bug" = false } authors = ["jdisanti"] [[aws-sdk-rust]] -message = "Stalled stream protection now supports request upload streams." +message = """ +Stalled stream protection now supports request upload streams. It is currently off by default, but will be enabled by default in a future release. To enable it now, you can do the following: + +```rust +let config = aws_config::defaults(BehaviorVersion::latest()) + .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build()) + .load() + .await; +``` +""" references = ["smithy-rs#3485"] meta = { "breaking" = false, "tada" = true, "bug" = false } author = "jdisanti" [[smithy-rs]] -message = "Stalled stream protection for download streams no longer triggers when user-code is taking a significant amount of time doing something other than downloading in the middle of the stream. As an example, previously, making a HTTP request to another service for each row in CSV file download from S3 would result in a stalled stream timeout." +message = "Stalled stream protection on downloads will now only trigger if the upstream source is too slow. Previously, stalled stream protection could be erroneously triggered if the user was slowly consuming the stream slower than the minimum speed limit." references = ["smithy-rs#3485"] meta = { "breaking" = false, "tada" = false, "bug" = true } authors = ["jdisanti"] [[aws-sdk-rust]] -message = "Stalled stream protection for download streams no longer triggers when user-code is taking a significant amount of time doing something other than downloading in the middle of the stream." +message = "Stalled stream protection on downloads will now only trigger if the upstream source is too slow. Previously, stalled stream protection could be erroneously triggered if the user was slowly consuming the stream slower than the minimum speed limit." references = ["smithy-rs#3485"] meta = { "breaking" = false, "tada" = false, "bug" = true } author = "jdisanti" diff --git a/aws/sdk/integration-tests/s3/tests/body_size_hint.rs b/aws/sdk/integration-tests/s3/tests/body_size_hint.rs index 270a87f902..97e9ac7234 100644 --- a/aws/sdk/integration-tests/s3/tests/body_size_hint.rs +++ b/aws/sdk/integration-tests/s3/tests/body_size_hint.rs @@ -7,61 +7,29 @@ use aws_config::SdkConfig; use aws_sdk_s3::{ - config::{Credentials, HttpClient, Region, RuntimeComponents, SharedCredentialsProvider}, + config::{Credentials, Region, SharedCredentialsProvider}, primitives::{ByteStream, SdkBody}, Client, }; -use aws_smithy_runtime_api::{ - client::{ - http::{HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector}, - orchestrator::HttpRequest, - }, - http::{Response, StatusCode}, -}; +use aws_smithy_runtime::client::http::test_util::{capture_request, infallible_client_fn}; use http_body::Body; -use std::sync::{Arc, Mutex}; - -#[derive(Clone, Debug, Default)] -struct TestClient { - response_body: Arc>>, - captured_body: Arc>>, -} -impl HttpConnector for TestClient { - fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture { - *self.captured_body.lock().unwrap() = Some(request.take_body()); - let body = self - .response_body - .lock() - .unwrap() - .take() - .unwrap_or_else(SdkBody::empty); - HttpConnectorFuture::ready(Ok(Response::new(StatusCode::try_from(200).unwrap(), body))) - } -} -impl HttpClient for TestClient { - fn http_connector( - &self, - _settings: &HttpConnectorSettings, - _components: &RuntimeComponents, - ) -> SharedHttpConnector { - SharedHttpConnector::new(self.clone()) - } -} #[tokio::test] async fn download_body_size_hint_check() { let test_body_content = b"hello"; - let test_body = SdkBody::from(&test_body_content[..]); + let test_body = || SdkBody::from(&test_body_content[..]); assert_eq!( Some(test_body_content.len() as u64), - test_body.size_hint().exact(), + (test_body)().size_hint().exact(), "pre-condition check" ); - let http_client = TestClient { - response_body: Arc::new(Mutex::new(Some(test_body))), - ..Default::default() - }; + let http_client = infallible_client_fn(move |_| { + http::Response::builder() + .status(200) + .body((test_body)()) + .unwrap() + }); let sdk_config = SdkConfig::builder() .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) .region(Region::new("us-east-1")) @@ -89,11 +57,11 @@ async fn download_body_size_hint_check() { async fn upload_body_size_hint_check() { let test_body_content = b"hello"; - let http_client = TestClient::default(); + let (http_client, rx) = capture_request(None); let sdk_config = SdkConfig::builder() .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) .region(Region::new("us-east-1")) - .http_client(http_client.clone()) + .http_client(http_client) .build(); let client = Client::new(&sdk_config); let body = ByteStream::from_static(test_body_content); @@ -111,12 +79,11 @@ async fn upload_body_size_hint_check() { .key("foo") .body(body) .send() - .await - .unwrap(); - let captured_body = http_client.captured_body.lock().unwrap().take().unwrap(); + .await; + let captured_request = rx.expect_request(); assert_eq!( Some(test_body_content.len() as u64), - captured_body.size_hint().exact(), + captured_request.body().size_hint().exact(), "the size hint should be passed through all the default body wrappers" ); } diff --git a/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs b/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs index 2046b8c3a7..d5447c98ab 100644 --- a/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs +++ b/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs @@ -12,36 +12,29 @@ use tracing_subscriber::fmt::TestWriter; /// A guard that resets log capturing upon being dropped. #[derive(Debug)] -pub struct LogCaptureGuard(#[allow(dead_code)] Option); +pub struct LogCaptureGuard(#[allow(dead_code)] DefaultGuard); -/// Enables output of test logs to stdout. +/// Enables output of test logs to stdout at trace level by default. /// -/// The `VERBOSE_TEST_LOGS` environment variable acts as a -/// tracing_subscriber fmt env filter. You can give it full env filter -/// expressions, or just simply give it a log level (e.g., tracing, debug, info, etc). -/// Setting it to "1" or "true" will enable trace logging. +/// The env filter can be changed with the `RUST_LOG` environment variable. #[must_use] pub fn show_test_logs() -> LogCaptureGuard { let (mut writer, _rx) = Tee::stdout(); - let env_var = env::var("VERBOSE_TEST_LOGS").ok(); - let env_filter = match env_var.as_deref() { - Some("true") | Some("1") => Some("trace"), - Some(filter) => Some(filter), - None => None, - }; - if let Some(env_filter) = env_filter { - eprintln!("Enabled verbose test logging with env filter {env_filter:?}."); - writer.loud(); + writer.loud(); - let subscriber = tracing_subscriber::fmt() - .with_env_filter(env_filter) - .with_writer(Mutex::new(writer)) - .finish(); - let guard = tracing::subscriber::set_default(subscriber); - LogCaptureGuard(Some(guard)) - } else { - LogCaptureGuard(None) - } + let env_var = env::var("RUST_LOG").ok(); + let env_filter = env_var.as_deref().unwrap_or("trace"); + eprintln!( + "Enabled verbose test logging with env filter {env_filter:?}. \ + You can change the env filter with the RUST_LOG environment variable." + ); + + let subscriber = tracing_subscriber::fmt() + .with_env_filter(env_filter) + .with_writer(Mutex::new(writer)) + .finish(); + let guard = tracing::subscriber::set_default(subscriber); + LogCaptureGuard(guard) } /// Capture logs from this test. @@ -65,7 +58,7 @@ pub fn capture_test_logs() -> (LogCaptureGuard, Rx) { .with_writer(Mutex::new(writer)) .finish(); let guard = tracing::subscriber::set_default(subscriber); - (LogCaptureGuard(Some(guard)), rx) + (LogCaptureGuard(guard), rx) } /// Receiver for the captured logs. From e7d7f442d2fd398e984aa2afec8cd8a77dfb2209 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Wed, 27 Mar 2024 11:15:17 -0700 Subject: [PATCH 19/19] Fix feature issue --- rust-runtime/aws-smithy-runtime/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-runtime/aws-smithy-runtime/Cargo.toml b/rust-runtime/aws-smithy-runtime/Cargo.toml index 25b051cd15..457dbf4d25 100644 --- a/rust-runtime/aws-smithy-runtime/Cargo.toml +++ b/rust-runtime/aws-smithy-runtime/Cargo.toml @@ -43,7 +43,7 @@ serde_json = { version = "1", features = ["preserve_order"], optional = true } indexmap = { version = "2", optional = true, features = ["serde"] } tokio = { version = "1.25", features = [] } tracing = "0.1.37" -tracing-subscriber = { version = "0.3.16", optional = true, features = ["fmt", "json"] } +tracing-subscriber = { version = "0.3.16", optional = true, features = ["env-filter", "fmt", "json"] } [dev-dependencies] approx = "0.5.1"