Skip to content

Commit

Permalink
Add support for context.Context
Browse files Browse the repository at this point in the history
  • Loading branch information
burdiyan committed Feb 20, 2020
1 parent 21cab29 commit 0da0687
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 2 deletions.
30 changes: 28 additions & 2 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package cobra

import (
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -143,9 +144,11 @@ type Command struct {
// TraverseChildren parses flags on all parents before executing child command.
TraverseChildren bool

//FParseErrWhitelist flag parse errors to be ignored
// FParseErrWhitelist flag parse errors to be ignored
FParseErrWhitelist FParseErrWhitelist

ctx context.Context

// commands is the list of commands supported by this program.
commands []*Command
// parent is a parent command for this command.
Expand Down Expand Up @@ -205,6 +208,12 @@ type Command struct {
errWriter io.Writer
}

// Context returns underlying command context. If command wasn't
// executed with ExecuteContext Context returns Background context.
func (c *Command) Context() context.Context {
return c.ctx
}

// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden
// particularly useful when testing.
func (c *Command) SetArgs(a []string) {
Expand Down Expand Up @@ -862,6 +871,13 @@ func (c *Command) preRun() {
}
}

// ExecuteContext is the same as Execute(), but sets the ctx on the command.
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions.
func (c *Command) ExecuteContext(ctx context.Context) error {
c.ctx = ctx
return c.Execute()
}

// Execute uses the args (os.Args[1:] by default)
// and run through the command tree finding appropriate matches
// for commands and then corresponding flags.
Expand All @@ -872,6 +888,10 @@ func (c *Command) Execute() error {

// ExecuteC executes the command.
func (c *Command) ExecuteC() (cmd *Command, err error) {
if c.ctx == nil {
c.ctx = context.Background()
}

// Regardless of what command execute is called on, run on Root only
if c.HasParent() {
return c.Root().ExecuteC()
Expand Down Expand Up @@ -916,6 +936,12 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
cmd.commandCalledAs.name = cmd.Name()
}

// We have to pass global context to children command
// if context is present on the parent command.
if cmd.ctx == nil {
cmd.ctx = c.ctx
}

err = cmd.execute(flags)
if err != nil {
// Always show help if requested, even if SilenceErrors is in
Expand Down Expand Up @@ -1560,7 +1586,7 @@ func (c *Command) ParseFlags(args []string) error {
beforeErrorBufLen := c.flagErrorBuf.Len()
c.mergePersistentFlags()

//do it here after merging all flags and just before parse
// do it here after merging all flags and just before parse
c.Flags().ParseErrorsWhitelist = flag.ParseErrorsWhitelist(c.FParseErrWhitelist)

err := c.Flags().Parse(args)
Expand Down
67 changes: 67 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cobra

import (
"bytes"
"context"
"fmt"
"os"
"reflect"
Expand All @@ -18,6 +19,16 @@ func executeCommand(root *Command, args ...string) (output string, err error) {
return output, err
}

func executeCommandWithContext(ctx context.Context, root *Command, args ...string) (output string, err error) {
buf := new(bytes.Buffer)
root.SetOutput(buf)
root.SetArgs(args)

err = root.ExecuteContext(ctx)

return buf.String(), err
}

func executeCommandC(root *Command, args ...string) (c *Command, output string, err error) {
buf := new(bytes.Buffer)
root.SetOutput(buf)
Expand Down Expand Up @@ -135,6 +146,62 @@ func TestSubcommandExecuteC(t *testing.T) {
}
}

func TestExecuteContext(t *testing.T) {
ctx := context.TODO()

ctxRun := func(cmd *Command, args []string) {
if cmd.Context() != ctx {
t.Errorf("Command %q must have context when called with ExecuteContext", cmd.Use)
}
}

rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun}
childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun}
granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun}

childCmd.AddCommand(granchildCmd)
rootCmd.AddCommand(childCmd)

if _, err := executeCommandWithContext(ctx, rootCmd, ""); err != nil {
t.Errorf("Root command must not fail: %+v", err)
}

if _, err := executeCommandWithContext(ctx, rootCmd, "child"); err != nil {
t.Errorf("Subcommand must not fail: %+v", err)
}

if _, err := executeCommandWithContext(ctx, rootCmd, "child", "grandchild"); err != nil {
t.Errorf("Command child must not fail: %+v", err)
}
}

func TestExecute_NoContext(t *testing.T) {
run := func(cmd *Command, args []string) {
if cmd.Context() != context.Background() {
t.Errorf("Command %s must have background context", cmd.Use)
}
}

rootCmd := &Command{Use: "root", Run: run, PreRun: run}
childCmd := &Command{Use: "child", Run: run, PreRun: run}
granchildCmd := &Command{Use: "grandchild", Run: run, PreRun: run}

childCmd.AddCommand(granchildCmd)
rootCmd.AddCommand(childCmd)

if _, err := executeCommand(rootCmd, ""); err != nil {
t.Errorf("Root command must not fail: %+v", err)
}

if _, err := executeCommand(rootCmd, "child"); err != nil {
t.Errorf("Subcommand must not fail: %+v", err)
}

if _, err := executeCommand(rootCmd, "child", "grandchild"); err != nil {
t.Errorf("Command child must not fail: %+v", err)
}
}

func TestRootUnknownCommandSilenced(t *testing.T) {
rootCmd := &Command{Use: "root", Run: emptyRun}
rootCmd.SilenceErrors = true
Expand Down

0 comments on commit 0da0687

Please sign in to comment.