Skip to content

Commit

Permalink
Add SOCKS5 support
Browse files Browse the repository at this point in the history
- Bundle the golang.org/x/net/proxy package to x_net_proxy.go. The
package contains a SOCKS5 proxy. The package is bundled to avoid adding
a dependency from the weboscket package to golang.org/x/net.
- Restructure the existing HTTP proxy code so the code can be used as a
dialer with the proxy package.
- Modify Dialer.Dial to use proxy.FromURL.
- Improve tests (avoid modifying package-level data, use timeouts in
tests, use correct proxy URLs in tests).

Fixes #297.
  • Loading branch information
garyburd committed Dec 1, 2017
1 parent 8c6cfd4 commit b89020e
Show file tree
Hide file tree
Showing 4 changed files with 677 additions and 67 deletions.
94 changes: 36 additions & 58 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
package websocket

import (
"bufio"
"bytes"
"crypto/tls"
"encoding/base64"
"errors"
"io"
"io/ioutil"
Expand Down Expand Up @@ -106,7 +104,7 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
return hostPort, hostNoPort
}

// DefaultDialer is a dialer with all fields set to the default zero values.
// DefaultDialer is a dialer with all fields set to the default values.
var DefaultDialer = &Dialer{
Proxy: http.ProxyFromEnvironment,
}
Expand Down Expand Up @@ -202,36 +200,52 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}

hostPort, hostNoPort := hostPortNoPort(u)

var proxyURL *url.URL
// Check wether the proxy method has been configured
if d.Proxy != nil {
proxyURL, err = d.Proxy(req)
}
if err != nil {
return nil, nil, err
}

var targetHostPort string
if proxyURL != nil {
targetHostPort, _ = hostPortNoPort(proxyURL)
} else {
targetHostPort = hostPort
}

var deadline time.Time
if d.HandshakeTimeout != 0 {
deadline = time.Now().Add(d.HandshakeTimeout)
}

// Get network dial function.
netDial := d.NetDial
if netDial == nil {
netDialer := &net.Dialer{Deadline: deadline}
netDial = netDialer.Dial
}

netConn, err := netDial("tcp", targetHostPort)
// If needed, wrap the dial function to set the connection deadline.
if !deadline.Equal(time.Time{}) {
forwardDial := netDial
netDial = func(network, addr string) (net.Conn, error) {
c, err := forwardDial(network, addr)
if err != nil {
return nil, err
}
err = c.SetDeadline(deadline)
if err != nil {
c.Close()
return nil, err
}
return c, nil
}
}

// If needed, wrap the dial function to connect through a proxy.
if d.Proxy != nil {
proxyURL, err := d.Proxy(req)
if err != nil {
return nil, nil, err
}
if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
if err != nil {
return nil, nil, err
}
netDial = dialer.Dial
}
}

hostPort, hostNoPort := hostPortNoPort(u)
netConn, err := netDial("tcp", hostPort)
if err != nil {
return nil, nil, err
}
Expand All @@ -242,42 +256,6 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
}
}()

if err := netConn.SetDeadline(deadline); err != nil {
return nil, nil, err
}

if proxyURL != nil {
connectHeader := make(http.Header)
if user := proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}
connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: hostPort},
Host: hostPort,
Header: connectHeader,
}

connectReq.Write(netConn)

// Read response.
// Okay to use and discard buffered reader here, because
// TLS server will not speak until spoken to.
br := bufio.NewReader(netConn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
return nil, nil, err
}
if resp.StatusCode != 200 {
f := strings.SplitN(resp.Status, " ", 2)
return nil, nil, errors.New(f[1])
}
}

if u.Scheme == "https" {
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
Expand Down
100 changes: 91 additions & 9 deletions client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
package websocket

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/binary"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
Expand All @@ -31,9 +34,10 @@ var cstUpgrader = Upgrader{
}

var cstDialer = Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: 30 * time.Second,
}

