Skip to content

Commit

Permalink
Implements IsCredentialsProvider for checking if a provider matches a…
Browse files Browse the repository at this point in the history
… target provider type. (#1890)
  • Loading branch information
skmcgrail authored Oct 21, 2022
1 parent 0fab39a commit 1c05fb6
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 13 deletions.
9 changes: 9 additions & 0 deletions .changelog/1b61ec1ce18c4cdfae74f8852ecbf877.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"id": "1b61ec1c-e18c-4cdf-ae74-f8852ecbf877",
"type": "bugfix",
"description": "The SDK client has been updated to utilize the `aws.IsCredentialsProvider` function for determining if `aws.AnonymousCredentials` has been configured for the `CredentialProvider`.",
"modules": [
"service/eventbridge",
"service/s3"
]
}
8 changes: 8 additions & 0 deletions .changelog/869890a030aa4f8e8ddd5ef80b7a01df.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "869890a0-30aa-4f8e-8ddd-5ef80b7a01df",
"type": "feature",
"description": "Adds `aws.IsCredentialsProvider` for inspecting `CredentialProvider` types when needing to determine if the underlying implementation type matches a target type. This resolves an issue where `CredentialsCache` could mask `AnonymousCredentials` providers, breaking downstream detection logic.",
"modules": [
"."
]
}
6 changes: 6 additions & 0 deletions aws/credential_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ func (p *CredentialsCache) Invalidate() {
p.creds.Store((*Credentials)(nil))
}

// IsCredentialsProvider returns whether credential provider wrapped by CredentialsCache
// matches the target provider type.
func (p *CredentialsCache) IsCredentialsProvider(target CredentialsProvider) bool {
return IsCredentialsProvider(p.provider, target)
}

// HandleFailRefreshCredentialsCacheStrategy is an interface for
// CredentialsCache to allow CredentialsProvider how failed to refresh
// credentials is handled.
Expand Down
43 changes: 43 additions & 0 deletions aws/credential_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,46 @@ func (m mockAdjustExpiryBy) AdjustExpiresBy(creds Credentials, dur time.Duration
}
return m.creds, m.err
}

func TestCredentialsCache_IsCredentialsProvider(t *testing.T) {
tests := map[string]struct {
provider CredentialsProvider
target CredentialsProvider
want bool
}{
"nil provider and target": {
provider: nil,
target: nil,
want: true,
},
"matches value implementations": {
provider: NewCredentialsCache(AnonymousCredentials{}),
target: AnonymousCredentials{},
want: true,
},
"matches value and pointer implementations, wrapped pointer": {
provider: NewCredentialsCache(&AnonymousCredentials{}),
target: AnonymousCredentials{},
want: true,
},
"matches value and pointer implementations, pointer target": {
provider: NewCredentialsCache(AnonymousCredentials{}),
target: &AnonymousCredentials{},
want: true,
},
"does not match mismatched provider types": {
provider: NewCredentialsCache(AnonymousCredentials{}),
target: &stubCredentialsProvider{},
want: false,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
if got := NewCredentialsCache(tt.provider).IsCredentialsProvider(tt.target); got != tt.want {
t.Errorf("IsCredentialsProvider() = %v, want %v", got, tt.want)
}
})
}
}

