Skip to content
This repository has been archived by the owner on Mar 2, 2024. It is now read-only.

🐛 Fix broken SSM parameter path construction #1034

Merged
merged 4 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/release_notes/0.0.105.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

## Bugfixes

🐛 Fix broken SSM parameter path construction
deifyed marked this conversation as resolved.
Show resolved Hide resolved

## Changes

## Other
Expand Down
39 changes: 31 additions & 8 deletions pkg/cognito/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/AlecAivazis/survey/v2"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/cognitoidentityprovider"
"github.com/aws/aws-sdk-go/service/cognitoidentityprovider/cognitoidentityprovideriface"
"github.com/oslokommune/okctl/pkg/apis/okctl.io/v1alpha1"
)

Expand Down Expand Up @@ -54,7 +53,12 @@ func acquireSession(opts RegisterMFADeviceOpts) (string, error) {
return "", fmt.Errorf("acquiring Cognito client ID: %w", err)
}

clientSecret, err := getCognitoClientSecretForClient(opts.Ctx, opts.ParameterStoreProvider, cognitoUserPoolclient.Name)
clientSecret, err := getCognitoClientSecretForClient(
opts.Ctx,
opts.ParameterStoreProvider,
opts.Cluster.Metadata.Name,
cognitoUserPoolclient.Name,
)
if err != nil {
return "", fmt.Errorf("acquiring Cognito client secret: %w", err)
}
Expand Down Expand Up @@ -89,7 +93,7 @@ func acquireSession(opts RegisterMFADeviceOpts) (string, error) {
return *initiateAuthResult.Session, nil
}

func getCognitoClientForCluster(ctx context.Context, provider cognitoidentityprovideriface.CognitoIdentityProviderAPI, cluster v1alpha1.Cluster) (userPoolClient, error) {
func getCognitoClientForCluster(ctx context.Context, provider userpoolAPI, cluster v1alpha1.Cluster) (userPoolClient, error) {
relevantUserPoolID, err := getRelevantUserPoolID(ctx, provider, cluster)
if err != nil {
return userPoolClient{}, fmt.Errorf("getting relevant user pool ID: %w", err)
Expand All @@ -100,10 +104,16 @@ func getCognitoClientForCluster(ctx context.Context, provider cognitoidentitypro
return userPoolClient{}, fmt.Errorf("getting relevant user pool client: %w", err)
}

return userPoolClient{Name: *relevantUserPoolClient.ClientName, ID: *relevantUserPoolClient.ClientId}, nil
clientName := strings.ReplaceAll(
*relevantUserPoolClient.ClientName,
fmt.Sprintf("okctl-%s-", cluster.Metadata.Name),
"",
)

return userPoolClient{Name: clientName, ID: *relevantUserPoolClient.ClientId}, nil
}

func getRelevantUserPoolID(ctx context.Context, provider cognitoidentityprovideriface.CognitoIdentityProviderAPI, cluster v1alpha1.Cluster) (string, error) {
func getRelevantUserPoolID(ctx context.Context, provider userpoolLister, cluster v1alpha1.Cluster) (string, error) {
var nextToken *string

for {
Expand Down Expand Up @@ -166,15 +176,15 @@ func getRelevantUserPoolClient(ctx context.Context, provider userpoolClientsList
return cognitoidentityprovider.UserPoolClientDescription{}, fmt.Errorf("no clients found for user pool %s", userPoolID)
}

func getCognitoClientSecretForClient(ctx context.Context, provider ssmiface.SSMAPI, clientName string) (string, error) {
parameterPath := fmt.Sprintf("/%s/client_secret", strings.ReplaceAll(clientName, "-", "/"))
func getCognitoClientSecretForClient(ctx context.Context, provider ssmiface.SSMAPI, clusterName string, clientName string) (string, error) {
parameterPath := path.Join("/", "okctl", clusterName, clientName, "client_secret")

getParameterResult, err := provider.GetParameterWithContext(ctx, &ssm.GetParameterInput{
Name: aws.String(parameterPath),
WithDecryption: aws.Bool(true),
})
if err != nil {
return "", fmt.Errorf("retrieving parameter: %w", err)
return "", fmt.Errorf("retrieving parameter \"%s\": %w", parameterPath, err)
}

return *getParameterResult.Parameter.Value, nil
Expand Down Expand Up @@ -270,3 +280,16 @@ type userpoolClientsLister interface {
...request.Option,
) (*cognitoidentityprovider.ListUserPoolClientsOutput, error)
}

type userpoolLister interface {
ListUserPoolsWithContext(
context.Context,
*cognitoidentityprovider.ListUserPoolsInput,
...request.Option,
) (*cognitoidentityprovider.ListUserPoolsOutput, error)
}

type userpoolAPI interface {
userpoolLister
userpoolClientsLister
}
54 changes: 53 additions & 1 deletion pkg/cognito/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package cognito

import (
"context"
"fmt"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/cognitoidentityprovider"
"github.com/oslokommune/okctl/pkg/apis/okctl.io/v1alpha1"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -57,8 +59,46 @@ func TestGetRelevantUserPoolClient(t *testing.T) {
}
}

func TestGetCognitoClientForCluster(t *testing.T) {
testCases := []struct {
name string
withClusterName string
expectClientName string
}{
{
name: "Should work",
withClusterName: "mock-prod",
expectClientName: "argocd",
},
}

for _, tc := range testCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

cluster := v1alpha1.NewCluster()
cluster.Metadata.Name = tc.withClusterName

client, err := getCognitoClientForCluster(
context.Background(),
&mockCognitoIdentityProviderAPI{
clusterName: cluster.Metadata.Name,
clients: []string{fmt.Sprintf("okctl-%s-argocd", tc.withClusterName)},
},
cluster,
)
assert.NoError(t, err)

assert.Equal(t, tc.expectClientName, client.Name)
})
}
}

type mockCognitoIdentityProviderAPI struct {
clients []string
clusterName string
clients []string
}

func (m mockCognitoIdentityProviderAPI) ListUserPoolClientsWithContext(
Expand All @@ -80,3 +120,15 @@ func (m mockCognitoIdentityProviderAPI) ListUserPoolClientsWithContext(
UserPoolClients: clients,
}, nil
}

func (m mockCognitoIdentityProviderAPI) ListUserPoolsWithContext(
_ context.Context,
_ *cognitoidentityprovider.ListUserPoolsInput,
_ ...request.Option,
) (*cognitoidentityprovider.ListUserPoolsOutput, error) {
return &cognitoidentityprovider.ListUserPoolsOutput{
UserPools: []*cognitoidentityprovider.UserPoolDescriptionType{
{Id: aws.String("mock-id"), Name: aws.String(m.clusterName)},
},
}, nil
}