Skip to content

Commit

Permalink
Merge branch 'main' into prerrcheck
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Holmes <daniel.holmes@djcentric.com>
  • Loading branch information
jaitaiwan authored Jul 1, 2024
2 parents eb890c8 + 8915bad commit 12b2063
Show file tree
Hide file tree
Showing 18 changed files with 311 additions and 599 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ workflows:
- test:
matrix:
parameters:
version: ["1.18", "1.17", "1.16"]
version: ["1.22", "1.21", "1.20"]
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
### Documentation

* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
* [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/main/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch)

### Status

Expand All @@ -29,5 +29,4 @@ package API is stable.

The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).

subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn).
57 changes: 17 additions & 40 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptrace"
Expand Down Expand Up @@ -53,7 +52,7 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
// It is safe to call Dialer's methods concurrently.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
// NetDial is nil, net.Dialer DialContext is used.
NetDial func(network, addr string) (net.Conn, error)

// NetDialContext specifies the dial function for creating TCP connections. If
Expand Down Expand Up @@ -245,46 +244,25 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
defer cancel()
}

// Get network dial function.
var netDial func(network, add string) (net.Conn, error)

switch u.Scheme {
case "http":
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
case "https":
if d.NetDialTLSContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialTLSContext(ctx, network, addr)
}
} else if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
var netDial netDialerFunc
switch {
case u.Scheme == "https" && d.NetDialTLSContext != nil:
netDial = d.NetDialTLSContext
case d.NetDialContext != nil:
netDial = d.NetDialContext
case d.NetDial != nil:
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
return d.NetDial(net, addr)
}
default:
return nil, nil, errMalformedURL
}

if netDial == nil {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
}
netDial = (&net.Dialer{}).DialContext
}

// If needed, wrap the dial function to set the connection deadline.
if deadline, ok := ctx.Deadline(); ok {
forwardDial := netDial
netDial = func(network, addr string) (net.Conn, error) {
c, err := forwardDial(network, addr)
netDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := forwardDial(ctx, network, addr)
if err != nil {
return nil, err
}
Expand All @@ -304,11 +282,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err
}
if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
netDial, err = proxyFromURL(proxyURL, netDial)
if err != nil {
return nil, nil, err
}
netDial = dialer.Dial
}
}

Expand All @@ -318,7 +295,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
trace.GetConn(hostPort)
}

netConn, err := netDial("tcp", hostPort)
netConn, err := netDial(ctx, "tcp", hostPort)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -406,7 +383,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// debugging.
buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake
}

Expand All @@ -424,7 +401,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
break
}

resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")

if err := netConn.SetDeadline(time.Time{}); err != nil {
Expand Down
126 changes: 118 additions & 8 deletions client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package websocket

import (
"bufio"
"bytes"
"context"
"crypto/tls"
Expand All @@ -14,7 +15,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
Expand All @@ -24,6 +24,7 @@ import (
"net/url"
"reflect"
"strings"
"sync"
"testing"
"time"
)
Expand All @@ -45,12 +46,15 @@ var cstDialer = Dialer{
HandshakeTimeout: 30 * time.Second,
}

type cstHandler struct{ *testing.T }
type cstHandler struct {
*testing.T
s *cstServer
}

type cstServer struct {
*httptest.Server
URL string
t *testing.T
URL string
Server *httptest.Server
wg sync.WaitGroup
}

const (
Expand All @@ -59,23 +63,35 @@ const (
cstRequestURI = cstPath + "?" + cstRawQuery
)

func (s *cstServer) Close() {
s.Server.Close()
// Wait for handler functions to complete.
s.wg.Wait()
}

func newServer(t *testing.T) *cstServer {
var s cstServer
s.Server = httptest.NewServer(cstHandler{t})
s.Server = httptest.NewServer(cstHandler{T: t, s: &s})
s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL)
return &s
}

func newTLSServer(t *testing.T) *cstServer {
var s cstServer
s.Server = httptest.NewTLSServer(cstHandler{t})
s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s})
s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL)
return &s
}

func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Because tests wait for a response from a server, we are guaranteed that
// the wait group count is incremented before the test waits on the group
// in the call to (*cstServer).Close().
t.s.wg.Add(1)
defer t.s.wg.Done()

