From 3810b2346f49a47aa0b99c23a7aa619d5f5dcf80 Mon Sep 17 00:00:00 2001 From: Canelo Hill <172609632+canelohill@users.noreply.github.com> Date: Fri, 5 Jul 2024 13:35:59 -0400 Subject: [PATCH 1/5] Handle errcheck warnings The package ignored errors from net.Conn Set*Deadline in a few places. Update the package to return these errors to the caller. Ignore all other errors reported by errcheck. These errors are safe to ignore because - The function is making a best effort to cleanup while handling another error. - The function call is guaranteed to succeed. - The error is ignored in a test. --- client.go | 18 +++++++++++++++--- client_server_test.go | 8 ++++---- compression.go | 6 +++++- compression_test.go | 6 +++--- conn.go | 36 ++++++++++++++++++++++-------------- conn_broadcast_test.go | 4 ++-- conn_test.go | 32 ++++++++++++++++---------------- join_test.go | 2 +- prepared_test.go | 4 +++- server.go | 34 ++++++++++++++++++++++++++++------ 10 files changed, 99 insertions(+), 51 deletions(-) diff --git a/client.go b/client.go index 73eada1..24bd7ff 100644 --- a/client.go +++ b/client.go @@ -305,9 +305,15 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h }) } + // Close the network connection when returning an error. The variable + // netConn is set to nil before the success return at the end of the + // function. defer func() { if netConn != nil { - netConn.Close() + // It's safe to ignore the error from Close() because this code is + // only executed when returning a more important error to the + // application. + _ = netConn.Close() } }() @@ -398,8 +404,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h resp.Body = io.NopCloser(bytes.NewReader([]byte{})) conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") - netConn.SetDeadline(time.Time{}) - netConn = nil // to avoid close in defer. + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, resp, err + } + + // Success! Set netConn to nil to stop the deferred function above from + // closing the network connection. + netConn = nil + return conn, resp, nil } diff --git a/client_server_test.go b/client_server_test.go index ec555b4..7de9e88 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -578,7 +578,7 @@ func TestRespOnBadHandshake(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(expectedStatus) - io.WriteString(w, expectedBody) + _, _ = io.WriteString(w, expectedBody) })) defer s.Close() @@ -828,7 +828,7 @@ func TestSocksProxyDial(t *testing.T) { } defer c1.Close() - c1.SetDeadline(time.Now().Add(30 * time.Second)) + _ = c1.SetDeadline(time.Now().Add(30 * time.Second)) buf := make([]byte, 32) if _, err := io.ReadFull(c1, buf[:3]); err != nil { @@ -867,10 +867,10 @@ func TestSocksProxyDial(t *testing.T) { defer c2.Close() done := make(chan struct{}) go func() { - io.Copy(c1, c2) + _, _ = io.Copy(c1, c2) close(done) }() - io.Copy(c2, c1) + _, _ = io.Copy(c2, c1) <-done }() diff --git a/compression.go b/compression.go index 813ffb1..fe1079e 100644 --- a/compression.go +++ b/compression.go @@ -33,7 +33,11 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser { "\x01\x00\x00\xff\xff" fr, _ := flateReaderPool.Get().(io.ReadCloser) - fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + mr := io.MultiReader(r, strings.NewReader(tail)) + if err := fr.(flate.Resetter).Reset(mr, nil); err != nil { + // Reset never fails, but handle error in case that changes. + fr = flate.NewReader(mr) + } return &flateReadWrapper{fr} } diff --git a/compression_test.go b/compression_test.go index 23591c4..00ae42f 100644 --- a/compression_test.go +++ b/compression_test.go @@ -22,7 +22,7 @@ func TestTruncWriter(t *testing.T) { if m > n { m = n } - w.Write(p[:m]) + _, _ = w.Write(p[:m]) p = p[m:] } if b.String() != data[:len(data)-len(w.p)] { @@ -46,7 +46,7 @@ func BenchmarkWriteNoCompression(b *testing.B) { messages := textMessages(100) b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(TextMessage, messages[i%len(messages)]) + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } @@ -59,7 +59,7 @@ func BenchmarkWriteWithCompression(b *testing.B) { c.newCompressionWriter = compressNoContextTakeover b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(TextMessage, messages[i%len(messages)]) + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } diff --git a/conn.go b/conn.go index 1bc4d3d..476616a 100644 --- a/conn.go +++ b/conn.go @@ -370,7 +370,9 @@ func (c *Conn) read(n int) ([]byte, error) { if err == io.EOF { err = errUnexpectedEOF } - c.br.Discard(len(p)) + // Discard is guaranteed to succeed because the number of bytes to discard + // is less than or equal to the number of bytes buffered. + _, _ = c.br.Discard(len(p)) return p, err } @@ -385,7 +387,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return err } - c.conn.SetWriteDeadline(deadline) + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } if len(buf1) == 0 { _, err = c.conn.Write(buf0) } else { @@ -395,7 +399,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return c.writeFatal(err) } if frameType == CloseMessage { - c.writeFatal(ErrCloseSent) + _ = c.writeFatal(ErrCloseSent) } return nil } @@ -458,13 +462,14 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er return err } - c.conn.SetWriteDeadline(deadline) - _, err = c.conn.Write(buf) - if err != nil { + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } + if _, err = c.conn.Write(buf); err != nil { return c.writeFatal(err) } if messageType == CloseMessage { - c.writeFatal(ErrCloseSent) + _ = c.writeFatal(ErrCloseSent) } return err } @@ -628,7 +633,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { } if final { - w.endMessage(errWriteClosed) + _ = w.endMessage(errWriteClosed) return nil } @@ -815,7 +820,7 @@ func (c *Conn) advanceFrame() (int, error) { rsv2 := p[0]&rsv2Bit != 0 rsv3 := p[0]&rsv3Bit != 0 mask := p[1]&maskBit != 0 - c.setReadRemaining(int64(p[1] & 0x7f)) + _ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0 c.readDecompress = false if rsv1 { @@ -920,7 +925,8 @@ func (c *Conn) advanceFrame() (int, error) { } if c.readLimit > 0 && c.readLength > c.readLimit { - c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + // Make a best effort to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) return noFrame, ErrReadLimit } @@ -932,7 +938,7 @@ func (c *Conn) advanceFrame() (int, error) { var payload []byte if c.readRemaining > 0 { payload, err = c.read(int(c.readRemaining)) - c.setReadRemaining(0) + _ = c.setReadRemaining(0) // will not fail because argument is >= 0 if err != nil { return noFrame, err } @@ -979,7 +985,8 @@ func (c *Conn) handleProtocolError(message string) error { if len(data) > maxControlFramePayloadSize { data = data[:maxControlFramePayloadSize] } - c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) + // Make a best effor to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) return errors.New("websocket: " + message) } @@ -1052,7 +1059,7 @@ func (r *messageReader) Read(b []byte) (int, error) { } rem := c.readRemaining rem -= int64(n) - c.setReadRemaining(rem) + _ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0 if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } @@ -1134,7 +1141,8 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) { if h == nil { h = func(code int, text string) error { message := FormatCloseMessage(code, "") - c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + // Make a best effor to send the close message. + _ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) return nil } } diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go index d8a6492..540be6b 100644 --- a/conn_broadcast_test.go +++ b/conn_broadcast_test.go @@ -69,9 +69,9 @@ func (b *broadcastBench) makeConns(numConns int) { select { case msg := <-c.msgCh: if msg.prepared != nil { - c.conn.WritePreparedMessage(msg.prepared) + _ = c.conn.WritePreparedMessage(msg.prepared) } else { - c.conn.WriteMessage(TextMessage, msg.payload) + _ = c.conn.WriteMessage(TextMessage, msg.payload) } val := atomic.AddInt32(&b.count, 1) if val%int32(numConns) == 0 { diff --git a/conn_test.go b/conn_test.go index e9f5441..3b244a9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -157,7 +157,7 @@ func TestControl(t *testing.T) { wc := newTestConn(nil, &connBuf, isServer) rc := newTestConn(&connBuf, nil, !isServer) if isWriteControl { - wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) + _ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) } else { w, err := wc.NextWriter(PongMessage) if err != nil { @@ -174,7 +174,7 @@ func TestControl(t *testing.T) { } var actualMessage string rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) - rc.NextReader() + _, _, _ = rc.NextReader() if actualMessage != message { t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) continue @@ -358,8 +358,8 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) - wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) + _ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) w.Close() op, r, err := rc.NextReader() @@ -385,7 +385,7 @@ func TestEOFWithinFrame(t *testing.T) { rc := newTestConn(&b, nil, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize)) + _, _ = w.Write(make([]byte, bufSize)) w.Close() if n >= b.Len() { @@ -419,7 +419,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) { rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { @@ -438,7 +438,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) { func TestWriteAfterMessageWriterClose(t *testing.T) { wc := newTestConn(nil, &bytes.Buffer{}, false) w, _ := wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") + _, _ = io.WriteString(w, "hello") if err := w.Close(); err != nil { t.Fatalf("unxpected error closing message writer, %v", err) } @@ -448,7 +448,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { } w, _ = wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") + _, _ = io.WriteString(w, "hello") // close w by getting next writer _, err := wc.NextWriter(BinaryMessage) @@ -473,13 +473,13 @@ func TestReadLimit(t *testing.T) { // Send message at the limit with interleaved pong. w, _ := wc.NextWriter(BinaryMessage) - w.Write(message[:readLimit-1]) - wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) - w.Write(message[:1]) + _, _ = w.Write(message[:readLimit-1]) + _ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + _, _ = w.Write(message[:1]) w.Close() // Send message larger than the limit. - wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + _ = wc.WriteMessage(BinaryMessage, message[:readLimit+1]) op, _, err := rc.NextReader() if op != BinaryMessage || err != nil { @@ -592,7 +592,7 @@ func TestBufioReadBytes(t *testing.T) { rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) w, _ := wc.NextWriter(BinaryMessage) - w.Write(m) + _, _ = w.Write(m) w.Close() op, r, err := rc.NextReader() @@ -666,7 +666,7 @@ func TestConcurrentWritePanic(t *testing.T) { w := blockingWriter{make(chan struct{}), make(chan struct{})} c := newTestConn(nil, w, false) go func() { - c.WriteMessage(TextMessage, []byte{}) + _ = c.WriteMessage(TextMessage, []byte{}) }() // wait for goroutine to block in write. @@ -679,7 +679,7 @@ func TestConcurrentWritePanic(t *testing.T) { } }() - c.WriteMessage(TextMessage, []byte{}) + _ = c.WriteMessage(TextMessage, []byte{}) t.Fatal("should not get here") } @@ -699,7 +699,7 @@ func TestFailedConnectionReadPanic(t *testing.T) { }() for i := 0; i < 20000; i++ { - c.ReadMessage() + _, _, _ = c.ReadMessage() } t.Fatal("should not get here") } diff --git a/join_test.go b/join_test.go index 961ac04..37bb30f 100644 --- a/join_test.go +++ b/join_test.go @@ -19,7 +19,7 @@ func TestJoinMessages(t *testing.T) { wc := newTestConn(nil, &connBuf, true) rc := newTestConn(&connBuf, nil, false) for _, m := range messages { - wc.WriteMessage(BinaryMessage, []byte(m)) + _ = wc.WriteMessage(BinaryMessage, []byte(m)) } var result bytes.Buffer diff --git a/prepared_test.go b/prepared_test.go index 536d58d..50d065e 100644 --- a/prepared_test.go +++ b/prepared_test.go @@ -45,7 +45,9 @@ func TestPreparedMessage(t *testing.T) { if tt.enableWriteCompression { c.newCompressionWriter = compressNoContextTakeover } - c.SetCompressionLevel(tt.compressionLevel) + if err := c.SetCompressionLevel(tt.compressionLevel); err != nil { + t.Fatal(err) + } // Seed random number generator for consistent frame mask. testRand.Seed(1234) diff --git a/server.go b/server.go index b76131d..02ea01f 100644 --- a/server.go +++ b/server.go @@ -178,6 +178,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade "websocket: hijack: "+err.Error()) } + // Close the network connection when returning an error. The variable + // netConn is set to nil before the success return at the end of the + // function. + defer func() { + if netConn != nil { + // It's safe to ignore the error from Close() because this code is + // only executed when returning a more important error to the + // application. + _ = netConn.Close() + } + }() + var br *bufio.Reader if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 { // Use hijacked buffered reader as the connection reader. @@ -244,20 +256,30 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } p = append(p, "\r\n"...) - // Clear deadlines set by HTTP server. - netConn.SetDeadline(time.Time{}) - if u.HandshakeTimeout > 0 { - netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) + if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { + return nil, err + } + } else { + // Clear deadlines set by HTTP server. + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, err + } } + if _, err = netConn.Write(p); err != nil { - netConn.Close() return nil, err } if u.HandshakeTimeout > 0 { - netConn.SetWriteDeadline(time.Time{}) + if err := netConn.SetWriteDeadline(time.Time{}); err != nil { + return nil, err + } } + // Success! Set netConn to nil to stop the deferred function above from + // closing the network connection. + netConn = nil + return c, nil } From 85fb2d8136476e486befcd2352f7e9e8f9fa4d0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Sz=C3=A9pe?= Date: Sun, 14 Jul 2024 09:59:04 +0000 Subject: [PATCH 2/5] Fix typos --- client_server_test.go | 2 +- conn.go | 2 +- conn_test.go | 6 +++--- proxy.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/client_server_test.go b/client_server_test.go index 7de9e88..e4546ae 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -522,7 +522,7 @@ func TestNoUpgrade(t *testing.T) { } resp.Body.Close() if u := resp.Header.Get("Upgrade"); u != "websocket" { - t.Errorf("Uprade response header is %q, want %q", u, "websocket") + t.Errorf("Upgrade response header is %q, want %q", u, "websocket") } if resp.StatusCode != http.StatusUpgradeRequired { t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired) diff --git a/conn.go b/conn.go index 476616a..13c86b2 100644 --- a/conn.go +++ b/conn.go @@ -1164,7 +1164,7 @@ func (c *Conn) PingHandler() func(appData string) error { func (c *Conn) SetPingHandler(h func(appData string) error) { if h == nil { h = func(message string) error { - // Make a best effort to send the pong mesage. + // Make a best effort to send the pong message. _ = c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) return nil } diff --git a/conn_test.go b/conn_test.go index 3b244a9..1dc52ec 100644 --- a/conn_test.go +++ b/conn_test.go @@ -47,7 +47,7 @@ func (a fakeAddr) String() string { return "str" } -// newTestConn creates a connnection backed by a fake network connection using +// newTestConn creates a connection backed by a fake network connection using // default values for buffering. func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) @@ -149,7 +149,7 @@ func TestFraming(t *testing.T) { } func TestControl(t *testing.T) { - const message = "this is a ping/pong messsage" + const message = "this is a ping/pong message" for _, isServer := range []bool{true, false} { for _, isWriteControl := range []bool{true, false} { name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) @@ -440,7 +440,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { w, _ := wc.NextWriter(BinaryMessage) _, _ = io.WriteString(w, "hello") if err := w.Close(); err != nil { - t.Fatalf("unxpected error closing message writer, %v", err) + t.Fatalf("unexpected error closing message writer, %v", err) } if _, err := io.WriteString(w, "world"); err == nil { diff --git a/proxy.go b/proxy.go index f113710..b4683b9 100644 --- a/proxy.go +++ b/proxy.go @@ -77,7 +77,7 @@ func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, add return nil, err } - // Read response. It's OK to use and discard buffered reader here becaue + // Read response. It's OK to use and discard buffered reader here because // the remote server does not speak until spoken to. br := bufio.NewReader(conn) resp, err := http.ReadResponse(br, connectReq) From f01629e7ea03b4e4f9433436a50fd4bdc290271c Mon Sep 17 00:00:00 2001 From: rfyiamcool Date: Fri, 8 Dec 2023 17:24:58 +0800 Subject: [PATCH 3/5] perf: reduce timer in write_control Signed-off-by: rfyiamcool --- conn.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 13c86b2..a3979da 100644 --- a/conn.go +++ b/conn.go @@ -446,13 +446,18 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er } } - timer := time.NewTimer(d) select { case <-c.mu: - timer.Stop() - case <-timer.C: - return errWriteTimeout + default: + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } } + defer func() { c.mu <- struct{}{} }() c.writeErrMu.Lock() From c5b8b8c38ae0aa1d2f75079f430a2393fce0d59d Mon Sep 17 00:00:00 2001 From: rfyiamcool Date: Mon, 22 Jan 2024 13:47:31 +0800 Subject: [PATCH 4/5] perf: reduce timer in write_control Signed-off-by: rfyiamcool --- conn_test.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/conn_test.go b/conn_test.go index 1dc52ec..61ed74e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -148,6 +148,31 @@ func TestFraming(t *testing.T) { } } +func TestConcurrencyWriteControl(t *testing.T) { + const message = "this is a ping/pong messsage" + loop := 10 + workers := 10 + for i := 0; i < loop; i++ { + var connBuf bytes.Buffer + + wg := sync.WaitGroup{} + wc := newTestConn(nil, &connBuf, true) + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil { + t.Errorf("concurrently wc.WriteControl() returned %v", err) + } + }() + } + + wg.Wait() + wc.Close() + } +} + func TestControl(t *testing.T) { const message = "this is a ping/pong message" for _, isServer := range []bool{true, false} { From 5e002381133d322c5f1305d171f3bdd07decf229 Mon Sep 17 00:00:00 2001 From: Martin Greenwald Date: Wed, 14 Feb 2024 19:06:54 -0800 Subject: [PATCH 5/5] Do not timeout when WriteControl deadline is zero A zero value for the Conn.WriteControl deadline specifies no timeout, but the feature was implemented as a very long timeout (1000 hours). This PR updates the code to use no timeout when the deadline is zero. See the discussion in #895 for more details. --- conn.go | 25 +++++++++++++------------ conn_test.go | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/conn.go b/conn.go index a3979da..9562ffd 100644 --- a/conn.go +++ b/conn.go @@ -438,23 +438,24 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er maskBytes(key, 0, buf[6:]) } - d := 1000 * time.Hour - if !deadline.IsZero() { - d = deadline.Sub(time.Now()) + if deadline.IsZero() { + // No timeout for zero time. + <-c.mu + } else { + d := time.Until(deadline) if d < 0 { return errWriteTimeout } - } - - select { - case <-c.mu: - default: - timer := time.NewTimer(d) select { case <-c.mu: - timer.Stop() - case <-timer.C: - return errWriteTimeout + default: + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } } } diff --git a/conn_test.go b/conn_test.go index 61ed74e..28f5c4a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -148,6 +148,22 @@ func TestFraming(t *testing.T) { } } +func TestWriteControlDeadline(t *testing.T) { + t.Parallel() + message := []byte("hello") + var connBuf bytes.Buffer + c := newTestConn(nil, &connBuf, true) + if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil { + t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil { + t.Errorf("WriteControl(..., future deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil { + t.Errorf("WriteControl(..., past deadline) = nil, want timeout error") + } +} + func TestConcurrencyWriteControl(t *testing.T) { const message = "this is a ping/pong messsage" loop := 10