diff --git a/command.go b/command.go index b257f91b6..c7e898303 100644 --- a/command.go +++ b/command.go @@ -177,8 +177,6 @@ type Command struct { // that we can use on every pflag set and children commands globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName - // output is an output writer defined by user. - output io.Writer // usageFunc is usage func defined by user. usageFunc func(*Command) error // usageTemplate is usage template defined by user. @@ -195,6 +193,13 @@ type Command struct { helpCommand *Command // versionTemplate is the version template defined by user. versionTemplate string + + // inReader is a reader defined by the user that replaces stdin + inReader io.Reader + // outWriter is a writer defined by the user that replaces stdout + outWriter io.Writer + // errWriter is a writer defined by the user that replaces stderr + errWriter io.Writer } // SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden @@ -205,8 +210,28 @@ func (c *Command) SetArgs(a []string) { // SetOutput sets the destination for usage and error messages. // If output is nil, os.Stderr is used. +// Deprecated: Use SetOut and/or SetErr instead func (c *Command) SetOutput(output io.Writer) { - c.output = output + c.outWriter = output + c.errWriter = output +} + +// SetOut sets the destination for usage messages. +// If newOut is nil, os.Stdout is used. +func (c *Command) SetOut(newOut io.Writer) { + c.outWriter = newOut +} + +// SetErr sets the destination for error messages. +// If newErr is nil, os.Stderr is used. +func (c *Command) SetErr(newErr io.Writer) { + c.errWriter = newErr +} + +// SetOut sets the source for input data +// If newIn is nil, os.Stdin is used. +func (c *Command) SetIn(newIn io.Reader) { + c.inReader = newIn } // SetUsageFunc sets usage function. Usage can be defined by application. @@ -267,9 +292,19 @@ func (c *Command) OutOrStderr() io.Writer { return c.getOut(os.Stderr) } +// ErrOrStderr returns output to stderr +func (c *Command) ErrOrStderr() io.Writer { + return c.getErr(os.Stderr) +} + +// ErrOrStderr returns output to stderr +func (c *Command) InOrStdin() io.Reader { + return c.getIn(os.Stdin) +} + func (c *Command) getOut(def io.Writer) io.Writer { - if c.output != nil { - return c.output + if c.outWriter != nil { + return c.outWriter } if c.HasParent() { return c.parent.getOut(def) @@ -277,6 +312,26 @@ func (c *Command) getOut(def io.Writer) io.Writer { return def } +func (c *Command) getErr(def io.Writer) io.Writer { + if c.errWriter != nil { + return c.errWriter + } + if c.HasParent() { + return c.parent.getErr(def) + } + return def +} + +func (c *Command) getIn(def io.Reader) io.Reader { + if c.inReader != nil { + return c.inReader + } + if c.HasParent() { + return c.parent.getIn(def) + } + return def +} + // UsageFunc returns either the function set by SetUsageFunc for this command // or a parent, or it returns a default usage function. func (c *Command) UsageFunc() (f func(*Command) error) { @@ -329,13 +384,22 @@ func (c *Command) Help() error { return nil } -// UsageString return usage string. +// UsageString returns usage string. func (c *Command) UsageString() string { - tmpOutput := c.output + // Storing normal writers + tmpOutput := c.outWriter + tmpErr := c.errWriter + bb := new(bytes.Buffer) - c.SetOutput(bb) + c.outWriter = bb + c.errWriter = bb + c.Usage() - c.output = tmpOutput + + // Setting things back to normal + c.outWriter = tmpOutput + c.errWriter = tmpErr + return bb.String() } @@ -1068,6 +1132,21 @@ func (c *Command) Printf(format string, i ...interface{}) { c.Print(fmt.Sprintf(format, i...)) } +// PrintErr is a convenience method to Print to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErr(i ...interface{}) { + fmt.Fprint(c.ErrOrStderr(), i...) +} + +// PrintErrln is a convenience method to Println to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErrln(i ...interface{}) { + c.Print(fmt.Sprintln(i...)) +} + +// PrintErrf is a convenience method to Printf to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErrf(format string, i ...interface{}) { + c.Print(fmt.Sprintf(format, i...)) +} + // CommandPath returns the full path to this command. func (c *Command) CommandPath() string { if c.HasParent() { diff --git a/command_test.go b/command_test.go index 6e483a3ec..2fa2003cb 100644 --- a/command_test.go +++ b/command_test.go @@ -1381,6 +1381,46 @@ func TestSetOutput(t *testing.T) { } } +func TestSetOut(t *testing.T) { + c := &Command{} + c.SetOut(nil) + if out := c.OutOrStdout(); out != os.Stdout { + t.Errorf("Expected setting output to nil to revert back to stdout") + } +} + +func TestSetErr(t *testing.T) { + c := &Command{} + c.SetErr(nil) + if out := c.ErrOrStderr(); out != os.Stderr { + t.Errorf("Expected setting error to nil to revert back to stderr") + } +} + +func TestSetIn(t *testing.T) { + c := &Command{} + c.SetIn(nil) + if out := c.InOrStdin(); out != os.Stdin { + t.Errorf("Expected setting input to nil to revert back to stdin") + } +} + +func TestUsageStringRedirected(t *testing.T) { + c := &Command{} + + c.usageFunc = func(cmd *Command) error { + cmd.Print("[stdout1]") + cmd.PrintErr("[stderr2]") + cmd.Print("[stdout3]") + return nil + } + + expected := "[stdout1][stderr2][stdout3]" + if got := c.UsageString(); got != expected { + t.Errorf("Expected usage string to consider both stdout and stderr") + } +} + func TestFlagErrorFunc(t *testing.T) { c := &Command{Use: "c", Run: emptyRun}