if r.URL.Path != cstPath {
t.Logf("path=%v, want %v", r.URL.Path, cstPath)
http.Error(w, "bad path", http.StatusBadRequest)
Expand Down Expand Up @@ -482,6 +498,37 @@ func TestBadMethod(t *testing.T) {
}
}

func TestNoUpgrade(t *testing.T) {
t.Parallel()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := cstUpgrader.Upgrade(w, r, nil)
if err == nil {
t.Errorf("handshake succeeded, expect fail")
ws.Close()
}
}))
defer s.Close()

req, err := http.NewRequest(http.MethodGet, s.URL, strings.NewReader(""))
if err != nil {
t.Fatalf("NewRequest returned error %v", err)
}
req.Header.Set("Connection", "upgrade")
req.Header.Set("Sec-Websocket-Version", "13")

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Do returned error %v", err)
}
resp.Body.Close()
if u := resp.Header.Get("Upgrade"); u != "websocket" {
t.Errorf("Uprade response header is %q, want %q", u, "websocket")
}
if resp.StatusCode != http.StatusUpgradeRequired {
t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired)
}
}

func TestDialExtraTokensInRespHeaders(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
challengeKey := r.Header.Get("Sec-Websocket-Key")
Expand Down Expand Up @@ -549,7 +596,7 @@ func TestRespOnBadHandshake(t *testing.T) {
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
}

p, err := ioutil.ReadAll(resp.Body)
p, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
}
Expand Down Expand Up @@ -1133,3 +1180,66 @@ func TestNextProtos(t *testing.T) {
t.Fatalf("Dial succeeded, expect fail ")
}
}

type dataBeforeHandshakeResponseWriter struct {
http.ResponseWriter
}

type dataBeforeHandshakeConnection struct {
net.Conn
io.Reader
}

func (c *dataBeforeHandshakeConnection) Read(p []byte) (int, error) {
return c.Reader.Read(p)
}

func (w dataBeforeHandshakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Example single-frame masked text message from section 5.7 of the RFC.
message := []byte{0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58}
n := len(message) / 2

c, rw, err := http.NewResponseController(w.ResponseWriter).Hijack()
if rw != nil {
// Load first part of message into bufio.Reader. If the websocket
// connection reads more than n bytes from the bufio.Reader, then the
// test will fail with an unexpected EOF error.
rw.Reader.Reset(bytes.NewReader(message[:n]))
rw.Reader.Peek(n)
}
if c != nil {
// Inject second part of message before data read from the network connection.
c = &dataBeforeHandshakeConnection{
Conn: c,
Reader: io.MultiReader(bytes.NewReader(message[n:]), c),
}
}
return c, rw, err
}

func TestDataReceivedBeforeHandshake(t *testing.T) {
s := newServer(t)
defer s.Close()

origHandler := s.Server.Config.Handler
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
origHandler.ServeHTTP(dataBeforeHandshakeResponseWriter{w}, r)
})

for _, readBufferSize := range []int{0, 1024} {
t.Run(fmt.Sprintf("ReadBufferSize=%d", readBufferSize), func(t *testing.T) {
dialer := cstDialer
dialer.ReadBufferSize = readBufferSize
ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
_, m, err := ws.ReadMessage()
if err != nil || string(m) != "Hello" {
t.Fatalf("ReadMessage() = %q, %v, want \"Hello\", nil", m, err)
}
})
}
}
5 changes: 2 additions & 3 deletions compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
"io"
"io/ioutil"
"testing"
)

Expand Down Expand Up @@ -42,7 +41,7 @@ func textMessages(num int) [][]byte {
}

func BenchmarkWriteNoCompression(b *testing.B) {
w := ioutil.Discard
w := io.Discard
c := newTestConn(nil, w, false)
messages := textMessages(100)
b.ResetTimer()
Expand All @@ -53,7 +52,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
}

func BenchmarkWriteWithCompression(b *testing.B) {
w := ioutil.Discard
w := io.Discard
c := newTestConn(nil, w, false)
messages := textMessages(100)
c.enableWriteCompression = true
Expand Down
Loading

0 comments on commit 12b2063

Please sign in to comment.