diff --git a/conn.go b/conn.go index 2ac19e9a..72f5df81 100644 --- a/conn.go +++ b/conn.go @@ -192,13 +192,6 @@ func newMaskKey() [4]byte { return k } -func hideTempErr(err error) error { - if e, ok := err.(net.Error); ok { - err = &netError{msg: e.Error(), timeout: e.Timeout()} - } - return err -} - func isControl(frameType int) bool { return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage } @@ -364,7 +357,6 @@ func (c *Conn) RemoteAddr() net.Addr { // Write methods func (c *Conn) writeFatal(err error) error { - err = hideTempErr(err) c.writeErrMu.Lock() if c.writeErr == nil { c.writeErr = err @@ -1033,7 +1025,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { for c.readErr == nil { frameType, err := c.advanceFrame() if err != nil { - c.readErr = hideTempErr(err) + c.readErr = err break } @@ -1073,7 +1065,7 @@ func (r *messageReader) Read(b []byte) (int, error) { b = b[:c.readRemaining] } n, err := c.br.Read(b) - c.readErr = hideTempErr(err) + c.readErr = err if c.isServer { c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) } @@ -1096,7 +1088,7 @@ func (r *messageReader) Read(b []byte) (int, error) { frameType, err := c.advanceFrame() switch { case err != nil: - c.readErr = hideTempErr(err) + c.readErr = err case frameType == TextMessage || frameType == BinaryMessage: c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") } diff --git a/conn_test.go b/conn_test.go index f0c29c39..3bd4f61b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -814,3 +814,25 @@ func TestFormatMessageType(t *testing.T) { t.Error("failed to format message type") } } + +type fakeNetClosedReader struct { +} + +func (r fakeNetClosedReader) Read([]byte) (int, error) { + return 0, net.ErrClosed +} + +func TestConnectionClosed(t *testing.T) { + var b1, b2 bytes.Buffer + + client := newTestConn(fakeNetClosedReader{}, &b1, false) + server := newTestConn(fakeNetClosedReader{}, &b2, true) + + if _, _, err := server.NextReader(); !errors.Is(err, net.ErrClosed) { + t.Fatalf("server expects a net.ErrClosed error, %v returned", err) + } + + if _, _, err := client.NextReader(); !errors.Is(err, net.ErrClosed) { + t.Fatalf("client expects a net.ErrClosed error, %v returned", err) + } +}