Skip to content

Commit

Permalink
Using struct for AWS supplier to better match java implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
aeitzman committed Jan 17, 2024
1 parent 0d06665 commit 85ffd02
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 33 deletions.
30 changes: 16 additions & 14 deletions google/internal/externalaccount/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ type awsCredentialSource struct {
IMDSv2SessionTokenURL string
TargetResource string
requestSigner *awsRequestSigner
Region string
region string
ctx context.Context
client *http.Client
AwsSecurityCredentialsSupplier func() (AwsSecurityCredentials, error)
awsSecurityCredentialsSupplier *AwsSecurityCredentialsSupplier
}

type awsRequestHeader struct {
Expand Down Expand Up @@ -300,11 +300,11 @@ func canRetrieveSecurityCredentialFromEnvironment() bool {
}

func (cs awsCredentialSource) shouldUseMetadataServer() bool {
return cs.AwsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
return cs.awsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
}

func (cs awsCredentialSource) credentialSourceType() string {
if cs.AwsSecurityCredentialsSupplier != nil {
if cs.awsSecurityCredentialsSupplier != nil {
return "programmatic"
}
return "aws"
Expand Down Expand Up @@ -332,22 +332,20 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
if err != nil {
return "", err
}
if cs.Region == "" {
cs.Region, err = cs.getRegion(headers)
if err != nil {
return "", err
}
cs.region, err = cs.getRegion(headers)
if err != nil {
return "", err
}

cs.requestSigner = &awsRequestSigner{
RegionName: cs.Region,
RegionName: cs.region,
AwsSecurityCredentials: awsSecurityCredentials,
}
}

// Generate the signed request to AWS STS GetCallerIdentity API.
// Use the required regional endpoint. Otherwise, the request will fail.
req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.Region, 1), nil)
req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -433,8 +431,12 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
}

func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
if cs.awsSecurityCredentialsSupplier != nil {
return cs.awsSecurityCredentialsSupplier.AwsRegion, nil
}
if canRetrieveRegionFromEnvironment() {
if envAwsRegion := getenv(awsRegion); envAwsRegion != "" {
cs.region = envAwsRegion
return envAwsRegion, nil
}
return getenv("AWS_DEFAULT_REGION"), nil
Expand Down Expand Up @@ -478,11 +480,11 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err
}

