From 66bbfaca0bc28b3a6cdd3c630dde631067a5a79b Mon Sep 17 00:00:00 2001 From: Naveen Gogineni Date: Sun, 6 Nov 2022 10:56:27 -0500 Subject: [PATCH] Changes from code review --- app_test.go | 57 +++++++++++++++++++++++++++++++++++------------------ command.go | 44 +++++++++++++++++++++++------------------ 2 files changed, 63 insertions(+), 38 deletions(-) diff --git a/app_test.go b/app_test.go index 9df0f4f69b..c68bff8ad1 100644 --- a/app_test.go +++ b/app_test.go @@ -3080,30 +3080,38 @@ func TestFlagAction(t *testing.T) { func TestPersistentFlag(t *testing.T) { - var vflag1, vflag2, vflag3 int + var topInt, topPersistentInt, subCommandInt int + var appFlag string a := &App{ + Flags: []Flag{ + &StringFlag{ + Name: "persistentAppFlag", + Persistent: true, + Destination: &appFlag, + }, + }, Commands: []*Command{ { - Name: "foo", + Name: "cmd", Flags: []Flag{ &IntFlag{ - Name: "flag1", - Destination: &vflag1, + Name: "cmdFlag", + Destination: &topInt, }, &IntFlag{ - Name: "flag2", + Name: "cmdPersistentFlag", Persistent: true, - Destination: &vflag2, + Destination: &topPersistentInt, }, }, Subcommands: []*Command{ { - Name: "bar", + Name: "subcmd", Flags: []Flag{ &IntFlag{ - Name: "flag1", - Destination: &vflag3, + Name: "cmdFlag", + Destination: &subCommandInt, }, }, }, @@ -3112,29 +3120,40 @@ func TestPersistentFlag(t *testing.T) { }, } - err := a.Run([]string{"app", "foo", "--flag1", "10"}) + err := a.Run([]string{"app", "cmd", "--cmdFlag", "10", "--persistentAppFlag", "hello"}) if err != nil { t.Fatal(err) } - if vflag1 != 10 { - t.Errorf("Expected 10 got %d", vflag1) + if topInt != 10 { + t.Errorf("Expected 10 got %d", topInt) } - err = a.Run([]string{"app", "foo", "--flag1", "12", "bar", "--flag2", "20", "--flag1", "11"}) + if appFlag != "hello" { + t.Errorf("Expected 'hello' got %s", appFlag) + } + + err = a.Run([]string{"app", "--persistentAppFlag", "hello", + "cmd", "--cmdFlag", "12", + "subcmd", "--cmdPersistentFlag", "20", "--cmdFlag", "11", "--persistentAppFlag", "bar"}) + if err != nil { t.Fatal(err) } - if vflag1 != 12 { - t.Errorf("Expected 12 got %d", vflag1) + if appFlag != "bar" { + t.Errorf("Expected 'bar' got %s", appFlag) + } + + if topInt != 12 { + t.Errorf("Expected 12 got %d", topInt) } - if vflag2 != 20 { - t.Errorf("Expected 20 got %d", vflag2) + if topPersistentInt != 20 { + t.Errorf("Expected 20 got %d", topPersistentInt) } - if vflag3 != 11 { - t.Errorf("Expected 11 got %d", vflag3) + if subCommandInt != 11 { + t.Errorf("Expected 11 got %d", subCommandInt) } } diff --git a/command.go b/command.go index cdec02b711..cd3a8b0378 100644 --- a/command.go +++ b/command.go @@ -318,34 +318,40 @@ func (c *Command) parseFlags(args Args, ctx *Context) (*flag.FlagSet, error) { } for pCtx := ctx.parentContext; pCtx != nil; pCtx = pCtx.parentContext { - if pCtx.Command != nil { - for _, fl := range pCtx.Command.Flags { - if pfl, ok := fl.(PersistentFlag); ok && pfl.IsPersistent() { - applyPersistentFlag := true - set.VisitAll(func(f *flag.Flag) { - for _, name := range fl.Names() { - if name == f.Name { - applyPersistentFlag = false - } - } - }) - if applyPersistentFlag { - if err := fl.Apply(set); err != nil { - return nil, err - } + if pCtx.Command == nil { + continue + } + + for _, fl := range pCtx.Command.Flags { + pfl, ok := fl.(PersistentFlag) + if !ok || !pfl.IsPersistent() { + continue + } + + applyPersistentFlag := true + set.VisitAll(func(f *flag.Flag) { + for _, name := range fl.Names() { + if name == f.Name { + applyPersistentFlag = false } } + }) + + if !applyPersistentFlag { + continue + } + + if err := fl.Apply(set); err != nil { + return nil, err } } } - err = parseIter(set, c, args.Tail(), ctx.shellComplete) - if err != nil { + if err = parseIter(set, c, args.Tail(), ctx.shellComplete); err != nil { return nil, err } - err = normalizeFlags(c.Flags, set) - if err != nil { + if err = normalizeFlags(c.Flags, set); err != nil { return nil, err }