Skip to content

Commit

Permalink
Use local aws config in cli to get account and regions (radius-projec…
Browse files Browse the repository at this point in the history
…t#7758)

# Description

Today for `rad init --full` command while configuring `aws` provider, we
are using `accesskey` and `secretkey` provided by user to get accountid
and regions, but for `irsa` we cannot get the account id and region
information.

So instead of using `accesskey` and `secretkey` to authenticate ,
updated it to use the local aws config to authenticate and get account
id and regions informaton.

## Type of change

<!--

Please select **one** of the following options that describes your
change and delete the others. Clearly identifying the type of change you
are making will help us review your PR faster, and is used in authoring
release notes.

If you are making a bug fix or functionality change to Radius and do not
have an associated issue link please create one now.

-->

- This pull request fixes a bug in Radius and has an approved issue
(issue link required).
- This pull request adds or changes features of Radius and has an
approved issue (issue link required).
- This pull request is a minor refactor, code cleanup, test improvement,
or other maintenance task and doesn't change the functionality of Radius
(issue link optional).

<!--

Please update the following to link the associated issue. This is
required for some kinds of changes (see above).

-->

Fixes: #issue_number

---------

Signed-off-by: Vishwanath Hiremath <vhiremath@microsoft.com>
Co-authored-by: Karishma Chawla <kachawla@microsoft.com>
  • Loading branch information
vishwahiremat and kachawla authored Aug 5, 2024
1 parent 63490d6 commit 4ba025d
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 46 deletions.
36 changes: 21 additions & 15 deletions pkg/cli/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package aws
import (
"context"

"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/sts"
)
Expand All @@ -28,9 +28,9 @@ import (
// Client is an interface that abstracts `rad init`'s interactions with AWS. This is for testing purposes. This is only exported because mockgen requires it.
type Client interface {
// GetCallerIdentity gets information about the provided credentials.
GetCallerIdentity(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*sts.GetCallerIdentityOutput, error)
GetCallerIdentity(ctx context.Context) (*sts.GetCallerIdentityOutput, error)
// ListRegions lists the AWS regions available (fetched from EC2.DescribeRegions API).
ListRegions(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*ec2.DescribeRegionsOutput, error)
ListRegions(ctx context.Context) (*ec2.DescribeRegionsOutput, error)
}

// NewClient returns a new Client.
Expand All @@ -43,12 +43,15 @@ type client struct{}
var _ Client = &client{}

// GetCallerIdentity gets information about the provided credentials.
func (c *client) GetCallerIdentity(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*sts.GetCallerIdentityOutput, error) {
credentialsProvider := credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, "")
stsClient := sts.New(sts.Options{
Region: region,
Credentials: credentialsProvider,
})
func (c *client) GetCallerIdentity(ctx context.Context) (*sts.GetCallerIdentityOutput, error) {
// Load the AWS SDK config and credentials
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, err
}

stsClient := sts.NewFromConfig(cfg)

result, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return nil, err
Expand All @@ -58,12 +61,15 @@ func (c *client) GetCallerIdentity(ctx context.Context, region string, accessKey
}

// ListRegions lists the AWS regions available (fetched from EC2.DescribeRegions API).
func (c *client) ListRegions(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*ec2.DescribeRegionsOutput, error) {
credentialsProvider := credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, "")
ec2Client := ec2.New(ec2.Options{
Region: region,
Credentials: credentialsProvider,
})
func (c *client) ListRegions(ctx context.Context) (*ec2.DescribeRegionsOutput, error) {
// Load the AWS SDK config and credentials
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, err
}

ec2Client := ec2.NewFromConfig(cfg)

result, err := ec2Client.DescribeRegions(ctx, &ec2.DescribeRegionsInput{})
if err != nil {
return nil, err
Expand Down
24 changes: 12 additions & 12 deletions pkg/cli/aws/client_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 27 additions & 11 deletions pkg/cli/cmd/radinit/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package radinit

import (
"context"
"fmt"

"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/charmbracelet/bubbles/textinput"
Expand All @@ -27,15 +28,15 @@ import (
)

const (
// QueryRegion is the region used for querying AWS before the user selects a region.
QueryRegion = "us-east-1"

selectAWSRegionPrompt = "Select the region you would like to deploy AWS resources to:"
enterAWSIAMAcessKeyIDPrompt = "Enter the IAM access key id:"
enterAWSIAMAcessKeyIDPlaceholder = "Enter IAM access key id..."
enterAWSIAMSecretAccessKeyPrompt = "Enter your IAM Secret Access Key:"
enterAWSIAMSecretAccessKeyPlaceholder = "Enter IAM secret access key..."
errNotEmptyTemplate = "%s cannot be empty"
confirmAWSAccountIDPromptFmt = "Use account id '%v'?"
enterAWSAccountIDPrompt = "Enter the account ID:"
enterAWSAccountIDPlaceholder = "Enter the account ID you want to use..."

awsAccessKeysCreateInstructionFmt = "\nAWS IAM Access keys (Access key ID and Secret access key) are required to access and create AWS resources.\n\nFor example, you can create one using the following command:\n\033[36maws iam create-access-key\033[0m\n\nFor more information refer to https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html.\n\n"
)
Expand All @@ -53,12 +54,12 @@ func (r *Runner) enterAWSCloudProvider(ctx context.Context) (*aws.Provider, erro
return nil, err
}

accountId, err := r.getAccountId(ctx, accessKeyID, secretAccessKey)
accountId, err := r.getAccountId(ctx)
if err != nil {
return nil, err
}

region, err := r.selectAWSRegion(ctx, QueryRegion, accessKeyID, secretAccessKey)
region, err := r.selectAWSRegion(ctx)
if err != nil {
return nil, err
}
Expand All @@ -71,21 +72,36 @@ func (r *Runner) enterAWSCloudProvider(ctx context.Context) (*aws.Provider, erro
}, nil
}

func (r *Runner) getAccountId(ctx context.Context, accessKeyID, secretAccessKey string) (string, error) {
callerIdentityOutput, err := r.awsClient.GetCallerIdentity(ctx, QueryRegion, accessKeyID, secretAccessKey)
func (r *Runner) getAccountId(ctx context.Context) (string, error) {
callerIdentityOutput, err := r.awsClient.GetCallerIdentity(ctx)
if err != nil {
return "", clierrors.MessageWithCause(err, "AWS credential verification failed.")
return "", clierrors.MessageWithCause(err, "AWS Cloud Provider setup failed, please use aws configure to set up the configuration. More information :https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html")
}

if callerIdentityOutput.Account == nil {
return "", clierrors.MessageWithCause(err, "AWS credential verification failed: Account ID is nil.")
}

return *callerIdentityOutput.Account, nil
accountID := *callerIdentityOutput.Account
addAlternateAccountID, err := prompt.YesOrNoPrompt(fmt.Sprintf(confirmAWSAccountIDPromptFmt, accountID), prompt.ConfirmYes, r.Prompter)
if err != nil {
return "", err
}

if !addAlternateAccountID {
accountID, err = r.Prompter.GetTextInput(enterAWSAccountIDPrompt, prompt.TextInputOptions{Placeholder: enterAWSAccountIDPlaceholder})
if err != nil {
return "", err
}
}

return accountID, nil
}

func (r *Runner) selectAWSRegion(ctx context.Context, region, accessKeyID, secretAccessKey string) (string, error) {
listRegionsOutput, err := r.awsClient.ListRegions(ctx, region, accessKeyID, secretAccessKey)
// selectAWSRegion prompts the user to select an AWS region from a list of available regions.
// Region list is retrieved using the locally configured AWS account.
func (r *Runner) selectAWSRegion(ctx context.Context) (string, error) {
listRegionsOutput, err := r.awsClient.ListRegions(ctx)
if err != nil {
return "", clierrors.MessageWithCause(err, "Listing AWS regions failed.")
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/cli/cmd/radinit/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ func Test_enterAWSCloudProvider(t *testing.T) {

setAWSAccessKeyIDPrompt(prompter, "access-key-id")
setAWSSecretAccessKeyPrompt(prompter, "secret-access-key")
setAWSCallerIdentity(client, QueryRegion, "access-key-id", "secret-access-key", &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")})
setAWSListRegions(client, QueryRegion, "access-key-id", "secret-access-key", &ec2.DescribeRegionsOutput{Regions: ec2Regions})
setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")})
setAWSAccountIDConfirmPrompt(prompter, "account-id", prompt.ConfirmYes)
setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: ec2Regions})
setAWSRegionPrompt(prompter, regions, "region")

provider, err := runner.enterAWSCloudProvider(context.Background())
Expand Down
20 changes: 14 additions & 6 deletions pkg/cli/cmd/radinit/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1037,16 +1037,23 @@ func setAWSSecretAccessKeyPrompt(prompter *prompt.MockInterface, secretAccessKey
Return(secretAccessKey, nil).Times(1)
}

func setAWSCallerIdentity(client *aws.MockClient, region string, accessKeyID string, secretAccessKey string, callerIdentityOutput *sts.GetCallerIdentityOutput) {
func setAWSCallerIdentity(client *aws.MockClient, callerIdentityOutput *sts.GetCallerIdentityOutput) {
client.EXPECT().
GetCallerIdentity(gomock.Any(), region, accessKeyID, secretAccessKey).
GetCallerIdentity(gomock.Any()).
Return(callerIdentityOutput, nil).
Times(1)
}

func setAWSListRegions(client *aws.MockClient, region string, accessKeyID string, secretAccessKey string, ec2DescribeRegionsOutput *ec2.DescribeRegionsOutput) {
func setAWSAccountIDConfirmPrompt(prompter *prompt.MockInterface, accountName string, choice string) {
prompter.EXPECT().
GetListInput([]string{prompt.ConfirmYes, prompt.ConfirmNo}, fmt.Sprintf(confirmAWSAccountIDPromptFmt, accountName)).
Return(choice, nil).
Times(1)
}

func setAWSListRegions(client *aws.MockClient, ec2DescribeRegionsOutput *ec2.DescribeRegionsOutput) {
client.EXPECT().
ListRegions(gomock.Any(), region, accessKeyID, secretAccessKey).
ListRegions(gomock.Any()).
Return(ec2DescribeRegionsOutput, nil).
Times(1)
}
Expand All @@ -1055,8 +1062,9 @@ func setAWSListRegions(client *aws.MockClient, region string, accessKeyID string
func setAWSCloudProvider(prompter *prompt.MockInterface, client *aws.MockClient, provider aws.Provider) {
setAWSAccessKeyIDPrompt(prompter, provider.AccessKeyID)
setAWSSecretAccessKeyPrompt(prompter, provider.SecretAccessKey)
setAWSCallerIdentity(client, QueryRegion, provider.AccessKeyID, provider.SecretAccessKey, &sts.GetCallerIdentityOutput{Account: &provider.AccountID})
setAWSListRegions(client, QueryRegion, provider.AccessKeyID, provider.SecretAccessKey, &ec2.DescribeRegionsOutput{Regions: getMockAWSRegions()})
setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: &provider.AccountID})
setAWSAccountIDConfirmPrompt(prompter, provider.AccountID, prompt.ConfirmYes)
setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: getMockAWSRegions()})
setAWSRegionPrompt(prompter, getMockAWSRegionsString(), provider.Region)
}

Expand Down

0 comments on commit 4ba025d

Please sign in to comment.