Skip to content

Commit

Permalink
Merge pull request #125 from hashicorp/assume-role-timeout
Browse files Browse the repository at this point in the history
Add credentials adapter from v2 to v1
  • Loading branch information
gdavison authored Feb 18, 2022
2 parents c32dd32 + 0867894 commit 5496790
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 11 deletions.
92 changes: 92 additions & 0 deletions v2/awsv1shim/credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package awsv1shim

import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim
"context"
"fmt"
"sync/atomic"
"time"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
)

type v2CredentialsProvider struct {
provider awsv2.CredentialsProvider

v2creds atomic.Value
}

// This adapter deals with multiple levels of caching and a slight mismatch between the AWS SDK for Go v1 and v2 credentials models.
// In the SDK v1 model has a root `credentials.Credentials` struct that handles caching. The `credentials.Value` contains only keys.
// The `credentials.Credentials` struct handles expiry information by calling the credentials provider.
// In the SDK v2 model, the SDK returns an `aws.CredentialsCache` which handles caching. The `aws.Credentials` value contains keys
// as well as the expiry information.
//
// The `v2CredentialsProvider` will typically be used with the following layout:
// (v1)`credentials.Credentials` ==> `v2CredentialsProvider` ==> (v2)`aws.CredentialsCache` ==> (v2)<actual credentials provider>
//
// Since the SDK v1 `credentials.Credentials` handles expiry, it has an `Expire` function to explicitly expire credentials. This is
// used, for example, in the SDK v1 default retry handler to catch an expired credentials error. Because of this, the result of
// `RetrieveWithContext` cannot be cached in `v2CredentialsProvider`.
// NOTE: Since the `Expire()` call is not passed up the chain, the (v2)`aws.CredentialsCache` will not have its cache cleared. This
// may cause problems if a credential is revoked early. If this becomes a problem, every call to `RetrieveWithContext` may need to
// call `Invalidate()` on the (v2)`aws.CredentialsCache`. In practice, `RetrieveWithContext` is rarely called, so this is not likely
// to have a significant impact.
//
// The expiry information is cached in `v2CredentialsProvider` because the SDK v1 model handles expiry separately from the credential
// information, and otherwise calling `IsExpired()` and `ExpiresAt()` would potentially call the actual credential provider on each call.

func (p *v2CredentialsProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
v2creds, err := p.provider.Retrieve(ctx)
if err != nil {
return credentials.Value{}, err
}
p.v2creds.Store(&v2creds)

return credentials.Value{
AccessKeyID: v2creds.AccessKeyID,
SecretAccessKey: v2creds.SecretAccessKey,
SessionToken: v2creds.SessionToken,
ProviderName: fmt.Sprintf("v2Credentials(%s)", v2creds.Source),
}, nil
}

func (p *v2CredentialsProvider) IsExpired() bool {
v2creds := p.credentials()
if v2creds != nil {
return v2creds.Expired()
}
return true
}

func (p *v2CredentialsProvider) ExpiresAt() time.Time {
v2creds := p.credentials()
if v2creds != nil {
return v2creds.Expires
}
return time.Time{}
}

func (p *v2CredentialsProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(context.Background())
}

func (p *v2CredentialsProvider) credentials() *awsv2.Credentials {
v := p.v2creds.Load()
if v == nil {
return nil
}

c := v.(*awsv2.Credentials)
if c != nil && c.HasKeys() && !c.Expired() {
return c
}

return nil
}

func newV2Credentials(v2provider awsv2.CredentialsProvider) *credentials.Credentials {
return credentials.NewCredentials(&v2CredentialsProvider{
provider: v2provider,
})
}
231 changes: 231 additions & 0 deletions v2/awsv1shim/credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
package awsv1shim

import (
"context"
"fmt"
"testing"
"time"

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials"
stscredsv2 "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
ststypesv2 "github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/aws/aws-sdk-go/aws"
)

