From 2db18585134f96bef865f0532100d740d3886bc1 Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 27 Jun 2024 14:30:47 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=20Feature:=20Add=20support=20for?= =?UTF-8?q?=20custom=20KeyLookup=20functions=20in=20the=20Keyauth=20middle?= =?UTF-8?q?ware=20(#3028)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * port over FallbackKeyLookups from v2 middleware to v3 Signed-off-by: Dave Lee * bot pointed out that I missed the format variable Signed-off-by: Dave Lee * fix lint and gofumpt issues Signed-off-by: Dave Lee * major revision: instead of FallbackKeyLookups, expose CustomKeyLookup as function, with utility functions to make creating these easy Signed-off-by: Dave Lee * add more tests to boost coverage Signed-off-by: Dave Lee * teardown code and cleanup Signed-off-by: Dave Lee * test fixes Signed-off-by: Dave Lee * slight boost to test coverage Signed-off-by: Dave Lee * docs: fix md table alignment * fix comments - change some names, expose functions, improve docs Signed-off-by: Dave Lee * missed one old name Signed-off-by: Dave Lee * fix some suggestions from the bot - error messages, test coverage, mark purely defensive code Signed-off-by: Dave Lee --------- Signed-off-by: Dave Lee Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Co-authored-by: Jason McNeil Co-authored-by: RW --- docs/middleware/keyauth.md | 24 +++-- middleware/keyauth/config.go | 9 +- middleware/keyauth/keyauth.go | 75 ++++++++++---- middleware/keyauth/keyauth_test.go | 152 +++++++++++++++++++++++++++++ 4 files changed, 233 insertions(+), 27 deletions(-) diff --git a/docs/middleware/keyauth.md b/docs/middleware/keyauth.md index 0699108588..9907a5dabf 100644 --- a/docs/middleware/keyauth.md +++ b/docs/middleware/keyauth.md @@ -214,14 +214,15 @@ curl --header "Authorization: Bearer my-super-secret-key" http://localhost:3000 ## Config -| Property | Type | Description | Default | -|:---------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------| -| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` | -| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` | -| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` | -| KeyLookup | `string` | KeyLookup is a string in the form of "`:`" that is used to extract key from the request. | "header:Authorization" | -| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" | -| Validator | `func(fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation | +| Property | Type | Description | Default | +|:----------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------| +| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` | +| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` | +| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` | +| KeyLookup | `string` | KeyLookup is a string in the form of "`:`" that is used to extract the key from the request. | "header:Authorization" | +| CustomKeyLookup | `KeyLookupFunc` aka `func(c fiber.Ctx) (string, error)` | If more complex logic is required to extract the key from the request, an arbitrary function to extract it can be specified here. Utility helper functions are described below. | `nil` | +| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" | +| Validator | `func(fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation | ## Default Config @@ -237,6 +238,13 @@ var ConfigDefault = Config{ return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key") }, KeyLookup: "header:" + fiber.HeaderAuthorization, + CustomKeyLookup: nil, AuthScheme: "Bearer", } ``` + +## CustomKeyLookup + +Two public utility functions are provided that may be useful when creating custom extraction: +* `DefaultKeyLookup(keyLookup string, authScheme string)`: This is the function that implements the default `KeyLookup` behavior, exposed to be used as a component of custom parsing logic +* `MultipleKeySourceLookup(keyLookups []string, authScheme string)`: Creates a CustomKeyLookup function that checks each listed source using the above function until a key is found or the options are all exhausted. For example, `MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "cookie:apikey"}, "Bearer")` would first check the standard Authorization header, checks the `x-api-key` header next, and finally checks for a cookie named `apikey`. If any of these contain a valid API key, the request continues. Otherwise, an error is returned. diff --git a/middleware/keyauth/config.go b/middleware/keyauth/config.go index 3b02a9e1fa..8c41d60f11 100644 --- a/middleware/keyauth/config.go +++ b/middleware/keyauth/config.go @@ -6,6 +6,8 @@ import ( "github.com/gofiber/fiber/v3" ) +type KeyLookupFunc func(c fiber.Ctx) (string, error) + // Config defines the config for middleware. type Config struct { // Next defines a function to skip middleware. @@ -32,6 +34,8 @@ type Config struct { // - "cookie:" KeyLookup string + CustomKeyLookup KeyLookupFunc + // AuthScheme to be used in the Authorization header. // Optional. Default value "Bearer". AuthScheme string @@ -51,8 +55,9 @@ var ConfigDefault = Config{ } return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key") }, - KeyLookup: "header:" + fiber.HeaderAuthorization, - AuthScheme: "Bearer", + KeyLookup: "header:" + fiber.HeaderAuthorization, + CustomKeyLookup: nil, + AuthScheme: "Bearer", } // Helper function to set default values diff --git a/middleware/keyauth/keyauth.go b/middleware/keyauth/keyauth.go index 914bca03f6..e245ba4247 100644 --- a/middleware/keyauth/keyauth.go +++ b/middleware/keyauth/keyauth.go @@ -3,6 +3,7 @@ package keyauth import ( "errors" + "fmt" "net/url" "strings" @@ -34,17 +35,12 @@ func New(config ...Config) fiber.Handler { cfg := configDefault(config...) // Initialize - parts := strings.Split(cfg.KeyLookup, ":") - extractor := keyFromHeader(parts[1], cfg.AuthScheme) - switch parts[0] { - case query: - extractor = keyFromQuery(parts[1]) - case form: - extractor = keyFromForm(parts[1]) - case param: - extractor = keyFromParam(parts[1]) - case cookie: - extractor = keyFromCookie(parts[1]) + if cfg.CustomKeyLookup == nil { + var err error + cfg.CustomKeyLookup, err = DefaultKeyLookup(cfg.KeyLookup, cfg.AuthScheme) + if err != nil { + panic(fmt.Errorf("unable to create lookup function: %w", err)) + } } // Return middleware handler @@ -55,7 +51,7 @@ func New(config ...Config) fiber.Handler { } // Extract and verify key - key, err := extractor(c) + key, err := cfg.CustomKeyLookup(c) if err != nil { return cfg.ErrorHandler(c, err) } @@ -80,8 +76,53 @@ func TokenFromContext(c fiber.Ctx) string { return token } +// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found +// Each element should be specified according to the format used in KeyLookup +func MultipleKeySourceLookup(keyLookups []string, authScheme string) (KeyLookupFunc, error) { + subExtractors := map[string]KeyLookupFunc{} + var err error + for _, keyLookup := range keyLookups { + subExtractors[keyLookup], err = DefaultKeyLookup(keyLookup, authScheme) + if err != nil { + return nil, err + } + } + return func(c fiber.Ctx) (string, error) { + for keyLookup, subExtractor := range subExtractors { + res, err := subExtractor(c) + if err == nil && res != "" { + return res, nil + } + if !errors.Is(err, ErrMissingOrMalformedAPIKey) { + // Defensive Code - not currently possible to hit + return "", fmt.Errorf("[%s] %w", keyLookup, err) + } + } + return "", ErrMissingOrMalformedAPIKey + }, nil +} + +func DefaultKeyLookup(keyLookup, authScheme string) (KeyLookupFunc, error) { + parts := strings.Split(keyLookup, ":") + if len(parts) <= 1 { + return nil, fmt.Errorf("invalid keyLookup: %q, expected format 'source:name'", keyLookup) + } + extractor := KeyFromHeader(parts[1], authScheme) // in the event of an invalid prefix, it is interpreted as header: + switch parts[0] { + case query: + extractor = KeyFromQuery(parts[1]) + case form: + extractor = KeyFromForm(parts[1]) + case param: + extractor = KeyFromParam(parts[1]) + case cookie: + extractor = KeyFromCookie(parts[1]) + } + return extractor, nil +} + // keyFromHeader returns a function that extracts api key from the request header. -func keyFromHeader(header, authScheme string) func(c fiber.Ctx) (string, error) { +func KeyFromHeader(header, authScheme string) KeyLookupFunc { return func(c fiber.Ctx) (string, error) { auth := c.Get(header) l := len(authScheme) @@ -96,7 +137,7 @@ func keyFromHeader(header, authScheme string) func(c fiber.Ctx) (string, error) } // keyFromQuery returns a function that extracts api key from the query string. -func keyFromQuery(param string) func(c fiber.Ctx) (string, error) { +func KeyFromQuery(param string) KeyLookupFunc { return func(c fiber.Ctx) (string, error) { key := fiber.Query[string](c, param) if key == "" { @@ -107,7 +148,7 @@ func keyFromQuery(param string) func(c fiber.Ctx) (string, error) { } // keyFromForm returns a function that extracts api key from the form. -func keyFromForm(param string) func(c fiber.Ctx) (string, error) { +func KeyFromForm(param string) KeyLookupFunc { return func(c fiber.Ctx) (string, error) { key := c.FormValue(param) if key == "" { @@ -118,7 +159,7 @@ func keyFromForm(param string) func(c fiber.Ctx) (string, error) { } // keyFromParam returns a function that extracts api key from the url param string. -func keyFromParam(param string) func(c fiber.Ctx) (string, error) { +func KeyFromParam(param string) KeyLookupFunc { return func(c fiber.Ctx) (string, error) { key, err := url.PathUnescape(c.Params(param)) if err != nil { @@ -129,7 +170,7 @@ func keyFromParam(param string) func(c fiber.Ctx) (string, error) { } // keyFromCookie returns a function that extracts api key from the named cookie. -func keyFromCookie(name string) func(c fiber.Ctx) (string, error) { +func KeyFromCookie(name string) KeyLookupFunc { return func(c fiber.Ctx) (string, error) { key := c.Cookies(name) if key == "" { diff --git a/middleware/keyauth/keyauth_test.go b/middleware/keyauth/keyauth_test.go index 3cb756dc32..e59e0936f7 100644 --- a/middleware/keyauth/keyauth_test.go +++ b/middleware/keyauth/keyauth_test.go @@ -130,6 +130,109 @@ func Test_AuthSources(t *testing.T) { } } +func TestPanicOnInvalidConfiguration(t *testing.T) { + require.Panics(t, func() { + authMiddleware := New(Config{ + KeyLookup: "invalid", + }) + // We shouldn't even make it this far, but these next two lines prevent authMiddleware from being an unused variable. + app := fiber.New() + defer func() { // testing panics, defer block to ensure cleanup + err := app.Shutdown() + require.NoError(t, err) + }() + app.Use(authMiddleware) + }, "should panic if Validator is missing") + + require.Panics(t, func() { + authMiddleware := New(Config{ + KeyLookup: "invalid", + Validator: func(_ fiber.Ctx, _ string) (bool, error) { + return true, nil + }, + }) + // We shouldn't even make it this far, but these next two lines prevent authMiddleware from being an unused variable. + app := fiber.New() + defer func() { // testing panics, defer block to ensure cleanup + err := app.Shutdown() + require.NoError(t, err) + }() + app.Use(authMiddleware) + }, "should panic if CustomKeyLookup is not set AND KeyLookup has an invalid value") +} + +func TestCustomKeyUtilityFunctionErrors(t *testing.T) { + const ( + scheme = "Bearer" + ) + + // Invalid element while parsing + _, err := DefaultKeyLookup("invalid", scheme) + require.Error(t, err, "DefaultKeyLookup should fail for 'invalid' keyLookup") + + _, err = MultipleKeySourceLookup([]string{"header:key", "invalid"}, scheme) + require.Error(t, err, "MultipleKeySourceLookup should fail for 'invalid' keyLookup") +} + +func TestMultipleKeyLookup(t *testing.T) { + const ( + desc = "auth with correct key" + success = "Success!" + scheme = "Bearer" + ) + + // setup the fiber endpoint + app := fiber.New() + + customKeyLookup, err := MultipleKeySourceLookup([]string{"header:key", "cookie:key", "query:key"}, scheme) + require.NoError(t, err) + + authMiddleware := New(Config{ + CustomKeyLookup: customKeyLookup, + Validator: func(_ fiber.Ctx, key string) (bool, error) { + if key == CorrectKey { + return true, nil + } + return false, ErrMissingOrMalformedAPIKey + }, + }) + app.Use(authMiddleware) + app.Get("/foo", func(c fiber.Ctx) error { + return c.SendString(success) + }) + + // construct the test HTTP request + var req *http.Request + req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil) + require.NoError(t, err) + q := req.URL.Query() + q.Add("key", CorrectKey) + req.URL.RawQuery = q.Encode() + + res, err := app.Test(req, -1) + + require.NoError(t, err) + + // test the body of the request + body, err := io.ReadAll(res.Body) + require.Equal(t, 200, res.StatusCode, desc) + // body + require.NoError(t, err) + require.Equal(t, success, string(body), desc) + + err = res.Body.Close() + require.NoError(t, err) + + // construct a second request without proper key + req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil) + require.NoError(t, err) + res, err = app.Test(req, -1) + require.NoError(t, err) + errBody, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(errBody)) +} + func Test_MultipleKeyAuth(t *testing.T) { // setup the fiber endpoint app := fiber.New() @@ -376,6 +479,55 @@ func Test_CustomNextFunc(t *testing.T) { require.Equal(t, string(body), ErrMissingOrMalformedAPIKey.Error()) } +func Test_TokenFromContext_None(t *testing.T) { + app := fiber.New() + // Define a test handler that checks TokenFromContext + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(TokenFromContext(c)) + }) + + // Verify a "" is sent back if nothing sets the token on the context. + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + // Send + res, err := app.Test(req) + require.NoError(t, err) + + // Read the response body into a string + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Empty(t, body) +} + +func Test_TokenFromContext(t *testing.T) { + app := fiber.New() + // Wire up keyauth middleware to set TokenFromContext now + app.Use(New(Config{ + KeyLookup: "header:Authorization", + AuthScheme: "Basic", + Validator: func(_ fiber.Ctx, key string) (bool, error) { + if key == CorrectKey { + return true, nil + } + return false, ErrMissingOrMalformedAPIKey + }, + })) + // Define a test handler that checks TokenFromContext + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(TokenFromContext(c)) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Add("Authorization", "Basic "+CorrectKey) + // Send + res, err := app.Test(req) + require.NoError(t, err) + + // Read the response body into a string + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, CorrectKey, string(body)) +} + func Test_AuthSchemeToken(t *testing.T) { app := fiber.New()