forked from armon/go-proxyproto
-
Notifications
You must be signed in to change notification settings - Fork 0
/
protocol.go
245 lines (216 loc) · 6.38 KB
/
protocol.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
package proxyproto
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
)
var (
// prefix is the string we look for at the start of a connection
// to check if this connection is using the proxy protocol
prefix = []byte("PROXY ")
prefixLen = len(prefix)
ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
)
// SourceChecker can be used to decide whether to trust the PROXY info or pass
// the original connection address through. If set, the connecting address is
// passed in as an argument. If the function returns an error due to the source
// being disallowed, it should return ErrInvalidUpstream.
//
// If error is not nil, the call to Accept() will fail. If the reason for
// triggering this failure is due to a disallowed source, it should return
// ErrInvalidUpstream.
//
// If bool is true, the PROXY-set address is used.
//
// If bool is false, the connection's remote address is used, rather than the
// address claimed in the PROXY info.
type SourceChecker func(net.Addr) (bool, error)
// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol (version 1).
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
//
// Optionally define ProxyHeaderTimeout to set a maximum time to
// receive the Proxy Protocol Header. Zero means no timeout.
type Listener struct {
Listener net.Listener
ProxyHeaderTimeout time.Duration
SourceCheck SourceChecker
}
// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
type Conn struct {
bufReader *bufio.Reader
conn net.Conn
dstAddr *net.TCPAddr
srcAddr *net.TCPAddr
useConnRemoteAddr bool
once sync.Once
proxyHeaderTimeout time.Duration
}
// Accept waits for and returns the next connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
return nil, err
}
var useConnRemoteAddr bool
if p.SourceCheck != nil {
allowed, err := p.SourceCheck(conn.RemoteAddr())
if err != nil {
return nil, err
}
if !allowed {
useConnRemoteAddr = true
}
}
newConn := NewConn(conn, p.ProxyHeaderTimeout)
newConn.useConnRemoteAddr = useConnRemoteAddr
return newConn, nil
}
// Close closes the underlying listener.
func (p *Listener) Close() error {
return p.Listener.Close()
}
// Addr returns the underlying listener's network address.
func (p *Listener) Addr() net.Addr {
return p.Listener.Addr()
}
// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
pConn := &Conn{
bufReader: bufio.NewReader(conn),
conn: conn,
proxyHeaderTimeout: timeout,
}
return pConn
}
// Read is check for the proxy protocol header when doing
// the initial scan. If there is an error parsing the header,
// it is returned and the socket is closed.
func (p *Conn) Read(b []byte) (int, error) {
var err error
p.once.Do(func() { err = p.checkPrefix() })
if err != nil {
return 0, err
}
return p.bufReader.Read(b)
}
func (p *Conn) Write(b []byte) (int, error) {
return p.conn.Write(b)
}
func (p *Conn) Close() error {
return p.conn.Close()
}
func (p *Conn) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
// RemoteAddr returns the address of the client if the proxy
// protocol is being used, otherwise just returns the address of
// the socket peer. If there is an error parsing the header, the
// address of the client is not returned, and the socket is closed.
// Once implication of this is that the call could block if the
// client is slow. Using a Deadline is recommended if this is called
// before Read()
func (p *Conn) RemoteAddr() net.Addr {
p.once.Do(func() {
if err := p.checkPrefix(); err != nil && err != io.EOF {
log.Printf("[ERR] Failed to read proxy prefix: err=%v remote_addr=%s", err, p.conn.RemoteAddr())
p.Close()
p.bufReader = bufio.NewReader(p.conn)
}
})
if p.srcAddr != nil && !p.useConnRemoteAddr {
return p.srcAddr
}
return p.conn.RemoteAddr()
}
func (p *Conn) SetDeadline(t time.Time) error {
return p.conn.SetDeadline(t)
}
func (p *Conn) SetReadDeadline(t time.Time) error {
return p.conn.SetReadDeadline(t)
}
func (p *Conn) SetWriteDeadline(t time.Time) error {
return p.conn.SetWriteDeadline(t)
}
func (p *Conn) checkPrefix() error {
if p.proxyHeaderTimeout != 0 {
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
p.conn.SetReadDeadline(readDeadLine)
defer p.conn.SetReadDeadline(time.Time{})
}
// Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ {
inp, err := p.bufReader.Peek(i)
if err != nil {
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return nil
} else {
return err
}
}
// Check for a prefix mis-match, quit early
if !bytes.Equal(inp, prefix[:i]) {
return nil
}
}
// Read the header line
header, err := p.bufReader.ReadString('\n')
if err != nil {
p.conn.Close()
return err
}
// Strip the carriage return and new line
header = header[:len(header)-2]
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
parts := strings.Split(header, " ")
if len(parts) != 6 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
}
// Verify the type is known
switch parts[1] {
case "TCP4":
case "TCP6":
default:
p.conn.Close()
return fmt.Errorf("Unhandled address type: %s", parts[1])
}
// Parse out the source address
ip := net.ParseIP(parts[2])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid source ip: %s", parts[2])
}
port, err := strconv.Atoi(parts[4])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid source port: %s", parts[4])
}
p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
// Parse out the destination address
ip = net.ParseIP(parts[3])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid destination ip: %s", parts[3])
}
port, err = strconv.Atoi(parts[5])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid destination port: %s", parts[5])
}
p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
return nil
}