Skip to content

Commit

Permalink
feat: add PFlagSetProvider interface to access underlying pflag.FlagS…
Browse files Browse the repository at this point in the history
…et (#18)

Signed-off-by: Keith Zantow <kzantow@gmail.com>
  • Loading branch information
kzantow authored Aug 5, 2023
1 parent 329a9a4 commit 205b14e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
13 changes: 12 additions & 1 deletion flag_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@ type FlagSet interface {
StringArrayVarP(p *[]string, name, shorthand, usage string)
}

type PFlagSetProvider interface {
PFlagSet() *pflag.FlagSet
}

type pflagSet struct {
ignoreDuplicates bool
log logger.Logger
flagSet *pflag.FlagSet
}

var _ FlagSet = (*pflagSet)(nil)
var _ interface {
FlagSet
PFlagSetProvider
} = (*pflagSet)(nil)

func NewPFlagSet(log logger.Logger, flags *pflag.FlagSet) FlagSet {
return &pflagSet{
Expand All @@ -36,6 +43,10 @@ func NewPFlagSet(log logger.Logger, flags *pflag.FlagSet) FlagSet {
}
}

func (f *pflagSet) PFlagSet() *pflag.FlagSet {
return f.flagSet
}

func (f *pflagSet) exists(name, shorthand string) bool {
if !f.ignoreDuplicates {
return false
Expand Down
20 changes: 14 additions & 6 deletions flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@ import (
"testing"

"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/anchore/go-logger/adapter/discard"
)

func Test_PFlagSetProvider(t *testing.T) {
flags := pflag.NewFlagSet("set", pflag.ContinueOnError)
flagSet := NewPFlagSet(discard.New(), flags)
prov, ok := flagSet.(PFlagSetProvider)
require.True(t, ok)
require.Equal(t, flags, prov.PFlagSet())
}

func Test_EmbeddedAddFlags(t *testing.T) {
type ty1 struct {
Something string
Expand All @@ -25,7 +33,7 @@ func Test_EmbeddedAddFlags(t *testing.T) {
flagNames = append(flagNames, flag.Name)
})

assert.Equal(t, flagNames, []string{"sub2-flag"})
require.Equal(t, flagNames, []string{"sub2-flag"})
}

func Test_AddFlags(t *testing.T) {
Expand All @@ -38,10 +46,10 @@ func Test_AddFlags(t *testing.T) {
flagNames = append(flagNames, flag.Name)
})

assert.Len(t, flagNames, 3)
assert.Contains(t, flagNames, "t1-flag")
assert.Contains(t, flagNames, "sub2-flag")
assert.Contains(t, flagNames, "sub3-flag")
require.Len(t, flagNames, 3)
require.Contains(t, flagNames, "t1-flag")
require.Contains(t, flagNames, "sub2-flag")
require.Contains(t, flagNames, "sub3-flag")
}

type Sub2 struct {
Expand Down

0 comments on commit 205b14e

Please sign in to comment.