func TestV2CredentialsProviderPassthrough(t *testing.T) {
v2creds := credentialsv2.NewStaticCredentialsProvider("key", "secret", "session")

creds := newV2Credentials(v2creds)

value, err := creds.GetWithContext(context.Background())
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

if a, e := value.AccessKeyID, "key"; a != e {
t.Errorf("AccessKeyID: expected %q, got %q", e, a)
}
if a, e := value.SecretAccessKey, "secret"; a != e {
t.Errorf("SecretAccessKey: expected %q, got %q", e, a)
}
if a, e := value.SessionToken, "session"; a != e {
t.Errorf("SecretAccessKey: expected %q, got %q", e, a)
}
if a, e := value.ProviderName, fmt.Sprintf("v2Credentials(%s)", credentialsv2.StaticCredentialsName); a != e {
t.Errorf("ProviderName: expected %q, got %q", e, a)
}
}

func TestV2CredentialsProviderExpriry(t *testing.T) {
testcases := map[string]struct {
v2creds awsv2.CredentialsProvider
}{
credentialsv2.StaticCredentialsName: {
v2creds: credentialsv2.NewStaticCredentialsProvider("key", "secret", "session"),
},
}

for name, testcase := range testcases {
t.Run(name, func(t *testing.T) {
creds := newV2Credentials(testcase.v2creds)
// Credentials need to be retrieved before we can check
_, err := creds.GetWithContext(context.Background())
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if creds.IsExpired() {
t.Fatalf("did not expect creds to be expired")
}
expiry, err := creds.ExpiresAt()
if err != nil {
t.Fatalf("unexpected error getting expiry: %s", err)
}
if !expiry.Equal(time.Time{}) {
t.Fatalf("expected no expiry time, got %s", expiry)
}

creds.Expire()
if !creds.IsExpired() {
t.Fatalf("expected creds to be expired")
}

value, err := creds.GetWithContext(context.Background())
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if value.AccessKeyID == "" {
t.Error("AccessKeyID: expected a value")
}
if value.SecretAccessKey == "" {
t.Error("SecretAccessKey: expected a value")
}
if value.SessionToken == "" {
t.Error("SessionToken: expected a value")
}
})
}
}

func TestV2CredentialsProviderExpriry_AssumeRole(t *testing.T) {
stsClient := &mockAssumeRole{}
v2creds := stscredsv2.NewAssumeRoleProvider(stsClient, "role")

creds := newV2Credentials(v2creds)
// Credentials need to be retrieved before we can check expiry information
_, err := creds.GetWithContext(context.Background())
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if creds.IsExpired() {
t.Fatalf("did not expect creds to be expired")
}
expiry, err := creds.ExpiresAt()
if err != nil {
t.Fatalf("unexpected error getting expiry: %s", err)
}
if expiry.Equal(time.Time{}) {
t.Fatal("expected expiry time, got none")
}

creds.Expire()
if !creds.IsExpired() {
t.Fatalf("expected creds to be expired")
}

value, err := creds.GetWithContext(context.Background())
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if value.AccessKeyID == "" {
t.Error("AccessKeyID: expected a value")
}
if value.SecretAccessKey == "" {
t.Error("SecretAccessKey: expected a value")
}
if value.SessionToken == "" {
t.Error("SessionToken: expected a value")
}
}

