diff --git a/args.go b/args.go index 682eb497da..b9c746be72 100644 --- a/args.go +++ b/args.go @@ -98,11 +98,11 @@ func (a *ArgumentBase[T, C, VC]) Usage() string { func (a *ArgumentBase[T, C, VC]) Parse(s []string) ([]string, error) { tracef("calling arg%[1] parse with args %[2]", &a.Name, s) if a.Max == 0 { - fmt.Printf("WARNING args %s has max 0, not parsing argument", a.Name) + fmt.Printf("WARNING args %s has max 0, not parsing argument\n", a.Name) return s, nil } if a.Max != -1 && a.Min > a.Max { - fmt.Printf("WARNING args %s has min[%d] > max[%d], not parsing argument", a.Name, a.Min, a.Max) + fmt.Printf("WARNING args %s has min[%d] > max[%d], not parsing argument\n", a.Name, a.Min, a.Max) return s, nil } diff --git a/args_test.go b/args_test.go index 983ed5d06a..1d2b5f6c66 100644 --- a/args_test.go +++ b/args_test.go @@ -44,6 +44,19 @@ func TestArgumentsRootCommand(t *testing.T) { require.Error(t, errors.New("No help topic for '12.1"), cmd.Run(context.Background(), []string{"foo", "13", "10.1", "11.09", "12.1"})) require.Equal(t, int64(13), ival) require.Equal(t, []float64{10.1, 11.09}, fvals) + + cmd.Arguments = append(cmd.Arguments, + &StringArg{ + Name: "sa", + }, + &UintArg{ + Name: "ua", + Min: 2, + Max: 1, // max is less than min + }, + ) + + require.NoError(t, cmd.Run(context.Background(), []string{"foo", "10"})) } func TestArgumentsSubcommand(t *testing.T) { @@ -103,6 +116,7 @@ func TestArgsUsage(t *testing.T) { name string min int max int + usage string expected string }{ { @@ -111,6 +125,13 @@ func TestArgsUsage(t *testing.T) { max: 1, expected: "[ia]", }, + { + name: "optional", + min: 0, + max: 1, + usage: "[my optional usage]", + expected: "[my optional usage]", + }, { name: "zero or more", min: 0, @@ -144,7 +165,7 @@ func TestArgsUsage(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - arg.Min, arg.Max = test.min, test.max + arg.Min, arg.Max, arg.UsageText = test.min, test.max, test.usage require.Equal(t, test.expected, arg.Usage()) }) } diff --git a/cli_test.go b/cli_test.go index 0eb3ee667d..4c2c4b26dc 100644 --- a/cli_test.go +++ b/cli_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -26,3 +27,33 @@ func buildTestContext(t *testing.T) context.Context { return ctx } + +func TestTracing(t *testing.T) { + olderr := os.Stderr + oldtracing := isTracingOn + defer func() { + os.Stderr = olderr + isTracingOn = oldtracing + }() + + file, err := os.CreateTemp(os.TempDir(), "cli*") + assert.NoError(t, err) + os.Stderr = file + + // Note we cant really set the env since the isTracingOn + // is read at module startup so any changes mid code + // wont take effect + isTracingOn = false + tracef("something") + + isTracingOn = true + tracef("foothing") + + assert.NoError(t, file.Close()) + + b, err := os.ReadFile(file.Name()) + assert.NoError(t, err) + + assert.Contains(t, string(b), "foothing") + assert.NotContains(t, string(b), "something") +} diff --git a/command_test.go b/command_test.go index dd822c76f2..e61bc52246 100644 --- a/command_test.go +++ b/command_test.go @@ -3975,6 +3975,21 @@ func TestZeroValueCommand(t *testing.T) { assert.NoError(t, cmd.Run(context.Background(), []string{"foo"})) } +func TestCommandInvalidName(t *testing.T) { + var cmd Command + assert.Equal(t, int64(0), cmd.Int("foo")) + assert.Equal(t, uint64(0), cmd.Uint("foo")) + assert.Equal(t, float64(0), cmd.Float("foo")) + assert.Equal(t, "", cmd.String("foo")) + assert.Equal(t, time.Time{}, cmd.Timestamp("foo")) + assert.Equal(t, time.Duration(0), cmd.Duration("foo")) + + assert.Equal(t, []int64(nil), cmd.IntSlice("foo")) + assert.Equal(t, []uint64(nil), cmd.UintSlice("foo")) + assert.Equal(t, []float64(nil), cmd.FloatSlice("foo")) + assert.Equal(t, []string(nil), cmd.StringSlice("foo")) +} + func TestJSONExportCommand(t *testing.T) { cmd := buildExtendedTestCommand() cmd.Arguments = []Argument{ diff --git a/errors.go b/errors.go index 18938e4d56..6b377f1097 100644 --- a/errors.go +++ b/errors.go @@ -47,7 +47,6 @@ func (m *multiError) Errors() []error { type requiredFlagsErr interface { error - getMissingFlags() []string } type errRequiredFlags struct { @@ -62,10 +61,6 @@ func (e *errRequiredFlags) Error() string { return fmt.Sprintf("Required flags %q not set", joinedMissingFlags) } -func (e *errRequiredFlags) getMissingFlags() []string { - return e.missingFlags -} - type mutuallyExclusiveGroup struct { flag1Name string flag2Name string @@ -86,12 +81,6 @@ func (e *mutuallyExclusiveGroupRequiredFlag) Error() string { for _, f := range grpf { grpString = append(grpString, f.Names()...) } - if len(e.flags.Flags) == 1 { - err := errRequiredFlags{ - missingFlags: grpString, - } - return err.Error() - } missingFlags = append(missingFlags, strings.Join(grpString, " ")) } @@ -148,10 +137,6 @@ func (ee *exitError) ExitCode() int { return ee.exitCode } -func (ee *exitError) Unwrap() error { - return ee.err -} - // HandleExitCoder handles errors implementing ExitCoder by printing their // message and calling OsExiter with the given exit code. // @@ -198,12 +183,3 @@ func handleMultiError(multiErr MultiError) int { } return code } - -type typeError[T any] struct { - other any -} - -func (te *typeError[T]) Error() string { - var t T - return fmt.Sprintf("Expected type %T got instead %T", t, te.other) -} diff --git a/errors_test.go b/errors_test.go index 8d734b621c..35aaab54d4 100644 --- a/errors_test.go +++ b/errors_test.go @@ -81,6 +81,7 @@ func TestHandleExitCoder_MultiErrorWithExitCoder(t *testing.T) { exitErr := Exit("galactic perimeter breach", 9) exitErr2 := Exit("last ExitCoder", 11) + err := newMultiError(errors.New("wowsa"), errors.New("egad"), exitErr, exitErr2) HandleExitCoder(err) @@ -88,6 +89,54 @@ func TestHandleExitCoder_MultiErrorWithExitCoder(t *testing.T) { assert.True(t, called) } +type exitFormatter struct { + code int +} + +func (f *exitFormatter) Format(s fmt.State, verb rune) { + _, _ = s.Write([]byte("some other special")) +} + +func (f *exitFormatter) ExitCode() int { + return f.code +} + +func (f *exitFormatter) Error() string { + return fmt.Sprintf("my special error code %d", f.code) +} + +func TestHandleExitCoder_ErrorFormatter(t *testing.T) { + exitCode := 0 + called := false + + OsExiter = func(rc int) { + if !called { + exitCode = rc + called = true + } + } + + oldWriter := ErrWriter + var buf bytes.Buffer + ErrWriter = &buf + defer func() { + OsExiter = fakeOsExiter + ErrWriter = oldWriter + }() + + exitErr := Exit("galactic perimeter breach", 9) + exitErr2 := Exit("last ExitCoder", 11) + exitErr3 := &exitFormatter{code: 12} + + // add some recursion for multi error to fix test coverage + err := newMultiError(errors.New("wowsa"), errors.New("egad"), exitErr3, newMultiError(exitErr, exitErr2)) + HandleExitCoder(err) + + assert.Equal(t, 11, exitCode) + assert.True(t, called) + assert.Contains(t, buf.String(), "some other special") +} + func TestHandleExitCoder_MultiErrorWithoutExitCoder(t *testing.T) { exitCode := 0 called := false diff --git a/flag.go b/flag.go index 420ea5e939..11b13662c5 100644 --- a/flag.go +++ b/flag.go @@ -129,10 +129,6 @@ type DocGenerationFlag interface { // GetUsage returns the usage string for the flag GetUsage() string - // GetValue returns the flags value as string representation and an empty - // string if the flag takes no value at all. - GetValue() string - // GetDefaultText returns the default text for this flag GetDefaultText() string diff --git a/flag_float_slice.go b/flag_float_slice.go index 390e466e85..1ba65306ff 100644 --- a/flag_float_slice.go +++ b/flag_float_slice.go @@ -1,9 +1,5 @@ package cli -import ( - "flag" -) - type ( FloatSlice = SliceBase[float64, NoConfig, floatValue] FloatSliceFlag = FlagBase[[]float64, NoConfig, FloatSlice] @@ -14,22 +10,11 @@ var NewFloatSlice = NewSliceBase[float64, NoConfig, floatValue] // FloatSlice looks up the value of a local FloatSliceFlag, returns // nil if not found func (cmd *Command) FloatSlice(name string) []float64 { - if flSet := cmd.lookupFlagSet(name); flSet != nil { - return lookupFloatSlice(name, flSet, cmd.Name) - } - - return nil -} - -func lookupFloatSlice(name string, set *flag.FlagSet, cmdName string) []float64 { - fl := set.Lookup(name) - if fl != nil { - if v, ok := fl.Value.(flag.Getter).Get().([]float64); ok { - tracef("float slice available for flag name %[1]q with value=%[2]v (cmd=%[3]q)", name, v, cmdName) - return v - } + if v, ok := cmd.Value(name).([]float64); ok { + tracef("float slice available for flag name %[1]q with value=%[2]v (cmd=%[3]q)", name, v, cmd.Name) + return v } - tracef("float slice NOT available for flag name %[1]q (cmd=%[2]q)", name, cmdName) + tracef("float slice NOT available for flag name %[1]q (cmd=%[2]q)", name, cmd.Name) return nil } diff --git a/flag_impl.go b/flag_impl.go index 77bc591860..ad6678bfc7 100644 --- a/flag_impl.go +++ b/flag_impl.go @@ -33,13 +33,6 @@ func (f *fnValue) String() string { return f.v.String() } -func (f *fnValue) Serialize() string { - if s, ok := f.v.(Serializer); ok { - return s.Serialize() - } - return f.v.String() -} - func (f *fnValue) IsBoolFlag() bool { return f.isBool } func (f *fnValue) Count() int { if s, ok := f.v.(Countable); ok { @@ -96,15 +89,6 @@ type FlagBase[T any, C any, VC ValueCreator[T, C]] struct { value Value // value representing this flag's value } -// GetValue returns the flags value as string representation and an empty -// string if the flag takes no value at all. -func (f *FlagBase[T, C, V]) GetValue() string { - if reflect.TypeOf(f.Value).Kind() == reflect.Bool { - return "" - } - return fmt.Sprintf("%v", f.Value) -} - // Apply populates the flag given the flag set and environment func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error { tracef("apply (flag=%[1]q)", f.Name) @@ -128,13 +112,7 @@ func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error { ) } } else if val == "" && reflect.TypeOf(f.Value).Kind() == reflect.Bool { - val = "false" - if err := tmpVal.Set(val); err != nil { - return fmt.Errorf( - "could not parse %[1]q as %[2]T value from %[3]s for flag %[4]s: %[5]s", - val, f.Value, source, f.Name, err, - ) - } + _ = tmpVal.Set("false") } newVal = tmpVal.Get().(T) @@ -149,11 +127,7 @@ func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error { // Validate the given default or values set from external sources as well if f.Validator != nil && f.ValidateDefaults { - if v, ok := f.value.Get().(T); !ok { - return &typeError[T]{ - other: f.value.Get(), - } - } else if err := f.Validator(v); err != nil { + if err := f.Validator(f.value.Get().(T)); err != nil { return err } } @@ -176,11 +150,7 @@ func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error { } f.hasBeenSet = true if f.Validator != nil { - if v, ok := f.value.Get().(T); !ok { - return &typeError[T]{ - other: f.value.Get(), - } - } else if err := f.Validator(v); err != nil { + if err := f.Validator(f.value.Get().(T)); err != nil { return err } } @@ -254,19 +224,10 @@ func (f *FlagBase[T, C, V]) GetDefaultText() string { return v.ToString(f.Value) } -// Get returns the flag’s value in the given Command. -func (f *FlagBase[T, C, V]) Get(cmd *Command) T { - if v, ok := cmd.Value(f.Name).(T); ok { - return v - } - var t T - return t -} - // RunAction executes flag action if set func (f *FlagBase[T, C, V]) RunAction(ctx context.Context, cmd *Command) error { if f.Action != nil { - return f.Action(ctx, cmd, f.Get(cmd)) + return f.Action(ctx, cmd, cmd.Value(f.Name).(T)) } return nil diff --git a/flag_slice_base.go b/flag_slice_base.go index 7278e3ed64..b97c4ff4f2 100644 --- a/flag_slice_base.go +++ b/flag_slice_base.go @@ -33,16 +33,6 @@ func NewSliceBase[T any, C any, VC ValueCreator[T, C]](defaults ...T) *SliceBase } } -// SetOne directly adds a value to the list of values -func (i *SliceBase[T, C, VC]) SetOne(value T) { - if !i.hasBeenSet { - *i.slice = []T{} - i.hasBeenSet = true - } - - *i.slice = append(*i.slice, value) -} - // Set parses the value and appends it to the list of values func (i *SliceBase[T, C, VC]) Set(value string) error { if !i.hasBeenSet { @@ -61,11 +51,7 @@ func (i *SliceBase[T, C, VC]) Set(value string) error { if err := i.value.Set(strings.TrimSpace(s)); err != nil { return err } - tmp, ok := i.value.Get().(T) - if !ok { - return fmt.Errorf("unable to cast %v", i.value) - } - *i.slice = append(*i.slice, tmp) + *i.slice = append(*i.slice, i.value.Get().(T)) } return nil @@ -90,7 +76,7 @@ func (i *SliceBase[T, C, VC]) Serialize() string { // Value returns the slice of values set by this flag func (i *SliceBase[T, C, VC]) Value() []T { if i.slice == nil { - return []T{} + return nil } return *i.slice } diff --git a/flag_test.go b/flag_test.go index 74b6d1d1bb..038fc45cfb 100644 --- a/flag_test.go +++ b/flag_test.go @@ -9,6 +9,7 @@ import ( "os" "reflect" "regexp" + "sort" "strings" "testing" "time" @@ -60,8 +61,8 @@ func TestBoolFlagValueFromCommand(t *testing.T) { ff := &BoolFlag{Name: "falseflag"} r := require.New(t) - r.True(tf.Get(cmd)) - r.False(ff.Get(cmd)) + r.True(cmd.Bool(tf.Name)) + r.False(cmd.Bool(ff.Name)) } func TestBoolFlagApply_SetsCount(t *testing.T) { @@ -663,7 +664,7 @@ func TestStringFlagValueFromCommand(t *testing.T) { set.String("myflag", "foobar", "doc") cmd := &Command{flagSet: set} f := &StringFlag{Name: "myflag"} - require.Equal(t, "foobar", f.Get(cmd)) + require.Equal(t, "foobar", cmd.String(f.Name)) } var _ = []struct { @@ -790,7 +791,7 @@ func TestStringSliceFlagValueFromCommand(t *testing.T) { set.Var(NewStringSlice("a", "b", "c"), "myflag", "doc") cmd := &Command{flagSet: set} f := &StringSliceFlag{Name: "myflag"} - require.Equal(t, []string{"a", "b", "c"}, f.Get(cmd)) + require.Equal(t, []string{"a", "b", "c"}, cmd.StringSlice(f.Name)) } var intFlagTests = []struct { @@ -844,7 +845,7 @@ func TestIntFlagValueFromCommand(t *testing.T) { set.Int64("myflag", int64(42), "doc") cmd := &Command{flagSet: set} fl := &IntFlag{Name: "myflag"} - require.Equal(t, int64(42), fl.Get(cmd)) + require.Equal(t, int64(42), cmd.Int(fl.Name)) } var uintFlagTests = []struct { @@ -887,7 +888,7 @@ func TestUintFlagValueFromCommand(t *testing.T) { set.Uint64("myflag", 42, "doc") cmd := &Command{flagSet: set} fl := &UintFlag{Name: "myflag"} - require.Equal(t, uint64(42), fl.Get(cmd)) + require.Equal(t, uint64(42), cmd.Uint(fl.Name)) } var uint64FlagTests = []struct { @@ -930,7 +931,7 @@ func TestUint64FlagValueFromCommand(t *testing.T) { set.Uint64("myflag", 42, "doc") cmd := &Command{flagSet: set} f := &UintFlag{Name: "myflag"} - require.Equal(t, uint64(42), f.Get(cmd)) + require.Equal(t, uint64(42), cmd.Uint(f.Name)) } var durationFlagTests = []struct { @@ -984,7 +985,7 @@ func TestDurationFlagValueFromCommand(t *testing.T) { set.Duration("myflag", 42*time.Second, "doc") cmd := &Command{flagSet: set} f := &DurationFlag{Name: "myflag"} - require.Equal(t, 42*time.Second, f.Get(cmd)) + require.Equal(t, 42*time.Second, cmd.Duration(f.Name)) } var intSliceFlagTests = []struct { @@ -1105,7 +1106,7 @@ func TestIntSliceFlagValueFromCommand(t *testing.T) { set.Var(NewIntSlice(1, 2, 3), "myflag", "doc") cmd := &Command{flagSet: set} f := &IntSliceFlag{Name: "myflag"} - require.Equal(t, []int64{1, 2, 3}, f.Get(cmd)) + require.Equal(t, []int64{1, 2, 3}, cmd.IntSlice(f.Name)) } var uintSliceFlagTests = []struct { @@ -1441,7 +1442,7 @@ func TestFloat64FlagValueFromCommand(t *testing.T) { set.Float64("myflag", 1.23, "doc") cmd := &Command{flagSet: set} f := &FloatFlag{Name: "myflag"} - require.Equal(t, 1.23, f.Get(cmd)) + require.Equal(t, 1.23, cmd.Float(f.Name)) } var float64SliceFlagTests = []struct { @@ -1536,7 +1537,7 @@ func TestFloat64SliceFlagValueFromCommand(t *testing.T) { set.Var(NewFloatSlice(1.23, 4.56), "myflag", "doc") cmd := &Command{flagSet: set} f := &FloatSliceFlag{Name: "myflag"} - require.Equal(t, []float64{1.23, 4.56}, f.Get(cmd)) + require.Equal(t, []float64{1.23, 4.56}, cmd.FloatSlice(f.Name)) } func TestFloat64SliceFlagApply_ParentCommand(t *testing.T) { @@ -2552,7 +2553,7 @@ func TestTimestampFlagValueFromCommand(t *testing.T) { set.Var(newTimestamp(now), "myflag", "doc") cmd := &Command{flagSet: set} f := &TimestampFlag{Name: "myflag"} - require.Equal(t, now, f.Get(cmd)) + require.Equal(t, now, cmd.Timestamp(f.Name)) } type flagDefaultTestCase struct { @@ -2986,7 +2987,7 @@ func TestStringMapFlagValueFromCommand(t *testing.T) { set.Var(NewStringMap(map[string]string{"a": "b", "c": ""}), "myflag", "doc") cmd := &Command{flagSet: set} f := &StringMapFlag{Name: "myflag"} - require.Equal(t, map[string]string{"a": "b", "c": ""}, f.Get(cmd)) + require.Equal(t, map[string]string{"a": "b", "c": ""}, cmd.StringMap(f.Name)) } func TestStringMapFlagApply_Error(t *testing.T) { @@ -3002,3 +3003,77 @@ func TestZeroValueMutexFlag(t *testing.T) { var fl MutuallyExclusiveFlags assert.NoError(t, fl.check(&Command{})) } + +func TestExtFlag(t *testing.T) { + fs := flag.NewFlagSet("foo", flag.ContinueOnError) + + var iv intValue + var ipv int64 + + f := &flag.Flag{ + Name: "bar", + Usage: "bar usage", + Value: iv.Create(11, &ipv, IntegerConfig{}), + DefValue: "10", + } + + extF := &extFlag{ + f: f, + } + + assert.NoError(t, extF.Apply(fs)) + assert.Equal(t, []string{"bar"}, extF.Names()) + assert.True(t, extF.IsVisible()) + assert.False(t, extF.IsSet()) + assert.False(t, extF.TakesValue()) + assert.Equal(t, "bar usage", extF.GetUsage()) + assert.Equal(t, "11", extF.GetValue()) + assert.Equal(t, "10", extF.GetDefaultText()) + assert.Nil(t, extF.GetEnvVars()) +} + +func TestSliceValuesNil(t *testing.T) { + assert.Equal(t, []float64(nil), NewFloatSlice().Value()) + assert.Equal(t, []int64(nil), NewIntSlice().Value()) + assert.Equal(t, []uint64(nil), NewUintSlice().Value()) + assert.Equal(t, []string(nil), NewStringSlice().Value()) + + assert.Equal(t, []float64(nil), (&FloatSlice{}).Value()) + assert.Equal(t, []int64(nil), (&IntSlice{}).Value()) + assert.Equal(t, []uint64(nil), (&UintSlice{}).Value()) + assert.Equal(t, []string(nil), (&StringSlice{}).Value()) +} + +func TestFileHint(t *testing.T) { + assert.Equal(t, "", withFileHint("", "")) + assert.Equal(t, " [/tmp/foo.txt]", withFileHint("/tmp/foo.txt", "")) + assert.Equal(t, "foo", withFileHint("", "foo")) + assert.Equal(t, "bar [/tmp/foo.txt]", withFileHint("/tmp/foo.txt", "bar")) +} + +func TestFlagsByName(t *testing.T) { + flags := []Flag{ + &StringFlag{ + Name: "b2", + }, + &IntFlag{ + Name: "a0", + }, + &FloatFlag{ + Name: "b1", + }, + } + + flagsByName := FlagsByName(flags) + sort.Sort(flagsByName) + + assert.Equal(t, len(flags), flagsByName.Len()) + + var prev Flag + for _, f := range flags { + if prev != nil { + assert.LessOrEqual(t, prev.Names()[0], f.Names()[0]) + } + prev = f + } +} diff --git a/godoc-current.txt b/godoc-current.txt index 2d120038ab..48d6ff7f0c 100644 --- a/godoc-current.txt +++ b/godoc-current.txt @@ -553,10 +553,6 @@ type DocGenerationFlag interface { // GetUsage returns the usage string for the flag GetUsage() string - // GetValue returns the flags value as string representation and an empty - // string if the flag takes no value at all. - GetValue() string - // GetDefaultText returns the default text for this flag GetDefaultText() string @@ -680,9 +676,6 @@ type FlagBase[T any, C any, VC ValueCreator[T, C]] struct { func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error Apply populates the flag given the flag set and environment -func (f *FlagBase[T, C, V]) Get(cmd *Command) T - Get returns the flag’s value in the given Command. - func (f *FlagBase[T, C, V]) GetCategory() string GetCategory returns the category of the flag @@ -695,10 +688,6 @@ func (f *FlagBase[T, C, V]) GetEnvVars() []string func (f *FlagBase[T, C, V]) GetUsage() string GetUsage returns the usage string for the flag -func (f *FlagBase[T, C, V]) GetValue() string - GetValue returns the flags value as string representation and an empty - string if the flag takes no value at all. - func (f *FlagBase[T, C, V]) IsDefaultVisible() bool IsDefaultVisible returns true if the flag is not hidden, otherwise false @@ -904,9 +893,6 @@ func (i *SliceBase[T, C, VC]) Serialize() string func (i *SliceBase[T, C, VC]) Set(value string) error Set parses the value and appends it to the list of values -func (i *SliceBase[T, C, VC]) SetOne(value T) - SetOne directly adds a value to the list of values - func (i *SliceBase[T, C, VC]) String() string String returns a readable representation of this value (for usage defaults) diff --git a/help.go b/help.go index 4a31a49e1a..20c7489709 100644 --- a/help.go +++ b/help.go @@ -169,7 +169,7 @@ func printCommandSuggestions(commands []*Command, writer io.Writer) { } } -func cliArgContains(flagName string) bool { +func cliArgContains(flagName string, args []string) bool { for _, name := range strings.Split(flagName, ",") { name = strings.TrimSpace(name) count := utf8.RuneCountInString(name) @@ -177,7 +177,7 @@ func cliArgContains(flagName string) bool { count = 2 } flag := fmt.Sprintf("%s%s", strings.Repeat("-", count), name) - for _, a := range os.Args { + for _, a := range args { if a == flag { return true } @@ -211,7 +211,7 @@ func printFlagSuggestions(lastArg string, flags []Flag, writer io.Writer) { continue } // match if last argument matches this flag and it is not repeated - if strings.HasPrefix(name, cur) && cur != name && !cliArgContains(name) { + if strings.HasPrefix(name, cur) && cur != name && !cliArgContains(name, os.Args) { flagCompletion := fmt.Sprintf("%s%s", strings.Repeat("-", count), name) if usage != "" && strings.HasSuffix(os.Getenv("SHELL"), "zsh") { flagCompletion = fmt.Sprintf("%s:%s", flagCompletion, usage) @@ -239,6 +239,11 @@ func DefaultCompleteWithFlags(ctx context.Context, cmd *Command) { lastArg = args[argsLen-1] } + if lastArg == "--" { + tracef("not printing flag suggestion as last arg is --") + return + } + if strings.HasPrefix(lastArg, "-") { tracef("printing flag suggestion for flag[%v] on command %[1]q", lastArg, cmd.Name) printFlagSuggestions(lastArg, cmd.Flags, cmd.Root().Writer) @@ -312,10 +317,6 @@ func ShowSubcommandHelpAndExit(cmd *Command, exitCode int) { // ShowSubcommandHelp prints help for the given subcommand func ShowSubcommandHelp(cmd *Command) error { - if cmd == nil { - return nil - } - HelpPrinter(cmd.Root().Writer, SubcommandHelpTemplate, cmd) return nil } @@ -376,6 +377,7 @@ func printHelpCustom(out io.Writer, templ string, data interface{}, customFuncs w := tabwriter.NewWriter(out, 1, 8, 2, ' ', 0) t := template.Must(template.New("help").Funcs(funcMap).Parse(templ)) + if _, err := t.New("helpNameTemplate").Parse(helpNameTemplate); err != nil { handleTemplateError(err) } diff --git a/help_test.go b/help_test.go index 448e6cb3ab..c39e9447d1 100644 --- a/help_test.go +++ b/help_test.go @@ -3,6 +3,7 @@ package cli import ( "bytes" "context" + "errors" "flag" "fmt" "io" @@ -1191,6 +1192,50 @@ func TestDefaultCompleteWithFlags(t *testing.T) { env: map[string]string{"SHELL": "bash"}, expected: "--excitement\n", }, + { + name: "typical-flag-suggestion-hidden-bool", + cmd: &Command{ + Flags: []Flag{ + &BoolFlag{Name: "excitement", Hidden: true}, + &StringFlag{Name: "hat-shape"}, + }, + parent: &Command{ + Name: "cmd", + Flags: []Flag{ + &BoolFlag{Name: "happiness"}, + &IntFlag{Name: "everybody-jump-on"}, + }, + Commands: []*Command{ + {Name: "putz"}, + }, + }, + }, + argv: []string{"cmd", "--e", "--generate-shell-completion"}, + env: map[string]string{"SHELL": "bash"}, + expected: "", + }, + { + name: "flag-suggestion-end-args", + cmd: &Command{ + Flags: []Flag{ + &BoolFlag{Name: "excitement"}, + &StringFlag{Name: "hat-shape"}, + }, + parent: &Command{ + Name: "cmd", + Flags: []Flag{ + &BoolFlag{Name: "happiness"}, + &IntFlag{Name: "everybody-jump-on"}, + }, + Commands: []*Command{ + {Name: "putz"}, + }, + }, + }, + argv: []string{"cmd", "--e", "--", "--generate-shell-completion"}, + env: map[string]string{"SHELL": "bash"}, + expected: "", + }, { name: "typical-command-suggestion", cmd: &Command{ @@ -1772,3 +1817,167 @@ func Test_checkShellCompleteFlag(t *testing.T) { }) } } + +func TestNIndent(t *testing.T) { + t.Parallel() + tests := []struct { + numSpaces int + str string + expected string + }{ + { + numSpaces: 0, + str: "foo", + expected: "\nfoo", + }, + { + numSpaces: 0, + str: "foo\n", + expected: "\nfoo\n", + }, + { + numSpaces: 2, + str: "foo", + expected: "\n foo", + }, + { + numSpaces: 3, + str: "foo\n", + expected: "\n foo\n ", + }, + } + for _, test := range tests { + assert.Equal(t, test.expected, nindent(test.numSpaces, test.str)) + } +} + +func TestTemplateError(t *testing.T) { + oldew := ErrWriter + defer func() { ErrWriter = oldew }() + + var buf bytes.Buffer + ErrWriter = &buf + err := errors.New("some error") + + handleTemplateError(err) + assert.Equal(t, []byte(nil), buf.Bytes()) + + t.Setenv("CLI_TEMPLATE_ERROR_DEBUG", "true") + handleTemplateError(err) + assert.Contains(t, buf.String(), "CLI TEMPLATE ERROR") + assert.Contains(t, buf.String(), err.Error()) +} + +func TestCliArgContainsFlag(t *testing.T) { + tests := []struct { + name string + args []string + contains bool + }{ + { + name: "", + args: []string{}, + }, + { + name: "f", + args: []string{}, + }, + { + name: "f", + args: []string{"g", "foo", "f"}, + }, + { + name: "f", + args: []string{"-f", "foo", "f"}, + contains: true, + }, + { + name: "f", + args: []string{"g", "-f", "f"}, + contains: true, + }, + { + name: "fh", + args: []string{"g", "f", "--fh"}, + contains: true, + }, + { + name: "fhg", + args: []string{"-fhg", "f", "fh"}, + }, + { + name: "fhg", + args: []string{"--fhg", "f", "fh"}, + contains: true, + }, + } + + for _, test := range tests { + if test.contains { + assert.True(t, cliArgContains(test.name, test.args)) + } else { + assert.False(t, cliArgContains(test.name, test.args)) + } + } +} + +func TestCommandHelpSuggest(t *testing.T) { + cmd := &Command{ + Suggest: true, + Commands: []*Command{ + { + Name: "putz", + }, + }, + } + + cmd.setupDefaults([]string{"foo"}) + + err := ShowCommandHelp(context.Background(), cmd, "put") + assert.ErrorContains(t, err, "No help topic for 'put'. putz") +} + +func TestWrapLine(t *testing.T) { + assert.Equal(t, " ", wrapLine(" ", 0, 3, " ")) +} + +func TestPrintHelpCustomTemplateError(t *testing.T) { + tmpls := []*string{ + &helpNameTemplate, + &argsTemplate, + &usageTemplate, + &descriptionTemplate, + &visibleCommandTemplate, + ©rightTemplate, + &versionTemplate, + &visibleFlagCategoryTemplate, + &visibleFlagTemplate, + &visiblePersistentFlagTemplate, + &visibleFlagCategoryTemplate, + &authorsTemplate, + &visibleCommandCategoryTemplate, + } + + oldErrWriter := ErrWriter + defer func() { ErrWriter = oldErrWriter }() + + t.Setenv("CLI_TEMPLATE_ERROR_DEBUG", "true") + + for _, tmpl := range tmpls { + oldtmpl := *tmpl + // safety mechanism in case something fails + defer func(stmpl *string) { *stmpl = oldtmpl }(tmpl) + + errBuf := &bytes.Buffer{} + ErrWriter = errBuf + buf := &bytes.Buffer{} + + *tmpl = "{{junk" + printHelpCustom(buf, "", "", nil) + + assert.Contains(t, errBuf.String(), "CLI TEMPLATE ERROR") + + // reset template back. + *tmpl = oldtmpl + } +} diff --git a/testdata/godoc-v3.x.txt b/testdata/godoc-v3.x.txt index 2d120038ab..48d6ff7f0c 100644 --- a/testdata/godoc-v3.x.txt +++ b/testdata/godoc-v3.x.txt @@ -553,10 +553,6 @@ type DocGenerationFlag interface { // GetUsage returns the usage string for the flag GetUsage() string - // GetValue returns the flags value as string representation and an empty - // string if the flag takes no value at all. - GetValue() string - // GetDefaultText returns the default text for this flag GetDefaultText() string @@ -680,9 +676,6 @@ type FlagBase[T any, C any, VC ValueCreator[T, C]] struct { func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error Apply populates the flag given the flag set and environment -func (f *FlagBase[T, C, V]) Get(cmd *Command) T - Get returns the flag’s value in the given Command. - func (f *FlagBase[T, C, V]) GetCategory() string GetCategory returns the category of the flag @@ -695,10 +688,6 @@ func (f *FlagBase[T, C, V]) GetEnvVars() []string func (f *FlagBase[T, C, V]) GetUsage() string GetUsage returns the usage string for the flag -func (f *FlagBase[T, C, V]) GetValue() string - GetValue returns the flags value as string representation and an empty - string if the flag takes no value at all. - func (f *FlagBase[T, C, V]) IsDefaultVisible() bool IsDefaultVisible returns true if the flag is not hidden, otherwise false @@ -904,9 +893,6 @@ func (i *SliceBase[T, C, VC]) Serialize() string func (i *SliceBase[T, C, VC]) Set(value string) error Set parses the value and appends it to the list of values -func (i *SliceBase[T, C, VC]) SetOne(value T) - SetOne directly adds a value to the list of values - func (i *SliceBase[T, C, VC]) String() string String returns a readable representation of this value (for usage defaults)