-
Notifications
You must be signed in to change notification settings - Fork 72
/
proxy.go
101 lines (85 loc) · 2.46 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package tunnel
import (
"io"
"net"
"sync"
"github.com/koding/logging"
"github.com/koding/tunnel/proto"
)
// ProxyFunc is responsible for forwarding a remote connection to local server and writing the response back.
type ProxyFunc func(remote net.Conn, msg *proto.ControlMessage)
var (
// DefaultProxyFuncs holds global default proxy functions for all transport protocols.
DefaultProxyFuncs = ProxyFuncs{
HTTP: new(HTTPProxy).Proxy,
TCP: new(TCPProxy).Proxy,
WS: new(HTTPProxy).Proxy,
}
// DefaultProxy is a ProxyFunc that uses DefaultProxyFuncs.
DefaultProxy = Proxy(ProxyFuncs{})
)
// ProxyFuncs is a collection of ProxyFunc.
type ProxyFuncs struct {
// HTTP is custom implementation of HTTP proxing.
HTTP ProxyFunc
// TCP is custom implementation of TCP proxing.
TCP ProxyFunc
// WS is custom implementation of web socket proxing.
WS ProxyFunc
}
// Proxy returns a ProxyFunc that uses custom function if provided, otherwise falls back to DefaultProxyFuncs.
func Proxy(p ProxyFuncs) ProxyFunc {
return func(remote net.Conn, msg *proto.ControlMessage) {
var f ProxyFunc
switch msg.Protocol {
case proto.HTTP:
f = DefaultProxyFuncs.HTTP
if p.HTTP != nil {
f = p.HTTP
}
case proto.TCP:
f = DefaultProxyFuncs.TCP
if p.TCP != nil {
f = p.TCP
}
case proto.WS:
f = DefaultProxyFuncs.WS
if p.WS != nil {
f = p.WS
}
}
if f == nil {
logging.Error("Could not determine proxy function for %v", msg)
remote.Close()
}
f(remote, msg)
}
}
// Join copies data between local and remote connections.
// It reads from one connection and writes to the other.
// It's a building block for ProxyFunc implementations.
func Join(local, remote net.Conn, log logging.Logger) {
var wg sync.WaitGroup
wg.Add(2)
transfer := func(side string, dst, src net.Conn) {
log.Debug("proxing %s -> %s", src.RemoteAddr(), dst.RemoteAddr())
n, err := io.Copy(dst, src)
if err != nil {
log.Error("%s: copy error: %s", side, err)
}
if err := src.Close(); err != nil {
log.Debug("%s: close error: %s", side, err)
}
// not for yamux streams, but for client to local server connections
if d, ok := dst.(*net.TCPConn); ok {
if err := d.CloseWrite(); err != nil {
log.Debug("%s: closeWrite error: %s", side, err)
}
}
wg.Done()
log.Debug("done proxing %s -> %s: %d bytes", src.RemoteAddr(), dst.RemoteAddr(), n)
}
go transfer("remote to local", local, remote)
go transfer("local to remote", remote, local)
wg.Wait()
}