Skip to content

Commit

Permalink
Resolve AWS IAM unique IDs (#2814)
Browse files Browse the repository at this point in the history
  • Loading branch information
joelthompson authored and jefferai committed Jun 7, 2017
1 parent 5be733d commit d858511
Show file tree
Hide file tree
Showing 9 changed files with 558 additions and 130 deletions.
90 changes: 90 additions & 0 deletions builtin/credential/aws/backend.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package awsauth

import (
"fmt"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/hashicorp/vault/logical"
Expand Down Expand Up @@ -54,6 +56,15 @@ type backend struct {
// When the credentials are modified or deleted, all the cached client objects
// will be flushed. The empty STS role signifies the master account
IAMClientsMap map[string]map[string]*iam.IAM

// AWS Account ID of the "default" AWS credentials
// This cache avoids the need to call GetCallerIdentity repeatedly to learn it
// We can't store this because, in certain pathological cases, it could change
// out from under us, such as a standby and active Vault server in different AWS
// accounts using their IAM instance profile to get their credentials.
defaultAWSAccountID string

resolveArnToUniqueIDFunc func(logical.Storage, string) (string, error)
}

func Backend(conf *logical.BackendConfig) (*backend, error) {
Expand All @@ -65,6 +76,8 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
IAMClientsMap: make(map[string]map[string]*iam.IAM),
}

b.resolveArnToUniqueIDFunc = b.resolveArnToRealUniqueId

b.Backend = &framework.Backend{
PeriodicFunc: b.periodicFunc,
AuthRenew: b.pathLoginRenew,
Expand Down Expand Up @@ -171,9 +184,86 @@ func (b *backend) invalidate(key string) {
defer b.configMutex.Unlock()
b.flushCachedEC2Clients()
b.flushCachedIAMClients()
b.defaultAWSAccountID = ""
}
}

// Putting this here so we can inject a fake resolver into the backend for unit testing
// purposes
func (b *backend) resolveArnToRealUniqueId(s logical.Storage, arn string) (string, error) {
entity, err := parseIamArn(arn)
if err != nil {
return "", err
}
// This odd-looking code is here because IAM is an inherently global service. IAM and STS ARNs
// don't have regions in them, and there is only a single global endpoint for IAM; see
// http://docs.aws.amazon.com/general/latest/gr/rande.html#iam_region
// However, the ARNs do have a partition in them, because the GovCloud and China partitions DO
// have their own separate endpoints, and the partition is encoded in the ARN. If Amazon's Go SDK
// would allow us to pass a partition back to the IAM client, it would be much simpler. But it
// doesn't appear that's possible, so in order to properly support GovCloud and China, we do a
// circular dance of extracting the partition from the ARN, finding any arbitrary region in the
// partition, and passing that region back back to the SDK, so that the SDK can figure out the
// proper partition from the arbitrary region we passed in to look up the endpoint.
// Sigh
region := getAnyRegionForAwsPartition(entity.Partition)
if region == nil {
return "", fmt.Errorf("Unable to resolve partition %q to a region", entity.Partition)
}
iamClient, err := b.clientIAM(s, region.ID(), entity.AccountNumber)
if err != nil {
return "", err
}

switch entity.Type {
case "user":
userInfo, err := iamClient.GetUser(&iam.GetUserInput{UserName: &entity.FriendlyName})
if err != nil {
return "", err
}
if userInfo == nil {
return "", fmt.Errorf("got nil result from GetUser")
}
return *userInfo.User.UserId, nil
case "role":
roleInfo, err := iamClient.GetRole(&iam.GetRoleInput{RoleName: &entity.FriendlyName})
if err != nil {
return "", err
}
if roleInfo == nil {
return "", fmt.Errorf("got nil result from GetRole")
}
return *roleInfo.Role.RoleId, nil
case "instance-profile":
profileInfo, err := iamClient.GetInstanceProfile(&iam.GetInstanceProfileInput{InstanceProfileName: &entity.FriendlyName})
if err != nil {
return "", err
}
if profileInfo == nil {
return "", fmt.Errorf("got nil result from GetInstanceProfile")
}
return *profileInfo.InstanceProfile.InstanceProfileId, nil
default:
return "", fmt.Errorf("unrecognized error type %#v", entity.Type)
}
}

// Adapted from https://docs.aws.amazon.com/sdk-for-go/api/aws/endpoints/
// the "Enumerating Regions and Endpoint Metadata" section
func getAnyRegionForAwsPartition(partitionId string) *endpoints.Region {
resolver := endpoints.DefaultResolver()
partitions := resolver.(endpoints.EnumPartitions).Partitions()

for _, p := range partitions {
if p.ID() == partitionId {
for _, r := range p.Regions() {
return &r
}
}
}
return nil
}

const backendHelp = `
aws-ec2 auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client
created nonce to authenticates the EC2 instance with Vault.
Expand Down
67 changes: 63 additions & 4 deletions builtin/credential/aws/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import (
"os"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/vault/helper/policyutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
logicaltest "github.com/hashicorp/vault/logical/testing"
)

Expand Down Expand Up @@ -1346,7 +1348,7 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
if err != nil {
t.Fatalf("Received error retrieving identity: %s", err)
}
testIdentityArn, _, _, err := parseIamArn(*testIdentity.Arn)
entity, err := parseIamArn(*testIdentity.Arn)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1385,7 +1387,7 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {

// configuring the valid role we'll be able to login to
roleData := map[string]interface{}{
"bound_iam_principal_arn": testIdentityArn,
"bound_iam_principal_arn": entity.canonicalArn(),
"policies": "root",
"auth_type": iamAuthType,
}
Expand Down Expand Up @@ -1417,8 +1419,17 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
t.Fatalf("bad: failed to create role; resp:%#v\nerr:%v", resp, err)
}

fakeArn := "arn:aws:iam::123456789012:role/FakeRole"
fakeArnResolver := func(s logical.Storage, arn string) (string, error) {
if arn == fakeArn {
return fmt.Sprintf("FakeUniqueIdFor%s", fakeArn), nil
}
return b.resolveArnToRealUniqueId(s, arn)
}
b.resolveArnToUniqueIDFunc = fakeArnResolver

// now we're creating the invalid role we won't be able to login to
roleData["bound_iam_principal_arn"] = "arn:aws:iam::123456789012:role/FakeRole"
roleData["bound_iam_principal_arn"] = fakeArn
roleRequest.Path = "role/" + testInvalidRoleName
resp, err = b.HandleRequest(roleRequest)
if err != nil || (resp != nil && resp.IsError()) {
Expand Down Expand Up @@ -1491,7 +1502,7 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
t.Errorf("bad: expected failed login due to bad auth type: resp:%#v\nerr:%v", resp, err)
}

// finally, the happy path tests :)
// finally, the happy path test :)

loginData["role"] = testValidRoleName
resp, err = b.HandleRequest(loginRequest)
Expand All @@ -1501,4 +1512,52 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
if resp == nil || resp.Auth == nil || resp.IsError() {
t.Errorf("bad: expected valid login: resp:%#v", resp)
}

renewReq := &logical.Request{
Storage: storage,
Auth: &logical.Auth{},
}
empty_login_fd := &framework.FieldData{
Raw: map[string]interface{}{},
Schema: pathLogin(b).Fields,
}
renewReq.Auth.InternalData = resp.Auth.InternalData
renewReq.Auth.Metadata = resp.Auth.Metadata
renewReq.Auth.LeaseOptions = resp.Auth.LeaseOptions
renewReq.Auth.Policies = resp.Auth.Policies
renewReq.Auth.IssueTime = time.Now()
// ensure we can renew
resp, err = b.pathLoginRenew(renewReq, empty_login_fd)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response from renew")
}
if resp.IsError() {
t.Fatalf("got error when renewing: %#v", *resp)
}

// Now, fake out the unique ID resolver to ensure we fail login if the unique ID
// changes from under us
b.resolveArnToUniqueIDFunc = resolveArnToFakeUniqueId
// First, we need to update the role to force Vault to use our fake resolver to
// pick up the fake user ID
roleData["bound_iam_principal_arn"] = entity.canonicalArn()
roleRequest.Path = "role/" + testValidRoleName
resp, err = b.HandleRequest(roleRequest)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: failed to recreate role: resp:%#v\nerr:%v", resp, err)
}
resp, err = b.HandleRequest(loginRequest)
if err != nil || resp == nil || !resp.IsError() {
t.Errorf("bad: expected failed login due to changed AWS role ID: resp: %#v\nerr:%v", resp, err)
}

// and ensure a renew no longer works
resp, err = b.pathLoginRenew(renewReq, empty_login_fd)
if err == nil || (resp != nil && !resp.IsError()) {
t.Errorf("bad: expected failed renew due to changed AWS role ID: resp: %#v", resp, err)
}

}
69 changes: 54 additions & 15 deletions builtin/credential/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/helper/awsutil"
"github.com/hashicorp/vault/logical"
Expand Down Expand Up @@ -70,7 +71,7 @@ func (b *backend) getRawClientConfig(s logical.Storage, region, clientType strin
// It uses getRawClientConfig to obtain config for the runtime environemnt, and if
// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed
// credentials. The credentials will expire after 15 minutes but will auto-refresh.
func (b *backend) getClientConfig(s logical.Storage, region, stsRole, clientType string) (*aws.Config, error) {
func (b *backend) getClientConfig(s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {

config, err := b.getRawClientConfig(s, region, clientType)
if err != nil {
Expand All @@ -80,20 +81,39 @@ func (b *backend) getClientConfig(s logical.Storage, region, stsRole, clientType
return nil, fmt.Errorf("could not compile valid credentials through the default provider chain")
}

stsConfig, err := b.getRawClientConfig(s, region, "sts")
if stsConfig == nil {
return nil, fmt.Errorf("could not configure STS client")
}
if err != nil {
return nil, err
}
if stsRole != "" {
assumeRoleConfig, err := b.getRawClientConfig(s, region, "sts")
if err != nil {
return nil, err
}
if assumeRoleConfig == nil {
return nil, fmt.Errorf("could not configure STS client")
}
assumedCredentials := stscreds.NewCredentials(session.New(assumeRoleConfig), stsRole)
assumedCredentials := stscreds.NewCredentials(session.New(stsConfig), stsRole)
// Test that we actually have permissions to assume the role
if _, err = assumedCredentials.Get(); err != nil {
return nil, err
}
config.Credentials = assumedCredentials
} else {
if b.defaultAWSAccountID == "" {
client := sts.New(session.New(stsConfig))
if client == nil {
return nil, fmt.Errorf("could not obtain sts client: %v", err)
}
inputParams := &sts.GetCallerIdentityInput{}
identity, err := client.GetCallerIdentity(inputParams)
if err != nil {
return nil, fmt.Errorf("unable to fetch current caller: %v", err)
}
if identity == nil {
return nil, fmt.Errorf("got nil result from GetCallerIdentity")
}
b.defaultAWSAccountID = *identity.Account
}
if b.defaultAWSAccountID != accountID {
return nil, fmt.Errorf("unable to fetch client for account ID %s -- default client is for account %s", accountID, b.defaultAWSAccountID)
}
}

return config, nil
Expand Down Expand Up @@ -121,8 +141,25 @@ func (b *backend) flushCachedIAMClients() {
}
}

func (b *backend) stsRoleForAccount(s logical.Storage, accountID string) (string, error) {
// Check if an STS configuration exists for the AWS account
sts, err := b.lockedAwsStsEntry(s, accountID)
if err != nil {
return "", fmt.Errorf("error fetching STS config for account ID %q: %q\n", accountID, err)
}
// An empty STS role signifies the master account
if sts != nil {
return sts.StsRole, nil
}
return "", nil
}

// clientEC2 creates a client to interact with AWS EC2 API
func (b *backend) clientEC2(s logical.Storage, region string, stsRole string) (*ec2.EC2, error) {
func (b *backend) clientEC2(s logical.Storage, region, accountID string) (*ec2.EC2, error) {
stsRole, err := b.stsRoleForAccount(s, accountID)
if err != nil {
return nil, err
}
b.configMutex.RLock()
if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
defer b.configMutex.RUnlock()
Expand All @@ -142,8 +179,7 @@ func (b *backend) clientEC2(s logical.Storage, region string, stsRole string) (*

// Create an AWS config object using a chain of providers
var awsConfig *aws.Config
var err error
awsConfig, err = b.getClientConfig(s, region, stsRole, "ec2")
awsConfig, err = b.getClientConfig(s, region, stsRole, accountID, "ec2")

if err != nil {
return nil, err
Expand All @@ -168,7 +204,11 @@ func (b *backend) clientEC2(s logical.Storage, region string, stsRole string) (*
}

// clientIAM creates a client to interact with AWS IAM API
func (b *backend) clientIAM(s logical.Storage, region string, stsRole string) (*iam.IAM, error) {
func (b *backend) clientIAM(s logical.Storage, region, accountID string) (*iam.IAM, error) {
stsRole, err := b.stsRoleForAccount(s, accountID)
if err != nil {
return nil, err
}
b.configMutex.RLock()
if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
defer b.configMutex.RUnlock()
Expand All @@ -188,8 +228,7 @@ func (b *backend) clientIAM(s logical.Storage, region string, stsRole string) (*

// Create an AWS config object using a chain of providers
var awsConfig *aws.Config
var err error
awsConfig, err = b.getClientConfig(s, region, stsRole, "iam")
awsConfig, err = b.getClientConfig(s, region, stsRole, accountID, "iam")

if err != nil {
return nil, err
Expand Down
4 changes: 4 additions & 0 deletions builtin/credential/aws/path_config_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ func (b *backend) pathConfigClientDelete(
// Remove all the cached EC2 client objects in the backend.
b.flushCachedIAMClients()

// unset the cached default AWS account ID
b.defaultAWSAccountID = ""

return nil, nil
}

Expand Down Expand Up @@ -234,6 +237,7 @@ func (b *backend) pathConfigClientCreateUpdate(
if changedCreds {
b.flushCachedEC2Clients()
b.flushCachedIAMClients()
b.defaultAWSAccountID = ""
}

return nil, nil
Expand Down
Loading

0 comments on commit d858511

Please sign in to comment.