diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 16c847df7fb..c18cf58b6cb 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -11,6 +11,42 @@ # meta = { "breaking" = false, "tada" = false, "bug" = false } # author = "rcoh" +[[smithy-rs]] +message = "Rename EventStreamInput<> to EventStreamSender<>" +references = ["smithy-rs#1157"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "82marbag" + +[[smithy-rs]] +message = "Add ability to sign a request with all headers, or to change which headers are excluded from signing" +references = ["smithy-rs#1381"] +meta = { "breaking" = false, "tada" = true, "bug" = false } +author = "alonlud" + + [[aws-sdk-rust]] + message = "Add method `ByteStream::into_async_read`. This makes it easy to convert `ByteStream`s into a struct implementing `tokio:io::AsyncRead`. Available on **crate feature** `rt-tokio` only." + references = ["smithy-rs#1390"] + meta = { "breaking" = false, "tada" = true, "bug" = false } + author = "Velfi" + + [[smithy-rs]] + message = "Add method `ByteStream::into_async_read`. This makes it easy to convert `ByteStream`s into a struct implementing `tokio:io::AsyncRead`. Available on **crate feature** `rt-tokio` only." + references = ["smithy-rs#1390"] + meta = { "breaking" = false, "tada" = true, "bug" = false } + author = "Velfi" + +[[smithy-rs]] +message = "Add ability to specify a different rust crate name than the one derived from the package name" +references = ["smithy-rs#1404"] +meta = { "breaking" = false, "tada" = false, "bug" = false } +author = "petrosagg" + +[[smithy-rs]] +message = "Switch to [RustCrypto](https://github.com/RustCrypto)'s implementation of MD5." +references = ["smithy-rs#1404"] +meta = { "breaking" = false, "tada" = false, "bug" = false } +author = "petrosagg" + [[aws-sdk-rust]] message = "Fix bug in profile file credential provider where a missing `default` profile lead to an unintended error." references = ["aws-sdk-rust#547", "smithy-rs#1458"] @@ -21,4 +57,4 @@ author = "rcoh" message = "Add `Debug` implementation to several types in `aws-config`" references = ["smithy-rs#1421"] meta = { "breaking" = false, "tada" = false, "bug" = false } -author = "jdisanti" \ No newline at end of file +author = "jdisanti" diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt index 9c06c6c2b37..1b19d9e8061 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.rust.codegen.smithy.customizations.RetryConfigDeco import software.amazon.smithy.rust.codegen.smithy.customizations.SleepImplDecorator import software.amazon.smithy.rust.codegen.smithy.customizations.TimeoutConfigDecorator import software.amazon.smithy.rust.codegen.smithy.customize.CombinedCodegenDecorator +import software.amazon.smithy.rust.codegen.smithy.customize.EventStreamDecorator import software.amazon.smithy.rustsdk.customize.apigateway.ApiGatewayDecorator import software.amazon.smithy.rustsdk.customize.auth.DisabledAuthDecorator import software.amazon.smithy.rustsdk.customize.ec2.Ec2Decorator @@ -24,7 +25,7 @@ val DECORATORS = listOf( RegionDecorator(), AwsEndpointDecorator(), UserAgentDecorator(), - SigV4SigningDecorator(), + EventStreamDecorator(listOf(SigV4SigningDecorator())), RetryPolicyDecorator(), IntegrationTestDecorator(), AwsFluentClientDecorator(), diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt index 935789068af..dc18d4ee500 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SigV4SigningDecorator.kt @@ -22,9 +22,9 @@ import software.amazon.smithy.rust.codegen.rustlang.writable import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.customize.EventStreamDecorator import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection -import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.smithy.letIf @@ -42,11 +42,11 @@ import software.amazon.smithy.rust.codegen.util.isInputEventStream * - sets a default `OperationSigningConfig` A future enhancement will customize this for specific services that need * different behavior. */ -class SigV4SigningDecorator : RustCodegenDecorator { +class SigV4SigningDecorator : EventStreamDecorator(listOf()) { override val name: String = "SigV4Signing" override val order: Byte = 0 - private fun applies(codegenContext: CodegenContext): Boolean = codegenContext.serviceShape.hasTrait() + override fun applies(codegenContext: CodegenContext): Boolean = codegenContext.serviceShape.hasTrait() override fun configCustomizations( codegenContext: CodegenContext, diff --git a/codegen-server-test/model/pokemon.smithy b/codegen-server-test/model/pokemon.smithy index 4c9b3c626b5..a0b9d435127 100644 --- a/codegen-server-test/model/pokemon.smithy +++ b/codegen-server-test/model/pokemon.smithy @@ -10,7 +10,7 @@ use aws.protocols#restJson1 service PokemonService { version: "2021-12-01", resources: [PokemonSpecies], - operations: [GetServerStatistics, EmptyOperation], + operations: [GetServerStatistics, EmptyOperation, CapturePokemonOperation], } /// A Pokémon species forms the basis for at least one Pokémon. @@ -22,6 +22,44 @@ resource PokemonSpecies { read: GetPokemonSpecies, } +/// Capture Pokémons via event streams +@http(uri: "/capture-pokemon-event", method: "POST") +operation CapturePokemonOperation { + input: CapturePokemonOperationEventsInput, + output: CapturePokemonOperationEventsOutput, +} + +@input +structure CapturePokemonOperationEventsInput { + @httpPayload + events: AttemptCapturingPokemonEvent, +} + +@output +structure CapturePokemonOperationEventsOutput { + @httpPayload + events: CapturePokemonEvents, +} + +@streaming +union AttemptCapturingPokemonEvent { + event: CapturingEvent, +} + +structure CapturingEvent { + name: String, + pokeball: String, +} + +@streaming +union CapturePokemonEvents { + event: CaptureEvent, +} + +structure CaptureEvent { + name: String, +} + /// Retrieve information about a Pokémon species. @readonly @http(uri: "/pokemon-species/{name}", method: "GET") diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt index 58bbc744829..fcc2377afe3 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt @@ -36,7 +36,7 @@ class PythonCodegenServerPlugin : SmithyBuildPlugin { override fun execute(context: PluginContext) { // Suppress extremely noisy logs about reserved words Logger.getLogger(ReservedWordSymbolProvider::class.java.name).level = Level.OFF - // Discover [RustCodegenDecorators] on the classpath. [RustCodegenDectorator] return different types of + // Discover [RustCodegenDecorators] on the classpath. [RustCodegenDecorator] return different types of // customization. A customization is a function of: // - location (e.g. the mutate section of an operation) // - context (e.g. the of the operation) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt index 435c51b64ef..8560d100e06 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.rust.codegen.smithy.StreamingShapeSymbolProvider import software.amazon.smithy.rust.codegen.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.smithy.customize.CombinedCodegenDecorator +import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import java.util.logging.Level import java.util.logging.Logger @@ -62,7 +63,7 @@ class RustCodegenServerPlugin : SmithyBuildPlugin { SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) // Generate different types for EventStream shapes (e.g. transcribe streaming) .let { - EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model) + EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) } // Generate [ByteStream] instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) .let { StreamingShapeSymbolProvider(it, model) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 9ad961d399c..96b8e535516 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -505,13 +505,10 @@ private class ServerHttpBoundProtocolTraitImplGenerator( // Fallback to the default code of `http::response::Builder`, 200. operationShape.outputShape(model).findStreamingMember(model)?.let { - val memberName = symbolProvider.toMemberName(it) - rustTemplate( - """ - let body = #{SmithyHttpServer}::body::to_boxed(#{SmithyHttpServer}::body::Body::wrap_stream(output.$memberName)); - """, - *codegenScope, - ) + val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE) + withBlockTemplate("let body = #{SmithyHttpServer}::body::boxed(#{SmithyHttpServer}::body::Body::wrap_stream(", "));", *codegenScope) { + payloadGenerator.generatePayload(this, "output", operationShape) + } } ?: run { val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE) withBlockTemplate("let body = #{SmithyHttpServer}::body::to_boxed(", ");", *codegenScope) { @@ -682,29 +679,29 @@ private class ServerHttpBoundProtocolTraitImplGenerator( HttpLocation.HEADER -> writable { serverRenderHeaderParser(this, binding, operationShape) } HttpLocation.PREFIX_HEADERS -> writable { serverRenderPrefixHeadersParser(this, binding, operationShape) } HttpLocation.PAYLOAD -> { - return if (binding.member.isStreaming(model)) { - writable { + val structureShapeHandler: RustWriter.(String) -> Unit = { body -> + rust("#T($body)", structuredDataParser.payloadParser(binding.member)) + } + val errorSymbol = getDeserializePayloadErrorSymbol(binding) + val deserializer = httpBindingGenerator.generateDeserializePayloadFn( + binding, + errorSymbol, + structuredHandler = structureShapeHandler + ) + return writable { + if (binding.member.isStreaming(model)) { rustTemplate( """ { let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?; - Some(body.into()) + let bytes = #{Hyper}::body::to_bytes(body).await?; + Some(#{Deserializer}(&mut bytes.into())?) } - """.trimIndent(), + """, + "Deserializer" to deserializer, *codegenScope ) - } - } else { - val structureShapeHandler: RustWriter.(String) -> Unit = { body -> - rust("#T($body)", structuredDataParser.payloadParser(binding.member)) - } - val errorSymbol = getDeserializePayloadErrorSymbol(binding) - val deserializer = httpBindingGenerator.generateDeserializePayloadFn( - binding, - errorSymbol, - structuredHandler = structureShapeHandler - ) - writable { + } else { rustTemplate( """ { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProvider.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProvider.kt index bb92cebf6cb..2a5b0b7970b 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProvider.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProvider.kt @@ -14,12 +14,14 @@ import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustType import software.amazon.smithy.rust.codegen.rustlang.render import software.amazon.smithy.rust.codegen.rustlang.stripOuter +import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.util.getTrait import software.amazon.smithy.rust.codegen.util.isEventStream import software.amazon.smithy.rust.codegen.util.isInputEventStream +import software.amazon.smithy.rust.codegen.util.isOutputEventStream /** * Wrapping symbol provider to wrap modeled types with the aws-smithy-http Event Stream send/receive types. @@ -27,8 +29,10 @@ import software.amazon.smithy.rust.codegen.util.isInputEventStream class EventStreamSymbolProvider( private val runtimeConfig: RuntimeConfig, base: RustSymbolProvider, - private val model: Model + private val model: Model, + private val target: CodegenTarget, ) : WrappingSymbolProvider(base) { + private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) override fun toSymbol(shape: Shape): Symbol { val initial = super.toSymbol(shape) @@ -42,21 +46,30 @@ class EventStreamSymbolProvider( } // If we find an operation shape, then we can wrap the type if (operationShape != null) { - val error = operationShape.errorSymbol(this).toSymbol() + val error = if (operationShape.errors.isNotEmpty()) { + operationShape.errorSymbol(this).toSymbol() + } else { + RuntimeType("MessageStreamError", smithyEventStream, "aws_smithy_http::event_stream") + .toSymbol() + } val errorFmt = error.rustType().render(fullyQualified = true) val innerFmt = initial.rustType().stripOuter().render(fullyQualified = true) - val outer = when (shape.isInputEventStream(model)) { - true -> "EventStreamInput<$innerFmt>" + val isSender = (shape.isInputEventStream(model) && target == CodegenTarget.CLIENT) || + (shape.isOutputEventStream(model) && target == CodegenTarget.SERVER) + val outer = when (isSender) { + true -> "EventStreamSender<$innerFmt>" else -> "Receiver<$innerFmt, $errorFmt>" } val rustType = RustType.Opaque(outer, "aws_smithy_http::event_stream") - return initial.toBuilder() + val symbol = initial.toBuilder() .name(rustType.name) .rustType(rustType) - .addReference(error) .addReference(initial) .addDependency(CargoDependency.SmithyHttp(runtimeConfig).withFeature("event-stream")) - .build() + if (operationShape.errors.isNotEmpty()) { + symbol.addReference(error) + } + return symbol.build() } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustCodegenPlugin.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustCodegenPlugin.kt index ceecac0b30f..88e4f097ec5 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustCodegenPlugin.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustCodegenPlugin.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.rust.codegen.rustlang.Attribute.Companion.NonExhau import software.amazon.smithy.rust.codegen.rustlang.RustReservedWordSymbolProvider import software.amazon.smithy.rust.codegen.smithy.customizations.ClientCustomizations import software.amazon.smithy.rust.codegen.smithy.customize.CombinedCodegenDecorator +import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import java.util.logging.Level import java.util.logging.Logger @@ -49,7 +50,7 @@ class RustCodegenPlugin : SmithyBuildPlugin { fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig = DefaultConfig) = SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model) } + .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.CLIENT) } // Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) .let { StreamingShapeSymbolProvider(it, model) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/EventStreamDecorator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/EventStreamDecorator.kt new file mode 100644 index 00000000000..aff6d481dd2 --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/EventStreamDecorator.kt @@ -0,0 +1,88 @@ +package software.amazon.smithy.rust.codegen.smithy.customize + +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.rustlang.Writable +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization +import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfig +import software.amazon.smithy.rust.codegen.util.hasEventStreamOperations + +/** + * The EventStreamDecorator: + * - adds a `new_event_stream_signer()` method to `config` to create an Event Stream NoOp signer + * - can be customized by subclassing, see SigV4SigningDecorator + */ +open class EventStreamDecorator( + private val decorators: List +) : RustCodegenDecorator { + override val name: String = "EventStreamDecorator" + override val order: Byte = 0 + + open fun applies(codegenContext: CodegenContext): Boolean = true + + override fun configCustomizations( + codegenContext: CodegenContext, + baseCustomizations: List + ): List { + decorators.forEach { + if (it.applies(codegenContext)) return it.configCustomizations(codegenContext, baseCustomizations) + } + return baseCustomizations + EventStreamSignConfig( + codegenContext.runtimeConfig, + codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model), + ) + } + + override fun operationCustomizations( + codegenContext: CodegenContext, + operation: OperationShape, + baseCustomizations: List + ): List { + decorators.forEach { + if (it.applies(codegenContext)) return it.operationCustomizations(codegenContext, operation, baseCustomizations) + } + return baseCustomizations + } +} + +class EventStreamSignConfig( + runtimeConfig: RuntimeConfig, + private val serviceHasEventStream: Boolean, +) : ConfigCustomization() { + private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) + private val codegenScope = arrayOf( + "NoOpSigner" to RuntimeType("NoOpSigner", smithyEventStream, "aws_smithy_eventstream::frame"), + "SharedPropertyBag" to RuntimeType( + "SharedPropertyBag", + CargoDependency.SmithyHttp(runtimeConfig), + "aws_smithy_http::property_bag" + ) + ) + + override fun section(section: ServiceConfig): Writable { + return when (section) { + is ServiceConfig.ConfigImpl -> writable { + if (serviceHasEventStream) { + rustTemplate( + """ + /// Creates a new Event Stream `SignMessage` implementor. + pub fn new_event_stream_signer( + &self, + _properties: #{SharedPropertyBag} + ) -> #{NoOpSigner} { + #{NoOpSigner}{} + } + """, + *codegenScope + ) + } + } + else -> emptySection + } + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RustCodegenDecorator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RustCodegenDecorator.kt index c17e765aa06..56870ebdd59 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RustCodegenDecorator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/RustCodegenDecorator.kt @@ -155,7 +155,7 @@ open class CombinedCodegenDecorator(decorators: List) : Ru .onEach { logger.info("Adding Codegen Decorator: ${it.javaClass.name}") }.toList() - return CombinedCodegenDecorator(decorators + RequiredCustomizations() + FluentClientDecorator() + extras) + return CombinedCodegenDecorator(decorators + RequiredCustomizations() + EventStreamDecorator(listOf()) + FluentClientDecorator() + extras) } } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index 6e6dee4864c..d1fd44ec845 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -25,6 +25,7 @@ import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError @@ -39,6 +40,7 @@ import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.isEventStream import software.amazon.smithy.rust.codegen.util.isInputEventStream +import software.amazon.smithy.rust.codegen.util.isOutputEventStream import software.amazon.smithy.rust.codegen.util.isStreaming import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.toSnakeCase @@ -56,13 +58,15 @@ class HttpBoundProtocolPayloadGenerator( private val operationSerModule = RustModule.private("operation_ser") + private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) private val codegenScope = arrayOf( "hyper" to CargoDependency.HyperWithStream.asType(), "ByteStream" to RuntimeType.byteStream(runtimeConfig), "ByteSlab" to RuntimeType.ByteSlab, "SdkBody" to RuntimeType.sdkBody(runtimeConfig), "BuildError" to runtimeConfig.operationBuildError(), - "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType() + "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(), + "NoOpSigner" to RuntimeType("NoOpSigner", smithyEventStream, "aws_smithy_eventstream::frame"), ) override fun payloadMetadata(operationShape: OperationShape): ProtocolPayloadGenerator.PayloadMetadata { @@ -108,8 +112,7 @@ class HttpBoundProtocolPayloadGenerator( val serializerGenerator = protocol.structuredDataSerializer(operationShape) generateStructureSerializer(writer, self, serializerGenerator.operationInputSerializer(operationShape)) } else { - val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName) - generatePayloadMemberSerializer(writer, self, operationShape, payloadMember) + generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName) } } @@ -120,8 +123,7 @@ class HttpBoundProtocolPayloadGenerator( val serializerGenerator = protocol.structuredDataSerializer(operationShape) generateStructureSerializer(writer, self, serializerGenerator.operationOutputSerializer(operationShape)) } else { - val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName) - generatePayloadMemberSerializer(writer, self, operationShape, payloadMember) + generatePayloadMemberSerializer(writer, self, operationShape, payloadMemberName) } } @@ -129,15 +131,24 @@ class HttpBoundProtocolPayloadGenerator( writer: RustWriter, self: String, operationShape: OperationShape, - payloadMember: MemberShape + payloadMemberName: String, ) { val serializerGenerator = protocol.structuredDataSerializer(operationShape) - // TODO(https://github.com/awslabs/smithy-rs/issues/1157) Add support for server event streams. - if (operationShape.isInputEventStream(model)) { - writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator) + if (operationShape.isEventStream(model)) { + if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) { + val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName) + writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "self") + } else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) { + val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName) + writer.serializeViaEventStream(operationShape, payloadMember, serializerGenerator, "output") + } } else { val bodyMetadata = payloadMetadata(operationShape) + val payloadMember = when (httpMessageType) { + HttpMessageType.RESPONSE -> operationShape.outputShape(model).expectMember(payloadMemberName) + HttpMessageType.REQUEST -> operationShape.inputShape(model).expectMember(payloadMemberName) + } writer.serializeViaPayload(bodyMetadata, self, payloadMember, serializerGenerator) } } @@ -156,11 +167,16 @@ class HttpBoundProtocolPayloadGenerator( private fun RustWriter.serializeViaEventStream( operationShape: OperationShape, memberShape: MemberShape, - serializerGenerator: StructuredDataSerializerGenerator + serializerGenerator: StructuredDataSerializerGenerator, + outerName: String, ) { val memberName = symbolProvider.toMemberName(memberShape) val unionShape = model.expectShape(memberShape.target, UnionShape::class.java) + val contentType = when (target) { + CodegenTarget.CLIENT -> httpBindingResolver.requestContentType(operationShape) + CodegenTarget.SERVER -> httpBindingResolver.responseContentType(operationShape) + } val marshallerConstructorFn = EventStreamMarshallerGenerator( model, target, @@ -168,27 +184,50 @@ class HttpBoundProtocolPayloadGenerator( symbolProvider, unionShape, serializerGenerator, - httpBindingResolver.requestContentType(operationShape) - ?: throw CodegenException("event streams must set a content type"), + contentType ?: throw CodegenException("event streams must set a content type"), ).render() + val operationError = if (operationShape.errors.isNotEmpty()) { + operationShape.errorSymbol(symbolProvider) + } else { + RuntimeType("MessageStreamError", smithyEventStream, "aws_smithy_http::event_stream") + } + // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the - // parameters that are not `@eventHeader` or `@eventPayload`. - rustTemplate( - """ - { - let marshaller = #{marshallerConstructorFn}(); - let signer = _config.new_event_stream_signer(properties.clone()); - let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, #{OperationError}> = - self.$memberName.into_body_stream(marshaller, signer); - let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); - body - } - """, - *codegenScope, - "marshallerConstructorFn" to marshallerConstructorFn, - "OperationError" to operationShape.errorSymbol(symbolProvider) - ) + // parameters that are not `@eventHeader` or `@eventPayload`. + when (target) { + CodegenTarget.CLIENT -> + rustTemplate( + """ + { + let marshaller = #{marshallerConstructorFn}(); + let signer = _config.new_event_stream_signer(properties.clone()); + let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, #{OperationError}> = + $outerName.$memberName.into_body_stream(marshaller, signer); + let body: #{SdkBody} = #{hyper}::Body::wrap_stream(adapter).into(); + body + } + """, + *codegenScope, + "marshallerConstructorFn" to marshallerConstructorFn, + "OperationError" to operationError, + ) + CodegenTarget.SERVER -> + rustTemplate( + """ + { + let marshaller = #{marshallerConstructorFn}(); + let signer = #{NoOpSigner}{}; + let adapter: #{SmithyHttp}::event_stream::MessageStreamAdapter<_, #{OperationError}> = + $outerName.$memberName.into_body_stream(marshaller, signer); + adapter + } + """, + *codegenScope, + "marshallerConstructorFn" to marshallerConstructorFn, + "OperationError" to operationError, + ) + } } private fun RustWriter.serializeViaPayload( diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index 1bbfc6482e3..3a8ae1dd535 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -55,8 +55,13 @@ class EventStreamUnmarshallerGenerator( private val target: CodegenTarget, ) { private val unionSymbol = symbolProvider.toSymbol(unionShape) - private val operationErrorSymbol = operationShape.errorSymbol(symbolProvider) private val smithyEventStream = CargoDependency.SmithyEventStream(runtimeConfig) + private val operationErrorSymbol = if (operationShape.errors.isNotEmpty()) { + operationShape.errorSymbol(symbolProvider).toSymbol() + } else { + RuntimeType("MessageStreamError", smithyEventStream, "aws_smithy_http::event_stream") + .toSymbol() + } private val eventStreamSerdeModule = RustModule.private("event_stream_serde") private val codegenScope = arrayOf( "Blob" to RuntimeType("Blob", CargoDependency.SmithyTypes(runtimeConfig), "aws_smithy_types"), diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RemoveEventStreamOperations.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RemoveEventStreamOperations.kt index 5d2bc11784e..98d33ef8d00 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RemoveEventStreamOperations.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RemoveEventStreamOperations.kt @@ -21,7 +21,7 @@ object RemoveEventStreamOperations { fun transform(model: Model, settings: RustSettings): Model { // If Event Stream is allowed in build config, then don't remove the operations - if (settings.codegenConfig.eventStreamAllowList.contains(settings.moduleName)) { + if (true || settings.codegenConfig.eventStreamAllowList.contains(settings.moduleName)) { return model } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt index 5c9b2ed202b..a9b9149ca2f 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.StreamingTrait import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait inline fun Model.lookup(shapeId: String): T { return this.expectShape(ShapeId.from(shapeId), T::class.java) @@ -57,7 +58,7 @@ fun MemberShape.isInputEventStream(model: Model): Boolean { } fun MemberShape.isOutputEventStream(model: Model): Boolean { - return isEventStream(model) && model.expectShape(container).hasTrait() + return isEventStream(model) && model.expectShape(container).hasTrait() } fun Shape.hasEventStreamMember(model: Model): Boolean { diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/customizations/HttpVersionListGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/customizations/HttpVersionListGeneratorTest.kt index 5165610be59..d84747eb8e1 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/customizations/HttpVersionListGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/customizations/HttpVersionListGeneratorTest.kt @@ -289,8 +289,8 @@ class FakeSigningConfig( Ok(message) } - fn sign_empty(&mut self) -> Result<#{Message}, #{SignMessageError}> { - Ok(#{Message}::new(Vec::new())) + fn sign_empty(&mut self) -> Option> { + Some(Ok(#{Message}::new(Vec::new()))) } } """, diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProviderTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProviderTest.kt index 0a85b92d1ad..96cb1bdb6fb 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProviderTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/EventStreamSymbolProviderTest.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.rustlang.RustType +import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.testutil.asSmithyModel @@ -41,7 +42,7 @@ class EventStreamSymbolProviderTest { ) val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape - val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, DefaultConfig), model) + val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, DefaultConfig), model, CodegenTarget.CLIENT) // Look up the synthetic input/output rather than the original input/output val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape @@ -50,7 +51,7 @@ class EventStreamSymbolProviderTest { val inputType = provider.toSymbol(inputStream).rustType() val outputType = provider.toSymbol(outputStream).rustType() - inputType shouldBe RustType.Opaque("EventStreamInput", "aws_smithy_http::event_stream") + inputType shouldBe RustType.Opaque("EventStreamSender", "aws_smithy_http::event_stream") outputType shouldBe RustType.Opaque("Receiver", "aws_smithy_http::event_stream") } @@ -77,7 +78,7 @@ class EventStreamSymbolProviderTest { ) val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape - val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, DefaultConfig), model) + val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, DefaultConfig), model, CodegenTarget.CLIENT) // Look up the synthetic input/output rather than the original input/output val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape diff --git a/rust-runtime/aws-smithy-eventstream/src/frame.rs b/rust-runtime/aws-smithy-eventstream/src/frame.rs index 4cdb95099f3..6732ae72a59 100644 --- a/rust-runtime/aws-smithy-eventstream/src/frame.rs +++ b/rust-runtime/aws-smithy-eventstream/src/frame.rs @@ -27,7 +27,19 @@ pub type SignMessageError = Box; pub trait SignMessage: fmt::Debug { fn sign(&mut self, message: Message) -> Result; - fn sign_empty(&mut self) -> Result; + fn sign_empty(&mut self) -> Option>; +} + +#[derive(Debug)] +pub struct NoOpSigner {} +impl SignMessage for NoOpSigner { + fn sign(&mut self, message: Message) -> Result { + Ok(message) + } + + fn sign_empty(&mut self) -> Option> { + None + } } /// Converts a Smithy modeled Event Stream type into a [`Message`](Message). diff --git a/rust-runtime/aws-smithy-http-server/examples/Cargo.toml b/rust-runtime/aws-smithy-http-server/examples/Cargo.toml index 2102e16a700..c799f19ac7e 100644 --- a/rust-runtime/aws-smithy-http-server/examples/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/examples/Cargo.toml @@ -2,8 +2,9 @@ [workspace] members = [ "pokemon_service", + "pokemon_client", "pokemon_service_sdk", - "pokemon_service_client" + "pokemon_service_client", ] [profile.release] diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon_client/Cargo.toml b/rust-runtime/aws-smithy-http-server/examples/pokemon_client/Cargo.toml new file mode 100644 index 00000000000..053f9f714b8 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon_client/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "pokemon_client" +version = "0.1.0" +edition = "2021" +publish = false +authors = ["Smithy-rs Server Team "] +description = "A smithy Rust client to retrieve information about and capture Pokémon." + +[dependencies] +async-stream = "0.3" +clap = { version = "~3.2.1", features = ["derive"] } +hyper = {version = "0.14", features = ["server"] } +rand = "0.8" +tokio = "1" +tower = "0.4" +tower-http = { version = "0.3", features = ["trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +# Local paths +aws-smithy-client = { path = "../../../aws-smithy-client/", features = ["rustls"] } +pokemon_service_client = { path = "../pokemon_service_client/" } + +[dev-dependencies] +assert_cmd = "2.0" +home = "0.5" +wrk-api-bench = "0.0.7" diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon_client/src/main.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon_client/src/main.rs new file mode 100644 index 00000000000..16697ed29c9 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon_client/src/main.rs @@ -0,0 +1,106 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use async_stream::stream; +use pokemon_service_client::model::AttemptCapturingPokemonEvent; +use pokemon_service_client::model::CapturingEvent; +use pokemon_service_client::{Builder, Client, Config}; +use rand::Rng; + +fn get_pokeball() -> String { + let random = rand::thread_rng().gen_range(0..100); + let pokeball = if random < 5 { + "Master Ball" + } else if random < 30 { + "Great Ball" + } else { + "Fast Ball" + }; + pokeball.to_string() +} +fn get_pokemon_to_capture() -> String { + let pokemons = vec!["Charizard", "Pikachu", "Regieleki"]; + pokemons[rand::thread_rng().gen_range(0..pokemons.len())].to_string() +} + +#[tokio::main] +pub async fn main() -> Result<(), ()> { + let raw_client = Builder::dyn_https() + .middleware_fn(|mut req| { + let http_req = req.http_mut(); + let uri = format!("http://localhost:13734{}", http_req.uri().path()); + *http_req.uri_mut() = uri.parse().unwrap(); + req + }) + .build_dyn(); + let config = Config::builder().build(); + let client = Client::with_config(raw_client, config); + + let mut team = vec![]; + let input_stream = stream! { + // Always Pikachu + yield Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder().name("Pikachu").pokeball("Master Ball").build() + )); + yield Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder().name("Regieleki").pokeball("Fast Ball").build() + )); + yield Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder().name("Charizard").pokeball("Great Ball").build() + )); + }; + + // Throw many! + let mut output = client + .capture_pokemon_operation() + .events(input_stream.into()) + .send() + .await + .unwrap(); + loop { + match output.events.recv().await { + Ok(Some(capture)) => { + let pokemon = capture.as_event().unwrap().name.as_ref().unwrap().clone(); + println!("captured {}", pokemon); + team.push(pokemon); + } + Err(e) => { + println!("error {:?}", e); + break; + } + Ok(None) => break, + } + } + + while team.len() < 6 { + let pokeball = get_pokeball(); + let pokemon = get_pokemon_to_capture(); + let input_stream = stream! { + yield Ok(AttemptCapturingPokemonEvent::Event( + CapturingEvent::builder().name(pokemon).pokeball(pokeball).build() + )) + }; + let mut output = client + .capture_pokemon_operation() + .events(input_stream.into()) + .send() + .await + .unwrap(); + match output.events.recv().await { + Ok(Some(capture)) => { + let pokemon = capture.as_event().unwrap().name.as_ref().unwrap().clone(); + println!("captured {}", pokemon); + team.push(pokemon); + } + Err(e) => { + println!("error {:?}", e); + break; + } + Ok(None) => {} + } + } + println!("Team: {:?}", team); + Ok(()) +} diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/Cargo.toml b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/Cargo.toml index 5a2e295e423..a61ba550928 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/Cargo.toml @@ -7,8 +7,10 @@ authors = ["Smithy-rs Server Team "] description = "A smithy Rust service to retrieve information about Pokémon." [dependencies] +async-stream = "0.3" clap = { version = "~3.2.1", features = ["derive"] } hyper = {version = "0.14", features = ["server"] } +rand = "0.8" tokio = "1" tower = "0.4" tower-http = { version = "0.3", features = ["trace"] } diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/lib.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/lib.rs index 8418e54d15c..d8b8f2125cd 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/lib.rs @@ -13,8 +13,10 @@ use std::{ sync::{atomic::AtomicU64, Arc}, }; +use async_stream::stream; use aws_smithy_http_server::Extension; use pokemon_service_sdk::{error, input, model, output}; +use rand::Rng; use tracing_subscriber::{prelude::*, EnvFilter}; const PIKACHU_ENGLISH_FLAVOR_TEXT: &str = @@ -183,6 +185,55 @@ pub async fn get_server_statistics( output::GetServerStatisticsOutput { calls_count } } +/// Attempts to capture a Pokémon +pub async fn capture_pokemon(mut input: input::CapturePokemonOperationInput) -> output::CapturePokemonOperationOutput { + let mut output_events = vec![]; + loop { + match input.events.recv().await { + Ok(maybe_event) => match maybe_event { + Some(event) => { + let capturing_event = event.as_event(); + // TODO: verify the events from the Pokémon trainer + if let Ok(attempt) = capturing_event { + let pokeball = attempt.pokeball.as_ref().map(|ball| ball.as_str()).unwrap_or(""); + let pokemon = attempt + .name + .as_ref() + .map(|name| name.as_str()) + .unwrap_or("") + .to_string(); + let captured = match pokeball { + "Master Ball" => true, + "Great Ball" => rand::thread_rng().gen_range(0..100) > 33, + "Fast Ball" => rand::thread_rng().gen_range(0..100) > 66, + _ => false, + }; + if captured { + output_events.push(Ok(crate::model::CapturePokemonEvents::Event( + crate::model::CaptureEvent::builder().name(pokemon).build(), + ))); + } + } + } + None => break, + }, + Err(e) => println!("{:?}", e), + } + } + let output_stream = stream! { + use std::time::Duration; + // Will it capture the Pokémon? + tokio::time::sleep(Duration::from_millis(1000)).await; + for output in output_events { + yield output; + } + }; + return output::CapturePokemonOperationOutput::builder() + .events(output_stream.into()) + .build() + .unwrap(); +} + /// Empty operation used to benchmark the service. pub async fn empty_operation(_input: input::EmptyOperationInput) -> output::EmptyOperationOutput { output::EmptyOperationOutput {} diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/main.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/main.rs index e84559daa38..a2a6f19a7c8 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/main.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon_service/src/main.rs @@ -8,7 +8,9 @@ use std::{net::SocketAddr, sync::Arc}; use aws_smithy_http_server::{AddExtensionLayer, Router}; use clap::Parser; -use pokemon_service::{empty_operation, get_pokemon_species, get_server_statistics, setup_tracing, State}; +use pokemon_service::{ + capture_pokemon, empty_operation, get_pokemon_species, get_server_statistics, setup_tracing, State, +}; use pokemon_service_sdk::operation_registry::OperationRegistryBuilder; use tower::ServiceBuilder; use tower_http::trace::TraceLayer; @@ -34,6 +36,7 @@ pub async fn main() { // return the operation's output. .get_pokemon_species(get_pokemon_species) .get_server_statistics(get_server_statistics) + .capture_pokemon_operation(capture_pokemon) .empty_operation(empty_operation) .build() .expect("Unable to build operation registry") diff --git a/rust-runtime/aws-smithy-http/src/event_stream.rs b/rust-runtime/aws-smithy-http/src/event_stream.rs index bc0f97ccd6f..ae6176cbd25 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream.rs @@ -7,13 +7,13 @@ use std::error::Error as StdError; -mod input; -mod output; +mod receiver; +mod sender; pub type BoxError = Box; #[doc(inline)] -pub use input::{EventStreamInput, MessageStreamAdapter}; +pub use sender::{EventStreamSender, MessageStreamAdapter, MessageStreamError}; #[doc(inline)] -pub use output::{Error, RawMessage, Receiver}; +pub use receiver::{Error, RawMessage, Receiver}; diff --git a/rust-runtime/aws-smithy-http/src/event_stream/output.rs b/rust-runtime/aws-smithy-http/src/event_stream/receiver.rs similarity index 100% rename from rust-runtime/aws-smithy-http/src/event_stream/output.rs rename to rust-runtime/aws-smithy-http/src/event_stream/receiver.rs diff --git a/rust-runtime/aws-smithy-http/src/event_stream/input.rs b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs similarity index 77% rename from rust-runtime/aws-smithy-http/src/event_stream/input.rs rename to rust-runtime/aws-smithy-http/src/event_stream/sender.rs index 7107146c8dd..a14e508e91c 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/input.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs @@ -11,22 +11,23 @@ use futures_core::Stream; use pin_project::pin_project; use std::error::Error as StdError; use std::fmt; +use std::fmt::Debug; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; /// Input type for Event Streams. -pub struct EventStreamInput { +pub struct EventStreamSender { input_stream: Pin> + Send>>, } -impl fmt::Debug for EventStreamInput { +impl Debug for EventStreamSender { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "EventStreamInput(Box)") + write!(f, "EventStreamSender(Box)") } } -impl EventStreamInput { +impl EventStreamSender { #[doc(hidden)] pub fn into_body_stream( self, @@ -37,24 +38,68 @@ impl EventStreamInput { } } -impl From for EventStreamInput +impl From for EventStreamSender where S: Stream> + Send + 'static, { fn from(stream: S) -> Self { - EventStreamInput { + EventStreamSender { input_stream: Box::pin(stream), } } } +#[derive(Debug)] +pub struct MessageStreamError { + kind: MessageStreamErrorKind, + pub(crate) meta: aws_smithy_types::Error, +} + +#[derive(Debug)] +pub enum MessageStreamErrorKind { + Unhandled(Box), +} + +impl MessageStreamError { + /// Creates the `MessageStreamError::Unhandled` variant from any error type. + pub fn unhandled(err: impl Into>) -> Self { + Self { + meta: Default::default(), + kind: MessageStreamErrorKind::Unhandled(err.into()), + } + } + + /// Creates the `MessageStreamError::Unhandled` variant from a `aws_smithy_types::Error`. + pub fn generic(err: aws_smithy_types::Error) -> Self { + Self { + meta: err.clone(), + kind: MessageStreamErrorKind::Unhandled(err.into()), + } + } + + /// Returns error metadata, which includes the error code, message, + /// request ID, and potentially additional information. + pub fn meta(&self) -> &aws_smithy_types::Error { + &self.meta + } +} + +impl StdError for MessageStreamError {} +impl fmt::Display for MessageStreamError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.kind { + MessageStreamErrorKind::Unhandled(inner) => std::fmt::Debug::fmt(inner, f), + } + } +} + /// Adapts a `Stream` to a signed `Stream` by using the provided /// message marshaller and signer implementations. /// /// This will yield an `Err(SdkError::ConstructionFailure)` if a message can't be /// marshalled into an Event Stream frame, (e.g., if the message payload was too large). #[pin_project] -pub struct MessageStreamAdapter { +pub struct MessageStreamAdapter { marshaller: Box + Send + Sync>, signer: Box, #[pin] @@ -111,12 +156,15 @@ where } else if !*this.end_signal_sent { *this.end_signal_sent = true; let mut buffer = Vec::new(); - this.signer - .sign_empty() - .map_err(|err| SdkError::ConstructionFailure(err))? - .write_to(&mut buffer) - .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; - Poll::Ready(Some(Ok(Bytes::from(buffer)))) + match this.signer.sign_empty() { + Some(sign) => { + sign.map_err(|err| SdkError::ConstructionFailure(err))? + .write_to(&mut buffer) + .map_err(|err| SdkError::ConstructionFailure(Box::new(err)))?; + Poll::Ready(Some(Ok(Bytes::from(buffer)))) + } + None => Poll::Ready(None), + } } else { Poll::Ready(None) } @@ -129,7 +177,7 @@ where #[cfg(test)] mod tests { use super::MarshallMessage; - use crate::event_stream::{EventStreamInput, MessageStreamAdapter}; + use crate::event_stream::{EventStreamSender, MessageStreamAdapter}; use crate::result::SdkError; use async_stream::stream; use aws_smithy_eventstream::error::Error as EventStreamError; @@ -246,8 +294,8 @@ mod tests { // Verify the developer experience for this compiles #[allow(unused)] fn event_stream_input_ergonomics() { - fn check(input: impl Into>) { - let _: EventStreamInput = input.into(); + fn check(input: impl Into>) { + let _: EventStreamSender = input.into(); } check(stream! { yield Ok(TestMessage("test".into()));