Skip to content

Commit

Permalink
reverseproxy: Implement cookie hash selection policy (#3809)
Browse files Browse the repository at this point in the history
* add CookieHashSelection for session affinity

* add CookieHashSelection for session affinity

* register module

* reverse_proxy: Add and fix cookie lb_policy

* reverse_proxy: Manage hmac.write error on cookie hash selection

* reverse_proxy: fix some comments

* reverse_proxy: variable `cookieValue` is inside the else block

* reverse_proxy: Abstract duplicate nuanced logic of reservoir sampling into a function

* reverse_proxy: Set a default secret is indeed useless

* reverse_proxy: add configuration syntax for cookie lb_policy

* reverse_proxy: doc typo and improvement

Co-authored-by: utick <123liuqingdong@163.com>
  • Loading branch information
d-masson and utick authored Nov 20, 2020
1 parent b0d5c2c commit 6e0849d
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 63 deletions.
4 changes: 2 additions & 2 deletions modules/caddyhttp/reverseproxy/reverseproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
var proxyErr error
for {
// choose an available upstream
upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, r)
upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, r, w)
if upstream == nil {
if proxyErr == nil {
proxyErr = fmt.Errorf("no upstreams available")
Expand Down Expand Up @@ -816,7 +816,7 @@ type LoadBalancing struct {

// Selector selects an available upstream from the pool.
type Selector interface {
Select(UpstreamPool, *http.Request) *Upstream
Select(UpstreamPool, *http.Request, http.ResponseWriter) *Upstream
}

// Hop-by-hop headers. These are removed when sent to the backend.
Expand Down
149 changes: 123 additions & 26 deletions modules/caddyhttp/reverseproxy/selectionpolicies.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
package reverseproxy

import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash/fnv"
weakrand "math/rand"
Expand All @@ -37,6 +40,7 @@ func init() {
caddy.RegisterModule(IPHashSelection{})
caddy.RegisterModule(URIHashSelection{})
caddy.RegisterModule(HeaderHashSelection{})
caddy.RegisterModule(CookieHashSelection{})

weakrand.Seed(time.Now().UTC().UnixNano())
}
Expand All @@ -54,24 +58,8 @@ func (RandomSelection) CaddyModule() caddy.ModuleInfo {
}

// Select returns an available host, if any.
func (r RandomSelection) Select(pool UpstreamPool, request *http.Request) *Upstream {
// use reservoir sampling because the number of available
// hosts isn't known: https://en.wikipedia.org/wiki/Reservoir_sampling
var randomHost *Upstream
var count int
for _, upstream := range pool {
if !upstream.Available() {
continue
}
// (n % 1 == 0) holds for all n, therefore a
// upstream will always be chosen if there is at
// least one available
count++
if (weakrand.Int() % count) == 0 {
randomHost = upstream
}
}
return randomHost
func (r RandomSelection) Select(pool UpstreamPool, request *http.Request, _ http.ResponseWriter) *Upstream {
return selectRandomHost(pool)
}

// UnmarshalCaddyfile sets up the module from Caddyfile tokens.
Expand Down Expand Up @@ -134,7 +122,7 @@ func (r RandomChoiceSelection) Validate() error {
}

// Select returns an available host, if any.
func (r RandomChoiceSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream {
func (r RandomChoiceSelection) Select(pool UpstreamPool, _ *http.Request, _ http.ResponseWriter) *Upstream {
k := r.Choose
if k > len(pool) {
k = len(pool)
Expand Down Expand Up @@ -174,7 +162,7 @@ func (LeastConnSelection) CaddyModule() caddy.ModuleInfo {
// Select selects the up host with the least number of connections in the
// pool. If more than one host has the same least number of connections,
// one of the hosts is chosen at random.
func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream {
func (LeastConnSelection) Select(pool UpstreamPool, _ *http.Request, _ http.ResponseWriter) *Upstream {
var bestHost *Upstream
var count int
leastReqs := -1
Expand Down Expand Up @@ -227,7 +215,7 @@ func (RoundRobinSelection) CaddyModule() caddy.ModuleInfo {
}

// Select returns an available host, if any.
func (r *RoundRobinSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream {
func (r *RoundRobinSelection) Select(pool UpstreamPool, _ *http.Request, _ http.ResponseWriter) *Upstream {
n := uint32(len(pool))
if n == 0 {
return nil
Expand Down Expand Up @@ -265,7 +253,7 @@ func (FirstSelection) CaddyModule() caddy.ModuleInfo {
}

// Select returns an available host, if any.
func (FirstSelection) Select(pool UpstreamPool, _ *http.Request) *Upstream {
func (FirstSelection) Select(pool UpstreamPool, _ *http.Request, _ http.ResponseWriter) *Upstream {
for _, host := range pool {
if host.Available() {
return host
Expand Down Expand Up @@ -297,7 +285,7 @@ func (IPHashSelection) CaddyModule() caddy.ModuleInfo {
}

// Select returns an available host, if any.
func (IPHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream {
func (IPHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
clientIP = req.RemoteAddr
Expand Down Expand Up @@ -328,7 +316,7 @@ func (URIHashSelection) CaddyModule() caddy.ModuleInfo {
}

// Select returns an available host, if any.
func (URIHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream {
func (URIHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
return hostByHashing(pool, req.RequestURI)
}

Expand Down Expand Up @@ -358,7 +346,7 @@ func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo {
}

// Select returns an available host, if any.
func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstream {
func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
if s.Field == "" {
return nil
}
Expand All @@ -371,7 +359,7 @@ func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request) *Upstr

val := req.Header.Get(s.Field)
if val == "" {
return RandomSelection{}.Select(pool, req)
return RandomSelection{}.Select(pool, req, nil)
}
return hostByHashing(pool, val)
}
Expand All @@ -387,6 +375,114 @@ func (s *HeaderHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return nil
}

// CookieHashSelection is a policy that selects
// a host based on a given cookie name.
type CookieHashSelection struct {
// The HTTP cookie name whose value is to be hashed and used for upstream selection.
Name string `json:"name,omitempty"`
// Secret to hash (Hmac256) chosen upstream in cookie
Secret string `json:"secret,omitempty"`
}

// CaddyModule returns the Caddy module information.
func (CookieHashSelection) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "http.reverse_proxy.selection_policies.cookie",
New: func() caddy.Module { return new(CookieHashSelection) },
}
}

// Select returns an available host, if any.
func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream {
if s.Name == "" {
s.Name = "lb"
}
cookie, err := req.Cookie(s.Name)
// If there's no cookie, select new random host
if err != nil || cookie == nil {
return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name)
} else {
// If the cookie is present, loop over the available upstreams until we find a match
cookieValue := cookie.Value
for _, upstream := range pool {
if !upstream.Available() {
continue
}
sha, err := hashCookie(s.Secret, upstream.Dial)
if err == nil && sha == cookieValue {
return upstream
}
}
}
// If there is no matching host, select new random host
return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name)
}

// UnmarshalCaddyfile sets up the module from Caddyfile tokens. Syntax:
// lb_policy cookie [<name> [<secret>]]
//
// By default name is `lb`
func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
args := d.RemainingArgs()
switch len(args) {
case 1:
case 2:
s.Name = args[1]
case 3:
s.Name = args[1]
s.Secret = args[2]
default:
return d.ArgErr()
}
return nil
}

// Select a new Host randomly and add a sticky session cookie
func selectNewHostWithCookieHashSelection(pool []*Upstream, w http.ResponseWriter, cookieSecret string, cookieName string) *Upstream {
randomHost := selectRandomHost(pool)

if randomHost != nil {
// Hash (HMAC with some key for privacy) the upstream.Dial string as the cookie value
sha, err := hashCookie(cookieSecret, randomHost.Dial)
if err == nil {
// write the cookie.
http.SetCookie(w, &http.Cookie{Name: cookieName, Value: sha, Secure: false})
}
}
return randomHost
}

// hashCookie hashes (HMAC 256) some data with the secret
func hashCookie(secret string, data string) (string, error) {
h := hmac.New(sha256.New, []byte(secret))
_, err := h.Write([]byte(data))
if err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}

// selectRandomHost returns a random available host
func selectRandomHost(pool []*Upstream) *Upstream {
// use reservoir sampling because the number of available
// hosts isn't known: https://en.wikipedia.org/wiki/Reservoir_sampling
var randomHost *Upstream
var count int
for _, upstream := range pool {
if !upstream.Available() {
continue
}
// (n % 1 == 0) holds for all n, therefore a
// upstream will always be chosen if there is at
// least one available
count++
if (weakrand.Int() % count) == 0 {
randomHost = upstream
}
}
return randomHost
}

// leastRequests returns the host with the
// least number of active requests to it.
// If more than one host has the same
Expand Down Expand Up @@ -454,6 +550,7 @@ var (
_ Selector = (*IPHashSelection)(nil)
_ Selector = (*URIHashSelection)(nil)
_ Selector = (*HeaderHashSelection)(nil)
_ Selector = (*CookieHashSelection)(nil)

_ caddy.Validator = (*RandomChoiceSelection)(nil)
_ caddy.Provisioner = (*RandomChoiceSelection)(nil)
Expand Down
Loading

0 comments on commit 6e0849d

Please sign in to comment.