Skip to content

Commit

Permalink
Adding rad init command changes to support irsa (radius-project#7761)
Browse files Browse the repository at this point in the history
# Description

- Added changes to `rad init --full` command to add types i.e
accesskey,irsa while configuring aws provider

![image](https://github.com/user-attachments/assets/da396b4d-5877-4772-a247-9ef6cf0c9e79)
- And adding prompts to accept role arn
- refactoring the code with switch cases to handle it differently for
both cases.
- updated the tests


## Type of change

<!--

Please select **one** of the following options that describes your
change and delete the others. Clearly identifying the type of change you
are making will help us review your PR faster, and is used in authoring
release notes.

If you are making a bug fix or functionality change to Radius and do not
have an associated issue link please create one now.

-->

- This pull request fixes a bug in Radius and has an approved issue
(issue link required).
- This pull request adds or changes features of Radius and has an
approved issue (issue link required).
- This pull request is a minor refactor, code cleanup, test improvement,
or other maintenance task and doesn't change the functionality of Radius
(issue link optional).

<!--

Please update the following to link the associated issue. This is
required for some kinds of changes (see above).

-->

Fixes: #issue_number

---------

Signed-off-by: Vishwanath Hiremath <vhiremath@microsoft.com>
  • Loading branch information
vishwahiremat authored Aug 6, 2024
1 parent 4ba025d commit a4a4b90
Show file tree
Hide file tree
Showing 10 changed files with 401 additions and 97 deletions.
34 changes: 28 additions & 6 deletions pkg/cli/aws/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,44 @@ limitations under the License.

package aws

// AWSCredentialKind - AWS credential kinds supported.
type AWSCredentialKind string

const (
// ProviderDisplayName is the text used in display for AWS.
ProviderDisplayName = "AWS"
ProviderDisplayName = "AWS"
AWSCredentialKindAccessKey = "AccessKey"
AWSCredentialKindIRSA = "IRSA"
)

// Provider specifies the properties required to configure AWS provider for cloud resources.
type Provider struct {
// AccessKeyID is the access key id for the AWS account.
AccessKeyID string

// SecretAccessKey is the secret access key for the AWS account.
SecretAccessKey string

// Region is the AWS region to use.
Region string

// AccountID is the AWS account id.
AccountID string

// CredentialKind represents ucp credential kind for aws provider.
CredentialKind AWSCredentialKind

// AccessKey represents ucp credential kind for aws access key credentials.
AccessKey *AccessKeyCredential

// IRSA represents ucp credential kind for aws irsa credentials.
IRSA *IRSACredential
}

type AccessKeyCredential struct {
// AccessKeyID is the access key id for the AWS account.
AccessKeyID string

// SecretAccessKey is the secret access key for the AWS account.
SecretAccessKey string
}

type IRSACredential struct {
// RoleARN for AWS IRSA identity
RoleARN string
}
2 changes: 1 addition & 1 deletion pkg/cli/azure/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type Provider struct {
ServicePrincipal *ServicePrincipalCredential
}

// Wor specifies the properties of an Azure service principal
// WorkloadIdentityCredential specifies the properties of an Azure service principal
type WorkloadIdentityCredential struct {
ClientID string
TenantID string
Expand Down
102 changes: 80 additions & 22 deletions pkg/cli/cmd/radinit/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ import (

const (
selectAWSRegionPrompt = "Select the region you would like to deploy AWS resources to:"
selectAWSCredentialKindPrompt = "Select a credential kind for the AWS credential:"
enterAWSIAMAcessKeyIDPrompt = "Enter the IAM access key id:"
enterAWSRoleARNPrompt = "Enter the role ARN:"
enterAWSRoleARNPlaceholder = "Enter IAM role ARN..."
enterAWSIAMAcessKeyIDPlaceholder = "Enter IAM access key id..."
enterAWSIAMSecretAccessKeyPrompt = "Enter your IAM Secret Access Key:"
enterAWSIAMSecretAccessKeyPlaceholder = "Enter IAM secret access key..."
Expand All @@ -39,37 +42,80 @@ const (
enterAWSAccountIDPlaceholder = "Enter the account ID you want to use..."

awsAccessKeysCreateInstructionFmt = "\nAWS IAM Access keys (Access key ID and Secret access key) are required to access and create AWS resources.\n\nFor example, you can create one using the following command:\n\033[36maws iam create-access-key\033[0m\n\nFor more information refer to https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html.\n\n"
awsIRSACredentialKind = "IRSA"
awsAccessKeyCredentialKind = "Access Key"
)

func (r *Runner) enterAWSCloudProvider(ctx context.Context) (*aws.Provider, error) {
r.Output.LogInfo(awsAccessKeysCreateInstructionFmt)

accessKeyID, err := r.Prompter.GetTextInput(enterAWSIAMAcessKeyIDPrompt, prompt.TextInputOptions{Placeholder: enterAWSIAMAcessKeyIDPlaceholder})
func (r *Runner) enterAWSCloudProvider(ctx context.Context, options *initOptions) (*aws.Provider, error) {
credentialKind, err := r.selectAWSCredentialKind()
if err != nil {
return nil, err
}

secretAccessKey, err := r.Prompter.GetTextInput(enterAWSIAMSecretAccessKeyPrompt, prompt.TextInputOptions{Placeholder: enterAWSIAMSecretAccessKeyPlaceholder, EchoMode: textinput.EchoPassword})
if err != nil {
return nil, err
}
switch credentialKind {
case awsAccessKeyCredentialKind:
r.Output.LogInfo(awsAccessKeysCreateInstructionFmt)

accountId, err := r.getAccountId(ctx)
if err != nil {
return nil, err
}
accessKeyID, err := r.Prompter.GetTextInput(enterAWSIAMAcessKeyIDPrompt, prompt.TextInputOptions{Placeholder: enterAWSIAMAcessKeyIDPlaceholder})
if err != nil {
return nil, err
}

region, err := r.selectAWSRegion(ctx)
if err != nil {
return nil, err
}
secretAccessKey, err := r.Prompter.GetTextInput(enterAWSIAMSecretAccessKeyPrompt, prompt.TextInputOptions{Placeholder: enterAWSIAMSecretAccessKeyPlaceholder, EchoMode: textinput.EchoPassword})
if err != nil {
return nil, err
}

return &aws.Provider{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
AccountID: accountId,
Region: region,
}, nil
accountId, err := r.getAccountId(ctx)
if err != nil {
return nil, err
}

region, err := r.selectAWSRegion(ctx)
if err != nil {
return nil, err
}

return &aws.Provider{
AccessKey: &aws.AccessKeyCredential{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
},
CredentialKind: aws.AWSCredentialKindAccessKey,
AccountID: accountId,
Region: region,
}, nil
case awsIRSACredentialKind:
r.Output.LogInfo(awsAccessKeysCreateInstructionFmt)

roleARN, err := r.Prompter.GetTextInput(enterAWSRoleARNPrompt, prompt.TextInputOptions{Placeholder: enterAWSRoleARNPlaceholder})
if err != nil {
return nil, err
}

accountId, err := r.getAccountId(ctx)
if err != nil {
return nil, err
}

region, err := r.selectAWSRegion(ctx)
if err != nil {
return nil, err
}

// Set the value for the Helm chart
options.SetValues = append(options.SetValues, "global.aws.irsa.enabled=true")
return &aws.Provider{
AccountID: accountId,
Region: region,
CredentialKind: aws.AWSCredentialKindIRSA,
IRSA: &aws.IRSACredential{
RoleARN: roleARN,
},
}, nil
default:
return nil, clierrors.Message("Invalid AWS credential kind: %s", credentialKind)
}
}

func (r *Runner) getAccountId(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -123,3 +169,15 @@ func (r *Runner) buildAWSRegionsList(listRegionsOutput *ec2.DescribeRegionsOutpu

return regions
}

func (r *Runner) selectAWSCredentialKind() (string, error) {
credentialKinds := r.buildAWSCredentialKind()
return r.Prompter.GetListInput(credentialKinds, selectAWSCredentialKindPrompt)
}

func (r *Runner) buildAWSCredentialKind() []string {
return []string{
awsAccessKeyCredentialKind,
awsIRSACredentialKind,
}
}
52 changes: 46 additions & 6 deletions pkg/cli/cmd/radinit/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
"go.uber.org/mock/gomock"
)

func Test_enterAWSCloudProvider(t *testing.T) {
func Test_enterAWSCloudProvider_AccessKey(t *testing.T) {
ctrl := gomock.NewController(t)
prompter := prompt.NewMockInterface(ctrl)
client := aws.NewMockClient(ctrl)
Expand All @@ -43,21 +43,61 @@ func Test_enterAWSCloudProvider(t *testing.T) {
}
regions := []string{"region", "region2"}

setAWSCredentialKindPrompt(prompter, "Access Key")
setAWSAccessKeyIDPrompt(prompter, "access-key-id")
setAWSSecretAccessKeyPrompt(prompter, "secret-access-key")
setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")})
setAWSAccountIDConfirmPrompt(prompter, "account-id", prompt.ConfirmYes)
setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: ec2Regions})
setAWSRegionPrompt(prompter, regions, "region")

provider, err := runner.enterAWSCloudProvider(context.Background())
options := &initOptions{}
provider, err := runner.enterAWSCloudProvider(context.Background(), options)
require.NoError(t, err)

expected := &aws.Provider{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
Region: "region",
AccountID: "account-id",
AccessKey: &aws.AccessKeyCredential{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
},
CredentialKind: "AccessKey",
Region: "region",
AccountID: "account-id",
}
require.Equal(t, expected, provider)
require.Equal(t, []any{output.LogOutput{Format: awsAccessKeysCreateInstructionFmt}}, outputSink.Writes)
}

func Test_enterAWSCloudProvider_IRSA(t *testing.T) {
ctrl := gomock.NewController(t)
prompter := prompt.NewMockInterface(ctrl)
client := aws.NewMockClient(ctrl)
outputSink := output.MockOutput{}
runner := Runner{Prompter: prompter, awsClient: client, Output: &outputSink}
ec2Regions := []ec2_types.Region{
{RegionName: to.Ptr("region")},
{RegionName: to.Ptr("region2")},
}
regions := []string{"region", "region2"}

setAWSCredentialKindPrompt(prompter, "IRSA")
setAwsIRSARoleARNPrompt(prompter, "role-arn")
setAWSCallerIdentity(client, &sts.GetCallerIdentityOutput{Account: to.Ptr("account-id")})
setAWSAccountIDConfirmPrompt(prompter, "account-id", prompt.ConfirmYes)
setAWSListRegions(client, &ec2.DescribeRegionsOutput{Regions: ec2Regions})
setAWSRegionPrompt(prompter, regions, "region")

options := &initOptions{}
provider, err := runner.enterAWSCloudProvider(context.Background(), options)
require.NoError(t, err)

expected := &aws.Provider{
IRSA: &aws.IRSACredential{
RoleARN: "role-arn",
},
CredentialKind: "IRSA",
Region: "region",
AccountID: "account-id",
}
require.Equal(t, expected, provider)
require.Equal(t, []any{output.LogOutput{Format: awsAccessKeysCreateInstructionFmt}}, outputSink.Writes)
Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/cmd/radinit/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (r *Runner) enterCloudProviderOptions(ctx context.Context, options *initOpt

options.CloudProviders.Azure = provider
case aws.ProviderDisplayName:
provider, err := r.enterAWSCloudProvider(ctx)
provider, err := r.enterAWSCloudProvider(ctx, options)
if err != nil {
return err
}
Expand Down
65 changes: 52 additions & 13 deletions pkg/cli/cmd/radinit/cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,23 @@ func Test_enterCloudProviderOptions(t *testing.T) {
},
}

awsProvider := aws.Provider{
Region: "test-region",
AccessKeyID: "test-access-key-id",
SecretAccessKey: "test-secret-access-key",
AccountID: "test-account-id",
awsProviderAccessKey := aws.Provider{
Region: "test-region",
CredentialKind: "AccessKey",
AccessKey: &aws.AccessKeyCredential{
AccessKeyID: "test-access-key-id",
SecretAccessKey: "test-secret-access-key",
},
AccountID: "test-account-id",
}

awsProviderIRSA := aws.Provider{
Region: "test-region",
CredentialKind: "IRSA",
IRSA: &aws.IRSACredential{
RoleARN: "test-role-arn",
},
AccountID: "test-account-id",
}

t.Run("cloud providers skipped when no flags specified", func(t *testing.T) {
Expand Down Expand Up @@ -114,7 +126,7 @@ func Test_enterCloudProviderOptions(t *testing.T) {
require.Empty(t, outputSink.Writes)
})

t.Run("--full - aws provider", func(t *testing.T) {
t.Run("--full - aws provider - accesskey", func(t *testing.T) {
ctrl := gomock.NewController(t)
prompter := prompt.NewMockInterface(ctrl)
awsClient := aws.NewMockClient(ctrl)
Expand All @@ -124,14 +136,41 @@ func Test_enterCloudProviderOptions(t *testing.T) {

initAddCloudProviderPromptYes(prompter)
initSelectCloudProvider(prompter, aws.ProviderDisplayName)
setAWSCloudProvider(prompter, awsClient, awsProvider)
setAWSCloudProviderAccessKey(prompter, awsClient, awsProviderAccessKey)
initAddCloudProviderPromptNo(prompter)

options := initOptions{Environment: environmentOptions{Create: true}}
err := runner.enterCloudProviderOptions(context.Background(), &options)
require.NoError(t, err)
require.Nil(t, options.CloudProviders.Azure)
require.Equal(t, awsProvider, *options.CloudProviders.AWS)
require.Equal(t, awsProviderAccessKey, *options.CloudProviders.AWS)

expectedWrites := []any{
output.LogOutput{
Format: awsAccessKeysCreateInstructionFmt,
},
}
require.Equal(t, expectedWrites, outputSink.Writes)
})

t.Run("--full - aws provider - irsa", func(t *testing.T) {
ctrl := gomock.NewController(t)
prompter := prompt.NewMockInterface(ctrl)
awsClient := aws.NewMockClient(ctrl)
azureClient := azure.NewMockClient(ctrl)
outputSink := output.MockOutput{}
runner := Runner{Prompter: prompter, awsClient: awsClient, azureClient: azureClient, Output: &outputSink, Full: true}

initAddCloudProviderPromptYes(prompter)
initSelectCloudProvider(prompter, aws.ProviderDisplayName)
setAWSCloudProviderIRSA(prompter, awsClient, awsProviderIRSA)
initAddCloudProviderPromptNo(prompter)

options := initOptions{Environment: environmentOptions{Create: true}}
err := runner.enterCloudProviderOptions(context.Background(), &options)
require.NoError(t, err)
require.Nil(t, options.CloudProviders.Azure)
require.Equal(t, awsProviderIRSA, *options.CloudProviders.AWS)

expectedWrites := []any{
output.LogOutput{
Expand Down Expand Up @@ -206,7 +245,7 @@ func Test_enterCloudProviderOptions(t *testing.T) {

initAddCloudProviderPromptYes(prompter)
initSelectCloudProvider(prompter, aws.ProviderDisplayName)
setAWSCloudProvider(prompter, awsClient, awsProvider)
setAWSCloudProviderAccessKey(prompter, awsClient, awsProviderAccessKey)

initAddCloudProviderPromptYes(prompter)
initSelectCloudProvider(prompter, azure.ProviderDisplayName)
Expand All @@ -217,7 +256,7 @@ func Test_enterCloudProviderOptions(t *testing.T) {
options := initOptions{Environment: environmentOptions{Create: true}}
err := runner.enterCloudProviderOptions(context.Background(), &options)
require.NoError(t, err)
require.Equal(t, awsProvider, *options.CloudProviders.AWS)
require.Equal(t, awsProviderAccessKey, *options.CloudProviders.AWS)
require.Equal(t, azureProviderServicePrincipal, *options.CloudProviders.Azure)

expectedWrites := []any{
Expand All @@ -243,13 +282,13 @@ func Test_enterCloudProviderOptions(t *testing.T) {

initAddCloudProviderPromptYes(prompter)
initSelectCloudProvider(prompter, aws.ProviderDisplayName)
setAWSCloudProvider(prompter, awsClient, awsProvider)
setAWSCloudProviderAccessKey(prompter, awsClient, awsProviderAccessKey)

awsProvider := awsProvider
awsProvider := awsProviderAccessKey
awsProvider.Region = "another-region"
initAddCloudProviderPromptYes(prompter)
initSelectCloudProvider(prompter, aws.ProviderDisplayName)
setAWSCloudProvider(prompter, awsClient, awsProvider)
setAWSCloudProviderAccessKey(prompter, awsClient, awsProvider)

initAddCloudProviderPromptNo(prompter)

Expand Down
Loading

0 comments on commit a4a4b90

Please sign in to comment.