Skip to content

Commit

Permalink
proxy: add trusted proxies
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Jul 22, 2021
1 parent 7d3625f commit a8ae14f
Show file tree
Hide file tree
Showing 9 changed files with 428 additions and 95 deletions.
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ func createProxyConfig(options *Options) proxy.Config {
CacheMaxTTL: options.CacheMaxTTL,
CacheOptimistic: options.CacheOptimistic,
RefuseAny: options.RefuseAny,
TrustedProxies: []string{"0.0.0.0/0", "::0/0"},
EnableEDNSClientSubnet: options.EnableEDNSSubnet,
UDPBufferSize: options.UDPBufferSize,
MaxGoroutines: options.MaxGoRoutines,
Expand Down
6 changes: 6 additions & 0 deletions proxy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ type Config struct {
RatelimitWhitelist []string // a list of whitelisted client IP addresses
RefuseAny bool // if true, refuse ANY requests

// TrustedProxies is the list of IP addresses and CIDR networks to
// detect proxy servers addresses the DoH requests from which should be
// handled. The value of nil or an empty slice for this field makes
// Proxy not trust any address.
TrustedProxies []string

// Upstream DNS servers and their settings
// --

Expand Down
16 changes: 0 additions & 16 deletions proxy/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package proxy

import (
"net"
"strings"

"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
Expand Down Expand Up @@ -191,18 +190,3 @@ func isPublicIP(ip net.IP) bool {

return true
}

// split string by a byte and return the first chunk
// Whitespace is trimmed
func splitNext(str *string, splitBy byte) string {
i := strings.IndexByte(*str, splitBy)
s := ""
if i != -1 {
s = (*str)[0:i]
*str = (*str)[i+1:]
} else {
s = *str
*str = ""
}
return strings.TrimSpace(s)
}
8 changes: 8 additions & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ type Proxy struct {
ratelimitBuckets *gocache.Cache // where the ratelimiters are stored, per IP
ratelimitLock sync.Mutex // Synchronizes access to ratelimitBuckets

// proxyVerifier checks if the proxy is in the trusted list.
proxyVerifier *subnetDetector

// DNS cache
// --

Expand Down Expand Up @@ -183,6 +186,11 @@ func (p *Proxy) Init() (err error) {
p.fastestAddr = fastip.NewFastestAddr()
}

p.proxyVerifier, err = newSubnetDetector(p.TrustedProxies)
if err != nil {
return fmt.Errorf("initializing subnet detector for proxies verifying: %w", err)
}

return nil
}

Expand Down
7 changes: 7 additions & 0 deletions proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,9 @@ func createTestProxy(t *testing.T, tlsConfig *tls.Config) *Proxy {
}
p.UpstreamConfig = &UpstreamConfig{}
p.UpstreamConfig.Upstreams = append(upstreams, dnsUpstream)

p.TrustedProxies = []string{"0.0.0.0/0", "::0/0"}

return &p
}

Expand Down Expand Up @@ -1072,6 +1075,8 @@ func createHostTestMessage(host string) *dns.Msg {
}

func assertResponse(t *testing.T, reply *dns.Msg) {
t.Helper()

if len(reply.Answer) != 1 {
t.Fatalf("DNS upstream returned reply with wrong number of answers - %d", len(reply.Answer))
}
Expand All @@ -1085,6 +1090,8 @@ func assertResponse(t *testing.T, reply *dns.Msg) {
}

func createServerTLSConfig(t *testing.T) (*tls.Config, []byte) {
t.Helper()

privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("cannot generate RSA key: %s", err)
Expand Down
7 changes: 7 additions & 0 deletions proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,13 @@ func (p *Proxy) genNotImpl(request *dns.Msg) *dns.Msg {
return &resp
}

func (p *Proxy) genRefused(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeRefused)
resp.RecursionAvailable = true
return &resp
}

func (p *Proxy) genNXDomain(req *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(req, dns.RcodeNameError)
Expand Down
101 changes: 62 additions & 39 deletions proxy/server_https.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"net/http"
"strconv"
"strings"

"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
Expand Down Expand Up @@ -95,41 +96,28 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

addr, _ := p.remoteAddr(r)
addr, prx, _ := remoteAddr(r)

d := p.newDNSContext(ProtoHTTPS, req)
d.Addr = addr
d.HTTPRequest = r
d.HTTPResponseWriter = w

err = p.handleDNSRequest(d)
if err != nil {
log.Tracef("error handling DNS (%s) request: %s", d.Proto, err)
}
}
if prx != nil {
log.Debug("request came from proxy server %s", prx)
if !p.proxyVerifier.detect(prx) {
log.Debug("the proxy server %s is not trusted", prx)
d.Res = p.genRefused(req)
p.respond(d)

// Get a client IP address from HTTP headers that proxy servers may set
func getIPFromHTTPRequest(r *http.Request) net.IP {
names := []string{
"CF-Connecting-IP", "True-Client-IP", // set by CloudFlare servers
"X-Real-IP",
}
for _, name := range names {
s := r.Header.Get(name)
ip := net.ParseIP(s)
if ip != nil {
return ip
return
}
}

s := r.Header.Get("X-Forwarded-For")
s = splitNext(&s, ',') // get left-most IP address
ip := net.ParseIP(s)
if ip != nil {
return ip
err = p.handleDNSRequest(d)
if err != nil {
log.Tracef("error handling DNS (%s) request: %s", d.Proto, err)
}

return nil
}

// Writes a response to the DOH client
Expand All @@ -155,26 +143,61 @@ func (p *Proxy) respondHTTPS(d *DNSContext) error {
return err
}

func (p *Proxy) remoteAddr(r *http.Request) (net.Addr, error) {
host, port, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return nil, err
func addrsFromRequest(r *http.Request) (realIP net.IP, prx net.IP) {
for _, h := range []string{
// Headers set by CloudFlare proxy servers.
"CF-Connecting-IP",
"True-Client-IP",
// Other proxying headers.
"X-Real-IP",
} {
realIP = net.ParseIP(r.Header.Get(h))
if realIP != nil {
return realIP, nil
}
}

portValue, err := strconv.Atoi(port)
if err != nil {
return nil, err
xff := r.Header.Get("X-Forwarded-For")
firstComma := strings.IndexByte(xff, ',')
if firstComma == -1 {
return net.ParseIP(xff), nil
}

ip := getIPFromHTTPRequest(r)
if ip != nil {
log.Tracef("Using IP address from HTTP request: %s", ip)
} else {
ip = net.ParseIP(host)
if ip == nil {
return nil, fmt.Errorf("invalid IP: %s", host)
realIP = net.ParseIP(xff[:firstComma])
lastComma := strings.LastIndexByte(xff, ',')

return realIP, net.ParseIP(xff[lastComma+1:])
}

func remoteAddr(r *http.Request) (addr net.Addr, prx net.IP, err error) {
var hostStr, portStr string
if hostStr, portStr, err = net.SplitHostPort(r.RemoteAddr); err != nil {
return nil, nil, err
}

var port int
if port, err = strconv.Atoi(portStr); err != nil {
return nil, nil, err
}

var realIP net.IP
realIP, prx = addrsFromRequest(r)
if realIP == nil {
realIP = net.ParseIP(hostStr)
if realIP == nil {
return nil, nil, fmt.Errorf("invalid ip: %s", hostStr)
}

return &net.TCPAddr{IP: realIP, Port: port}, prx, nil
}

log.Tracef("Using IP address from HTTP request: %s", realIP)
if prx == nil {
prx = net.ParseIP(hostStr)
if prx == nil {
return nil, nil, fmt.Errorf("invalid ip: %s", hostStr)
}
}

return &net.TCPAddr{IP: ip, Port: portValue}, nil
return &net.TCPAddr{IP: realIP, Port: port}, prx, nil
}
Loading

0 comments on commit a8ae14f

Please sign in to comment.