func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result AwsSecurityCredentials, err error) {
if cs.AwsSecurityCredentialsSupplier != nil {
if cs.Region == "" {
if cs.awsSecurityCredentialsSupplier != nil {
if cs.awsSecurityCredentialsSupplier.AwsRegion == "" {
return result, errors.New("oauth2/google: AwsRegion must be provided when using an AwsSecurityCredentialsSupplier")
}
return cs.AwsSecurityCredentialsSupplier()
return cs.awsSecurityCredentialsSupplier.GetAwsSecurityCredentials()
}
if canRetrieveSecurityCredentialFromEnvironment() {
return AwsSecurityCredentials{
Expand Down
24 changes: 15 additions & 9 deletions google/internal/externalaccount/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1245,10 +1245,12 @@ func TestAWSCredential_ProgrammaticAuth(t *testing.T) {
}

tfc.CredentialSource = CredentialSource{
AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) {
return securityCredentials, nil
AwsSecurityCredentialsSupplier: &AwsSecurityCredentialsSupplier{
GetAwsSecurityCredentials: func() (AwsSecurityCredentials, error) {
return securityCredentials, nil
},
AwsRegion: "us-east-2",
},
AwsRegion: "us-east-2",
}

oldNow := now
Expand Down Expand Up @@ -1288,10 +1290,12 @@ func TestAWSCredential_ProgrammaticAuthNoSessionToken(t *testing.T) {
}

tfc.CredentialSource = CredentialSource{
AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) {
return securityCredentials, nil
AwsSecurityCredentialsSupplier: &AwsSecurityCredentialsSupplier{
GetAwsSecurityCredentials: func() (AwsSecurityCredentials, error) {
return securityCredentials, nil
},
AwsRegion: "us-east-2",
},
AwsRegion: "us-east-2",
}

oldNow := now
Expand Down Expand Up @@ -1327,10 +1331,12 @@ func TestAWSCredential_ProgrammaticAuthError(t *testing.T) {
tfc := testFileConfig

tfc.CredentialSource = CredentialSource{
AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) {
return AwsSecurityCredentials{}, errors.New("test error")
AwsSecurityCredentialsSupplier: &AwsSecurityCredentialsSupplier{
GetAwsSecurityCredentials: func() (AwsSecurityCredentials, error) {
return AwsSecurityCredentials{}, errors.New("test error")
},
AwsRegion: "us-east-2",
},
AwsRegion: "us-east-2",
}

base, err := tfc.parse(context.Background())
Expand Down
22 changes: 12 additions & 10 deletions google/internal/externalaccount/basecredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,9 @@ type format struct {
}

// CredentialSource stores the information necessary to retrieve the credentials for the STS exchange.
// One field amongst File, URL, Executable, SubjectTokenSupplier, or AwsSecurityCredentialSupplier
// One field amongst File, URL, Executable, SubjectTokenSupplier, or AwsSecurityCredentialSupplierConfig.
// should be filled, depending on the kind of credential in question.
// The EnvironmentID should start with AWS if being used for an AWS credential.
// AwsRegion is required when AwsSecurityCredentialsSupplier is used.
type CredentialSource struct {
// File is the location for file sourced credentials.
File string `json:"file"`
Expand All @@ -149,15 +148,13 @@ type CredentialSource struct {
IMDSv2SessionTokenURL string `json:"imdsv2_session_token_url"`
// Format is the format type for the subject token. Used for File and URL sourced credentials. Expected values are "text" or "json".
Format format `json:"format"`
// AwsRegion is the AWS region, required when an AwsSecurityCredentialsSupplier is provided.
AwsRegion string `json:"-"` // Ignore for json.

// SubjectTokenSupplier is an optional token supplier for OIDC/SAML credentials. This should be a function that returns
// a valid subject token as a string.
SubjectTokenSupplier func() (string, error) `json:"-"` // Ignore for json.
// AwsSecurityCredentialsSupplier is an optional AWS Security Credential supplier for AWS credentials. This should be a function
// that returns a valid AwsSecurityCredentials object.
AwsSecurityCredentialsSupplier func() (AwsSecurityCredentials, error) `json:"-"` // Ignore for json.
// AwsSecurityCredentialsSupplier is an optional AWS Security Credential supplier. This should contain a
// function that returns valid AwsSecurityCredentials and a valid AwsRegion.
AwsSecurityCredentialsSupplier *AwsSecurityCredentialsSupplier `json:"-"` // Ignore for json.
}

type ExecutableConfig struct {
Expand All @@ -166,6 +163,13 @@ type ExecutableConfig struct {
OutputFile string `json:"output_file"`
}

type AwsSecurityCredentialsSupplier struct {
// AwsRegion is the AWS region.
AwsRegion string
// GetAwsSecurityCredentials is a function that should return valid AwsSecurityCredentials.
GetAwsSecurityCredentials func() (AwsSecurityCredentials, error)
}

// parse determines the type of CredentialSource needed.
func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
//set Defaults
Expand All @@ -175,9 +179,8 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {

if c.CredentialSource.AwsSecurityCredentialsSupplier != nil {
awsCredSource := awsCredentialSource{
Region: c.CredentialSource.AwsRegion,
RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
AwsSecurityCredentialsSupplier: c.CredentialSource.AwsSecurityCredentialsSupplier,
awsSecurityCredentialsSupplier: c.CredentialSource.AwsSecurityCredentialsSupplier,
TargetResource: c.Audience,
}
return awsCredSource, nil
Expand All @@ -195,7 +198,6 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
CredVerificationURL: c.CredentialSource.URL,
TargetResource: c.Audience,
Region: c.CredentialSource.AwsRegion,
ctx: ctx,
}
if c.CredentialSource.IMDSv2SessionTokenURL != "" {
Expand Down

0 comments on commit 85ffd02

Please sign in to comment.