Skip to content

Commit

Permalink
Connect aws-config token providers to service config via codegen (#3443)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdisanti authored Mar 1, 2024
2 parents 414c137 + 3d924d6 commit b49b88d
Show file tree
Hide file tree
Showing 19 changed files with 6,301 additions and 279 deletions.
9 changes: 8 additions & 1 deletion aws/rust-runtime/aws-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,14 @@ mod loader {

/// Set test credentials for use when signing requests
pub fn test_credentials(self) -> Self {
self.credentials_provider(Credentials::for_tests())
#[allow(unused_mut)]
let mut ret = self.credentials_provider(Credentials::for_tests());
#[cfg(all(feature = "sso", feature = "test-util"))]
{
use aws_smithy_runtime_api::client::identity::http::Token;
ret = ret.token_provider(Token::for_tests());
}
ret
}

/// Override the access token provider used to build [`SdkConfig`].
Expand Down
2 changes: 1 addition & 1 deletion aws/rust-runtime/aws-credential-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repository = "https://github.com/smithy-lang/smithy-rs"

[features]
hardcoded-credentials = []
test-util = []
test-util = ["aws-smithy-runtime-api/test-util"]

[dependencies]
aws-smithy-async = { path = "../../../rust-runtime/aws-smithy-async" }
Expand Down
1 change: 1 addition & 0 deletions aws/rust-runtime/aws-credential-types/external-types.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ allowed_external_types = [
"aws_smithy_async::rt::sleep::SharedAsyncSleep",
"aws_smithy_runtime_api::client::identity::ResolveIdentity",
"aws_smithy_runtime_api::client::identity::http::Token",
"aws_smithy_runtime_api::shared::FromUnshared",
"aws_smithy_types::config_bag::storable::Storable",
"aws_smithy_types::config_bag::storable::StoreReplace",
"aws_smithy_types::config_bag::storable::Storer",
Expand Down
18 changes: 18 additions & 0 deletions aws/rust-runtime/aws-credential-types/src/provider/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
//! token providers in the SDK config.
use crate::{provider::error::TokenError, provider::future, Token};
use aws_smithy_runtime_api::client::{
identity::{IdentityFuture, ResolveIdentity},
runtime_components::RuntimeComponents,
};
use aws_smithy_runtime_api::impl_shared_conversions;
use aws_smithy_types::config_bag::ConfigBag;
use std::sync::Arc;

/// Result type for token providers
Expand Down Expand Up @@ -71,3 +77,15 @@ impl ProvideToken for SharedTokenProvider {
self.0.provide_token()
}
}

impl ResolveIdentity for SharedTokenProvider {
fn resolve_identity<'a>(
&'a self,
_runtime_components: &'a RuntimeComponents,
_config_bag: &'a ConfigBag,
) -> IdentityFuture<'a> {
IdentityFuture::new(async move { Ok(self.provide_token().await?.into()) })
}
}

impl_shared_conversions!(convert SharedTokenProvider from ProvideToken using SharedTokenProvider::new);
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ val DECORATORS: List<ClientCodegenDecorator> =
InvocationIdDecorator(),
RetryInformationHeaderDecorator(),
RemoveDefaultsDecorator(),
TokenProvidersDecorator(),
),
// Service specific decorators
ApiGatewayDecorator().onlyApplyTo("com.amazonaws.apigateway#BackplaneControlService"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.configReexport
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.ConditionalDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.TestUtilFeature
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.supportedAuthSchemes
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.featureGateBlock
Expand All @@ -23,38 +23,40 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.adhocCustomization

class CredentialsProviderDecorator : ClientCodegenDecorator {
override val name: String = "CredentialsProvider"
override val order: Byte = 0
class CredentialsProviderDecorator : ConditionalDecorator(
predicate = { codegenContext, _ -> codegenContext?.usesSigAuth() ?: false },
delegateTo =
object : ClientCodegenDecorator {
override val name: String = "CredentialsProviderDecorator"
override val order: Byte = 0

override fun configCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ConfigCustomization>,
): List<ConfigCustomization> {
return baseCustomizations + CredentialProviderConfig(codegenContext)
}
override fun configCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ConfigCustomization>,
): List<ConfigCustomization> = baseCustomizations + CredentialProviderConfig(codegenContext)

override fun extraSections(codegenContext: ClientCodegenContext): List<AdHocCustomization> =
listOf(
adhocCustomization<SdkConfigSection.CopySdkConfigToClientConfig> { section ->
rust("${section.serviceConfigBuilder}.set_credentials_provider(${section.sdkConfig}.credentials_provider());")
},
)
override fun extraSections(codegenContext: ClientCodegenContext): List<AdHocCustomization> =
listOf(
adhocCustomization<SdkConfigSection.CopySdkConfigToClientConfig> { section ->
rust("${section.serviceConfigBuilder}.set_credentials_provider(${section.sdkConfig}.credentials_provider());")
},
)

override fun extras(
codegenContext: ClientCodegenContext,
rustCrate: RustCrate,
) {
rustCrate.mergeFeature(TestUtilFeature.copy(deps = listOf("aws-credential-types/test-util")))
override fun extras(
codegenContext: ClientCodegenContext,
rustCrate: RustCrate,
) {
rustCrate.mergeFeature(TestUtilFeature.copy(deps = listOf("aws-credential-types/test-util")))

rustCrate.withModule(ClientRustModule.config) {
rust(
"pub use #T::Credentials;",
AwsRuntimeType.awsCredentialTypes(codegenContext.runtimeConfig),
)
}
}
}
rustCrate.withModule(ClientRustModule.config) {
rust(
"pub use #T::Credentials;",
AwsRuntimeType.awsCredentialTypes(codegenContext.runtimeConfig),
)
}
}
},
)

