From 82aa92eca34f8ee4a249327d22931f42277aba69 Mon Sep 17 00:00:00 2001 From: aeitzman Date: Mon, 18 Dec 2023 09:38:27 -0800 Subject: [PATCH] Adding programattic refreshable credentials --- google/internal/externalaccount/aws.go | 79 +++++++---- google/internal/externalaccount/aws_test.go | 124 +++++++++++++++++- .../externalaccount/basecredentials.go | 55 ++++++-- .../programmaticrefreshcredsource.go | 17 +++ .../programmaticrefreshcredsource_test.go | 71 ++++++++++ 5 files changed, 303 insertions(+), 43 deletions(-) create mode 100644 google/internal/externalaccount/programmaticrefreshcredsource.go create mode 100644 google/internal/externalaccount/programmaticrefreshcredsource_test.go diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index bd4efd19b..7ad9390ba 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -26,22 +26,28 @@ import ( "golang.org/x/oauth2" ) -type awsSecurityCredentials struct { - AccessKeyID string `json:"AccessKeyID"` +// Models AWS security credentials. +type AwsSecurityCredentials struct { + // AWS Access Key ID - Required. + AccessKeyID string `json:"AccessKeyID"` + // AWS Secret Access Key - Required. SecretAccessKey string `json:"SecretAccessKey"` - SecurityToken string `json:"Token"` + // AWS Session token - Optional. + SessionToken string `json:"Token"` } // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature. type awsRequestSigner struct { RegionName string - AwsSecurityCredentials awsSecurityCredentials + AwsSecurityCredentials AwsSecurityCredentials } // getenv aliases os.Getenv for testing var getenv = os.Getenv const ( + defaultRegionalCredentialVerificationUrl = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" + // AWS Signature Version 4 signing algorithm identifier. awsAlgorithm = "AWS4-HMAC-SHA256" @@ -197,8 +203,8 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error { signedRequest.Header.Add("host", requestHost(req)) - if rs.AwsSecurityCredentials.SecurityToken != "" { - signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SecurityToken) + if rs.AwsSecurityCredentials.SessionToken != "" { + signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken) } if signedRequest.Header.Get("date") == "" { @@ -251,16 +257,17 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp } type awsCredentialSource struct { - EnvironmentID string - RegionURL string - RegionalCredVerificationURL string - CredVerificationURL string - IMDSv2SessionTokenURL string - TargetResource string - requestSigner *awsRequestSigner - region string - ctx context.Context - client *http.Client + EnvironmentID string + RegionURL string + RegionalCredVerificationURL string + CredVerificationURL string + IMDSv2SessionTokenURL string + TargetResource string + requestSigner *awsRequestSigner + Region string + ctx context.Context + client *http.Client + AwsSecurityCredentialsSupplier func() (AwsSecurityCredentials, error) } type awsRequestHeader struct { @@ -292,18 +299,25 @@ func canRetrieveSecurityCredentialFromEnvironment() bool { return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != "" } -func shouldUseMetadataServer() bool { - return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment() +func (cs awsCredentialSource) shouldUseMetadataServer() bool { + return cs.AwsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()) } func (cs awsCredentialSource) credentialSourceType() string { + if cs.AwsSecurityCredentialsSupplier != nil { + return "programmatic" + } return "aws" } func (cs awsCredentialSource) subjectToken() (string, error) { + // Set Defaults + if cs.RegionalCredVerificationURL == "" { + cs.RegionalCredVerificationURL = defaultRegionalCredentialVerificationUrl + } if cs.requestSigner == nil { headers := make(map[string]string) - if shouldUseMetadataServer() { + if cs.shouldUseMetadataServer() { awsSessionToken, err := cs.getAWSSessionToken() if err != nil { return "", err @@ -318,20 +332,20 @@ func (cs awsCredentialSource) subjectToken() (string, error) { if err != nil { return "", err } - - if cs.region, err = cs.getRegion(headers); err != nil { + cs.Region, err = cs.getRegion(headers) + if err != nil { return "", err } cs.requestSigner = &awsRequestSigner{ - RegionName: cs.region, + RegionName: cs.Region, AwsSecurityCredentials: awsSecurityCredentials, } } // Generate the signed request to AWS STS GetCallerIdentity API. // Use the required regional endpoint. Otherwise, the request will fail. - req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil) + req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.Region, 1), nil) if err != nil { return "", err } @@ -417,6 +431,12 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { } func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) { + if cs.Region != "" { + return cs.Region, nil + } + if cs.AwsSecurityCredentialsSupplier != nil { + return "", errors.New("oauth2/google: AwsRegion must be provided when using an AwsSecurityCredentialsSupplier") + } if canRetrieveRegionFromEnvironment() { if envAwsRegion := getenv(awsRegion); envAwsRegion != "" { return envAwsRegion, nil @@ -461,12 +481,15 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err return string(respBody[:respBodyEnd]), nil } -func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) { +func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result AwsSecurityCredentials, err error) { + if cs.AwsSecurityCredentialsSupplier != nil { + return cs.AwsSecurityCredentialsSupplier() + } if canRetrieveSecurityCredentialFromEnvironment() { - return awsSecurityCredentials{ + return AwsSecurityCredentials{ AccessKeyID: getenv(awsAccessKeyId), SecretAccessKey: getenv(awsSecretAccessKey), - SecurityToken: getenv(awsSessionToken), + SessionToken: getenv(awsSessionToken), }, nil } @@ -491,8 +514,8 @@ func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) return credentials, nil } -func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (awsSecurityCredentials, error) { - var result awsSecurityCredentials +func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (AwsSecurityCredentials, error) { + var result AwsSecurityCredentials req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil) if err != nil { diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go index 28dc5284b..63eb1d913 100644 --- a/google/internal/externalaccount/aws_test.go +++ b/google/internal/externalaccount/aws_test.go @@ -7,6 +7,7 @@ package externalaccount import ( "context" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -36,7 +37,7 @@ func setEnvironment(env map[string]string) func(string) string { var defaultRequestSigner = &awsRequestSigner{ RegionName: "us-east-1", - AwsSecurityCredentials: awsSecurityCredentials{ + AwsSecurityCredentials: AwsSecurityCredentials{ AccessKeyID: "AKIDEXAMPLE", SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", }, @@ -50,10 +51,10 @@ const ( var requestSignerWithToken = &awsRequestSigner{ RegionName: "us-east-2", - AwsSecurityCredentials: awsSecurityCredentials{ + AwsSecurityCredentials: AwsSecurityCredentials{ AccessKeyID: accessKeyID, SecretAccessKey: secretAccessKey, - SecurityToken: securityToken, + SessionToken: securityToken, }, } @@ -388,7 +389,7 @@ func TestAWSv4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *test func TestAWSv4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) { var requestSigner = &awsRequestSigner{ RegionName: "us-east-2", - AwsSecurityCredentials: awsSecurityCredentials{ + AwsSecurityCredentials: AwsSecurityCredentials{ AccessKeyID: accessKeyID, SecretAccessKey: secretAccessKey, }, @@ -541,10 +542,10 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security req.Header.Add("x-goog-cloud-target-resource", testFileConfig.Audience) signer := &awsRequestSigner{ RegionName: region, - AwsSecurityCredentials: awsSecurityCredentials{ + AwsSecurityCredentials: AwsSecurityCredentials{ AccessKeyID: accessKeyID, SecretAccessKey: secretAccessKey, - SecurityToken: securityToken, + SessionToken: securityToken, }, } signer.SignRequest(req) @@ -1235,6 +1236,117 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testin } } +func TestAWSCredential_ProgrammaticAuth(t *testing.T) { + tfc := testFileConfig + securityCredentials := AwsSecurityCredentials{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + SessionToken: securityToken, + } + + tfc.CredentialSource = CredentialSource{ + AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) { + return securityCredentials, nil + }, + AwsRegion: "us-east-2", + } + + oldNow := now + defer func() { + now = oldNow + }() + now = setTime(defaultTime) + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-east-2", + accessKeyID, + secretAccessKey, + securityToken, + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ProgrammaticAuthNoSessionToken(t *testing.T) { + tfc := testFileConfig + securityCredentials := AwsSecurityCredentials{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + } + + tfc.CredentialSource = CredentialSource{ + AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) { + return securityCredentials, nil + }, + AwsRegion: "us-east-2", + } + + oldNow := now + defer func() { + now = oldNow + }() + now = setTime(defaultTime) + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-east-2", + accessKeyID, + secretAccessKey, + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ProgrammaticAuthError(t *testing.T) { + tfc := testFileConfig + + tfc.CredentialSource = CredentialSource{ + AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) { + return AwsSecurityCredentials{}, errors.New("test error") + }, + AwsRegion: "us-east-2", + } + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + _, err = base.subjectToken() + if err == nil { + t.Fatalf("subjectToken() should have failed") + } + if got, want := err.Error(), "test error"; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = %q, want %q", got, want) + } +} + func TestAwsCredential_CredentialSourceType(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 33288d367..5c6129690 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -107,8 +107,9 @@ func (c *Config) tokenSource(ctx context.Context, scheme string) (oauth2.TokenSo // Subject token file types. const ( - fileTypeText = "text" - fileTypeJSON = "json" + fileTypeText = "text" + fileTypeJSON = "json" + defaultTokenUrl = "https://sts.googleapis.com/v1/token" ) type format struct { @@ -119,22 +120,42 @@ type format struct { } // CredentialSource stores the information necessary to retrieve the credentials for the STS exchange. -// One field amongst File, URL, and Executable should be filled, depending on the kind of credential in question. +// One field amongst File, URL, Executable, SubjectTokenSupplier, or AwsSecurityCredentialSupplier +// should be filled, depending on the kind of credential in question. // The EnvironmentID should start with AWS if being used for an AWS credential. type CredentialSource struct { + // File location for file sourced credentials. File string `json:"file"` - URL string `json:"url"` + // Url to call for URL sourced credentials. + URL string `json:"url"` + // Headers to attach to the request for URL sourced credentials. Headers map[string]string `json:"headers"` + // Configuration object for executable sourced credentials. Executable *ExecutableConfig `json:"executable"` - EnvironmentID string `json:"environment_id"` - RegionURL string `json:"region_url"` + // Environment ID used for AWS sourced credentials. + EnvironmentID string `json:"environment_id"` + // Metadata URL to retrieve the region from for EC2 AWS credentials. + RegionURL string `json:"region_url"` + // AWS regional credential verification URL, will default to "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" if not provided." RegionalCredVerificationURL string `json:"regional_cred_verification_url"` - CredVerificationURL string `json:"cred_verification_url"` - IMDSv2SessionTokenURL string `json:"imdsv2_session_token_url"` - Format format `json:"format"` + // DEPRECATED + CredVerificationURL string `json:"cred_verification_url"` + // URL to retrieve the session token when using IMDSv2 in AWS. + IMDSv2SessionTokenURL string `json:"imdsv2_session_token_url"` + // Format type for the subject token. Used for File and URL sourced credentials. Expected values are "text" or "json". + Format format `json:"format"` + // AWS region, required when an AwsSecurityCredentials Supplier is provided. + AwsRegion string `json:"-"` // Ignore for json. + + // Token supplier for OIDC/SAML credentials. This should be a function that returns + // a valid subject token as a string. + SubjectTokenSupplier func() (string, error) `json:"-"` // Ignore for json. + // AWS Security Credential supplier for AWS credentials. This should be a function + // that returns a valid AwsSecurityCredentials object. + AwsSecurityCredentialsSupplier func() (AwsSecurityCredentials, error) `json:"-"` // Ignore for json. } type ExecutableConfig struct { @@ -145,6 +166,11 @@ type ExecutableConfig struct { // parse determines the type of CredentialSource needed. func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { + //set Defaults + if c.TokenURL == "" { + c.TokenURL = defaultTokenUrl + } + if len(c.CredentialSource.EnvironmentID) > 3 && c.CredentialSource.EnvironmentID[:3] == "aws" { if awsVersion, err := strconv.Atoi(c.CredentialSource.EnvironmentID[3:]); err == nil { if awsVersion != 1 { @@ -157,6 +183,7 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL, CredVerificationURL: c.CredentialSource.URL, TargetResource: c.Audience, + Region: c.CredentialSource.AwsRegion, ctx: ctx, } if c.CredentialSource.IMDSv2SessionTokenURL != "" { @@ -165,12 +192,22 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { return awsCredSource, nil } + } else if c.CredentialSource.AwsSecurityCredentialsSupplier != nil { + awsCredSource := awsCredentialSource{ + Region: c.CredentialSource.AwsRegion, + RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL, + AwsSecurityCredentialsSupplier: c.CredentialSource.AwsSecurityCredentialsSupplier, + TargetResource: c.Audience, + } + return awsCredSource, nil } else if c.CredentialSource.File != "" { return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil } else if c.CredentialSource.URL != "" { return urlCredentialSource{URL: c.CredentialSource.URL, Headers: c.CredentialSource.Headers, Format: c.CredentialSource.Format, ctx: ctx}, nil } else if c.CredentialSource.Executable != nil { return CreateExecutableCredential(ctx, c.CredentialSource.Executable, c) + } else if c.CredentialSource.SubjectTokenSupplier != nil { + return programmaticRefreshCredentialSource{SubjectTokenSupplier: c.CredentialSource.SubjectTokenSupplier}, nil } return nil, fmt.Errorf("oauth2/google: unable to parse credential source") } diff --git a/google/internal/externalaccount/programmaticrefreshcredsource.go b/google/internal/externalaccount/programmaticrefreshcredsource.go new file mode 100644 index 000000000..6808930e1 --- /dev/null +++ b/google/internal/externalaccount/programmaticrefreshcredsource.go @@ -0,0 +1,17 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package externalaccount + +type programmaticRefreshCredentialSource struct { + SubjectTokenSupplier func() (string, error) +} + +func (cs programmaticRefreshCredentialSource) credentialSourceType() string { + return "programmatic" +} + +func (cs programmaticRefreshCredentialSource) subjectToken() (string, error) { + return cs.SubjectTokenSupplier() +} diff --git a/google/internal/externalaccount/programmaticrefreshcredsource_test.go b/google/internal/externalaccount/programmaticrefreshcredsource_test.go new file mode 100644 index 000000000..9e5454826 --- /dev/null +++ b/google/internal/externalaccount/programmaticrefreshcredsource_test.go @@ -0,0 +1,71 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package externalaccount + +import ( + "context" + "errors" + "reflect" + "testing" +) + +func TestRetrieveSubjectToken_ProgrammaticAuth(t *testing.T) { + tfc := testConfig + + tfc.CredentialSource = CredentialSource{ + SubjectTokenSupplier: func() (string, error) { + return "subjectToken", nil + }, + } + + oldNow := now + defer func() { + now = oldNow + }() + now = setTime(defaultTime) + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + if got, want := out, "subjectToken"; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestRetrieveSubjectToken_ProgrammaticAuthFails(t *testing.T) { + tfc := testConfig + + tfc.CredentialSource = CredentialSource{ + SubjectTokenSupplier: func() (string, error) { + return "", errors.New("test error") + }, + } + + oldNow := now + defer func() { + now = oldNow + }() + now = setTime(defaultTime) + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + _, err = base.subjectToken() + if err == nil { + t.Fatalf("subjectToken() should have failed") + } + if got, want := err.Error(), "test error"; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = %q, want %q", got, want) + } +}