diff --git a/pkg/cli/aws/provider.go b/pkg/cli/aws/provider.go index 442bf16d63..b9ed76aba4 100644 --- a/pkg/cli/aws/provider.go +++ b/pkg/cli/aws/provider.go @@ -16,22 +16,44 @@ limitations under the License. package aws +// AWSCredentialKind - AWS credential kinds supported. +type AWSCredentialKind string + const ( // ProviderDisplayName is the text used in display for AWS. - ProviderDisplayName = "AWS" + ProviderDisplayName = "AWS" + AWSCredentialKindAccessKey = "AccessKey" + AWSCredentialKindIRSA = "IRSA" ) // Provider specifies the properties required to configure AWS provider for cloud resources. type Provider struct { - // AccessKeyID is the access key id for the AWS account. - AccessKeyID string - - // SecretAccessKey is the secret access key for the AWS account. - SecretAccessKey string // Region is the AWS region to use. Region string // AccountID is the AWS account id. AccountID string + + // CredentialKind represents ucp credential kind for aws provider. + CredentialKind AWSCredentialKind + + // AccessKey represents ucp credential kind for aws access key credentials. + AccessKey *AccessKeyCredential + + // IRSA represents ucp credential kind for aws irsa credentials. + IRSA *IRSACredential +} + +type AccessKeyCredential struct { + // AccessKeyID is the access key id for the AWS account. + AccessKeyID string + + // SecretAccessKey is the secret access key for the AWS account. + SecretAccessKey string +} + +type IRSACredential struct { + // RoleARN for AWS IRSA identity + RoleARN string } diff --git a/pkg/cli/azure/provider.go b/pkg/cli/azure/provider.go index 8a60c14ddb..76895deb67 100644 --- a/pkg/cli/azure/provider.go +++ b/pkg/cli/azure/provider.go @@ -35,7 +35,7 @@ type Provider struct { ServicePrincipal *ServicePrincipalCredential } -// Wor specifies the properties of an Azure service principal +// WorkloadIdentityCredential specifies the properties of an Azure service principal type WorkloadIdentityCredential struct { ClientID string TenantID string diff --git a/pkg/cli/cmd/radinit/aws.go b/pkg/cli/cmd/radinit/aws.go index 6d3747309e..1bac374808 100644 --- a/pkg/cli/cmd/radinit/aws.go +++ b/pkg/cli/cmd/radinit/aws.go @@ -29,7 +29,10 @@ import ( const ( selectAWSRegionPrompt = "Select the region you would like to deploy AWS resources to:" + selectAWSCredentialKindPrompt = "Select a credential kind for the AWS credential:" enterAWSIAMAcessKeyIDPrompt = "Enter the IAM access key id:" + enterAWSRoleARNPrompt = "Enter the role ARN:" + enterAWSRoleARNPlaceholder = "Enter IAM role ARN..." enterAWSIAMAcessKeyIDPlaceholder = "Enter IAM access key id..." enterAWSIAMSecretAccessKeyPrompt = "Enter your IAM Secret Access Key:" enterAWSIAMSecretAccessKeyPlaceholder = "Enter IAM secret access key..." @@ -39,37 +42,80 @@ const ( enterAWSAccountIDPlaceholder = "Enter the account ID you want to use..." awsAccessKeysCreateInstructionFmt = "\nAWS IAM Access keys (Access key ID and Secret access key) are required to access and create AWS resources.\n\nFor example, you can create one using the following command:\n\033[36maws iam create-access-key\033[0m\n\nFor more information refer to https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html.\n\n" + awsIRSACredentialKind = "IRSA" + awsAccessKeyCredentialKind = "Access Key" ) -func (r *Runner) enterAWSCloudProvider(ctx context.Context) (*aws.Provider, error) { - r.Output.LogInfo(awsAccessKeysCreateInstructionFmt) - - accessKeyID, err := r.Prompter.GetTextInput(enterAWSIAMAcessKeyIDPrompt, prompt.TextInputOptions{Placeholder: enterAWSIAMAcessKeyIDPlaceholder}) +func (r *Runner) enterAWSCloudProvider(ctx context.Context, options *initOptions) (*aws.Provider, error) { + credentialKind, err := r.selectAWSCredentialKind() if err != nil { return nil, err } - secretAccessKey, err := r.Prompter.GetTextInput(enterAWSIAMSecretAccessKeyPrompt, prompt.TextInputOptions{Placeholder: enterAWSIAMSecretAccessKeyPlaceholder, EchoMode: textinput.EchoPassword}) - if err != nil { - return nil, err - } + switch credentialKind { + case awsAccessKeyCredentialKind: + r.Output.LogInfo(awsAccessKeysCreateInstructionFmt) - accountId, err := r.getAccountId(ctx) - if err != nil { - return nil, err - } + accessKeyID, err := r.Prompter.GetTextInput(enterAWSIAMAcessKeyIDPrompt, prompt.TextInputOptions{Placeholder: enterAWSIAMAcessKeyIDPlaceholder}) + if err != nil { + return nil, err + } - region, err := r.selectAWSRegion(ctx) - if err != nil { - return nil, err - } + secretAccessKey, err := r.Prompter.GetTextInput(enterAWSIAMSecretAccessKeyPrompt, prompt.TextInputOptions{Placeholder: enterAWSIAMSecretAccessKeyPlaceholder, EchoMode: textinput.EchoPassword}) + if err != nil { + return nil, err + } - return &aws.Provider{ - AccessKeyID: accessKeyID, - SecretAccessKey: secretAccessKey, - AccountID: accountId, - Region: region, - }, nil + accountId, err := r.getAccountId(ctx) + if err != nil { + return nil, err + } + + region, err := r.selectAWSRegion(ctx) + if err != nil { + return nil, err + } + + return &aws.Provider{ + AccessKey: &aws.AccessKeyCredential{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + }, + CredentialKind: aws.AWSCredentialKindAccessKey, + AccountID: accountId, + Region: region, + }, nil + case awsIRSACredentialKind: + r.Output.LogInfo(awsAccessKeysCreateInstructionFmt) + + roleARN, err := r.Prompter.GetTextInput(enterAWSRoleARNPrompt, prompt.TextInputOptions{Placeholder: enterAWSRoleARNPlaceholder}) + if err != nil { + return nil, err + } + + accountId, err := r.getAccountId(ctx) + if err != nil { + return nil, err + } + + region, err := r.selectAWSRegion(ctx) + if err != nil { + return nil, err + } + + // Set the value for the Helm chart + options.SetValues = append(options.SetValues, "global.aws.irsa.enabled=true") + return &aws.Provider{ + AccountID: accountId, + Region: region, + CredentialKind: aws.AWSCredentialKindIRSA, + IRSA: &aws.IRSACredential{ + RoleARN: roleARN, + }, + }, nil + default: + return nil, clierrors.Message("Invalid AWS credential kind: %s", credentialKind) + } } func (r *Runner) getAccountId(ctx context.Context) (string, error) { @@ -123,3 +169,15 @@ func (r *Runner) buildAWSRegionsList(listRegionsOutput *ec2.DescribeRegionsOutpu return regions } + +func (r *Runner) selectAWSCredentialKind() (string, error) { + credentialKinds := r.buildAWSCredentialKind() + return r.Prompter.GetListInput(credentialKinds, selectAWSCredentialKindPrompt) +} + +func (r *Runner) buildAWSCredentialKind() []string { + return []string{ + awsAccessKeyCredentialKind, + awsIRSACredentialKind, + } +} diff --git a/pkg/cli/cmd/radinit/aws_test.go b/pkg/cli/cmd/radinit/aws_test.go index 1a8e7e3cb7..b9def5a525 100644 --- a/pkg/cli/cmd/radinit/aws_test.go +++ b/pkg/cli/cmd/radinit/aws_test.go @@ -31,7 +31,7 @@ import ( "go.uber.org/mock/gomock" ) -func Test_enterAWSCloudProvider(t *testing.T) { +func Test_enterAWSCloudProvider_AccessKey(t *testing.T) { ctrl := gomock.NewController(t) prompter := prompt.NewMockInterface(ctrl) client := aws.NewMockClient(ctrl) @@ -43,6 +43,7 @@ func Test_enterAWSCloudProvider(t *testing.T) { } regions := []string{"region", "region2"} + setAWSCredentialKindPrompt(prompter, "Access Key") setAWSAccessKeyIDPrompt(prompter, "access-key-id") setAWSSecretAccessKeyPrompt(prompter, "secret-access-key") setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")}) @@ -50,14 +51,53 @@ func Test_enterAWSCloudProvider(t *testing.T) { setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: ec2Regions}) setAWSRegionPrompt(prompter, regions, "region") - provider, err := runner.enterAWSCloudProvider(context.Background()) + options := &initOptions{} + provider, err := runner.enterAWSCloudProvider(context.Background(), options) require.NoError(t, err) expected := &aws.Provider{ - AccessKeyID: "access-key-id", - SecretAccessKey: "secret-access-key", - Region: "region", - AccountID: "account-id", + AccessKey: &aws.AccessKeyCredential{ + AccessKeyID: "access-key-id", + SecretAccessKey: "secret-access-key", + }, + CredentialKind: "AccessKey", + Region: "region", + AccountID: "account-id", + } + require.Equal(t, expected, provider) + require.Equal(t, []any{output.LogOutput{Format: awsAccessKeysCreateInstructionFmt}}, outputSink.Writes) +} + +func Test_enterAWSCloudProvider_IRSA(t *testing.T) { + ctrl := gomock.NewController(t) + prompter := prompt.NewMockInterface(ctrl) + client := aws.NewMockClient(ctrl) + outputSink := output.MockOutput{} + runner := Runner{Prompter: prompter, awsClient: client, Output: &outputSink} + ec2Regions := []ec2_types.Region{ + {RegionName: to.Ptr("region")}, + {RegionName: to.Ptr("region2")}, + } + regions := []string{"region", "region2"} + + setAWSCredentialKindPrompt(prompter, "IRSA") + setAwsIRSARoleARNPrompt(prompter, "role-arn") + setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")}) + setAWSAccountIDConfirmPrompt(prompter, "account-id", prompt.ConfirmYes) + setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: ec2Regions}) + setAWSRegionPrompt(prompter, regions, "region") + + options := &initOptions{} + provider, err := runner.enterAWSCloudProvider(context.Background(), options) + require.NoError(t, err) + + expected := &aws.Provider{ + IRSA: &aws.IRSACredential{ + RoleARN: "role-arn", + }, + CredentialKind: "IRSA", + Region: "region", + AccountID: "account-id", } require.Equal(t, expected, provider) require.Equal(t, []any{output.LogOutput{Format: awsAccessKeysCreateInstructionFmt}}, outputSink.Writes) diff --git a/pkg/cli/cmd/radinit/cloud.go b/pkg/cli/cmd/radinit/cloud.go index e492d439cd..4e0a2e3ff4 100644 --- a/pkg/cli/cmd/radinit/cloud.go +++ b/pkg/cli/cmd/radinit/cloud.go @@ -64,7 +64,7 @@ func (r *Runner) enterCloudProviderOptions(ctx context.Context, options *initOpt options.CloudProviders.Azure = provider case aws.ProviderDisplayName: - provider, err := r.enterAWSCloudProvider(ctx) + provider, err := r.enterAWSCloudProvider(ctx, options) if err != nil { return err } diff --git a/pkg/cli/cmd/radinit/cloud_test.go b/pkg/cli/cmd/radinit/cloud_test.go index f9949673b5..3d161eb6be 100644 --- a/pkg/cli/cmd/radinit/cloud_test.go +++ b/pkg/cli/cmd/radinit/cloud_test.go @@ -50,11 +50,23 @@ func Test_enterCloudProviderOptions(t *testing.T) { }, } - awsProvider := aws.Provider{ - Region: "test-region", - AccessKeyID: "test-access-key-id", - SecretAccessKey: "test-secret-access-key", - AccountID: "test-account-id", + awsProviderAccessKey := aws.Provider{ + Region: "test-region", + CredentialKind: "AccessKey", + AccessKey: &aws.AccessKeyCredential{ + AccessKeyID: "test-access-key-id", + SecretAccessKey: "test-secret-access-key", + }, + AccountID: "test-account-id", + } + + awsProviderIRSA := aws.Provider{ + Region: "test-region", + CredentialKind: "IRSA", + IRSA: &aws.IRSACredential{ + RoleARN: "test-role-arn", + }, + AccountID: "test-account-id", } t.Run("cloud providers skipped when no flags specified", func(t *testing.T) { @@ -114,7 +126,7 @@ func Test_enterCloudProviderOptions(t *testing.T) { require.Empty(t, outputSink.Writes) }) - t.Run("--full - aws provider", func(t *testing.T) { + t.Run("--full - aws provider - accesskey", func(t *testing.T) { ctrl := gomock.NewController(t) prompter := prompt.NewMockInterface(ctrl) awsClient := aws.NewMockClient(ctrl) @@ -124,14 +136,41 @@ func Test_enterCloudProviderOptions(t *testing.T) { initAddCloudProviderPromptYes(prompter) initSelectCloudProvider(prompter, aws.ProviderDisplayName) - setAWSCloudProvider(prompter, awsClient, awsProvider) + setAWSCloudProviderAccessKey(prompter, awsClient, awsProviderAccessKey) initAddCloudProviderPromptNo(prompter) options := initOptions{Environment: environmentOptions{Create: true}} err := runner.enterCloudProviderOptions(context.Background(), &options) require.NoError(t, err) require.Nil(t, options.CloudProviders.Azure) - require.Equal(t, awsProvider, *options.CloudProviders.AWS) + require.Equal(t, awsProviderAccessKey, *options.CloudProviders.AWS) + + expectedWrites := []any{ + output.LogOutput{ + Format: awsAccessKeysCreateInstructionFmt, + }, + } + require.Equal(t, expectedWrites, outputSink.Writes) + }) + + t.Run("--full - aws provider - irsa", func(t *testing.T) { + ctrl := gomock.NewController(t) + prompter := prompt.NewMockInterface(ctrl) + awsClient := aws.NewMockClient(ctrl) + azureClient := azure.NewMockClient(ctrl) + outputSink := output.MockOutput{} + runner := Runner{Prompter: prompter, awsClient: awsClient, azureClient: azureClient, Output: &outputSink, Full: true} + + initAddCloudProviderPromptYes(prompter) + initSelectCloudProvider(prompter, aws.ProviderDisplayName) + setAWSCloudProviderIRSA(prompter, awsClient, awsProviderIRSA) + initAddCloudProviderPromptNo(prompter) + + options := initOptions{Environment: environmentOptions{Create: true}} + err := runner.enterCloudProviderOptions(context.Background(), &options) + require.NoError(t, err) + require.Nil(t, options.CloudProviders.Azure) + require.Equal(t, awsProviderIRSA, *options.CloudProviders.AWS) expectedWrites := []any{ output.LogOutput{ @@ -206,7 +245,7 @@ func Test_enterCloudProviderOptions(t *testing.T) { initAddCloudProviderPromptYes(prompter) initSelectCloudProvider(prompter, aws.ProviderDisplayName) - setAWSCloudProvider(prompter, awsClient, awsProvider) + setAWSCloudProviderAccessKey(prompter, awsClient, awsProviderAccessKey) initAddCloudProviderPromptYes(prompter) initSelectCloudProvider(prompter, azure.ProviderDisplayName) @@ -217,7 +256,7 @@ func Test_enterCloudProviderOptions(t *testing.T) { options := initOptions{Environment: environmentOptions{Create: true}} err := runner.enterCloudProviderOptions(context.Background(), &options) require.NoError(t, err) - require.Equal(t, awsProvider, *options.CloudProviders.AWS) + require.Equal(t, awsProviderAccessKey, *options.CloudProviders.AWS) require.Equal(t, azureProviderServicePrincipal, *options.CloudProviders.Azure) expectedWrites := []any{ @@ -243,13 +282,13 @@ func Test_enterCloudProviderOptions(t *testing.T) { initAddCloudProviderPromptYes(prompter) initSelectCloudProvider(prompter, aws.ProviderDisplayName) - setAWSCloudProvider(prompter, awsClient, awsProvider) + setAWSCloudProviderAccessKey(prompter, awsClient, awsProviderAccessKey) - awsProvider := awsProvider + awsProvider := awsProviderAccessKey awsProvider.Region = "another-region" initAddCloudProviderPromptYes(prompter) initSelectCloudProvider(prompter, aws.ProviderDisplayName) - setAWSCloudProvider(prompter, awsClient, awsProvider) + setAWSCloudProviderAccessKey(prompter, awsClient, awsProvider) initAddCloudProviderPromptNo(prompter) diff --git a/pkg/cli/cmd/radinit/display.go b/pkg/cli/cmd/radinit/display.go index ead182de57..d16772c931 100644 --- a/pkg/cli/cmd/radinit/display.go +++ b/pkg/cli/cmd/radinit/display.go @@ -26,6 +26,7 @@ import ( "github.com/charmbracelet/bubbles/spinner" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/radius-project/radius/pkg/cli/aws" "github.com/radius-project/radius/pkg/cli/azure" "github.com/radius-project/radius/pkg/cli/prompt" ) @@ -36,7 +37,7 @@ const ( summaryFooter = "\n(press enter to confirm or esc to restart)\n" summaryKubernetesHeadingIcon = "🔧 " summaryKubernetesInstallHeadingFmt = "Install Radius %s\n" + summaryIndent + "Kubernetes cluster: %s\n" + summaryIndent + "Kubernetes namespace: %s\n" - summaryKubernetesInstallAWSCloudProviderFmt = summaryIndent + "AWS IAM access key id: %s\n" + summaryKubernetesInstallAWSCloudProviderFmt = summaryIndent + "AWS credential: %s\n" summaryKubernetesInstallAzureCloudProviderFmt = summaryIndent + "Azure credential: %s\n" summaryKubernetesExistingHeadingFmt = "Use existing Radius %s install on %s\n" summaryEnvironmentHeadingIcon = "🌏 " @@ -205,7 +206,14 @@ func (m *summaryModel) View() string { message.WriteString(fmt.Sprintf(summaryKubernetesInstallHeadingFmt, highlight(options.Cluster.Version), highlight(options.Cluster.Context), highlight(options.Cluster.Namespace))) if options.CloudProviders.AWS != nil { - message.WriteString(fmt.Sprintf(summaryKubernetesInstallAWSCloudProviderFmt, highlight(options.CloudProviders.AWS.AccessKeyID))) + message.WriteString(fmt.Sprintf(summaryKubernetesInstallAWSCloudProviderFmt, highlight(string(options.CloudProviders.AWS.CredentialKind)))) + switch options.CloudProviders.AWS.CredentialKind { + case aws.AWSCredentialKindAccessKey: + message.WriteString(fmt.Sprintf(summaryIndent+"AccessKey ID: %s\n", highlight(options.CloudProviders.AWS.AccessKey.AccessKeyID))) + case aws.AWSCredentialKindIRSA: + message.WriteString(fmt.Sprintf(summaryIndent+"IAM Role ARN: %s\n", highlight(options.CloudProviders.AWS.IRSA.RoleARN))) + } + } if options.CloudProviders.Azure != nil { message.WriteString(fmt.Sprintf(summaryKubernetesInstallAzureCloudProviderFmt, highlight(string(options.CloudProviders.Azure.CredentialKind)))) @@ -337,8 +345,15 @@ func (m *progressModel) View() string { message.WriteString(fmt.Sprintf(summaryKubernetesInstallHeadingFmt, highlight(options.Cluster.Version), highlight(options.Cluster.Context), highlight(options.Cluster.Namespace))) if options.CloudProviders.AWS != nil { - message.WriteString(fmt.Sprintf(summaryKubernetesInstallAWSCloudProviderFmt, highlight(options.CloudProviders.AWS.AccessKeyID))) + message.WriteString(fmt.Sprintf(summaryKubernetesInstallAWSCloudProviderFmt, highlight(string(options.CloudProviders.AWS.CredentialKind)))) + switch options.CloudProviders.AWS.CredentialKind { + case aws.AWSCredentialKindAccessKey: + message.WriteString(fmt.Sprintf(summaryIndent+"AccessKey ID: %s\n", highlight(options.CloudProviders.AWS.AccessKey.AccessKeyID))) + case aws.AWSCredentialKindIRSA: + message.WriteString(fmt.Sprintf(summaryIndent+"IAM Role ARN: %s\n", highlight(options.CloudProviders.AWS.IRSA.RoleARN))) + } } + if options.CloudProviders.Azure != nil { message.WriteString(fmt.Sprintf(summaryKubernetesInstallAzureCloudProviderFmt, highlight(string(options.CloudProviders.Azure.CredentialKind)))) switch options.CloudProviders.Azure.CredentialKind { diff --git a/pkg/cli/cmd/radinit/environment.go b/pkg/cli/cmd/radinit/environment.go index a4542c8cbf..f09e547736 100644 --- a/pkg/cli/cmd/radinit/environment.go +++ b/pkg/cli/cmd/radinit/environment.go @@ -108,8 +108,12 @@ func (r *Runner) CreateEnvironment(ctx context.Context) error { } if r.Options.CloudProviders.AWS != nil { - credential := r.getAWSCredential() - err := credentialClient.PutAWS(ctx, credential) + credential, err := r.getAWSCredential() + if err != nil { + return clierrors.MessageWithCause(err, "Failed to configure AWS credentials.") + } + + err = credentialClient.PutAWS(ctx, credential) if err != nil { return clierrors.MessageWithCause(err, "Failed to configure AWS credentials.") } diff --git a/pkg/cli/cmd/radinit/init.go b/pkg/cli/cmd/radinit/init.go index 84133d9619..ce7e6df3d4 100644 --- a/pkg/cli/cmd/radinit/init.go +++ b/pkg/cli/cmd/radinit/init.go @@ -316,17 +316,33 @@ func (r *Runner) getAzureCredential() (ucp.AzureCredentialResource, error) { } } -func (r *Runner) getAWSCredential() ucp.AwsCredentialResource { - return ucp.AwsCredentialResource{ - Location: to.Ptr(v1.LocationGlobal), - Type: to.Ptr(cli_credential.AWSCredential), - Properties: &ucp.AwsAccessKeyCredentialProperties{ - Storage: &ucp.CredentialStorageProperties{ - Kind: to.Ptr(ucp.CredentialStorageKindInternal), +func (r *Runner) getAWSCredential() (ucp.AwsCredentialResource, error) { + switch r.Options.CloudProviders.AWS.CredentialKind { + case aws.AWSCredentialKindAccessKey: + return ucp.AwsCredentialResource{ + Location: to.Ptr(v1.LocationGlobal), + Type: to.Ptr(cli_credential.AWSCredential), + Properties: &ucp.AwsAccessKeyCredentialProperties{ + Storage: &ucp.CredentialStorageProperties{ + Kind: to.Ptr(ucp.CredentialStorageKindInternal), + }, + AccessKeyID: &r.Options.CloudProviders.AWS.AccessKey.AccessKeyID, + SecretAccessKey: &r.Options.CloudProviders.AWS.AccessKey.SecretAccessKey, + }, + }, nil + case aws.AWSCredentialKindIRSA: + return ucp.AwsCredentialResource{ + Location: to.Ptr(v1.LocationGlobal), + Type: to.Ptr(cli_credential.AWSCredential), + Properties: &ucp.AwsIRSACredentialProperties{ + Storage: &ucp.CredentialStorageProperties{ + Kind: to.Ptr(ucp.CredentialStorageKindInternal), + }, + RoleARN: &r.Options.CloudProviders.AWS.IRSA.RoleARN, }, - AccessKeyID: &r.Options.CloudProviders.AWS.AccessKeyID, - SecretAccessKey: &r.Options.CloudProviders.AWS.SecretAccessKey, - }, + }, nil + default: + return ucp.AwsCredentialResource{}, fmt.Errorf("unsupported AWS credential kind: %s", r.Options.CloudProviders.AWS.CredentialKind) } } diff --git a/pkg/cli/cmd/radinit/init_test.go b/pkg/cli/cmd/radinit/init_test.go index 894dd658df..24b1008695 100644 --- a/pkg/cli/cmd/radinit/init_test.go +++ b/pkg/cli/cmd/radinit/init_test.go @@ -80,11 +80,23 @@ func Test_Validate(t *testing.T) { }, } - awsProvider := aws.Provider{ - Region: "test-region", - AccessKeyID: "test-access-key-id", - SecretAccessKey: "test-secret-access-key", - AccountID: "test-account-id", + awsProviderAccessKey := aws.Provider{ + Region: "test-region", + CredentialKind: "AccessKey", + AccessKey: &aws.AccessKeyCredential{ + AccessKeyID: "test-access-key-id", + SecretAccessKey: "test-secret-access-key", + }, + AccountID: "test-account-id", + } + + awsProviderIRSA := aws.Provider{ + Region: "test-region", + CredentialKind: "IRSA", + IRSA: &aws.IRSACredential{ + RoleARN: "test-role-arn", + }, + AccountID: "test-account-id", } testcases := []radcli.ValidateInput{ @@ -325,7 +337,7 @@ func Test_Validate(t *testing.T) { }, }, { - Name: "Init --full Command With AWS Cloud Provider", + Name: "Init --full Command With AWS Cloud Provider - Access Key", Input: []string{"--full"}, ExpectedValid: true, ConfigHolder: framework.ConfigHolder{ @@ -348,7 +360,42 @@ func Test_Validate(t *testing.T) { // Add aws provider initAddCloudProviderPromptYes(mocks.Prompter) initSelectCloudProvider(mocks.Prompter, aws.ProviderDisplayName) - setAWSCloudProvider(mocks.Prompter, mocks.AWSClient, awsProvider) + setAWSCloudProviderAccessKey(mocks.Prompter, mocks.AWSClient, awsProviderAccessKey) + + // Don't add any other cloud providers + initAddCloudProviderPromptNo(mocks.Prompter) + + // No application + setScaffoldApplicationPromptNo(mocks.Prompter) + + setConfirmOption(mocks.Prompter, resultConfimed) + }, + }, + { + Name: "Init --full Command With AWS Cloud Provider - IRSA", + Input: []string{"--full"}, + ExpectedValid: true, + ConfigHolder: framework.ConfigHolder{ + ConfigFilePath: "", + Config: config, + }, + ConfigureMocks: func(mocks radcli.ValidateMocks) { + // Radius is already installed + initGetKubeContextSuccess(mocks.Kubernetes) + initKubeContextWithKind(mocks.Prompter) + initHelmMockRadiusInstalled(mocks.Helm) + + // No existing environment, users will be prompted to create a new one + setExistingEnvironments(mocks.ApplicationManagementClient, []corerp.EnvironmentResource{}) + + // Choose default name and namespace + initEnvNamePrompt(mocks.Prompter, "default") + initNamespacePrompt(mocks.Prompter, "default") + + // Add aws provider + initAddCloudProviderPromptYes(mocks.Prompter) + initSelectCloudProvider(mocks.Prompter, aws.ProviderDisplayName) + setAWSCloudProviderIRSA(mocks.Prompter, mocks.AWSClient, awsProviderIRSA) // Don't add any other cloud providers initAddCloudProviderPromptNo(mocks.Prompter) @@ -692,23 +739,44 @@ func Test_Run_InstallAndCreateEnvironment(t *testing.T) { full: false, azureProvider: nil, awsProvider: &aws.Provider{ - AccessKeyID: "test-access-key", - SecretAccessKey: "test-secret-access", - Region: "us-west-2", - AccountID: "test-account-id", + AccessKey: &aws.AccessKeyCredential{ + AccessKeyID: "test-access-key", + SecretAccessKey: "test-secret-access", + }, + CredentialKind: "AccessKey", + Region: "us-west-2", + AccountID: "test-account-id", }, recipes: map[string]map[string]corerp.RecipePropertiesClassification{}, expectedOutput: []any{}, }, { - name: "`rad init --full` with AWS Provider", + name: "`rad init --full` with AWS Provider - Access Key", + full: true, + azureProvider: nil, + awsProvider: &aws.Provider{ + AccessKey: &aws.AccessKeyCredential{ + AccessKeyID: "test-access-key", + SecretAccessKey: "test-secret-access", + }, + CredentialKind: "AccessKey", + Region: "us-west-2", + AccountID: "test-account-id", + }, + recipes: nil, + expectedOutput: []any{}, + }, + { + name: "`rad init --full` with AWS Provider - IRSA", full: true, azureProvider: nil, awsProvider: &aws.Provider{ - AccessKeyID: "test-access-key", - SecretAccessKey: "test-secret-access", - Region: "us-west-2", - AccountID: "test-account-id", + IRSA: &aws.IRSACredential{ + RoleARN: "role-arn", + }, + CredentialKind: "IRSA", + Region: "us-west-2", + AccountID: "test-account-id", }, recipes: nil, expectedOutput: []any{}, @@ -777,20 +845,37 @@ func Test_Run_InstallAndCreateEnvironment(t *testing.T) { Times(1) } if tc.awsProvider != nil { - credentialManagementClient.EXPECT(). - PutAWS(context.Background(), ucp.AwsCredentialResource{ - Location: to.Ptr(v1.LocationGlobal), - Type: to.Ptr(cli_credential.AWSCredential), - Properties: &ucp.AwsAccessKeyCredentialProperties{ - Storage: &ucp.CredentialStorageProperties{ - Kind: to.Ptr(ucp.CredentialStorageKindInternal), + if tc.awsProvider.AccessKey != nil { + credentialManagementClient.EXPECT(). + PutAWS(context.Background(), ucp.AwsCredentialResource{ + Location: to.Ptr(v1.LocationGlobal), + Type: to.Ptr(cli_credential.AWSCredential), + Properties: &ucp.AwsAccessKeyCredentialProperties{ + Storage: &ucp.CredentialStorageProperties{ + Kind: to.Ptr(ucp.CredentialStorageKindInternal), + }, + AccessKeyID: to.Ptr(tc.awsProvider.AccessKey.AccessKeyID), + SecretAccessKey: to.Ptr(tc.awsProvider.AccessKey.SecretAccessKey), }, - AccessKeyID: to.Ptr(tc.awsProvider.AccessKeyID), - SecretAccessKey: to.Ptr(tc.awsProvider.SecretAccessKey), - }, - }). - Return(nil). - Times(1) + }). + Return(nil). + Times(1) + } else { + credentialManagementClient.EXPECT(). + PutAWS(context.Background(), ucp.AwsCredentialResource{ + Location: to.Ptr(v1.LocationGlobal), + Type: to.Ptr(cli_credential.AWSCredential), + Properties: &ucp.AwsIRSACredentialProperties{ + Storage: &ucp.CredentialStorageProperties{ + Kind: to.Ptr(ucp.CredentialStorageKindInternal), + }, + RoleARN: to.Ptr(tc.awsProvider.IRSA.RoleARN), + }, + }). + Return(nil). + Times(1) + } + } configFileInterface.EXPECT(). @@ -1058,10 +1143,21 @@ func setAWSListRegions(client *aws.MockClient, ec2DescribeRegionsOutput *ec2.Des Times(1) } -// setAWSCloudProvider sets up mocks that will configure an AWS cloud provider. -func setAWSCloudProvider(prompter *prompt.MockInterface, client *aws.MockClient, provider aws.Provider) { - setAWSAccessKeyIDPrompt(prompter, provider.AccessKeyID) - setAWSSecretAccessKeyPrompt(prompter, provider.SecretAccessKey) +// setAWSCloudProviderAccessKey sets up mocks that will configure an AWS cloud provider with access key. +func setAWSCloudProviderAccessKey(prompter *prompt.MockInterface, client *aws.MockClient, provider aws.Provider) { + setAWSCredentialKindPrompt(prompter, "Access Key") + setAWSAccessKeyIDPrompt(prompter, provider.AccessKey.AccessKeyID) + setAWSSecretAccessKeyPrompt(prompter, provider.AccessKey.SecretAccessKey) + setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: &provider.AccountID}) + setAWSAccountIDConfirmPrompt(prompter, provider.AccountID, prompt.ConfirmYes) + setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: getMockAWSRegions()}) + setAWSRegionPrompt(prompter, getMockAWSRegionsString(), provider.Region) +} + +// setAWSCloudProviderIRSA sets up mocks that will configure an AWS cloud provider with IRSA. +func setAWSCloudProviderIRSA(prompter *prompt.MockInterface, client *aws.MockClient, provider aws.Provider) { + setAWSCredentialKindPrompt(prompter, "IRSA") + setAwsIRSARoleARNPrompt(prompter, provider.IRSA.RoleARN) setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: &provider.AccountID}) setAWSAccountIDConfirmPrompt(prompter, provider.AccountID, prompt.ConfirmYes) setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: getMockAWSRegions()}) @@ -1187,6 +1283,20 @@ func setAzureCredentialKindPrompt(prompter *prompt.MockInterface, choice string) Times(1) } +func setAWSCredentialKindPrompt(prompter *prompt.MockInterface, choice string) { + prompter.EXPECT(). + GetListInput([]string{"Access Key", "IRSA"}, selectAWSCredentialKindPrompt). + Return(choice, nil). + Times(1) +} + +func setAwsIRSARoleARNPrompt(prompter *prompt.MockInterface, roleARN string) { + prompter.EXPECT(). + GetTextInput(enterAWSRoleARNPrompt, gomock.Any()). + Return(roleARN, nil). + Times(1) +} + // setAzureCloudProviderServicePrincipal sets up mocks that will configure an Azure cloud provider with service principal credential. func setAzureCloudProviderServicePrincipal(prompter *prompt.MockInterface, client *azure.MockClient, provider azure.Provider) { subscriptions := &azure.SubscriptionResult{