From 6b308e436de27b0770b3b736361d4bd91bd7742d Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Tue, 8 Aug 2023 17:17:29 -0500 Subject: [PATCH 01/36] Remove `futures_core::stream::Stream` from `smithy-async` --- rust-runtime/aws-smithy-async/Cargo.toml | 2 +- .../aws-smithy-async/external-types.toml | 3 - .../aws-smithy-async/src/future/fn_stream.rs | 202 +++++++++++------- .../src/future/fn_stream/collect.rs | 83 +++++++ .../aws-smithy-async/src/future/rendezvous.rs | 20 +- 5 files changed, 218 insertions(+), 92 deletions(-) create mode 100644 rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs diff --git a/rust-runtime/aws-smithy-async/Cargo.toml b/rust-runtime/aws-smithy-async/Cargo.toml index c95862d9ff..6b3bca322d 100644 --- a/rust-runtime/aws-smithy-async/Cargo.toml +++ b/rust-runtime/aws-smithy-async/Cargo.toml @@ -14,10 +14,10 @@ test-util = [] [dependencies] pin-project-lite = "0.2" tokio = { version = "1.23.1", features = ["sync"] } -tokio-stream = { version = "0.1.5", default-features = false } futures-util = { version = "0.3.16", default-features = false } [dev-dependencies] +pin-utils = "0.1" tokio = { version = "1.23.1", features = ["rt", "macros", "test-util"] } tokio-test = "0.4.2" diff --git a/rust-runtime/aws-smithy-async/external-types.toml b/rust-runtime/aws-smithy-async/external-types.toml index 424f7dc1db..464456a2dc 100644 --- a/rust-runtime/aws-smithy-async/external-types.toml +++ b/rust-runtime/aws-smithy-async/external-types.toml @@ -2,7 +2,4 @@ allowed_external_types = [ "aws_smithy_types::config_bag::storable::Storable", "aws_smithy_types::config_bag::storable::StoreReplace", "aws_smithy_types::config_bag::storable::Storer", - - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Switch to AsyncIterator once standardized - "futures_core::stream::Stream", ] diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs index 804b08f6bb..75a4b375cd 100644 --- a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs @@ -6,12 +6,14 @@ //! Utility to drive a stream with an async function and a channel. use crate::future::rendezvous; -use futures_util::StreamExt; use pin_project_lite::pin_project; +use std::fmt; +use std::future::poll_fn; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio_stream::{Iter, Once, Stream}; + +pub mod collect; pin_project! { /// Utility to drive a stream with an async function and a channel. @@ -26,10 +28,9 @@ pin_project! { /// /// # Examples /// ```no_run - /// use tokio_stream::StreamExt; /// # async fn docs() { /// use aws_smithy_async::future::fn_stream::FnStream; - /// let stream = FnStream::new(|tx| Box::pin(async move { + /// let mut stream = FnStream::new(|tx| Box::pin(async move { /// if let Err(_) = tx.send("Hello!").await { /// return; /// } @@ -39,52 +40,86 @@ pin_project! { /// })); /// assert_eq!(stream.collect::>().await, vec!["Hello!", "Goodbye!"]); /// # } - pub struct FnStream { + pub struct FnStream { #[pin] rx: rendezvous::Receiver, - #[pin] - generator: Option, + generator: Option + Send + 'static>>>, + } +} + +impl fmt::Debug for FnStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let item_typename = std::any::type_name::(); + write!(f, "FnStream<{item_typename}>") } } -impl FnStream { +impl FnStream { /// Creates a new function based stream driven by `generator`. /// /// For examples, see the documentation for [`FnStream`] pub fn new(generator: T) -> Self where - T: FnOnce(rendezvous::Sender) -> F, + T: FnOnce(rendezvous::Sender) -> Pin + Send + 'static>>, { let (tx, rx) = rendezvous::channel::(); Self { rx, - generator: Some(generator(tx)), + generator: Some(Box::pin(generator(tx))), } } -} -impl Stream for FnStream -where - F: Future, -{ - type Item = Item; + /// Creates unreadable `FnStream` but useful to pass to `std::mem::swap` when extracting an + /// owned `FnStream` from a mutable reference. + pub fn taken() -> Self { + Self::new(|_tx| Box::pin(async move {})) + } + + /// Consumes and returns the next `Item` from this stream. + pub async fn next(&mut self) -> Option + where + Self: Unpin, + { + let mut me = Pin::new(self); + poll_fn(|cx| me.as_mut().poll_next(cx)).await + } - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + /// Attempts to pull out the next value of this stream, returning `None` if the stream is + /// exhausted. + pub fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut me = self.project(); match me.rx.poll_recv(cx) { Poll::Ready(item) => Poll::Ready(item), Poll::Pending => { - if let Some(generator) = me.generator.as_mut().as_pin_mut() { - if generator.poll(cx).is_ready() { + if let Some(generator) = me.generator { + if generator.as_mut().poll(cx).is_ready() { // if the generator returned ready we MUST NOT poll it again—doing so // will cause a panic. - me.generator.set(None); + *me.generator = None; } } Poll::Pending } } } + + /// Consumes this stream and gathers elements into a collection. + pub async fn collect>(mut self) -> T { + let mut collection = T::initialize(); + while let Some(item) = self.next().await { + if !T::extend(&mut collection, item) { + break; + } + } + T::finalize(&mut collection) + } +} + +impl FnStream> { + /// Yields the next item in the stream or returns an error if an error is encountered. + pub async fn try_next(&mut self) -> Result, E> { + self.next().await.transpose() + } } /// Utility wrapper to flatten paginated results @@ -93,62 +128,50 @@ where /// is present in each item. This provides `items()` which can wrap an stream of `Result` /// and produce a stream of `Result`. #[derive(Debug)] -pub struct TryFlatMap(I); +pub struct TryFlatMap(FnStream>); -impl TryFlatMap { - /// Create a `TryFlatMap` that wraps the input - pub fn new(i: I) -> Self { - Self(i) +impl TryFlatMap { + /// Creates a `TryFlatMap` that wraps the input. + pub fn new(stream: FnStream>) -> Self { + Self(stream) } - /// Produce a new [`Stream`] by mapping this stream with `map` then flattening the result - pub fn flat_map(self, map: M) -> impl Stream> + /// Produces a new [`FnStream`] by mapping this stream with `map` then flattening the result. + pub fn flat_map(mut self, map: M) -> FnStream> where - I: Stream>, - M: Fn(Page) -> Iter, - Iter: IntoIterator, + Page: Send + 'static, + Err: Send + 'static, + M: Fn(Page) -> Iter + Send + 'static, + Item: Send + 'static, + Iter: IntoIterator + Send, + ::IntoIter: Send, { - self.0.flat_map(move |page| match page { - Ok(page) => OnceOrMany::Many { - many: tokio_stream::iter(map(page).into_iter().map(Ok)), - }, - Err(e) => OnceOrMany::Once { - once: tokio_stream::once(Err(e)), - }, + FnStream::new(|tx| { + Box::pin(async move { + while let Some(page) = self.0.next().await { + match page { + Ok(page) => { + let mapped = map(page); + for item in mapped.into_iter() { + let _ = tx.send(Ok(item)).await; + } + } + Err(e) => { + let _ = tx.send(Err(e)).await; + break; + } + } + } + }) as Pin + Send>> }) } } -pin_project! { - /// Helper enum to to support returning `Once` and `Iter` from `Items::items` - #[project = OnceOrManyProj] - enum OnceOrMany { - Many { #[pin] many: Iter }, - Once { #[pin] once: Once }, - } -} - -impl Stream for OnceOrMany -where - Iter: Iterator, -{ - type Item = Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let me = self.project(); - match me { - OnceOrManyProj::Many { many } => many.poll_next(cx), - OnceOrManyProj::Once { once } => once.poll_next(cx), - } - } -} - #[cfg(test)] mod test { use crate::future::fn_stream::{FnStream, TryFlatMap}; use std::sync::{Arc, Mutex}; use std::time::Duration; - use tokio_stream::StreamExt; /// basic test of FnStream functionality #[tokio::test] @@ -168,7 +191,24 @@ mod test { while let Some(value) = stream.next().await { out.push(value); } - assert_eq!(out, vec!["1", "2", "3"]); + assert_eq!(vec!["1", "2", "3"], out); + } + + #[tokio::test] + async fn fn_stream_try_next() { + tokio::time::pause(); + let mut stream = FnStream::new(|tx| { + Box::pin(async move { + tx.send(Ok(1)).await.unwrap(); + tx.send(Ok(2)).await.unwrap(); + tx.send(Err("err")).await.unwrap(); + }) + }); + let mut out = vec![]; + while let Ok(value) = stream.try_next().await { + out.push(value); + } + assert_eq!(vec![Some(1), Some(2)], out); } // smithy-rs#1902: there was a bug where we could continue to poll the generator after it @@ -183,10 +223,16 @@ mod test { Box::leak(Box::new(tx)); }) }); - assert_eq!(stream.next().await, Some("blah")); + assert_eq!(Some("blah"), stream.next().await); let mut test_stream = tokio_test::task::spawn(stream); - assert!(test_stream.poll_next().is_pending()); - assert!(test_stream.poll_next().is_pending()); + let _ = test_stream.enter(|ctx, pin| { + let polled = pin.poll_next(ctx); + assert!(polled.is_pending()); + }); + let _ = test_stream.enter(|ctx, pin| { + let polled = pin.poll_next(ctx); + assert!(polled.is_pending()); + }); } /// Tests that the generator will not advance until demand exists @@ -209,13 +255,13 @@ mod test { stream.next().await.expect("ready"); assert_eq!(*progress.lock().unwrap(), 1); - assert_eq!(stream.next().await.expect("ready"), "2"); - assert_eq!(*progress.lock().unwrap(), 2); + assert_eq!("2", stream.next().await.expect("ready")); + assert_eq!(2, *progress.lock().unwrap()); let _ = stream.next().await.expect("ready"); - assert_eq!(*progress.lock().unwrap(), 3); - assert_eq!(stream.next().await, None); - assert_eq!(*progress.lock().unwrap(), 4); + assert_eq!(3, *progress.lock().unwrap()); + assert_eq!(None, stream.next().await); + assert_eq!(4, *progress.lock().unwrap()); } #[tokio::test] @@ -238,7 +284,7 @@ mod test { while let Some(Ok(value)) = stream.next().await { out.push(value); } - assert_eq!(out, vec![0, 1]); + assert_eq!(vec![0, 1], out); } #[tokio::test] @@ -262,12 +308,12 @@ mod test { }) }); assert_eq!( - TryFlatMap(stream) + Ok(vec![1, 2, 3, 4, 5, 6]), + TryFlatMap::new(stream) .flat_map(|output| output.items.into_iter()) .collect::, &str>>() .await, - Ok(vec![1, 2, 3, 4, 5, 6]) - ) + ); } #[tokio::test] @@ -287,11 +333,11 @@ mod test { }) }); assert_eq!( - TryFlatMap(stream) + Err("bummer"), + TryFlatMap::new(stream) .flat_map(|output| output.items.into_iter()) .collect::, &str>>() - .await, - Err("bummer") + .await ) } } diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs new file mode 100644 index 0000000000..6027ba8f8a --- /dev/null +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs @@ -0,0 +1,83 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Module to extend the functionality of `FnStream` to allow for collecting elements of the stream +//! into collection. +//! +//! Majority of the code is borrowed from +//! https://github.com/tokio-rs/tokio/blob/fc9518b62714daac9a38b46c698b94ac5d5b1ca2/tokio-stream/src/stream_ext/collect.rs + +/// A trait that signifies that elements can be collected into `T`. +/// +/// Currently the trait may not be implemented by clients so we can make changes in the future +/// without breaking code depending on it. +pub trait Collectable: sealed::CollectablePrivate {} + +pub(crate) mod sealed { + #[doc(hidden)] + pub trait CollectablePrivate { + type Collection; + + fn initialize() -> Self::Collection; + + fn extend(collection: &mut Self::Collection, item: T) -> bool; + + fn finalize(collection: &mut Self::Collection) -> Self; + } +} + +impl Collectable for Vec {} + +impl sealed::CollectablePrivate for Vec { + type Collection = Self; + + fn initialize() -> Self::Collection { + Vec::default() + } + + fn extend(collection: &mut Self::Collection, item: T) -> bool { + collection.push(item); + true + } + + fn finalize(collection: &mut Self::Collection) -> Self { + std::mem::take(collection) + } +} + +impl Collectable> for Result where U: Collectable {} + +impl sealed::CollectablePrivate> for Result +where + U: Collectable, +{ + type Collection = Result; + + fn initialize() -> Self::Collection { + Ok(U::initialize()) + } + + fn extend(collection: &mut Self::Collection, item: Result) -> bool { + match item { + Ok(item) => { + let collection = collection.as_mut().ok().expect("invalid state"); + U::extend(collection, item) + } + Err(e) => { + *collection = Err(e); + false + } + } + } + + fn finalize(collection: &mut Self::Collection) -> Self { + if let Ok(collection) = collection.as_mut() { + Ok(U::finalize(collection)) + } else { + let res = std::mem::replace(collection, Ok(U::initialize())); + Err(res.map(drop).unwrap_err()) + } + } +} diff --git a/rust-runtime/aws-smithy-async/src/future/rendezvous.rs b/rust-runtime/aws-smithy-async/src/future/rendezvous.rs index 16456f123e..f2342543f9 100644 --- a/rust-runtime/aws-smithy-async/src/future/rendezvous.rs +++ b/rust-runtime/aws-smithy-async/src/future/rendezvous.rs @@ -12,6 +12,7 @@ //! Rendezvous channels should be used with care—it's inherently easy to deadlock unless they're being //! used from separate tasks or an a coroutine setup (e.g. [`crate::future::fn_stream::FnStream`]) +use std::future::poll_fn; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::sync::Semaphore; @@ -104,7 +105,11 @@ pub struct Receiver { impl Receiver { /// Polls to receive an item from the channel - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + pub async fn recv(&mut self) -> Option { + poll_fn(|cx| self.poll_recv(cx)).await + } + + pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { // This uses `needs_permit` to track whether this is the first poll since we last returned an item. // If it is, we will grant a permit to the semaphore. Otherwise, we'll just forward the response through. let resp = self.chan.poll_recv(cx); @@ -124,13 +129,8 @@ impl Receiver { #[cfg(test)] mod test { - use crate::future::rendezvous::{channel, Receiver}; + use crate::future::rendezvous::channel; use std::sync::{Arc, Mutex}; - use tokio::macros::support::poll_fn; - - async fn recv(rx: &mut Receiver) -> Option { - poll_fn(|cx| rx.poll_recv(cx)).await - } #[tokio::test] async fn send_blocks_caller() { @@ -145,11 +145,11 @@ mod test { *idone.lock().unwrap() = 3; }); assert_eq!(*done.lock().unwrap(), 0); - assert_eq!(recv(&mut rx).await, Some(0)); + assert_eq!(rx.recv().await, Some(0)); assert_eq!(*done.lock().unwrap(), 1); - assert_eq!(recv(&mut rx).await, Some(1)); + assert_eq!(rx.recv().await, Some(1)); assert_eq!(*done.lock().unwrap(), 2); - assert_eq!(recv(&mut rx).await, None); + assert_eq!(rx.recv().await, None); assert_eq!(*done.lock().unwrap(), 3); let _ = send.await; } From 9a270750f3bb13fa50b961da36f8a08a5d6d7f12 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Tue, 8 Aug 2023 18:32:37 -0500 Subject: [PATCH 02/36] Remove `futures_core::stream::Stream` from `smithy-http` --- rust-runtime/aws-smithy-http/Cargo.toml | 2 +- .../aws-smithy-http/external-types.toml | 4 +- .../aws-smithy-http/src/byte_stream.rs | 113 +++++++++++----- .../src/event_stream/sender.rs | 126 +++++++++--------- 4 files changed, 139 insertions(+), 106 deletions(-) diff --git a/rust-runtime/aws-smithy-http/Cargo.toml b/rust-runtime/aws-smithy-http/Cargo.toml index 0238f5dfc6..72d03493eb 100644 --- a/rust-runtime/aws-smithy-http/Cargo.toml +++ b/rust-runtime/aws-smithy-http/Cargo.toml @@ -15,6 +15,7 @@ rt-tokio = ["dep:tokio-util", "dep:tokio", "tokio?/rt", "tokio?/fs", "tokio?/io- event-stream = ["aws-smithy-eventstream"] [dependencies] +aws-smithy-async = { path = "../aws-smithy-async" } aws-smithy-eventstream = { path = "../aws-smithy-eventstream", optional = true } aws-smithy-types = { path = "../aws-smithy-types" } bytes = "1" @@ -36,7 +37,6 @@ tokio = { version = "1.23.1", optional = true } tokio-util = { version = "0.7", optional = true } [dev-dependencies] -async-stream = "0.3" futures-util = { version = "0.3.16", default-features = false } hyper = { version = "0.14.26", features = ["stream"] } pretty_assertions = "1.3" diff --git a/rust-runtime/aws-smithy-http/external-types.toml b/rust-runtime/aws-smithy-http/external-types.toml index a228978c9a..4f76051c53 100644 --- a/rust-runtime/aws-smithy-http/external-types.toml +++ b/rust-runtime/aws-smithy-http/external-types.toml @@ -20,9 +20,6 @@ allowed_external_types = [ # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Feature gate Tokio `AsyncRead` "tokio::io::async_read::AsyncRead", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Switch to AsyncIterator once standardized - "futures_core::stream::Stream", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Feature gate references to Tokio `File` "tokio::fs::file::File", @@ -31,4 +28,5 @@ allowed_external_types = [ # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Once tooling permits it, only allow the following types in the `event-stream` feature "aws_smithy_eventstream::*", + "aws_smithy_async::*", ] diff --git a/rust-runtime/aws-smithy-http/src/byte_stream.rs b/rust-runtime/aws-smithy-http/src/byte_stream.rs index e067018a9d..e48cd673da 100644 --- a/rust-runtime/aws-smithy-http/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-http/src/byte_stream.rs @@ -48,7 +48,8 @@ //! //! ### Stream a ByteStream into a file //! The previous example is recommended in cases where loading the entire file into memory first is desirable. For extremely large -//! files, you may wish to stream the data directly to the file system, chunk by chunk. This is posible using the `futures::Stream` implementation. +//! files, you may wish to stream the data directly to the file system, chunk by chunk. +//! This is possible using the [`.next()`](crate::byte_stream::ByteStream::next). //! //! ```no_run //! use bytes::{Buf, Bytes}; @@ -128,6 +129,7 @@ use bytes::Bytes; use bytes_utils::SegmentedBuf; use http_body::Body; use pin_project_lite::pin_project; +use std::future::poll_fn; use std::io::IoSlice; use std::pin::Pin; use std::task::{Context, Poll}; @@ -166,9 +168,7 @@ pin_project! { /// println!("first chunk: {:?}", data.chunk()); /// } /// ``` - /// 2. Via [`impl Stream`](futures_core::Stream): - /// - /// _Note: An import of `StreamExt` is required to use `.try_next()`._ + /// 2. Via [`.next()`](crate::byte_stream::ByteStream::next) or [`.try_next()`](crate::byte_stream::ByteStream::try_next): /// /// For use-cases where holding the entire ByteStream in memory is unnecessary, use the /// `Stream` implementation: @@ -183,7 +183,6 @@ pin_project! { /// # } /// use aws_smithy_http::byte_stream::{ByteStream, AggregatedBytes, error::Error}; /// use aws_smithy_http::body::SdkBody; - /// use tokio_stream::StreamExt; /// /// async fn example() -> Result<(), Error> { /// let mut stream = ByteStream::from(vec![1, 2, 3, 4, 5, 99]); @@ -276,7 +275,7 @@ impl ByteStream { } } - /// Consumes the ByteStream, returning the wrapped SdkBody + /// Consume the `ByteStream`, returning the wrapped SdkBody. // Backwards compatibility note: Because SdkBody has a dyn variant, // we will always be able to implement this method, even if we stop using // SdkBody as the internal representation @@ -284,6 +283,26 @@ impl ByteStream { self.inner.body } + /// Return the next item in the `ByteStream`. + pub async fn next(&mut self) -> Option> { + Some(self.inner.next().await?.map_err(Error::streaming)) + } + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { + self.project().inner.poll_next(cx).map_err(Error::streaming) + } + + /// Consume and return the next item in the `ByteStream` or return an error if an error is + /// encountered. + pub async fn try_next(&mut self) -> Result, Error> { + self.next().await.transpose() + } + + /// Return the bounds on the remaining length of the `ByteStream`. + pub fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } + /// Read all the data from this `ByteStream` into memory /// /// If an error in the underlying stream is encountered, `ByteStreamError` is returned. @@ -393,7 +412,7 @@ impl ByteStream { /// # } /// ``` pub fn into_async_read(self) -> impl tokio::io::AsyncRead { - tokio_util::io::StreamReader::new(self) + tokio_util::io::StreamReader::new(StreamWrapper { byte_stream: self }) } /// Given a function to modify an [`SdkBody`], run it on the `SdkBody` inside this `Bytestream`. @@ -403,6 +422,29 @@ impl ByteStream { } } +pin_project! { + // A new-type wrapper around `ByteStream` so we can pass it to `tokio_util::io::StreamReader`. + struct StreamWrapper { + #[pin] + byte_stream: ByteStream, + } +} + +impl futures_core::stream::Stream for StreamWrapper { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .byte_stream + .poll_next(cx) + .map_err(Error::streaming) + } + + fn size_hint(&self) -> (usize, Option) { + self.byte_stream.size_hint() + } +} + impl Default for ByteStream { fn default() -> Self { Self { @@ -442,18 +484,6 @@ impl From for ByteStream { } } -impl futures_core::stream::Stream for ByteStream { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_next(cx).map_err(Error::streaming) - } - - fn size_hint(&self) -> (usize, Option) { - self.inner.size_hint() - } -} - /// Non-contiguous Binary Data Storage /// /// When data is read from the network, it is read in a sequence of chunks that are not in @@ -524,6 +554,25 @@ impl Inner { Self { body } } + async fn next(&mut self) -> Option> + where + Self: Unpin, + B: http_body::Body, + { + let mut me = Pin::new(self); + poll_fn(|cx| me.as_mut().poll_next(cx)).await + } + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> + where + B: http_body::Body, + { + self.project().body.poll_data(cx) + } + async fn collect(self) -> Result where B: http_body::Body, @@ -536,24 +585,11 @@ impl Inner { } Ok(AggregatedBytes(output)) } -} -const SIZE_HINT_32_BIT_PANIC_MESSAGE: &str = r#" -You're running a 32-bit system and this stream's length is too large to be represented with a usize. -Please limit stream length to less than 4.294Gb or run this program on a 64-bit computer architecture. -"#; - -impl futures_core::stream::Stream for Inner -where - B: http_body::Body, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().body.poll_data(cx) - } - - fn size_hint(&self) -> (usize, Option) { + fn size_hint(&self) -> (usize, Option) + where + B: http_body::Body, + { let size_hint = http_body::Body::size_hint(&self.body); let lower = size_hint.lower().try_into(); let upper = size_hint.upper().map(|u| u.try_into()).transpose(); @@ -567,6 +603,11 @@ where } } +const SIZE_HINT_32_BIT_PANIC_MESSAGE: &str = r#" +You're running a 32-bit system and this stream's length is too large to be represented with a usize. +Please limit stream length to less than 4.294Gb or run this program on a 64-bit computer architecture. +"#; + #[cfg(test)] mod tests { use crate::byte_stream::Inner; diff --git a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs index d19690e727..825c93e1f9 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs @@ -4,20 +4,25 @@ */ use crate::result::SdkError; +use aws_smithy_async::future::fn_stream::FnStream; use aws_smithy_eventstream::frame::{MarshallMessage, SignMessage}; use bytes::Bytes; -use futures_core::Stream; use std::error::Error as StdError; use std::fmt; use std::fmt::Debug; +use std::future::poll_fn; use std::marker::PhantomData; use std::pin::Pin; +use std::sync::Mutex; use std::task::{Context, Poll}; use tracing::trace; /// Input type for Event Streams. pub struct EventStreamSender { - input_stream: Pin> + Send + Sync>>, + // `FnStream` does not have a `Sync` bound but this struct needs to be `Sync` + // as demonstrated by a unit test `event_stream_sender_send`. + // Wrapping `input_stream` with a `Mutex` will make `EventStreamSender` `Sync`. + input_stream: Mutex>>, } impl Debug for EventStreamSender { @@ -36,17 +41,19 @@ impl EventStreamSender { error_marshaller: impl MarshallMessage + Send + Sync + 'static, signer: impl SignMessage + Send + Sync + 'static, ) -> MessageStreamAdapter { - MessageStreamAdapter::new(marshaller, error_marshaller, signer, self.input_stream) + MessageStreamAdapter::new( + marshaller, + error_marshaller, + signer, + std::mem::replace(&mut *self.input_stream.lock().unwrap(), FnStream::taken()), + ) } } -impl From for EventStreamSender -where - S: Stream> + Send + Sync + 'static, -{ - fn from(stream: S) -> Self { +impl From>> for EventStreamSender { + fn from(stream: FnStream>) -> Self { EventStreamSender { - input_stream: Box::pin(stream), + input_stream: Mutex::new(stream), } } } @@ -109,24 +116,24 @@ impl fmt::Display for MessageStreamError { /// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be /// marshalled into an Event Stream frame, (e.g., if the message payload was too large). #[allow(missing_debug_implementations)] -pub struct MessageStreamAdapter { +pub struct MessageStreamAdapter { marshaller: Box + Send + Sync>, error_marshaller: Box + Send + Sync>, signer: Box, - stream: Pin> + Send>>, + stream: FnStream>, end_signal_sent: bool, _phantom: PhantomData, } impl Unpin for MessageStreamAdapter {} -impl MessageStreamAdapter { +impl MessageStreamAdapter { /// Create a new `MessageStreamAdapter`. pub fn new( marshaller: impl MarshallMessage + Send + Sync + 'static, error_marshaller: impl MarshallMessage + Send + Sync + 'static, signer: impl SignMessage + Send + Sync + 'static, - stream: Pin> + Send>>, + stream: FnStream>, ) -> Self { MessageStreamAdapter { marshaller: Box::new(marshaller), @@ -139,11 +146,20 @@ impl MessageStreamAdapter { } } -impl Stream for MessageStreamAdapter { - type Item = Result>; +impl MessageStreamAdapter { + /// Consumes and returns the next item from this stream. + pub async fn next(&mut self) -> Option>> { + let mut me = Pin::new(self); + poll_fn(|cx| me.as_mut().poll_next(cx)).await + } - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.stream.as_mut().poll_next(cx) { + /// Attempts to pull out the next value of this stream, returning `None` if the stream is + /// exhausted. + pub fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>>> { + match Pin::new(&mut self.stream).as_mut().poll_next(cx) { Poll::Ready(message_option) => { if let Some(message_result) = message_option { let message = match message_result { @@ -196,14 +212,11 @@ mod tests { use super::MarshallMessage; use crate::event_stream::{EventStreamSender, MessageStreamAdapter}; use crate::result::SdkError; - use async_stream::stream; + use aws_smithy_async::future::fn_stream::FnStream; use aws_smithy_eventstream::error::Error as EventStreamError; use aws_smithy_eventstream::frame::{ Header, HeaderValue, Message, NoOpSigner, SignMessage, SignMessageError, }; - use bytes::Bytes; - use futures_core::Stream; - use futures_util::stream::StreamExt; use std::error::Error as StdError; #[derive(Debug)] @@ -267,35 +280,29 @@ mod tests { } #[test] - fn event_stream_sender_send_sync() { - check_send_sync(EventStreamSender::from(stream! { - yield Result::<_, SignMessageError>::Ok(TestMessage("test".into())); - })); - } - - fn check_compatible_with_hyper_wrap_stream(stream: S) -> S - where - S: Stream> + Send + 'static, - O: Into + 'static, - E: Into> + 'static, - { - stream + fn event_stream_sender_send() { + check_send_sync(EventStreamSender::from(FnStream::new(|tx| { + Box::pin(async move { + let message = Result::<_, TestServiceError>::Ok(TestMessage("test".into())); + tx.send(message).await.expect("failed to send"); + }) + }))); } #[tokio::test] async fn message_stream_adapter_success() { - let stream = stream! { - yield Ok(TestMessage("test".into())); - }; - let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< - TestMessage, - TestServiceError, - >::new( + let stream = FnStream::new(|tx| { + Box::pin(async move { + let message = Ok(TestMessage("test".into())); + tx.send(message).await.expect("failed to send"); + }) + }); + let mut adapter = MessageStreamAdapter::::new( Marshaller, ErrorMarshaller, TestSigner, - Box::pin(stream), - )); + stream, + ); let mut sent_bytes = adapter.next().await.unwrap().unwrap(); let sent = Message::read_from(&mut sent_bytes).unwrap(); @@ -313,18 +320,19 @@ mod tests { #[tokio::test] async fn message_stream_adapter_construction_failure() { - let stream = stream! { - yield Err(TestServiceError); - }; - let mut adapter = check_compatible_with_hyper_wrap_stream(MessageStreamAdapter::< - TestMessage, - TestServiceError, - >::new( + let stream = FnStream::new(|tx| { + Box::pin(async move { + tx.send(Err(TestServiceError)) + .await + .expect("failed to send"); + }) + }); + let mut adapter = MessageStreamAdapter::::new( Marshaller, ErrorMarshaller, NoOpSigner {}, - Box::pin(stream), - )); + stream, + ); let result = adapter.next().await.unwrap(); assert!(result.is_err()); @@ -333,18 +341,4 @@ mod tests { SdkError::ConstructionFailure(_) )); } - - // Verify the developer experience for this compiles - #[allow(unused)] - fn event_stream_input_ergonomics() { - fn check(input: impl Into>) { - let _: EventStreamSender = input.into(); - } - check(stream! { - yield Ok(TestMessage("test".into())); - }); - check(stream! { - yield Err(TestServiceError); - }); - } } From f5b440c5b41179e5f34a64dd43b107f8e78a8e5d Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Tue, 8 Aug 2023 19:11:54 -0500 Subject: [PATCH 03/36] Remove `futures_core::stream::Stream` from `Paginator` --- .../smithy/generators/PaginatorGenerator.kt | 14 +-- .../protocols/HttpBoundProtocolGenerator.kt | 12 ++- .../codegen/core/rustlang/CargoDependency.kt | 3 +- rust-runtime/inlineable/Cargo.toml | 5 +- .../inlineable/src/hyper_body_wrap_stream.rs | 99 +++++++++++++++++++ rust-runtime/inlineable/src/lib.rs | 2 + 6 files changed, 125 insertions(+), 10 deletions(-) create mode 100644 rust-runtime/inlineable/src/hyper_body_wrap_stream.rs diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt index f68aed2dec..27e6c32018 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt @@ -159,9 +159,9 @@ class PaginatorGenerator private constructor( /// Create the pagination stream /// - /// _Note:_ No requests will be dispatched until the stream is used (eg. with [`.next().await`](tokio_stream::StreamExt::next)). - pub fn send(self) -> impl #{Stream} + #{Unpin} - #{send_bounds:W} { + /// _Note:_ No requests will be dispatched until the stream is used + /// (e.g. with [`.next().await`](aws_smithy_async::future::fn_stream::FnStream::next)). + pub fn send(self) -> #{fn_stream}::FnStream<#{item_type}> { // Move individual fields out of self for the borrow checker let builder = self.builder; let handle = self.handle; @@ -302,11 +302,11 @@ class PaginatorGenerator private constructor( impl ${generics.inst} ${paginatorName}Items${generics.inst} #{bounds:W} { /// Create the pagination stream /// - /// _Note: No requests will be dispatched until the stream is used (eg. with [`.next().await`](tokio_stream::StreamExt::next))._ + /// _Note_: No requests will be dispatched until the stream is used + /// (e.g. with [`.next().await`](aws_smithy_async::future::fn_stream::FnStream::next)). /// - /// To read the entirety of the paginator, use [`.collect::, _>()`](tokio_stream::StreamExt::collect). - pub fn send(self) -> impl #{Stream} + #{Unpin} - #{send_bounds:W} { + /// To read the entirety of the paginator, use [`.collect::, _>()`](aws_smithy_async::future::fn_stream::FnStream::collect). + pub fn send(self) -> #{fn_stream}::FnStream<#{item_type}> { #{fn_stream}::TryFlatMap::new(self.0.send()).flat_map(|page| #{extract_items}(page).unwrap_or_default().into_iter()) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index d8c8bf9dc2..064ff150c2 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection @@ -16,10 +17,13 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Mak import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolParserGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.toType import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope @@ -73,7 +77,7 @@ class ClientHttpBoundProtocolPayloadGenerator( #{insert_into_config} let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); - let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); + let body: #{SdkBody} = #{hyper}::Body::wrap_stream(#{HyperBodyWrapStreamCompat}::new(adapter)).into(); body } """, @@ -81,6 +85,12 @@ class ClientHttpBoundProtocolPayloadGenerator( "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::DeferredSigner"), + "HyperBodyWrapStreamCompat" to InlineDependency.forRustFile( + RustModule.pubCrate("stream_compat", parent = ClientRustModule.root), + "/inlineable/src/hyper_body_wrap_stream.rs", + CargoDependency.smithyHttpEventStream(codegenContext.runtimeConfig), + CargoDependency.FuturesCore, + ).toType().resolve("HyperBodyWrapStreamCompat"), "marshallerConstructorFn" to params.marshallerConstructorFn, "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, "insert_into_config" to writable { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index 73a5e9ebbe..6367d110a1 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -227,6 +227,7 @@ data class CargoDependency( val Bytes: CargoDependency = CargoDependency("bytes", CratesIo("1.0.0")) val BytesUtils: CargoDependency = CargoDependency("bytes-utils", CratesIo("0.1.0")) val FastRand: CargoDependency = CargoDependency("fastrand", CratesIo("2.0.0")) + val FuturesCore: CargoDependency = CargoDependency("futures-core", CratesIo("0.3.25")) val Hex: CargoDependency = CargoDependency("hex", CratesIo("0.4.3")) val Http: CargoDependency = CargoDependency("http", CratesIo("0.2.9")) val HttpBody: CargoDependency = CargoDependency("http-body", CratesIo("0.4.4")) @@ -246,7 +247,6 @@ data class CargoDependency( val AsyncStd: CargoDependency = CargoDependency("async-std", CratesIo("1.12.0"), DependencyScope.Dev) val AsyncStream: CargoDependency = CargoDependency("async-stream", CratesIo("0.3.0"), DependencyScope.Dev) val Criterion: CargoDependency = CargoDependency("criterion", CratesIo("0.4.0"), DependencyScope.Dev) - val FuturesCore: CargoDependency = CargoDependency("futures-core", CratesIo("0.3.25"), DependencyScope.Dev) val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3.25"), DependencyScope.Dev, defaultFeatures = false) val HdrHistogram: CargoDependency = CargoDependency("hdrhistogram", CratesIo("7.5.2"), DependencyScope.Dev) @@ -289,6 +289,7 @@ data class CargoDependency( fun smithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-eventstream") fun smithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http") + fun smithyHttpEventStream(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).withFeature("event-stream") fun smithyHttpAuth(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-auth") fun smithyHttpTower(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-tower") fun smithyJson(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-json") diff --git a/rust-runtime/inlineable/Cargo.toml b/rust-runtime/inlineable/Cargo.toml index 6d0c099429..97cd27ea85 100644 --- a/rust-runtime/inlineable/Cargo.toml +++ b/rust-runtime/inlineable/Cargo.toml @@ -19,7 +19,7 @@ default = ["gated-tests"] [dependencies] async-trait = "0.1" -aws-smithy-http = { path = "../aws-smithy-http" } +aws-smithy-http = { path = "../aws-smithy-http", features = ["event-stream"] } aws-smithy-http-server = { path = "../aws-smithy-http-server" } aws-smithy-json = { path = "../aws-smithy-json" } aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["client"] } @@ -27,6 +27,7 @@ aws-smithy-types = { path = "../aws-smithy-types" } aws-smithy-xml = { path = "../aws-smithy-xml" } bytes = "1" fastrand = "2.0.0" +futures-core = "0.3" futures-util = "0.3" http = "0.2.1" md-5 = "0.10.0" @@ -39,6 +40,8 @@ url = "2.2.2" [dev-dependencies] proptest = "1" +aws-smithy-async = { path = "../aws-smithy-async" } +aws-smithy-eventstream = { path = "../aws-smithy-eventstream" } [package.metadata.docs.rs] all-features = true diff --git a/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs b/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs new file mode 100644 index 0000000000..69ad994876 --- /dev/null +++ b/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs @@ -0,0 +1,99 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_http::event_stream::MessageStreamAdapter; +use aws_smithy_http::result::SdkError; +use bytes::Bytes; +use futures_core::stream::Stream; +use std::error::Error as StdError; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub(crate) struct HyperBodyWrapStreamCompat(MessageStreamAdapter); + +impl HyperBodyWrapStreamCompat { + pub(crate) fn new(adapter: MessageStreamAdapter) -> Self { + Self(adapter) + } +} + +impl Unpin for HyperBodyWrapStreamCompat {} + +impl Stream for HyperBodyWrapStreamCompat { + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_next(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use aws_smithy_async::future::fn_stream::FnStream; + use aws_smithy_eventstream::error::Error; + use aws_smithy_eventstream::frame::MarshallMessage; + use aws_smithy_eventstream::frame::{Message, NoOpSigner}; + use futures_core::stream::Stream; + + #[derive(Debug, Eq, PartialEq)] + struct TestMessage(String); + + #[derive(Debug)] + struct Marshaller; + impl MarshallMessage for Marshaller { + type Input = TestMessage; + + fn marshall(&self, input: Self::Input) -> Result { + Ok(Message::new(input.0.as_bytes().to_vec())) + } + } + #[derive(Debug)] + struct ErrorMarshaller; + impl MarshallMessage for ErrorMarshaller { + type Input = TestServiceError; + + fn marshall(&self, _input: Self::Input) -> Result { + Err(Message::read_from(&b""[..]).expect_err("this should always fail")) + } + } + + #[derive(Debug)] + struct TestServiceError; + impl std::fmt::Display for TestServiceError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TestServiceError") + } + } + impl StdError for TestServiceError {} + + fn check_compatible_with_hyper_wrap_stream(stream: S) -> S + where + S: Stream> + Send + 'static, + O: Into + 'static, + E: Into> + 'static, + { + stream + } + + #[test] + fn test_message_adapter_stream_is_compatible_with_hyper_wrap_stream() { + let stream = FnStream::new(|tx| { + Box::pin(async move { + let message = Ok(TestMessage("test".into())); + tx.send(message).await.expect("failed to send"); + }) + }); + check_compatible_with_hyper_wrap_stream(HyperBodyWrapStreamCompat(MessageStreamAdapter::< + TestMessage, + TestServiceError, + >::new( + Marshaller, + ErrorMarshaller, + NoOpSigner {}, + stream, + ))); + } +} diff --git a/rust-runtime/inlineable/src/lib.rs b/rust-runtime/inlineable/src/lib.rs index b672eeef9d..fd40a4aca7 100644 --- a/rust-runtime/inlineable/src/lib.rs +++ b/rust-runtime/inlineable/src/lib.rs @@ -13,6 +13,8 @@ mod client_idempotency_token; mod constrained; #[allow(dead_code)] mod ec2_query_errors; +#[allow(unused)] +mod hyper_body_wrap_stream; #[allow(dead_code)] mod idempotency_token; #[allow(dead_code)] From 2bcf9b5df6b176d675c49938f824e2f9fef2edcf Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Tue, 8 Aug 2023 19:12:14 -0500 Subject: [PATCH 04/36] Remove `futures_core::stream::Stream` from integration tests --- .../rustsdk/IntegrationTestDependencies.kt | 2 +- .../dynamodb/tests/paginators.rs | 2 - .../integration-tests/ec2/tests/paginators.rs | 2 - .../transcribestreaming/Cargo.toml | 2 +- .../transcribestreaming/tests/test.rs | 41 ++++++++++++------- aws/sdk/sdk-external-types.toml | 8 ---- 6 files changed, 28 insertions(+), 29 deletions(-) diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt index 3cee565e64..7afd02a012 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt @@ -126,7 +126,7 @@ class TranscribeTestDependencies : LibRsCustomization() { override fun section(section: LibRsSection): Writable = writable { addDependency(AsyncStream) - addDependency(FuturesCore) + addDependency(FuturesCore.toDevDependency()) addDependency(Hound) } } diff --git a/aws/sdk/integration-tests/dynamodb/tests/paginators.rs b/aws/sdk/integration-tests/dynamodb/tests/paginators.rs index 807a11890d..a3d0c62473 100644 --- a/aws/sdk/integration-tests/dynamodb/tests/paginators.rs +++ b/aws/sdk/integration-tests/dynamodb/tests/paginators.rs @@ -6,8 +6,6 @@ use std::collections::HashMap; use std::iter::FromIterator; -use tokio_stream::StreamExt; - use aws_credential_types::Credentials; use aws_sdk_dynamodb::types::AttributeValue; use aws_sdk_dynamodb::{Client, Config}; diff --git a/aws/sdk/integration-tests/ec2/tests/paginators.rs b/aws/sdk/integration-tests/ec2/tests/paginators.rs index 83528f2075..d070971a4f 100644 --- a/aws/sdk/integration-tests/ec2/tests/paginators.rs +++ b/aws/sdk/integration-tests/ec2/tests/paginators.rs @@ -3,8 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -use tokio_stream::StreamExt; - use aws_sdk_ec2::{config::Credentials, config::Region, types::InstanceType, Client, Config}; use aws_smithy_client::http_connector::HttpConnector; use aws_smithy_client::test_connection::TestConnection; diff --git a/aws/sdk/integration-tests/transcribestreaming/Cargo.toml b/aws/sdk/integration-tests/transcribestreaming/Cargo.toml index 181ba493cb..d214ae56ed 100644 --- a/aws/sdk/integration-tests/transcribestreaming/Cargo.toml +++ b/aws/sdk/integration-tests/transcribestreaming/Cargo.toml @@ -9,10 +9,10 @@ repository = "https://github.com/awslabs/smithy-rs" publish = false [dev-dependencies] -async-stream = "0.3.0" aws-credential-types = { path = "../../build/aws-sdk/sdk/aws-credential-types", features = ["test-util"] } aws-http = { path = "../../build/aws-sdk/sdk/aws-http" } aws-sdk-transcribestreaming = { path = "../../build/aws-sdk/sdk/transcribestreaming" } +aws-smithy-async = { path = "../../build/aws-sdk/sdk/aws-smithy-async" } aws-smithy-client = { path = "../../build/aws-sdk/sdk/aws-smithy-client", features = ["test-util", "rustls"] } aws-smithy-eventstream = { path = "../../build/aws-sdk/sdk/aws-smithy-eventstream" } aws-smithy-http = { path = "../../build/aws-sdk/sdk/aws-smithy-http" } diff --git a/aws/sdk/integration-tests/transcribestreaming/tests/test.rs b/aws/sdk/integration-tests/transcribestreaming/tests/test.rs index 62654ebd82..29c963d6c5 100644 --- a/aws/sdk/integration-tests/transcribestreaming/tests/test.rs +++ b/aws/sdk/integration-tests/transcribestreaming/tests/test.rs @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -use async_stream::stream; use aws_sdk_transcribestreaming::config::{Credentials, Region}; use aws_sdk_transcribestreaming::error::SdkError; use aws_sdk_transcribestreaming::operation::start_stream_transcription::StartStreamTranscriptionOutput; @@ -13,10 +12,10 @@ use aws_sdk_transcribestreaming::types::{ AudioEvent, AudioStream, LanguageCode, MediaEncoding, TranscriptResultStream, }; use aws_sdk_transcribestreaming::{Client, Config}; +use aws_smithy_async::future::fn_stream::FnStream; use aws_smithy_client::dvr::{Event, ReplayingConnection}; use aws_smithy_eventstream::frame::{DecodedFrame, HeaderValue, Message, MessageFrameDecoder}; use bytes::BufMut; -use futures_core::Stream; use std::collections::{BTreeMap, BTreeSet}; use std::error::Error as StdError; @@ -24,12 +23,18 @@ const CHUNK_SIZE: usize = 8192; #[tokio::test] async fn test_success() { - let input_stream = stream! { - let pcm = pcm_data(); - for chunk in pcm.chunks(CHUNK_SIZE) { - yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); - } - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + let pcm = pcm_data(); + for chunk in pcm.chunks(CHUNK_SIZE) { + tx.send(Ok(AudioStream::AudioEvent( + AudioEvent::builder().audio_chunk(Blob::new(chunk)).build(), + ))) + .await + .expect("send should succeed"); + } + }) + }); let (replayer, mut output) = start_request("us-west-2", include_str!("success.json"), input_stream).await; @@ -65,12 +70,18 @@ async fn test_success() { #[tokio::test] async fn test_error() { - let input_stream = stream! { - let pcm = pcm_data(); - for chunk in pcm.chunks(CHUNK_SIZE).take(1) { - yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); - } - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + let pcm = pcm_data(); + for chunk in pcm.chunks(CHUNK_SIZE).take(1) { + tx.send(Ok(AudioStream::AudioEvent( + AudioEvent::builder().audio_chunk(Blob::new(chunk)).build(), + ))) + .await + .expect("send should succeed"); + } + }) + }); let (replayer, mut output) = start_request("us-east-1", include_str!("error.json"), input_stream).await; @@ -97,7 +108,7 @@ async fn test_error() { async fn start_request( region: &'static str, events_json: &str, - input_stream: impl Stream> + Send + Sync + 'static, + input_stream: FnStream>, ) -> (ReplayingConnection, StartStreamTranscriptionOutput) { let events: Vec = serde_json::from_str(events_json).unwrap(); let replayer = ReplayingConnection::new(events); diff --git a/aws/sdk/sdk-external-types.toml b/aws/sdk/sdk-external-types.toml index b484544c27..b623edd4a1 100644 --- a/aws/sdk/sdk-external-types.toml +++ b/aws/sdk/sdk-external-types.toml @@ -19,14 +19,6 @@ allowed_external_types = [ "http::uri::Uri", "http::method::Method", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Switch to AsyncIterator once standardized - "futures_core::stream::Stream", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Once tooling permits it, only allow the following types in the `event-stream` feature "aws_smithy_eventstream::*", - - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Decide if we want to continue exposing tower_layer - "tower_layer::Layer", - "tower_layer::identity::Identity", - "tower_layer::stack::Stack", ] From 181440dbf4c95e50d48dc7a4470b996d83e4f67b Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Tue, 8 Aug 2023 20:11:46 -0500 Subject: [PATCH 05/36] Fix links in docs --- .../client/smithy/generators/client/FluentClientGenerator.kt | 2 +- rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt index d7fb4560b5..ef88db3ac3 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt @@ -570,7 +570,7 @@ class FluentClientGenerator( """ /// Create a paginator for this request /// - /// Paginators are used by calling [`send().await`](#{Paginator}::send) which returns a `Stream`. + /// Paginators are used by calling [`send().await`](#{Paginator}::send) which returns an [`FnStream`](aws_smithy_async::future::fn_stream::FnStream). pub fn into_paginator(self) -> #{Paginator}${generics.inst} { #{Paginator}::new(self.handle, self.inner) } diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs index 6027ba8f8a..f909fd730e 100644 --- a/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs @@ -7,7 +7,7 @@ //! into collection. //! //! Majority of the code is borrowed from -//! https://github.com/tokio-rs/tokio/blob/fc9518b62714daac9a38b46c698b94ac5d5b1ca2/tokio-stream/src/stream_ext/collect.rs +//! /// A trait that signifies that elements can be collected into `T`. /// From f00d605a93afb70b725cedcacc008034255763a3 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Wed, 9 Aug 2023 15:52:16 -0500 Subject: [PATCH 06/36] Fix codegen-server-test:test --- .../ServerHttpBoundProtocolGenerator.kt | 124 +++++++++++++++--- .../aws-smithy-http/src/byte_stream.rs | 7 +- .../inlineable/src/hyper_body_wrap_stream.rs | 40 +++++- 3 files changed, 144 insertions(+), 27 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index c579d64415..6ff7c647cc 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -28,6 +28,9 @@ import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.MediaTypeTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -38,6 +41,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter +import software.amazon.smithy.rust.codegen.core.rustlang.toType import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -68,10 +72,12 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.core.util.isEventStream import software.amazon.smithy.rust.codegen.core.util.isStreaming import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerRequestBindingGenerator @@ -85,7 +91,8 @@ import java.util.logging.Logger * Class describing a ServerHttpBoundProtocol section that can be used in a customization. */ sealed class ServerHttpBoundProtocolSection(name: String) : Section(name) { - data class AfterTimestampDeserializedMember(val shape: MemberShape) : ServerHttpBoundProtocolSection("AfterTimestampDeserializedMember") + data class AfterTimestampDeserializedMember(val shape: MemberShape) : + ServerHttpBoundProtocolSection("AfterTimestampDeserializedMember") } /** @@ -105,7 +112,12 @@ class ServerHttpBoundProtocolGenerator( additionalHttpBindingCustomizations: List = listOf(), ) : ServerProtocolGenerator( protocol, - ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations), + ServerHttpBoundProtocolTraitImplGenerator( + codegenContext, + protocol, + customizations, + additionalHttpBindingCustomizations, + ), ) { // Define suffixes for operation input / output / error wrappers companion object { @@ -513,6 +525,25 @@ class ServerHttpBoundProtocolTraitImplGenerator( } } + /** + * Return a Writable that renders a new-type wrapper implementing "futures_core::stream::Stream" + * on behalf of the inner type. + */ + private fun futuresStreamCompatible(operationShape: OperationShape): RuntimeType { + val inlineable = InlineDependency.forRustFile( + RustModule.pubCrate("stream_compat", parent = ServerRustModule.root), + "/inlineable/src/hyper_body_wrap_stream.rs", + CargoDependency.smithyHttpEventStream(codegenContext.runtimeConfig), + CargoDependency.smithyAsync(codegenContext.runtimeConfig).toDevDependency(), + CargoDependency.smithyEventStream(codegenContext.runtimeConfig).toDevDependency(), + CargoDependency.FuturesCore, + ).toType() + if (operationShape.isEventStream(model)) { + return inlineable.resolve("HyperBodyWrapEventStream") + } + return inlineable.resolve("HyperBodyWrapByteStream") + } + /** * Render an HTTP response (headers, response code, body) for an operation's output and the given [bindings]. */ @@ -539,7 +570,12 @@ class ServerHttpBoundProtocolTraitImplGenerator( operationShape.outputShape(model).findStreamingMember(model)?.let { val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol) - withBlockTemplate("let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", "));", *codegenScope) { + withBlockTemplate( + "let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(#{futures_stream_compatible}::new(", + ")));", + *codegenScope, + "futures_stream_compatible" to futuresStreamCompatible(operationShape), + ) { payloadGenerator.generatePayload(this, "output", operationShape) } } ?: run { @@ -576,7 +612,10 @@ class ServerHttpBoundProtocolTraitImplGenerator( * 2. The protocol-specific `Content-Type` header for the operation. * 3. Additional protocol-specific headers for errors, if [errorShape] is non-null. */ - private fun RustWriter.serverRenderResponseHeaders(operationShape: OperationShape, errorShape: StructureShape? = null) { + private fun RustWriter.serverRenderResponseHeaders( + operationShape: OperationShape, + errorShape: StructureShape? = null, + ) { val bindingGenerator = ServerResponseBindingGenerator(protocol, codegenContext, operationShape) val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape ?: operationShape) if (addHeadersFn != null) { @@ -686,7 +725,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( inputShape: StructureShape, bindings: List, ) { - val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) + val httpBindingGenerator = + ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) val structuredDataParser = protocol.structuredDataParser() Attribute.AllowUnusedMut.render(this) rust( @@ -696,7 +736,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( Attribute.AllowUnusedVariables.render(this) rust("let (parts, body) = request.into_parts();") val parser = structuredDataParser.serverInputParser(operationShape) - val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null + val noInputs = + model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (parser != null) { // `null` is only returned by Smithy when there are no members, but we know there's at least one, since @@ -707,7 +748,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( rustTemplate( """ #{SmithyHttpServer}::protocol::content_type_header_classifier( - &parts.headers, + &parts.headers, Some("$expectedRequestContentType"), )?; input = #{parser}(bytes.as_ref(), input)?; @@ -719,7 +760,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( } for (binding in bindings) { val member = binding.member - val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) + val parsedValue = + serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) if (parsedValue != null) { rust("if let Some(value) = ") parsedValue(this) @@ -809,10 +851,12 @@ class ServerHttpBoundProtocolTraitImplGenerator( } } } + HttpLocation.DOCUMENT, HttpLocation.LABEL, HttpLocation.QUERY, HttpLocation.QUERY_PARAMS -> { // All of these are handled separately. null } + else -> { logger.warning("[rust-server-codegen] ${operationShape.id}: request parsing does not currently support ${binding.location} bindings") null @@ -838,15 +882,24 @@ class ServerHttpBoundProtocolTraitImplGenerator( } val restAfterGreedyLabel = if (greedyLabelIndex >= 0) { - httpTrait.uri.segments.slice((greedyLabelIndex + 1) until httpTrait.uri.segments.size).joinToString(prefix = "/", separator = "/") + httpTrait.uri.segments.slice((greedyLabelIndex + 1) until httpTrait.uri.segments.size) + .joinToString(prefix = "/", separator = "/") } else { "" } val labeledNames = segments .mapIndexed { index, segment -> - if (segment.isLabel) { "m$index" } else { "_" } + if (segment.isLabel) { + "m$index" + } else { + "_" + } } - .joinToString(prefix = (if (segments.size > 1) "(" else ""), separator = ",", postfix = (if (segments.size > 1) ")" else "")) + .joinToString( + prefix = (if (segments.size > 1) "(" else ""), + separator = ",", + postfix = (if (segments.size > 1) ")" else ""), + ) val nomParser = segments .map { segment -> if (segment.isGreedyLabel) { @@ -1011,6 +1064,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( rust("let v = v.into_owned();") } } + memberShape.isTimestampShape -> { val index = HttpBindingIndex.of(model) val timestampFormat = @@ -1019,7 +1073,11 @@ class ServerHttpBoundProtocolTraitImplGenerator( it.location, protocol.defaultTimestampFormat, ) - val timestampFormatType = RuntimeType.parseTimestampFormat(CodegenTarget.SERVER, runtimeConfig, timestampFormat) + val timestampFormatType = RuntimeType.parseTimestampFormat( + CodegenTarget.SERVER, + runtimeConfig, + timestampFormat, + ) rustTemplate( """ let v = #{DateTime}::from_str(&v, #{format})? @@ -1028,10 +1086,15 @@ class ServerHttpBoundProtocolTraitImplGenerator( "format" to timestampFormatType, ) for (customization in customizations) { - customization.section(ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember(it.member))(this) + customization.section( + ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember( + it.member, + ), + )(this) } rust(";") } + else -> { // Number or boolean. rust( """ @@ -1054,9 +1117,11 @@ class ServerHttpBoundProtocolTraitImplGenerator( QueryParamsTargetMapValueType.STRING -> { rust("query_params.${if (hasConstrainedTarget) "0." else ""}entry(String::from(k)).or_insert_with(|| String::from(v));") } + QueryParamsTargetMapValueType.LIST, QueryParamsTargetMapValueType.SET -> { if (hasConstrainedTarget) { - val collectionShape = model.expectShape(target.value.target, CollectionShape::class.java) + val collectionShape = + model.expectShape(target.value.target, CollectionShape::class.java) val collectionSymbol = unconstrainedShapeSymbolProvider.toSymbol(collectionShape) rust( // `or_insert_with` instead of `or_insert` to avoid the allocation when the entry is @@ -1091,7 +1156,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( // TODO(https://github.com/awslabs/smithy-rs/issues/1401) Constraint traits on member shapes are not // implemented yet. val hasConstrainedTarget = - model.expectShape(binding.member.target, CollectionShape::class.java).canReachConstrainedShape(model, symbolProvider) + model.expectShape(binding.member.target, CollectionShape::class.java) + .canReachConstrainedShape(model, symbolProvider) val memberName = unconstrainedShapeSymbolProvider.toMemberName(binding.member) val isOptional = unconstrainedShapeSymbolProvider.toSymbol(binding.member).isOptional() rustBlock("if !$memberName.is_empty()") { @@ -1119,8 +1185,13 @@ class ServerHttpBoundProtocolTraitImplGenerator( } } - private fun serverRenderHeaderParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { - val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) + private fun serverRenderHeaderParser( + writer: RustWriter, + binding: HttpBindingDescriptor, + operationShape: OperationShape, + ) { + val httpBindingGenerator = + ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations) val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding) writer.rustTemplate( """ @@ -1131,7 +1202,11 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } - private fun serverRenderPrefixHeadersParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) { + private fun serverRenderPrefixHeadersParser( + writer: RustWriter, + binding: HttpBindingDescriptor, + operationShape: OperationShape, + ) { check(binding.location == HttpLocation.PREFIX_HEADERS) val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) @@ -1168,6 +1243,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( rust("let value = value.to_owned();") } } + target.isTimestampShape -> { val index = HttpBindingIndex.of(model) val timestampFormat = @@ -1176,7 +1252,8 @@ class ServerHttpBoundProtocolTraitImplGenerator( binding.location, protocol.defaultTimestampFormat, ) - val timestampFormatType = RuntimeType.parseTimestampFormat(CodegenTarget.SERVER, runtimeConfig, timestampFormat) + val timestampFormatType = + RuntimeType.parseTimestampFormat(CodegenTarget.SERVER, runtimeConfig, timestampFormat) if (percentDecoding) { rustTemplate( @@ -1197,10 +1274,15 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } for (customization in customizations) { - customization.section(ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember(binding.member))(this) + customization.section( + ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember( + binding.member, + ), + )(this) } rust(";") } + else -> { check(target is NumberShape || target is BooleanShape) rustTemplate( @@ -1230,9 +1312,11 @@ class ServerHttpBoundProtocolTraitImplGenerator( RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> { RuntimeType.smithyJson(runtimeConfig).resolve("deserialize::error::DeserializeError").toSymbol() } + RestXmlTrait.ID -> { RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError").toSymbol() } + else -> { TODO("Protocol ${codegenContext.protocol} not supported yet") } diff --git a/rust-runtime/aws-smithy-http/src/byte_stream.rs b/rust-runtime/aws-smithy-http/src/byte_stream.rs index e48cd673da..28cf660969 100644 --- a/rust-runtime/aws-smithy-http/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-http/src/byte_stream.rs @@ -288,7 +288,12 @@ impl ByteStream { Some(self.inner.next().await?.map_err(Error::streaming)) } - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { + /// Attempts to pull out the next value of this stream, returning `None` if the stream is + /// exhausted. + pub fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { self.project().inner.poll_next(cx).map_err(Error::streaming) } diff --git a/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs b/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs index 69ad994876..6e5c669e63 100644 --- a/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs +++ b/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +use aws_smithy_http::byte_stream::error::Error as ByteStreamError; +use aws_smithy_http::byte_stream::ByteStream; use aws_smithy_http::event_stream::MessageStreamAdapter; use aws_smithy_http::result::SdkError; use bytes::Bytes; @@ -11,17 +13,18 @@ use std::error::Error as StdError; use std::pin::Pin; use std::task::{Context, Poll}; -pub(crate) struct HyperBodyWrapStreamCompat(MessageStreamAdapter); +pub(crate) struct HyperBodyWrapEventStream(MessageStreamAdapter); -impl HyperBodyWrapStreamCompat { +impl HyperBodyWrapEventStream { + #[allow(dead_code)] pub(crate) fn new(adapter: MessageStreamAdapter) -> Self { Self(adapter) } } -impl Unpin for HyperBodyWrapStreamCompat {} +impl Unpin for HyperBodyWrapEventStream {} -impl Stream for HyperBodyWrapStreamCompat { +impl Stream for HyperBodyWrapEventStream { type Item = Result>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -29,6 +32,25 @@ impl Stream for HyperBodyWrapStreamCompa } } +pub(crate) struct HyperBodyWrapByteStream(ByteStream); + +impl HyperBodyWrapByteStream { + #[allow(dead_code)] + pub(crate) fn new(stream: ByteStream) -> Self { + Self(stream) + } +} + +impl Unpin for HyperBodyWrapByteStream {} + +impl Stream for HyperBodyWrapByteStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_next(cx) + } +} + #[cfg(test)] mod tests { use super::*; @@ -79,14 +101,14 @@ mod tests { } #[test] - fn test_message_adapter_stream_is_compatible_with_hyper_wrap_stream() { + fn test_message_adapter_stream_can_be_made_compatible_with_hyper_wrap_stream() { let stream = FnStream::new(|tx| { Box::pin(async move { let message = Ok(TestMessage("test".into())); tx.send(message).await.expect("failed to send"); }) }); - check_compatible_with_hyper_wrap_stream(HyperBodyWrapStreamCompat(MessageStreamAdapter::< + check_compatible_with_hyper_wrap_stream(HyperBodyWrapEventStream(MessageStreamAdapter::< TestMessage, TestServiceError, >::new( @@ -96,4 +118,10 @@ mod tests { stream, ))); } + + #[test] + fn test_byte_stream_stream_can_be_made_compatible_with_hyper_wrap_stream() { + let stream = ByteStream::from_static(b"Hello world"); + check_compatible_with_hyper_wrap_stream(HyperBodyWrapByteStream::new(stream)); + } } From 3e05c47c3afa718a53492a1ef8529ffcd5264baf Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Thu, 10 Aug 2023 22:06:46 -0500 Subject: [PATCH 07/36] Fix tests in CI --- .../aws-inlineable/src/glacier_checksums.rs | 1 - .../protocols/HttpBoundProtocolGenerator.kt | 19 +++++++------------ .../codegen/core/rustlang/CargoDependency.kt | 9 ++++++++- .../rust/codegen/core/smithy/RuntimeType.kt | 3 ++- .../ServerHttpBoundProtocolGenerator.kt | 18 +++--------------- .../src/types.rs | 1 - 6 files changed, 20 insertions(+), 31 deletions(-) diff --git a/aws/rust-runtime/aws-inlineable/src/glacier_checksums.rs b/aws/rust-runtime/aws-inlineable/src/glacier_checksums.rs index 18f1d9219e..bf95910e00 100644 --- a/aws/rust-runtime/aws-inlineable/src/glacier_checksums.rs +++ b/aws/rust-runtime/aws-inlineable/src/glacier_checksums.rs @@ -14,7 +14,6 @@ use bytes::Buf; use bytes_utils::SegmentedBuf; use http::header::HeaderName; use ring::digest::{Context, Digest, SHA256}; -use tokio_stream::StreamExt; const TREE_HASH_HEADER: &str = "x-amz-sha256-tree-hash"; const X_AMZ_CONTENT_SHA256: &str = "x-amz-content-sha256"; diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index 064ff150c2..d17f8a97bd 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection @@ -17,8 +16,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Mak import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolParserGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency -import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate @@ -67,7 +64,8 @@ class ClientHttpBoundProtocolPayloadGenerator( ) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator( codegenContext, protocol, HttpMessageType.REQUEST, renderEventStreamBody = { writer, params -> - val propertyBagAvailable = (params.additionalPayloadContext as ClientAdditionalPayloadContext).propertyBagAvailable + val propertyBagAvailable = + (params.additionalPayloadContext as ClientAdditionalPayloadContext).propertyBagAvailable writer.rustTemplate( """ { @@ -77,20 +75,17 @@ class ClientHttpBoundProtocolPayloadGenerator( #{insert_into_config} let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); - let body: #{SdkBody} = #{hyper}::Body::wrap_stream(#{HyperBodyWrapStreamCompat}::new(adapter)).into(); + let body: #{SdkBody} = #{hyper}::Body::wrap_stream(#{HyperBodyWrapEventStream}::new(adapter)).into(); body } """, "hyper" to CargoDependency.HyperWithStream.toType(), "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), - "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig).resolve("frame::DeferredSigner"), - "HyperBodyWrapStreamCompat" to InlineDependency.forRustFile( - RustModule.pubCrate("stream_compat", parent = ClientRustModule.root), - "/inlineable/src/hyper_body_wrap_stream.rs", - CargoDependency.smithyHttpEventStream(codegenContext.runtimeConfig), - CargoDependency.FuturesCore, - ).toType().resolve("HyperBodyWrapStreamCompat"), + "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig) + .resolve("frame::DeferredSigner"), + "HyperBodyWrapEventStream" to RuntimeType.hyperBodyWrapStream(codegenContext.runtimeConfig) + .resolve("HyperBodyWrapEventStream"), "marshallerConstructorFn" to params.marshallerConstructorFn, "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, "insert_into_config" to writable { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index 6367d110a1..beede3bd75 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -127,6 +127,14 @@ class InlineDependency( CargoDependency.smithyTypes(runtimeConfig), ) + fun hyperBodyWrapStream(runtimeConfig: RuntimeConfig): InlineDependency = forInlineableRustFile( + "hyper_body_wrap_stream", + CargoDependency.smithyHttp(runtimeConfig).withFeature("event-stream"), + CargoDependency.FuturesCore, + CargoDependency.smithyAsync(runtimeConfig).toDevDependency(), + CargoDependency.smithyEventStream(runtimeConfig).toDevDependency(), + ) + fun constrained(): InlineDependency = InlineDependency.forRustFile(ConstrainedModule, "/inlineable/src/constrained.rs") } @@ -289,7 +297,6 @@ data class CargoDependency( fun smithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-eventstream") fun smithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http") - fun smithyHttpEventStream(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).withFeature("event-stream") fun smithyHttpAuth(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-auth") fun smithyHttpTower(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-tower") fun smithyJson(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-json") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index e24edc7f86..680366fee2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -383,9 +383,10 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) fun retryErrorKind(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("retry::ErrorKind") fun eventStreamReceiver(runtimeConfig: RuntimeConfig): RuntimeType = smithyHttp(runtimeConfig).resolve("event_stream::Receiver") - fun eventStreamSender(runtimeConfig: RuntimeConfig): RuntimeType = smithyHttp(runtimeConfig).resolve("event_stream::EventStreamSender") + fun hyperBodyWrapStream(runtimeConfig: RuntimeConfig): RuntimeType = + forInlineDependency(InlineDependency.hyperBodyWrapStream(runtimeConfig)) fun errorMetadata(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("error::ErrorMetadata") fun errorMetadataBuilder(runtimeConfig: RuntimeConfig) = diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 6ff7c647cc..609575c540 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -28,9 +28,6 @@ import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.MediaTypeTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency -import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -41,13 +38,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter -import software.amazon.smithy.rust.codegen.core.rustlang.toType import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.hyperBodyWrapStream import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization @@ -77,7 +74,6 @@ import software.amazon.smithy.rust.codegen.core.util.isStreaming import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerRequestBindingGenerator @@ -530,18 +526,10 @@ class ServerHttpBoundProtocolTraitImplGenerator( * on behalf of the inner type. */ private fun futuresStreamCompatible(operationShape: OperationShape): RuntimeType { - val inlineable = InlineDependency.forRustFile( - RustModule.pubCrate("stream_compat", parent = ServerRustModule.root), - "/inlineable/src/hyper_body_wrap_stream.rs", - CargoDependency.smithyHttpEventStream(codegenContext.runtimeConfig), - CargoDependency.smithyAsync(codegenContext.runtimeConfig).toDevDependency(), - CargoDependency.smithyEventStream(codegenContext.runtimeConfig).toDevDependency(), - CargoDependency.FuturesCore, - ).toType() if (operationShape.isEventStream(model)) { - return inlineable.resolve("HyperBodyWrapEventStream") + return hyperBodyWrapStream(codegenContext.runtimeConfig).resolve("HyperBodyWrapEventStream") } - return inlineable.resolve("HyperBodyWrapByteStream") + return hyperBodyWrapStream(codegenContext.runtimeConfig).resolve("HyperBodyWrapByteStream") } /** diff --git a/rust-runtime/aws-smithy-http-server-python/src/types.rs b/rust-runtime/aws-smithy-http-server-python/src/types.rs index a2fa308512..a274efe086 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/types.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/types.rs @@ -31,7 +31,6 @@ use pyo3::{ prelude::*, }; use tokio::{runtime::Handle, sync::Mutex}; -use tokio_stream::StreamExt; use crate::PyError; From 8e90987f5086c01fef991a278bcd1790c0a9d155 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Thu, 10 Aug 2023 22:56:04 -0500 Subject: [PATCH 08/36] Update pokemon example to use `FnStream` --- examples/pokemon-service-common/Cargo.toml | 2 +- examples/pokemon-service-common/src/lib.rs | 94 +++++++++++----------- 2 files changed, 50 insertions(+), 46 deletions(-) diff --git a/examples/pokemon-service-common/Cargo.toml b/examples/pokemon-service-common/Cargo.toml index f2c86eee0e..c0a1aa676c 100644 --- a/examples/pokemon-service-common/Cargo.toml +++ b/examples/pokemon-service-common/Cargo.toml @@ -7,7 +7,6 @@ authors = ["Smithy-rs Server Team "] description = "A smithy Rust service to retrieve information about Pokémon." [dependencies] -async-stream = "0.3" http = "0.2.9" rand = "0.8" tracing = "0.1" @@ -16,6 +15,7 @@ tokio = { version = "1", default-features = false, features = ["time"] } tower = "0.4" # Local paths +aws-smithy-async = { path = "../../rust-runtime/aws-smithy-async" } aws-smithy-client = { path = "../../rust-runtime/aws-smithy-client" } aws-smithy-http = { path = "../../rust-runtime/aws-smithy-http" } aws-smithy-http-server = { path = "../../rust-runtime/aws-smithy-http-server" } diff --git a/examples/pokemon-service-common/src/lib.rs b/examples/pokemon-service-common/src/lib.rs index cb5c4bd604..08d96189c8 100644 --- a/examples/pokemon-service-common/src/lib.rs +++ b/examples/pokemon-service-common/src/lib.rs @@ -14,7 +14,7 @@ use std::{ sync::{atomic::AtomicUsize, Arc}, }; -use async_stream::stream; +use aws_smithy_async::future::fn_stream::FnStream; use aws_smithy_client::{conns, hyper_ext::Adapter}; use aws_smithy_http::{body::SdkBody, byte_stream::ByteStream}; use aws_smithy_http_server::Extension; @@ -242,59 +242,63 @@ pub async fn capture_pokemon( }, )); } - let output_stream = stream! { - loop { - use std::time::Duration; - match input.events.recv().await { - Ok(maybe_event) => match maybe_event { - Some(event) => { - let capturing_event = event.as_event(); - if let Ok(attempt) = capturing_event { - let payload = attempt.payload.clone().unwrap_or_else(|| CapturingPayload::builder().build()); - let pokeball = payload.pokeball().unwrap_or(""); - if ! matches!(pokeball, "Master Ball" | "Great Ball" | "Fast Ball") { - yield Err( + let output_stream = FnStream::new(|tx| { + Box::pin(async move { + loop { + use std::time::Duration; + match input.events.recv().await { + Ok(maybe_event) => match maybe_event { + Some(event) => { + let capturing_event = event.as_event(); + if let Ok(attempt) = capturing_event { + let payload = attempt + .payload + .clone() + .unwrap_or_else(|| CapturingPayload::builder().build()); + let pokeball = payload.pokeball().unwrap_or(""); + if !matches!(pokeball, "Master Ball" | "Great Ball" | "Fast Ball") { + tx.send(Err( crate::error::CapturePokemonEventsError::InvalidPokeballError( crate::error::InvalidPokeballError { pokeball: pokeball.to_owned() } ) - ); - } else { - let captured = match pokeball { - "Master Ball" => true, - "Great Ball" => rand::thread_rng().gen_range(0..100) > 33, - "Fast Ball" => rand::thread_rng().gen_range(0..100) > 66, - _ => unreachable!("invalid pokeball"), - }; - // Only support Kanto - tokio::time::sleep(Duration::from_millis(1000)).await; - // Will it capture the Pokémon? - if captured { - let shiny = rand::thread_rng().gen_range(0..4096) == 0; - let pokemon = payload - .name() - .unwrap_or("") - .to_string(); - let pokedex: Vec = (0..255).collect(); - yield Ok(crate::model::CapturePokemonEvents::Event( - crate::model::CaptureEvent { - name: Some(pokemon), - shiny: Some(shiny), - pokedex_update: Some(Blob::new(pokedex)), - captured: Some(true), - } - )); + )).await.expect("send should succeed"); + } else { + let captured = match pokeball { + "Master Ball" => true, + "Great Ball" => rand::thread_rng().gen_range(0..100) > 33, + "Fast Ball" => rand::thread_rng().gen_range(0..100) > 66, + _ => unreachable!("invalid pokeball"), + }; + // Only support Kanto + tokio::time::sleep(Duration::from_millis(1000)).await; + // Will it capture the Pokémon? + if captured { + let shiny = rand::thread_rng().gen_range(0..4096) == 0; + let pokemon = payload.name().unwrap_or("").to_string(); + let pokedex: Vec = (0..255).collect(); + tx.send(Ok(crate::model::CapturePokemonEvents::Event( + crate::model::CaptureEvent { + name: Some(pokemon), + shiny: Some(shiny), + pokedex_update: Some(Blob::new(pokedex)), + captured: Some(true), + }, + ))) + .await + .expect("send should succeed"); + } } } } - } - None => break, - }, - Err(e) => println!("{:?}", e), + None => break, + }, + Err(e) => println!("{:?}", e), + } } - } - }; + }) + }); Ok(output::CapturePokemonOutput::builder() .events(output_stream.into()) .build() From 492e8bf246779399ac053b2b111e7a722a0680aa Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 11 Aug 2023 13:45:01 -0500 Subject: [PATCH 09/36] Fix more server tests --- .../HttpBoundProtocolPayloadGenerator.kt | 160 ++++++++++++++---- ...PythonServerEventStreamWrapperGenerator.kt | 13 +- .../ServerHttpBoundProtocolGenerator.kt | 18 +- 3 files changed, 141 insertions(+), 50 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 2c6ffc662c..61ef4bae70 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -18,11 +18,14 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.AdditionalPayloadContext @@ -63,11 +66,16 @@ class HttpBoundProtocolPayloadGenerator( private val httpBindingResolver = protocol.httpBindingResolver private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) private val codegenScope = arrayOf( - "hyper" to CargoDependency.HyperWithStream.toType(), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + *preludeScope, "BuildError" to runtimeConfig.operationBuildError(), - "SmithyHttp" to RuntimeType.smithyHttp(runtimeConfig), + "Bytes" to RuntimeType.Bytes, + "ByteStreamError" to RuntimeType.smithyHttp(runtimeConfig).resolve("byte_stream::error::Error"), + "HyperBodyWrapEventStream" to RuntimeType.hyperBodyWrapStream(runtimeConfig).resolve("HyperBodyWrapEventStream"), "NoOpSigner" to smithyEventStream.resolve("frame::NoOpSigner"), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "SmithyHttp" to RuntimeType.smithyHttp(runtimeConfig), + "Stream" to CargoDependency.FuturesCore.toType().resolve("stream::Stream"), + "hyper" to CargoDependency.HyperWithStream.toType(), ) private val protocolFunctions = ProtocolFunctions(codegenContext) @@ -78,6 +86,7 @@ class HttpBoundProtocolPayloadGenerator( val (shape, payloadMemberName) = when (httpMessageType) { HttpMessageType.RESPONSE -> operationShape.outputShape(model) to httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName + HttpMessageType.REQUEST -> operationShape.inputShape(model) to httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName } @@ -97,6 +106,7 @@ class HttpBoundProtocolPayloadGenerator( is DocumentShape, is StructureShape, is UnionShape -> ProtocolPayloadGenerator.PayloadMetadata( takesOwnership = false, ) + is StringShape, is BlobShape -> ProtocolPayloadGenerator.PayloadMetadata(takesOwnership = true) else -> UNREACHABLE("Unexpected payload target type: $type") } @@ -110,8 +120,19 @@ class HttpBoundProtocolPayloadGenerator( additionalPayloadContext: AdditionalPayloadContext, ) { when (httpMessageType) { - HttpMessageType.RESPONSE -> generateResponsePayload(writer, shapeName, operationShape, additionalPayloadContext) - HttpMessageType.REQUEST -> generateRequestPayload(writer, shapeName, operationShape, additionalPayloadContext) + HttpMessageType.RESPONSE -> generateResponsePayload( + writer, + shapeName, + operationShape, + additionalPayloadContext, + ) + + HttpMessageType.REQUEST -> generateRequestPayload( + writer, + shapeName, + operationShape, + additionalPayloadContext, + ) } } @@ -119,13 +140,20 @@ class HttpBoundProtocolPayloadGenerator( writer: RustWriter, shapeName: String, operationShape: OperationShape, additionalPayloadContext: AdditionalPayloadContext, ) { - val payloadMemberName = httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName + val payloadMemberName = + httpBindingResolver.requestMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { val serializerGenerator = protocol.structuredDataSerializer() generateStructureSerializer(writer, shapeName, serializerGenerator.operationInputSerializer(operationShape)) } else { - generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName, additionalPayloadContext) + generatePayloadMemberSerializer( + writer, + shapeName, + operationShape, + payloadMemberName, + additionalPayloadContext, + ) } } @@ -133,13 +161,24 @@ class HttpBoundProtocolPayloadGenerator( writer: RustWriter, shapeName: String, operationShape: OperationShape, additionalPayloadContext: AdditionalPayloadContext, ) { - val payloadMemberName = httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName + val payloadMemberName = + httpBindingResolver.responseMembers(operationShape, HttpLocation.PAYLOAD).firstOrNull()?.memberName if (payloadMemberName == null) { val serializerGenerator = protocol.structuredDataSerializer() - generateStructureSerializer(writer, shapeName, serializerGenerator.operationOutputSerializer(operationShape)) + generateStructureSerializer( + writer, + shapeName, + serializerGenerator.operationOutputSerializer(operationShape), + ) } else { - generatePayloadMemberSerializer(writer, shapeName, operationShape, payloadMemberName, additionalPayloadContext) + generatePayloadMemberSerializer( + writer, + shapeName, + operationShape, + payloadMemberName, + additionalPayloadContext, + ) } } @@ -155,10 +194,22 @@ class HttpBoundProtocolPayloadGenerator( if (operationShape.isEventStream(model)) { if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) { val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName) - writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, shapeName, additionalPayloadContext) + writer.serializeViaEventStream( + operationShape, + payloadMember, + serializerGenerator, + shapeName, + additionalPayloadContext, + ) } else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) { val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName) - writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output", additionalPayloadContext) + writer.serializeViaEventStream( + operationShape, + payloadMember, + serializerGenerator, + "output", + additionalPayloadContext, + ) } else { throw CodegenException("Payload serializer for event streams with an invalid configuration") } @@ -216,18 +267,41 @@ class HttpBoundProtocolPayloadGenerator( contentType, ).render() - // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the - // parameters that are not `@eventHeader` or `@eventPayload`. - renderEventStreamBody( - this, - EventStreamBodyParams( - outerName, - memberName, - marshallerConstructorFn, - errorMarshallerConstructorFn, - additionalPayloadContext, - ), - ) + if (target == CodegenTarget.CLIENT) { + // No need to wrap it with `HyperBodyWrapEventStream` for the client since wrapping takes place + // within `renderEventStreamBody` provided by `ClientHttpBoundProtocolPayloadGenerator`. + + // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the + // parameters that are not `@eventHeader` or `@eventPayload`. + renderEventStreamBody( + this, + EventStreamBodyParams( + outerName, + memberName, + marshallerConstructorFn, + errorMarshallerConstructorFn, + additionalPayloadContext, + ), + ) + } else { + withBlockTemplate( + "#{HyperBodyWrapEventStream}::new(", ")", + *codegenScope, + ) { + // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the + // parameters that are not `@eventHeader` or `@eventPayload`. + renderEventStreamBody( + this, + EventStreamBodyParams( + outerName, + memberName, + marshallerConstructorFn, + errorMarshallerConstructorFn, + additionalPayloadContext, + ), + ) + } + } } private fun RustWriter.serializeViaPayload( @@ -238,15 +312,16 @@ class HttpBoundProtocolPayloadGenerator( ) { val ref = if (payloadMetadata.takesOwnership) "" else "&" val serializer = protocolFunctions.serializeFn(member, fnNameSuffix = "http_payload") { fnName -> - val outputT = if (member.isStreaming(model)) { - symbolProvider.toSymbol(member) - } else { - RuntimeType.ByteSlab.toSymbol() - } rustBlockTemplate( - "pub fn $fnName(payload: $ref#{Member}) -> Result<#{outputT}, #{BuildError}>", + "pub(crate) fn $fnName(payload: $ref#{Member}) -> Result<#{outputT:W}, #{BuildError}>", "Member" to symbolProvider.toSymbol(member), - "outputT" to outputT, + "outputT" to writable { + if (member.isStreaming(model)) { + rustTemplate("impl #{Stream}>", *codegenScope) + } else { + rust("${RuntimeType.ByteSlab.toSymbol()}") + } + }, *codegenScope, ) { val asRef = if (payloadMetadata.takesOwnership) "" else ".as_ref()" @@ -268,6 +343,7 @@ class HttpBoundProtocolPayloadGenerator( Vec::new() """, ) + is StructureShape -> rust("#T()", serializerGenerator.unsetStructure(targetShape)) is UnionShape -> throw CodegenException("Currently unsupported. Tracking issue: https://github.com/awslabs/smithy-rs/issues/1896") else -> throw CodegenException("`httpPayload` on member shapes targeting shapes of type ${targetShape.type} is unsupported") @@ -303,13 +379,26 @@ class HttpBoundProtocolPayloadGenerator( is BlobShape -> { // Write the raw blob to the payload. if (member.isStreaming(model)) { - // Return the `ByteStream`. - rust(payloadName) + if (fromPythonServerRuntime(member)) { + // `aws_smithy_http_server_python::types::ByteStream` already implements + // `futures::stream::Stream`, so no need to wrap it in a futures' stream-compatible + // wrapper. + rust(payloadName) + } else { + // `aws_smithy_http::byte_stream::ByteStream` no longer implements `futures::stream::Stream` + // so wrap it in a new-type to enable the trait. + rustTemplate( + "#{HyperBodyWrapByteStream}::new($payloadName)", + "HyperBodyWrapByteStream" to RuntimeType.hyperBodyWrapStream(runtimeConfig) + .resolve("HyperBodyWrapByteStream"), + ) + } } else { // Convert the `Blob` into a `Vec` and return it. rust("$payloadName.into_inner()") } } + is StructureShape, is UnionShape -> { check( !((targetShape as? UnionShape)?.isEventStream() ?: false), @@ -320,13 +409,18 @@ class HttpBoundProtocolPayloadGenerator( serializer.payloadSerializer(member), ) } + is DocumentShape -> { rust( "#T($payloadName)", serializer.documentSerializer(), ) } + else -> PANIC("Unexpected payload target type: $targetShape") } } + + private fun fromPythonServerRuntime(member: MemberShape) = + symbolProvider.toSymbol(member).namespace.contains("aws_smithy_http_server_python") } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamWrapperGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamWrapperGenerator.kt index 8594d703c1..0832c38d79 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamWrapperGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEventStreamWrapperGenerator.kt @@ -12,6 +12,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream import software.amazon.smithy.rust.codegen.core.util.toPascalCase @@ -66,6 +67,7 @@ class PythonServerEventStreamWrapperGenerator( private val pyO3 = PythonServerCargoDependency.PyO3.toType() private val codegenScope = arrayOf( + *preludeScope, "Inner" to innerT, "Error" to errorT, "SmithyPython" to PythonServerCargoDependency.smithyHttpServerPython(runtimeConfig).toType(), @@ -81,6 +83,7 @@ class PythonServerEventStreamWrapperGenerator( "Option" to RuntimeType.Option, "Arc" to RuntimeType.Arc, "Body" to RuntimeType.sdkBody(runtimeConfig), + "FnStream" to RuntimeType.smithyAsync(runtimeConfig).resolve("future::fn_stream::FnStream"), "UnmarshallMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::UnmarshallMessage"), "MarshallMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::MarshallMessage"), "SignMessage" to RuntimeType.smithyEventStream(runtimeConfig).resolve("frame::SignMessage"), @@ -137,7 +140,7 @@ class PythonServerEventStreamWrapperGenerator( fn extract(obj: &'source #{PyO3}::PyAny) -> #{PyO3}::PyResult { use #{TokioStream}::StreamExt; let stream = #{PyO3Asyncio}::tokio::into_stream_v1(obj)?; - let stream = stream.filter_map(|res| { + let mut stream = stream.filter_map(|res| { #{PyO3}::Python::with_gil(|py| { // TODO(EventStreamImprovements): Add `InternalServerError` variant to all event streaming // errors and return that variant in case of errors here? @@ -166,6 +169,14 @@ class PythonServerEventStreamWrapperGenerator( }) }); + let stream = #{FnStream}::new(|tx| { + Box::pin(async move { + while let #{Some}(item) = stream.next().await { + tx.send(item).await.expect("send should succeed"); + } + }) + }); + Ok($name { inner: #{Arc}::new(#{Mutex}::new(Some(stream.into()))) }) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 609575c540..3a592f1966 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -44,7 +44,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.hyperBodyWrapStream import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.Section import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingCustomization @@ -69,7 +68,6 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape -import software.amazon.smithy.rust.codegen.core.util.isEventStream import software.amazon.smithy.rust.codegen.core.util.isStreaming import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency @@ -521,17 +519,6 @@ class ServerHttpBoundProtocolTraitImplGenerator( } } - /** - * Return a Writable that renders a new-type wrapper implementing "futures_core::stream::Stream" - * on behalf of the inner type. - */ - private fun futuresStreamCompatible(operationShape: OperationShape): RuntimeType { - if (operationShape.isEventStream(model)) { - return hyperBodyWrapStream(codegenContext.runtimeConfig).resolve("HyperBodyWrapEventStream") - } - return hyperBodyWrapStream(codegenContext.runtimeConfig).resolve("HyperBodyWrapByteStream") - } - /** * Render an HTTP response (headers, response code, body) for an operation's output and the given [bindings]. */ @@ -559,10 +546,9 @@ class ServerHttpBoundProtocolTraitImplGenerator( operationShape.outputShape(model).findStreamingMember(model)?.let { val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol) withBlockTemplate( - "let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(#{futures_stream_compatible}::new(", - ")));", + "let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", + "));", *codegenScope, - "futures_stream_compatible" to futuresStreamCompatible(operationShape), ) { payloadGenerator.generatePayload(this, "output", operationShape) } From 2c7541bf69f41a43660d127d63831c181a752555 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 11 Aug 2023 13:50:59 -0500 Subject: [PATCH 10/36] Fix test name that checks both `Send` and `Sync` --- rust-runtime/aws-smithy-http/src/event_stream/sender.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs index 825c93e1f9..5fe3312be3 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs @@ -20,7 +20,7 @@ use tracing::trace; /// Input type for Event Streams. pub struct EventStreamSender { // `FnStream` does not have a `Sync` bound but this struct needs to be `Sync` - // as demonstrated by a unit test `event_stream_sender_send`. + // as demonstrated by a unit test `event_stream_sender_send_sync`. // Wrapping `input_stream` with a `Mutex` will make `EventStreamSender` `Sync`. input_stream: Mutex>>, } @@ -280,7 +280,7 @@ mod tests { } #[test] - fn event_stream_sender_send() { + fn event_stream_sender_send_sync() { check_send_sync(EventStreamSender::from(FnStream::new(|tx| { Box::pin(async move { let message = Result::<_, TestServiceError>::Ok(TestMessage("test".into())); From f7e8a69bf6786be62f6077bb2b11650f577bdc14 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 11 Aug 2023 14:16:25 -0500 Subject: [PATCH 11/36] Fix event-streaming tests in pokemon-service --- examples/pokemon-service/Cargo.toml | 2 +- .../pokemon-service/tests/event_streaming.rs | 107 +++++++++++------- 2 files changed, 68 insertions(+), 41 deletions(-) diff --git a/examples/pokemon-service/Cargo.toml b/examples/pokemon-service/Cargo.toml index d3bc81ea0b..23c3704392 100644 --- a/examples/pokemon-service/Cargo.toml +++ b/examples/pokemon-service/Cargo.toml @@ -20,7 +20,6 @@ pokemon-service-common = { path = "../pokemon-service-common/" } [dev-dependencies] assert_cmd = "2.0" -async-stream = "0.3" rand = "0.8.5" serial_test = "1.0.0" @@ -31,6 +30,7 @@ hyper = { version = "0.14.26", features = ["server", "client"] } hyper-rustls = { version = "0.24", features = ["http2"] } # Local paths +aws-smithy-async = { path = "../../rust-runtime/aws-smithy-async/" } aws-smithy-client = { path = "../../rust-runtime/aws-smithy-client/", features = ["rustls"] } aws-smithy-http = { path = "../../rust-runtime/aws-smithy-http/" } aws-smithy-types = { path = "../../rust-runtime/aws-smithy-types/" } diff --git a/examples/pokemon-service/tests/event_streaming.rs b/examples/pokemon-service/tests/event_streaming.rs index 664827620b..9b86596c4e 100644 --- a/examples/pokemon-service/tests/event_streaming.rs +++ b/examples/pokemon-service/tests/event_streaming.rs @@ -5,7 +5,7 @@ pub mod common; -use async_stream::stream; +use aws_smithy_async::future::fn_stream::FnStream; use rand::Rng; use serial_test::serial; @@ -40,35 +40,56 @@ async fn event_stream_test() { let client = common::client(); let mut team = vec![]; - let input_stream = stream! { - // Always Pikachu - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Pikachu") - .pokeball("Master Ball") - .build()) - .build() - )); - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Regieleki") - .pokeball("Fast Ball") - .build()) - .build() - )); - yield Err(AttemptCapturingPokemonEventError::MasterBallUnsuccessful(MasterBallUnsuccessful::builder().build())); - // The next event should not happen - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Charizard") - .pokeball("Great Ball") - .build()) - .build() - )); - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + // Always Pikachu + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Pikachu") + .pokeball("Master Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Regieleki") + .pokeball("Fast Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + + tx.send(Err( + AttemptCapturingPokemonEventError::MasterBallUnsuccessful( + MasterBallUnsuccessful::builder().build(), + ), + )) + .await + .expect("send should succeed"); + // The next event should not happen + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Charizard") + .pokeball("Great Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + }) + }); // Throw many! let mut output = common::client() @@ -112,16 +133,22 @@ async fn event_stream_test() { while team.len() < 6 { let pokeball = get_pokeball(); let pokemon = get_pokemon_to_capture(); - let input_stream = stream! { - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name(pokemon) - .pokeball(pokeball) - .build()) - .build() - )) - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name(pokemon) + .pokeball(pokeball) + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + }) + }); let mut output = client .capture_pokemon() .region("Kanto") From f3dae469cc78b8019c36954c6d4a53d695d1d6ac Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 11 Aug 2023 14:44:53 -0500 Subject: [PATCH 12/36] Let ser_* return concreate types instead of `impl Stream` --- .../HttpBoundProtocolPayloadGenerator.kt | 34 +++++++++---------- .../inlineable/src/hyper_body_wrap_stream.rs | 6 ++++ 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 61ef4bae70..0c24e26115 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -21,7 +21,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -68,13 +67,10 @@ class HttpBoundProtocolPayloadGenerator( private val codegenScope = arrayOf( *preludeScope, "BuildError" to runtimeConfig.operationBuildError(), - "Bytes" to RuntimeType.Bytes, - "ByteStreamError" to RuntimeType.smithyHttp(runtimeConfig).resolve("byte_stream::error::Error"), "HyperBodyWrapEventStream" to RuntimeType.hyperBodyWrapStream(runtimeConfig).resolve("HyperBodyWrapEventStream"), "NoOpSigner" to smithyEventStream.resolve("frame::NoOpSigner"), "SdkBody" to RuntimeType.sdkBody(runtimeConfig), "SmithyHttp" to RuntimeType.smithyHttp(runtimeConfig), - "Stream" to CargoDependency.FuturesCore.toType().resolve("stream::Stream"), "hyper" to CargoDependency.HyperWithStream.toType(), ) private val protocolFunctions = ProtocolFunctions(codegenContext) @@ -312,16 +308,25 @@ class HttpBoundProtocolPayloadGenerator( ) { val ref = if (payloadMetadata.takesOwnership) "" else "&" val serializer = protocolFunctions.serializeFn(member, fnNameSuffix = "http_payload") { fnName -> + val outputT = if (member.isStreaming(model)) { + if (fromPythonServerRuntime(member)) { + // `aws_smithy_http_server_python::types::ByteStream` already implements + // `futures::stream::Stream`, so no need to wrap it in a futures' stream-compatible + // wrapper. + symbolProvider.toSymbol(member) + } else { + // `aws_smithy_http::byte_stream::ByteStream` no longer implements `futures::stream::Stream` + // so wrap it in a new-type to enable the trait. + RuntimeType.hyperBodyWrapStream(runtimeConfig) + .resolve("HyperBodyWrapByteStream").toSymbol() + } + } else { + RuntimeType.ByteSlab.toSymbol() + } rustBlockTemplate( - "pub(crate) fn $fnName(payload: $ref#{Member}) -> Result<#{outputT:W}, #{BuildError}>", + "pub(crate) fn $fnName(payload: $ref#{Member}) -> Result<#{outputT}, #{BuildError}>", "Member" to symbolProvider.toSymbol(member), - "outputT" to writable { - if (member.isStreaming(model)) { - rustTemplate("impl #{Stream}>", *codegenScope) - } else { - rust("${RuntimeType.ByteSlab.toSymbol()}") - } - }, + "outputT" to outputT, *codegenScope, ) { val asRef = if (payloadMetadata.takesOwnership) "" else ".as_ref()" @@ -380,13 +385,8 @@ class HttpBoundProtocolPayloadGenerator( // Write the raw blob to the payload. if (member.isStreaming(model)) { if (fromPythonServerRuntime(member)) { - // `aws_smithy_http_server_python::types::ByteStream` already implements - // `futures::stream::Stream`, so no need to wrap it in a futures' stream-compatible - // wrapper. rust(payloadName) } else { - // `aws_smithy_http::byte_stream::ByteStream` no longer implements `futures::stream::Stream` - // so wrap it in a new-type to enable the trait. rustTemplate( "#{HyperBodyWrapByteStream}::new($payloadName)", "HyperBodyWrapByteStream" to RuntimeType.hyperBodyWrapStream(runtimeConfig) diff --git a/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs b/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs index 6e5c669e63..1dc6c2be33 100644 --- a/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs +++ b/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +use aws_smithy_http::body::SdkBody; use aws_smithy_http::byte_stream::error::Error as ByteStreamError; use aws_smithy_http::byte_stream::ByteStream; use aws_smithy_http::event_stream::MessageStreamAdapter; @@ -39,6 +40,11 @@ impl HyperBodyWrapByteStream { pub(crate) fn new(stream: ByteStream) -> Self { Self(stream) } + + #[allow(dead_code)] + pub(crate) fn into_inner(self) -> SdkBody { + self.0.into_inner() + } } impl Unpin for HyperBodyWrapByteStream {} From 3aae47016f3c3c8a70ba2ff6eb0828d32c0d1861 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 11 Aug 2023 14:53:13 -0500 Subject: [PATCH 13/36] Port `event_stream_input_ergonomics` test --- .../aws-smithy-http/src/event_stream/sender.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs index 5fe3312be3..4274e210af 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs @@ -341,4 +341,22 @@ mod tests { SdkError::ConstructionFailure(_) )); } + + // Verify the developer experience for this compiles + #[allow(unused)] + fn event_stream_input_ergonomics() { + fn check(input: impl Into>) { + let _: EventStreamSender = input.into(); + } + check(FnStream::new(|tx| { + Box::pin(async move { + tx.send(Ok(TestMessage("test".into()))).await.unwrap(); + }) + })); + check(FnStream::new(|tx| { + Box::pin(async move { + tx.send(Err(TestServiceError)).await.unwrap(); + }) + })); + } } From 3944bb3504bb25d81cecf3aa86436ac4243bffc8 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 11 Aug 2023 15:46:26 -0500 Subject: [PATCH 14/36] Remove unused dependency from `aws-smithy-async` --- rust-runtime/aws-smithy-async/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/rust-runtime/aws-smithy-async/Cargo.toml b/rust-runtime/aws-smithy-async/Cargo.toml index 6b3bca322d..55907cb079 100644 --- a/rust-runtime/aws-smithy-async/Cargo.toml +++ b/rust-runtime/aws-smithy-async/Cargo.toml @@ -14,7 +14,6 @@ test-util = [] [dependencies] pin-project-lite = "0.2" tokio = { version = "1.23.1", features = ["sync"] } -futures-util = { version = "0.3.16", default-features = false } [dev-dependencies] pin-utils = "0.1" From fd70e49426010b7eb7142cd7aa92cb68b8548db6 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 11 Aug 2023 15:49:28 -0500 Subject: [PATCH 15/36] Update `pokemon-service-test` to use `FnStream` --- .../python/pokemon-service-test/Cargo.toml | 2 +- .../tests/simple_integration_test.rs | 106 +++++++++++------- 2 files changed, 67 insertions(+), 41 deletions(-) diff --git a/examples/python/pokemon-service-test/Cargo.toml b/examples/python/pokemon-service-test/Cargo.toml index b4084185c2..2821d8c0b8 100644 --- a/examples/python/pokemon-service-test/Cargo.toml +++ b/examples/python/pokemon-service-test/Cargo.toml @@ -8,7 +8,6 @@ description = "Run tests against the Python server implementation" [dev-dependencies] rand = "0.8" -async-stream = "0.3" command-group = "2.1.0" tokio = { version = "1.20.1", features = ["full"] } serial_test = "2.0.0" @@ -17,6 +16,7 @@ tokio-rustls = "0.24.0" hyper-rustls = { version = "0.24", features = ["http2"] } # Local paths +aws-smithy-async = { path = "../../../rust-runtime/aws-smithy-async/" } aws-smithy-client = { path = "../../../rust-runtime/aws-smithy-client/", features = ["rustls"] } aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http/" } aws-smithy-types = { path = "../../../rust-runtime/aws-smithy-types/" } diff --git a/examples/python/pokemon-service-test/tests/simple_integration_test.rs b/examples/python/pokemon-service-test/tests/simple_integration_test.rs index 39f979e9a3..a5ed604679 100644 --- a/examples/python/pokemon-service-test/tests/simple_integration_test.rs +++ b/examples/python/pokemon-service-test/tests/simple_integration_test.rs @@ -7,7 +7,7 @@ // These tests only have access to your crate's public API. // See: https://doc.rust-lang.org/book/ch11-03-test-organization.html#integration-tests -use async_stream::stream; +use aws_smithy_async::future::fn_stream::FnStream; use aws_smithy_types::error::display::DisplayErrorContext; use rand::Rng; use serial_test::serial; @@ -75,35 +75,55 @@ async fn event_stream_test() { let _program = PokemonService::run().await; let mut team = vec![]; - let input_stream = stream! { - // Always Pikachu - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Pikachu") - .pokeball("Master Ball") - .build()) - .build() - )); - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Regieleki") - .pokeball("Fast Ball") - .build()) - .build() - )); - yield Err(AttemptCapturingPokemonEventError::MasterBallUnsuccessful(MasterBallUnsuccessful::builder().build())); - // The next event should not happen - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name("Charizard") - .pokeball("Great Ball") - .build()) - .build() - )); - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + // Always Pikachu + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Pikachu") + .pokeball("Master Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Regieleki") + .pokeball("Fast Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + tx.send(Err( + AttemptCapturingPokemonEventError::MasterBallUnsuccessful( + MasterBallUnsuccessful::builder().build(), + ), + )) + .await + .expect("send should succeed"); + // The next event should not happen + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name("Charizard") + .pokeball("Great Ball") + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + }) + }); // Throw many! let mut output = client() @@ -147,16 +167,22 @@ async fn event_stream_test() { while team.len() < 6 { let pokeball = get_pokeball(); let pokemon = get_pokemon_to_capture(); - let input_stream = stream! { - yield Ok(AttemptCapturingPokemonEvent::Event( - CapturingEvent::builder() - .payload(CapturingPayload::builder() - .name(pokemon) - .pokeball(pokeball) - .build()) - .build() - )) - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + tx.send(Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder() + .payload( + CapturingPayload::builder() + .name(pokemon) + .pokeball(pokeball) + .build(), + ) + .build(), + ))) + .await + .expect("send should succeed"); + }) + }); let mut output = client() .capture_pokemon() .region("Kanto") From ae047e5b4afcbe9eb9603b7ae867ba8f3e95de01 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 11 Aug 2023 15:50:40 -0500 Subject: [PATCH 16/36] Remove unused `tokio_stream::StreamExt` from `canary-lambda` --- tools/ci-cdk/canary-lambda/src/latest/paginator_canary.rs | 1 - .../canary-lambda/src/release_2023_01_26/paginator_canary.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/tools/ci-cdk/canary-lambda/src/latest/paginator_canary.rs b/tools/ci-cdk/canary-lambda/src/latest/paginator_canary.rs index d50c4f2be8..11914660ad 100644 --- a/tools/ci-cdk/canary-lambda/src/latest/paginator_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/latest/paginator_canary.rs @@ -10,7 +10,6 @@ use aws_sdk_ec2 as ec2; use aws_sdk_ec2::types::InstanceType; use crate::CanaryEnv; -use tokio_stream::StreamExt; mk_canary!( "ec2_paginator", diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs index 72c9b40ed0..727c5c4c61 100644 --- a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs @@ -10,7 +10,6 @@ use aws_sdk_ec2 as ec2; use aws_sdk_ec2::model::InstanceType; use crate::CanaryEnv; -use tokio_stream::StreamExt; mk_canary!( "ec2_paginator", From 50fe74cdd9ea6ca6588938ced46bcf5ef8830699 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 11 Aug 2023 16:03:32 -0500 Subject: [PATCH 17/36] Tell udeps `futures_util` is used --- rust-runtime/aws-smithy-async/Cargo.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/rust-runtime/aws-smithy-async/Cargo.toml b/rust-runtime/aws-smithy-async/Cargo.toml index 55907cb079..fd51b4fb1e 100644 --- a/rust-runtime/aws-smithy-async/Cargo.toml +++ b/rust-runtime/aws-smithy-async/Cargo.toml @@ -14,12 +14,18 @@ test-util = [] [dependencies] pin-project-lite = "0.2" tokio = { version = "1.23.1", features = ["sync"] } +futures-util = { version = "0.3.16", default-features = false } [dev-dependencies] pin-utils = "0.1" tokio = { version = "1.23.1", features = ["rt", "macros", "test-util"] } tokio-test = "0.4.2" +# futures-util is used by `now_or_later`, for instance, but the tooling +# reports a false positive, saying it is unused. +[package.metadata.cargo-udeps.ignore] +normal = ["futures-util"] + [package.metadata.docs.rs] all-features = true targets = ["x86_64-unknown-linux-gnu"] From ccf9d06d77d46bc3a3d0ec63bcc7c16b27f730f9 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Tue, 15 Aug 2023 18:51:43 -0500 Subject: [PATCH 18/36] Update latest for `canary-lambda` --- .../src/latest/transcribe_canary.rs | 20 ++- tools/ci-cdk/canary-lambda/src/main.rs | 6 + .../release_2023_01_26/paginator_canary.rs | 1 + .../canary-lambda/src/release_2023_08_03.rs | 8 + .../release_2023_08_03/paginator_canary.rs | 71 +++++++++ .../src/release_2023_08_03/s3_canary.rs | 140 ++++++++++++++++++ .../release_2023_08_03/transcribe_canary.rs | 92 ++++++++++++ 7 files changed, 331 insertions(+), 7 deletions(-) create mode 100644 tools/ci-cdk/canary-lambda/src/release_2023_08_03.rs create mode 100644 tools/ci-cdk/canary-lambda/src/release_2023_08_03/paginator_canary.rs create mode 100644 tools/ci-cdk/canary-lambda/src/release_2023_08_03/s3_canary.rs create mode 100644 tools/ci-cdk/canary-lambda/src/release_2023_08_03/transcribe_canary.rs diff --git a/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs b/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs index 8f6420fc1b..903329e555 100644 --- a/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs @@ -5,9 +5,9 @@ use crate::canary::CanaryError; use crate::mk_canary; -use async_stream::stream; use aws_config::SdkConfig; use aws_sdk_transcribestreaming as transcribe; +use aws_smithy_async::future::fn_stream::FnStream; use bytes::BufMut; use transcribe::primitives::Blob; use transcribe::types::{ @@ -31,12 +31,18 @@ pub async fn transcribe_canary( client: transcribe::Client, expected_transcribe_result: String, ) -> anyhow::Result<()> { - let input_stream = stream! { - let pcm = pcm_data(); - for chunk in pcm.chunks(CHUNK_SIZE) { - yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); - } - }; + let input_stream = FnStream::new(|tx| { + Box::pin(async move { + let pcm = pcm_data(); + for chunk in pcm.chunks(CHUNK_SIZE) { + tx.send(Ok(AudioStream::AudioEvent( + AudioEvent::builder().audio_chunk(Blob::new(chunk)).build(), + ))) + .await + .expect("send should succeed"); + } + }) + }); let mut output = client .start_stream_transcription() diff --git a/tools/ci-cdk/canary-lambda/src/main.rs b/tools/ci-cdk/canary-lambda/src/main.rs index 688462031d..de564d2a2d 100644 --- a/tools/ci-cdk/canary-lambda/src/main.rs +++ b/tools/ci-cdk/canary-lambda/src/main.rs @@ -26,6 +26,12 @@ mod latest; #[cfg(feature = "latest")] pub(crate) use latest as current_canary; +// NOTE: This module can be deleted 3 releases after release-2023-08-03 +#[cfg(feature = "release-2023-08-03")] +mod release_2023_08_03; +#[cfg(feature = "release-2023-08-03")] +pub(crate) use release_2023_08_03 as current_canary; + // NOTE: This module can be deleted 3 releases after release-2023-01-26 #[cfg(feature = "release-2023-01-26")] mod release_2023_01_26; diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs index 727c5c4c61..72c9b40ed0 100644 --- a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs @@ -10,6 +10,7 @@ use aws_sdk_ec2 as ec2; use aws_sdk_ec2::model::InstanceType; use crate::CanaryEnv; +use tokio_stream::StreamExt; mk_canary!( "ec2_paginator", diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_08_03.rs b/tools/ci-cdk/canary-lambda/src/release_2023_08_03.rs new file mode 100644 index 0000000000..238c361166 --- /dev/null +++ b/tools/ci-cdk/canary-lambda/src/release_2023_08_03.rs @@ -0,0 +1,8 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +pub(crate) mod paginator_canary; +pub(crate) mod s3_canary; +pub(crate) mod transcribe_canary; diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_08_03/paginator_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/paginator_canary.rs new file mode 100644 index 0000000000..d50c4f2be8 --- /dev/null +++ b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/paginator_canary.rs @@ -0,0 +1,71 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use crate::mk_canary; +use anyhow::bail; + +use aws_sdk_ec2 as ec2; +use aws_sdk_ec2::types::InstanceType; + +use crate::CanaryEnv; +use tokio_stream::StreamExt; + +mk_canary!( + "ec2_paginator", + |sdk_config: &aws_config::SdkConfig, env: &CanaryEnv| { + paginator_canary(ec2::Client::new(sdk_config), env.page_size) + } +); + +pub async fn paginator_canary(client: ec2::Client, page_size: usize) -> anyhow::Result<()> { + let mut history = client + .describe_spot_price_history() + .instance_types(InstanceType::M1Medium) + .into_paginator() + .page_size(page_size as i32) + .send(); + + let mut num_pages = 0; + while let Some(page) = history.try_next().await? { + let items_in_page = page.spot_price_history.unwrap_or_default().len(); + if items_in_page > page_size { + bail!( + "failed to retrieve results of correct page size (expected {}, got {})", + page_size, + items_in_page + ) + } + num_pages += 1; + } + if dbg!(num_pages) < 2 { + bail!( + "expected 3+ pages containing ~60 results but got {} pages", + num_pages + ) + } + + // https://github.com/awslabs/aws-sdk-rust/issues/405 + let _ = client + .describe_vpcs() + .into_paginator() + .items() + .send() + .collect::, _>>() + .await?; + + Ok(()) +} + +#[cfg(test)] +mod test { + use crate::latest::paginator_canary::paginator_canary; + + #[tokio::test] + async fn test_paginator() { + let conf = aws_config::load_from_env().await; + let client = aws_sdk_ec2::Client::new(&conf); + paginator_canary(client, 20).await.unwrap() + } +} diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_08_03/s3_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/s3_canary.rs new file mode 100644 index 0000000000..fbcba976d8 --- /dev/null +++ b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/s3_canary.rs @@ -0,0 +1,140 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use crate::canary::CanaryError; +use crate::{mk_canary, CanaryEnv}; +use anyhow::Context; +use aws_config::SdkConfig; +use aws_sdk_s3 as s3; +use s3::presigning::PresigningConfig; +use s3::primitives::ByteStream; +use std::time::Duration; +use uuid::Uuid; + +const METADATA_TEST_VALUE: &str = "some value"; + +mk_canary!("s3", |sdk_config: &SdkConfig, env: &CanaryEnv| s3_canary( + s3::Client::new(sdk_config), + env.s3_bucket_name.clone() +)); + +pub async fn s3_canary(client: s3::Client, s3_bucket_name: String) -> anyhow::Result<()> { + let test_key = Uuid::new_v4().as_u128().to_string(); + + // Look for the test object and expect that it doesn't exist + match client + .get_object() + .bucket(&s3_bucket_name) + .key(&test_key) + .send() + .await + { + Ok(_) => { + return Err( + CanaryError(format!("Expected object {} to not exist in S3", test_key)).into(), + ); + } + Err(err) => { + let err = err.into_service_error(); + // If we get anything other than "No such key", we have a problem + if !err.is_no_such_key() { + return Err(err).context("unexpected s3::GetObject failure"); + } + } + } + + // Put the test object + client + .put_object() + .bucket(&s3_bucket_name) + .key(&test_key) + .body(ByteStream::from_static(b"test")) + .metadata("something", METADATA_TEST_VALUE) + .send() + .await + .context("s3::PutObject")?; + + // Get the test object and verify it looks correct + let output = client + .get_object() + .bucket(&s3_bucket_name) + .key(&test_key) + .send() + .await + .context("s3::GetObject[2]")?; + + // repeat the test with a presigned url + let uri = client + .get_object() + .bucket(&s3_bucket_name) + .key(&test_key) + .presigned(PresigningConfig::expires_in(Duration::from_secs(120)).unwrap()) + .await + .unwrap(); + let response = reqwest::get(uri.uri().to_string()) + .await + .context("s3::presigned")? + .text() + .await?; + if response != "test" { + return Err(CanaryError(format!("presigned URL returned bad data: {:?}", response)).into()); + } + + let mut result = Ok(()); + match output.metadata() { + Some(map) => { + // Option::as_deref doesn't work here since the deref of &String is String + let value = map.get("something").map(|s| s.as_str()).unwrap_or(""); + if value != METADATA_TEST_VALUE { + result = Err(CanaryError(format!( + "S3 metadata was incorrect. Expected `{}` but got `{}`.", + METADATA_TEST_VALUE, value + )) + .into()); + } + } + None => { + result = Err(CanaryError("S3 metadata was missing".into()).into()); + } + } + + let payload = output + .body + .collect() + .await + .context("download s3::GetObject[2] body")? + .into_bytes(); + if std::str::from_utf8(payload.as_ref()).context("s3 payload")? != "test" { + result = Err(CanaryError("S3 object body didn't match what was put there".into()).into()); + } + + // Delete the test object + client + .delete_object() + .bucket(&s3_bucket_name) + .key(&test_key) + .send() + .await + .context("s3::DeleteObject")?; + + result +} + +// This test runs against an actual AWS account. Comment out the `ignore` to run it. +// Be sure to set the `TEST_S3_BUCKET` environment variable to the S3 bucket to use, +// and also make sure the credential profile sets the region (or set `AWS_DEFAULT_PROFILE`). +#[ignore] +#[cfg(test)] +#[tokio::test] +async fn test_s3_canary() { + let config = aws_config::load_from_env().await; + let client = s3::Client::new(&config); + s3_canary( + client, + std::env::var("TEST_S3_BUCKET").expect("TEST_S3_BUCKET must be set"), + ) + .await + .expect("success"); +} diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_08_03/transcribe_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/transcribe_canary.rs new file mode 100644 index 0000000000..8f6420fc1b --- /dev/null +++ b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/transcribe_canary.rs @@ -0,0 +1,92 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use crate::canary::CanaryError; +use crate::mk_canary; +use async_stream::stream; +use aws_config::SdkConfig; +use aws_sdk_transcribestreaming as transcribe; +use bytes::BufMut; +use transcribe::primitives::Blob; +use transcribe::types::{ + AudioEvent, AudioStream, LanguageCode, MediaEncoding, TranscriptResultStream, +}; + +const CHUNK_SIZE: usize = 8192; +use crate::canary::CanaryEnv; + +mk_canary!( + "transcribe_canary", + |sdk_config: &SdkConfig, env: &CanaryEnv| { + transcribe_canary( + transcribe::Client::new(sdk_config), + env.expected_transcribe_result.clone(), + ) + } +); + +pub async fn transcribe_canary( + client: transcribe::Client, + expected_transcribe_result: String, +) -> anyhow::Result<()> { + let input_stream = stream! { + let pcm = pcm_data(); + for chunk in pcm.chunks(CHUNK_SIZE) { + yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); + } + }; + + let mut output = client + .start_stream_transcription() + .language_code(LanguageCode::EnGb) + .media_sample_rate_hertz(8000) + .media_encoding(MediaEncoding::Pcm) + .audio_stream(input_stream.into()) + .send() + .await?; + + let mut full_message = String::new(); + while let Some(event) = output.transcript_result_stream.recv().await? { + match event { + TranscriptResultStream::TranscriptEvent(transcript_event) => { + let transcript = transcript_event.transcript.unwrap(); + for result in transcript.results.unwrap_or_default() { + if !result.is_partial { + let first_alternative = &result.alternatives.as_ref().unwrap()[0]; + full_message += first_alternative.transcript.as_ref().unwrap(); + full_message.push(' '); + } + } + } + otherwise => panic!("received unexpected event type: {:?}", otherwise), + } + } + + if expected_transcribe_result != full_message.trim() { + Err(CanaryError(format!( + "Transcription from Transcribe doesn't look right:\n\ + Expected: `{}`\n\ + Actual: `{}`\n", + expected_transcribe_result, + full_message.trim() + )) + .into()) + } else { + Ok(()) + } +} + +fn pcm_data() -> Vec { + let reader = + hound::WavReader::new(&include_bytes!("../../audio/hello-transcribe-8000.wav")[..]) + .expect("valid wav data"); + let samples_result: hound::Result> = reader.into_samples::().collect(); + + let mut pcm: Vec = Vec::new(); + for sample in samples_result.unwrap() { + pcm.put_i16_le(sample); + } + pcm +} From d31669b4c79bdf0bed28cb5841f163525f690df7 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Tue, 15 Aug 2023 21:19:38 -0500 Subject: [PATCH 19/36] Append `release-2023-08-03` to `NOTABLE_SDK_RELEASE_TAGS` This commit appends `release-2023-08-03` to `NOTABLE_SDK_RELEASE_TAGS`. Elements in that `Vec` should be sorted in an ascending order for a function `enabled_feature` to work correctly. The change has been verified by the following (each executed from canary-runner directory) ``` cargo run -- build-bundle \ --sdk-release-tag release-2023-08-03 \ --canary-path ../canary-lambda \ --manifest-only --musl && \ cd ../canary-lambda && \ cargo check ``` ``` cargo run -- build-bundle \ --sdk-release-tag release-2023-05-24 \ --canary-path ../canary-lambda \ --manifest-only --musl && \ cd ../canary-lambda && \ cargo check ``` ``` cargo run -- build-bundle \ --sdk-release-tag release-2023-12-14 \ --canary-path ../canary-lambda \ --manifest-only --musl && \ cd ../canary-lambda && \ cargo check --- .../ci-cdk/canary-runner/src/build_bundle.rs | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/tools/ci-cdk/canary-runner/src/build_bundle.rs b/tools/ci-cdk/canary-runner/src/build_bundle.rs index 464ee2e4ad..a93ebcda5b 100644 --- a/tools/ci-cdk/canary-runner/src/build_bundle.rs +++ b/tools/ci-cdk/canary-runner/src/build_bundle.rs @@ -63,9 +63,11 @@ const REQUIRED_SDK_CRATES: &[&str] = &[ "aws-sdk-transcribestreaming", ]; +// The elements in this `Vec` should be sorted in an ascending order by the release date. lazy_static! { static ref NOTABLE_SDK_RELEASE_TAGS: Vec = vec![ ReleaseTag::from_str("release-2023-01-26").unwrap(), // last version before the crate reorg + ReleaseTag::from_str("release-2023-08-03").unwrap(), // last version before `Stream` trait removal ]; } @@ -112,23 +114,18 @@ enum CrateSource { }, } -fn enabled_features(crate_source: &CrateSource) -> Vec { - let mut enabled = Vec::new(); +fn enabled_feature(crate_source: &CrateSource) -> String { if let CrateSource::VersionsManifest { release_tag, .. } = crate_source { - // we want to select the newest module specified after this release + // we want to select the oldest module specified after this release for notable in NOTABLE_SDK_RELEASE_TAGS.iter() { tracing::debug!(release_tag = ?release_tag, notable = ?notable, "considering if release tag came before notable release"); if release_tag <= notable { tracing::debug!("selecting {} as chosen release", notable); - enabled.push(notable.as_str().into()); - break; + return notable.as_str().into(); } } } - if enabled.is_empty() { - enabled.push("latest".into()); - } - enabled + "latest".into() } fn generate_crate_manifest(crate_source: CrateSource) -> Result { @@ -176,12 +173,8 @@ fn generate_crate_manifest(crate_source: CrateSource) -> Result { } writeln!( output, - "default = [{enabled}]", - enabled = enabled_features(&crate_source) - .into_iter() - .map(|f| format!("\"{f}\"")) - .collect::>() - .join(", ") + "default = [\"{enabled}\"]", + enabled = enabled_feature(&crate_source) ) .unwrap(); Ok(output) @@ -442,6 +435,7 @@ aws-sdk-transcribestreaming = { path = "some/sdk/path/transcribestreaming" } [features] latest = [] "release-2023-01-26" = [] +"release-2023-08-03" = [] default = ["latest"] "#, generate_crate_manifest(CrateSource::Path("some/sdk/path".into())).expect("success") @@ -506,6 +500,7 @@ aws-sdk-transcribestreaming = "0.16.0" [features] latest = [] "release-2023-01-26" = [] +"release-2023-08-03" = [] default = ["latest"] "#, generate_crate_manifest(CrateSource::VersionsManifest { @@ -523,7 +518,7 @@ default = ["latest"] .collect(), release: None, }, - release_tag: ReleaseTag::from_str("release-2023-05-26").unwrap(), + release_tag: ReleaseTag::from_str("release-2023-08-26").unwrap(), }) .expect("success") ); @@ -577,26 +572,32 @@ default = ["latest"] release: None, }; assert_eq!( - enabled_features(&CrateSource::VersionsManifest { + "latest".to_string(), + enabled_feature(&CrateSource::VersionsManifest { + versions: versions.clone(), + release_tag: "release-9999-12-31".parse().unwrap(), + }), + ); + assert_eq!( + "release-2023-08-03".to_string(), + enabled_feature(&CrateSource::VersionsManifest { versions: versions.clone(), release_tag: "release-2023-02-23".parse().unwrap(), }), - vec!["latest".to_string()] ); - assert_eq!( - enabled_features(&CrateSource::VersionsManifest { + "release-2023-01-26".to_string(), + enabled_feature(&CrateSource::VersionsManifest { versions: versions.clone(), release_tag: "release-2023-01-26".parse().unwrap(), }), - vec!["release-2023-01-26".to_string()] ); assert_eq!( - enabled_features(&CrateSource::VersionsManifest { + "release-2023-01-26".to_string(), + enabled_feature(&CrateSource::VersionsManifest { versions, release_tag: "release-2023-01-13".parse().unwrap(), }), - vec!["release-2023-01-26".to_string()] ); } } From f8c0c40e11e021fca5edf3c2697ef787ce077065 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 25 Aug 2023 12:55:08 -0500 Subject: [PATCH 20/36] Add dependency on `aws-smithy-async` --- .../ci-cdk/canary-runner/src/build_bundle.rs | 63 ++++++++++++------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/tools/ci-cdk/canary-runner/src/build_bundle.rs b/tools/ci-cdk/canary-runner/src/build_bundle.rs index a93ebcda5b..71e6cd3817 100644 --- a/tools/ci-cdk/canary-runner/src/build_bundle.rs +++ b/tools/ci-cdk/canary-runner/src/build_bundle.rs @@ -56,6 +56,8 @@ tracing-texray = "0.1.1" reqwest = { version = "0.11.14", features = ["rustls-tls"], default-features = false } "#; +const REQUIRED_SMITHY_CRATES: &[&str] = &["aws-smithy-async"]; + const REQUIRED_SDK_CRATES: &[&str] = &[ "aws-config", "aws-sdk-s3", @@ -130,17 +132,43 @@ fn enabled_feature(crate_source: &CrateSource) -> String { fn generate_crate_manifest(crate_source: CrateSource) -> Result { let mut output = BASE_MANIFEST.to_string(); - for &sdk_crate in REQUIRED_SDK_CRATES { + write_dependencies(REQUIRED_SMITHY_CRATES, &mut output, &crate_source)?; + write_dependencies(REQUIRED_SDK_CRATES, &mut output, &crate_source)?; + write!(output, "\n[features]\n").unwrap(); + writeln!(output, "latest = []").unwrap(); + for release_tag in NOTABLE_SDK_RELEASE_TAGS.iter() { + writeln!( + output, + "\"{release_tag}\" = []", + release_tag = release_tag.as_str() + ) + .unwrap(); + } + writeln!( + output, + "default = [\"{enabled}\"]", + enabled = enabled_feature(&crate_source) + ) + .unwrap(); + Ok(output) +} + +fn write_dependencies( + required_crates: &[&str], + output: &mut String, + crate_source: &CrateSource, +) -> Result<()> { + for &required_crate in required_crates { match &crate_source { CrateSource::Path(path) => { - let path_name = match sdk_crate.strip_prefix("aws-sdk-") { + let path_name = match required_crate.strip_prefix("aws-sdk-") { Some(path) => path, - None => sdk_crate, + None => required_crate, }; let crate_path = path.join(path_name); writeln!( output, - r#"{sdk_crate} = {{ path = "{path}" }}"#, + r#"{required_crate} = {{ path = "{path}" }}"#, path = crate_path.to_string_lossy() ) .unwrap() @@ -148,36 +176,20 @@ fn generate_crate_manifest(crate_source: CrateSource) -> Result { CrateSource::VersionsManifest { versions, release_tag, - } => match versions.crates.get(sdk_crate) { + } => match versions.crates.get(required_crate) { Some(version) => writeln!( output, - r#"{sdk_crate} = "{version}""#, + r#"{required_crate} = "{version}""#, version = version.version ) .unwrap(), None => { - bail!("Couldn't find `{sdk_crate}` in versions.toml for `{release_tag}`") + bail!("Couldn't find `{required_crate}` in versions.toml for `{release_tag}`") } }, } } - write!(output, "\n[features]\n").unwrap(); - writeln!(output, "latest = []").unwrap(); - for release_tag in NOTABLE_SDK_RELEASE_TAGS.iter() { - writeln!( - output, - "\"{release_tag}\" = []", - release_tag = release_tag.as_str() - ) - .unwrap(); - } - writeln!( - output, - "default = [\"{enabled}\"]", - enabled = enabled_feature(&crate_source) - ) - .unwrap(); - Ok(output) + Ok(()) } fn sha1_file(path: &Path) -> Result { @@ -427,6 +439,7 @@ uuid = { version = "0.8", features = ["v4"] } tokio-stream = "0" tracing-texray = "0.1.1" reqwest = { version = "0.11.14", features = ["rustls-tls"], default-features = false } +aws-smithy-async = { path = "some/sdk/path/aws-smithy-async" } aws-config = { path = "some/sdk/path/aws-config" } aws-sdk-s3 = { path = "some/sdk/path/s3" } aws-sdk-ec2 = { path = "some/sdk/path/ec2" } @@ -492,6 +505,7 @@ uuid = { version = "0.8", features = ["v4"] } tokio-stream = "0" tracing-texray = "0.1.1" reqwest = { version = "0.11.14", features = ["rustls-tls"], default-features = false } +aws-smithy-async = "0.46.0" aws-config = "0.46.0" aws-sdk-s3 = "0.20.0" aws-sdk-ec2 = "0.19.0" @@ -509,6 +523,7 @@ default = ["latest"] aws_doc_sdk_examples_revision: "some-revision-docs".into(), manual_interventions: Default::default(), crates: [ + crate_version("aws-smithy-async", "0.46.0"), crate_version("aws-config", "0.46.0"), crate_version("aws-sdk-s3", "0.20.0"), crate_version("aws-sdk-ec2", "0.19.0"), From d853cb13639f5f1979ec7c8ffc4555e413d38ff2 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 25 Aug 2023 15:03:53 -0500 Subject: [PATCH 21/36] Remove 2023_01_26 that is more than 3 releases ago --- tools/ci-cdk/canary-lambda/src/main.rs | 6 - .../canary-lambda/src/release_2023_01_26.rs | 8 - .../release_2023_01_26/paginator_canary.rs | 71 --------- .../src/release_2023_01_26/s3_canary.rs | 140 ------------------ .../release_2023_01_26/transcribe_canary.rs | 92 ------------ .../release_2023_08_03/paginator_canary.rs | 2 +- .../ci-cdk/canary-runner/src/build_bundle.rs | 14 +- 7 files changed, 3 insertions(+), 330 deletions(-) delete mode 100644 tools/ci-cdk/canary-lambda/src/release_2023_01_26.rs delete mode 100644 tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs delete mode 100644 tools/ci-cdk/canary-lambda/src/release_2023_01_26/s3_canary.rs delete mode 100644 tools/ci-cdk/canary-lambda/src/release_2023_01_26/transcribe_canary.rs diff --git a/tools/ci-cdk/canary-lambda/src/main.rs b/tools/ci-cdk/canary-lambda/src/main.rs index de564d2a2d..d42fb84f53 100644 --- a/tools/ci-cdk/canary-lambda/src/main.rs +++ b/tools/ci-cdk/canary-lambda/src/main.rs @@ -32,12 +32,6 @@ mod release_2023_08_03; #[cfg(feature = "release-2023-08-03")] pub(crate) use release_2023_08_03 as current_canary; -// NOTE: This module can be deleted 3 releases after release-2023-01-26 -#[cfg(feature = "release-2023-01-26")] -mod release_2023_01_26; -#[cfg(feature = "release-2023-01-26")] -pub(crate) use release_2023_01_26 as current_canary; - #[tokio::main] async fn main() -> Result<(), Error> { let subscriber = tracing_subscriber::registry() diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_01_26.rs b/tools/ci-cdk/canary-lambda/src/release_2023_01_26.rs deleted file mode 100644 index 238c361166..0000000000 --- a/tools/ci-cdk/canary-lambda/src/release_2023_01_26.rs +++ /dev/null @@ -1,8 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -pub(crate) mod paginator_canary; -pub(crate) mod s3_canary; -pub(crate) mod transcribe_canary; diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs deleted file mode 100644 index 72c9b40ed0..0000000000 --- a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/paginator_canary.rs +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use crate::mk_canary; -use anyhow::bail; - -use aws_sdk_ec2 as ec2; -use aws_sdk_ec2::model::InstanceType; - -use crate::CanaryEnv; -use tokio_stream::StreamExt; - -mk_canary!( - "ec2_paginator", - |sdk_config: &aws_config::SdkConfig, env: &CanaryEnv| { - paginator_canary(ec2::Client::new(sdk_config), env.page_size) - } -); - -pub async fn paginator_canary(client: ec2::Client, page_size: usize) -> anyhow::Result<()> { - let mut history = client - .describe_spot_price_history() - .instance_types(InstanceType::M1Medium) - .into_paginator() - .page_size(page_size as i32) - .send(); - - let mut num_pages = 0; - while let Some(page) = history.try_next().await? { - let items_in_page = page.spot_price_history.unwrap_or_default().len(); - if items_in_page > page_size as usize { - bail!( - "failed to retrieve results of correct page size (expected {}, got {})", - page_size, - items_in_page - ) - } - num_pages += 1; - } - if dbg!(num_pages) < 2 { - bail!( - "expected 3+ pages containing ~60 results but got {} pages", - num_pages - ) - } - - // https://github.com/awslabs/aws-sdk-rust/issues/405 - let _ = client - .describe_vpcs() - .into_paginator() - .items() - .send() - .collect::, _>>() - .await?; - - Ok(()) -} - -#[cfg(test)] -mod test { - use crate::paginator_canary::paginator_canary; - - #[tokio::test] - async fn test_paginator() { - let conf = aws_config::load_from_env().await; - let client = aws_sdk_ec2::Client::new(&conf); - paginator_canary(client, 20).await.unwrap() - } -} diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/s3_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_01_26/s3_canary.rs deleted file mode 100644 index 70e3d18c55..0000000000 --- a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/s3_canary.rs +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use crate::canary::CanaryError; -use crate::{mk_canary, CanaryEnv}; -use anyhow::Context; -use aws_config::SdkConfig; -use aws_sdk_s3 as s3; -use aws_sdk_s3::presigning::config::PresigningConfig; -use s3::types::ByteStream; -use std::time::Duration; -use uuid::Uuid; - -const METADATA_TEST_VALUE: &str = "some value"; - -mk_canary!("s3", |sdk_config: &SdkConfig, env: &CanaryEnv| s3_canary( - s3::Client::new(sdk_config), - env.s3_bucket_name.clone() -)); - -pub async fn s3_canary(client: s3::Client, s3_bucket_name: String) -> anyhow::Result<()> { - let test_key = Uuid::new_v4().as_u128().to_string(); - - // Look for the test object and expect that it doesn't exist - match client - .get_object() - .bucket(&s3_bucket_name) - .key(&test_key) - .send() - .await - { - Ok(_) => { - return Err( - CanaryError(format!("Expected object {} to not exist in S3", test_key)).into(), - ); - } - Err(err) => { - let err = err.into_service_error(); - // If we get anything other than "No such key", we have a problem - if !err.is_no_such_key() { - return Err(err).context("unexpected s3::GetObject failure"); - } - } - } - - // Put the test object - client - .put_object() - .bucket(&s3_bucket_name) - .key(&test_key) - .body(ByteStream::from_static(b"test")) - .metadata("something", METADATA_TEST_VALUE) - .send() - .await - .context("s3::PutObject")?; - - // Get the test object and verify it looks correct - let output = client - .get_object() - .bucket(&s3_bucket_name) - .key(&test_key) - .send() - .await - .context("s3::GetObject[2]")?; - - // repeat the test with a presigned url - let uri = client - .get_object() - .bucket(&s3_bucket_name) - .key(&test_key) - .presigned(PresigningConfig::expires_in(Duration::from_secs(120)).unwrap()) - .await - .unwrap(); - let response = reqwest::get(uri.uri().to_string()) - .await - .context("s3::presigned")? - .text() - .await?; - if response != "test" { - return Err(CanaryError(format!("presigned URL returned bad data: {:?}", response)).into()); - } - - let mut result = Ok(()); - match output.metadata() { - Some(map) => { - // Option::as_deref doesn't work here since the deref of &String is String - let value = map.get("something").map(|s| s.as_str()).unwrap_or(""); - if value != METADATA_TEST_VALUE { - result = Err(CanaryError(format!( - "S3 metadata was incorrect. Expected `{}` but got `{}`.", - METADATA_TEST_VALUE, value - )) - .into()); - } - } - None => { - result = Err(CanaryError("S3 metadata was missing".into()).into()); - } - } - - let payload = output - .body - .collect() - .await - .context("download s3::GetObject[2] body")? - .into_bytes(); - if std::str::from_utf8(payload.as_ref()).context("s3 payload")? != "test" { - result = Err(CanaryError("S3 object body didn't match what was put there".into()).into()); - } - - // Delete the test object - client - .delete_object() - .bucket(&s3_bucket_name) - .key(&test_key) - .send() - .await - .context("s3::DeleteObject")?; - - result -} - -// This test runs against an actual AWS account. Comment out the `ignore` to run it. -// Be sure to set the `TEST_S3_BUCKET` environment variable to the S3 bucket to use, -// and also make sure the credential profile sets the region (or set `AWS_DEFAULT_PROFILE`). -#[ignore] -#[cfg(test)] -#[tokio::test] -async fn test_s3_canary() { - let config = aws_config::load_from_env().await; - let client = s3::Client::new(&config); - s3_canary( - client, - std::env::var("TEST_S3_BUCKET").expect("TEST_S3_BUCKET must be set"), - ) - .await - .expect("success"); -} diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/transcribe_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_01_26/transcribe_canary.rs deleted file mode 100644 index 554f4c3ddf..0000000000 --- a/tools/ci-cdk/canary-lambda/src/release_2023_01_26/transcribe_canary.rs +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use crate::canary::CanaryError; -use crate::mk_canary; -use async_stream::stream; -use aws_config::SdkConfig; -use aws_sdk_transcribestreaming as transcribe; -use bytes::BufMut; -use transcribe::model::{ - AudioEvent, AudioStream, LanguageCode, MediaEncoding, TranscriptResultStream, -}; -use transcribe::types::Blob; - -const CHUNK_SIZE: usize = 8192; -use crate::canary::CanaryEnv; - -mk_canary!( - "transcribe_canary", - |sdk_config: &SdkConfig, env: &CanaryEnv| { - transcribe_canary( - transcribe::Client::new(sdk_config), - env.expected_transcribe_result.clone(), - ) - } -); - -pub async fn transcribe_canary( - client: transcribe::Client, - expected_transcribe_result: String, -) -> anyhow::Result<()> { - let input_stream = stream! { - let pcm = pcm_data(); - for chunk in pcm.chunks(CHUNK_SIZE) { - yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); - } - }; - - let mut output = client - .start_stream_transcription() - .language_code(LanguageCode::EnGb) - .media_sample_rate_hertz(8000) - .media_encoding(MediaEncoding::Pcm) - .audio_stream(input_stream.into()) - .send() - .await?; - - let mut full_message = String::new(); - while let Some(event) = output.transcript_result_stream.recv().await? { - match event { - TranscriptResultStream::TranscriptEvent(transcript_event) => { - let transcript = transcript_event.transcript.unwrap(); - for result in transcript.results.unwrap_or_default() { - if !result.is_partial { - let first_alternative = &result.alternatives.as_ref().unwrap()[0]; - full_message += first_alternative.transcript.as_ref().unwrap(); - full_message.push(' '); - } - } - } - otherwise => panic!("received unexpected event type: {:?}", otherwise), - } - } - - if expected_transcribe_result != full_message.trim() { - Err(CanaryError(format!( - "Transcription from Transcribe doesn't look right:\n\ - Expected: `{}`\n\ - Actual: `{}`\n", - expected_transcribe_result, - full_message.trim() - )) - .into()) - } else { - Ok(()) - } -} - -fn pcm_data() -> Vec { - let reader = - hound::WavReader::new(&include_bytes!("../../audio/hello-transcribe-8000.wav")[..]) - .expect("valid wav data"); - let samples_result: hound::Result> = reader.into_samples::().collect(); - - let mut pcm: Vec = Vec::new(); - for sample in samples_result.unwrap() { - pcm.put_i16_le(sample); - } - pcm -} diff --git a/tools/ci-cdk/canary-lambda/src/release_2023_08_03/paginator_canary.rs b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/paginator_canary.rs index d50c4f2be8..66df5a03e4 100644 --- a/tools/ci-cdk/canary-lambda/src/release_2023_08_03/paginator_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/release_2023_08_03/paginator_canary.rs @@ -60,7 +60,7 @@ pub async fn paginator_canary(client: ec2::Client, page_size: usize) -> anyhow:: #[cfg(test)] mod test { - use crate::latest::paginator_canary::paginator_canary; + use crate::current_canary::paginator_canary::paginator_canary; #[tokio::test] async fn test_paginator() { diff --git a/tools/ci-cdk/canary-runner/src/build_bundle.rs b/tools/ci-cdk/canary-runner/src/build_bundle.rs index 71e6cd3817..e82b00ddd7 100644 --- a/tools/ci-cdk/canary-runner/src/build_bundle.rs +++ b/tools/ci-cdk/canary-runner/src/build_bundle.rs @@ -68,7 +68,6 @@ const REQUIRED_SDK_CRATES: &[&str] = &[ // The elements in this `Vec` should be sorted in an ascending order by the release date. lazy_static! { static ref NOTABLE_SDK_RELEASE_TAGS: Vec = vec![ - ReleaseTag::from_str("release-2023-01-26").unwrap(), // last version before the crate reorg ReleaseTag::from_str("release-2023-08-03").unwrap(), // last version before `Stream` trait removal ]; } @@ -447,7 +446,6 @@ aws-sdk-transcribestreaming = { path = "some/sdk/path/transcribestreaming" } [features] latest = [] -"release-2023-01-26" = [] "release-2023-08-03" = [] default = ["latest"] "#, @@ -513,7 +511,6 @@ aws-sdk-transcribestreaming = "0.16.0" [features] latest = [] -"release-2023-01-26" = [] "release-2023-08-03" = [] default = ["latest"] "#, @@ -597,18 +594,11 @@ default = ["latest"] "release-2023-08-03".to_string(), enabled_feature(&CrateSource::VersionsManifest { versions: versions.clone(), - release_tag: "release-2023-02-23".parse().unwrap(), + release_tag: "release-2023-08-03".parse().unwrap(), }), ); assert_eq!( - "release-2023-01-26".to_string(), - enabled_feature(&CrateSource::VersionsManifest { - versions: versions.clone(), - release_tag: "release-2023-01-26".parse().unwrap(), - }), - ); - assert_eq!( - "release-2023-01-26".to_string(), + "release-2023-08-03".to_string(), enabled_feature(&CrateSource::VersionsManifest { versions, release_tag: "release-2023-01-13".parse().unwrap(), From 5f65dffa67279bb06f6b1a32e706b94f11dee3c4 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Fri, 25 Aug 2023 17:20:33 -0500 Subject: [PATCH 22/36] Update CHANGELOG.next.toml --- CHANGELOG.next.toml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 8f11accc1e..4dfa059ce7 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -54,3 +54,21 @@ message = "Update MSRV to Rust 1.70.0" references = ["smithy-rs#2948"] meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" } author = "Velfi" + +[[aws-sdk-rust]] +message = """ +The `futures_core::stream::Stream` trait has been removed from public API. It should not affect usual SDK use cases, but it does require code upgrade for a small number of cases. The notable example is Transcribe streaming when streaming data is created via a `stream!` macro from the `async-stream` crate. The use of that macro needs to be replaced with `aws_smithy_async::future::fn_stream::FnStream`. See https://github.com/awslabs/smithy-rs/discussions/2952 for more details. +""" +references = ["smithy-rs#2910"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "ysaito1001" + +[[smithy-rs]] +message = """ +The `futures_core::stream::Stream` trait has been removed from public API. The methods that were made available through the `Stream` trait have been removed from these types. However, we have preserved `.next()` and `.collect()` to continue supporting existing call sites in `smithy-rs` and `aws-sdk-rust`, including tests and rustdocs. If we need to support missing stream operations, we are planning to do so in an additive, backward compatible manner. + +If your code uses a `stream!` macro from the `async_stream` crate to generate stream data, it needs to be replaced by `aws_smithy_async::future::fn_steram::FnStream`. See https://github.com/awslabs/smithy-rs/discussions/2952 for more details. +""" +references = ["smithy-rs#2910"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" } +author = "ysaito1001" From 6abb51948532755ea7ccad8759aafe6521675944 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Mon, 28 Aug 2023 12:32:03 -0500 Subject: [PATCH 23/36] Remove code duplication in calling `renderEventStreamBody` This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1306203075 --- .../HttpBoundProtocolPayloadGenerator.kt | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 0c24e26115..749e296554 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -21,6 +21,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -263,10 +264,7 @@ class HttpBoundProtocolPayloadGenerator( contentType, ).render() - if (target == CodegenTarget.CLIENT) { - // No need to wrap it with `HyperBodyWrapEventStream` for the client since wrapping takes place - // within `renderEventStreamBody` provided by `ClientHttpBoundProtocolPayloadGenerator`. - + val renderEventStreamBody = writable { // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the // parameters that are not `@eventHeader` or `@eventPayload`. renderEventStreamBody( @@ -279,23 +277,18 @@ class HttpBoundProtocolPayloadGenerator( additionalPayloadContext, ), ) + } + + if (target == CodegenTarget.CLIENT) { + // No need to wrap it with `HyperBodyWrapEventStream` for the client since wrapping takes place + // within `renderEventStreamBody` provided by `ClientHttpBoundProtocolPayloadGenerator`. + renderEventStreamBody() } else { withBlockTemplate( "#{HyperBodyWrapEventStream}::new(", ")", *codegenScope, ) { - // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the - // parameters that are not `@eventHeader` or `@eventPayload`. - renderEventStreamBody( - this, - EventStreamBodyParams( - outerName, - memberName, - marshallerConstructorFn, - errorMarshallerConstructorFn, - additionalPayloadContext, - ), - ) + renderEventStreamBody() } } } From 04912210dc91461b69110827488783cbe8841d83 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Tue, 29 Aug 2023 16:21:57 -0500 Subject: [PATCH 24/36] Remove uncecessary `CollectablePrivate` This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1306210201 --- .../aws-smithy-async/src/future/fn_stream.rs | 2 +- .../src/future/fn_stream/collect.rs | 22 +++++++------------ 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs index 75a4b375cd..983df28f1f 100644 --- a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs @@ -104,7 +104,7 @@ impl FnStream { } /// Consumes this stream and gathers elements into a collection. - pub async fn collect>(mut self) -> T { + pub async fn collect>(mut self) -> T { let mut collection = T::initialize(); while let Some(item) = self.next().await { if !T::extend(&mut collection, item) { diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs index f909fd730e..a6c3b59e3d 100644 --- a/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs @@ -9,15 +9,13 @@ //! Majority of the code is borrowed from //! -/// A trait that signifies that elements can be collected into `T`. -/// -/// Currently the trait may not be implemented by clients so we can make changes in the future -/// without breaking code depending on it. -pub trait Collectable: sealed::CollectablePrivate {} - pub(crate) mod sealed { + /// A trait that signifies that elements can be collected into `T`. + /// + /// Currently the trait may not be implemented by clients so we can make changes in the future + /// without breaking code depending on it. #[doc(hidden)] - pub trait CollectablePrivate { + pub trait Collectable { type Collection; fn initialize() -> Self::Collection; @@ -28,9 +26,7 @@ pub(crate) mod sealed { } } -impl Collectable for Vec {} - -impl sealed::CollectablePrivate for Vec { +impl sealed::Collectable for Vec { type Collection = Self; fn initialize() -> Self::Collection { @@ -47,11 +43,9 @@ impl sealed::CollectablePrivate for Vec { } } -impl Collectable> for Result where U: Collectable {} - -impl sealed::CollectablePrivate> for Result +impl sealed::Collectable> for Result where - U: Collectable, + U: sealed::Collectable, { type Collection = Result; From bedfb0485332833b3ddabf695d48745be7000443 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Wed, 30 Aug 2023 12:14:59 -0500 Subject: [PATCH 25/36] Use customization to wrap stream payload in new-type This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1306203927 --- .../protocols/HttpBoundProtocolGenerator.kt | 18 +++++ .../HttpBoundProtocolPayloadGenerator.kt | 65 ++++++++++--------- .../protocols/PythonServerProtocolLoader.kt | 26 ++++++++ .../ServerHttpBoundProtocolGenerator.kt | 31 ++++++++- .../smithy/protocols/ServerProtocolLoader.kt | 51 +++++++++++++-- .../smithy/protocols/ServerRestXmlFactory.kt | 10 ++- 6 files changed, 162 insertions(+), 39 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index 876bc3ab51..2753e151c7 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -7,12 +7,14 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializer class ClientHttpBoundProtocolPayloadGenerator( codegenContext: ClientCodegenContext, @@ -44,4 +46,20 @@ class ClientHttpBoundProtocolPayloadGenerator( "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, ) }, + streamPayloadSerializer = StreamPayloadSerializer( + { writer, params -> + // `aws_smithy_http::byte_stream::ByteStream` no longer implements `futures::stream::Stream` + // so wrap it in a new-type to enable the trait. + writer.rust( + "#T", + RuntimeType.hyperBodyWrapStream(params.runtimeConfig).resolve("HyperBodyWrapByteStream").toSymbol(), + ) + }, + { writer, params -> + writer.rust( + "#T::new(${params.payloadName!!})", + RuntimeType.hyperBodyWrapStream(params.runtimeConfig).resolve("HyperBodyWrapByteStream").toSymbol(), + ) + }, + ), ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 749e296554..2c7bf67a8f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.DocumentShape import software.amazon.smithy.model.shapes.MemberShape @@ -17,13 +18,13 @@ import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType @@ -53,11 +54,24 @@ data class EventStreamBodyParams( val additionalPayloadContext: AdditionalPayloadContext, ) +data class StreamPayloadSerializerParams( + val symbolProvider: SymbolProvider, + val runtimeConfig: RuntimeConfig, + val member: MemberShape, + val payloadName: String?, +) + +data class StreamPayloadSerializer( + val outputT: (RustWriter, StreamPayloadSerializerParams) -> Unit, + val renderPayload: (RustWriter, StreamPayloadSerializerParams) -> Unit, +) + class HttpBoundProtocolPayloadGenerator( codegenContext: CodegenContext, private val protocol: Protocol, private val httpMessageType: HttpMessageType = HttpMessageType.REQUEST, private val renderEventStreamBody: (RustWriter, EventStreamBodyParams) -> Unit, + private val streamPayloadSerializer: StreamPayloadSerializer, ) : ProtocolPayloadGenerator { private val symbolProvider = codegenContext.symbolProvider private val model = codegenContext.model @@ -301,27 +315,22 @@ class HttpBoundProtocolPayloadGenerator( ) { val ref = if (payloadMetadata.takesOwnership) "" else "&" val serializer = protocolFunctions.serializeFn(member, fnNameSuffix = "http_payload") { fnName -> - val outputT = if (member.isStreaming(model)) { - if (fromPythonServerRuntime(member)) { - // `aws_smithy_http_server_python::types::ByteStream` already implements - // `futures::stream::Stream`, so no need to wrap it in a futures' stream-compatible - // wrapper. - symbolProvider.toSymbol(member) - } else { - // `aws_smithy_http::byte_stream::ByteStream` no longer implements `futures::stream::Stream` - // so wrap it in a new-type to enable the trait. - RuntimeType.hyperBodyWrapStream(runtimeConfig) - .resolve("HyperBodyWrapByteStream").toSymbol() - } - } else { - RuntimeType.ByteSlab.toSymbol() - } - rustBlockTemplate( - "pub(crate) fn $fnName(payload: $ref#{Member}) -> Result<#{outputT}, #{BuildError}>", + rustTemplate( + "pub(crate) fn $fnName(payload: $ref#{Member}) -> #{Result}<", "Member" to symbolProvider.toSymbol(member), - "outputT" to outputT, *codegenScope, - ) { + ) + if (member.isStreaming(model)) { + streamPayloadSerializer.outputT( + this, + StreamPayloadSerializerParams(symbolProvider, runtimeConfig, member, null), + ) + } else { + rust("#T", RuntimeType.ByteSlab.toSymbol()) + } + rustTemplate(", #{BuildError}>", *codegenScope) + + withBlockTemplate("{", "}", *codegenScope) { val asRef = if (payloadMetadata.takesOwnership) "" else ".as_ref()" if (symbolProvider.toSymbol(member).isOptional()) { @@ -377,15 +386,10 @@ class HttpBoundProtocolPayloadGenerator( is BlobShape -> { // Write the raw blob to the payload. if (member.isStreaming(model)) { - if (fromPythonServerRuntime(member)) { - rust(payloadName) - } else { - rustTemplate( - "#{HyperBodyWrapByteStream}::new($payloadName)", - "HyperBodyWrapByteStream" to RuntimeType.hyperBodyWrapStream(runtimeConfig) - .resolve("HyperBodyWrapByteStream"), - ) - } + streamPayloadSerializer.renderPayload( + this, + StreamPayloadSerializerParams(symbolProvider, runtimeConfig, member, payloadName), + ) } else { // Convert the `Blob` into a `Vec` and return it. rust("$payloadName.into_inner()") @@ -413,7 +417,4 @@ class HttpBoundProtocolPayloadGenerator( else -> PANIC("Unexpected payload target type: $targetShape") } } - - private fun fromPythonServerRuntime(member: MemberShape) = - symbolProvider.toSymbol(member).namespace.contains("aws_smithy_http_server_python") } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt index f84d35e7cb..df9ae3ab97 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt @@ -61,6 +61,29 @@ class PythonServerAfterDeserializedMemberServerHttpBoundCustomization() : is ServerHttpBoundProtocolSection.AfterTimestampDeserializedMember -> writable { rust(".into()") } + + else -> emptySection + } +} + +/** + * Customization class used to determine the type of serialized stream payload and how it should be wrapped in a + * new-type wrapper to enable `futures_core::stream::Stream` trait. + */ +class PythonServerStreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { + override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { + is ServerHttpBoundProtocolSection.TypeOfSerializedPayloadStream -> writable { + // `aws_smithy_http_server_python::types::ByteStream` already implements + // `futures::stream::Stream`, so no need to wrap it in a futures' stream-compatible + // wrapper. + rust("#T", section.params.symbolProvider.toSymbol(section.params.member)) + } + + is ServerHttpBoundProtocolSection.WrapStreamAfterPayloadGenerated -> writable { + // payloadName is always non-null within WrapStreamAfterPayloadGenerated + rust(section.params.payloadName!!) + } + else -> emptySection } } @@ -91,6 +114,7 @@ class PythonServerProtocolLoader( ), additionalServerHttpBoundProtocolCustomizations = listOf( PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + PythonServerStreamPayloadSerializerCustomization(), ), additionalHttpBindingCustomizations = listOf( PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), @@ -103,6 +127,7 @@ class PythonServerProtocolLoader( ), additionalServerHttpBoundProtocolCustomizations = listOf( PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + PythonServerStreamPayloadSerializerCustomization(), ), additionalHttpBindingCustomizations = listOf( PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), @@ -115,6 +140,7 @@ class PythonServerProtocolLoader( ), additionalServerHttpBoundProtocolCustomizations = listOf( PythonServerAfterDeserializedMemberServerHttpBoundCustomization(), + PythonServerStreamPayloadSerializerCustomization(), ), additionalHttpBindingCustomizations = listOf( PythonServerAfterDeserializedMemberHttpBindingCustomization(runtimeConfig), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 3a592f1966..2052a25652 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -57,6 +57,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtoc import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializer +import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializerParams import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors @@ -87,6 +89,12 @@ import java.util.logging.Logger sealed class ServerHttpBoundProtocolSection(name: String) : Section(name) { data class AfterTimestampDeserializedMember(val shape: MemberShape) : ServerHttpBoundProtocolSection("AfterTimestampDeserializedMember") + + data class TypeOfSerializedPayloadStream(val params: StreamPayloadSerializerParams) : + ServerHttpBoundProtocolSection("TypeOfSerializedPayloadStream") + + data class WrapStreamAfterPayloadGenerated(val params: StreamPayloadSerializerParams) : + ServerHttpBoundProtocolSection("WrapStreamAfterPayloadGenerated") } /** @@ -123,6 +131,7 @@ class ServerHttpBoundProtocolGenerator( class ServerHttpBoundProtocolPayloadGenerator( codegenContext: CodegenContext, protocol: Protocol, + customizations: List = listOf(), ) : ProtocolPayloadGenerator by HttpBoundProtocolPayloadGenerator( codegenContext, protocol, HttpMessageType.RESPONSE, renderEventStreamBody = { writer, params -> @@ -143,6 +152,26 @@ class ServerHttpBoundProtocolPayloadGenerator( "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, ) }, + streamPayloadSerializer = StreamPayloadSerializer( + { writer, params -> + for (customization in customizations) { + customization.section( + ServerHttpBoundProtocolSection.TypeOfSerializedPayloadStream( + params, + ), + )(writer) + } + }, + { writer, params -> + for (customization in customizations) { + customization.section( + ServerHttpBoundProtocolSection.WrapStreamAfterPayloadGenerated( + params, + ), + )(writer) + } + }, + ), ) /* @@ -544,7 +573,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( ?: serverRenderHttpResponseCode(httpTraitStatusCode)(this) operationShape.outputShape(model).findStreamingMember(model)?.let { - val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol) + val payloadGenerator = ServerHttpBoundProtocolPayloadGenerator(codegenContext, protocol, customizations) withBlockTemplate( "let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", "));", diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index e52d9e3a3b..5a175edb33 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -9,21 +9,64 @@ import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator +class StreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { + override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { + is ServerHttpBoundProtocolSection.TypeOfSerializedPayloadStream -> writable { + // `aws_smithy_http::byte_stream::ByteStream` no longer implements `futures_core::stream::Stream` + // so wrap it in a new-type to enable the trait. + rust( + "#T", + RuntimeType.hyperBodyWrapStream(section.params.runtimeConfig) + .resolve("HyperBodyWrapByteStream").toSymbol(), + ) + } + + is ServerHttpBoundProtocolSection.WrapStreamAfterPayloadGenerated -> writable { + rustTemplate( + "#{HyperBodyWrapByteStream}::new(${section.params.payloadName!!})", + "HyperBodyWrapByteStream" to RuntimeType.hyperBodyWrapStream(section.params.runtimeConfig) + .resolve("HyperBodyWrapByteStream"), + ) + } + + else -> emptySection + } +} + class ServerProtocolLoader(supportedProtocols: ProtocolMap) : ProtocolLoader(supportedProtocols) { companion object { val DefaultProtocols = mapOf( - RestJson1Trait.ID to ServerRestJsonFactory(), - RestXmlTrait.ID to ServerRestXmlFactory(), - AwsJson1_0Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json10), - AwsJson1_1Trait.ID to ServerAwsJsonFactory(AwsJsonVersion.Json11), + RestJson1Trait.ID to ServerRestJsonFactory( + additionalServerHttpBoundProtocolCustomizations = listOf( + StreamPayloadSerializerCustomization(), + ), + ), + RestXmlTrait.ID to ServerRestXmlFactory( + additionalServerHttpBoundProtocolCustomizations = listOf( + StreamPayloadSerializerCustomization(), + ), + ), + AwsJson1_0Trait.ID to ServerAwsJsonFactory( + AwsJsonVersion.Json10, + additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), + ), + AwsJson1_1Trait.ID to ServerAwsJsonFactory( + AwsJsonVersion.Json11, + additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()), + ), ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt index f5b3be454f..9207c56046 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestXmlFactory.kt @@ -15,11 +15,17 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser * RestXml server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator] * with RestXml specific configurations. */ -class ServerRestXmlFactory : ProtocolGeneratorFactory { +class ServerRestXmlFactory( + private val additionalServerHttpBoundProtocolCustomizations: List = listOf(), +) : ProtocolGeneratorFactory { override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerRestXmlProtocol(codegenContext) override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = - ServerHttpBoundProtocolGenerator(codegenContext, ServerRestXmlProtocol(codegenContext)) + ServerHttpBoundProtocolGenerator( + codegenContext, + ServerRestXmlProtocol(codegenContext), + additionalServerHttpBoundProtocolCustomizations, + ) override fun support(): ProtocolSupport { return ProtocolSupport( From 5698ee3ebd0594b082022db85074141d49d44050 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Wed, 30 Aug 2023 12:31:00 -0500 Subject: [PATCH 26/36] Let `finalize` take an owned associated collection This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1306210966 --- .../aws-smithy-async/src/future/fn_stream.rs | 2 +- .../src/future/fn_stream/collect.rs | 16 +++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs index 983df28f1f..7d09ae3175 100644 --- a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs @@ -111,7 +111,7 @@ impl FnStream { break; } } - T::finalize(&mut collection) + T::finalize(collection) } } diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs index a6c3b59e3d..a07909b999 100644 --- a/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream/collect.rs @@ -22,7 +22,7 @@ pub(crate) mod sealed { fn extend(collection: &mut Self::Collection, item: T) -> bool; - fn finalize(collection: &mut Self::Collection) -> Self; + fn finalize(collection: Self::Collection) -> Self; } } @@ -38,8 +38,8 @@ impl sealed::Collectable for Vec { true } - fn finalize(collection: &mut Self::Collection) -> Self { - std::mem::take(collection) + fn finalize(collection: Self::Collection) -> Self { + collection } } @@ -66,12 +66,10 @@ where } } - fn finalize(collection: &mut Self::Collection) -> Self { - if let Ok(collection) = collection.as_mut() { - Ok(U::finalize(collection)) - } else { - let res = std::mem::replace(collection, Ok(U::initialize())); - Err(res.map(drop).unwrap_err()) + fn finalize(collection: Self::Collection) -> Self { + match collection { + Ok(collection) => Ok(U::finalize(collection)), + err @ Err(_) => Err(err.map(drop).unwrap_err()), } } } From d9052d5bc30f256d0001480a93b480f7815b757e Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Wed, 30 Aug 2023 13:10:44 -0500 Subject: [PATCH 27/36] Reexport `FnStream` and remove dependency on `aws-smithy-async` This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1306220065 --- .../customizations/SmithyTypesPubUseExtra.kt | 16 ++++++++++++++++ .../src/latest/transcribe_canary.rs | 3 +-- tools/ci-cdk/canary-runner/src/build_bundle.rs | 6 ------ 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt index 14d081b2cb..57af1a8547 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt @@ -26,6 +26,14 @@ private fun hasStreamingOperations(model: Model): Boolean { } } +/** Returns true if the model has event streaming operations */ +private fun hasEventStreamOperations(model: Model): Boolean = + model.operationShapes.any { operation -> + val input = model.expectShape(operation.inputShape, StructureShape::class.java) + val output = model.expectShape(operation.outputShape, StructureShape::class.java) + input.hasEventStreamMember(model) || output.hasEventStreamMember(model) + } + // TODO(https://github.com/awslabs/smithy-rs/issues/2111): Fix this logic to consider collection/map shapes private fun structUnionMembersMatchPredicate(model: Model, predicate: (Shape) -> Boolean): Boolean = model.structureShapes.any { structure -> @@ -70,4 +78,12 @@ fun pubUseSmithyPrimitives(codegenContext: CodegenContext, model: Model): Writab "SdkBody" to RuntimeType.smithyHttp(rc).resolve("body::SdkBody"), ) } + if (hasEventStreamOperations(model)) { + rustTemplate( + """ + pub use #{FnStream}; + """, + "FnStream" to RuntimeType.smithyAsync(rc).resolve("future::fn_stream::FnStream"), + ) + } } diff --git a/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs b/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs index 903329e555..0857f54c94 100644 --- a/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/latest/transcribe_canary.rs @@ -7,7 +7,6 @@ use crate::canary::CanaryError; use crate::mk_canary; use aws_config::SdkConfig; use aws_sdk_transcribestreaming as transcribe; -use aws_smithy_async::future::fn_stream::FnStream; use bytes::BufMut; use transcribe::primitives::Blob; use transcribe::types::{ @@ -31,7 +30,7 @@ pub async fn transcribe_canary( client: transcribe::Client, expected_transcribe_result: String, ) -> anyhow::Result<()> { - let input_stream = FnStream::new(|tx| { + let input_stream = transcribe::primitives::FnStream::new(|tx| { Box::pin(async move { let pcm = pcm_data(); for chunk in pcm.chunks(CHUNK_SIZE) { diff --git a/tools/ci-cdk/canary-runner/src/build_bundle.rs b/tools/ci-cdk/canary-runner/src/build_bundle.rs index e82b00ddd7..635cec4e7d 100644 --- a/tools/ci-cdk/canary-runner/src/build_bundle.rs +++ b/tools/ci-cdk/canary-runner/src/build_bundle.rs @@ -56,8 +56,6 @@ tracing-texray = "0.1.1" reqwest = { version = "0.11.14", features = ["rustls-tls"], default-features = false } "#; -const REQUIRED_SMITHY_CRATES: &[&str] = &["aws-smithy-async"]; - const REQUIRED_SDK_CRATES: &[&str] = &[ "aws-config", "aws-sdk-s3", @@ -131,7 +129,6 @@ fn enabled_feature(crate_source: &CrateSource) -> String { fn generate_crate_manifest(crate_source: CrateSource) -> Result { let mut output = BASE_MANIFEST.to_string(); - write_dependencies(REQUIRED_SMITHY_CRATES, &mut output, &crate_source)?; write_dependencies(REQUIRED_SDK_CRATES, &mut output, &crate_source)?; write!(output, "\n[features]\n").unwrap(); writeln!(output, "latest = []").unwrap(); @@ -438,7 +435,6 @@ uuid = { version = "0.8", features = ["v4"] } tokio-stream = "0" tracing-texray = "0.1.1" reqwest = { version = "0.11.14", features = ["rustls-tls"], default-features = false } -aws-smithy-async = { path = "some/sdk/path/aws-smithy-async" } aws-config = { path = "some/sdk/path/aws-config" } aws-sdk-s3 = { path = "some/sdk/path/s3" } aws-sdk-ec2 = { path = "some/sdk/path/ec2" } @@ -503,7 +499,6 @@ uuid = { version = "0.8", features = ["v4"] } tokio-stream = "0" tracing-texray = "0.1.1" reqwest = { version = "0.11.14", features = ["rustls-tls"], default-features = false } -aws-smithy-async = "0.46.0" aws-config = "0.46.0" aws-sdk-s3 = "0.20.0" aws-sdk-ec2 = "0.19.0" @@ -520,7 +515,6 @@ default = ["latest"] aws_doc_sdk_examples_revision: "some-revision-docs".into(), manual_interventions: Default::default(), crates: [ - crate_version("aws-smithy-async", "0.46.0"), crate_version("aws-config", "0.46.0"), crate_version("aws-sdk-s3", "0.20.0"), crate_version("aws-sdk-ec2", "0.19.0"), From b281aee127fcf5a7db1b7676654dbf66df328b6b Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Wed, 30 Aug 2023 13:16:45 -0500 Subject: [PATCH 28/36] Remove dependency on `aws-smithy-async` and use reexported `FnStream` --- aws/sdk/integration-tests/transcribestreaming/Cargo.toml | 1 - aws/sdk/integration-tests/transcribestreaming/tests/test.rs | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/aws/sdk/integration-tests/transcribestreaming/Cargo.toml b/aws/sdk/integration-tests/transcribestreaming/Cargo.toml index d214ae56ed..33dfc393d4 100644 --- a/aws/sdk/integration-tests/transcribestreaming/Cargo.toml +++ b/aws/sdk/integration-tests/transcribestreaming/Cargo.toml @@ -12,7 +12,6 @@ publish = false aws-credential-types = { path = "../../build/aws-sdk/sdk/aws-credential-types", features = ["test-util"] } aws-http = { path = "../../build/aws-sdk/sdk/aws-http" } aws-sdk-transcribestreaming = { path = "../../build/aws-sdk/sdk/transcribestreaming" } -aws-smithy-async = { path = "../../build/aws-sdk/sdk/aws-smithy-async" } aws-smithy-client = { path = "../../build/aws-sdk/sdk/aws-smithy-client", features = ["test-util", "rustls"] } aws-smithy-eventstream = { path = "../../build/aws-sdk/sdk/aws-smithy-eventstream" } aws-smithy-http = { path = "../../build/aws-sdk/sdk/aws-smithy-http" } diff --git a/aws/sdk/integration-tests/transcribestreaming/tests/test.rs b/aws/sdk/integration-tests/transcribestreaming/tests/test.rs index 29c963d6c5..333ed88b95 100644 --- a/aws/sdk/integration-tests/transcribestreaming/tests/test.rs +++ b/aws/sdk/integration-tests/transcribestreaming/tests/test.rs @@ -6,13 +6,12 @@ use aws_sdk_transcribestreaming::config::{Credentials, Region}; use aws_sdk_transcribestreaming::error::SdkError; use aws_sdk_transcribestreaming::operation::start_stream_transcription::StartStreamTranscriptionOutput; -use aws_sdk_transcribestreaming::primitives::Blob; +use aws_sdk_transcribestreaming::primitives::{Blob, FnStream}; use aws_sdk_transcribestreaming::types::error::{AudioStreamError, TranscriptResultStreamError}; use aws_sdk_transcribestreaming::types::{ AudioEvent, AudioStream, LanguageCode, MediaEncoding, TranscriptResultStream, }; use aws_sdk_transcribestreaming::{Client, Config}; -use aws_smithy_async::future::fn_stream::FnStream; use aws_smithy_client::dvr::{Event, ReplayingConnection}; use aws_smithy_eventstream::frame::{DecodedFrame, HeaderValue, Message, MessageFrameDecoder}; use bytes::BufMut; From 0f5f6be07370bcf54cea84b02e680d1bd0d0f292 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Wed, 30 Aug 2023 15:12:41 -0500 Subject: [PATCH 29/36] Return size hint as is by `http_body::Body::size_hint` This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1306212995 --- .../aws-smithy-http/src/byte_stream.rs | 23 +++---------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/rust-runtime/aws-smithy-http/src/byte_stream.rs b/rust-runtime/aws-smithy-http/src/byte_stream.rs index 28cf660969..a369bf04c5 100644 --- a/rust-runtime/aws-smithy-http/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-http/src/byte_stream.rs @@ -304,7 +304,7 @@ impl ByteStream { } /// Return the bounds on the remaining length of the `ByteStream`. - pub fn size_hint(&self) -> (usize, Option) { + pub fn size_hint(&self) -> (u64, Option) { self.inner.size_hint() } @@ -444,10 +444,6 @@ impl futures_core::stream::Stream for StreamWrapper { .poll_next(cx) .map_err(Error::streaming) } - - fn size_hint(&self) -> (usize, Option) { - self.byte_stream.size_hint() - } } impl Default for ByteStream { @@ -591,28 +587,15 @@ impl Inner { Ok(AggregatedBytes(output)) } - fn size_hint(&self) -> (usize, Option) + fn size_hint(&self) -> (u64, Option) where B: http_body::Body, { let size_hint = http_body::Body::size_hint(&self.body); - let lower = size_hint.lower().try_into(); - let upper = size_hint.upper().map(|u| u.try_into()).transpose(); - - match (lower, upper) { - (Ok(lower), Ok(upper)) => (lower, upper), - (Err(_), _) | (_, Err(_)) => { - panic!("{}", SIZE_HINT_32_BIT_PANIC_MESSAGE) - } - } + (size_hint.lower(), size_hint.upper()) } } -const SIZE_HINT_32_BIT_PANIC_MESSAGE: &str = r#" -You're running a 32-bit system and this stream's length is too large to be represented with a usize. -Please limit stream length to less than 4.294Gb or run this program on a 64-bit computer architecture. -"#; - #[cfg(test)] mod tests { use crate::byte_stream::Inner; From d2e44118b4c8632ff90d427da73c9ca561703d61 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Wed, 30 Aug 2023 15:33:20 -0500 Subject: [PATCH 30/36] Add an comment `FnStream` is `Send` not `Sync` This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1306217194 --- rust-runtime/aws-smithy-async/src/future/fn_stream.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs index 7d09ae3175..65d89584ad 100644 --- a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs @@ -26,6 +26,9 @@ pin_project! { /// /// If `tx.send` returns an error, the function MUST return immediately. /// + /// Note `FnStream` is only `Send` but not `Sync` because `generator` is a boxed future that + /// is `Send` and returns `()` as output when it is done. + /// /// # Examples /// ```no_run /// # async fn docs() { From a778886fdcce4f8dcebd5bc3a6f97fdbe937f45e Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Wed, 30 Aug 2023 15:40:15 -0500 Subject: [PATCH 31/36] Add more comments as to why `generator` can be set to `None` This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1306209330 --- rust-runtime/aws-smithy-async/src/future/fn_stream.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs index 65d89584ad..f236c33775 100644 --- a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs @@ -96,8 +96,10 @@ impl FnStream { Poll::Pending => { if let Some(generator) = me.generator { if generator.as_mut().poll(cx).is_ready() { - // if the generator returned ready we MUST NOT poll it again—doing so - // will cause a panic. + // `generator` keeps writing items to `tx` and will not be `Poll::Ready` + // until it is done writing to `tx`. Once it is done, it returns `()` + // as output and is `Poll::Ready`, at which point we MUST NOT poll it again + // since doing so will cause a panic. *me.generator = None; } } From ab57dce62552b043e77b69438e8263c4fd25654b Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Wed, 30 Aug 2023 22:31:46 -0500 Subject: [PATCH 32/36] Convert `StreamPayloadSerializer` to interface This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1310835743 --- .../protocols/HttpBoundProtocolGenerator.kt | 39 ++++++++------ .../HttpBoundProtocolPayloadGenerator.kt | 28 +++++++--- .../ServerHttpBoundProtocolGenerator.kt | 53 +++++++++++-------- 3 files changed, 75 insertions(+), 45 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index 2753e151c7..c449d75d70 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -14,7 +15,26 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessa import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializer +import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializerParams +import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializerRenderer + +private class ClientStreamPayloadSerializerRenderer : StreamPayloadSerializerRenderer { + override fun renderOutputType(writer: RustWriter, params: StreamPayloadSerializerParams) { + writer.rust( + "#T", + RuntimeType.hyperBodyWrapStream(params.runtimeConfig).resolve("HyperBodyWrapByteStream").toSymbol(), + ) + } + + override fun renderPayload(writer: RustWriter, params: StreamPayloadSerializerParams) { + // Payload is `aws_smithy_http::byte_stream::ByteStream` but it no longer implements `futures::stream::Stream`, + // so we wrap it in a new-type to enable the trait. + writer.rust( + "#T::new(${params.payloadName!!})", + RuntimeType.hyperBodyWrapStream(params.runtimeConfig).resolve("HyperBodyWrapByteStream").toSymbol(), + ) + } +} class ClientHttpBoundProtocolPayloadGenerator( codegenContext: ClientCodegenContext, @@ -46,20 +66,5 @@ class ClientHttpBoundProtocolPayloadGenerator( "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, ) }, - streamPayloadSerializer = StreamPayloadSerializer( - { writer, params -> - // `aws_smithy_http::byte_stream::ByteStream` no longer implements `futures::stream::Stream` - // so wrap it in a new-type to enable the trait. - writer.rust( - "#T", - RuntimeType.hyperBodyWrapStream(params.runtimeConfig).resolve("HyperBodyWrapByteStream").toSymbol(), - ) - }, - { writer, params -> - writer.rust( - "#T::new(${params.payloadName!!})", - RuntimeType.hyperBodyWrapStream(params.runtimeConfig).resolve("HyperBodyWrapByteStream").toSymbol(), - ) - }, - ), + ClientStreamPayloadSerializerRenderer(), ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 2c7bf67a8f..ff44f67269 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -61,17 +61,31 @@ data class StreamPayloadSerializerParams( val payloadName: String?, ) -data class StreamPayloadSerializer( - val outputT: (RustWriter, StreamPayloadSerializerParams) -> Unit, - val renderPayload: (RustWriter, StreamPayloadSerializerParams) -> Unit, -) +/** + * An interface to help customize how to render a stream payload serializer. + * + * When the output of the serializer is passed to `hyper::body::Body::wrap_stream`, + * it requires what's passed to implement `futures_core::stream::Stream` trait. + * However, a certain type, such as `aws_smithy_http::byte_stream::ByteStream` does not + * implement the trait, so we need to wrap it with a new-type that does implement the trait. + * + * Each implementing type of the interface can choose whether the payload should be wrapped + * with such a new-type or should simply be used as-is. + */ +interface StreamPayloadSerializerRenderer { + /** Renders the return type of stream payload serializer **/ + fun renderOutputType(writer: RustWriter, params: StreamPayloadSerializerParams) + + /** Renders the stream payload **/ + fun renderPayload(writer: RustWriter, params: StreamPayloadSerializerParams) +} class HttpBoundProtocolPayloadGenerator( codegenContext: CodegenContext, private val protocol: Protocol, private val httpMessageType: HttpMessageType = HttpMessageType.REQUEST, private val renderEventStreamBody: (RustWriter, EventStreamBodyParams) -> Unit, - private val streamPayloadSerializer: StreamPayloadSerializer, + private val streamPayloadSerializerRenderer: StreamPayloadSerializerRenderer, ) : ProtocolPayloadGenerator { private val symbolProvider = codegenContext.symbolProvider private val model = codegenContext.model @@ -321,7 +335,7 @@ class HttpBoundProtocolPayloadGenerator( *codegenScope, ) if (member.isStreaming(model)) { - streamPayloadSerializer.outputT( + streamPayloadSerializerRenderer.renderOutputType( this, StreamPayloadSerializerParams(symbolProvider, runtimeConfig, member, null), ) @@ -386,7 +400,7 @@ class HttpBoundProtocolPayloadGenerator( is BlobShape -> { // Write the raw blob to the payload. if (member.isStreaming(model)) { - streamPayloadSerializer.renderPayload( + streamPayloadSerializerRenderer.renderPayload( this, StreamPayloadSerializerParams(symbolProvider, runtimeConfig, member, payloadName), ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 2052a25652..8db95c0657 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -57,8 +57,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtoc import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions -import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializer import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializerParams +import software.amazon.smithy.rust.codegen.core.smithy.protocols.StreamPayloadSerializerRenderer import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors @@ -128,6 +128,36 @@ class ServerHttpBoundProtocolGenerator( } } +/** + * Server implementation of the [StreamPayloadSerializerRenderer] interface. + * + * The implementation of each method is delegated to [customizations]. Regular server codegen and python server + * have different requirements for how to render stream payload serializers, and they express their requirements + * through customizations, specifically with [TypeOfSerializedPayloadStream] and [WrapStreamAfterPayloadGenerated]. + */ +private class ServerStreamPayloadSerializerRenderer(private val customizations: List) : + StreamPayloadSerializerRenderer { + override fun renderOutputType(writer: RustWriter, params: StreamPayloadSerializerParams) { + for (customization in customizations) { + customization.section( + ServerHttpBoundProtocolSection.TypeOfSerializedPayloadStream( + params, + ), + )(writer) + } + } + + override fun renderPayload(writer: RustWriter, params: StreamPayloadSerializerParams) { + for (customization in customizations) { + customization.section( + ServerHttpBoundProtocolSection.WrapStreamAfterPayloadGenerated( + params, + ), + )(writer) + } + } +} + class ServerHttpBoundProtocolPayloadGenerator( codegenContext: CodegenContext, protocol: Protocol, @@ -152,26 +182,7 @@ class ServerHttpBoundProtocolPayloadGenerator( "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, ) }, - streamPayloadSerializer = StreamPayloadSerializer( - { writer, params -> - for (customization in customizations) { - customization.section( - ServerHttpBoundProtocolSection.TypeOfSerializedPayloadStream( - params, - ), - )(writer) - } - }, - { writer, params -> - for (customization in customizations) { - customization.section( - ServerHttpBoundProtocolSection.WrapStreamAfterPayloadGenerated( - params, - ), - )(writer) - } - }, - ), + ServerStreamPayloadSerializerRenderer(customizations), ) /* From 3db05256633d262b59affcc4e5609afa2b142434 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Wed, 30 Aug 2023 22:53:43 -0500 Subject: [PATCH 33/36] Add comments to data class for `ServerHttpBoundProtocolSection` This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1310840452 --- .../protocols/PythonServerProtocolLoader.kt | 4 +-- .../ServerHttpBoundProtocolGenerator.kt | 25 +++++++++++++------ .../smithy/protocols/ServerProtocolLoader.kt | 4 +-- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt index df9ae3ab97..5569c92526 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/protocols/PythonServerProtocolLoader.kt @@ -72,14 +72,14 @@ class PythonServerAfterDeserializedMemberServerHttpBoundCustomization() : */ class PythonServerStreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { - is ServerHttpBoundProtocolSection.TypeOfSerializedPayloadStream -> writable { + is ServerHttpBoundProtocolSection.TypeOfSerializedStreamPayload -> writable { // `aws_smithy_http_server_python::types::ByteStream` already implements // `futures::stream::Stream`, so no need to wrap it in a futures' stream-compatible // wrapper. rust("#T", section.params.symbolProvider.toSymbol(section.params.member)) } - is ServerHttpBoundProtocolSection.WrapStreamAfterPayloadGenerated -> writable { + is ServerHttpBoundProtocolSection.WrapStreamPayload -> writable { // payloadName is always non-null within WrapStreamAfterPayloadGenerated rust(section.params.payloadName!!) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 8db95c0657..bc6fa6ae01 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -90,11 +90,22 @@ sealed class ServerHttpBoundProtocolSection(name: String) : Section(name) { data class AfterTimestampDeserializedMember(val shape: MemberShape) : ServerHttpBoundProtocolSection("AfterTimestampDeserializedMember") - data class TypeOfSerializedPayloadStream(val params: StreamPayloadSerializerParams) : - ServerHttpBoundProtocolSection("TypeOfSerializedPayloadStream") + /** + * Represent a section for rendering the return type of serialized stream payload. + * + * When overriding the `section` method, this should render [Symbol] for that return type. + */ + data class TypeOfSerializedStreamPayload(val params: StreamPayloadSerializerParams) : + ServerHttpBoundProtocolSection("TypeOfSerializedStreamPayload") - data class WrapStreamAfterPayloadGenerated(val params: StreamPayloadSerializerParams) : - ServerHttpBoundProtocolSection("WrapStreamAfterPayloadGenerated") + /** + * Represent a section for rendering the serialized stream payload. + * + * When overriding the `section` method, this should render either the payload as-is or the payload wrapped + * with a new-type that implements the `futures_core::stream::Stream` trait. + */ + data class WrapStreamPayload(val params: StreamPayloadSerializerParams) : + ServerHttpBoundProtocolSection("WrapStreamPayload") } /** @@ -133,14 +144,14 @@ class ServerHttpBoundProtocolGenerator( * * The implementation of each method is delegated to [customizations]. Regular server codegen and python server * have different requirements for how to render stream payload serializers, and they express their requirements - * through customizations, specifically with [TypeOfSerializedPayloadStream] and [WrapStreamAfterPayloadGenerated]. + * through customizations, specifically with [TypeOfSerializedStreamPayload] and [WrapStreamPayload]. */ private class ServerStreamPayloadSerializerRenderer(private val customizations: List) : StreamPayloadSerializerRenderer { override fun renderOutputType(writer: RustWriter, params: StreamPayloadSerializerParams) { for (customization in customizations) { customization.section( - ServerHttpBoundProtocolSection.TypeOfSerializedPayloadStream( + ServerHttpBoundProtocolSection.TypeOfSerializedStreamPayload( params, ), )(writer) @@ -150,7 +161,7 @@ private class ServerStreamPayloadSerializerRenderer(private val customizations: override fun renderPayload(writer: RustWriter, params: StreamPayloadSerializerParams) { for (customization in customizations) { customization.section( - ServerHttpBoundProtocolSection.WrapStreamAfterPayloadGenerated( + ServerHttpBoundProtocolSection.WrapStreamPayload( params, ), )(writer) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index 5a175edb33..ed9db48b86 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -22,7 +22,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser class StreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { - is ServerHttpBoundProtocolSection.TypeOfSerializedPayloadStream -> writable { + is ServerHttpBoundProtocolSection.TypeOfSerializedStreamPayload -> writable { // `aws_smithy_http::byte_stream::ByteStream` no longer implements `futures_core::stream::Stream` // so wrap it in a new-type to enable the trait. rust( @@ -32,7 +32,7 @@ class StreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomizat ) } - is ServerHttpBoundProtocolSection.WrapStreamAfterPayloadGenerated -> writable { + is ServerHttpBoundProtocolSection.WrapStreamPayload -> writable { rustTemplate( "#{HyperBodyWrapByteStream}::new(${section.params.payloadName!!})", "HyperBodyWrapByteStream" to RuntimeType.hyperBodyWrapStream(section.params.runtimeConfig) From 91163dead231d9b954638a7f92d0704f363e22a5 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Wed, 30 Aug 2023 23:13:36 -0500 Subject: [PATCH 34/36] Use `ServiceShape.hasEventStreamOperations` This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1310842380 --- .../smithy/customizations/SmithyTypesPubUseExtra.kt | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt index 57af1a8547..22b9b0eb1e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.hasEventStreamMember +import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember /** Returns true if the model has normal streaming operations (excluding event streams) */ @@ -26,14 +27,6 @@ private fun hasStreamingOperations(model: Model): Boolean { } } -/** Returns true if the model has event streaming operations */ -private fun hasEventStreamOperations(model: Model): Boolean = - model.operationShapes.any { operation -> - val input = model.expectShape(operation.inputShape, StructureShape::class.java) - val output = model.expectShape(operation.outputShape, StructureShape::class.java) - input.hasEventStreamMember(model) || output.hasEventStreamMember(model) - } - // TODO(https://github.com/awslabs/smithy-rs/issues/2111): Fix this logic to consider collection/map shapes private fun structUnionMembersMatchPredicate(model: Model, predicate: (Shape) -> Boolean): Boolean = model.structureShapes.any { structure -> @@ -78,7 +71,7 @@ fun pubUseSmithyPrimitives(codegenContext: CodegenContext, model: Model): Writab "SdkBody" to RuntimeType.smithyHttp(rc).resolve("body::SdkBody"), ) } - if (hasEventStreamOperations(model)) { + if (codegenContext.serviceShape.hasEventStreamOperations(model)) { rustTemplate( """ pub use #{FnStream}; From f5c8a96fc47331d219ed4aa34e492a776b1b2190 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Thu, 7 Sep 2023 14:30:13 -0500 Subject: [PATCH 35/36] Move `hyper_body_wrap_stream` from inlineable to `aws-smithy-http` This commit addresses https://github.com/awslabs/smithy-rs/pull/2910#discussion_r1317126345. Now that `HyperBodyWrapByteStream` is also used outside the context of `hyper::body::Body::wrap_stream` (passed to `tokio_util::io::StreamReader`), the module has been renamed to `futures_stream_adapter`. Furthermore, `HyperBodyWrapByteStream` and `HyperBodyWrapEventStream` have been renamed to `FuturesStreamCompatByteStream` and `FuturesStreamCompatEventStream` respectively. --- .../protocols/HttpBoundProtocolGenerator.kt | 11 +-- .../codegen/core/rustlang/CargoDependency.kt | 8 -- .../rust/codegen/core/smithy/RuntimeType.kt | 6 +- .../HttpBoundProtocolPayloadGenerator.kt | 6 +- .../smithy/protocols/ServerProtocolLoader.kt | 10 +- .../aws-smithy-http/src/byte_stream.rs | 22 +---- .../src/futures_stream_adapter.rs} | 95 +++++++++++-------- rust-runtime/aws-smithy-http/src/lib.rs | 2 + rust-runtime/inlineable/Cargo.toml | 5 +- rust-runtime/inlineable/src/lib.rs | 2 - 10 files changed, 72 insertions(+), 95 deletions(-) rename rust-runtime/{inlineable/src/hyper_body_wrap_stream.rs => aws-smithy-http/src/futures_stream_adapter.rs} (51%) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index c449d75d70..e2283054ec 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -22,16 +22,14 @@ private class ClientStreamPayloadSerializerRenderer : StreamPayloadSerializerRen override fun renderOutputType(writer: RustWriter, params: StreamPayloadSerializerParams) { writer.rust( "#T", - RuntimeType.hyperBodyWrapStream(params.runtimeConfig).resolve("HyperBodyWrapByteStream").toSymbol(), + RuntimeType.futuresStreamCompatByteStream(params.runtimeConfig).toSymbol(), ) } override fun renderPayload(writer: RustWriter, params: StreamPayloadSerializerParams) { - // Payload is `aws_smithy_http::byte_stream::ByteStream` but it no longer implements `futures::stream::Stream`, - // so we wrap it in a new-type to enable the trait. writer.rust( "#T::new(${params.payloadName!!})", - RuntimeType.hyperBodyWrapStream(params.runtimeConfig).resolve("HyperBodyWrapByteStream").toSymbol(), + RuntimeType.futuresStreamCompatByteStream(params.runtimeConfig), ) } } @@ -51,7 +49,7 @@ class ClientHttpBoundProtocolPayloadGenerator( _cfg.interceptor_state().store_put(signer_sender); let adapter: #{aws_smithy_http}::event_stream::MessageStreamAdapter<_, _> = ${params.outerName}.${params.memberName}.into_body_stream(marshaller, error_marshaller, signer); - let body: #{SdkBody} = #{hyper}::Body::wrap_stream(#{HyperBodyWrapEventStream}::new(adapter)).into(); + let body: #{SdkBody} = #{hyper}::Body::wrap_stream(#{FuturesStreamCompatEventStream}::new(adapter)).into(); body } """, @@ -60,8 +58,7 @@ class ClientHttpBoundProtocolPayloadGenerator( "aws_smithy_http" to RuntimeType.smithyHttp(codegenContext.runtimeConfig), "DeferredSigner" to RuntimeType.smithyEventStream(codegenContext.runtimeConfig) .resolve("frame::DeferredSigner"), - "HyperBodyWrapEventStream" to RuntimeType.hyperBodyWrapStream(codegenContext.runtimeConfig) - .resolve("HyperBodyWrapEventStream"), + "FuturesStreamCompatEventStream" to RuntimeType.futuresStreamCompatEventStream(codegenContext.runtimeConfig), "marshallerConstructorFn" to params.marshallerConstructorFn, "errorMarshallerConstructorFn" to params.errorMarshallerConstructorFn, ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index beede3bd75..d3f3819b94 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -127,14 +127,6 @@ class InlineDependency( CargoDependency.smithyTypes(runtimeConfig), ) - fun hyperBodyWrapStream(runtimeConfig: RuntimeConfig): InlineDependency = forInlineableRustFile( - "hyper_body_wrap_stream", - CargoDependency.smithyHttp(runtimeConfig).withFeature("event-stream"), - CargoDependency.FuturesCore, - CargoDependency.smithyAsync(runtimeConfig).toDevDependency(), - CargoDependency.smithyEventStream(runtimeConfig).toDevDependency(), - ) - fun constrained(): InlineDependency = InlineDependency.forRustFile(ConstrainedModule, "/inlineable/src/constrained.rs") } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index 0474bd450b..7aba2a5aee 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -408,8 +408,10 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) smithyHttp(runtimeConfig).resolve("event_stream::Receiver") fun eventStreamSender(runtimeConfig: RuntimeConfig): RuntimeType = smithyHttp(runtimeConfig).resolve("event_stream::EventStreamSender") - fun hyperBodyWrapStream(runtimeConfig: RuntimeConfig): RuntimeType = - forInlineDependency(InlineDependency.hyperBodyWrapStream(runtimeConfig)) + fun futuresStreamCompatByteStream(runtimeConfig: RuntimeConfig): RuntimeType = + smithyHttp(runtimeConfig).resolve("futures_stream_adapter::FuturesStreamCompatByteStream") + fun futuresStreamCompatEventStream(runtimeConfig: RuntimeConfig): RuntimeType = + smithyHttp(runtimeConfig).resolve("futures_stream_adapter::FuturesStreamCompatEventStream") fun errorMetadata(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("error::ErrorMetadata") fun errorMetadataBuilder(runtimeConfig: RuntimeConfig) = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index ff44f67269..68e4de4403 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -96,7 +96,7 @@ class HttpBoundProtocolPayloadGenerator( private val codegenScope = arrayOf( *preludeScope, "BuildError" to runtimeConfig.operationBuildError(), - "HyperBodyWrapEventStream" to RuntimeType.hyperBodyWrapStream(runtimeConfig).resolve("HyperBodyWrapEventStream"), + "FuturesStreamCompatEventStream" to RuntimeType.futuresStreamCompatEventStream(runtimeConfig), "NoOpSigner" to smithyEventStream.resolve("frame::NoOpSigner"), "SdkBody" to RuntimeType.sdkBody(runtimeConfig), "SmithyHttp" to RuntimeType.smithyHttp(runtimeConfig), @@ -308,12 +308,12 @@ class HttpBoundProtocolPayloadGenerator( } if (target == CodegenTarget.CLIENT) { - // No need to wrap it with `HyperBodyWrapEventStream` for the client since wrapping takes place + // No need to wrap it with `FuturesStreamCompatEventStream` for the client since wrapping takes place // within `renderEventStreamBody` provided by `ClientHttpBoundProtocolPayloadGenerator`. renderEventStreamBody() } else { withBlockTemplate( - "#{HyperBodyWrapEventStream}::new(", ")", + "#{FuturesStreamCompatEventStream}::new(", ")", *codegenScope, ) { renderEventStreamBody() diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt index ed9db48b86..e13220687d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerProtocolLoader.kt @@ -23,20 +23,16 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser class StreamPayloadSerializerCustomization() : ServerHttpBoundProtocolCustomization() { override fun section(section: ServerHttpBoundProtocolSection): Writable = when (section) { is ServerHttpBoundProtocolSection.TypeOfSerializedStreamPayload -> writable { - // `aws_smithy_http::byte_stream::ByteStream` no longer implements `futures_core::stream::Stream` - // so wrap it in a new-type to enable the trait. rust( "#T", - RuntimeType.hyperBodyWrapStream(section.params.runtimeConfig) - .resolve("HyperBodyWrapByteStream").toSymbol(), + RuntimeType.futuresStreamCompatByteStream(section.params.runtimeConfig).toSymbol(), ) } is ServerHttpBoundProtocolSection.WrapStreamPayload -> writable { rustTemplate( - "#{HyperBodyWrapByteStream}::new(${section.params.payloadName!!})", - "HyperBodyWrapByteStream" to RuntimeType.hyperBodyWrapStream(section.params.runtimeConfig) - .resolve("HyperBodyWrapByteStream"), + "#{FuturesStreamCompatByteStream}::new(${section.params.payloadName!!})", + "FuturesStreamCompatByteStream" to RuntimeType.futuresStreamCompatByteStream(section.params.runtimeConfig), ) } diff --git a/rust-runtime/aws-smithy-http/src/byte_stream.rs b/rust-runtime/aws-smithy-http/src/byte_stream.rs index a369bf04c5..047be61101 100644 --- a/rust-runtime/aws-smithy-http/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-http/src/byte_stream.rs @@ -136,6 +136,7 @@ use std::task::{Context, Poll}; #[cfg(feature = "rt-tokio")] mod bytestream_util; +use crate::futures_stream_adapter::FuturesStreamCompatByteStream; #[cfg(feature = "rt-tokio")] pub use bytestream_util::Length; @@ -417,7 +418,7 @@ impl ByteStream { /// # } /// ``` pub fn into_async_read(self) -> impl tokio::io::AsyncRead { - tokio_util::io::StreamReader::new(StreamWrapper { byte_stream: self }) + tokio_util::io::StreamReader::new(FuturesStreamCompatByteStream::new(self)) } /// Given a function to modify an [`SdkBody`], run it on the `SdkBody` inside this `Bytestream`. @@ -427,25 +428,6 @@ impl ByteStream { } } -pin_project! { - // A new-type wrapper around `ByteStream` so we can pass it to `tokio_util::io::StreamReader`. - struct StreamWrapper { - #[pin] - byte_stream: ByteStream, - } -} - -impl futures_core::stream::Stream for StreamWrapper { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project() - .byte_stream - .poll_next(cx) - .map_err(Error::streaming) - } -} - impl Default for ByteStream { fn default() -> Self { Self { diff --git a/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs b/rust-runtime/aws-smithy-http/src/futures_stream_adapter.rs similarity index 51% rename from rust-runtime/inlineable/src/hyper_body_wrap_stream.rs rename to rust-runtime/aws-smithy-http/src/futures_stream_adapter.rs index 1dc6c2be33..6c696b7138 100644 --- a/rust-runtime/inlineable/src/hyper_body_wrap_stream.rs +++ b/rust-runtime/aws-smithy-http/src/futures_stream_adapter.rs @@ -1,56 +1,67 @@ +// Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT. /* * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_http::body::SdkBody; -use aws_smithy_http::byte_stream::error::Error as ByteStreamError; -use aws_smithy_http::byte_stream::ByteStream; -use aws_smithy_http::event_stream::MessageStreamAdapter; -use aws_smithy_http::result::SdkError; +use crate::body::SdkBody; +use crate::byte_stream::error::Error as ByteStreamError; +use crate::byte_stream::ByteStream; use bytes::Bytes; use futures_core::stream::Stream; -use std::error::Error as StdError; use std::pin::Pin; use std::task::{Context, Poll}; -pub(crate) struct HyperBodyWrapEventStream(MessageStreamAdapter); +/// New-type wrapper to enable the impl of the `futures_core::stream::Stream` trait +/// +/// [`ByteStream`] no longer implements `futures_core::stream::Stream` so we wrap it in the +/// new-type to enable the trait when it is required. +/// +/// This is meant to be used by codegen code, and users should not need to use it directly. +pub struct FuturesStreamCompatByteStream(ByteStream); + +impl FuturesStreamCompatByteStream { + /// Creates a new `FuturesStreamCompatByteStream` by wrapping `stream`. + pub fn new(stream: ByteStream) -> Self { + Self(stream) + } -impl HyperBodyWrapEventStream { - #[allow(dead_code)] - pub(crate) fn new(adapter: MessageStreamAdapter) -> Self { - Self(adapter) + /// Returns [`SdkBody`] of the wrapped [`ByteStream`]. + pub fn into_inner(self) -> SdkBody { + self.0.into_inner() } } -impl Unpin for HyperBodyWrapEventStream {} - -impl Stream for HyperBodyWrapEventStream { - type Item = Result>; +impl Stream for FuturesStreamCompatByteStream { + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_next(cx) } } -pub(crate) struct HyperBodyWrapByteStream(ByteStream); - -impl HyperBodyWrapByteStream { - #[allow(dead_code)] - pub(crate) fn new(stream: ByteStream) -> Self { - Self(stream) - } - - #[allow(dead_code)] - pub(crate) fn into_inner(self) -> SdkBody { - self.0.into_inner() +#[cfg(feature = "event-stream")] +/// New-type wrapper to enable the impl of the `futures_core::stream::Stream` trait +/// +/// [`crate::event_stream::MessageStreamAdapter`] no longer implements `futures_core::stream::Stream` +/// so we wrap it in the new-type to enable the trait when it is required. +/// +/// This is meant to be used by codegen code, and users should not need to use it directly. +pub struct FuturesStreamCompatEventStream(crate::event_stream::MessageStreamAdapter); + +#[cfg(feature = "event-stream")] +impl FuturesStreamCompatEventStream { + /// Creates a new `FuturesStreamCompatEventStream` by wrapping `adapter`. + pub fn new(adapter: crate::event_stream::MessageStreamAdapter) -> Self { + Self(adapter) } } -impl Unpin for HyperBodyWrapByteStream {} - -impl Stream for HyperBodyWrapByteStream { - type Item = Result; +#[cfg(feature = "event-stream")] +impl Stream + for FuturesStreamCompatEventStream +{ + type Item = Result>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_next(cx) @@ -95,17 +106,18 @@ mod tests { write!(f, "TestServiceError") } } - impl StdError for TestServiceError {} + impl std::error::Error for TestServiceError {} fn check_compatible_with_hyper_wrap_stream(stream: S) -> S where S: Stream> + Send + 'static, O: Into + 'static, - E: Into> + 'static, + E: Into> + 'static, { stream } + #[cfg(feature = "event-stream")] #[test] fn test_message_adapter_stream_can_be_made_compatible_with_hyper_wrap_stream() { let stream = FnStream::new(|tx| { @@ -114,20 +126,19 @@ mod tests { tx.send(message).await.expect("failed to send"); }) }); - check_compatible_with_hyper_wrap_stream(HyperBodyWrapEventStream(MessageStreamAdapter::< - TestMessage, - TestServiceError, - >::new( - Marshaller, - ErrorMarshaller, - NoOpSigner {}, - stream, - ))); + check_compatible_with_hyper_wrap_stream(FuturesStreamCompatEventStream( + crate::event_stream::MessageStreamAdapter::::new( + Marshaller, + ErrorMarshaller, + NoOpSigner {}, + stream, + ), + )); } #[test] fn test_byte_stream_stream_can_be_made_compatible_with_hyper_wrap_stream() { let stream = ByteStream::from_static(b"Hello world"); - check_compatible_with_hyper_wrap_stream(HyperBodyWrapByteStream::new(stream)); + check_compatible_with_hyper_wrap_stream(FuturesStreamCompatByteStream::new(stream)); } } diff --git a/rust-runtime/aws-smithy-http/src/lib.rs b/rust-runtime/aws-smithy-http/src/lib.rs index ad53e37f9c..abad71c2c9 100644 --- a/rust-runtime/aws-smithy-http/src/lib.rs +++ b/rust-runtime/aws-smithy-http/src/lib.rs @@ -27,6 +27,8 @@ pub mod body; pub mod endpoint; +#[doc(hidden)] +pub mod futures_stream_adapter; pub mod header; pub mod http; pub mod http_versions; diff --git a/rust-runtime/inlineable/Cargo.toml b/rust-runtime/inlineable/Cargo.toml index 97cd27ea85..6d0c099429 100644 --- a/rust-runtime/inlineable/Cargo.toml +++ b/rust-runtime/inlineable/Cargo.toml @@ -19,7 +19,7 @@ default = ["gated-tests"] [dependencies] async-trait = "0.1" -aws-smithy-http = { path = "../aws-smithy-http", features = ["event-stream"] } +aws-smithy-http = { path = "../aws-smithy-http" } aws-smithy-http-server = { path = "../aws-smithy-http-server" } aws-smithy-json = { path = "../aws-smithy-json" } aws-smithy-runtime-api = { path = "../aws-smithy-runtime-api", features = ["client"] } @@ -27,7 +27,6 @@ aws-smithy-types = { path = "../aws-smithy-types" } aws-smithy-xml = { path = "../aws-smithy-xml" } bytes = "1" fastrand = "2.0.0" -futures-core = "0.3" futures-util = "0.3" http = "0.2.1" md-5 = "0.10.0" @@ -40,8 +39,6 @@ url = "2.2.2" [dev-dependencies] proptest = "1" -aws-smithy-async = { path = "../aws-smithy-async" } -aws-smithy-eventstream = { path = "../aws-smithy-eventstream" } [package.metadata.docs.rs] all-features = true diff --git a/rust-runtime/inlineable/src/lib.rs b/rust-runtime/inlineable/src/lib.rs index fd40a4aca7..b672eeef9d 100644 --- a/rust-runtime/inlineable/src/lib.rs +++ b/rust-runtime/inlineable/src/lib.rs @@ -13,8 +13,6 @@ mod client_idempotency_token; mod constrained; #[allow(dead_code)] mod ec2_query_errors; -#[allow(unused)] -mod hyper_body_wrap_stream; #[allow(dead_code)] mod idempotency_token; #[allow(dead_code)] From 852211fec0c8df5457f44a58bc9d90c601b60bc4 Mon Sep 17 00:00:00 2001 From: ysaito1001 Date: Thu, 7 Sep 2023 22:31:07 -0500 Subject: [PATCH 36/36] Fix undeclared crate or module for feature-gated types --- .../aws-smithy-http/src/byte_stream.rs | 5 +- .../src/futures_stream_adapter.rs | 87 ++++++++++--------- 2 files changed, 49 insertions(+), 43 deletions(-) diff --git a/rust-runtime/aws-smithy-http/src/byte_stream.rs b/rust-runtime/aws-smithy-http/src/byte_stream.rs index 047be61101..ed35149621 100644 --- a/rust-runtime/aws-smithy-http/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-http/src/byte_stream.rs @@ -136,7 +136,6 @@ use std::task::{Context, Poll}; #[cfg(feature = "rt-tokio")] mod bytestream_util; -use crate::futures_stream_adapter::FuturesStreamCompatByteStream; #[cfg(feature = "rt-tokio")] pub use bytestream_util::Length; @@ -418,7 +417,9 @@ impl ByteStream { /// # } /// ``` pub fn into_async_read(self) -> impl tokio::io::AsyncRead { - tokio_util::io::StreamReader::new(FuturesStreamCompatByteStream::new(self)) + tokio_util::io::StreamReader::new( + crate::futures_stream_adapter::FuturesStreamCompatByteStream::new(self), + ) } /// Given a function to modify an [`SdkBody`], run it on the `SdkBody` inside this `Bytestream`. diff --git a/rust-runtime/aws-smithy-http/src/futures_stream_adapter.rs b/rust-runtime/aws-smithy-http/src/futures_stream_adapter.rs index 6c696b7138..96551714f5 100644 --- a/rust-runtime/aws-smithy-http/src/futures_stream_adapter.rs +++ b/rust-runtime/aws-smithy-http/src/futures_stream_adapter.rs @@ -72,42 +72,8 @@ impl Stream mod tests { use super::*; use aws_smithy_async::future::fn_stream::FnStream; - use aws_smithy_eventstream::error::Error; - use aws_smithy_eventstream::frame::MarshallMessage; - use aws_smithy_eventstream::frame::{Message, NoOpSigner}; use futures_core::stream::Stream; - #[derive(Debug, Eq, PartialEq)] - struct TestMessage(String); - - #[derive(Debug)] - struct Marshaller; - impl MarshallMessage for Marshaller { - type Input = TestMessage; - - fn marshall(&self, input: Self::Input) -> Result { - Ok(Message::new(input.0.as_bytes().to_vec())) - } - } - #[derive(Debug)] - struct ErrorMarshaller; - impl MarshallMessage for ErrorMarshaller { - type Input = TestServiceError; - - fn marshall(&self, _input: Self::Input) -> Result { - Err(Message::read_from(&b""[..]).expect_err("this should always fail")) - } - } - - #[derive(Debug)] - struct TestServiceError; - impl std::fmt::Display for TestServiceError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "TestServiceError") - } - } - impl std::error::Error for TestServiceError {} - fn check_compatible_with_hyper_wrap_stream(stream: S) -> S where S: Stream> + Send + 'static, @@ -117,8 +83,53 @@ mod tests { stream } + #[test] + fn test_byte_stream_stream_can_be_made_compatible_with_hyper_wrap_stream() { + let stream = ByteStream::from_static(b"Hello world"); + check_compatible_with_hyper_wrap_stream(FuturesStreamCompatByteStream::new(stream)); + } + #[cfg(feature = "event-stream")] + mod tests_event_stream { + use aws_smithy_eventstream::error::Error; + use aws_smithy_eventstream::frame::MarshallMessage; + use aws_smithy_eventstream::frame::Message; + + #[derive(Debug, Eq, PartialEq)] + pub(crate) struct TestMessage(pub(crate) String); + + #[derive(Debug)] + pub(crate) struct Marshaller; + impl MarshallMessage for Marshaller { + type Input = TestMessage; + + fn marshall(&self, input: Self::Input) -> Result { + Ok(Message::new(input.0.as_bytes().to_vec())) + } + } + #[derive(Debug)] + pub(crate) struct ErrorMarshaller; + impl MarshallMessage for ErrorMarshaller { + type Input = TestServiceError; + + fn marshall(&self, _input: Self::Input) -> Result { + Err(Message::read_from(&b""[..]).expect_err("this should always fail")) + } + } + + #[derive(Debug)] + pub(crate) struct TestServiceError; + impl std::fmt::Display for TestServiceError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TestServiceError") + } + } + impl std::error::Error for TestServiceError {} + } + use tests_event_stream::*; + #[test] + #[cfg(feature = "event-stream")] fn test_message_adapter_stream_can_be_made_compatible_with_hyper_wrap_stream() { let stream = FnStream::new(|tx| { Box::pin(async move { @@ -130,15 +141,9 @@ mod tests { crate::event_stream::MessageStreamAdapter::::new( Marshaller, ErrorMarshaller, - NoOpSigner {}, + aws_smithy_eventstream::frame::NoOpSigner {}, stream, ), )); } - - #[test] - fn test_byte_stream_stream_can_be_made_compatible_with_hyper_wrap_stream() { - let stream = ByteStream::from_static(b"Hello world"); - check_compatible_with_hyper_wrap_stream(FuturesStreamCompatByteStream::new(stream)); - } }