Skip to content

Commit

Permalink
Remove parameterized AWS session from token.go
Browse files Browse the repository at this point in the history
This simplifies the API, and removes the unnecessary GetWithRoleForSession()
method. This also simplifies migration to aws-sdk-go-v2 by allowing
both Generator and TokenOptions to be not bound to a specific SDK version.
  • Loading branch information
micahhausler committed Aug 26, 2024
1 parent 80ee77e commit 2140ea6
Showing 1 changed file with 33 additions and 50 deletions.
83 changes: 33 additions & 50 deletions pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ type GetTokenOptions struct {
AssumeRoleARN string
AssumeRoleExternalID string
SessionName string
Session *session.Session
}

// FormatError is returned when there is a problem with token that is
Expand Down Expand Up @@ -186,8 +185,6 @@ type Generator interface {
Get(string) (Token, error)
// GetWithRole creates a token by assuming the provided role, using the credentials in the default chain.
GetWithRole(clusterID, roleARN string) (Token, error)
// GetWithRoleForSession creates a token by assuming the provided role, using the provided session.
GetWithRoleForSession(clusterID string, roleARN string, sess *session.Session) (Token, error)
// Get a token using the provided options
GetWithOptions(options *GetTokenOptions) (Token, error)
// GetWithSTS returns a token valid for clusterID using the given STS client.
Expand Down Expand Up @@ -226,16 +223,6 @@ func (g generator) GetWithRole(clusterID string, roleARN string) (Token, error)
})
}

// GetWithRoleForSession assumes the given AWS IAM role for the given session and behaves
// like GetWithRole.
func (g generator) GetWithRoleForSession(clusterID string, roleARN string, sess *session.Session) (Token, error) {
return g.GetWithOptions(&GetTokenOptions{
ClusterID: clusterID,
AssumeRoleARN: roleARN,
Session: sess,
})
}

// StdinStderrTokenProvider gets MFA token from standard input.
func StdinStderrTokenProvider() (string, error) {
var v string
Expand All @@ -252,46 +239,42 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) {
return Token{}, fmt.Errorf("ClusterID is required")
}

if options.Session == nil {
// create a session with the "base" credentials available
// (from environment variable, profile files, EC2 metadata, etc)
sess, err := session.NewSessionWithOptions(session.Options{
AssumeRoleTokenProvider: StdinStderrTokenProvider,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return Token{}, fmt.Errorf("could not create session: %v", err)
}
sess.Handlers.Build.PushFrontNamed(request.NamedHandler{
Name: "authenticatorUserAgent",
Fn: request.MakeAddToUserAgentHandler(
"aws-iam-authenticator", pkg.Version),
})
if options.Region != "" {
sess = sess.Copy(aws.NewConfig().WithRegion(options.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint))
}
// create a session with the "base" credentials available
// (from environment variable, profile files, EC2 metadata, etc)
sess, err := session.NewSessionWithOptions(session.Options{
AssumeRoleTokenProvider: StdinStderrTokenProvider,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return Token{}, fmt.Errorf("could not create session: %v", err)
}
sess.Handlers.Build.PushFrontNamed(request.NamedHandler{
Name: "authenticatorUserAgent",
Fn: request.MakeAddToUserAgentHandler(
"aws-iam-authenticator", pkg.Version),
})
if options.Region != "" {
sess = sess.Copy(aws.NewConfig().WithRegion(options.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint))
}

if g.cache {
// figure out what profile we're using
var profile string
if v := os.Getenv("AWS_PROFILE"); len(v) > 0 {
profile = v
} else {
profile = session.DefaultSharedConfigProfile
}
// create a cacheing Provider wrapper around the Credentials
if cacheProvider, err := NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil {
sess.Config.Credentials = credentials.NewCredentials(&cacheProvider)
} else {
_, _ = fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err)
}
if g.cache {
// figure out what profile we're using
var profile string
if v := os.Getenv("AWS_PROFILE"); len(v) > 0 {
profile = v
} else {
profile = session.DefaultSharedConfigProfile
}
// create a cacheing Provider wrapper around the Credentials
if cacheProvider, err := NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil {
sess.Config.Credentials = credentials.NewCredentials(&cacheProvider)
} else {
fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err)
}

options.Session = sess
}

// use an STS client based on the direct credentials
stsAPI := sts.New(options.Session)
stsAPI := sts.New(sess)

// if a roleARN was specified, replace the STS client with one that uses
// temporary credentials from that role.
Expand Down Expand Up @@ -326,10 +309,10 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) {
}

// create STS-based credentials that will assume the given role
creds := stscreds.NewCredentials(options.Session, options.AssumeRoleARN, sessionSetters...)
creds := stscreds.NewCredentials(sess, options.AssumeRoleARN, sessionSetters...)

// create an STS API interface that uses the assumed role's temporary credentials
stsAPI = sts.New(options.Session, &aws.Config{Credentials: creds})
stsAPI = sts.New(sess, &aws.Config{Credentials: creds})
}

return g.GetWithSTS(options.ClusterID, stsAPI)
Expand Down

0 comments on commit 2140ea6

Please sign in to comment.