Skip to content

Commit

Permalink
Adding programattic refreshable credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
aeitzman committed Dec 18, 2023
1 parent 6e9ec93 commit 82aa92e
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 43 deletions.
79 changes: 51 additions & 28 deletions google/internal/externalaccount/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,28 @@ import (
"golang.org/x/oauth2"
)

type awsSecurityCredentials struct {
AccessKeyID string `json:"AccessKeyID"`
// Models AWS security credentials.
type AwsSecurityCredentials struct {
// AWS Access Key ID - Required.
AccessKeyID string `json:"AccessKeyID"`
// AWS Secret Access Key - Required.
SecretAccessKey string `json:"SecretAccessKey"`
SecurityToken string `json:"Token"`
// AWS Session token - Optional.
SessionToken string `json:"Token"`
}

// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
type awsRequestSigner struct {
RegionName string
AwsSecurityCredentials awsSecurityCredentials
AwsSecurityCredentials AwsSecurityCredentials
}

// getenv aliases os.Getenv for testing
var getenv = os.Getenv

const (
defaultRegionalCredentialVerificationUrl = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"

// AWS Signature Version 4 signing algorithm identifier.
awsAlgorithm = "AWS4-HMAC-SHA256"

Expand Down Expand Up @@ -197,8 +203,8 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error {

signedRequest.Header.Add("host", requestHost(req))

if rs.AwsSecurityCredentials.SecurityToken != "" {
signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SecurityToken)
if rs.AwsSecurityCredentials.SessionToken != "" {
signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken)
}

if signedRequest.Header.Get("date") == "" {
Expand Down Expand Up @@ -251,16 +257,17 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp
}

type awsCredentialSource struct {
EnvironmentID string
RegionURL string
RegionalCredVerificationURL string
CredVerificationURL string
IMDSv2SessionTokenURL string
TargetResource string
requestSigner *awsRequestSigner
region string
ctx context.Context
client *http.Client
EnvironmentID string
RegionURL string
RegionalCredVerificationURL string
CredVerificationURL string
IMDSv2SessionTokenURL string
TargetResource string
requestSigner *awsRequestSigner
Region string
ctx context.Context
client *http.Client
AwsSecurityCredentialsSupplier func() (AwsSecurityCredentials, error)
}

type awsRequestHeader struct {
Expand Down Expand Up @@ -292,18 +299,25 @@ func canRetrieveSecurityCredentialFromEnvironment() bool {
return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != ""
}

func shouldUseMetadataServer() bool {
return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()
func (cs awsCredentialSource) shouldUseMetadataServer() bool {
return cs.AwsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment())
}

func (cs awsCredentialSource) credentialSourceType() string {
if cs.AwsSecurityCredentialsSupplier != nil {
return "programmatic"
}
return "aws"
}

func (cs awsCredentialSource) subjectToken() (string, error) {
// Set Defaults
if cs.RegionalCredVerificationURL == "" {
cs.RegionalCredVerificationURL = defaultRegionalCredentialVerificationUrl
}
if cs.requestSigner == nil {
headers := make(map[string]string)
if shouldUseMetadataServer() {
if cs.shouldUseMetadataServer() {
awsSessionToken, err := cs.getAWSSessionToken()
if err != nil {
return "", err
Expand All @@ -318,20 +332,20 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
if err != nil {
return "", err
}

if cs.region, err = cs.getRegion(headers); err != nil {
cs.Region, err = cs.getRegion(headers)
if err != nil {
return "", err
}

cs.requestSigner = &awsRequestSigner{
RegionName: cs.region,
RegionName: cs.Region,
AwsSecurityCredentials: awsSecurityCredentials,
}
}

// Generate the signed request to AWS STS GetCallerIdentity API.
// Use the required regional endpoint. Otherwise, the request will fail.
req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil)
req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.Region, 1), nil)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -417,6 +431,12 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
}

func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
if cs.Region != "" {
return cs.Region, nil
}
if cs.AwsSecurityCredentialsSupplier != nil {
return "", errors.New("oauth2/google: AwsRegion must be provided when using an AwsSecurityCredentialsSupplier")
}
if canRetrieveRegionFromEnvironment() {
if envAwsRegion := getenv(awsRegion); envAwsRegion != "" {
return envAwsRegion, nil
Expand Down Expand Up @@ -461,12 +481,15 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err
return string(respBody[:respBodyEnd]), nil
}

func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) {
func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result AwsSecurityCredentials, err error) {
if cs.AwsSecurityCredentialsSupplier != nil {
return cs.AwsSecurityCredentialsSupplier()
}
if canRetrieveSecurityCredentialFromEnvironment() {
return awsSecurityCredentials{
return AwsSecurityCredentials{
AccessKeyID: getenv(awsAccessKeyId),
SecretAccessKey: getenv(awsSecretAccessKey),
SecurityToken: getenv(awsSessionToken),
SessionToken: getenv(awsSessionToken),
}, nil
}

Expand All @@ -491,8 +514,8 @@ func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string)
return credentials, nil
}

func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (awsSecurityCredentials, error) {
var result awsSecurityCredentials
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (AwsSecurityCredentials, error) {
var result AwsSecurityCredentials

req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
if err != nil {
Expand Down
124 changes: 118 additions & 6 deletions google/internal/externalaccount/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package externalaccount
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -36,7 +37,7 @@ func setEnvironment(env map[string]string) func(string) string {

var defaultRequestSigner = &awsRequestSigner{
RegionName: "us-east-1",
AwsSecurityCredentials: awsSecurityCredentials{
AwsSecurityCredentials: AwsSecurityCredentials{
AccessKeyID: "AKIDEXAMPLE",
SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
},
Expand All @@ -50,10 +51,10 @@ const (

var requestSignerWithToken = &awsRequestSigner{
RegionName: "us-east-2",
AwsSecurityCredentials: awsSecurityCredentials{
AwsSecurityCredentials: AwsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SecurityToken: securityToken,
SessionToken: securityToken,
},
}

Expand Down Expand Up @@ -388,7 +389,7 @@ func TestAWSv4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *test
func TestAWSv4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
var requestSigner = &awsRequestSigner{
RegionName: "us-east-2",
AwsSecurityCredentials: awsSecurityCredentials{
AwsSecurityCredentials: AwsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
},
Expand Down Expand Up @@ -541,10 +542,10 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
req.Header.Add("x-goog-cloud-target-resource", testFileConfig.Audience)
signer := &awsRequestSigner{
RegionName: region,
AwsSecurityCredentials: awsSecurityCredentials{
AwsSecurityCredentials: AwsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SecurityToken: securityToken,
SessionToken: securityToken,
},
}
signer.SignRequest(req)
Expand Down Expand Up @@ -1235,6 +1236,117 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testin
}
}

func TestAWSCredential_ProgrammaticAuth(t *testing.T) {
tfc := testFileConfig
securityCredentials := AwsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: securityToken,
}

tfc.CredentialSource = CredentialSource{
AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) {
return securityCredentials, nil
},
AwsRegion: "us-east-2",
}

oldNow := now
defer func() {
now = oldNow
}()
now = setTime(defaultTime)

base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}

out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}

expected := getExpectedSubjectToken(
"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-east-2",
accessKeyID,
secretAccessKey,
securityToken,
)

if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
}
}

func TestAWSCredential_ProgrammaticAuthNoSessionToken(t *testing.T) {
tfc := testFileConfig
securityCredentials := AwsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
}

tfc.CredentialSource = CredentialSource{
AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) {
return securityCredentials, nil
},
AwsRegion: "us-east-2",
}

oldNow := now
defer func() {
now = oldNow
}()
now = setTime(defaultTime)

base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}

out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}

expected := getExpectedSubjectToken(
"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-east-2",
accessKeyID,
secretAccessKey,
"",
)

if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
}
}

func TestAWSCredential_ProgrammaticAuthError(t *testing.T) {
tfc := testFileConfig

tfc.CredentialSource = CredentialSource{
AwsSecurityCredentialsSupplier: func() (AwsSecurityCredentials, error) {
return AwsSecurityCredentials{}, errors.New("test error")
},
AwsRegion: "us-east-2",
}

base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}

_, err = base.subjectToken()
if err == nil {
t.Fatalf("subjectToken() should have failed")
}
if got, want := err.Error(), "test error"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}

func TestAwsCredential_CredentialSourceType(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
Expand Down
Loading

0 comments on commit 82aa92e

Please sign in to comment.