From f0929e74ba5859759c8cdf7b95e27d0b8b244f6e Mon Sep 17 00:00:00 2001 From: david-perez Date: Tue, 31 Oct 2023 17:10:09 +0100 Subject: [PATCH] Allow server decorators to inject methods on config (#3111) PR #3095 added a code-generated `${serviceName}Config` object on which users can register layers and plugins. For example: ```rust let config = PokemonServiceConfig::builder() .layer(layers) .http_plugin(authn_plugin) .model_plugin(authz_plugin) .build(); ``` This PR makes it so that server decorators can inject methods on this config builder object. These methods can apply arbitrary layers, HTTP plugins, and/or model plugins. Moreover, the decorator can configure whether invoking such method is required or not. For example, a decorator can inject an `aws_auth` method that configures some plugins using its input arguments. Missing invocation of this method will result in the config failing to build: ```rust let _: SimpleServiceConfig< // No layers have been applied. tower::layer::util::Identity, // One HTTP plugin has been applied. PluginStack, // One model plugin has been applied. PluginStack, > = SimpleServiceConfig::builder() // This method has been injected in the config builder! .aws_auth("a".repeat(69).to_owned(), 69) // The method will apply one HTTP plugin and one model plugin, // configuring them with the input arguments. Configuration can be // declared to be fallible, in which case we get a `Result` we unwrap // here. .expect("failed to configure aws_auth") .build() // Since `aws_auth` has been marked as required, if the user misses // invoking it, this would panic here. .unwrap(); ``` ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --- .../smithy/rust/codegen/core/testutil/Rust.kt | 2 +- .../server/smithy/ServerCargoDependency.kt | 1 + .../server/smithy/ServerCodegenVisitor.kt | 8 +- .../customize/ServerCodegenDecorator.kt | 16 +- .../smithy/generators/ServerRootGenerator.kt | 38 +- .../generators/ServerServiceGenerator.kt | 10 +- .../generators/ServiceConfigGenerator.kt | 339 +++++++++++++++++- .../generators/ServiceConfigGeneratorTest.kt | 231 ++++++++++++ .../src/plugin/stack.rs | 2 + 9 files changed, 614 insertions(+), 33 deletions(-) create mode 100644 codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt index c443110805..9fdab3965d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt @@ -294,7 +294,7 @@ class TestWriterDelegator( } /** - * Generate a newtest module + * Generate a new test module * * This should only be used in test codeā€”the generated module name will be something like `tests_123` */ diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index 956a88ecf2..779fcb1686 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -22,6 +22,7 @@ object ServerCargoDependency { val Nom: CargoDependency = CargoDependency("nom", CratesIo("7")) val OnceCell: CargoDependency = CargoDependency("once_cell", CratesIo("1.13")) val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2")) + val ThisError: CargoDependency = CargoDependency("thiserror", CratesIo("1.0")) val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.23.1"), scope = DependencyScope.Dev) val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5")) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 6f66c72166..adfa466257 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -78,6 +78,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.Unconstraine import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedMapGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.isBuilderFallible import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator @@ -591,11 +592,15 @@ open class ServerCodegenVisitor( logger.info("[rust-server-codegen] Generating a service $shape") val serverProtocol = protocolGeneratorFactory.protocol(codegenContext) as ServerProtocol + val configMethods = codegenDecorator.configMethods(codegenContext) + val isConfigBuilderFallible = configMethods.isBuilderFallible() + // Generate root. rustCrate.lib { ServerRootGenerator( serverProtocol, codegenContext, + isConfigBuilderFallible, ).render(this) } @@ -612,9 +617,10 @@ open class ServerCodegenVisitor( ServerServiceGenerator( codegenContext, serverProtocol, + isConfigBuilderFallible, ).render(this) - ServiceConfigGenerator(codegenContext).render(this) + ServiceConfigGenerator(codegenContext, configMethods).render(this) ScopeMacroGenerator(codegenContext).render(this) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt index 06f9c3c09b..5470c0902c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt @@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings import software.amazon.smithy.rust.codegen.server.smithy.ValidationResult +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConfigMethod import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator import java.util.logging.Logger @@ -41,6 +42,12 @@ interface ServerCodegenDecorator : CoreCodegenDecorator = emptyList() + + /** + * Configuration methods that should be injected into the `${serviceName}Config` struct to allow users to configure + * pre-applied layers and plugins. + */ + fun configMethods(codegenContext: ServerCodegenContext): List = emptyList() } /** @@ -74,10 +81,11 @@ class CombinedServerCodegenDecorator(decorators: List) : decorator.postprocessValidationExceptionNotAttachedErrorMessage(accumulated) } - override fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List { - return orderedDecorators.map { decorator -> decorator.postprocessGenerateAdditionalStructures(operationShape) } - .flatten() - } + override fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List = + orderedDecorators.flatMap { it.postprocessGenerateAdditionalStructures(operationShape) } + + override fun configMethods(codegenContext: ServerCodegenContext): List = + orderedDecorators.flatMap { it.configMethods(codegenContext) } companion object { fun fromClasspath( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt index a1ea4b90f6..63f55954da 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output open class ServerRootGenerator( val protocol: ServerProtocol, private val codegenContext: ServerCodegenContext, + private val isConfigBuilderFallible: Boolean, ) { private val index = TopDownIndex.of(codegenContext.model) private val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet( @@ -57,6 +58,8 @@ open class ServerRootGenerator( } .join("//!\n") + val unwrapConfigBuilder = if (isConfigBuilderFallible) ".expect(\"config failed to build\")" else "" + writer.rustTemplate( """ //! A fast and customizable Rust implementation of the $serviceName Smithy service. @@ -75,7 +78,10 @@ open class ServerRootGenerator( //! ## async fn dummy() { //! use $crateName::{$serviceName, ${serviceName}Config}; //! - //! ## let app = $serviceName::builder(${serviceName}Config::builder().build()).build_unchecked(); + //! ## let app = $serviceName::builder( + //! ## ${serviceName}Config::builder() + //! ## .build()$unwrapConfigBuilder + //! ## ).build_unchecked(); //! let server = app.into_make_service(); //! let bind: SocketAddr = "127.0.0.1:6969".parse() //! .expect("unable to parse the server bind address and port"); @@ -92,7 +98,10 @@ open class ServerRootGenerator( //! use $crateName::$serviceName; //! //! ## async fn dummy() { - //! ## let app = $serviceName::builder(${serviceName}Config::builder().build()).build_unchecked(); + //! ## let app = $serviceName::builder( + //! ## ${serviceName}Config::builder() + //! ## .build()$unwrapConfigBuilder + //! ## ).build_unchecked(); //! let handler = LambdaHandler::new(app); //! lambda_http::run(handler).await.unwrap(); //! ## } @@ -118,7 +127,7 @@ open class ServerRootGenerator( //! let http_plugins = HttpPlugins::new() //! .push(LoggingPlugin) //! .push(MetricsPlugin); - //! let config = ${serviceName}Config::builder().build(); + //! let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; //! let builder: $builderName = $serviceName::builder(config); //! ``` //! @@ -183,13 +192,13 @@ open class ServerRootGenerator( //! //! ## Example //! - //! ```rust + //! ```rust,no_run //! ## use std::net::SocketAddr; //! use $crateName::{$serviceName, ${serviceName}Config}; //! //! ##[#{Tokio}::main] //! pub async fn main() { - //! let config = ${serviceName}Config::builder().build(); + //! let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; //! let app = $serviceName::builder(config) ${builderFieldNames.values.joinToString("\n") { "//! .$it($it)" }} //! .build() @@ -236,6 +245,23 @@ open class ServerRootGenerator( fun render(rustWriter: RustWriter) { documentation(rustWriter) - rustWriter.rust("pub use crate::service::{$serviceName, ${serviceName}Config, ${serviceName}ConfigBuilder, ${serviceName}Builder, MissingOperationsError};") + // Only export config builder error if fallible. + val configErrorReExport = if (isConfigBuilderFallible) { + "${serviceName}ConfigError," + } else { + "" + } + rustWriter.rust( + """ + pub use crate::service::{ + $serviceName, + ${serviceName}Config, + ${serviceName}ConfigBuilder, + $configErrorReExport + ${serviceName}Builder, + MissingOperationsError + }; + """, + ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index 97963cbb40..e1da585d05 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -33,6 +33,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output class ServerServiceGenerator( private val codegenContext: ServerCodegenContext, private val protocol: ServerProtocol, + private val isConfigBuilderFallible: Boolean, ) { private val runtimeConfig = codegenContext.runtimeConfig private val smithyHttpServer = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() @@ -107,6 +108,11 @@ class ServerServiceGenerator( val docHandler = DocHandlerGenerator(codegenContext, operationShape, "handler", "///") val handler = docHandler.docSignature() val handlerFixed = docHandler.docFixedSignature() + val unwrapConfigBuilder = if (isConfigBuilderFallible) { + ".expect(\"config failed to build\")" + } else { + "" + } rustTemplate( """ /// Sets the [`$structName`](crate::operation_shape::$structName) operation. @@ -123,7 +129,7 @@ class ServerServiceGenerator( /// #{Handler:W} /// - /// let config = ${serviceName}Config::builder().build(); + /// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; /// let app = $serviceName::builder(config) /// .$fieldName(handler) /// /* Set other handlers */ @@ -186,7 +192,7 @@ class ServerServiceGenerator( /// #{HandlerFixed:W} /// - /// let config = ${serviceName}Config::builder().build(); + /// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; /// let svc = #{Tower}::util::service_fn(handler); /// let app = $serviceName::builder(config) /// .${fieldName}_service(svc) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt index 960f9d0df7..525487968f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt @@ -6,31 +6,141 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.join +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +fun List.isBuilderFallible() = this.any { it.isRequired } + +/** + * Contains all data necessary to render a method on the config builder object to apply arbitrary layers, HTTP plugins, + * and model plugins. + */ +data class ConfigMethod( + /** The name of the method. **/ + val name: String, + /** The Rust docs for the method. **/ + val docs: String, + /** The parameters of the method. **/ + val params: List, + /** In case the method is fallible, the error type it returns. **/ + val errorType: RuntimeType?, + /** The code block inside the method. **/ + val initializer: Initializer, + /** Whether the user must invoke the method or not. **/ + val isRequired: Boolean, +) { + /** The name of the flag on the config builder object that tracks whether the _required_ method has already been invoked or not. **/ + fun requiredBuilderFlagName(): String { + check(isRequired) { + "Config method is not required so it shouldn't need a field in the builder tracking whether it has been configured" + } + return "${name}_configured" + } + + /** The name of the enum variant on the config builder's error struct for a _required_ method. **/ + fun requiredErrorVariant(): String { + check(isRequired) { + "Config method is not required so it shouldn't need an error variant" + } + return "${name.toPascalCase()}NotConfigured" + } +} + +/** + * Represents the code block inside the method that initializes and configures a set of layers, HTTP plugins, and/or model + * plugins. + */ +data class Initializer( + /** + * The code itself that initializes and configures the layers, HTTP plugins, and/or model plugins. This should be + * a set of [Rust statements] that, after execution, defines one variable binding per layer/HTTP plugin/model plugin + * that it has configured and wants to apply. The code may use the method's input arguments (see [params] in + * [ConfigMethod]) to perform checks and initialize the bindings. + * + * For example, the following code performs checks on the `authorizer` and `auth_spec` input arguments, returning + * an error (see [errorType] in [ConfigMethod]) in case these checks fail, and leaves two plugins defined in two + * variable bindings, `authn_plugin` and `authz_plugin`. + * + * ```rust + * if authorizer != 69 { + * return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1")); + * } + + * if auth_spec.len() != 69 { + * return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2")); + * } + * let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; + * let authz_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; + * ``` + * + * [Rust statements]: https://doc.rust-lang.org/reference/statements.html + */ + val code: Writable, + /** Ordered list of layers that should be applied. Layers are executed in the order they appear in the list. **/ + val layerBindings: List, + /** Ordered list of HTTP plugins that should be applied. Http plugins are executed in the order they appear in the list. **/ + val httpPluginBindings: List, + /** Ordered list of model plugins that should be applied. Model plugins are executed in the order they appear in the list. **/ + val modelPluginBindings: List, +) + +/** + * Represents a variable binding. For example, the following Rust code: + * + * ```rust + * fn foo(bar: String) { + * let baz: u64 = 69; + * } + * + * has two variable bindings. The `bar` name is bound to a `String` variable and the `baz` name is bound to a + * `u64` variable. + * ``` + */ +data class Binding( + /** The name of the variable. */ + val name: String, + /** The type of the variable. */ + val ty: RuntimeType, +) + class ServiceConfigGenerator( codegenContext: ServerCodegenContext, + private val configMethods: List, ) { private val crateName = codegenContext.moduleUseName() - private val codegenScope = codegenContext.runtimeConfig.let { runtimeConfig -> - val smithyHttpServer = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() - arrayOf( - "Debug" to RuntimeType.Debug, - "SmithyHttpServer" to smithyHttpServer, - "PluginStack" to smithyHttpServer.resolve("plugin::PluginStack"), - "ModelMarker" to smithyHttpServer.resolve("plugin::ModelMarker"), - "HttpMarker" to smithyHttpServer.resolve("plugin::HttpMarker"), - "Tower" to RuntimeType.Tower, - "Stack" to RuntimeType.Tower.resolve("layer::util::Stack"), - ) - } + private val smithyHttpServer = ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType() + private val codegenScope = arrayOf( + *preludeScope, + "Debug" to RuntimeType.Debug, + "SmithyHttpServer" to smithyHttpServer, + "PluginStack" to smithyHttpServer.resolve("plugin::PluginStack"), + "ModelMarker" to smithyHttpServer.resolve("plugin::ModelMarker"), + "HttpMarker" to smithyHttpServer.resolve("plugin::HttpMarker"), + "Tower" to RuntimeType.Tower, + "Stack" to RuntimeType.Tower.resolve("layer::util::Stack"), + ) private val serviceName = codegenContext.serviceShape.id.name.toPascalCase() fun render(writer: RustWriter) { + val unwrapConfigBuilder = if (isBuilderFallible) { + """ + /// .expect("config failed to build"); + """ + } else { + ";" + } + writer.rustTemplate( """ /// Configuration for the [`$serviceName`]. This is the central place where to register and @@ -50,7 +160,7 @@ class ServiceConfigGenerator( /// .http_plugin(authentication_plugin) /// // ...and right after deserialization, model plugins. /// .model_plugin(authorization_plugin) - /// .build(); + /// .build()$unwrapConfigBuilder /// ``` /// /// See the [`plugin`] system for details. @@ -74,6 +184,7 @@ class ServiceConfigGenerator( layers: #{Tower}::layer::util::Identity::new(), http_plugins: #{SmithyHttpServer}::plugin::IdentityPlugin, model_plugins: #{SmithyHttpServer}::plugin::IdentityPlugin, + #{BuilderRequiredMethodFlagsInit:W} } } } @@ -84,15 +195,21 @@ class ServiceConfigGenerator( pub(crate) layers: L, pub(crate) http_plugins: H, pub(crate) model_plugins: M, + #{BuilderRequiredMethodFlagDefinitions:W} } + + #{BuilderRequiredMethodError:W} impl ${serviceName}ConfigBuilder { + #{InjectedMethods:W} + /// Add a [`#{Tower}::Layer`] to the service. pub fn layer(self, layer: NewLayer) -> ${serviceName}ConfigBuilder<#{Stack}, H, M> { ${serviceName}ConfigBuilder { layers: #{Stack}::new(layer, self.layers), http_plugins: self.http_plugins, model_plugins: self.model_plugins, + #{BuilderRequiredMethodFlagsMove1:W} } } @@ -109,6 +226,7 @@ class ServiceConfigGenerator( layers: self.layers, http_plugins: #{PluginStack}::new(http_plugin, self.http_plugins), model_plugins: self.model_plugins, + #{BuilderRequiredMethodFlagsMove2:W} } } @@ -125,20 +243,203 @@ class ServiceConfigGenerator( layers: self.layers, http_plugins: self.http_plugins, model_plugins: #{PluginStack}::new(model_plugin, self.model_plugins), + #{BuilderRequiredMethodFlagsMove3:W} } } + + #{BuilderBuildMethod:W} + } + """, + *codegenScope, + "BuilderRequiredMethodFlagsInit" to builderRequiredMethodFlagsInit(), + "BuilderRequiredMethodFlagDefinitions" to builderRequiredMethodFlagsDefinitions(), + "BuilderRequiredMethodError" to builderRequiredMethodError(), + "InjectedMethods" to injectedMethods(), + "BuilderRequiredMethodFlagsMove1" to builderRequiredMethodFlagsMove(), + "BuilderRequiredMethodFlagsMove2" to builderRequiredMethodFlagsMove(), + "BuilderRequiredMethodFlagsMove3" to builderRequiredMethodFlagsMove(), + "BuilderBuildMethod" to builderBuildMethod(), + ) + } + + private val isBuilderFallible = configMethods.isBuilderFallible() - /// Build the configuration. - pub fn build(self) -> super::${serviceName}Config { + private fun builderBuildRequiredMethodChecks() = configMethods.filter { it.isRequired }.map { + writable { + rustTemplate( + """ + if !self.${it.requiredBuilderFlagName()} { + return #{Err}(${serviceName}ConfigError::${it.requiredErrorVariant()}); + } + """, + *codegenScope, + ) + } + }.join("\n") + + private fun builderRequiredMethodFlagsDefinitions() = configMethods.filter { it.isRequired }.map { + writable { rust("pub(crate) ${it.requiredBuilderFlagName()}: bool,") } + }.join("\n") + + private fun builderRequiredMethodFlagsInit() = configMethods.filter { it.isRequired }.map { + writable { rust("${it.requiredBuilderFlagName()}: false,") } + }.join("\n") + + private fun builderRequiredMethodFlagsMove() = configMethods.filter { it.isRequired }.map { + writable { rust("${it.requiredBuilderFlagName()}: self.${it.requiredBuilderFlagName()},") } + }.join("\n") + + private fun builderRequiredMethodError() = writable { + if (isBuilderFallible) { + val variants = configMethods.filter { it.isRequired }.map { + writable { + rust( + """ + ##[error("service is not fully configured; invoke `${it.name}` on the config builder")] + ${it.requiredErrorVariant()}, + """, + ) + } + } + rustTemplate( + """ + ##[derive(Debug, #{ThisError}::Error)] + pub enum ${serviceName}ConfigError { + #{Variants:W} + } + """, + "ThisError" to ServerCargoDependency.ThisError.toType(), + "Variants" to variants.join("\n"), + ) + } + } + + private fun injectedMethods() = configMethods.map { + writable { + val paramBindings = it.params.map { binding -> + writable { rustTemplate("${binding.name}: #{BindingTy},", "BindingTy" to binding.ty) } + }.join("\n") + + // This produces a nested type like: "S>", where + // - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack + // and the second is the "outer" part of the stack. The outer part gets executed first. For an example, + // see `aws_smithy_http_server::plugin::PluginStack`. + // - "A", "B" are the types of the "things" that are added. + // - "T" is the generic type variable name used in the enclosing impl block. + fun List.stackReturnType(genericTypeVarName: String, stackType: RuntimeType): Writable = + this.fold(writable { rust(genericTypeVarName) }) { acc, next -> + writable { + rustTemplate( + "#{StackType}<#{Ty}, #{Acc:W}>", + "StackType" to stackType, + "Ty" to next.ty, + "Acc" to acc, + ) + } + } + + val layersReturnTy = + it.initializer.layerBindings.stackReturnType("L", RuntimeType.Tower.resolve("layer::util::Stack")) + val httpPluginsReturnTy = + it.initializer.httpPluginBindings.stackReturnType("H", smithyHttpServer.resolve("plugin::PluginStack")) + val modelPluginsReturnTy = + it.initializer.modelPluginBindings.stackReturnType("M", smithyHttpServer.resolve("plugin::PluginStack")) + + val configBuilderReturnTy = writable { + rustTemplate( + """ + ${serviceName}ConfigBuilder< + #{LayersReturnTy:W}, + #{HttpPluginsReturnTy:W}, + #{ModelPluginsReturnTy:W}, + > + """, + "LayersReturnTy" to layersReturnTy, + "HttpPluginsReturnTy" to httpPluginsReturnTy, + "ModelPluginsReturnTy" to modelPluginsReturnTy, + ) + } + + val returnTy = if (it.errorType != null) { + writable { + rustTemplate( + "#{Result}<#{T:W}, #{E}>", + "T" to configBuilderReturnTy, + "E" to it.errorType, + *codegenScope, + ) + } + } else { + configBuilderReturnTy + } + + docs(it.docs) + rustBlockTemplate( + """ + pub fn ${it.name}( + ##[allow(unused_mut)] + mut self, + #{ParamBindings:W} + ) -> #{ReturnTy:W} + """, + "ReturnTy" to returnTy, + "ParamBindings" to paramBindings, + ) { + rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code) + + check(it.initializer.layerBindings.size + it.initializer.httpPluginBindings.size + it.initializer.modelPluginBindings.size > 0) { + "This method's initializer does not register any layers, HTTP plugins, or model plugins. It must register at least something!" + } + + if (it.isRequired) { + rust("self.${it.requiredBuilderFlagName()} = true;") + } + conditionalBlock("Ok(", ")", conditional = it.errorType != null) { + val registrations = ( + it.initializer.layerBindings.map { ".layer(${it.name})" } + + it.initializer.httpPluginBindings.map { ".http_plugin(${it.name})" } + + it.initializer.modelPluginBindings.map { ".model_plugin(${it.name})" } + ).joinToString("") + rust("self$registrations") + } + } + } + }.join("\n\n") + + private fun builderBuildReturnType() = writable { + val t = "super::${serviceName}Config" + + if (isBuilderFallible) { + rustTemplate("#{Result}<$t, ${serviceName}ConfigError>", *codegenScope) + } else { + rust(t) + } + } + + private fun builderBuildMethod() = writable { + rustBlockTemplate( + """ + /// Build the configuration. + pub fn build(self) -> #{BuilderBuildReturnTy:W} + """, + "BuilderBuildReturnTy" to builderBuildReturnType(), + ) { + rustTemplate( + "#{BuilderBuildRequiredMethodChecks:W}", + "BuilderBuildRequiredMethodChecks" to builderBuildRequiredMethodChecks(), + ) + + conditionalBlock("Ok(", ")", isBuilderFallible) { + rust( + """ super::${serviceName}Config { layers: self.layers, http_plugins: self.http_plugins, model_plugins: self.model_plugins, } - } + """, + ) } - """, - *codegenScope, - ) + } } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt new file mode 100644 index 0000000000..c2c568b291 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt @@ -0,0 +1,231 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.testModule +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest +import java.io.File + +internal class ServiceConfigGeneratorTest { + @Test + fun `it should inject an aws_auth method that configures an HTTP plugin and a model plugin`() { + val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() + + val decorator = object : ServerCodegenDecorator { + override val name: String + get() = "AWSAuth pre-applied middleware decorator" + override val order: Byte + get() = -69 + + override fun configMethods(codegenContext: ServerCodegenContext): List { + val smithyHttpServer = ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType() + val codegenScope = arrayOf( + "SmithyHttpServer" to smithyHttpServer, + ) + return listOf( + ConfigMethod( + name = "aws_auth", + docs = "Docs", + params = listOf( + Binding("auth_spec", RuntimeType.String), + Binding("authorizer", RuntimeType.U64), + ), + errorType = RuntimeType.std.resolve("io::Error"), + initializer = Initializer( + code = writable { + rustTemplate( + """ + if authorizer != 69 { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1")); + } + + if auth_spec.len() != 69 { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2")); + } + let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; + let authz_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; + """, + *codegenScope, + ) + }, + layerBindings = emptyList(), + httpPluginBindings = listOf( + Binding( + "authn_plugin", + smithyHttpServer.resolve("plugin::IdentityPlugin"), + ), + ), + modelPluginBindings = listOf( + Binding( + "authz_plugin", + smithyHttpServer.resolve("plugin::IdentityPlugin"), + ), + ), + ), + isRequired = true, + ), + ) + } + } + + serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, rustCrate -> + rustCrate.testModule { + rust( + """ + use crate::{SimpleServiceConfig, SimpleServiceConfigError}; + use aws_smithy_http_server::plugin::IdentityPlugin; + use crate::server::plugin::PluginStack; + """, + ) + + unitTest("successful_config_initialization") { + rust( + """ + let _: SimpleServiceConfig< + tower::layer::util::Identity, + // One HTTP plugin has been applied. + PluginStack, + // One model plugin has been applied. + PluginStack, + > = SimpleServiceConfig::builder() + .aws_auth("a".repeat(69).to_owned(), 69) + .expect("failed to configure aws_auth") + .build() + .unwrap(); + """, + ) + } + + unitTest("wrong_aws_auth_auth_spec") { + rust( + """ + let actual_err = SimpleServiceConfig::builder() + .aws_auth("a".to_owned(), 69) + .unwrap_err(); + let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 2").to_string(); + assert_eq!(actual_err.to_string(), expected); + """, + ) + } + + unitTest("wrong_aws_auth_authorizer") { + rust( + """ + let actual_err = SimpleServiceConfig::builder() + .aws_auth("a".repeat(69).to_owned(), 6969) + .unwrap_err(); + let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 1").to_string(); + assert_eq!(actual_err.to_string(), expected); + """, + ) + } + + unitTest("aws_auth_not_configured") { + rust( + """ + let actual_err = SimpleServiceConfig::builder().build().unwrap_err(); + let expected = SimpleServiceConfigError::AwsAuthNotConfigured.to_string(); + assert_eq!(actual_err.to_string(), expected); + """, + ) + } + } + } + } + + @Test + fun `it should inject an method that applies three non-required layers`() { + val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() + + val decorator = object : ServerCodegenDecorator { + override val name: String + get() = "ApplyThreeNonRequiredLayers" + override val order: Byte + get() = 69 + + override fun configMethods(codegenContext: ServerCodegenContext): List { + val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity") + val codegenScope = arrayOf( + "Identity" to identityLayer, + ) + return listOf( + ConfigMethod( + name = "three_non_required_layers", + docs = "Docs", + params = emptyList(), + errorType = null, + initializer = Initializer( + code = writable { + rustTemplate( + """ + let layer1 = #{Identity}::new(); + let layer2 = #{Identity}::new(); + let layer3 = #{Identity}::new(); + """, + *codegenScope, + ) + }, + layerBindings = listOf( + Binding("layer1", identityLayer), + Binding("layer2", identityLayer), + Binding("layer3", identityLayer), + ), + httpPluginBindings = emptyList(), + modelPluginBindings = emptyList(), + ), + isRequired = false, + ), + ) + } + } + + serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, rustCrate -> + rustCrate.testModule { + unitTest("successful_config_initialization_applying_the_three_layers") { + rust( + """ + let _: crate::SimpleServiceConfig< + // Three Tower layers have been applied. + tower::layer::util::Stack< + tower::layer::util::Identity, + tower::layer::util::Stack< + tower::layer::util::Identity, + tower::layer::util::Stack< + tower::layer::util::Identity, + tower::layer::util::Identity, + >, + >, + >, + aws_smithy_http_server::plugin::IdentityPlugin, + aws_smithy_http_server::plugin::IdentityPlugin, + > = crate::SimpleServiceConfig::builder() + .three_non_required_layers() + .build(); + """, + ) + } + + unitTest("successful_config_initialization_without_applying_the_three_layers") { + rust( + """ + crate::SimpleServiceConfig::builder().build(); + """, + ) + } + } + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/plugin/stack.rs b/rust-runtime/aws-smithy-http-server/src/plugin/stack.rs index 6c96ebaca0..c42462ec52 100644 --- a/rust-runtime/aws-smithy-http-server/src/plugin/stack.rs +++ b/rust-runtime/aws-smithy-http-server/src/plugin/stack.rs @@ -4,6 +4,7 @@ */ use super::{HttpMarker, ModelMarker, Plugin}; +use std::fmt::Debug; /// A wrapper struct which composes an `Inner` and an `Outer` [`Plugin`]. /// @@ -13,6 +14,7 @@ use super::{HttpMarker, ModelMarker, Plugin}; /// [`HttpPlugins`](crate::plugin::HttpPlugins), and the primary tool for composing HTTP plugins is /// [`ModelPlugins`](crate::plugin::ModelPlugins); if you are an application writer, you should /// prefer composing plugins using these. +#[derive(Debug)] pub struct PluginStack { inner: Inner, outer: Outer,