diff --git a/aws/credentials/credentials.go b/aws/credentials/credentials.go index 894bbc7f82c..83bbc311b4d 100644 --- a/aws/credentials/credentials.go +++ b/aws/credentials/credentials.go @@ -50,9 +50,10 @@ package credentials import ( "fmt" - "github.com/aws/aws-sdk-go/aws/awserr" "sync" "time" + + "github.com/aws/aws-sdk-go/aws/awserr" ) // AnonymousCredentials is an empty Credential object that can be used as diff --git a/aws/credentials/stscreds/assume_role_provider.go b/aws/credentials/stscreds/assume_role_provider.go index b6dbfd2467d..2e528d130d4 100644 --- a/aws/credentials/stscreds/assume_role_provider.go +++ b/aws/credentials/stscreds/assume_role_provider.go @@ -200,7 +200,7 @@ type AssumeRoleProvider struct { // by a random percentage between 0 and MaxJitterFraction. MaxJitterFrac must // have a value between 0 and 1. Any other value may lead to expected behavior. // With a MaxJitterFrac value of 0, default) will no jitter will be used. - // + // // For example, with a Duration of 30m and a MaxJitterFrac of 0.1, the // AssumeRole call will be made with an arbitrary Duration between 27m and // 30m. @@ -258,7 +258,6 @@ func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(* // Retrieve generates a new set of temporary credentials using STS. func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) { - // Apply defaults where parameters are not set. if p.RoleSessionName == "" { // Try to work out a role name that will hopefully end up unique. diff --git a/aws/request/handlers.go b/aws/request/handlers.go index 8ef8548a96d..627ec722c05 100644 --- a/aws/request/handlers.go +++ b/aws/request/handlers.go @@ -59,6 +59,51 @@ func (h *Handlers) Clear() { h.Complete.Clear() } +// IsEmpty returns if there are no handlers in any of the handlerlists. +func (h *Handlers) IsEmpty() bool { + if h.Validate.Len() != 0 { + return false + } + if h.Build.Len() != 0 { + return false + } + if h.Send.Len() != 0 { + return false + } + if h.Sign.Len() != 0 { + return false + } + if h.Unmarshal.Len() != 0 { + return false + } + if h.UnmarshalStream.Len() != 0 { + return false + } + if h.UnmarshalMeta.Len() != 0 { + return false + } + if h.UnmarshalError.Len() != 0 { + return false + } + if h.ValidateResponse.Len() != 0 { + return false + } + if h.Retry.Len() != 0 { + return false + } + if h.AfterRetry.Len() != 0 { + return false + } + if h.CompleteAttempt.Len() != 0 { + return false + } + if h.Complete.Len() != 0 { + return false + } + + return true +} + // A HandlerListRunItem represents an entry in the HandlerList which // is being run. type HandlerListRunItem struct { diff --git a/aws/session/credentials.go b/aws/session/credentials.go new file mode 100644 index 00000000000..9e2652c5596 --- /dev/null +++ b/aws/session/credentials.go @@ -0,0 +1,202 @@ +package session + +import ( + "fmt" + "os" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/processcreds" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/defaults" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/internal/shareddefaults" +) + +// valid credential source values +const ( + credSourceEc2Metadata = "Ec2InstanceMetadata" + credSourceEnvironment = "Environment" + credSourceECSContainer = "EcsContainer" +) + +func resolveCredentials(cfg *aws.Config, + envCfg envConfig, sharedCfg sharedConfig, + handlers request.Handlers, + sessOpts Options, +) (*credentials.Credentials, error) { + // Credentials from Assume Role with specific credentials source. + if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.CredentialSource) > 0 { + return resolveCredsFromSource(cfg, envCfg, sharedCfg, handlers, sessOpts) + } + + // Credentials from environment variables + if len(envCfg.Creds.AccessKeyID) > 0 { + return credentials.NewStaticCredentialsFromCreds(envCfg.Creds), nil + } + + // Fallback to the "default" credential resolution chain. + return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts) +} + +func resolveCredsFromProfile(cfg *aws.Config, + envCfg envConfig, sharedCfg sharedConfig, + handlers request.Handlers, + sessOpts Options, +) (*credentials.Credentials, error) { + + if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil { + // Assume IAM role with credentials source from a different profile. + cred, err := resolveCredsFromProfile(cfg, envCfg, *sharedCfg.AssumeRoleSource, handlers, sessOpts) + if err != nil { + return nil, err + } + + cfgCp := *cfg + cfgCp.Credentials = cred + return credsFromAssumeRole(cfgCp, handlers, sharedCfg, sessOpts) + + } else if len(sharedCfg.Creds.AccessKeyID) > 0 { + // Static Credentials from Shared Config/Credentials file. + return credentials.NewStaticCredentialsFromCreds( + sharedCfg.Creds, + ), nil + + } else if len(sharedCfg.CredentialProcess) > 0 { + // Credential Process credentials from Shared Config/Credentials file. + return processcreds.NewCredentials( + sharedCfg.CredentialProcess, + ), nil + + } else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.CredentialSource) > 0 { + // Assume IAM Role with specific credential source. + return resolveCredsFromSource(cfg, envCfg, sharedCfg, handlers, sessOpts) + } + + // Fallback to default credentials provider, include mock errors + // for the credential chain so user can identify why credentials + // failed to be retrieved. + return credentials.NewCredentials(&credentials.ChainProvider{ + VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors), + Providers: []credentials.Provider{ + &credProviderError{ + Err: awserr.New("EnvAccessKeyNotFound", + "failed to find credentials in the environment.", nil), + }, + &credProviderError{ + Err: awserr.New("SharedCredsLoad", + fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil), + }, + defaults.RemoteCredProvider(*cfg, handlers), + }, + }), nil +} + +func resolveCredsFromSource(cfg *aws.Config, + envCfg envConfig, sharedCfg sharedConfig, + handlers request.Handlers, + sessOpts Options, +) (*credentials.Credentials, error) { + // if both credential_source and source_profile have been set, return an + // error as this is undefined behavior. Only one can be used at a time + // within a profile. + if len(sharedCfg.AssumeRole.SourceProfile) > 0 { + return nil, ErrSharedConfigSourceCollision + } + + cfgCp := *cfg + switch sharedCfg.AssumeRole.CredentialSource { + case credSourceEc2Metadata: + p := defaults.RemoteCredProvider(cfgCp, handlers) + cfgCp.Credentials = credentials.NewCredentials(p) + + case credSourceEnvironment: + cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds(envCfg.Creds) + + case credSourceECSContainer: + if len(os.Getenv(shareddefaults.ECSCredsProviderEnvVar)) == 0 { + return nil, ErrSharedConfigECSContainerEnvVarEmpty + } + + p := defaults.RemoteCredProvider(cfgCp, handlers) + cfgCp.Credentials = credentials.NewCredentials(p) + + default: + return nil, ErrSharedConfigInvalidCredSource + } + + return credsFromAssumeRole(cfgCp, handlers, sharedCfg, sessOpts) +} + +func credsFromAssumeRole(cfg aws.Config, + handlers request.Handlers, + sharedCfg sharedConfig, + sessOpts Options, +) (*credentials.Credentials, error) { + if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil { + // AssumeRole Token provider is required if doing Assume Role + // with MFA. + return nil, AssumeRoleTokenProviderNotSetError{} + } + + return stscreds.NewCredentials( + &Session{ + Config: &cfg, + Handlers: handlers.Copy(), + }, + sharedCfg.AssumeRole.RoleARN, + func(opt *stscreds.AssumeRoleProvider) { + opt.RoleSessionName = sharedCfg.AssumeRole.RoleSessionName + + // Assume role with external ID + if len(sharedCfg.AssumeRole.ExternalID) > 0 { + opt.ExternalID = aws.String(sharedCfg.AssumeRole.ExternalID) + } + + // Assume role with MFA + if len(sharedCfg.AssumeRole.MFASerial) > 0 { + opt.SerialNumber = aws.String(sharedCfg.AssumeRole.MFASerial) + opt.TokenProvider = sessOpts.AssumeRoleTokenProvider + } + }, + ), nil +} + +// AssumeRoleTokenProviderNotSetError is an error returned when creating a session when the +// MFAToken option is not set when shared config is configured load assume a +// role with an MFA token. +type AssumeRoleTokenProviderNotSetError struct{} + +// Code is the short id of the error. +func (e AssumeRoleTokenProviderNotSetError) Code() string { + return "AssumeRoleTokenProviderNotSetError" +} + +// Message is the description of the error +func (e AssumeRoleTokenProviderNotSetError) Message() string { + return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.") +} + +// OrigErr is the underlying error that caused the failure. +func (e AssumeRoleTokenProviderNotSetError) OrigErr() error { + return nil +} + +// Error satisfies the error interface. +func (e AssumeRoleTokenProviderNotSetError) Error() string { + return awserr.SprintError(e.Code(), e.Message(), "", nil) +} + +type credProviderError struct { + Err error +} + +var emptyCreds = credentials.Value{} + +func (c credProviderError) Retrieve() (credentials.Value, error) { + return credentials.Value{}, c.Err +} +func (c credProviderError) IsExpired() bool { + return true +} diff --git a/aws/session/credentials_test.go b/aws/session/credentials_test.go new file mode 100644 index 00000000000..ad6a4edf50a --- /dev/null +++ b/aws/session/credentials_test.go @@ -0,0 +1,416 @@ +// +build go1.7 + +package session + +import ( + "fmt" + "net/http" + "net/http/httptest" + "os" + "reflect" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/defaults" + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/awstesting" + "github.com/aws/aws-sdk-go/internal/shareddefaults" + "github.com/aws/aws-sdk-go/service/sts" +) + +func setupCredentialsEndpoints(t *testing.T) (endpoints.Resolver, func()) { + origECSEndpoint := shareddefaults.ECSContainerCredentialsURI + + ecsMetadataServer := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ECS" { + w.Write([]byte(ecsResponse)) + } else { + w.Write([]byte("")) + } + })) + shareddefaults.ECSContainerCredentialsURI = ecsMetadataServer.URL + + ec2MetadataServer := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/meta-data/iam/security-credentials/RoleName" { + w.Write([]byte(ec2MetadataResponse)) + } else if r.URL.Path == "/meta-data/iam/security-credentials/" { + w.Write([]byte("RoleName")) + } else { + w.Write([]byte("")) + } + })) + + stsServer := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf( + assumeRoleRespMsg, + time.Now(). + Add(15*time.Minute). + Format("2006-01-02T15:04:05Z")))) + })) + + resolver := endpoints.ResolverFunc( + func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + switch service { + case "ec2metadata": + return endpoints.ResolvedEndpoint{ + URL: ec2MetadataServer.URL, + }, nil + case "sts": + return endpoints.ResolvedEndpoint{ + URL: stsServer.URL, + }, nil + default: + return endpoints.ResolvedEndpoint{}, + fmt.Errorf("unknown service endpoint, %s", service) + } + }) + + return resolver, func() { + shareddefaults.ECSContainerCredentialsURI = origECSEndpoint + ecsMetadataServer.Close() + ec2MetadataServer.Close() + stsServer.Close() + } +} + +func TestSharedConfigCredentialSource(t *testing.T) { + const configFile = "testdata/credential_source_config" + + cases := []struct { + name string + profile string + expectedError error + expectedAccessKey string + expectedSecretKey string + expectedChain []string + init func() + }{ + { + name: "credential source and source profile", + profile: "invalid_source_and_credential_source", + expectedError: ErrSharedConfigSourceCollision, + init: func() { + os.Setenv("AWS_ACCESS_KEY", "access_key") + os.Setenv("AWS_SECRET_KEY", "secret_key") + }, + }, + { + name: "env var credential source", + profile: "env_var_credential_source", + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedChain: []string{ + "assume_role_w_creds_role_arn_env", + }, + init: func() { + os.Setenv("AWS_ACCESS_KEY", "access_key") + os.Setenv("AWS_SECRET_KEY", "secret_key") + }, + }, + { + name: "ec2metadata credential source", + profile: "ec2metadata", + expectedChain: []string{ + "assume_role_w_creds_role_arn_ec2", + }, + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + }, + { + name: "ecs container credential source", + profile: "ecscontainer", + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedChain: []string{ + "assume_role_w_creds_role_arn_ecs", + }, + init: func() { + os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS") + }, + }, + { + name: "chained assume role with env creds", + profile: "chained_assume_role", + expectedAccessKey: "AKID", + expectedSecretKey: "SECRET", + expectedChain: []string{ + "assume_role_w_creds_role_arn_chain", + "assume_role_w_creds_role_arn_ec2", + }, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("%d %s", i, c.name), + func(t *testing.T) { + env := awstesting.StashEnv() + defer awstesting.PopEnv(env) + + os.Setenv("AWS_REGION", "us-east-1") + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") + os.Setenv("AWS_CONFIG_FILE", configFile) + os.Setenv("AWS_PROFILE", c.profile) + + endpointResolver, cleanupFn := setupCredentialsEndpoints(t) + defer cleanupFn() + + if c.init != nil { + c.init() + } + + var credChain []string + handlers := defaults.Handlers() + handlers.Sign.PushBack(func(r *request.Request) { + if r.Config.Credentials == credentials.AnonymousCredentials { + return + } + params := r.Params.(*sts.AssumeRoleInput) + credChain = append(credChain, *params.RoleArn) + }) + + sess, err := NewSessionWithOptions(Options{ + Config: aws.Config{ + Logger: t, + // LogLevel: aws.LogLevel(aws.LogDebugWithHTTPBody), + EndpointResolver: endpointResolver, + }, + Handlers: handlers, + }) + if e, a := c.expectedError, err; e != a { + t.Errorf("expected %v, but received %v", e, a) + } + + if c.expectedError != nil { + return + } + + creds, err := sess.Config.Credentials.Get() + if err != nil { + t.Fatalf("expected no error, but received %v", err) + } + + if e, a := c.expectedChain, credChain; !reflect.DeepEqual(e, a) { + t.Errorf("expected %v, but received %v", e, a) + } + + if e, a := c.expectedAccessKey, creds.AccessKeyID; e != a { + t.Errorf("expected %v, but received %v", e, a) + } + + if e, a := c.expectedSecretKey, creds.SecretAccessKey; e != a { + t.Errorf("expected %v, but received %v", e, a) + } + }) + } +} + +const ecsResponse = `{ + "Code": "Success", + "Type": "AWS-HMAC", + "AccessKeyId" : "ecs-access-key", + "SecretAccessKey" : "ecs-secret-key", + "Token" : "token", + "Expiration" : "2100-01-01T00:00:00Z", + "LastUpdated" : "2009-11-23T0:00:00Z" + }` + +const ec2MetadataResponse = `{ + "Code": "Success", + "Type": "AWS-HMAC", + "AccessKeyId" : "ec2-access-key", + "SecretAccessKey" : "ec2-secret-key", + "Token" : "token", + "Expiration" : "2100-01-01T00:00:00Z", + "LastUpdated" : "2009-11-23T0:00:00Z" + }` + +const assumeRoleRespMsg = ` + + + + arn:aws:sts::account_id:assumed-role/role/session_name + AKID:session_name + + + AKID + SECRET + SESSION_TOKEN + %s + + + + request-id + + +` + +func TestSesisonAssumeRole(t *testing.T) { + oldEnv := initSessionTestEnv() + defer awstesting.PopEnv(oldEnv) + + os.Setenv("AWS_REGION", "us-east-1") + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename) + os.Setenv("AWS_PROFILE", "assume_role_w_creds") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z")))) + })) + + s, err := NewSession(&aws.Config{Endpoint: aws.String(server.URL), DisableSSL: aws.Bool(true)}) + + creds, err := s.Config.Credentials.Get() + if err != nil { + t.Errorf("expect nil, %v", err) + } + if e, a := "AKID", creds.AccessKeyID; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "SECRET", creds.SecretAccessKey; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "SESSION_TOKEN", creds.SessionToken; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "AssumeRoleProvider", creds.ProviderName; !strings.Contains(a, e) { + t.Errorf("expect %v, to contain %v", e, a) + } +} + +func TestSessionAssumeRole_WithMFA(t *testing.T) { + oldEnv := initSessionTestEnv() + defer awstesting.PopEnv(oldEnv) + + os.Setenv("AWS_REGION", "us-east-1") + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename) + os.Setenv("AWS_PROFILE", "assume_role_w_creds") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if e, a := r.FormValue("SerialNumber"), "0123456789"; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := r.FormValue("TokenCode"), "tokencode"; e != a { + t.Errorf("expect %v, got %v", e, a) + } + + w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z")))) + })) + + customProviderCalled := false + sess, err := NewSessionWithOptions(Options{ + Profile: "assume_role_w_mfa", + Config: aws.Config{ + Region: aws.String("us-east-1"), + Endpoint: aws.String(server.URL), + DisableSSL: aws.Bool(true), + }, + SharedConfigState: SharedConfigEnable, + AssumeRoleTokenProvider: func() (string, error) { + customProviderCalled = true + + return "tokencode", nil + }, + }) + if err != nil { + t.Errorf("expect nil, %v", err) + } + + creds, err := sess.Config.Credentials.Get() + if err != nil { + t.Errorf("expect nil, %v", err) + } + if !customProviderCalled { + t.Errorf("expect true") + } + + if e, a := "AKID", creds.AccessKeyID; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "SECRET", creds.SecretAccessKey; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "SESSION_TOKEN", creds.SessionToken; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "AssumeRoleProvider", creds.ProviderName; !strings.Contains(a, e) { + t.Errorf("expect %v, to contain %v", e, a) + } +} + +func TestSessionAssumeRole_WithMFA_NoTokenProvider(t *testing.T) { + oldEnv := initSessionTestEnv() + defer awstesting.PopEnv(oldEnv) + + os.Setenv("AWS_REGION", "us-east-1") + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename) + os.Setenv("AWS_PROFILE", "assume_role_w_creds") + + _, err := NewSessionWithOptions(Options{ + Profile: "assume_role_w_mfa", + SharedConfigState: SharedConfigEnable, + }) + if e, a := (AssumeRoleTokenProviderNotSetError{}), err; e != a { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestSessionAssumeRole_DisableSharedConfig(t *testing.T) { + // Backwards compatibility with Shared config disabled + // assume role should not be built into the config. + oldEnv := initSessionTestEnv() + defer awstesting.PopEnv(oldEnv) + + os.Setenv("AWS_SDK_LOAD_CONFIG", "0") + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename) + os.Setenv("AWS_PROFILE", "assume_role_w_creds") + + s, err := NewSession() + if err != nil { + t.Errorf("expect nil, %v", err) + } + + creds, err := s.Config.Credentials.Get() + if err != nil { + t.Errorf("expect nil, %v", err) + } + if e, a := "assume_role_w_creds_akid", creds.AccessKeyID; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "assume_role_w_creds_secret", creds.SecretAccessKey; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "SharedConfigCredentials", creds.ProviderName; !strings.Contains(a, e) { + t.Errorf("expect %v, to contain %v", e, a) + } +} + +func TestSessionAssumeRole_InvalidSourceProfile(t *testing.T) { + // Backwards compatibility with Shared config disabled + // assume role should not be built into the config. + oldEnv := initSessionTestEnv() + defer awstesting.PopEnv(oldEnv) + + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename) + os.Setenv("AWS_PROFILE", "assume_role_invalid_source_profile") + + s, err := NewSession() + if err == nil { + t.Errorf("expect error") + } + if e, a := "SharedConfigAssumeRoleError: failed to load assume role", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %v, to contain %v", e, a) + } + if s != nil { + t.Errorf("expect nil, %v", err) + } +} diff --git a/aws/session/session.go b/aws/session/session.go index be4b5f07772..dab8fea3998 100644 --- a/aws/session/session.go +++ b/aws/session/session.go @@ -3,7 +3,6 @@ package session import ( "crypto/tls" "crypto/x509" - "fmt" "io" "io/ioutil" "net/http" @@ -14,13 +13,10 @@ import ( "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/corehandlers" "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/processcreds" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/csm" "github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/internal/shareddefaults" ) const ( @@ -224,6 +220,12 @@ type Options struct { // to also enable this feature. CustomCABundle session option field has priority // over the AWS_CA_BUNDLE environment variable, and will be used if both are set. CustomCABundle io.Reader + + // The handlers that the session and all API clients will be created with. + // This must be a complete set of handlers. Use the defaults.Handlers() + // function to initialize this value before changing the handlers to be + // used by the SDK. + Handlers request.Handlers } // NewSessionWithOptions returns a new Session created from SDK defaults, config files, @@ -344,7 +346,11 @@ func enableCSM(handlers *request.Handlers, clientID string, port string, logger func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, error) { cfg := defaults.Config() - handlers := defaults.Handlers() + + handlers := opts.Handlers + if handlers.IsEmpty() { + handlers = defaults.Handlers() + } // Get a merged version of the user provided config to determine if // credentials were. @@ -443,7 +449,11 @@ func loadCertPool(r io.Reader) (*x509.CertPool, error) { return p, nil } -func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers, sessOpts Options) error { +func mergeConfigSrcs(cfg, userCfg *aws.Config, + envCfg envConfig, sharedCfg sharedConfig, + handlers request.Handlers, + sessOpts Options, +) error { // Merge in user provided configuration cfg.MergeIn(userCfg) @@ -464,164 +474,19 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, envCfg envConfig, sharedCfg share } } - // Configure credentials if not already set + // Configure credentials if not already set by the user when creating the + // Session. if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil { - - // inspect the profile to see if a credential source has been specified. - if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.CredentialSource) > 0 { - - // if both credential_source and source_profile have been set, return an error - // as this is undefined behavior. - if len(sharedCfg.AssumeRole.SourceProfile) > 0 { - return ErrSharedConfigSourceCollision - } - - // valid credential source values - const ( - credSourceEc2Metadata = "Ec2InstanceMetadata" - credSourceEnvironment = "Environment" - credSourceECSContainer = "EcsContainer" - ) - - switch sharedCfg.AssumeRole.CredentialSource { - case credSourceEc2Metadata: - cfgCp := *cfg - p := defaults.RemoteCredProvider(cfgCp, handlers) - cfgCp.Credentials = credentials.NewCredentials(p) - - if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil { - // AssumeRole Token provider is required if doing Assume Role - // with MFA. - return AssumeRoleTokenProviderNotSetError{} - } - - cfg.Credentials = assumeRoleCredentials(cfgCp, handlers, sharedCfg, sessOpts) - case credSourceEnvironment: - cfg.Credentials = credentials.NewStaticCredentialsFromCreds( - envCfg.Creds, - ) - case credSourceECSContainer: - if len(os.Getenv(shareddefaults.ECSCredsProviderEnvVar)) == 0 { - return ErrSharedConfigECSContainerEnvVarEmpty - } - - cfgCp := *cfg - p := defaults.RemoteCredProvider(cfgCp, handlers) - creds := credentials.NewCredentials(p) - - cfg.Credentials = creds - default: - return ErrSharedConfigInvalidCredSource - } - - return nil - } - - if len(envCfg.Creds.AccessKeyID) > 0 { - cfg.Credentials = credentials.NewStaticCredentialsFromCreds( - envCfg.Creds, - ) - } else if envCfg.EnableSharedConfig && len(sharedCfg.AssumeRole.RoleARN) > 0 && sharedCfg.AssumeRoleSource != nil { - cfgCp := *cfg - cfgCp.Credentials = credentials.NewStaticCredentialsFromCreds( - sharedCfg.AssumeRoleSource.Creds, - ) - - if len(sharedCfg.AssumeRole.MFASerial) > 0 && sessOpts.AssumeRoleTokenProvider == nil { - // AssumeRole Token provider is required if doing Assume Role - // with MFA. - return AssumeRoleTokenProviderNotSetError{} - } - - cfg.Credentials = assumeRoleCredentials(cfgCp, handlers, sharedCfg, sessOpts) - } else if len(sharedCfg.Creds.AccessKeyID) > 0 { - cfg.Credentials = credentials.NewStaticCredentialsFromCreds( - sharedCfg.Creds, - ) - } else if len(sharedCfg.CredentialProcess) > 0 { - cfg.Credentials = processcreds.NewCredentials( - sharedCfg.CredentialProcess, - ) - } else { - // Fallback to default credentials provider, include mock errors - // for the credential chain so user can identify why credentials - // failed to be retrieved. - cfg.Credentials = credentials.NewCredentials(&credentials.ChainProvider{ - VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors), - Providers: []credentials.Provider{ - &credProviderError{Err: awserr.New("EnvAccessKeyNotFound", "failed to find credentials in the environment.", nil)}, - &credProviderError{Err: awserr.New("SharedCredsLoad", fmt.Sprintf("failed to load profile, %s.", envCfg.Profile), nil)}, - defaults.RemoteCredProvider(*cfg, handlers), - }, - }) + creds, err := resolveCredentials(cfg, envCfg, sharedCfg, handlers, sessOpts) + if err != nil { + return err } + cfg.Credentials = creds } return nil } -func assumeRoleCredentials(cfg aws.Config, handlers request.Handlers, sharedCfg sharedConfig, sessOpts Options) *credentials.Credentials { - return stscreds.NewCredentials( - &Session{ - Config: &cfg, - Handlers: handlers.Copy(), - }, - sharedCfg.AssumeRole.RoleARN, - func(opt *stscreds.AssumeRoleProvider) { - opt.RoleSessionName = sharedCfg.AssumeRole.RoleSessionName - - // Assume role with external ID - if len(sharedCfg.AssumeRole.ExternalID) > 0 { - opt.ExternalID = aws.String(sharedCfg.AssumeRole.ExternalID) - } - - // Assume role with MFA - if len(sharedCfg.AssumeRole.MFASerial) > 0 { - opt.SerialNumber = aws.String(sharedCfg.AssumeRole.MFASerial) - opt.TokenProvider = sessOpts.AssumeRoleTokenProvider - } - }, - ) -} - -// AssumeRoleTokenProviderNotSetError is an error returned when creating a session when the -// MFAToken option is not set when shared config is configured load assume a -// role with an MFA token. -type AssumeRoleTokenProviderNotSetError struct{} - -// Code is the short id of the error. -func (e AssumeRoleTokenProviderNotSetError) Code() string { - return "AssumeRoleTokenProviderNotSetError" -} - -// Message is the description of the error -func (e AssumeRoleTokenProviderNotSetError) Message() string { - return fmt.Sprintf("assume role with MFA enabled, but AssumeRoleTokenProvider session option not set.") -} - -// OrigErr is the underlying error that caused the failure. -func (e AssumeRoleTokenProviderNotSetError) OrigErr() error { - return nil -} - -// Error satisfies the error interface. -func (e AssumeRoleTokenProviderNotSetError) Error() string { - return awserr.SprintError(e.Code(), e.Message(), "", nil) -} - -type credProviderError struct { - Err error -} - -var emptyCreds = credentials.Value{} - -func (c credProviderError) Retrieve() (credentials.Value, error) { - return credentials.Value{}, c.Err -} -func (c credProviderError) IsExpired() bool { - return true -} - func initHandlers(s *Session) { // Add the Validate parameter handler if it is not disabled. s.Handlers.Validate.Remove(corehandlers.ValidateParametersHandler) diff --git a/aws/session/session_test.go b/aws/session/session_test.go index 61a1fd2f5fb..af63d85ca22 100644 --- a/aws/session/session_test.go +++ b/aws/session/session_test.go @@ -1,21 +1,21 @@ +// +build go1.7 + package session import ( "bytes" "fmt" "net/http" - "net/http/httptest" "os" + "strconv" "strings" "testing" - "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/awstesting" - "github.com/aws/aws-sdk-go/internal/shareddefaults" "github.com/aws/aws-sdk-go/service/s3" ) @@ -367,454 +367,42 @@ func TestNewSessionWithOptions_Overrides(t *testing.T) { }, } - for _, c := range cases { - oldEnv := initSessionTestEnv() - defer awstesting.PopEnv(oldEnv) - - for k, v := range c.InEnvs { - os.Setenv(k, v) - } - - s, err := NewSessionWithOptions(Options{ - Profile: c.InProfile, - SharedConfigState: SharedConfigEnable, + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + oldEnv := initSessionTestEnv() + defer awstesting.PopEnv(oldEnv) + + for k, v := range c.InEnvs { + os.Setenv(k, v) + } + + s, err := NewSessionWithOptions(Options{ + Profile: c.InProfile, + SharedConfigState: SharedConfigEnable, + }) + if err != nil { + t.Errorf("expect nil, %v", err) + } + + creds, err := s.Config.Credentials.Get() + if err != nil { + t.Errorf("expect nil, %v", err) + } + if e, a := c.OutRegion, *s.Config.Region; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := c.OutCreds.AccessKeyID, creds.AccessKeyID; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := c.OutCreds.SecretAccessKey, creds.SecretAccessKey; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := c.OutCreds.SessionToken, creds.SessionToken; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := c.OutCreds.ProviderName, creds.ProviderName; !strings.Contains(a, e) { + t.Errorf("expect %v, to contain %v", e, a) + } }) - if err != nil { - t.Errorf("expect nil, %v", err) - } - - creds, err := s.Config.Credentials.Get() - if err != nil { - t.Errorf("expect nil, %v", err) - } - if e, a := c.OutRegion, *s.Config.Region; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := c.OutCreds.AccessKeyID, creds.AccessKeyID; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := c.OutCreds.SecretAccessKey, creds.SecretAccessKey; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := c.OutCreds.SessionToken, creds.SessionToken; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := c.OutCreds.ProviderName, creds.ProviderName; !strings.Contains(a, e) { - t.Errorf("expect %v, to contain %v", e, a) - } - } -} - -const assumeRoleRespMsg = ` - - - - arn:aws:sts::account_id:assumed-role/role/session_name - AKID:session_name - - - AKID - SECRET - SESSION_TOKEN - %s - - - - request-id - - -` - -func TestSesisonAssumeRole(t *testing.T) { - oldEnv := initSessionTestEnv() - defer awstesting.PopEnv(oldEnv) - - os.Setenv("AWS_REGION", "us-east-1") - os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename) - os.Setenv("AWS_PROFILE", "assume_role_w_creds") - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z")))) - })) - - s, err := NewSession(&aws.Config{Endpoint: aws.String(server.URL), DisableSSL: aws.Bool(true)}) - - creds, err := s.Config.Credentials.Get() - if err != nil { - t.Errorf("expect nil, %v", err) } - if e, a := "AKID", creds.AccessKeyID; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := "SECRET", creds.SecretAccessKey; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := "SESSION_TOKEN", creds.SessionToken; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := "AssumeRoleProvider", creds.ProviderName; !strings.Contains(a, e) { - t.Errorf("expect %v, to contain %v", e, a) - } -} - -func TestSessionAssumeRole_WithMFA(t *testing.T) { - oldEnv := initSessionTestEnv() - defer awstesting.PopEnv(oldEnv) - - os.Setenv("AWS_REGION", "us-east-1") - os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename) - os.Setenv("AWS_PROFILE", "assume_role_w_creds") - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if e, a := r.FormValue("SerialNumber"), "0123456789"; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := r.FormValue("TokenCode"), "tokencode"; e != a { - t.Errorf("expect %v, got %v", e, a) - } - - w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z")))) - })) - - customProviderCalled := false - sess, err := NewSessionWithOptions(Options{ - Profile: "assume_role_w_mfa", - Config: aws.Config{ - Region: aws.String("us-east-1"), - Endpoint: aws.String(server.URL), - DisableSSL: aws.Bool(true), - }, - SharedConfigState: SharedConfigEnable, - AssumeRoleTokenProvider: func() (string, error) { - customProviderCalled = true - - return "tokencode", nil - }, - }) - if err != nil { - t.Errorf("expect nil, %v", err) - } - - creds, err := sess.Config.Credentials.Get() - if err != nil { - t.Errorf("expect nil, %v", err) - } - if !customProviderCalled { - t.Errorf("expect true") - } - - if e, a := "AKID", creds.AccessKeyID; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := "SECRET", creds.SecretAccessKey; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := "SESSION_TOKEN", creds.SessionToken; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := "AssumeRoleProvider", creds.ProviderName; !strings.Contains(a, e) { - t.Errorf("expect %v, to contain %v", e, a) - } -} - -func TestSessionAssumeRole_WithMFA_NoTokenProvider(t *testing.T) { - oldEnv := initSessionTestEnv() - defer awstesting.PopEnv(oldEnv) - - os.Setenv("AWS_REGION", "us-east-1") - os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename) - os.Setenv("AWS_PROFILE", "assume_role_w_creds") - - _, err := NewSessionWithOptions(Options{ - Profile: "assume_role_w_mfa", - SharedConfigState: SharedConfigEnable, - }) - if e, a := (AssumeRoleTokenProviderNotSetError{}), err; e != a { - t.Errorf("expect %v, got %v", e, a) - } -} - -func TestSessionAssumeRole_DisableSharedConfig(t *testing.T) { - // Backwards compatibility with Shared config disabled - // assume role should not be built into the config. - oldEnv := initSessionTestEnv() - defer awstesting.PopEnv(oldEnv) - - os.Setenv("AWS_SDK_LOAD_CONFIG", "0") - os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename) - os.Setenv("AWS_PROFILE", "assume_role_w_creds") - - s, err := NewSession() - if err != nil { - t.Errorf("expect nil, %v", err) - } - - creds, err := s.Config.Credentials.Get() - if err != nil { - t.Errorf("expect nil, %v", err) - } - if e, a := "assume_role_w_creds_akid", creds.AccessKeyID; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := "assume_role_w_creds_secret", creds.SecretAccessKey; e != a { - t.Errorf("expect %v, got %v", e, a) - } - if e, a := "SharedConfigCredentials", creds.ProviderName; !strings.Contains(a, e) { - t.Errorf("expect %v, to contain %v", e, a) - } -} - -func TestSessionAssumeRole_InvalidSourceProfile(t *testing.T) { - // Backwards compatibility with Shared config disabled - // assume role should not be built into the config. - oldEnv := initSessionTestEnv() - defer awstesting.PopEnv(oldEnv) - - os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - os.Setenv("AWS_SHARED_CREDENTIALS_FILE", testConfigFilename) - os.Setenv("AWS_PROFILE", "assume_role_invalid_source_profile") - - s, err := NewSession() - if err == nil { - t.Errorf("expect error") - } - if e, a := "SharedConfigAssumeRoleError: failed to load assume role", err.Error(); !strings.Contains(a, e) { - t.Errorf("expect %v, to contain %v", e, a) - } - if s != nil { - t.Errorf("expect nil, %v", err) - } -} - -func TestSharedConfigCredentialSource(t *testing.T) { - cases := []struct { - name string - profile string - expectedError error - expectedAccessKey string - expectedSecretKey string - init func(*aws.Config, string) func() error - }{ - { - name: "env var credential source", - profile: "env_var_credential_source", - expectedAccessKey: "access_key", - expectedSecretKey: "secret_key", - init: func(cfg *aws.Config, profile string) func() error { - os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - os.Setenv("AWS_CONFIG_FILE", "testdata/credential_source_config") - os.Setenv("AWS_PROFILE", profile) - os.Setenv("AWS_ACCESS_KEY", "access_key") - os.Setenv("AWS_SECRET_KEY", "secret_key") - - return func() error { - os.Unsetenv("AWS_SDK_LOAD_CONFIG") - os.Unsetenv("AWS_CONFIG_FILE") - os.Unsetenv("AWS_PROFILE") - os.Unsetenv("AWS_ACCESS_KEY") - os.Unsetenv("AWS_SECRET_KEY") - - return nil - } - }, - }, - { - name: "credential source and source profile", - profile: "invalid_source_and_credential_source", - expectedError: ErrSharedConfigSourceCollision, - init: func(cfg *aws.Config, profile string) func() error { - os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - os.Setenv("AWS_CONFIG_FILE", "testdata/credential_source_config") - os.Setenv("AWS_PROFILE", profile) - os.Setenv("AWS_ACCESS_KEY", "access_key") - os.Setenv("AWS_SECRET_KEY", "secret_key") - - return func() error { - os.Unsetenv("AWS_SDK_LOAD_CONFIG") - os.Unsetenv("AWS_CONFIG_FILE") - os.Unsetenv("AWS_PROFILE") - os.Unsetenv("AWS_ACCESS_KEY") - os.Unsetenv("AWS_SECRET_KEY") - - return nil - } - }, - }, - { - name: "ec2metadata credential source", - profile: "ec2metadata", - expectedAccessKey: "AKID", - expectedSecretKey: "SECRET", - init: func(cfg *aws.Config, profile string) func() error { - os.Setenv("AWS_REGION", "us-east-1") - os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - os.Setenv("AWS_CONFIG_FILE", "testdata/credential_source_config") - os.Setenv("AWS_PROFILE", "ec2metadata") - - const ec2MetadataResponse = `{ - "Code": "Success", - "Type": "AWS-HMAC", - "AccessKeyId" : "access-key", - "SecretAccessKey" : "secret-key", - "Token" : "token", - "Expiration" : "2100-01-01T00:00:00Z", - "LastUpdated" : "2009-11-23T0:00:00Z" - }` - - ec2MetadataCalled := false - ec2MetadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/meta-data/iam/security-credentials/RoleName" { - ec2MetadataCalled = true - w.Write([]byte(ec2MetadataResponse)) - } else if r.URL.Path == "/meta-data/iam/security-credentials/" { - w.Write([]byte("RoleName")) - } else { - w.Write([]byte("")) - } - })) - - stsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z")))) - })) - - cfg.EndpointResolver = endpoints.ResolverFunc( - func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { - if service == "ec2metadata" { - return endpoints.ResolvedEndpoint{ - URL: ec2MetadataServer.URL, - }, nil - } - - return endpoints.ResolvedEndpoint{ - URL: stsServer.URL, - }, nil - }, - ) - - return func() error { - os.Unsetenv("AWS_SDK_LOAD_CONFIG") - os.Unsetenv("AWS_CONFIG_FILE") - os.Unsetenv("AWS_PROFILE") - os.Unsetenv("AWS_REGION") - - ec2MetadataServer.Close() - stsServer.Close() - - if !ec2MetadataCalled { - return fmt.Errorf("expected ec2metadata to be called") - } - - return nil - } - }, - }, - { - name: "ecs container credential source", - profile: "ecscontainer", - expectedAccessKey: "access-key", - expectedSecretKey: "secret-key", - init: func(cfg *aws.Config, profile string) func() error { - os.Setenv("AWS_REGION", "us-east-1") - os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - os.Setenv("AWS_CONFIG_FILE", "testdata/credential_source_config") - os.Setenv("AWS_PROFILE", "ecscontainer") - os.Setenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/ECS") - - const ecsResponse = `{ - "Code": "Success", - "Type": "AWS-HMAC", - "AccessKeyId" : "access-key", - "SecretAccessKey" : "secret-key", - "Token" : "token", - "Expiration" : "2100-01-01T00:00:00Z", - "LastUpdated" : "2009-11-23T0:00:00Z" - }` - - ecsCredsCalled := false - ecsMetadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/ECS" { - ecsCredsCalled = true - w.Write([]byte(ecsResponse)) - } else { - w.Write([]byte("")) - } - })) - - shareddefaults.ECSContainerCredentialsURI = ecsMetadataServer.URL - - stsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(fmt.Sprintf(assumeRoleRespMsg, time.Now().Add(15*time.Minute).Format("2006-01-02T15:04:05Z")))) - })) - - cfg.Endpoint = aws.String(stsServer.URL) - - cfg.EndpointResolver = endpoints.ResolverFunc( - func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { - fmt.Println("SERVICE", service) - return endpoints.ResolvedEndpoint{ - URL: stsServer.URL, - }, nil - }, - ) - - return func() error { - os.Unsetenv("AWS_SDK_LOAD_CONFIG") - os.Unsetenv("AWS_CONFIG_FILE") - os.Unsetenv("AWS_PROFILE") - os.Unsetenv("AWS_REGION") - os.Unsetenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") - - ecsMetadataServer.Close() - stsServer.Close() - - if !ecsCredsCalled { - return fmt.Errorf("expected ec2metadata to be called") - } - - return nil - } - }, - }, - } - - for _, c := range cases { - cfg := &aws.Config{} - clean := c.init(cfg, c.profile) - sess, err := NewSession(cfg) - if e, a := c.expectedError, err; e != a { - t.Errorf("expected %v, but received %v", e, a) - } - - if c.expectedError != nil { - continue - } - - creds, err := sess.Config.Credentials.Get() - if err != nil { - t.Errorf("expected no error, but received %v", err) - } - - if e, a := c.expectedAccessKey, creds.AccessKeyID; e != a { - t.Errorf("expected %v, but received %v", e, a) - } - - if e, a := c.expectedSecretKey, creds.SecretAccessKey; e != a { - t.Errorf("expected %v, but received %v", e, a) - } - - if err := clean(); err != nil { - t.Errorf("expected no error, but received %v", err) - } - } -} - -func initSessionTestEnv() (oldEnv []string) { - oldEnv = awstesting.StashEnv() - os.Setenv("AWS_CONFIG_FILE", "file_not_exists") - os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "file_not_exists") - - return oldEnv } diff --git a/aws/session/shared_config.go b/aws/session/shared_config.go index 7cb44021b3f..e0102363ddd 100644 --- a/aws/session/shared_config.go +++ b/aws/session/shared_config.go @@ -156,10 +156,20 @@ func (cfg *sharedConfig) setAssumeRoleSource(origProfile string, files []sharedC if err != nil { return err } + + // Chain if profile depends of other profiles + if len(assumeRoleSrc.AssumeRole.SourceProfile) > 0 { + err := assumeRoleSrc.setAssumeRoleSource(cfg.AssumeRole.SourceProfile, files) + if err != nil { + return err + } + } } - if len(assumeRoleSrc.Creds.AccessKeyID) == 0 { - return SharedConfigAssumeRoleError{RoleARN: cfg.AssumeRole.RoleARN} + if cfg.AssumeRole.SourceProfile == origProfile || len(assumeRoleSrc.AssumeRole.SourceProfile) == 0 { + if len(assumeRoleSrc.AssumeRole.CredentialSource) == 0 && len(assumeRoleSrc.Creds.AccessKeyID) == 0 { + return SharedConfigAssumeRoleError{RoleARN: cfg.AssumeRole.RoleARN} + } } cfg.AssumeRoleSource = &assumeRoleSrc diff --git a/aws/session/shared_config_test.go b/aws/session/shared_config_test.go index b92a09b6007..2c93a2d3699 100644 --- a/aws/session/shared_config_test.go +++ b/aws/session/shared_config_test.go @@ -125,6 +125,77 @@ func TestLoadSharedConfig(t *testing.T) { Profile: "profile_name", Err: SharedConfigLoadError{Filename: filepath.Join("testdata", "shared_config_invalid_ini")}, }, + { + Filenames: []string{testConfigOtherFilename, testConfigFilename}, + Profile: "assume_role_with_credential_source", + Expected: sharedConfig{ + AssumeRole: assumeRoleConfig{ + RoleARN: "assume_role_with_credential_source_role_arn", + CredentialSource: credSourceEc2Metadata, + }, + }, + }, + { + Filenames: []string{testConfigOtherFilename, testConfigFilename}, + Profile: "multiple_assume_role", + Expected: sharedConfig{ + AssumeRole: assumeRoleConfig{ + RoleARN: "multiple_assume_role_role_arn", + SourceProfile: "assume_role", + }, + AssumeRoleSource: &sharedConfig{ + AssumeRole: assumeRoleConfig{ + RoleARN: "assume_role_role_arn", + SourceProfile: "complete_creds", + }, + AssumeRoleSource: &sharedConfig{ + Creds: credentials.Value{ + AccessKeyID: "complete_creds_akid", + SecretAccessKey: "complete_creds_secret", + ProviderName: fmt.Sprintf("SharedConfigCredentials: %s", testConfigFilename), + }, + }, + }, + }, + }, + { + Filenames: []string{testConfigOtherFilename, testConfigFilename}, + Profile: "multiple_assume_role_with_credential_source", + Expected: sharedConfig{ + AssumeRole: assumeRoleConfig{ + RoleARN: "multiple_assume_role_with_credential_source_role_arn", + SourceProfile: "assume_role_with_credential_source", + }, + AssumeRoleSource: &sharedConfig{ + AssumeRole: assumeRoleConfig{ + RoleARN: "assume_role_with_credential_source_role_arn", + CredentialSource: credSourceEc2Metadata, + }, + }, + }, + }, + { + Filenames: []string{testConfigOtherFilename, testConfigFilename}, + Profile: "multiple_assume_role_with_credential_source2", + Expected: sharedConfig{ + AssumeRole: assumeRoleConfig{ + RoleARN: "multiple_assume_role_with_credential_source2_role_arn", + SourceProfile: "multiple_assume_role_with_credential_source", + }, + AssumeRoleSource: &sharedConfig{ + AssumeRole: assumeRoleConfig{ + RoleARN: "multiple_assume_role_with_credential_source_role_arn", + SourceProfile: "assume_role_with_credential_source", + }, + AssumeRoleSource: &sharedConfig{ + AssumeRole: assumeRoleConfig{ + RoleARN: "assume_role_with_credential_source_role_arn", + CredentialSource: credSourceEc2Metadata, + }, + }, + }, + }, + }, } for i, c := range cases { @@ -139,7 +210,7 @@ func TestLoadSharedConfig(t *testing.T) { if err != nil { t.Errorf("%d, expect nil, %v", i, err) } - if e, a := c.Expected, cfg; !reflect.DeepEqual(e,a) { + if e, a := c.Expected, cfg; !reflect.DeepEqual(e, a) { t.Errorf("%d, expect %v, got %v", i, e, a) } } @@ -249,7 +320,7 @@ func TestLoadSharedConfigFromFile(t *testing.T) { if err != nil { t.Errorf("%d, expect nil, %v", i, err) } - if e, a := c.Expected, cfg; e != a { + if e, a := c.Expected, cfg; !reflect.DeepEqual(e, a) { t.Errorf("%d, expect %v, got %v", i, e, a) } } diff --git a/aws/session/shared_test.go b/aws/session/shared_test.go new file mode 100644 index 00000000000..13f01677f1a --- /dev/null +++ b/aws/session/shared_test.go @@ -0,0 +1,15 @@ +package session + +import ( + "os" + + "github.com/aws/aws-sdk-go/awstesting" +) + +func initSessionTestEnv() (oldEnv []string) { + oldEnv = awstesting.StashEnv() + os.Setenv("AWS_CONFIG_FILE", "file_not_exists") + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "file_not_exists") + + return oldEnv +} diff --git a/aws/session/testdata/credential_source_config b/aws/session/testdata/credential_source_config index 58b66614ad3..4379eb84df9 100644 --- a/aws/session/testdata/credential_source_config +++ b/aws/session/testdata/credential_source_config @@ -1,16 +1,20 @@ [env_var_credential_source] -role_arn = arn +role_arn = assume_role_w_creds_role_arn_env credential_source = Environment [invalid_source_and_credential_source] -role_arn = arn +role_arn = assume_role_w_creds_role_arn_bad credential_source = Environment source_profile = env_var_credential_source [ec2metadata] -role_arn = assume_role_w_creds_role_arn +role_arn = assume_role_w_creds_role_arn_ec2 credential_source = Ec2InstanceMetadata [ecscontainer] -role_arn = assume_role_w_creds_role_arn +role_arn = assume_role_w_creds_role_arn_ecs credential_source = EcsContainer + +[chained_assume_role] +role_arn = assume_role_w_creds_role_arn_chain +source_profile = ec2metadata diff --git a/aws/session/testdata/shared_config b/aws/session/testdata/shared_config index fe816fe201b..7d645d0c1cf 100644 --- a/aws/session/testdata/shared_config +++ b/aws/session/testdata/shared_config @@ -63,3 +63,19 @@ aws_secret_access_key = assume_role_w_creds_secret [assume_role_wo_creds] role_arn = assume_role_wo_creds_role_arn source_profile = assume_role_wo_creds + +[assume_role_with_credential_source] +role_arn = assume_role_with_credential_source_role_arn +credential_source = Ec2InstanceMetadata + +[multiple_assume_role] +role_arn = multiple_assume_role_role_arn +source_profile = assume_role + +[multiple_assume_role_with_credential_source] +role_arn = multiple_assume_role_with_credential_source_role_arn +source_profile = assume_role_with_credential_source + +[multiple_assume_role_with_credential_source2] +role_arn = multiple_assume_role_with_credential_source2_role_arn +source_profile = multiple_assume_role_with_credential_source