Skip to content

Commit

Permalink
Refactor credentials finder function
Browse files Browse the repository at this point in the history
To unblock unit-tests.

Signed-off-by: Artiom Diomin <artiom@kubermatic.com>
  • Loading branch information
kron4eg committed Jan 4, 2024
1 parent 998b347 commit 67e40f8
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 29 deletions.
92 changes: 63 additions & 29 deletions pkg/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,15 @@ type ProviderEnvironmentVariable struct {
}

func Any(credentialsFilePath string) (map[string]string, error) {
credentialsFinder, err := newCredsFinder(credentialsFilePath, TypeUniversal)
credentialsFinder, err := newCredentialsFinder(withYAMLFile(credentialsFilePath))
if err != nil {
return nil, err
}

creds := map[string]string{}

for _, key := range allKeys {
if val := credentialsFinder(key); val != "" {
if val := credentialsFinder.get(key); val != "" {
creds[key] = val
// NB: We want to use Equinix Metal env vars everywhere, even if
// users has PACKET_ env vars on their systems.
Expand All @@ -178,11 +178,12 @@ func Any(credentialsFilePath string) (map[string]string, error) {

// ProviderCredentials implements fetching credentials for each supported provider
func ProviderCredentials(cloudProvider kubeoneapi.CloudProviderSpec, credentialsFilePath string, credentialsType Type) (map[string]string, error) {
credentialsFinder, err := newCredsFinder(credentialsFilePath, credentialsType)
credentialsFinderStore, err := newCredentialsFinder(withYAMLFile(credentialsFilePath), withType(credentialsType))
if err != nil {
return nil, err
}

credentialsFinder := credentialsFinderStore.lookupFunc()
switch {
case cloudProvider.AWS != nil:
return credentialsFinder.aws()
Expand Down Expand Up @@ -277,43 +278,76 @@ func ProviderCredentials(cloudProvider kubeoneapi.CloudProviderSpec, credentials
}
}

func newCredsFinder(credentialsFilePath string, credentialsType Type) (lookupFunc, error) {
staticMap := map[string]string{}
finder := func(name string) string {
switch {
case credentialsType != TypeUniversal:
typedName := string(credentialsType) + "_" + name
if val := os.Getenv(typedName); val != "" {
return val
}
if val, ok := staticMap[typedName]; ok && val != "" {
return val
}
func withYAMLFile(filePath string) func(*credentialsFinder) error {
return func(cf *credentialsFinder) error {
if filePath == "" {
return nil
}

fallthrough
default:
if val := os.Getenv(name); val != "" {
return val
}
buf, err := os.ReadFile(filePath)
if err != nil {
return fail.Runtime(err, "reading credentials file")
}

return staticMap[name]
if err = yaml.Unmarshal(buf, &cf.static); err != nil {
return fail.Runtime(err, "unmarshalling credentials file")
}

return nil
}
}

func withType(typ Type) func(*credentialsFinder) error {
return func(cf *credentialsFinder) error {
cf.typ = typ

if credentialsFilePath == "" {
return finder, nil
return nil
}
}

buf, err := os.ReadFile(credentialsFilePath)
if err != nil {
return nil, fail.Runtime(err, "loading credentials file")
func newCredentialsFinder(opts ...func(*credentialsFinder) error) (*credentialsFinder, error) {
cf := credentialsFinder{
static: map[string]string{},
dynamic: os.Getenv,
}

if err = yaml.Unmarshal(buf, &staticMap); err != nil {
return nil, fail.Runtime(err, "unmarshalling credentials file")
for _, optFn := range opts {
if err := optFn(&cf); err != nil {
return nil, err
}
}

return &cf, nil
}

type credentialsFinder struct {
static map[string]string
dynamic func(string) string
typ Type
}

func (cf *credentialsFinder) lookupFunc() lookupFunc { return cf.get }

func (cf *credentialsFinder) typedKey(name string) string {
return string(cf.typ) + "_" + name
}

func (cf *credentialsFinder) fetch(name string) string {
if val := cf.static[name]; val != "" {
return val
}

return cf.dynamic(name)
}

func (cf *credentialsFinder) get(name string) string {
if cf.typ != TypeUniversal {
if val := cf.fetch(cf.typedKey(name)); val != "" {
return val
}
}

return finder, nil
return cf.fetch(name)
}

// lookupFunc is function that retrieves credentials from the sources
Expand Down
92 changes: 92 additions & 0 deletions pkg/credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,95 @@ func TestVmwareCloudDirectorValidationFunc(t *testing.T) {
})
}
}

func TestCredentialsFinder(t *testing.T) {
withDynamicFixture := func(dynamicFn func(string) string) func(*credentialsFinder) error {
return func(cf *credentialsFinder) error {
cf.dynamic = dynamicFn

return nil
}
}

withStaticFixture := func(static map[string]string) func(*credentialsFinder) error {
return func(cf *credentialsFinder) error {
cf.static = static

return nil
}
}

tests := []struct {
name string
key string
want string
opts []func(*credentialsFinder) error
}{
{
name: "static universal",
key: "key1",
want: "val1",
opts: []func(*credentialsFinder) error{
withStaticFixture(map[string]string{
"key1": "val1",
}),
},
},
{
name: "static with type OSM",
key: "key1",
want: "OSM_val1",
opts: []func(*credentialsFinder) error{
withType(TypeOSM),
withStaticFixture(map[string]string{
"OSM_key1": "OSM_val1",
}),
},
},
{
name: "dynamic with type OSM",
key: "key1",
want: "OSM_val1",
opts: []func(*credentialsFinder) error{
withType(TypeOSM),
withStaticFixture(map[string]string{
"key1": "from_static",
}),
withDynamicFixture(func(key string) string {
return map[string]string{
"OSM_key1": "OSM_val1",
}[key]
}),
},
},
{
name: "static precedence over dynamic with type OSM",
key: "key1",
want: "from_static",
opts: []func(*credentialsFinder) error{
withType(TypeOSM),
withStaticFixture(map[string]string{
"OSM_key1": "from_static",
}),
withDynamicFixture(func(key string) string {
return map[string]string{
"OSM_key1": "from_dynamic",
}[key]
}),
},
},
}

for _, tcase := range tests {
t.Run(tcase.name, func(t *testing.T) {
finder, err := newCredentialsFinder(tcase.opts...)
if err != nil {
t.Fatalf("got unexpcted error: %v", err)
}

if result := finder.get(tcase.key); result != tcase.want {
t.Errorf("get(%q)=%q, want %q", tcase.key, result, tcase.want)
}
})
}
}

0 comments on commit 67e40f8

Please sign in to comment.