Skip to content

Commit

Permalink
Add AWS cloud provider option for IAM role
Browse files Browse the repository at this point in the history
Currently the AWS cloud provider uses the EC2 instance role when
interacting with AWS APIs. This change gives the option to provide and IAM
role that the cloud provider will assume before calling the APIs. All
resources created by the role will be owned by that account instead of
the account where the EC2 instance is running.
  • Loading branch information
Bryce Carman committed Feb 16, 2018
1 parent c105796 commit 3b99e1b
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 38 deletions.
2 changes: 2 additions & 0 deletions pkg/cloudprovider/providers/aws/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ go_library(
"//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/ec2metadata:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/aws/session:go_default_library",
Expand All @@ -46,6 +47,7 @@ go_library(
"//vendor/github.com/aws/aws-sdk-go/service/elb:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/service/elbv2:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/service/kms:go_default_library",
"//vendor/github.com/aws/aws-sdk-go/service/sts:go_default_library",
"//vendor/github.com/golang/glog:go_default_library",
"//vendor/github.com/prometheus/client_golang/prometheus:go_default_library",
"//vendor/gopkg.in/gcfg.v1:go_default_library",
Expand Down
61 changes: 46 additions & 15 deletions pkg/cloudprovider/providers/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -41,6 +42,7 @@ import (
"github.com/aws/aws-sdk-go/service/elb"
"github.com/aws/aws-sdk-go/service/elbv2"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/golang/glog"
"github.com/prometheus/client_golang/prometheus"
clientset "k8s.io/client-go/kubernetes"
Expand Down Expand Up @@ -526,6 +528,9 @@ type CloudConfig struct {
// RouteTableID enables using a specific RouteTable
RouteTableID string

// RoleARN is the IAM role to assume when interaction with AWS APIs.
RoleARN string

// KubernetesClusterTag is the legacy cluster id we'll use to identify our cluster resources
KubernetesClusterTag string
// KubernetesClusterID is the cluster id we'll use to identify our cluster resources
Expand Down Expand Up @@ -927,22 +932,43 @@ func (s *awsSdkEC2) DescribeVpcs(request *ec2.DescribeVpcsInput) (*ec2.DescribeV
func init() {
registerMetrics()
cloudprovider.RegisterCloudProvider(ProviderName, func(config io.Reader) (cloudprovider.Interface, error) {
cfg, err := readAWSCloudConfig(config)
if err != nil {
return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err)
}

sess, err := session.NewSession(&aws.Config{})
if err != nil {
return nil, fmt.Errorf("unable to initialize AWS session: %v", err)
}

var provider credentials.Provider
if cfg.Global.RoleARN == "" {
provider = &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(sess),
}
} else {
glog.Infof("Using AWS assumed role %v", cfg.Global.RoleARN)
provider = &stscreds.AssumeRoleProvider{
Client: sts.New(sess),
RoleARN: cfg.Global.RoleARN,
}
}

creds := credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(session.New(&aws.Config{})),
},
provider,
&credentials.SharedCredentialsProvider{},
})

aws := newAWSSDKProvider(creds)
return newAWSCloud(config, aws)
return newAWSCloud(*cfg, aws)
})
}

// readAWSCloudConfig reads an instance of AWSCloudConfig from config reader.
func readAWSCloudConfig(config io.Reader, metadata EC2Metadata) (*CloudConfig, error) {
func readAWSCloudConfig(config io.Reader) (*CloudConfig, error) {
var cfg CloudConfig
var err error

Expand All @@ -953,20 +979,25 @@ func readAWSCloudConfig(config io.Reader, metadata EC2Metadata) (*CloudConfig, e
}
}

return &cfg, nil
}

func updateConfigZone(cfg *CloudConfig, metadata EC2Metadata) error {
if cfg.Global.Zone == "" {
if metadata != nil {
glog.Info("Zone not specified in configuration file; querying AWS metadata service")
var err error
cfg.Global.Zone, err = getAvailabilityZone(metadata)
if err != nil {
return nil, err
return err
}
}
if cfg.Global.Zone == "" {
return nil, fmt.Errorf("no zone specified in configuration file")
return fmt.Errorf("no zone specified in configuration file")
}
}

return &cfg, nil
return nil
}

func getInstanceType(metadata EC2Metadata) (string, error) {
Expand All @@ -989,7 +1020,7 @@ func azToRegion(az string) (string, error) {

// newAWSCloud creates a new instance of AWSCloud.
// AWSProvider and instanceId are primarily for tests
func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) {
func newAWSCloud(cfg CloudConfig, awsServices Services) (*Cloud, error) {
// We have some state in the Cloud object - in particular the attaching map
// Log so that if we are building multiple Cloud objects, it is obvious!
glog.Infof("Building AWS cloudprovider")
Expand All @@ -999,9 +1030,9 @@ func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) {
return nil, fmt.Errorf("error creating AWS metadata client: %q", err)
}

cfg, err := readAWSCloudConfig(config, metadata)
err = updateConfigZone(&cfg, metadata)
if err != nil {
return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err)
return nil, fmt.Errorf("unable to determine AWS zone from cloud provider config or EC2 instance metadata: %v", err)
}

zone := cfg.Global.Zone
Expand Down Expand Up @@ -1059,16 +1090,17 @@ func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) {
asg: asg,
metadata: metadata,
kms: kms,
cfg: cfg,
cfg: &cfg,
region: regionName,

attaching: make(map[types.NodeName]map[mountDevice]awsVolumeID),
deviceAllocators: make(map[types.NodeName]DeviceAllocator),
}
awsCloud.instanceCache.cloud = awsCloud

