Skip to content

Commit

Permalink
google/internal/externalaccount: Added check for aws region and secur…
Browse files Browse the repository at this point in the history
…ity credential environment variables before aws metadata call

Adds check for aws values in environment variables before the metadata server is called to prevent unnecessary off box calls. See googleapis/google-auth-library-java#1100 for same change in java library.

Change-Id: Ie86a899be88c38d3fcbbe377f9bf30a7a66530c0
GitHub-Last-Rev: bcab695
GitHub-Pull-Request: #612
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/453715
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
  • Loading branch information
aeitzman authored and gopherbot committed Nov 30, 2022
1 parent ec4a9b2 commit 510acbc
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 49 deletions.
62 changes: 42 additions & 20 deletions google/internal/externalaccount/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ const (
// The AWS authorization header name for the auto-generated date.
awsDateHeader = "x-amz-date"

// Supported AWS configuration environment variables.
awsAccessKeyId = "AWS_ACCESS_KEY_ID"
awsDefaultRegion = "AWS_DEFAULT_REGION"
awsRegion = "AWS_REGION"
awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
awsSessionToken = "AWS_SESSION_TOKEN"

awsTimeFormatLong = "20060102T150405Z"
awsTimeFormatShort = "20060102"
)
Expand Down Expand Up @@ -317,16 +324,33 @@ func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, erro
return cs.client.Do(req.WithContext(cs.ctx))
}

func canRetrieveRegionFromEnvironment() bool {
// The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is
// required.
return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != ""
}

func canRetrieveSecurityCredentialFromEnvironment() bool {
// Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available.
return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != ""
}

func shouldUseMetadataServer() bool {
return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()
}

func (cs awsCredentialSource) subjectToken() (string, error) {
if cs.requestSigner == nil {
awsSessionToken, err := cs.getAWSSessionToken()
if err != nil {
return "", err
}

headers := make(map[string]string)
if awsSessionToken != "" {
headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
if shouldUseMetadataServer() {
awsSessionToken, err := cs.getAWSSessionToken()
if err != nil {
return "", err
}

if awsSessionToken != "" {
headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
}
}

awsSecurityCredentials, err := cs.getSecurityCredentials(headers)
Expand Down Expand Up @@ -432,11 +456,11 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
}

func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
return envAwsRegion, nil
}
if envAwsRegion := getenv("AWS_DEFAULT_REGION"); envAwsRegion != "" {
return envAwsRegion, nil
if canRetrieveRegionFromEnvironment() {
if envAwsRegion := getenv(awsRegion); envAwsRegion != "" {
return envAwsRegion, nil
}
return getenv("AWS_DEFAULT_REGION"), nil
}

if cs.RegionURL == "" {
Expand Down Expand Up @@ -477,14 +501,12 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err
}

func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) {
if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
return awsSecurityCredentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SecurityToken: getenv("AWS_SESSION_TOKEN"),
}, nil
}
if canRetrieveSecurityCredentialFromEnvironment() {
return awsSecurityCredentials{
AccessKeyID: getenv(awsAccessKeyId),
SecretAccessKey: getenv(awsSecretAccessKey),
SecurityToken: getenv(awsSessionToken),
}, nil
}

roleName, err := cs.getMetadataRoleName(headers)
Expand Down
264 changes: 235 additions & 29 deletions google/internal/externalaccount/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,38 @@ func createDefaultAwsTestServer() *testAwsServer {
)
}

func createDefaultAwsTestServerWithImdsv2(t *testing.T) *testAwsServer {
validateSessionTokenHeaders := func(r *http.Request) {
if r.URL.Path == "/latest/api/token" {
headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader)
if headerValue != awsIMDSv2SessionTtl {
t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl)
}
} else {
headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader)
if headerValue != "sessiontoken" {
t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken")
}
}
}

return createAwsTestServer(
"/latest/meta-data/iam/security-credentials",
"/latest/meta-data/placement/availability-zone",
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"/latest/api/token",
"gcp-aws-role",
"us-east-2b",
map[string]string{
"SecretAccessKey": secretAccessKey,
"AccessKeyId": accessKeyID,
"Token": securityToken,
},
"sessiontoken",
validateSessionTokenHeaders,
)
}

func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch p := r.URL.Path; p {
case server.url:
Expand Down Expand Up @@ -597,35 +629,7 @@ func TestAWSCredential_BasicRequest(t *testing.T) {
}

func TestAWSCredential_IMDSv2(t *testing.T) {
validateSessionTokenHeaders := func(r *http.Request) {
if r.URL.Path == "/latest/api/token" {
headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader)
if headerValue != awsIMDSv2SessionTtl {
t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl)
}
} else {
headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader)
if headerValue != "sessiontoken" {
t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken")
}
}
}

server := createAwsTestServer(
"/latest/meta-data/iam/security-credentials",
"/latest/meta-data/placement/availability-zone",
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"/latest/api/token",
"gcp-aws-role",
"us-east-2b",
map[string]string{
"SecretAccessKey": secretAccessKey,
"AccessKeyId": accessKeyID,
"Token": securityToken,
},
"sessiontoken",
validateSessionTokenHeaders,
)
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
Expand Down Expand Up @@ -1152,6 +1156,208 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
}
}

func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}

metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("Metadata server should not have been called.")
}))

tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.IMDSv2SessionTokenURL = metadataTs.URL

oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}

base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}

out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}

expected := getExpectedSubjectToken(
"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-west-1",
"AKIDEXAMPLE",
"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"",
)

if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
}
}

func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}

tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)

oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": accessKeyID,
"AWS_SECRET_ACCESS_KEY": secretAccessKey,
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}

base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}

out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}

expected := getExpectedSubjectToken(
"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-east-2",
accessKeyID,
secretAccessKey,
"",
)

if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
}
}

func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}

tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)

oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}

base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}

out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}

expected := getExpectedSubjectToken(
"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-west-1",
accessKeyID,
secretAccessKey,
securityToken,
)

if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
}
}

func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) {
server := createDefaultAwsTestServerWithImdsv2(t)
ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}

tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)

oldGetenv := getenv
oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_REGION": "us-west-1",
})
now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}

base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}

out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}

expected := getExpectedSubjectToken(
"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-west-1",
accessKeyID,
secretAccessKey,
securityToken,
)

if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
}
}

func TestAWSCredential_Validations(t *testing.T) {
var metadataServerValidityTests = []struct {
name string
Expand Down

0 comments on commit 510acbc

Please sign in to comment.