From 85ffd0280f7d5fd9df04718a4bde0ef56ff4c2e2 Mon Sep 17 00:00:00 2001 From: aeitzman Date: Tue, 16 Jan 2024 16:21:42 -0800 Subject: [PATCH] Using struct for AWS supplier to better match java implementation --- google/internal/externalaccount/aws.go | 30 ++++++++++--------- google/internal/externalaccount/aws_test.go | 24 +++++++++------ .../externalaccount/basecredentials.go | 22 +++++++------- 3 files changed, 43 insertions(+), 33 deletions(-) diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index d550ed5a4..61e3c7fcc 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -264,10 +264,10 @@ type awsCredentialSource struct { IMDSv2SessionTokenURL string TargetResource string requestSigner *awsRequestSigner - Region string + region string ctx context.Context client *http.Client - AwsSecurityCredentialsSupplier func() (AwsSecurityCredentials, error) + awsSecurityCredentialsSupplier *AwsSecurityCredentialsSupplier } type awsRequestHeader struct { @@ -300,11 +300,11 @@ func canRetrieveSecurityCredentialFromEnvironment() bool { } func (cs awsCredentialSource) shouldUseMetadataServer() bool { - return cs.AwsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()) + return cs.awsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()) } func (cs awsCredentialSource) credentialSourceType() string { - if cs.AwsSecurityCredentialsSupplier != nil { + if cs.awsSecurityCredentialsSupplier != nil { return "programmatic" } return "aws" @@ -332,22 +332,20 @@ func (cs awsCredentialSource) subjectToken() (string, error) { if err != nil { return "", err } - if cs.Region == "" { - cs.Region, err = cs.getRegion(headers) - if err != nil { - return "", err - } + cs.region, err = cs.getRegion(headers) + if err != nil { + return "", err } cs.requestSigner = &awsRequestSigner{ - RegionName: cs.Region, + RegionName: cs.region, AwsSecurityCredentials: awsSecurityCredentials, } } // Generate the signed request to AWS STS GetCallerIdentity API. // Use the required regional endpoint. Otherwise, the request will fail. - req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.Region, 1), nil) + req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil) if err != nil { return "", err } @@ -433,8 +431,12 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { } func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) { + if cs.awsSecurityCredentialsSupplier != nil { + return cs.awsSecurityCredentialsSupplier.AwsRegion, nil + } if canRetrieveRegionFromEnvironment() { if envAwsRegion := getenv(awsRegion); envAwsRegion != "" { + cs.region = envAwsRegion return envAwsRegion, nil } return getenv("AWS_DEFAULT_REGION"), nil @@ -478,11 +480,11 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err } func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result AwsSecurityCredentials, err error) { - if cs.AwsSecurityCredentialsSupplier != nil { - if cs.Region == "" { + if cs.awsSecurityCredentialsSupplier != nil { + if cs.awsSecurityCredentialsSupplier.AwsRegion == "" { return result, errors.New("oauth2/google: AwsRegion must be provided when using an AwsSecurityCredentialsSupplier") } - return cs.AwsSecurityCredentialsSupplier() + return cs.awsSecurityCredentialsSupplier.GetAwsSecurityCredentials() } if canRetrieveSecurityCredentialFromEnvironment() { return AwsSecurityCredentials{ diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go index 63eb1d913..cd35c5da8 100644 --- a/google/internal/externalaccount/aws_test.go +++ b/google/internal/externalaccount/aws_test.go @@ -1245,10 +1245,12 @@ func TestAWSCredential_ProgrammaticAuth(t *testing.T) { } tfc.CredentialSource = CredentialSource{ - AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) { - return securityCredentials, nil + AwsSecurityCredentialsSupplier: &AwsSecurityCredentialsSupplier{ + GetAwsSecurityCredentials: func() (AwsSecurityCredentials, error) { + return securityCredentials, nil + }, + AwsRegion: "us-east-2", }, - AwsRegion: "us-east-2", } oldNow := now @@ -1288,10 +1290,12 @@ func TestAWSCredential_ProgrammaticAuthNoSessionToken(t *testing.T) { } tfc.CredentialSource = CredentialSource{ - AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) { - return securityCredentials, nil + AwsSecurityCredentialsSupplier: &AwsSecurityCredentialsSupplier{ + GetAwsSecurityCredentials: func() (AwsSecurityCredentials, error) { + return securityCredentials, nil + }, + AwsRegion: "us-east-2", }, - AwsRegion: "us-east-2", } oldNow := now @@ -1327,10 +1331,12 @@ func TestAWSCredential_ProgrammaticAuthError(t *testing.T) { tfc := testFileConfig tfc.CredentialSource = CredentialSource{ - AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) { - return AwsSecurityCredentials{}, errors.New("test error") + AwsSecurityCredentialsSupplier: &AwsSecurityCredentialsSupplier{ + GetAwsSecurityCredentials: func() (AwsSecurityCredentials, error) { + return AwsSecurityCredentials{}, errors.New("test error") + }, + AwsRegion: "us-east-2", }, - AwsRegion: "us-east-2", } base, err := tfc.parse(context.Background()) diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 432f3c8c9..8eca02113 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -120,10 +120,9 @@ type format struct { } // CredentialSource stores the information necessary to retrieve the credentials for the STS exchange. -// One field amongst File, URL, Executable, SubjectTokenSupplier, or AwsSecurityCredentialSupplier +// One field amongst File, URL, Executable, SubjectTokenSupplier, or AwsSecurityCredentialSupplierConfig. // should be filled, depending on the kind of credential in question. // The EnvironmentID should start with AWS if being used for an AWS credential. -// AwsRegion is required when AwsSecurityCredentialsSupplier is used. type CredentialSource struct { // File is the location for file sourced credentials. File string `json:"file"` @@ -149,15 +148,13 @@ type CredentialSource struct { IMDSv2SessionTokenURL string `json:"imdsv2_session_token_url"` // Format is the format type for the subject token. Used for File and URL sourced credentials. Expected values are "text" or "json". Format format `json:"format"` - // AwsRegion is the AWS region, required when an AwsSecurityCredentialsSupplier is provided. - AwsRegion string `json:"-"` // Ignore for json. // SubjectTokenSupplier is an optional token supplier for OIDC/SAML credentials. This should be a function that returns // a valid subject token as a string. SubjectTokenSupplier func() (string, error) `json:"-"` // Ignore for json. - // AwsSecurityCredentialsSupplier is an optional AWS Security Credential supplier for AWS credentials. This should be a function - // that returns a valid AwsSecurityCredentials object. - AwsSecurityCredentialsSupplier func() (AwsSecurityCredentials, error) `json:"-"` // Ignore for json. + // AwsSecurityCredentialsSupplier is an optional AWS Security Credential supplier. This should contain a + // function that returns valid AwsSecurityCredentials and a valid AwsRegion. + AwsSecurityCredentialsSupplier *AwsSecurityCredentialsSupplier `json:"-"` // Ignore for json. } type ExecutableConfig struct { @@ -166,6 +163,13 @@ type ExecutableConfig struct { OutputFile string `json:"output_file"` } +type AwsSecurityCredentialsSupplier struct { + // AwsRegion is the AWS region. + AwsRegion string + // GetAwsSecurityCredentials is a function that should return valid AwsSecurityCredentials. + GetAwsSecurityCredentials func() (AwsSecurityCredentials, error) +} + // parse determines the type of CredentialSource needed. func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { //set Defaults @@ -175,9 +179,8 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { if c.CredentialSource.AwsSecurityCredentialsSupplier != nil { awsCredSource := awsCredentialSource{ - Region: c.CredentialSource.AwsRegion, RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL, - AwsSecurityCredentialsSupplier: c.CredentialSource.AwsSecurityCredentialsSupplier, + awsSecurityCredentialsSupplier: c.CredentialSource.AwsSecurityCredentialsSupplier, TargetResource: c.Audience, } return awsCredSource, nil @@ -195,7 +198,6 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL, CredVerificationURL: c.CredentialSource.URL, TargetResource: c.Audience, - Region: c.CredentialSource.AwsRegion, ctx: ctx, } if c.CredentialSource.IMDSv2SessionTokenURL != "" {