/**
* Add a `.credentials_provider` field and builder to the `Config` for a given service
Expand Down Expand Up @@ -125,7 +127,7 @@ class CredentialProviderConfig(private val codegenContext: ClientCodegenContext)
""",
*codegenScope,
) {
if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) {
if (codegenContext.usesSigV4a()) {
featureGateBlock("sigv4a") {
rustTemplate(
"self.runtime_components.set_identity_resolver(#{SIGV4A_SCHEME_ID}, credentials_provider.clone());",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rulesengine.language.EndpointRuleSet
import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.AuthSchemeOption
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.supportedAuthSchemes
import software.amazon.smithy.rust.codegen.client.smithy.customize.ConditionalDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.AuthSchemeLister
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
Expand All @@ -37,69 +40,88 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isInputEventStream
import software.amazon.smithy.rust.codegen.core.util.thenSingletonListOf

class SigV4AuthDecorator : ClientCodegenDecorator {
override val name: String get() = "SigV4AuthDecorator"
override val order: Byte = 0
internal fun ClientCodegenContext.usesSigAuth(): Boolean =
ServiceIndex.of(model).getEffectiveAuthSchemes(serviceShape).containsKey(SigV4Trait.ID) ||
usesSigV4a()

private val sigv4a = "sigv4a"
/**
* SigV4a doesn't have a Smithy auth trait yet, so this is a hack to determine if a service supports it.
*
* In the future, Smithy's `ServiceIndex.getEffectiveAuthSchemes` should be used instead.
*/
internal fun ClientCodegenContext.usesSigV4a(): Boolean {
val endpointAuthSchemes =
serviceShape.getTrait<EndpointRuleSetTrait>()?.ruleSet?.let { EndpointRuleSet.fromNode(it) }
?.also { it.typeCheck() }?.let { AuthSchemeLister.authSchemesForRuleset(it) } ?: setOf()
return endpointAuthSchemes.contains("sigv4a")
}

