Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔥 Feature: Add support for custom KeyLookup functions in the Keyauth middleware #3028

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
0b6b5e9
port over FallbackKeyLookups from v2 middleware to v3
dave-gray101 Jun 9, 2024
f118663
bot pointed out that I missed the format variable
dave-gray101 Jun 9, 2024
a432b80
fix lint and gofumpt issues
dave-gray101 Jun 10, 2024
4e061aa
major revision: instead of FallbackKeyLookups, expose CustomKeyLookup…
dave-gray101 Jun 10, 2024
397ff49
add more tests to boost coverage
dave-gray101 Jun 10, 2024
2968439
teardown code and cleanup
dave-gray101 Jun 10, 2024
5b24181
Merge branch 'main' into feat-keyauth-fallback-keylookup
gaby Jun 11, 2024
7bfc96d
test fixes
dave-gray101 Jun 11, 2024
8c13f25
Merge branch 'main' into feat-keyauth-fallback-keylookup
dave-gray101 Jun 12, 2024
7191f65
slight boost to test coverage
dave-gray101 Jun 16, 2024
1004aa0
Merge branch 'feat-keyauth-fallback-keylookup' of ghgray101:dave-gray…
dave-gray101 Jun 16, 2024
961e8de
docs: fix md table alignment
sixcolors Jun 16, 2024
293c01b
fix comments - change some names, expose functions, improve docs
dave-gray101 Jun 16, 2024
825e11a
Merge branch 'feat-keyauth-fallback-keylookup' of ghgray101:dave-gray…
dave-gray101 Jun 16, 2024
26bc132
missed one old name
dave-gray101 Jun 16, 2024
9588706
fix some suggestions from the bot - error messages, test coverage, ma…
dave-gray101 Jun 17, 2024
2711dc3
Merge branch 'main' into feat-keyauth-fallback-keylookup
gaby Jun 17, 2024
4da76e4
Merge branch 'main' into feat-keyauth-fallback-keylookup
dave-gray101 Jun 18, 2024
12eeca8
Merge branch 'main' into feat-keyauth-fallback-keylookup
dave-gray101 Jun 20, 2024
d6d5bfe
Merge branch 'main' into feat-keyauth-fallback-keylookup
ReneWerner87 Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions docs/middleware/keyauth.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,15 @@ curl --header "Authorization: Bearer my-super-secret-key" http://localhost:3000

## Config

