Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use local aws config in cli to get account and regions #7758

36 changes: 21 additions & 15 deletions pkg/cli/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package aws
import (
"context"

"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/sts"
)
Expand All @@ -28,9 +28,9 @@ import (
// Client is an interface that abstracts `rad init`'s interactions with AWS. This is for testing purposes. This is only exported because mockgen requires it.
type Client interface {
// GetCallerIdentity gets information about the provided credentials.
GetCallerIdentity(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*sts.GetCallerIdentityOutput, error)
GetCallerIdentity(ctx context.Context) (*sts.GetCallerIdentityOutput, error)
// ListRegions lists the AWS regions available (fetched from EC2.DescribeRegions API).
ListRegions(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*ec2.DescribeRegionsOutput, error)
ListRegions(ctx context.Context) (*ec2.DescribeRegionsOutput, error)
}

// NewClient returns a new Client.
Expand All @@ -43,12 +43,15 @@ type client struct{}
var _ Client = &client{}

// GetCallerIdentity gets information about the provided credentials.
func (c *client) GetCallerIdentity(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*sts.GetCallerIdentityOutput, error) {
credentialsProvider := credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, "")
stsClient := sts.New(sts.Options{
Region: region,
Credentials: credentialsProvider,
})
func (c *client) GetCallerIdentity(ctx context.Context) (*sts.GetCallerIdentityOutput, error) {
// Load the AWS SDK config and credentials
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should "default" be a const?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the config.WithSharedConfigProfile

return nil, err
}

stsClient := sts.NewFromConfig(cfg)

result, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return nil, err
Expand All @@ -58,12 +61,15 @@ func (c *client) GetCallerIdentity(ctx context.Context, region string, accessKey
}

// ListRegions lists the AWS regions available (fetched from EC2.DescribeRegions API).
func (c *client) ListRegions(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*ec2.DescribeRegionsOutput, error) {
credentialsProvider := credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, "")
ec2Client := ec2.New(ec2.Options{
Region: region,
Credentials: credentialsProvider,
})
func (c *client) ListRegions(ctx context.Context) (*ec2.DescribeRegionsOutput, error) {
// Load the AWS SDK config and credentials
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, err
}

ec2Client := ec2.NewFromConfig(cfg)

result, err := ec2Client.DescribeRegions(ctx, &ec2.DescribeRegionsInput{})
if err != nil {
return nil, err
Expand Down
24 changes: 12 additions & 12 deletions pkg/cli/aws/client_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 27 additions & 11 deletions pkg/cli/cmd/radinit/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package radinit

import (
"context"
"fmt"

"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/charmbracelet/bubbles/textinput"
Expand All @@ -27,15 +28,15 @@ import (
)

const (
// QueryRegion is the region used for querying AWS before the user selects a region.
QueryRegion = "us-east-1"

selectAWSRegionPrompt = "Select the region you would like to deploy AWS resources to:"
enterAWSIAMAcessKeyIDPrompt = "Enter the IAM access key id:"
enterAWSIAMAcessKeyIDPlaceholder = "Enter IAM access key id..."
enterAWSIAMSecretAccessKeyPrompt = "Enter your IAM Secret Access Key:"
enterAWSIAMSecretAccessKeyPlaceholder = "Enter IAM secret access key..."
errNotEmptyTemplate = "%s cannot be empty"
confirmAWSAccountIDPromptFmt = "Use account id '%v'?"
enterAWSAccountIDPrompt = "Enter the account ID:"
enterAWSAccountIDPlaceholder = "Enter the account ID you want to use..."

awsAccessKeysCreateInstructionFmt = "\nAWS IAM Access keys (Access key ID and Secret access key) are required to access and create AWS resources.\n\nFor example, you can create one using the following command:\n\033[36maws iam create-access-key\033[0m\n\nFor more information refer to https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html.\n\n"
)
Expand All @@ -53,12 +54,12 @@ func (r *Runner) enterAWSCloudProvider(ctx context.Context) (*aws.Provider, erro
return nil, err
}

accountId, err := r.getAccountId(ctx, accessKeyID, secretAccessKey)
accountId, err := r.getAccountId(ctx)
if err != nil {
return nil, err
}

region, err := r.selectAWSRegion(ctx, QueryRegion, accessKeyID, secretAccessKey)
region, err := r.selectAWSRegion(ctx)
if err != nil {
return nil, err
}
Expand All @@ -71,21 +72,36 @@ func (r *Runner) enterAWSCloudProvider(ctx context.Context) (*aws.Provider, erro
}, nil
}

func (r *Runner) getAccountId(ctx context.Context, accessKeyID, secretAccessKey string) (string, error) {
callerIdentityOutput, err := r.awsClient.GetCallerIdentity(ctx, QueryRegion, accessKeyID, secretAccessKey)
func (r *Runner) getAccountId(ctx context.Context) (string, error) {
callerIdentityOutput, err := r.awsClient.GetCallerIdentity(ctx)
if err != nil {
return "", clierrors.MessageWithCause(err, "AWS credential verification failed.")
return "", clierrors.MessageWithCause(err, "AWS Cloud Provider setup failed, please use aws configure to set up the configuration. More information :https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html")
}
Copy link
Contributor

@nithyatsu nithyatsu Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, would it help user if we add some detail, like "AWS credential verification failed. Please use aws configure to configure credentials and then try again " ? cc @Reshrahim


if callerIdentityOutput.Account == nil {
return "", clierrors.MessageWithCause(err, "AWS credential verification failed: Account ID is nil.")
}

return *callerIdentityOutput.Account, nil
accountID := *callerIdentityOutput.Account
addAlternateAccountID, err := prompt.YesOrNoPrompt(fmt.Sprintf(confirmAWSAccountIDPromptFmt, accountID), prompt.ConfirmYes, r.Prompter)
if err != nil {
return "", err
}

if !addAlternateAccountID {
accountID, err = r.Prompter.GetTextInput(enterAWSAccountIDPrompt, prompt.TextInputOptions{Placeholder: enterAWSAccountIDPlaceholder})
if err != nil {
return "", err
}
}

return accountID, nil
}

func (r *Runner) selectAWSRegion(ctx context.Context, region, accessKeyID, secretAccessKey string) (string, error) {
listRegionsOutput, err := r.awsClient.ListRegions(ctx, region, accessKeyID, secretAccessKey)
// selectAWSRegion prompts the user to select an AWS region from a list of available regions.
// Region list is retrieved using the locally configured AWS account.
func (r *Runner) selectAWSRegion(ctx context.Context) (string, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for easier context, can you please add a comment here that it uses local aws config to list regions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added it

listRegionsOutput, err := r.awsClient.ListRegions(ctx)
if err != nil {
return "", clierrors.MessageWithCause(err, "Listing AWS regions failed.")
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/cli/cmd/radinit/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ func Test_enterAWSCloudProvider(t *testing.T) {

setAWSAccessKeyIDPrompt(prompter, "access-key-id")
setAWSSecretAccessKeyPrompt(prompter, "secret-access-key")
setAWSCallerIdentity(client, QueryRegion, "access-key-id", "secret-access-key", &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")})
setAWSListRegions(client, QueryRegion, "access-key-id", "secret-access-key", &ec2.DescribeRegionsOutput{Regions: ec2Regions})
setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")})
setAWSAccountIDConfirmPrompt(prompter, "account-id", prompt.ConfirmYes)
setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: ec2Regions})
setAWSRegionPrompt(prompter, regions, "region")

provider, err := runner.enterAWSCloudProvider(context.Background())
Expand Down
20 changes: 14 additions & 6 deletions pkg/cli/cmd/radinit/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1037,16 +1037,23 @@ func setAWSSecretAccessKeyPrompt(prompter *prompt.MockInterface, secretAccessKey
Return(secretAccessKey, nil).Times(1)
}

func setAWSCallerIdentity(client *aws.MockClient, region string, accessKeyID string, secretAccessKey string, callerIdentityOutput *sts.GetCallerIdentityOutput) {
func setAWSCallerIdentity(client *aws.MockClient, callerIdentityOutput *sts.GetCallerIdentityOutput) {
client.EXPECT().
GetCallerIdentity(gomock.Any(), region, accessKeyID, secretAccessKey).
GetCallerIdentity(gomock.Any()).
Return(callerIdentityOutput, nil).
Times(1)
}

func setAWSListRegions(client *aws.MockClient, region string, accessKeyID string, secretAccessKey string, ec2DescribeRegionsOutput *ec2.DescribeRegionsOutput) {
func setAWSAccountIDConfirmPrompt(prompter *prompt.MockInterface, accountName string, choice string) {
prompter.EXPECT().
GetListInput([]string{prompt.ConfirmYes, prompt.ConfirmNo}, fmt.Sprintf(confirmAWSAccountIDPromptFmt, accountName)).
Return(choice, nil).
Times(1)
}

func setAWSListRegions(client *aws.MockClient, ec2DescribeRegionsOutput *ec2.DescribeRegionsOutput) {
client.EXPECT().
ListRegions(gomock.Any(), region, accessKeyID, secretAccessKey).
ListRegions(gomock.Any()).
Return(ec2DescribeRegionsOutput, nil).
Times(1)
}
Expand All @@ -1055,8 +1062,9 @@ func setAWSListRegions(client *aws.MockClient, region string, accessKeyID string
func setAWSCloudProvider(prompter *prompt.MockInterface, client *aws.MockClient, provider aws.Provider) {
setAWSAccessKeyIDPrompt(prompter, provider.AccessKeyID)
setAWSSecretAccessKeyPrompt(prompter, provider.SecretAccessKey)
setAWSCallerIdentity(client, QueryRegion, provider.AccessKeyID, provider.SecretAccessKey, &sts.GetCallerIdentityOutput{Account: &provider.AccountID})
setAWSListRegions(client, QueryRegion, provider.AccessKeyID, provider.SecretAccessKey, &ec2.DescribeRegionsOutput{Regions: getMockAWSRegions()})
setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: &provider.AccountID})
setAWSAccountIDConfirmPrompt(prompter, provider.AccountID, prompt.ConfirmYes)
setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: getMockAWSRegions()})
setAWSRegionPrompt(prompter, getMockAWSRegionsString(), provider.Region)
}

Expand Down