-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #125 from hashicorp/assume-role-timeout
Add credentials adapter from v2 to v1
- Loading branch information
Showing
4 changed files
with
326 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package awsv1shim | ||
|
||
import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim | ||
"context" | ||
"fmt" | ||
"sync/atomic" | ||
"time" | ||
|
||
awsv2 "github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
) | ||
|
||
type v2CredentialsProvider struct { | ||
provider awsv2.CredentialsProvider | ||
|
||
v2creds atomic.Value | ||
} | ||
|
||
// This adapter deals with multiple levels of caching and a slight mismatch between the AWS SDK for Go v1 and v2 credentials models. | ||
// In the SDK v1 model has a root `credentials.Credentials` struct that handles caching. The `credentials.Value` contains only keys. | ||
// The `credentials.Credentials` struct handles expiry information by calling the credentials provider. | ||
// In the SDK v2 model, the SDK returns an `aws.CredentialsCache` which handles caching. The `aws.Credentials` value contains keys | ||
// as well as the expiry information. | ||
// | ||
// The `v2CredentialsProvider` will typically be used with the following layout: | ||
// (v1)`credentials.Credentials` ==> `v2CredentialsProvider` ==> (v2)`aws.CredentialsCache` ==> (v2)<actual credentials provider> | ||
// | ||
// Since the SDK v1 `credentials.Credentials` handles expiry, it has an `Expire` function to explicitly expire credentials. This is | ||
// used, for example, in the SDK v1 default retry handler to catch an expired credentials error. Because of this, the result of | ||
// `RetrieveWithContext` cannot be cached in `v2CredentialsProvider`. | ||
// NOTE: Since the `Expire()` call is not passed up the chain, the (v2)`aws.CredentialsCache` will not have its cache cleared. This | ||
// may cause problems if a credential is revoked early. If this becomes a problem, every call to `RetrieveWithContext` may need to | ||
// call `Invalidate()` on the (v2)`aws.CredentialsCache`. In practice, `RetrieveWithContext` is rarely called, so this is not likely | ||
// to have a significant impact. | ||
// | ||
// The expiry information is cached in `v2CredentialsProvider` because the SDK v1 model handles expiry separately from the credential | ||
// information, and otherwise calling `IsExpired()` and `ExpiresAt()` would potentially call the actual credential provider on each call. | ||
|
||
func (p *v2CredentialsProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { | ||
v2creds, err := p.provider.Retrieve(ctx) | ||
if err != nil { | ||
return credentials.Value{}, err | ||
} | ||
p.v2creds.Store(&v2creds) | ||
|
||
return credentials.Value{ | ||
AccessKeyID: v2creds.AccessKeyID, | ||
SecretAccessKey: v2creds.SecretAccessKey, | ||
SessionToken: v2creds.SessionToken, | ||
ProviderName: fmt.Sprintf("v2Credentials(%s)", v2creds.Source), | ||
}, nil | ||
} | ||
|
||
func (p *v2CredentialsProvider) IsExpired() bool { | ||
v2creds := p.credentials() | ||
if v2creds != nil { | ||
return v2creds.Expired() | ||
} | ||
return true | ||
} | ||
|
||
func (p *v2CredentialsProvider) ExpiresAt() time.Time { | ||
v2creds := p.credentials() | ||
if v2creds != nil { | ||
return v2creds.Expires | ||
} | ||
return time.Time{} | ||
} | ||
|
||
func (p *v2CredentialsProvider) Retrieve() (credentials.Value, error) { | ||
return p.RetrieveWithContext(context.Background()) | ||
} | ||
|
||
func (p *v2CredentialsProvider) credentials() *awsv2.Credentials { | ||
v := p.v2creds.Load() | ||
if v == nil { | ||
return nil | ||
} | ||
|
||
c := v.(*awsv2.Credentials) | ||
if c != nil && c.HasKeys() && !c.Expired() { | ||
return c | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func newV2Credentials(v2provider awsv2.CredentialsProvider) *credentials.Credentials { | ||
return credentials.NewCredentials(&v2CredentialsProvider{ | ||
provider: v2provider, | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
package awsv1shim | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"testing" | ||
"time" | ||
|
||
awsv2 "github.com/aws/aws-sdk-go-v2/aws" | ||
credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials" | ||
stscredsv2 "github.com/aws/aws-sdk-go-v2/credentials/stscreds" | ||
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts" | ||
ststypesv2 "github.com/aws/aws-sdk-go-v2/service/sts/types" | ||
"github.com/aws/aws-sdk-go/aws" | ||
) | ||
|
||
func TestV2CredentialsProviderPassthrough(t *testing.T) { | ||
v2creds := credentialsv2.NewStaticCredentialsProvider("key", "secret", "session") | ||
|
||
creds := newV2Credentials(v2creds) | ||
|
||
value, err := creds.GetWithContext(context.Background()) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
|
||
if a, e := value.AccessKeyID, "key"; a != e { | ||
t.Errorf("AccessKeyID: expected %q, got %q", e, a) | ||
} | ||
if a, e := value.SecretAccessKey, "secret"; a != e { | ||
t.Errorf("SecretAccessKey: expected %q, got %q", e, a) | ||
} | ||
if a, e := value.SessionToken, "session"; a != e { | ||
t.Errorf("SecretAccessKey: expected %q, got %q", e, a) | ||
} | ||
if a, e := value.ProviderName, fmt.Sprintf("v2Credentials(%s)", credentialsv2.StaticCredentialsName); a != e { | ||
t.Errorf("ProviderName: expected %q, got %q", e, a) | ||
} | ||
} | ||
|
||
func TestV2CredentialsProviderExpriry(t *testing.T) { | ||
testcases := map[string]struct { | ||
v2creds awsv2.CredentialsProvider | ||
}{ | ||
credentialsv2.StaticCredentialsName: { | ||
v2creds: credentialsv2.NewStaticCredentialsProvider("key", "secret", "session"), | ||
}, | ||
} | ||
|
||
for name, testcase := range testcases { | ||
t.Run(name, func(t *testing.T) { | ||
creds := newV2Credentials(testcase.v2creds) | ||
// Credentials need to be retrieved before we can check | ||
_, err := creds.GetWithContext(context.Background()) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
if creds.IsExpired() { | ||
t.Fatalf("did not expect creds to be expired") | ||
} | ||
expiry, err := creds.ExpiresAt() | ||
if err != nil { | ||
t.Fatalf("unexpected error getting expiry: %s", err) | ||
} | ||
if !expiry.Equal(time.Time{}) { | ||
t.Fatalf("expected no expiry time, got %s", expiry) | ||
} | ||
|
||
creds.Expire() | ||
if !creds.IsExpired() { | ||
t.Fatalf("expected creds to be expired") | ||
} | ||
|
||
value, err := creds.GetWithContext(context.Background()) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
if value.AccessKeyID == "" { | ||
t.Error("AccessKeyID: expected a value") | ||
} | ||
if value.SecretAccessKey == "" { | ||
t.Error("SecretAccessKey: expected a value") | ||
} | ||
if value.SessionToken == "" { | ||
t.Error("SessionToken: expected a value") | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestV2CredentialsProviderExpriry_AssumeRole(t *testing.T) { | ||
stsClient := &mockAssumeRole{} | ||
v2creds := stscredsv2.NewAssumeRoleProvider(stsClient, "role") | ||
|
||
creds := newV2Credentials(v2creds) | ||
// Credentials need to be retrieved before we can check expiry information | ||
_, err := creds.GetWithContext(context.Background()) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
if creds.IsExpired() { | ||
t.Fatalf("did not expect creds to be expired") | ||
} | ||
expiry, err := creds.ExpiresAt() | ||
if err != nil { | ||
t.Fatalf("unexpected error getting expiry: %s", err) | ||
} | ||
if expiry.Equal(time.Time{}) { | ||
t.Fatal("expected expiry time, got none") | ||
} | ||
|
||
creds.Expire() | ||
if !creds.IsExpired() { | ||
t.Fatalf("expected creds to be expired") | ||
} | ||
|
||
value, err := creds.GetWithContext(context.Background()) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
if value.AccessKeyID == "" { | ||
t.Error("AccessKeyID: expected a value") | ||
} | ||
if value.SecretAccessKey == "" { | ||
t.Error("SecretAccessKey: expected a value") | ||
} | ||
if value.SessionToken == "" { | ||
t.Error("SessionToken: expected a value") | ||
} | ||
} | ||
|
||
func TestV2CredentialsProviderCaching(t *testing.T) { | ||
stsClientCalls := 0 | ||
expectedStsClientCalls := 0 | ||
stsClient := &mockAssumeRole{ | ||
TestInput: func(in *stsv2.AssumeRoleInput) { | ||
stsClientCalls++ | ||
}, | ||
} | ||
v2creds := stscredsv2.NewAssumeRoleProvider(stsClient, "role") | ||
creds := newV2Credentials(v2creds) | ||
if stsClientCalls != expectedStsClientCalls { | ||
t.Errorf("did not expect call to STS client") | ||
expectedStsClientCalls = stsClientCalls | ||
} | ||
|
||
_, err := creds.GetWithContext(context.Background()) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
expectedStsClientCalls++ | ||
if stsClientCalls != expectedStsClientCalls { | ||
t.Errorf("expected call to STS client") | ||
expectedStsClientCalls = stsClientCalls | ||
} | ||
|
||
_, err = creds.GetWithContext(context.Background()) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
if stsClientCalls != expectedStsClientCalls { | ||
t.Errorf("did not expect call to STS client") | ||
expectedStsClientCalls = stsClientCalls | ||
} | ||
|
||
creds.IsExpired() | ||
if stsClientCalls != expectedStsClientCalls { | ||
t.Errorf("did not expect call to STS client") | ||
expectedStsClientCalls = stsClientCalls | ||
} | ||
|
||
_, err = creds.ExpiresAt() | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
if stsClientCalls != expectedStsClientCalls { | ||
t.Errorf("did not expect call to STS client") | ||
expectedStsClientCalls = stsClientCalls | ||
} | ||
|
||
creds.Expire() | ||
if stsClientCalls != expectedStsClientCalls { | ||
t.Errorf("did not expect call to STS client") | ||
expectedStsClientCalls = stsClientCalls | ||
} | ||
|
||
creds.IsExpired() | ||
if stsClientCalls != expectedStsClientCalls { | ||
t.Errorf("did not expect call to STS client") | ||
expectedStsClientCalls = stsClientCalls | ||
} | ||
|
||
_, err = creds.ExpiresAt() | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
if stsClientCalls != expectedStsClientCalls { | ||
t.Errorf("did not expect call to STS client") | ||
expectedStsClientCalls = stsClientCalls | ||
} | ||
|
||
_, err = creds.GetWithContext(context.Background()) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
expectedStsClientCalls++ | ||
if stsClientCalls != expectedStsClientCalls { | ||
t.Errorf("expected call to STS client") | ||
} | ||
} | ||
|
||
type mockAssumeRole struct { | ||
TestInput func(*stsv2.AssumeRoleInput) | ||
} | ||
|
||
func (s *mockAssumeRole) AssumeRole(ctx context.Context, params *stsv2.AssumeRoleInput, optFns ...func(*stsv2.Options)) (*stsv2.AssumeRoleOutput, error) { | ||
if s.TestInput != nil { | ||
s.TestInput(params) | ||
} | ||
expiry := time.Now().Add(60 * time.Minute) | ||
|
||
return &stsv2.AssumeRoleOutput{ | ||
Credentials: &ststypesv2.Credentials{ | ||
// Just reflect the role arn to the provider. | ||
AccessKeyId: params.RoleArn, | ||
SecretAccessKey: aws.String("assumedSecretAccessKey"), | ||
SessionToken: aws.String("assumedSessionToken"), | ||
Expiration: &expiry, | ||
}, | ||
}, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters