diff --git a/pkg/cli/aws/client.go b/pkg/cli/aws/client.go index b28389459b..afedb44b8d 100644 --- a/pkg/cli/aws/client.go +++ b/pkg/cli/aws/client.go @@ -18,7 +18,7 @@ package aws import ( "context" - "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/sts" ) @@ -28,9 +28,9 @@ import ( // Client is an interface that abstracts `rad init`'s interactions with AWS. This is for testing purposes. This is only exported because mockgen requires it. type Client interface { // GetCallerIdentity gets information about the provided credentials. - GetCallerIdentity(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*sts.GetCallerIdentityOutput, error) + GetCallerIdentity(ctx context.Context) (*sts.GetCallerIdentityOutput, error) // ListRegions lists the AWS regions available (fetched from EC2.DescribeRegions API). - ListRegions(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*ec2.DescribeRegionsOutput, error) + ListRegions(ctx context.Context) (*ec2.DescribeRegionsOutput, error) } // NewClient returns a new Client. @@ -43,12 +43,15 @@ type client struct{} var _ Client = &client{} // GetCallerIdentity gets information about the provided credentials. -func (c *client) GetCallerIdentity(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*sts.GetCallerIdentityOutput, error) { - credentialsProvider := credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, "") - stsClient := sts.New(sts.Options{ - Region: region, - Credentials: credentialsProvider, - }) +func (c *client) GetCallerIdentity(ctx context.Context) (*sts.GetCallerIdentityOutput, error) { + // Load the AWS SDK config and credentials + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return nil, err + } + + stsClient := sts.NewFromConfig(cfg) + result, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) if err != nil { return nil, err @@ -58,12 +61,15 @@ func (c *client) GetCallerIdentity(ctx context.Context, region string, accessKey } // ListRegions lists the AWS regions available (fetched from EC2.DescribeRegions API). -func (c *client) ListRegions(ctx context.Context, region string, accessKeyID string, secretAccessKey string) (*ec2.DescribeRegionsOutput, error) { - credentialsProvider := credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, "") - ec2Client := ec2.New(ec2.Options{ - Region: region, - Credentials: credentialsProvider, - }) +func (c *client) ListRegions(ctx context.Context) (*ec2.DescribeRegionsOutput, error) { + // Load the AWS SDK config and credentials + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return nil, err + } + + ec2Client := ec2.NewFromConfig(cfg) + result, err := ec2Client.DescribeRegions(ctx, &ec2.DescribeRegionsInput{}) if err != nil { return nil, err diff --git a/pkg/cli/aws/client_mock.go b/pkg/cli/aws/client_mock.go index 13391cd4c8..0e8c5e27bc 100644 --- a/pkg/cli/aws/client_mock.go +++ b/pkg/cli/aws/client_mock.go @@ -42,18 +42,18 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { } // GetCallerIdentity mocks base method. -func (m *MockClient) GetCallerIdentity(arg0 context.Context, arg1, arg2, arg3 string) (*sts.GetCallerIdentityOutput, error) { +func (m *MockClient) GetCallerIdentity(arg0 context.Context) (*sts.GetCallerIdentityOutput, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetCallerIdentity", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "GetCallerIdentity", arg0) ret0, _ := ret[0].(*sts.GetCallerIdentityOutput) ret1, _ := ret[1].(error) return ret0, ret1 } // GetCallerIdentity indicates an expected call of GetCallerIdentity. -func (mr *MockClientMockRecorder) GetCallerIdentity(arg0, arg1, arg2, arg3 any) *MockClientGetCallerIdentityCall { +func (mr *MockClientMockRecorder) GetCallerIdentity(arg0 any) *MockClientGetCallerIdentityCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCallerIdentity", reflect.TypeOf((*MockClient)(nil).GetCallerIdentity), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCallerIdentity", reflect.TypeOf((*MockClient)(nil).GetCallerIdentity), arg0) return &MockClientGetCallerIdentityCall{Call: call} } @@ -69,30 +69,30 @@ func (c *MockClientGetCallerIdentityCall) Return(arg0 *sts.GetCallerIdentityOutp } // Do rewrite *gomock.Call.Do -func (c *MockClientGetCallerIdentityCall) Do(f func(context.Context, string, string, string) (*sts.GetCallerIdentityOutput, error)) *MockClientGetCallerIdentityCall { +func (c *MockClientGetCallerIdentityCall) Do(f func(context.Context) (*sts.GetCallerIdentityOutput, error)) *MockClientGetCallerIdentityCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockClientGetCallerIdentityCall) DoAndReturn(f func(context.Context, string, string, string) (*sts.GetCallerIdentityOutput, error)) *MockClientGetCallerIdentityCall { +func (c *MockClientGetCallerIdentityCall) DoAndReturn(f func(context.Context) (*sts.GetCallerIdentityOutput, error)) *MockClientGetCallerIdentityCall { c.Call = c.Call.DoAndReturn(f) return c } // ListRegions mocks base method. -func (m *MockClient) ListRegions(arg0 context.Context, arg1, arg2, arg3 string) (*ec2.DescribeRegionsOutput, error) { +func (m *MockClient) ListRegions(arg0 context.Context) (*ec2.DescribeRegionsOutput, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListRegions", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "ListRegions", arg0) ret0, _ := ret[0].(*ec2.DescribeRegionsOutput) ret1, _ := ret[1].(error) return ret0, ret1 } // ListRegions indicates an expected call of ListRegions. -func (mr *MockClientMockRecorder) ListRegions(arg0, arg1, arg2, arg3 any) *MockClientListRegionsCall { +func (mr *MockClientMockRecorder) ListRegions(arg0 any) *MockClientListRegionsCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRegions", reflect.TypeOf((*MockClient)(nil).ListRegions), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRegions", reflect.TypeOf((*MockClient)(nil).ListRegions), arg0) return &MockClientListRegionsCall{Call: call} } @@ -108,13 +108,13 @@ func (c *MockClientListRegionsCall) Return(arg0 *ec2.DescribeRegionsOutput, arg1 } // Do rewrite *gomock.Call.Do -func (c *MockClientListRegionsCall) Do(f func(context.Context, string, string, string) (*ec2.DescribeRegionsOutput, error)) *MockClientListRegionsCall { +func (c *MockClientListRegionsCall) Do(f func(context.Context) (*ec2.DescribeRegionsOutput, error)) *MockClientListRegionsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockClientListRegionsCall) DoAndReturn(f func(context.Context, string, string, string) (*ec2.DescribeRegionsOutput, error)) *MockClientListRegionsCall { +func (c *MockClientListRegionsCall) DoAndReturn(f func(context.Context) (*ec2.DescribeRegionsOutput, error)) *MockClientListRegionsCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/pkg/cli/cmd/radinit/aws.go b/pkg/cli/cmd/radinit/aws.go index 38782c3ed0..6d3747309e 100644 --- a/pkg/cli/cmd/radinit/aws.go +++ b/pkg/cli/cmd/radinit/aws.go @@ -18,6 +18,7 @@ package radinit import ( "context" + "fmt" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/charmbracelet/bubbles/textinput" @@ -27,15 +28,15 @@ import ( ) const ( - // QueryRegion is the region used for querying AWS before the user selects a region. - QueryRegion = "us-east-1" - selectAWSRegionPrompt = "Select the region you would like to deploy AWS resources to:" enterAWSIAMAcessKeyIDPrompt = "Enter the IAM access key id:" enterAWSIAMAcessKeyIDPlaceholder = "Enter IAM access key id..." enterAWSIAMSecretAccessKeyPrompt = "Enter your IAM Secret Access Key:" enterAWSIAMSecretAccessKeyPlaceholder = "Enter IAM secret access key..." errNotEmptyTemplate = "%s cannot be empty" + confirmAWSAccountIDPromptFmt = "Use account id '%v'?" + enterAWSAccountIDPrompt = "Enter the account ID:" + enterAWSAccountIDPlaceholder = "Enter the account ID you want to use..." awsAccessKeysCreateInstructionFmt = "\nAWS IAM Access keys (Access key ID and Secret access key) are required to access and create AWS resources.\n\nFor example, you can create one using the following command:\n\033[36maws iam create-access-key\033[0m\n\nFor more information refer to https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html.\n\n" ) @@ -53,12 +54,12 @@ func (r *Runner) enterAWSCloudProvider(ctx context.Context) (*aws.Provider, erro return nil, err } - accountId, err := r.getAccountId(ctx, accessKeyID, secretAccessKey) + accountId, err := r.getAccountId(ctx) if err != nil { return nil, err } - region, err := r.selectAWSRegion(ctx, QueryRegion, accessKeyID, secretAccessKey) + region, err := r.selectAWSRegion(ctx) if err != nil { return nil, err } @@ -71,21 +72,36 @@ func (r *Runner) enterAWSCloudProvider(ctx context.Context) (*aws.Provider, erro }, nil } -func (r *Runner) getAccountId(ctx context.Context, accessKeyID, secretAccessKey string) (string, error) { - callerIdentityOutput, err := r.awsClient.GetCallerIdentity(ctx, QueryRegion, accessKeyID, secretAccessKey) +func (r *Runner) getAccountId(ctx context.Context) (string, error) { + callerIdentityOutput, err := r.awsClient.GetCallerIdentity(ctx) if err != nil { - return "", clierrors.MessageWithCause(err, "AWS credential verification failed.") + return "", clierrors.MessageWithCause(err, "AWS Cloud Provider setup failed, please use aws configure to set up the configuration. More information :https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html") } if callerIdentityOutput.Account == nil { return "", clierrors.MessageWithCause(err, "AWS credential verification failed: Account ID is nil.") } - return *callerIdentityOutput.Account, nil + accountID := *callerIdentityOutput.Account + addAlternateAccountID, err := prompt.YesOrNoPrompt(fmt.Sprintf(confirmAWSAccountIDPromptFmt, accountID), prompt.ConfirmYes, r.Prompter) + if err != nil { + return "", err + } + + if !addAlternateAccountID { + accountID, err = r.Prompter.GetTextInput(enterAWSAccountIDPrompt, prompt.TextInputOptions{Placeholder: enterAWSAccountIDPlaceholder}) + if err != nil { + return "", err + } + } + + return accountID, nil } -func (r *Runner) selectAWSRegion(ctx context.Context, region, accessKeyID, secretAccessKey string) (string, error) { - listRegionsOutput, err := r.awsClient.ListRegions(ctx, region, accessKeyID, secretAccessKey) +// selectAWSRegion prompts the user to select an AWS region from a list of available regions. +// Region list is retrieved using the locally configured AWS account. +func (r *Runner) selectAWSRegion(ctx context.Context) (string, error) { + listRegionsOutput, err := r.awsClient.ListRegions(ctx) if err != nil { return "", clierrors.MessageWithCause(err, "Listing AWS regions failed.") } diff --git a/pkg/cli/cmd/radinit/aws_test.go b/pkg/cli/cmd/radinit/aws_test.go index 337b08d732..1a8e7e3cb7 100644 --- a/pkg/cli/cmd/radinit/aws_test.go +++ b/pkg/cli/cmd/radinit/aws_test.go @@ -45,8 +45,9 @@ func Test_enterAWSCloudProvider(t *testing.T) { setAWSAccessKeyIDPrompt(prompter, "access-key-id") setAWSSecretAccessKeyPrompt(prompter, "secret-access-key") - setAWSCallerIdentity(client, QueryRegion, "access-key-id", "secret-access-key", &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")}) - setAWSListRegions(client, QueryRegion, "access-key-id", "secret-access-key", &ec2.DescribeRegionsOutput{Regions: ec2Regions}) + setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")}) + setAWSAccountIDConfirmPrompt(prompter, "account-id", prompt.ConfirmYes) + setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: ec2Regions}) setAWSRegionPrompt(prompter, regions, "region") provider, err := runner.enterAWSCloudProvider(context.Background()) diff --git a/pkg/cli/cmd/radinit/init_test.go b/pkg/cli/cmd/radinit/init_test.go index 30b4c51071..894dd658df 100644 --- a/pkg/cli/cmd/radinit/init_test.go +++ b/pkg/cli/cmd/radinit/init_test.go @@ -1037,16 +1037,23 @@ func setAWSSecretAccessKeyPrompt(prompter *prompt.MockInterface, secretAccessKey Return(secretAccessKey, nil).Times(1) } -func setAWSCallerIdentity(client *aws.MockClient, region string, accessKeyID string, secretAccessKey string, callerIdentityOutput *sts.GetCallerIdentityOutput) { +func setAWSCallerIdentity(client *aws.MockClient, callerIdentityOutput *sts.GetCallerIdentityOutput) { client.EXPECT(). - GetCallerIdentity(gomock.Any(), region, accessKeyID, secretAccessKey). + GetCallerIdentity(gomock.Any()). Return(callerIdentityOutput, nil). Times(1) } -func setAWSListRegions(client *aws.MockClient, region string, accessKeyID string, secretAccessKey string, ec2DescribeRegionsOutput *ec2.DescribeRegionsOutput) { +func setAWSAccountIDConfirmPrompt(prompter *prompt.MockInterface, accountName string, choice string) { + prompter.EXPECT(). + GetListInput([]string{prompt.ConfirmYes, prompt.ConfirmNo}, fmt.Sprintf(confirmAWSAccountIDPromptFmt, accountName)). + Return(choice, nil). + Times(1) +} + +func setAWSListRegions(client *aws.MockClient, ec2DescribeRegionsOutput *ec2.DescribeRegionsOutput) { client.EXPECT(). - ListRegions(gomock.Any(), region, accessKeyID, secretAccessKey). + ListRegions(gomock.Any()). Return(ec2DescribeRegionsOutput, nil). Times(1) } @@ -1055,8 +1062,9 @@ func setAWSListRegions(client *aws.MockClient, region string, accessKeyID string func setAWSCloudProvider(prompter *prompt.MockInterface, client *aws.MockClient, provider aws.Provider) { setAWSAccessKeyIDPrompt(prompter, provider.AccessKeyID) setAWSSecretAccessKeyPrompt(prompter, provider.SecretAccessKey) - setAWSCallerIdentity(client, QueryRegion, provider.AccessKeyID, provider.SecretAccessKey, &sts.GetCallerIdentityOutput{Account: &provider.AccountID}) - setAWSListRegions(client, QueryRegion, provider.AccessKeyID, provider.SecretAccessKey, &ec2.DescribeRegionsOutput{Regions: getMockAWSRegions()}) + setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: &provider.AccountID}) + setAWSAccountIDConfirmPrompt(prompter, provider.AccountID, prompt.ConfirmYes) + setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: getMockAWSRegions()}) setAWSRegionPrompt(prompter, getMockAWSRegionsString(), provider.Region) }