var _ isCredentialsProvider = (*CredentialsCache)(nil)
39 changes: 39 additions & 0 deletions aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aws
import (
"context"
"fmt"
"reflect"
"time"

"github.com/aws/aws-sdk-go-v2/internal/sdk"
Expand Down Expand Up @@ -129,3 +130,41 @@ type CredentialsProviderFunc func(context.Context) (Credentials, error)
func (fn CredentialsProviderFunc) Retrieve(ctx context.Context) (Credentials, error) {
return fn(ctx)
}

type isCredentialsProvider interface {
IsCredentialsProvider(CredentialsProvider) bool
}

// IsCredentialsProvider returns whether the target CredentialProvider is the same type as provider when comparing the
// implementation type.
//
// If provider has a method IsCredentialsProvider(CredentialsProvider) bool it will be responsible for validating
// whether target matches the credential provider type.
//
// When comparing the CredentialProvider implementations provider and target for equality, the following rules are used:
//
// If provider is of type T and target is of type V, true if type *T is the same as type *V, otherwise false
// If provider is of type *T and target is of type V, true if type *T is the same as type *V, otherwise false
// If provider is of type T and target is of type *V, true if type *T is the same as type *V, otherwise false
// If provider is of type *T and target is of type *V,true if type *T is the same as type *V, otherwise false
func IsCredentialsProvider(provider, target CredentialsProvider) bool {
if target == nil || provider == nil {
return provider == target
}

if x, ok := provider.(isCredentialsProvider); ok {
return x.IsCredentialsProvider(target)
}

targetType := reflect.TypeOf(target)
if targetType.Kind() != reflect.Ptr {
targetType = reflect.PtrTo(targetType)
}

providerType := reflect.TypeOf(provider)
if providerType.Kind() != reflect.Ptr {
providerType = reflect.PtrTo(providerType)
}

return targetType.AssignableTo(providerType)
}
83 changes: 83 additions & 0 deletions aws/credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package aws

import (
"context"
"testing"
)

type anonymousNamedType AnonymousCredentials

func (f anonymousNamedType) Retrieve(ctx context.Context) (Credentials, error) {
return AnonymousCredentials(f).Retrieve(ctx)
}

func TestIsCredentialsProvider(t *testing.T) {
tests := map[string]struct {
provider CredentialsProvider
target CredentialsProvider
want bool
}{
"same implementations": {
provider: AnonymousCredentials{},
target: AnonymousCredentials{},
want: true,
},
"same implementations, pointer target": {
provider: AnonymousCredentials{},
target: &AnonymousCredentials{},
want: true,
},
"same implementations, pointer provider": {
provider: &AnonymousCredentials{},
target: AnonymousCredentials{},
want: true,
},
"same implementations, both pointers": {
provider: &AnonymousCredentials{},
target: &AnonymousCredentials{},
want: true,
},
"different implementations, nil target": {
provider: AnonymousCredentials{},
target: nil,
want: false,
},
"different implementations, nil provider": {
provider: nil,
target: AnonymousCredentials{},
want: false,
},
"different implementations": {
provider: AnonymousCredentials{},
target: &stubCredentialsProvider{},
want: false,
},
"nil provider and target": {
provider: nil,
target: nil,
want: true,
},
"implements IsCredentialsProvider, match": {
provider: NewCredentialsCache(AnonymousCredentials{}),
target: AnonymousCredentials{},
want: true,
},
"implements IsCredentialsProvider, no match": {
provider: NewCredentialsCache(AnonymousCredentials{}),
target: &stubCredentialsProvider{},
want: false,
},
"named types aliasing underlying types": {
provider: AnonymousCredentials{},
target: anonymousNamedType{},
want: false,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
if got := IsCredentialsProvider(tt.provider, tt.target); got != tt.want {
t.Errorf("IsCredentialsProvider() = %v, want %v", got, tt.want)
}
})
}
}
7 changes: 1 addition & 6 deletions aws/signer/v4/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,8 @@ func haveCredentialProvider(p aws.CredentialsProvider) bool {
if p == nil {
return false
}
switch p.(type) {
case aws.AnonymousCredentials,
*aws.AnonymousCredentials:
return false
}

return true
return !aws.IsCredentialsProvider(p, (*aws.AnonymousCredentials)(nil))
}

type payloadHashKey struct{}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public static void writeCredentialProviderResolver(GoWriter writer) {
AwsCustomGoDependency.INTERNAL_SIGV4A).build());
writer.putContext("anonType", SymbolUtils.createPointableSymbolBuilder("AnonymousCredentials",
AwsCustomGoDependency.AWS_CORE).build());
writer.putContext("isProvider", SymbolUtils.createValueSymbolBuilder("IsCredentialsProvider",
AwsCustomGoDependency.AWS_CORE).build());
writer.putContext("adapType", SymbolUtils.createPointableSymbolBuilder("SymmetricCredentialAdaptor",
AwsCustomGoDependency.INTERNAL_SIGV4A).build());
writer.write("""
Expand All @@ -54,9 +56,8 @@ public static void writeCredentialProviderResolver(GoWriter writer) {
return
}
switch o.$fieldName:L.(type) {
case $anonType:T, $anonType:P:
return
if $isProvider:T(o.$fieldName:L, ($anonType:P)(nil)) {
return
}
o.$fieldName:L = &$adapType:T{SymmetricProvider: o.$fieldName:L}
Expand Down
3 changes: 1 addition & 2 deletions service/eventbridge/api_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions service/s3/api_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1c05fb6

Please sign in to comment.