-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
aws/session: Add support for assuming role via Web Identity Tokens (#…
…2667) Adds support for assuming an role via the Web Identity Token. Allows for OIDC token files to be used by specifying the token path through the AWS_WEB_IDENTITY_TOKEN_FILE, and AWS_ROLE_ARN environment variables. Replaces PR #2193
- Loading branch information
Showing
12 changed files
with
838 additions
and
400 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
foo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
package stscreds | ||
|
||
import ( | ||
"fmt" | ||
"io/ioutil" | ||
"strconv" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/awserr" | ||
"github.com/aws/aws-sdk-go/aws/client" | ||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
"github.com/aws/aws-sdk-go/service/sts" | ||
"github.com/aws/aws-sdk-go/service/sts/stsiface" | ||
) | ||
|
||
const ( | ||
// ErrCodeWebIdentity will be used as an error code when constructing | ||
// a new error to be returned during session creation or retrieval. | ||
ErrCodeWebIdentity = "WebIdentityErr" | ||
|
||
// WebIdentityProviderName is the web identity provider name | ||
WebIdentityProviderName = "WebIdentityCredentials" | ||
) | ||
|
||
// now is used to return a time.Time object representing | ||
// the current time. This can be used to easily test and | ||
// compare test values. | ||
var now = func() time.Time { | ||
return time.Now() | ||
} | ||
|
||
// WebIdentityRoleProvider is used to retrieve credentials using | ||
// an OIDC token. | ||
type WebIdentityRoleProvider struct { | ||
credentials.Expiry | ||
|
||
client stsiface.STSAPI | ||
ExpiryWindow time.Duration | ||
|
||
tokenFilePath string | ||
roleARN string | ||
roleSessionName string | ||
} | ||
|
||
// NewWebIdentityCredentials will return a new set of credentials with a given | ||
// configuration, role arn, and token file path. | ||
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials { | ||
svc := sts.New(c) | ||
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path) | ||
return credentials.NewCredentials(p) | ||
} | ||
|
||
// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the | ||
// provided stsiface.STSAPI | ||
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider { | ||
return &WebIdentityRoleProvider{ | ||
client: svc, | ||
tokenFilePath: path, | ||
roleARN: roleARN, | ||
roleSessionName: roleSessionName, | ||
} | ||
} | ||
|
||
// Retrieve will attempt to assume a role from a token which is located at | ||
// 'WebIdentityTokenFilePath' specified destination and if that is empty an | ||
// error will be returned. | ||
func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) { | ||
b, err := ioutil.ReadFile(p.tokenFilePath) | ||
if err != nil { | ||
errMsg := fmt.Sprintf("unable to read file at %s", p.tokenFilePath) | ||
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, errMsg, err) | ||
} | ||
|
||
sessionName := p.roleSessionName | ||
if len(sessionName) == 0 { | ||
// session name is used to uniquely identify a session. This simply | ||
// uses unix time in nanoseconds to uniquely identify sessions. | ||
sessionName = strconv.FormatInt(now().UnixNano(), 10) | ||
} | ||
resp, err := p.client.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{ | ||
RoleArn: &p.roleARN, | ||
RoleSessionName: &sessionName, | ||
WebIdentityToken: aws.String(string(b)), | ||
}) | ||
if err != nil { | ||
return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err) | ||
} | ||
|
||
p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow) | ||
|
||
value := credentials.Value{ | ||
AccessKeyID: aws.StringValue(resp.Credentials.AccessKeyId), | ||
SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey), | ||
SessionToken: aws.StringValue(resp.Credentials.SessionToken), | ||
ProviderName: WebIdentityProviderName, | ||
} | ||
return value, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
// +build go1.7 | ||
|
||
package stscreds | ||
|
||
import ( | ||
"fmt" | ||
"reflect" | ||
"testing" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
"github.com/aws/aws-sdk-go/service/sts" | ||
) | ||
|
||
type mockSTS struct { | ||
*sts.STS | ||
AssumeRoleWithWebIdentityFn func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) | ||
} | ||
|
||
func (m *mockSTS) AssumeRoleWithWebIdentity(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
if m.AssumeRoleWithWebIdentityFn != nil { | ||
return m.AssumeRoleWithWebIdentityFn(input) | ||
} | ||
|
||
return nil, nil | ||
} | ||
|
||
func TestWebIdentityProviderRetrieve(t *testing.T) { | ||
now = func() time.Time { | ||
return time.Time{} | ||
} | ||
|
||
cases := []struct { | ||
name string | ||
mockSTS *mockSTS | ||
roleARN string | ||
tokenFilepath string | ||
sessionName string | ||
expectedError error | ||
expectedCredValue credentials.Value | ||
}{ | ||
{ | ||
name: "session name case", | ||
roleARN: "arn", | ||
tokenFilepath: "testdata/token.jwt", | ||
sessionName: "foo", | ||
mockSTS: &mockSTS{ | ||
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
|
||
return &sts.AssumeRoleWithWebIdentityOutput{ | ||
Credentials: &sts.Credentials{ | ||
Expiration: aws.Time(time.Now()), | ||
AccessKeyId: aws.String("access-key-id"), | ||
SecretAccessKey: aws.String("secret-access-key"), | ||
SessionToken: aws.String("session-token"), | ||
}, | ||
}, nil | ||
}, | ||
}, | ||
expectedCredValue: credentials.Value{ | ||
AccessKeyID: "access-key-id", | ||
SecretAccessKey: "secret-access-key", | ||
SessionToken: "session-token", | ||
ProviderName: WebIdentityProviderName, | ||
}, | ||
}, | ||
{ | ||
name: "valid case", | ||
roleARN: "arn", | ||
tokenFilepath: "testdata/token.jwt", | ||
mockSTS: &mockSTS{ | ||
AssumeRoleWithWebIdentityFn: func(input *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { | ||
if e, a := fmt.Sprintf("%d", now().UnixNano()), *input.RoleSessionName; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
|
||
return &sts.AssumeRoleWithWebIdentityOutput{ | ||
Credentials: &sts.Credentials{ | ||
Expiration: aws.Time(time.Now()), | ||
AccessKeyId: aws.String("access-key-id"), | ||
SecretAccessKey: aws.String("secret-access-key"), | ||
SessionToken: aws.String("session-token"), | ||
}, | ||
}, nil | ||
}, | ||
}, | ||
expectedCredValue: credentials.Value{ | ||
AccessKeyID: "access-key-id", | ||
SecretAccessKey: "secret-access-key", | ||
SessionToken: "session-token", | ||
ProviderName: WebIdentityProviderName, | ||
}, | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
t.Run(c.name, func(t *testing.T) { | ||
p := NewWebIdentityRoleProvider(c.mockSTS, c.roleARN, c.sessionName, c.tokenFilepath) | ||
credValue, err := p.Retrieve() | ||
if e, a := c.expectedError, err; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
|
||
if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) { | ||
t.Errorf("expected %v, but received %v", e, a) | ||
} | ||
}) | ||
} | ||
} |
Oops, something went wrong.