Skip to content

Commit

Permalink
reverse_proxy: Add and fix cookie lb_policy
Browse files Browse the repository at this point in the history
  • Loading branch information
d-masson committed Nov 4, 2020
1 parent 0cc999d commit 0b14630
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 36 deletions.
81 changes: 69 additions & 12 deletions modules/caddyhttp/reverseproxy/selectionpolicies.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
package reverseproxy

import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash/fnv"
weakrand "math/rand"
Expand Down Expand Up @@ -392,7 +395,8 @@ func (s *HeaderHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
// a host based on a given cookie name.
type CookieHashSelection struct {
// The HTTP cookie name whose value is to be hashed and used for upstream selection.
Field string `json:"field,omitempty"`
Name string `json:"name,omitempty"`
Secret string `json:"secret,omitempty"`
}

// CaddyModule returns the Caddy module information.
Expand All @@ -405,31 +409,84 @@ func (CookieHashSelection) CaddyModule() caddy.ModuleInfo {

// Select returns an available host, if any.
func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream {
if s.Field == "" {
return nil
if s.Secret == "" {
s.Secret = "caddysecret"
}
cookie, err := req.Cookie(s.Field)
if s.Name == "" {
s.Name = "lb"
}
cookie, err := req.Cookie(s.Name)
var cookieValue string
// If there's no cookie, select new random host
if err != nil || cookie == nil {
cookieValue = caddy.RandomString(16)
http.SetCookie(w, &http.Cookie{Name: s.Field, Value: cookieValue, Secure: false})
return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name)
} else {
// If the cookie is present, loop over the available upstreams until we find a match
cookieValue = cookie.Value
for _, upstream := range pool {
if !upstream.Available() {
continue
}
if hashCookie(s.Secret, upstream.Dial) == cookieValue {
return upstream
}
}
}
return hostByHashing(pool, cookieValue)
// If there is no matching host, select new random host
return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name)
}

// UnmarshalCaddyfile sets up the module from Caddyfile tokens.
func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
for d.Next() {
if !d.NextArg() {
return d.ArgErr()
}
s.Field = d.Val()
args := d.RemainingArgs()
switch len(args) {
case 1:
case 2:
s.Name = args[1]
case 3:
s.Name = args[1]
s.Secret = args[2]
default:
return d.ArgErr()
}
return nil
}

// Select a new Host using RandomChoose () and add a sticky session cookie
func selectNewHostWithCookieHashSelection(pool []*Upstream, w http.ResponseWriter, cookieSecret string, cookieName string) *Upstream {
var randomHost *Upstream
var count int
for _, upstream := range pool {
if !upstream.Available() {
continue
}
// (n % 1 == 0) holds for all n, therefore a
// upstream will always be chosen if there is at
// least one available
count++
if (weakrand.Int() % count) == 0 {
randomHost = upstream
}
}

if randomHost != nil {
// Hash (HMAC with some key for privacy) the upstream.Dial string as the cookie value
sha := hashCookie(cookieSecret, randomHost.Dial)
// write the cookie.
http.SetCookie(w, &http.Cookie{Name: cookieName, Value: sha, Secure: false})
}
return randomHost
}

// Hash (Hmac256) some data with the secret
func hashCookie(secret string, data string) string {
// Create a new HMAC by defining the hash type and the key (as byte array)
h := hmac.New(sha256.New, []byte(secret))
// Write Data to it
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}

// leastRequests returns the host with the
// least number of active requests to it.
// If more than one host has the same
Expand Down
61 changes: 61 additions & 0 deletions modules/caddyhttp/reverseproxy/selectionpolicies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,64 @@ func TestURIHashPolicy(t *testing.T) {
t.Error("Expected uri policy policy host to be nil.")
}
}

func TestCookieHashPolicy(t *testing.T) {
pool := testPool()
pool[0].Dial = "localhost:8080"
pool[1].Dial = "localhost:8081"
pool[2].Dial = "localhost:8082"
pool[0].SetHealthy(true)
pool[1].SetHealthy(false)
pool[2].SetHealthy(false)

request := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
cookieHashPolicy := new(CookieHashSelection)
h := cookieHashPolicy.Select(pool, request, w)

cookie_server1 := w.Result().Cookies()[0]

if cookie_server1 == nil {
t.Error("cookieHashPolicy should set a cookie")
}

if cookie_server1.Name != "lb" {
t.Error("cookieHashPolicy should set a cookie with name lb")
}

if h != pool[0] {
t.Error("Expected cookieHashPolicy host to be the first only available host.")
}

pool[1].SetHealthy(true)
pool[2].SetHealthy(true)
request = httptest.NewRequest(http.MethodGet, "/test", nil)
w = httptest.NewRecorder()
request.AddCookie(cookie_server1)
h = cookieHashPolicy.Select(pool, request, w)

if h != pool[0] {
t.Error("Expected cookieHashPolicy host to stick to the first host (matching cookie).")
}

s := w.Result().Cookies()

if len(s) != 0 {
t.Error("Expected cookieHashPolicy to not set a new cookie.")
}

pool[0].SetHealthy(false)
request = httptest.NewRequest(http.MethodGet, "/test", nil)
w = httptest.NewRecorder()
request.AddCookie(cookie_server1)
h = cookieHashPolicy.Select(pool, request, w)

if h == pool[0] {
t.Error("Expected cookieHashPolicy to select a new host.")
}

if w.Result().Cookies() == nil {
t.Error("Expected cookieHashPolicy to set a new cookie.")
}

}
24 changes: 0 additions & 24 deletions rand.go

This file was deleted.

0 comments on commit 0b14630

Please sign in to comment.