From 205b14e7411a6e897aec849ed1f068f674269121 Mon Sep 17 00:00:00 2001 From: Keith Zantow Date: Sat, 5 Aug 2023 13:51:15 -0400 Subject: [PATCH] feat: add PFlagSetProvider interface to access underlying pflag.FlagSet (#18) Signed-off-by: Keith Zantow --- flag_set.go | 13 ++++++++++++- flags_test.go | 20 ++++++++++++++------ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/flag_set.go b/flag_set.go index 9ba0f0b..060ec0a 100644 --- a/flag_set.go +++ b/flag_set.go @@ -20,13 +20,20 @@ type FlagSet interface { StringArrayVarP(p *[]string, name, shorthand, usage string) } +type PFlagSetProvider interface { + PFlagSet() *pflag.FlagSet +} + type pflagSet struct { ignoreDuplicates bool log logger.Logger flagSet *pflag.FlagSet } -var _ FlagSet = (*pflagSet)(nil) +var _ interface { + FlagSet + PFlagSetProvider +} = (*pflagSet)(nil) func NewPFlagSet(log logger.Logger, flags *pflag.FlagSet) FlagSet { return &pflagSet{ @@ -36,6 +43,10 @@ func NewPFlagSet(log logger.Logger, flags *pflag.FlagSet) FlagSet { } } +func (f *pflagSet) PFlagSet() *pflag.FlagSet { + return f.flagSet +} + func (f *pflagSet) exists(name, shorthand string) bool { if !f.ignoreDuplicates { return false diff --git a/flags_test.go b/flags_test.go index 87aee5d..809c911 100644 --- a/flags_test.go +++ b/flags_test.go @@ -4,11 +4,19 @@ import ( "testing" "github.com/spf13/pflag" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/anchore/go-logger/adapter/discard" ) +func Test_PFlagSetProvider(t *testing.T) { + flags := pflag.NewFlagSet("set", pflag.ContinueOnError) + flagSet := NewPFlagSet(discard.New(), flags) + prov, ok := flagSet.(PFlagSetProvider) + require.True(t, ok) + require.Equal(t, flags, prov.PFlagSet()) +} + func Test_EmbeddedAddFlags(t *testing.T) { type ty1 struct { Something string @@ -25,7 +33,7 @@ func Test_EmbeddedAddFlags(t *testing.T) { flagNames = append(flagNames, flag.Name) }) - assert.Equal(t, flagNames, []string{"sub2-flag"}) + require.Equal(t, flagNames, []string{"sub2-flag"}) } func Test_AddFlags(t *testing.T) { @@ -38,10 +46,10 @@ func Test_AddFlags(t *testing.T) { flagNames = append(flagNames, flag.Name) }) - assert.Len(t, flagNames, 3) - assert.Contains(t, flagNames, "t1-flag") - assert.Contains(t, flagNames, "sub2-flag") - assert.Contains(t, flagNames, "sub3-flag") + require.Len(t, flagNames, 3) + require.Contains(t, flagNames, "t1-flag") + require.Contains(t, flagNames, "sub2-flag") + require.Contains(t, flagNames, "sub3-flag") } type Sub2 struct {