func TestV2CredentialsProviderCaching(t *testing.T) {
stsClientCalls := 0
expectedStsClientCalls := 0
stsClient := &mockAssumeRole{
TestInput: func(in *stsv2.AssumeRoleInput) {
stsClientCalls++
},
}
v2creds := stscredsv2.NewAssumeRoleProvider(stsClient, "role")
creds := newV2Credentials(v2creds)
if stsClientCalls != expectedStsClientCalls {
t.Errorf("did not expect call to STS client")
expectedStsClientCalls = stsClientCalls
}

_, err := creds.GetWithContext(context.Background())
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
expectedStsClientCalls++
if stsClientCalls != expectedStsClientCalls {
t.Errorf("expected call to STS client")
expectedStsClientCalls = stsClientCalls
}

_, err = creds.GetWithContext(context.Background())
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if stsClientCalls != expectedStsClientCalls {
t.Errorf("did not expect call to STS client")
expectedStsClientCalls = stsClientCalls
}

creds.IsExpired()
if stsClientCalls != expectedStsClientCalls {
t.Errorf("did not expect call to STS client")
expectedStsClientCalls = stsClientCalls
}

_, err = creds.ExpiresAt()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if stsClientCalls != expectedStsClientCalls {
t.Errorf("did not expect call to STS client")
expectedStsClientCalls = stsClientCalls
}

creds.Expire()
if stsClientCalls != expectedStsClientCalls {
t.Errorf("did not expect call to STS client")
expectedStsClientCalls = stsClientCalls
}

creds.IsExpired()
if stsClientCalls != expectedStsClientCalls {
t.Errorf("did not expect call to STS client")
expectedStsClientCalls = stsClientCalls
}

_, err = creds.ExpiresAt()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if stsClientCalls != expectedStsClientCalls {
t.Errorf("did not expect call to STS client")
expectedStsClientCalls = stsClientCalls
}

_, err = creds.GetWithContext(context.Background())
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
expectedStsClientCalls++
if stsClientCalls != expectedStsClientCalls {
t.Errorf("expected call to STS client")
}
}

type mockAssumeRole struct {
TestInput func(*stsv2.AssumeRoleInput)
}

func (s *mockAssumeRole) AssumeRole(ctx context.Context, params *stsv2.AssumeRoleInput, optFns ...func(*stsv2.Options)) (*stsv2.AssumeRoleOutput, error) {
if s.TestInput != nil {
s.TestInput(params)
}
expiry := time.Now().Add(60 * time.Minute)

return &stsv2.AssumeRoleOutput{
Credentials: &ststypesv2.Credentials{
// Just reflect the role arn to the provider.
AccessKeyId: params.RoleArn,
SecretAccessKey: aws.String("assumedSecretAccessKey"),
SessionToken: aws.String("assumedSessionToken"),
Expiration: &expiry,
},
}, nil
}
2 changes: 2 additions & 0 deletions v2/awsv1shim/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module github.com/hashicorp/aws-sdk-go-base/v2/awsv1shim/v2
require (
github.com/aws/aws-sdk-go v1.42.52
github.com/aws/aws-sdk-go-v2 v1.13.0
github.com/aws/aws-sdk-go-v2/credentials v1.8.0
github.com/aws/aws-sdk-go-v2/service/sts v1.14.0
github.com/google/go-cmp v0.5.7
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.7
github.com/hashicorp/go-cleanhttp v0.5.2
Expand Down
12 changes: 1 addition & 11 deletions v2/awsv1shim/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim

awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -22,11 +21,6 @@ import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim
// options based on pre-existing credential provider, configured profile, or
// fallback to automatically a determined session via the AWS Go SDK.
func getSessionOptions(awsC *awsv2.Config, c *awsbase.Config) (*session.Options, error) {
creds, err := awsC.Credentials.Retrieve(context.Background())
if err != nil {
return nil, fmt.Errorf("error accessing credentials: %w", err)
}

useFIPSEndpoint, _, err := awsconfig.ResolveUseFIPSEndpoint(context.Background(), awsC.ConfigSources)
if err != nil {
return nil, fmt.Errorf("error resolving configuration: %w", err)
Expand All @@ -44,11 +38,7 @@ func getSessionOptions(awsC *awsv2.Config, c *awsbase.Config) (*session.Options,

options := &session.Options{
Config: aws.Config{
Credentials: credentials.NewStaticCredentials(
creds.AccessKeyID,
creds.SecretAccessKey,
creds.SessionToken,
),
Credentials: newV2Credentials(awsC.Credentials),
HTTPClient: httpClient,
LogLevel: aws.LogLevel(aws.LogDebugWithHTTPBody | aws.LogDebugWithRequestRetries | aws.LogDebugWithRequestErrors),
Logger: debugLogger{},
Expand Down

0 comments on commit 5496790

Please sign in to comment.