Skip to content

Commit

Permalink
Merge pull request #118 from skriss/fix-stderr-race
Browse files Browse the repository at this point in the history
ensure we finish reading stderr pipe before calling cmd.Wait()
  • Loading branch information
jbardin authored Jun 10, 2019
2 parents a1756f3 + 9d4515e commit a1bc615
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ type Client struct {
// goroutines.
clientWaitGroup sync.WaitGroup

// stderrWaitGroup is used to prevent the command's Wait() function from
// being called before we've finished reading from the stderr pipe.
stderrWaitGroup sync.WaitGroup

// processKilled is used for testing only, to flag when the process was
// forcefully killed.
processKilled bool
Expand Down Expand Up @@ -590,6 +594,12 @@ func (c *Client) Start() (addr net.Addr, err error) {
// Create a context for when we kill
c.doneCtx, c.ctxCancel = context.WithCancel(context.Background())

// Start goroutine that logs the stderr
c.clientWaitGroup.Add(1)
c.stderrWaitGroup.Add(1)
// logStderr calls Done()
go c.logStderr(cmdStderr)

c.clientWaitGroup.Add(1)
go func() {
// ensure the context is cancelled when we're done
Expand All @@ -602,6 +612,10 @@ func (c *Client) Start() (addr net.Addr, err error) {
pid := c.process.Pid
path := cmd.Path

// wait to finish reading from stderr since the stderr pipe reader
// will be closed by the subsequent call to cmd.Wait().
c.stderrWaitGroup.Wait()

// Wait for the command to end.
err := cmd.Wait()

Expand All @@ -624,11 +638,6 @@ func (c *Client) Start() (addr net.Addr, err error) {
c.exited = true
}()

// Start goroutine that logs the stderr
c.clientWaitGroup.Add(1)
// logStderr calls Done()
go c.logStderr(cmdStderr)

// Start a goroutine that is going to be reading the lines
// out of stdout
linesCh := make(chan string)
Expand Down Expand Up @@ -936,6 +945,7 @@ var stdErrBufferSize = 64 * 1024

func (c *Client) logStderr(r io.Reader) {
defer c.clientWaitGroup.Done()
defer c.stderrWaitGroup.Done()
l := c.logger.Named(filepath.Base(c.config.Cmd.Path))

reader := bufio.NewReaderSize(r, stdErrBufferSize)
Expand Down

0 comments on commit a1bc615

Please sign in to comment.