Skip to content

Commit

Permalink
aws/session: Add support for assuming role via Web Identity Tokens (#…
Browse files Browse the repository at this point in the history
…2667)

Adds support for assuming an role via the Web Identity Token. Allows for OIDC token files to be used by specifying the token path through the AWS_WEB_IDENTITY_TOKEN_FILE, and AWS_ROLE_ARN environment variables.

Replaces PR #2193
  • Loading branch information
jasdel authored Jul 17, 2019
1 parent 8598ee2 commit 2e1d76a
Show file tree
Hide file tree
Showing 12 changed files with 838 additions and 400 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
### SDK Features
* `aws/session`: Add support for assuming role via Web Identity Token ([#2667](https://github.com/aws/aws-sdk-go/pull/2667))
* Adds support for assuming an role via the Web Identity Token. Allows for OIDC token files to be used by specifying the token path through the AWS_WEB_IDENTITY_TOKEN_FILE, and AWS_ROLE_ARN environment variables.

### SDK Enhancements

Expand Down
6 changes: 6 additions & 0 deletions aws/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ type Value struct {
ProviderName string
}

// HasKeys returns if the credentials Value has both AccessKeyID and
// SecretAccessKey value set.
func (v Value) HasKeys() bool {
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
}

// A Provider is the interface for any component which will provide credentials
// Value. A provider is required to manage its own Expired state, and what to
// be expired means.
Expand Down
1 change: 1 addition & 0 deletions aws/credentials/stscreds/testdata/token.jwt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
foo
99 changes: 99 additions & 0 deletions aws/credentials/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package stscreds

import (
"fmt"
"io/ioutil"
"strconv"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

const (
// ErrCodeWebIdentity will be used as an error code when constructing
// a new error to be returned during session creation or retrieval.
ErrCodeWebIdentity = "WebIdentityErr"

// WebIdentityProviderName is the web identity provider name
WebIdentityProviderName = "WebIdentityCredentials"
)

// now is used to return a time.Time object representing
// the current time. This can be used to easily test and
// compare test values.
var now = func() time.Time {
return time.Now()
}

// WebIdentityRoleProvider is used to retrieve credentials using
// an OIDC token.
type WebIdentityRoleProvider struct {
credentials.Expiry

client stsiface.STSAPI
ExpiryWindow time.Duration

tokenFilePath string
roleARN string
roleSessionName string
}

// NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
svc := sts.New(c)
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
return credentials.NewCredentials(p)
}

// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
return &WebIdentityRoleProvider{
client: svc,
tokenFilePath: path,
roleARN: roleARN,
roleSessionName: roleSessionName,
}
}

// Retrieve will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
b, err := ioutil.ReadFile(p.tokenFilePath)
if err != nil {
errMsg := fmt.Sprintf("unable to read file at %s", p.tokenFilePath)
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, errMsg, err)
}

sessionName := p.roleSessionName
if len(sessionName) == 0 {
// session name is used to uniquely identify a session. This simply
// uses unix time in nanoseconds to uniquely identify sessions.
sessionName = strconv.FormatInt(now().UnixNano(), 10)
}
resp, err := p.client.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{
RoleArn: &p.roleARN,
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
})
if err != nil {
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
}

p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)

value := credentials.Value{
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId),
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
SessionToken: aws.StringValue(resp.Credentials.SessionToken),
ProviderName: WebIdentityProviderName,
}
return value, nil
}
113 changes: 113 additions & 0 deletions aws/credentials/stscreds/web_identity_provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// +build go1.7

package stscreds

import (
"fmt"
"reflect"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
)

type mockSTS struct {
*sts.STS
AssumeRoleWithWebIdentityFn func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error)
}

func (m *mockSTS) AssumeRoleWithWebIdentity(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if m.AssumeRoleWithWebIdentityFn != nil {
return m.AssumeRoleWithWebIdentityFn(input)
}

return nil, nil
}

func TestWebIdentityProviderRetrieve(t *testing.T) {
now = func() time.Time {
return time.Time{}
}

cases := []struct {
name string
mockSTS *mockSTS
roleARN string
tokenFilepath string
sessionName string
expectedError error
expectedCredValue credentials.Value
}{
{
name: "session name case",
roleARN: "arn",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
mockSTS: &mockSTS{
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}, nil
},
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
ProviderName: WebIdentityProviderName,
},
},
{
name: "valid case",
roleARN: "arn",
tokenFilepath: "testdata/token.jwt",
mockSTS: &mockSTS{
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) {
if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}, nil
},
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
ProviderName: WebIdentityProviderName,
},
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
p := NewWebIdentityRoleProvider(c.mockSTS, c.roleARN, c.sessionName, c.tokenFilepath)
credValue, err := p.Retrieve()
if e, a := c.expectedError, err; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
})
}
}
Loading

0 comments on commit 2e1d76a

Please sign in to comment.