type cstHandler struct{ *testing.T }
Expand Down Expand Up @@ -143,8 +147,9 @@ func TestProxyDial(t *testing.T) {
s := newServer(t)
defer s.Close()

surl, _ := url.Parse(s.URL)
surl, _ := url.Parse(s.Server.URL)

cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(surl)

connect := false
Expand Down Expand Up @@ -173,16 +178,16 @@ func TestProxyDial(t *testing.T) {
}
defer ws.Close()
sendRecv(t, ws)

cstDialer.Proxy = http.ProxyFromEnvironment
}

func TestProxyAuthorizationDial(t *testing.T) {
s := newServer(t)
defer s.Close()

surl, _ := url.Parse(s.URL)
surl, _ := url.Parse(s.Server.URL)
surl.User = url.UserPassword("username", "password")

cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(surl)

connect := false
Expand Down Expand Up @@ -213,8 +218,6 @@ func TestProxyAuthorizationDial(t *testing.T) {
}
defer ws.Close()
sendRecv(t, ws)

cstDialer.Proxy = http.ProxyFromEnvironment
}

func TestDial(t *testing.T) {
Expand Down Expand Up @@ -518,3 +521,82 @@ func TestDialCompression(t *testing.T) {
defer ws.Close()
sendRecv(t, ws)
}

func TestSocksProxyDial(t *testing.T) {
s := newServer(t)
defer s.Close()

proxyListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen failed: %v", err)
}
defer proxyListener.Close()
go func() {
c1, err := proxyListener.Accept()
if err != nil {
t.Errorf("proxy accept failed: %v", err)
return
}
defer c1.Close()

c1.SetDeadline(time.Now().Add(30 * time.Second))

buf := make([]byte, 32)
if _, err := io.ReadFull(c1, buf[:3]); err != nil {
t.Errorf("read failed: %v", err)
return
}
if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) {
t.Errorf("read %x, want %x", buf[:len(want)], want)
}
if _, err := c1.Write([]byte{5, 0}); err != nil {
t.Errorf("write failed: %v", err)
return
}
if _, err := io.ReadFull(c1, buf[:10]); err != nil {
t.Errorf("read failed: %v", err)
return
}
if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) {
t.Errorf("read %x, want %x", buf[:len(want)], want)
return
}
buf[1] = 0
if _, err := c1.Write(buf[:10]); err != nil {
t.Errorf("write failed: %v", err)
return
}

ip := net.IP(buf[4:8])
port := binary.BigEndian.Uint16(buf[8:10])

c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)})
if err != nil {
t.Errorf("dial failed; %v", err)
return
}
defer c2.Close()
done := make(chan struct{})
go func() {
io.Copy(c1, c2)
close(done)
}()
io.Copy(c2, c1)
<-done
}()

purl, err := url.Parse("socks5://" + proxyListener.Addr().String())
if err != nil {
t.Fatalf("parse failed: %v", err)
}

cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(purl)

ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
}
77 changes: 77 additions & 0 deletions proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package websocket

import (
"bufio"
"encoding/base64"
"errors"
"net"
"net/http"
"net/url"
"strings"
)

type netDialerFunc func(netowrk, addr string) (net.Conn, error)

func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)
}

func init() {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, fowardDial: forwardDialer.Dial}, nil
})
}

type httpProxyDialer struct {
proxyURL *url.URL
fowardDial func(network, addr string) (net.Conn, error)
}

func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
hostPort, _ := hostPortNoPort(hpd.proxyURL)
conn, err := hpd.fowardDial(network, hostPort)
if err != nil {
return nil, err
}

connectHeader := make(http.Header)
if user := hpd.proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}

connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: addr},
Host: addr,
Header: connectHeader,
}

if err := connectReq.Write(conn); err != nil {
conn.Close()
return nil, err
}

// Read response. It's OK to use and discard buffered reader here becaue
// the remote server does not speak until spoken to.
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
conn.Close()
return nil, err
}

if resp.StatusCode != 200 {
conn.Close()
f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1])
}
return conn, nil
}
Loading

0 comments on commit b89020e

Please sign in to comment.