diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 4caa3b291..b1de9ad37 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -313,8 +313,7 @@ func TestBasicAuthPassword(t *testing.T) { })) opts := NewOptions() opts.Upstreams = append(opts.Upstreams, provider_server.URL) - // The CookieSecret must be 32 bytes in order to create the AES - // cipher. + // The CookieSecret must be 32 bytes in order to create the AES cipher opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" opts.ClientID = "bazquux" opts.ClientSecret = "foobar" @@ -407,8 +406,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes t.opts = NewOptions() t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL) - // The CookieSecret must be 32 bytes in order to create the AES - // cipher. + // The CookieSecret must be 32 bytes in order to create the AES cipher t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" t.opts.ClientID = "bazquux" t.opts.ClientSecret = "foobar" @@ -639,7 +637,7 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { pc_test.opts = NewOptions() pc_test.opts.ClientID = "bazquux" pc_test.opts.ClientSecret = "xyzzyplugh" - pc_test.opts.CookieSecret = "0123456789abcdefabcd" + pc_test.opts.CookieSecret = "0123456789abcdef012345==" // First, set the CookieRefresh option so proxy.AesCipher is created, // needed to encrypt the access_token. pc_test.opts.CookieRefresh = time.Hour diff --git a/options.go b/options.go index 14c3d4625..4d1e5b094 100644 --- a/options.go +++ b/options.go @@ -331,11 +331,9 @@ func validateCookieName(o *Options, msgs []string) []string { return msgs } +// for base64 which has had '=' padding trimmed off func addPadding(secret string) string { - padding := len(secret) % 4 - switch padding { - case 1: - return secret + "===" + switch len(secret) % 4 { case 2: return secret + "==" case 3: @@ -349,7 +347,9 @@ func addPadding(secret string) string { func secretBytes(secret string) []byte { b, err := base64.URLEncoding.DecodeString(addPadding(secret)) if err == nil { - return []byte(addPadding(string(b))) + if len(b) == 16 || len(b) == 24 || len(b) == 32 { + return b + } } return []byte(secret) } diff --git a/options_test.go b/options_test.go index e1017eb12..3ddff572e 100644 --- a/options_test.go +++ b/options_test.go @@ -2,7 +2,10 @@ package main import ( "crypto" + "crypto/rand" + "encoding/base64" "fmt" + "io" "net/url" "strings" "testing" @@ -196,7 +199,7 @@ func TestCookieRefreshMustBeLessThanCookieExpire(t *testing.T) { o := testOptions() assert.Equal(t, nil, o.Validate()) - o.CookieSecret = "0123456789abcdefabcd" + o.CookieSecret = "0123456789abcdef012345" o.CookieRefresh = o.CookieExpire assert.NotEqual(t, nil, o.Validate()) @@ -283,3 +286,73 @@ func TestSkipOIDCDiscovery(t *testing.T) { assert.Equal(t, nil, o.Validate()) } + +func TestSecretBytesEncoded(t *testing.T) { + for _, secretSize := range []int{16, 24, 32} { + t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { + secret := make([]byte, secretSize) + _, err := io.ReadFull(rand.Reader, secret) + assert.Equal(t, nil, err) + + // We test both padded & raw Base64 to ensure we handle both + // potential user input routes for Base64 + base64Padded := base64.URLEncoding.EncodeToString(secret) + sb := secretBytes(base64Padded) + assert.Equal(t, secret, sb) + assert.Equal(t, len(sb), secretSize) + + base64Raw := base64.RawURLEncoding.EncodeToString(secret) + sb = secretBytes(base64Raw) + assert.Equal(t, secret, sb) + assert.Equal(t, len(sb), secretSize) + }) + } +} + +// A string that isn't intended as Base64 and still decodes (but to unintended length) +// will return the original secret as bytes +func TestSecretBytesEncodedWrongSize(t *testing.T) { + for _, secretSize := range []int{15, 20, 28, 33, 44} { + t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) { + secret := make([]byte, secretSize) + _, err := io.ReadFull(rand.Reader, secret) + assert.Equal(t, nil, err) + + // We test both padded & raw Base64 to ensure we handle both + // potential user input routes for Base64 + base64Padded := base64.URLEncoding.EncodeToString(secret) + sb := secretBytes(base64Padded) + assert.NotEqual(t, secret, sb) + assert.NotEqual(t, len(sb), secretSize) + // The given secret is returned as []byte + assert.Equal(t, base64Padded, string(sb)) + + base64Raw := base64.RawURLEncoding.EncodeToString(secret) + sb = secretBytes(base64Raw) + assert.NotEqual(t, secret, sb) + assert.NotEqual(t, len(sb), secretSize) + // The given secret is returned as []byte + assert.Equal(t, base64Raw, string(sb)) + }) + } +} + +func TestSecretBytesNonBase64(t *testing.T) { + trailer := "equals==========" + assert.Equal(t, trailer, string(secretBytes(trailer))) + + raw16 := "asdflkjhqwer)(*&" + sb16 := secretBytes(raw16) + assert.Equal(t, raw16, string(sb16)) + assert.Equal(t, 16, len(sb16)) + + raw24 := "asdflkjhqwer)(*&CJEN#$%^" + sb24 := secretBytes(raw24) + assert.Equal(t, raw24, string(sb24)) + assert.Equal(t, 24, len(sb24)) + + raw32 := "asdflkjhqwer)(*&1234lkjhqwer)(*&" + sb32 := secretBytes(raw32) + assert.Equal(t, raw32, string(sb32)) + assert.Equal(t, 32, len(sb32)) +}