Skip to content

Commit

Permalink
Allow server decorators to inject methods on config (#3111)
Browse files Browse the repository at this point in the history
PR #3095 added a code-generated `${serviceName}Config` object on which
users can register layers and plugins. For example:

```rust
let config = PokemonServiceConfig::builder()
    .layer(layers)
    .http_plugin(authn_plugin)
    .model_plugin(authz_plugin)
    .build();
```

This PR makes it so that server decorators can inject methods on this
config builder object. These methods can apply arbitrary layers, HTTP
plugins, and/or model plugins. Moreover, the decorator can configure
whether invoking such method is required or not.

For example, a decorator can inject an `aws_auth` method that configures
some plugins using its input arguments. Missing invocation of this
method
will result in the config failing to build:

```rust
let _: SimpleServiceConfig<
    // No layers have been applied.
    tower::layer::util::Identity,
    // One HTTP plugin has been applied.
    PluginStack<IdentityPlugin, IdentityPlugin>,
    // One model plugin has been applied.
    PluginStack<IdentityPlugin, IdentityPlugin>,
> = SimpleServiceConfig::builder()
    // This method has been injected in the config builder!
    .aws_auth("a".repeat(69).to_owned(), 69)
    // The method will apply one HTTP plugin and one model plugin,
    // configuring them with the input arguments. Configuration can be
    // declared to be fallible, in which case we get a `Result` we unwrap
    // here.
    .expect("failed to configure aws_auth")
    .build()
    // Since `aws_auth` has been marked as required, if the user misses
    // invoking it, this would panic here.
    .unwrap();
```

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
david-perez authored and rcoh committed Nov 1, 2023
1 parent 20a89c2 commit 2a6af3b
Show file tree
Hide file tree
Showing 9 changed files with 614 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class TestWriterDelegator(
}

/**
* Generate a newtest module
* Generate a new test module
*
* This should only be used in test code—the generated module name will be something like `tests_123`
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ object ServerCargoDependency {
val Nom: CargoDependency = CargoDependency("nom", CratesIo("7"))
val OnceCell: CargoDependency = CargoDependency("once_cell", CratesIo("1.13"))
val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2"))
val ThisError: CargoDependency = CargoDependency("thiserror", CratesIo("1.0"))
val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))
val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.23.1"), scope = DependencyScope.Dev)
val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.Unconstraine
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedMapGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isBuilderFallible
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator
Expand Down Expand Up @@ -591,11 +592,15 @@ open class ServerCodegenVisitor(
logger.info("[rust-server-codegen] Generating a service $shape")
val serverProtocol = protocolGeneratorFactory.protocol(codegenContext) as ServerProtocol

val configMethods = codegenDecorator.configMethods(codegenContext)
val isConfigBuilderFallible = configMethods.isBuilderFallible()

// Generate root.
rustCrate.lib {
ServerRootGenerator(
serverProtocol,
codegenContext,
isConfigBuilderFallible,
).render(this)
}

Expand All @@ -612,9 +617,10 @@ open class ServerCodegenVisitor(
ServerServiceGenerator(
codegenContext,
serverProtocol,
isConfigBuilderFallible,
).render(this)

ServiceConfigGenerator(codegenContext).render(this)
ServiceConfigGenerator(codegenContext, configMethods).render(this)

ScopeMacroGenerator(codegenContext).render(this)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
import software.amazon.smithy.rust.codegen.server.smithy.ValidationResult
import software.amazon.smithy.rust.codegen.server.smithy.generators.ConfigMethod
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator
import java.util.logging.Logger
Expand All @@ -41,6 +42,12 @@ interface ServerCodegenDecorator : CoreCodegenDecorator<ServerCodegenContext, Se
* Therefore, ensure that all the structure shapes returned by this method are not in the service's closure.
*/
fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List<StructureShape> = emptyList()

/**
* Configuration methods that should be injected into the `${serviceName}Config` struct to allow users to configure
* pre-applied layers and plugins.
*/
fun configMethods(codegenContext: ServerCodegenContext): List<ConfigMethod> = emptyList()
}