private fun sigv4(runtimeConfig: RuntimeConfig) =
writable {
val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth")
rust("#T", awsRuntimeAuthModule.resolve("sigv4::SCHEME_ID"))
}
class SigV4AuthDecorator : ConditionalDecorator(
predicate = { codegenContext, _ -> codegenContext?.usesSigAuth() ?: false },
delegateTo =
object : ClientCodegenDecorator {
override val name: String get() = "SigV4AuthDecorator"
override val order: Byte = 0

private fun sigv4a(runtimeConfig: RuntimeConfig) =
writable {
val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth")
featureGateBlock(sigv4a) {
rust("#T", awsRuntimeAuthModule.resolve("sigv4a::SCHEME_ID"))
}
}
private val sigv4a = "sigv4a"

override fun authOptions(
codegenContext: ClientCodegenContext,
operationShape: OperationShape,
baseAuthSchemeOptions: List<AuthSchemeOption>,
): List<AuthSchemeOption> {
val supportsSigV4a =
codegenContext.serviceShape.supportedAuthSchemes().contains(sigv4a)
.thenSingletonListOf { sigv4a(codegenContext.runtimeConfig) }
return baseAuthSchemeOptions +
AuthSchemeOption.StaticAuthSchemeOption(
SigV4Trait.ID,
listOf(sigv4(codegenContext.runtimeConfig)) + supportsSigV4a,
)
}
private fun sigv4(runtimeConfig: RuntimeConfig) =
writable {
val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth")
rust("#T", awsRuntimeAuthModule.resolve("sigv4::SCHEME_ID"))
}

override fun serviceRuntimePluginCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ServiceRuntimePluginCustomization>,
): List<ServiceRuntimePluginCustomization> =
baseCustomizations + listOf(AuthServiceRuntimePluginCustomization(codegenContext))
private fun sigv4a(runtimeConfig: RuntimeConfig) =
writable {
val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth")
featureGateBlock(sigv4a) {
rust("#T", awsRuntimeAuthModule.resolve("sigv4a::SCHEME_ID"))
}
}

override fun operationCustomizations(
codegenContext: ClientCodegenContext,
operation: OperationShape,
baseCustomizations: List<OperationCustomization>,
): List<OperationCustomization> = baseCustomizations + AuthOperationCustomization(codegenContext)
override fun authOptions(
codegenContext: ClientCodegenContext,
operationShape: OperationShape,
baseAuthSchemeOptions: List<AuthSchemeOption>,
): List<AuthSchemeOption> {
val supportsSigV4a =
codegenContext.usesSigV4a().thenSingletonListOf { sigv4a(codegenContext.runtimeConfig) }
return baseAuthSchemeOptions +
AuthSchemeOption.StaticAuthSchemeOption(
SigV4Trait.ID,
listOf(sigv4(codegenContext.runtimeConfig)) + supportsSigV4a,
)
}

override fun configCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ConfigCustomization>,
): List<ConfigCustomization> =
baseCustomizations + SigV4SigningConfig(codegenContext.runtimeConfig, codegenContext.serviceShape.getTrait())
override fun serviceRuntimePluginCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ServiceRuntimePluginCustomization>,
): List<ServiceRuntimePluginCustomization> =
baseCustomizations + listOf(AuthServiceRuntimePluginCustomization(codegenContext))

override fun extras(
codegenContext: ClientCodegenContext,
rustCrate: RustCrate,
) {
if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) {
// Add optional feature for SigV4a support
rustCrate.mergeFeature(Feature("sigv4a", true, listOf("aws-runtime/sigv4a")))
}
}
}
override fun operationCustomizations(
codegenContext: ClientCodegenContext,
operation: OperationShape,
baseCustomizations: List<OperationCustomization>,
): List<OperationCustomization> = baseCustomizations + AuthOperationCustomization(codegenContext)

override fun configCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ConfigCustomization>,
): List<ConfigCustomization> =
baseCustomizations + SigV4SigningConfig(codegenContext.runtimeConfig, codegenContext.serviceShape.getTrait())

override fun extras(
codegenContext: ClientCodegenContext,
rustCrate: RustCrate,
) {
if (codegenContext.usesSigV4a()) {
// Add optional feature for SigV4a support
rustCrate.mergeFeature(Feature("sigv4a", true, listOf("aws-runtime/sigv4a")))
}
}
},
)

private class SigV4SigningConfig(
runtimeConfig: RuntimeConfig,
Expand Down Expand Up @@ -166,26 +188,30 @@ private class AuthServiceRuntimePluginCustomization(private val codegenContext:
arrayOf(
"SigV4AuthScheme" to awsRuntime.resolve("auth::sigv4::SigV4AuthScheme"),
"SigV4aAuthScheme" to awsRuntime.resolve("auth::sigv4a::SigV4aAuthScheme"),
"SharedAuthScheme" to RuntimeType.smithyRuntimeApiClient(runtimeConfig).resolve("client::auth::SharedAuthScheme"),
"SharedAuthScheme" to
RuntimeType.smithyRuntimeApiClient(runtimeConfig)
.resolve("client::auth::SharedAuthScheme"),
)
}

override fun section(section: ServiceRuntimePluginSection): Writable =
writable {
when (section) {
is ServiceRuntimePluginSection.RegisterRuntimeComponents -> {
val serviceHasEventStream = codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model)
val serviceHasEventStream =
codegenContext.serviceShape.hasEventStreamOperations(codegenContext.model)
if (serviceHasEventStream) {
// enable the aws-runtime `sign-eventstream` feature
addDependency(
AwsCargoDependency.awsRuntime(runtimeConfig).withFeature("event-stream").toType().toSymbol(),
AwsCargoDependency.awsRuntime(runtimeConfig).withFeature("event-stream").toType()
.toSymbol(),
)
}
section.registerAuthScheme(this) {
rustTemplate("#{SharedAuthScheme}::new(#{SigV4AuthScheme}::new())", *codegenScope)
}

if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) {
if (codegenContext.usesSigV4a()) {
featureGateBlock("sigv4a") {
section.registerAuthScheme(this) {
rustTemplate("#{SharedAuthScheme}::new(#{SigV4aAuthScheme}::new())", *codegenScope)
Expand Down
Loading

0 comments on commit b49b88d

Please sign in to comment.