if cfg.Global.VPC != "" && cfg.Global.SubnetID != "" && (cfg.Global.KubernetesClusterTag != "" || cfg.Global.KubernetesClusterID != "") {
// When the master is running on a different AWS account, cloud provider or on-premises
tagged := cfg.Global.KubernetesClusterTag != "" || cfg.Global.KubernetesClusterID != ""
if cfg.Global.VPC != "" && (cfg.Global.SubnetID != "" || cfg.Global.RoleARN != "") && tagged {
// When the master is running on a different AWS account, cloud provider or on-premise
// build up a dummy instance and use the VPC from the nodes account
glog.Info("Master is configured to run on a different AWS account, different cloud provider or on-premises")
awsCloud.selfAWSInstance = &awsInstance{
Expand All @@ -1084,7 +1116,6 @@ func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) {
}
awsCloud.selfAWSInstance = selfAWSInstance
awsCloud.vpcID = selfAWSInstance.vpcID

}

if cfg.Global.KubernetesClusterTag != "" || cfg.Global.KubernetesClusterID != "" {
Expand Down
47 changes: 27 additions & 20 deletions pkg/cloudprovider/providers/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ func TestReadAWSCloudConfig(t *testing.T) {
if test.aws != nil {
metadata, _ = test.aws.Metadata()
}
cfg, err := readAWSCloudConfig(test.reader, metadata)
cfg, err := readAWSCloudConfig(test.reader)
if err == nil {
err = updateConfigZone(cfg, metadata)
}
if test.expectError {
if err == nil {
t.Errorf("Should error for case %s (cfg=%v)", test.name, cfg)
Expand Down Expand Up @@ -213,7 +216,11 @@ func TestNewAWSCloud(t *testing.T) {

for _, test := range tests {
t.Logf("Running test case %s", test.name)
c, err := newAWSCloud(test.reader, test.awsServices)
cfg, err := readAWSCloudConfig(test.reader)
var c *Cloud
if err == nil {
c, err = newAWSCloud(*cfg, test.awsServices)
}
if test.expectError {
if err == nil {
t.Errorf("Should error for case %s", test.name)
Expand All @@ -233,7 +240,7 @@ func mockInstancesResp(selfInstance *ec2.Instance, instances []*ec2.Instance) (*
awsServices := newMockedFakeAWSServices(TestClusterId)
awsServices.instances = instances
awsServices.selfInstance = selfInstance
awsCloud, err := newAWSCloud(nil, awsServices)
awsCloud, err := newAWSCloud(CloudConfig{}, awsServices)
if err != nil {
panic(err)
}
Expand All @@ -242,7 +249,7 @@ func mockInstancesResp(selfInstance *ec2.Instance, instances []*ec2.Instance) (*

func mockAvailabilityZone(availabilityZone string) *Cloud {
awsServices := newMockedFakeAWSServices(TestClusterId).WithAz(availabilityZone)
awsCloud, err := newAWSCloud(nil, awsServices)
awsCloud, err := newAWSCloud(CloudConfig{}, awsServices)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -389,7 +396,7 @@ func TestGetRegion(t *testing.T) {

func TestFindVPCID(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, err := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, err := newAWSCloud(CloudConfig{}, awsServices)
if err != nil {
t.Errorf("Error building aws cloud: %v", err)
return
Expand Down Expand Up @@ -463,7 +470,7 @@ func constructRouteTable(subnetID string, public bool) *ec2.RouteTable {

func TestSubnetIDsinVPC(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, err := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, err := newAWSCloud(CloudConfig{}, awsServices)
if err != nil {
t.Errorf("Error building aws cloud: %v", err)
return
Expand Down Expand Up @@ -798,7 +805,7 @@ func TestFindInstanceByNodeNameExcludesTerminatedInstances(t *testing.T) {
instances := []*ec2.Instance{&terminatedInstance, &runningInstance}
awsServices.instances = append(awsServices.instances, instances...)

c, err := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, err := newAWSCloud(CloudConfig{}, awsServices)
if err != nil {
t.Errorf("Error building aws cloud: %v", err)
return
Expand All @@ -818,7 +825,7 @@ func TestFindInstanceByNodeNameExcludesTerminatedInstances(t *testing.T) {

func TestGetInstanceByNodeNameBatching(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, err := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, err := newAWSCloud(CloudConfig{}, awsServices)
assert.Nil(t, err, "Error building aws cloud: %v", err)
var tag ec2.Tag
tag.Key = aws.String(TagNameKubernetesClusterPrefix + TestClusterId)
Expand All @@ -845,7 +852,7 @@ func TestGetInstanceByNodeNameBatching(t *testing.T) {

func TestGetVolumeLabels(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, err := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, err := newAWSCloud(CloudConfig{}, awsServices)
assert.Nil(t, err, "Error building aws cloud: %v", err)
volumeId := awsVolumeID("vol-VolumeId")
expectedVolumeRequest := &ec2.DescribeVolumesInput{VolumeIds: []*string{volumeId.awsString()}}
Expand All @@ -867,31 +874,31 @@ func TestGetVolumeLabels(t *testing.T) {

func TestDescribeLoadBalancerOnDelete(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, _ := newAWSCloud(CloudConfig{}, awsServices)
awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid")

c.EnsureLoadBalancerDeleted(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}})
}

func TestDescribeLoadBalancerOnUpdate(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, _ := newAWSCloud(CloudConfig{}, awsServices)
awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid")

c.UpdateLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}, []*v1.Node{})
}

func TestDescribeLoadBalancerOnGet(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, _ := newAWSCloud(CloudConfig{}, awsServices)
awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid")

c.GetLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}})
}

func TestDescribeLoadBalancerOnEnsure(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, _ := newAWSCloud(CloudConfig{}, awsServices)
awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid")

c.EnsureLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}, []*v1.Node{})
Expand Down Expand Up @@ -1123,7 +1130,7 @@ func TestGetLoadBalancerAdditionalTags(t *testing.T) {

func TestLBExtraSecurityGroupsAnnotation(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, _ := newAWSCloud(CloudConfig{}, awsServices)

sg1 := "sg-000001"
sg2 := "sg-000002"
Expand Down Expand Up @@ -1159,7 +1166,7 @@ func TestLBExtraSecurityGroupsAnnotation(t *testing.T) {
func TestAddLoadBalancerTags(t *testing.T) {
loadBalancerName := "test-elb"
awsServices := newMockedFakeAWSServices(TestClusterId)
c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, _ := newAWSCloud(CloudConfig{}, awsServices)

want := make(map[string]string)
want["tag1"] = "val1"
Expand Down Expand Up @@ -1215,7 +1222,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, err := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, err := newAWSCloud(CloudConfig{}, awsServices)
assert.Nil(t, err, "Error building aws cloud: %v", err)
expectedHC := *defaultHC
if test.overriddenFieldName != "" { // cater for test case with no overrides
Expand All @@ -1233,7 +1240,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) {

t.Run("does not make an API call if the current health check is the same", func(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, err := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, err := newAWSCloud(CloudConfig{}, awsServices)
assert.Nil(t, err, "Error building aws cloud: %v", err)
expectedHC := *defaultHC
timeout := int64(3)
Expand All @@ -1255,7 +1262,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) {

t.Run("validates resulting expected health check before making an API call", func(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, err := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, err := newAWSCloud(CloudConfig{}, awsServices)
assert.Nil(t, err, "Error building aws cloud: %v", err)
expectedHC := *defaultHC
invalidThreshold := int64(1)
Expand All @@ -1271,7 +1278,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) {

t.Run("handles invalid override values", func(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, err := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, err := newAWSCloud(CloudConfig{}, awsServices)
assert.Nil(t, err, "Error building aws cloud: %v", err)
annotations := map[string]string{ServiceAnnotationLoadBalancerHCTimeout: "3.3"}

Expand All @@ -1283,7 +1290,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) {

t.Run("returns error when updating the health check fails", func(t *testing.T) {
awsServices := newMockedFakeAWSServices(TestClusterId)
c, err := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, err := newAWSCloud(CloudConfig{}, awsServices)
assert.Nil(t, err, "Error building aws cloud: %v", err)
returnErr := fmt.Errorf("throttling error")
awsServices.elb.(*MockedFakeELB).expectConfigureHealthCheck(&lbName, defaultHC, returnErr)
Expand Down
2 changes: 1 addition & 1 deletion pkg/cloudprovider/providers/aws/regions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestRecognizesNewRegion(t *testing.T) {
}

awsServices := NewFakeAWSServices(TestClusterId).WithAz(region + "a")
_, err := newAWSCloud(nil, awsServices)
_, err := newAWSCloud(CloudConfig{}, awsServices)
if err != nil {
t.Errorf("error building AWS cloud: %v", err)
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/cloudprovider/providers/aws/tags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ package aws
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"strings"
"testing"
)

func TestFilterTags(t *testing.T) {
awsServices := NewFakeAWSServices(TestClusterId)
c, err := newAWSCloud(strings.NewReader("[global]"), awsServices)
c, err := newAWSCloud(CloudConfig{}, awsServices)
if err != nil {
t.Errorf("Error building aws cloud: %v", err)
return
Expand Down

0 comments on commit 3b99e1b

Please sign in to comment.