/**
Expand Down Expand Up @@ -74,10 +81,11 @@ class CombinedServerCodegenDecorator(decorators: List<ServerCodegenDecorator>) :
decorator.postprocessValidationExceptionNotAttachedErrorMessage(accumulated)
}

override fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List<StructureShape> {
return orderedDecorators.map { decorator -> decorator.postprocessGenerateAdditionalStructures(operationShape) }
.flatten()
}
override fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List<StructureShape> =
orderedDecorators.flatMap { it.postprocessGenerateAdditionalStructures(operationShape) }

override fun configMethods(codegenContext: ServerCodegenContext): List<ConfigMethod> =
orderedDecorators.flatMap { it.configMethods(codegenContext) }

companion object {
fun fromClasspath(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output
open class ServerRootGenerator(
val protocol: ServerProtocol,
private val codegenContext: ServerCodegenContext,
private val isConfigBuilderFallible: Boolean,
) {
private val index = TopDownIndex.of(codegenContext.model)
private val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet(
Expand All @@ -57,6 +58,8 @@ open class ServerRootGenerator(
}
.join("//!\n")

val unwrapConfigBuilder = if (isConfigBuilderFallible) ".expect(\"config failed to build\")" else ""

writer.rustTemplate(
"""
//! A fast and customizable Rust implementation of the $serviceName Smithy service.
Expand All @@ -75,7 +78,10 @@ open class ServerRootGenerator(
//! ## async fn dummy() {
//! use $crateName::{$serviceName, ${serviceName}Config};
//!
//! ## let app = $serviceName::builder(${serviceName}Config::builder().build()).build_unchecked();
//! ## let app = $serviceName::builder(
//! ## ${serviceName}Config::builder()
//! ## .build()$unwrapConfigBuilder
//! ## ).build_unchecked();
//! let server = app.into_make_service();
//! let bind: SocketAddr = "127.0.0.1:6969".parse()
//! .expect("unable to parse the server bind address and port");
Expand All @@ -92,7 +98,10 @@ open class ServerRootGenerator(
//! use $crateName::$serviceName;
//!
//! ## async fn dummy() {
//! ## let app = $serviceName::builder(${serviceName}Config::builder().build()).build_unchecked();
//! ## let app = $serviceName::builder(
//! ## ${serviceName}Config::builder()
//! ## .build()$unwrapConfigBuilder
//! ## ).build_unchecked();
//! let handler = LambdaHandler::new(app);
//! lambda_http::run(handler).await.unwrap();
//! ## }
Expand All @@ -118,7 +127,7 @@ open class ServerRootGenerator(
//! let http_plugins = HttpPlugins::new()
//! .push(LoggingPlugin)
//! .push(MetricsPlugin);
//! let config = ${serviceName}Config::builder().build();
//! let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
//! let builder: $builderName<Body, _, _, _> = $serviceName::builder(config);
//! ```
//!
Expand Down Expand Up @@ -183,13 +192,13 @@ open class ServerRootGenerator(
//!
//! ## Example
//!
//! ```rust
//! ```rust,no_run
//! ## use std::net::SocketAddr;
//! use $crateName::{$serviceName, ${serviceName}Config};
//!
//! ##[#{Tokio}::main]
//! pub async fn main() {
//! let config = ${serviceName}Config::builder().build();
//! let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
//! let app = $serviceName::builder(config)
${builderFieldNames.values.joinToString("\n") { "//! .$it($it)" }}
//! .build()
Expand Down Expand Up @@ -236,6 +245,23 @@ open class ServerRootGenerator(
fun render(rustWriter: RustWriter) {
documentation(rustWriter)

rustWriter.rust("pub use crate::service::{$serviceName, ${serviceName}Config, ${serviceName}ConfigBuilder, ${serviceName}Builder, MissingOperationsError};")
// Only export config builder error if fallible.
val configErrorReExport = if (isConfigBuilderFallible) {
"${serviceName}ConfigError,"
} else {
""
}
rustWriter.rust(
"""
pub use crate::service::{
$serviceName,
${serviceName}Config,
${serviceName}ConfigBuilder,
$configErrorReExport
${serviceName}Builder,
MissingOperationsError
};
""",
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output
class ServerServiceGenerator(
private val codegenContext: ServerCodegenContext,
private val protocol: ServerProtocol,
private val isConfigBuilderFallible: Boolean,
) {
private val runtimeConfig = codegenContext.runtimeConfig
private val smithyHttpServer = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType()
Expand Down Expand Up @@ -107,6 +108,11 @@ class ServerServiceGenerator(
val docHandler = DocHandlerGenerator(codegenContext, operationShape, "handler", "///")
val handler = docHandler.docSignature()
val handlerFixed = docHandler.docFixedSignature()
val unwrapConfigBuilder = if (isConfigBuilderFallible) {
".expect(\"config failed to build\")"
} else {
""
}
rustTemplate(
"""
/// Sets the [`$structName`](crate::operation_shape::$structName) operation.
Expand All @@ -123,7 +129,7 @@ class ServerServiceGenerator(
///
#{Handler:W}
///
/// let config = ${serviceName}Config::builder().build();
/// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
/// let app = $serviceName::builder(config)
/// .$fieldName(handler)
/// /* Set other handlers */
Expand Down Expand Up @@ -186,7 +192,7 @@ class ServerServiceGenerator(
///
#{HandlerFixed:W}
///
/// let config = ${serviceName}Config::builder().build();
/// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
/// let svc = #{Tower}::util::service_fn(handler);
/// let app = $serviceName::builder(config)
/// .${fieldName}_service(svc)
Expand Down
Loading

0 comments on commit 2a6af3b

Please sign in to comment.