| Property | Type | Description | Default |
|:---------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract key from the request. | "header:Authorization" |
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
| Validator | `func(fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |
| Property | Type | Description | Default |
|:----------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract the key from the request. | "header:Authorization" |
| CustomKeyLookup | `KeyauthKeyLookupFunc` aka `func(c fiber.Ctx) (string, error)` | If more complex logic is required to extract the key from the request, an arbitrary function to extract it can be specified here. Utility helper functions are described below. | `nil` |
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
| Validator | `func(fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |

## Default Config

Expand All @@ -237,6 +238,13 @@ var ConfigDefault = Config{
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
},
KeyLookup: "header:" + fiber.HeaderAuthorization,
CustomKeyLookup: nil,
AuthScheme: "Bearer",
}
```

## CustomKeyLookup

Two public utility functions are provided that may be useful when creating custom extraction:
* `SingleKeyLookup(keyLookup string, authScheme string)`: This is the function that implements the default `KeyLookup` behavior, exposed to be used as a component of custom parsing logic
* `MultipleKeySourceLookup(keyLookups []string, authScheme string)`: Creates a CustomKeyLookup function that checks each listed source using the above function until a key is found or the options are all exhausted
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
9 changes: 7 additions & 2 deletions middleware/keyauth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"github.com/gofiber/fiber/v3"
)

type KeyauthKeyLookupFunc func(c fiber.Ctx) (string, error)
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved

// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip middleware.
Expand All @@ -32,6 +34,8 @@ type Config struct {
// - "cookie:<name>"
KeyLookup string

CustomKeyLookup KeyauthKeyLookupFunc
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved

// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
Expand All @@ -51,8 +55,9 @@ var ConfigDefault = Config{
}
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
},
KeyLookup: "header:" + fiber.HeaderAuthorization,
AuthScheme: "Bearer",
KeyLookup: "header:" + fiber.HeaderAuthorization,
CustomKeyLookup: nil,
AuthScheme: "Bearer",
}

// Helper function to set default values
Expand Down
74 changes: 57 additions & 17 deletions middleware/keyauth/keyauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import (
"errors"
"fmt"
"net/url"
"strings"

Expand Down Expand Up @@ -34,17 +35,12 @@
cfg := configDefault(config...)

// Initialize
parts := strings.Split(cfg.KeyLookup, ":")
extractor := keyFromHeader(parts[1], cfg.AuthScheme)
switch parts[0] {
case query:
extractor = keyFromQuery(parts[1])
case form:
extractor = keyFromForm(parts[1])
case param:
extractor = keyFromParam(parts[1])
case cookie:
extractor = keyFromCookie(parts[1])
if cfg.CustomKeyLookup == nil {
var err error
cfg.CustomKeyLookup, err = SingleKeyLookup(cfg.KeyLookup, cfg.AuthScheme)
if err != nil {
panic(fmt.Errorf("unable to create lookup function: %w", err))

Check warning on line 42 in middleware/keyauth/keyauth.go

View check run for this annotation

Codecov / codecov/patch

middleware/keyauth/keyauth.go#L42

Added line #L42 was not covered by tests
}
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
}

// Return middleware handler
Expand All @@ -55,7 +51,7 @@
}

// Extract and verify key
key, err := extractor(c)
key, err := cfg.CustomKeyLookup(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}
Expand All @@ -80,8 +76,52 @@
return token
}

// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found
// Each element should be specified according to the format used in KeyLookup
func MultipleKeySourceLookup(keyLookups []string, authScheme string) (KeyauthKeyLookupFunc, error) {
subExtractors := map[string]KeyauthKeyLookupFunc{}
var err error
for _, keyLookup := range keyLookups {
subExtractors[keyLookup], err = SingleKeyLookup(keyLookup, authScheme)
if err != nil {
return nil, err
}
}
return func(c fiber.Ctx) (string, error) {
for keyLookup, subExtractor := range subExtractors {
res, err := subExtractor(c)
if err == nil && res != "" {
return res, nil
}
if !errors.Is(err, ErrMissingOrMalformedAPIKey) {
return "", fmt.Errorf("[%s] %w", keyLookup, err)

Check warning on line 97 in middleware/keyauth/keyauth.go

View check run for this annotation

Codecov / codecov/patch

middleware/keyauth/keyauth.go#L97

Added line #L97 was not covered by tests
}
}
return "", ErrMissingOrMalformedAPIKey
}, nil
}

func SingleKeyLookup(keyLookup, authScheme string) (KeyauthKeyLookupFunc, error) {
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
parts := strings.Split(keyLookup, ":")
if len(parts) <= 1 {
return nil, fmt.Errorf("invalid keyLookup: %s", keyLookup)
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
}
extractor := keyFromHeader(parts[1], authScheme) // in the event of an invalid prefix, it is interpreted as header:
switch parts[0] {
case query:
extractor = keyFromQuery(parts[1])
case form:
extractor = keyFromForm(parts[1])
case param:
extractor = keyFromParam(parts[1])
case cookie:
extractor = keyFromCookie(parts[1])
}
return extractor, nil
}
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved

// keyFromHeader returns a function that extracts api key from the request header.
func keyFromHeader(header, authScheme string) func(c fiber.Ctx) (string, error) {
func keyFromHeader(header, authScheme string) KeyauthKeyLookupFunc {
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
return func(c fiber.Ctx) (string, error) {
auth := c.Get(header)
l := len(authScheme)
Expand All @@ -96,7 +136,7 @@
}

// keyFromQuery returns a function that extracts api key from the query string.
func keyFromQuery(param string) func(c fiber.Ctx) (string, error) {
func keyFromQuery(param string) KeyauthKeyLookupFunc {
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
return func(c fiber.Ctx) (string, error) {
key := fiber.Query[string](c, param)
if key == "" {
Expand All @@ -107,7 +147,7 @@
}

// keyFromForm returns a function that extracts api key from the form.
func keyFromForm(param string) func(c fiber.Ctx) (string, error) {
func keyFromForm(param string) KeyauthKeyLookupFunc {
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
return func(c fiber.Ctx) (string, error) {
key := c.FormValue(param)
if key == "" {
Expand All @@ -118,7 +158,7 @@
}

// keyFromParam returns a function that extracts api key from the url param string.
func keyFromParam(param string) func(c fiber.Ctx) (string, error) {
func keyFromParam(param string) KeyauthKeyLookupFunc {
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
return func(c fiber.Ctx) (string, error) {
key, err := url.PathUnescape(c.Params(param))
if err != nil {
Expand All @@ -129,7 +169,7 @@
}

// keyFromCookie returns a function that extracts api key from the named cookie.
func keyFromCookie(name string) func(c fiber.Ctx) (string, error) {
func keyFromCookie(name string) KeyauthKeyLookupFunc {
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
return func(c fiber.Ctx) (string, error) {
key := c.Cookies(name)
if key == "" {
Expand Down
136 changes: 136 additions & 0 deletions middleware/keyauth/keyauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,93 @@ func Test_AuthSources(t *testing.T) {
}
}

func TestPanicOnInvalidConfiguration(t *testing.T) {
require.Panics(t, func() {
authMiddleware := New(Config{
KeyLookup: "invalid",
})
// We shouldn't even make it this far, but these next two lines prevent authMiddleware from being an unused variable.
app := fiber.New()
defer func() { // testing panics, defer block to ensure cleanup
err := app.Shutdown()
require.NoError(t, err)
}()
app.Use(authMiddleware)
})
}
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved

func TestCustomKeyUtilityFunctionErrors(t *testing.T) {
const (
scheme = "Bearer"
)

// Invalid element while parsing
_, err := SingleKeyLookup("invalid", scheme)
require.Error(t, err)

_, err = MultipleKeySourceLookup([]string{"header:key", "invalid"}, scheme)
require.Error(t, err)
}
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved

func TestMultipleKeyLookup(t *testing.T) {
const (
desc = "auth with correct key"
success = "Success!"
scheme = "Bearer"
)

// setup the fiber endpoint
app := fiber.New()

customKeyLookup, err := MultipleKeySourceLookup([]string{"header:key", "cookie:key", "query:key"}, scheme)
require.NoError(t, err)

authMiddleware := New(Config{
CustomKeyLookup: customKeyLookup,
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
})
app.Use(authMiddleware)
app.Get("/foo", func(c fiber.Ctx) error {
return c.SendString(success)
})

// construct the test HTTP request
var req *http.Request
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
require.NoError(t, err)
q := req.URL.Query()
q.Add("key", CorrectKey)
req.URL.RawQuery = q.Encode()

res, err := app.Test(req, -1)

require.NoError(t, err)

// test the body of the request
body, err := io.ReadAll(res.Body)
require.Equal(t, 200, res.StatusCode, desc)
// body
require.NoError(t, err)
require.Equal(t, success, string(body), desc)

err = res.Body.Close()
require.NoError(t, err)

// construct a second request without proper key
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
require.NoError(t, err)
res, err = app.Test(req, -1)
require.NoError(t, err)
errBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(errBody))
}

func Test_MultipleKeyAuth(t *testing.T) {
// setup the fiber endpoint
app := fiber.New()
Expand Down Expand Up @@ -376,6 +463,55 @@ func Test_CustomNextFunc(t *testing.T) {
require.Equal(t, string(body), ErrMissingOrMalformedAPIKey.Error())
}

func Test_TokenFromContext_None(t *testing.T) {
app := fiber.New()
// Define a test handler that checks TokenFromContext
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(TokenFromContext(c))
})

// Verify a "" is sent back if nothing sets the token on the context.
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
// Send
res, err := app.Test(req)
require.NoError(t, err)

// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Empty(t, body)
}

func Test_TokenFromContext(t *testing.T) {
app := fiber.New()
// Wire up keyauth middleware to set TokenFromContext now
app.Use(New(Config{
KeyLookup: "header:Authorization",
AuthScheme: "Basic",
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler that checks TokenFromContext
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(TokenFromContext(c))
})

req := httptest.NewRequest(fiber.MethodGet, "/", nil)
req.Header.Add("Authorization", "Basic "+CorrectKey)
// Send
res, err := app.Test(req)
require.NoError(t, err)

// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, CorrectKey, string(body))
}

func Test_AuthSchemeToken(t *testing.T) {
app := fiber.New()

Expand Down
Loading