diff --git a/.changelog/1b61ec1ce18c4cdfae74f8852ecbf877.json b/.changelog/1b61ec1ce18c4cdfae74f8852ecbf877.json new file mode 100644 index 00000000000..2c086adc345 --- /dev/null +++ b/.changelog/1b61ec1ce18c4cdfae74f8852ecbf877.json @@ -0,0 +1,9 @@ +{ + "id": "1b61ec1c-e18c-4cdf-ae74-f8852ecbf877", + "type": "bugfix", + "description": "The SDK client has been updated to utilize the `aws.IsCredentialsProvider` function for determining if `aws.AnonymousCredentials` has been configured for the `CredentialProvider`.", + "modules": [ + "service/eventbridge", + "service/s3" + ] +} \ No newline at end of file diff --git a/.changelog/869890a030aa4f8e8ddd5ef80b7a01df.json b/.changelog/869890a030aa4f8e8ddd5ef80b7a01df.json new file mode 100644 index 00000000000..5eca3364ece --- /dev/null +++ b/.changelog/869890a030aa4f8e8ddd5ef80b7a01df.json @@ -0,0 +1,8 @@ +{ + "id": "869890a0-30aa-4f8e-8ddd-5ef80b7a01df", + "type": "feature", + "description": "Adds `aws.IsCredentialsProvider` for inspecting `CredentialProvider` types when needing to determine if the underlying implementation type matches a target type. This resolves an issue where `CredentialsCache` could mask `AnonymousCredentials` providers, breaking downstream detection logic.", + "modules": [ + "." + ] +} diff --git a/aws/credential_cache.go b/aws/credential_cache.go index 9e9525231c5..781ac0ae2c0 100644 --- a/aws/credential_cache.go +++ b/aws/credential_cache.go @@ -178,6 +178,12 @@ func (p *CredentialsCache) Invalidate() { p.creds.Store((*Credentials)(nil)) } +// IsCredentialsProvider returns whether credential provider wrapped by CredentialsCache +// matches the target provider type. +func (p *CredentialsCache) IsCredentialsProvider(target CredentialsProvider) bool { + return IsCredentialsProvider(p.provider, target) +} + // HandleFailRefreshCredentialsCacheStrategy is an interface for // CredentialsCache to allow CredentialsProvider how failed to refresh // credentials is handled. diff --git a/aws/credential_cache_test.go b/aws/credential_cache_test.go index 28f43d3871d..0c40fa6d48f 100644 --- a/aws/credential_cache_test.go +++ b/aws/credential_cache_test.go @@ -617,3 +617,46 @@ func (m mockAdjustExpiryBy) AdjustExpiresBy(creds Credentials, dur time.Duration } return m.creds, m.err } + +func TestCredentialsCache_IsCredentialsProvider(t *testing.T) { + tests := map[string]struct { + provider CredentialsProvider + target CredentialsProvider + want bool + }{ + "nil provider and target": { + provider: nil, + target: nil, + want: true, + }, + "matches value implementations": { + provider: NewCredentialsCache(AnonymousCredentials{}), + target: AnonymousCredentials{}, + want: true, + }, + "matches value and pointer implementations, wrapped pointer": { + provider: NewCredentialsCache(&AnonymousCredentials{}), + target: AnonymousCredentials{}, + want: true, + }, + "matches value and pointer implementations, pointer target": { + provider: NewCredentialsCache(AnonymousCredentials{}), + target: &AnonymousCredentials{}, + want: true, + }, + "does not match mismatched provider types": { + provider: NewCredentialsCache(AnonymousCredentials{}), + target: &stubCredentialsProvider{}, + want: false, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + if got := NewCredentialsCache(tt.provider).IsCredentialsProvider(tt.target); got != tt.want { + t.Errorf("IsCredentialsProvider() = %v, want %v", got, tt.want) + } + }) + } +} + +var _ isCredentialsProvider = (*CredentialsCache)(nil) diff --git a/aws/credentials.go b/aws/credentials.go index 24c8ce4a73f..714d4ad85cb 100644 --- a/aws/credentials.go +++ b/aws/credentials.go @@ -3,6 +3,7 @@ package aws import ( "context" "fmt" + "reflect" "time" "github.com/aws/aws-sdk-go-v2/internal/sdk" @@ -129,3 +130,41 @@ type CredentialsProviderFunc func(context.Context) (Credentials, error) func (fn CredentialsProviderFunc) Retrieve(ctx context.Context) (Credentials, error) { return fn(ctx) } + +type isCredentialsProvider interface { + IsCredentialsProvider(CredentialsProvider) bool +} + +// IsCredentialsProvider returns whether the target CredentialProvider is the same type as provider when comparing the +// implementation type. +// +// If provider has a method IsCredentialsProvider(CredentialsProvider) bool it will be responsible for validating +// whether target matches the credential provider type. +// +// When comparing the CredentialProvider implementations provider and target for equality, the following rules are used: +// +// If provider is of type T and target is of type V, true if type *T is the same as type *V, otherwise false +// If provider is of type *T and target is of type V, true if type *T is the same as type *V, otherwise false +// If provider is of type T and target is of type *V, true if type *T is the same as type *V, otherwise false +// If provider is of type *T and target is of type *V,true if type *T is the same as type *V, otherwise false +func IsCredentialsProvider(provider, target CredentialsProvider) bool { + if target == nil || provider == nil { + return provider == target + } + + if x, ok := provider.(isCredentialsProvider); ok { + return x.IsCredentialsProvider(target) + } + + targetType := reflect.TypeOf(target) + if targetType.Kind() != reflect.Ptr { + targetType = reflect.PtrTo(targetType) + } + + providerType := reflect.TypeOf(provider) + if providerType.Kind() != reflect.Ptr { + providerType = reflect.PtrTo(providerType) + } + + return targetType.AssignableTo(providerType) +} diff --git a/aws/credentials_test.go b/aws/credentials_test.go new file mode 100644 index 00000000000..0ce61f360c1 --- /dev/null +++ b/aws/credentials_test.go @@ -0,0 +1,83 @@ +package aws + +import ( + "context" + "testing" +) + +type anonymousNamedType AnonymousCredentials + +func (f anonymousNamedType) Retrieve(ctx context.Context) (Credentials, error) { + return AnonymousCredentials(f).Retrieve(ctx) +} + +func TestIsCredentialsProvider(t *testing.T) { + tests := map[string]struct { + provider CredentialsProvider + target CredentialsProvider + want bool + }{ + "same implementations": { + provider: AnonymousCredentials{}, + target: AnonymousCredentials{}, + want: true, + }, + "same implementations, pointer target": { + provider: AnonymousCredentials{}, + target: &AnonymousCredentials{}, + want: true, + }, + "same implementations, pointer provider": { + provider: &AnonymousCredentials{}, + target: AnonymousCredentials{}, + want: true, + }, + "same implementations, both pointers": { + provider: &AnonymousCredentials{}, + target: &AnonymousCredentials{}, + want: true, + }, + "different implementations, nil target": { + provider: AnonymousCredentials{}, + target: nil, + want: false, + }, + "different implementations, nil provider": { + provider: nil, + target: AnonymousCredentials{}, + want: false, + }, + "different implementations": { + provider: AnonymousCredentials{}, + target: &stubCredentialsProvider{}, + want: false, + }, + "nil provider and target": { + provider: nil, + target: nil, + want: true, + }, + "implements IsCredentialsProvider, match": { + provider: NewCredentialsCache(AnonymousCredentials{}), + target: AnonymousCredentials{}, + want: true, + }, + "implements IsCredentialsProvider, no match": { + provider: NewCredentialsCache(AnonymousCredentials{}), + target: &stubCredentialsProvider{}, + want: false, + }, + "named types aliasing underlying types": { + provider: AnonymousCredentials{}, + target: anonymousNamedType{}, + want: false, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + if got := IsCredentialsProvider(tt.provider, tt.target); got != tt.want { + t.Errorf("IsCredentialsProvider() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/aws/signer/v4/middleware.go b/aws/signer/v4/middleware.go index db8377ae509..749bda69eed 100644 --- a/aws/signer/v4/middleware.go +++ b/aws/signer/v4/middleware.go @@ -371,13 +371,8 @@ func haveCredentialProvider(p aws.CredentialsProvider) bool { if p == nil { return false } - switch p.(type) { - case aws.AnonymousCredentials, - *aws.AnonymousCredentials: - return false - } - return true + return !aws.IsCredentialsProvider(p, (*aws.AnonymousCredentials)(nil)) } type payloadHashKey struct{} diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsSignatureVersion4aUtils.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsSignatureVersion4aUtils.java index bc4fe28db3b..8f3fa0bf35b 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsSignatureVersion4aUtils.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsSignatureVersion4aUtils.java @@ -42,6 +42,8 @@ public static void writeCredentialProviderResolver(GoWriter writer) { AwsCustomGoDependency.INTERNAL_SIGV4A).build()); writer.putContext("anonType", SymbolUtils.createPointableSymbolBuilder("AnonymousCredentials", AwsCustomGoDependency.AWS_CORE).build()); + writer.putContext("isProvider", SymbolUtils.createValueSymbolBuilder("IsCredentialsProvider", + AwsCustomGoDependency.AWS_CORE).build()); writer.putContext("adapType", SymbolUtils.createPointableSymbolBuilder("SymmetricCredentialAdaptor", AwsCustomGoDependency.INTERNAL_SIGV4A).build()); writer.write(""" @@ -54,9 +56,8 @@ public static void writeCredentialProviderResolver(GoWriter writer) { return } - switch o.$fieldName:L.(type) { - case $anonType:T, $anonType:P: - return + if $isProvider:T(o.$fieldName:L, ($anonType:P)(nil)) { + return } o.$fieldName:L = &$adapType:T{SymmetricProvider: o.$fieldName:L} diff --git a/service/eventbridge/api_client.go b/service/eventbridge/api_client.go index e24a63307ca..a48190c4066 100644 --- a/service/eventbridge/api_client.go +++ b/service/eventbridge/api_client.go @@ -405,8 +405,7 @@ func resolveCredentialProvider(o *Options) { return } - switch o.Credentials.(type) { - case aws.AnonymousCredentials, *aws.AnonymousCredentials: + if aws.IsCredentialsProvider(o.Credentials, (*aws.AnonymousCredentials)(nil)) { return } diff --git a/service/s3/api_client.go b/service/s3/api_client.go index 790cc43807b..f228819fc28 100644 --- a/service/s3/api_client.go +++ b/service/s3/api_client.go @@ -487,8 +487,7 @@ func resolveCredentialProvider(o *Options) { return } - switch o.Credentials.(type) { - case aws.AnonymousCredentials, *aws.AnonymousCredentials: + if aws.IsCredentialsProvider(o.Credentials, (*aws.AnonymousCredentials)(nil)) { return }