diff --git a/pkg/generator/credentials.go b/pkg/generator/credentials.go index 15aa71329..fe158705a 100644 --- a/pkg/generator/credentials.go +++ b/pkg/generator/credentials.go @@ -51,7 +51,7 @@ func genCredentialSet(name string, creds map[string]bundle.Credential, fn genera sort.Strings(credentialNames) for _, name := range credentialNames { - c, err := fn(name, surveyCredentials, "") + c, err := fn(name, surveyCredentials, nil) if err != nil { return cs, err } diff --git a/pkg/generator/generator.go b/pkg/generator/generator.go index 7099d8aae..ecf3584be 100644 --- a/pkg/generator/generator.go +++ b/pkg/generator/generator.go @@ -32,16 +32,16 @@ const ( questionCommand = "shell command" ) -type generator func(name string, surveyType SurveyType, defaultVal string) (valuesource.Strategy, error) +type generator func(name string, surveyType SurveyType, defaultVal interface{}) (valuesource.Strategy, error) -func genEmptySet(name string, surveyType SurveyType, defaultVal string) (valuesource.Strategy, error) { +func genEmptySet(name string, surveyType SurveyType, defaultVal interface{}) (valuesource.Strategy, error) { return valuesource.Strategy{ Name: name, Source: valuesource.Source{Value: "TODO"}, }, nil } -func genSurvey(name string, surveyType SurveyType, defaultVal string) (valuesource.Strategy, error) { +func genSurvey(name string, surveyType SurveyType, defaultVal interface{}) (valuesource.Strategy, error) { if surveyType != surveyCredentials && surveyType != surveyParameters { return valuesource.Strategy{}, fmt.Errorf("unsupported survey type: %s", surveyType) } @@ -49,7 +49,7 @@ func genSurvey(name string, surveyType SurveyType, defaultVal string) (valuesour options := []string{questionSecret, questionValue, questionEnvVar, questionPath, questionCommand} questionDefault := fmt.Sprintf("use default value (%s)", defaultVal) - if defaultVal != "" { + if defaultVal != nil { options = append(options, questionDefault) } @@ -93,15 +93,13 @@ func genSurvey(name string, surveyType SurveyType, defaultVal string) (valuesour if err := survey.AskOne(sourceValuePrompt, &value, nil); err != nil { return c, err } - } else { - value = defaultVal } switch source { case questionSecret: c.Source.Key = secrets.SourceSecret c.Source.Value = value - case questionValue, questionDefault: + case questionValue: c.Source.Key = host.SourceValue c.Source.Value = value case questionEnvVar: diff --git a/pkg/generator/generator_test.go b/pkg/generator/generator_test.go index 9040adae4..207631446 100644 --- a/pkg/generator/generator_test.go +++ b/pkg/generator/generator_test.go @@ -13,13 +13,13 @@ func Test_genEmptySet(t *testing.T) { Source: valuesource.Source{Value: "TODO"}, } - got, err := genEmptySet("emptyset", surveyParameters, "") + got, err := genEmptySet("emptyset", surveyParameters, nil) require.NoError(t, err) require.Equal(t, expected, got) } func Test_genSurvey_unsupported(t *testing.T) { - got, err := genSurvey("myturtleset", SurveyType("turtles"), "") + got, err := genSurvey("myturtleset", SurveyType("turtles"), nil) require.EqualError(t, err, "unsupported survey type: turtles") require.Equal(t, valuesource.Strategy{}, got) } diff --git a/pkg/generator/parameters.go b/pkg/generator/parameters.go index af2be8bdc..f04c10209 100644 --- a/pkg/generator/parameters.go +++ b/pkg/generator/parameters.go @@ -52,7 +52,11 @@ func (opts *GenerateParametersOptions) genParameterSet(fn generator) (parameters if parameters.IsInternal(name, opts.Bundle) { continue } - defaultVal, _ := getDefaultParamValue(opts.Bundle, name) + defaultVal, err := getDefaultParamValue(opts.Bundle, name) + + if err != nil { + return pset, err + } c, err := fn(name, surveyParameters, defaultVal) if err != nil { @@ -64,7 +68,7 @@ func (opts *GenerateParametersOptions) genParameterSet(fn generator) (parameters return pset, nil } -func getDefaultParamValue(bun bundle.Bundle, name string) (string, error) { +func getDefaultParamValue(bun bundle.Bundle, name string) (interface{}, error) { for p, v := range bun.Parameters { if p == name { def, ok := bun.Definitions[v.Definition] @@ -75,9 +79,7 @@ func getDefaultParamValue(bun bundle.Bundle, name string) (string, error) { return "", fmt.Errorf("parameter definition for %s is empty", name) } - if def.Default != nil { - return fmt.Sprintf("%s", def.Default), nil - } + return def.Default, nil } } return "", nil