From bcd8bdefc1827aadbc099458ed9f044cf9f36b15 Mon Sep 17 00:00:00 2001 From: Leon Klingele Date: Sat, 28 Jan 2023 21:39:27 +0100 Subject: [PATCH] middleware/earlydata: backport to v2 Backport of https://github.com/gofiber/fiber/pull/2270 to v2. --- middleware/earlydata/README.md | 14 ++--- middleware/earlydata/config.go | 14 +++-- middleware/earlydata/earlydata.go | 6 +- middleware/earlydata/earlydata_test.go | 82 ++++++++++++++------------ 4 files changed, 61 insertions(+), 55 deletions(-) diff --git a/middleware/earlydata/README.md b/middleware/earlydata/README.md index d12d6946539..862e78b496f 100644 --- a/middleware/earlydata/README.md +++ b/middleware/earlydata/README.md @@ -36,8 +36,8 @@ First import the middleware from Fiber, ```go import ( - "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/middleware/earlydata" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/earlydata" ) ``` @@ -65,17 +65,17 @@ type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil - Next func(c fiber.Ctx) bool + Next func(c *fiber.Ctx) bool // IsEarlyData returns whether the request is an early-data request. // // Optional. Default: a function which checks if the "Early-Data" request header equals "1". - IsEarlyData func(c fiber.Ctx) bool + IsEarlyData func(c *fiber.Ctx) bool // AllowEarlyData returns whether the early-data request should be allowed or rejected. // // Optional. Default: a function which rejects the request on unsafe and allows the request on safe HTTP request methods. - AllowEarlyData func(c fiber.Ctx) bool + AllowEarlyData func(c *fiber.Ctx) bool // Error is returned in case an early-data request is rejected. // @@ -88,11 +88,11 @@ type Config struct { ```go var ConfigDefault = Config{ - IsEarlyData: func(c fiber.Ctx) bool { + IsEarlyData: func(c *fiber.Ctx) bool { return c.Get("Early-Data") == "1" }, - AllowEarlyData: func(c fiber.Ctx) bool { + AllowEarlyData: func(c *fiber.Ctx) bool { return fiber.IsMethodSafe(c.Method()) }, diff --git a/middleware/earlydata/config.go b/middleware/earlydata/config.go index ced705dd570..9ec223a8b75 100644 --- a/middleware/earlydata/config.go +++ b/middleware/earlydata/config.go @@ -1,7 +1,7 @@ package earlydata import ( - "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v2" ) const ( @@ -14,17 +14,17 @@ type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil - Next func(c fiber.Ctx) bool + Next func(c *fiber.Ctx) bool // IsEarlyData returns whether the request is an early-data request. // // Optional. Default: a function which checks if the "Early-Data" request header equals "1". - IsEarlyData func(c fiber.Ctx) bool + IsEarlyData func(c *fiber.Ctx) bool // AllowEarlyData returns whether the early-data request should be allowed or rejected. // // Optional. Default: a function which rejects the request on unsafe and allows the request on safe HTTP request methods. - AllowEarlyData func(c fiber.Ctx) bool + AllowEarlyData func(c *fiber.Ctx) bool // Error is returned in case an early-data request is rejected. // @@ -33,12 +33,14 @@ type Config struct { } // ConfigDefault is the default config +// +//nolint:gochecknoglobals // Using a global var is fine here var ConfigDefault = Config{ - IsEarlyData: func(c fiber.Ctx) bool { + IsEarlyData: func(c *fiber.Ctx) bool { return c.Get(DefaultHeaderName) == DefaultHeaderTrueValue }, - AllowEarlyData: func(c fiber.Ctx) bool { + AllowEarlyData: func(c *fiber.Ctx) bool { return fiber.IsMethodSafe(c.Method()) }, diff --git a/middleware/earlydata/earlydata.go b/middleware/earlydata/earlydata.go index 2b53341f9c0..638db3c6fb9 100644 --- a/middleware/earlydata/earlydata.go +++ b/middleware/earlydata/earlydata.go @@ -1,14 +1,14 @@ package earlydata import ( - "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v2" ) const ( localsKeyAllowed = "earlydata_allowed" ) -func IsEarly(c fiber.Ctx) bool { +func IsEarly(c *fiber.Ctx) bool { return c.Locals(localsKeyAllowed) != nil } @@ -19,7 +19,7 @@ func New(config ...Config) fiber.Handler { cfg := configDefault(config...) // Return new handler - return func(c fiber.Ctx) error { + return func(c *fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() diff --git a/middleware/earlydata/earlydata_test.go b/middleware/earlydata/earlydata_test.go index aee22c61f2b..7d650d539ea 100644 --- a/middleware/earlydata/earlydata_test.go +++ b/middleware/earlydata/earlydata_test.go @@ -1,3 +1,4 @@ +//nolint:bodyclose // Much easier to just ignore memory leaks in tests package earlydata_test import ( @@ -7,9 +8,9 @@ import ( "net/http/httptest" "testing" - "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/middleware/earlydata" - "github.com/stretchr/testify/require" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/earlydata" + "github.com/gofiber/fiber/v2/utils" ) const ( @@ -33,12 +34,11 @@ func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App { // Middleware to test IsEarly func const localsKeyTestValid = "earlydata_testvalid" - app.Use(func(c fiber.Ctx) error { + app.Use(func(c *fiber.Ctx) error { isEarly := earlydata.IsEarly(c) switch h := c.Get(headerName); h { - case "", - headerValOff: + case "", headerValOff: if isEarly { return errors.New("is early-data even though it's not") } @@ -64,16 +64,20 @@ func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App { return c.Next() }) - app.Add([]string{ - fiber.MethodGet, - fiber.MethodPost, - }, "/", func(c fiber.Ctx) error { - if !c.Locals(localsKeyTestValid).(bool) { - return errors.New("handler called even though validation failed") - } + { + { + handler := func(c *fiber.Ctx) error { + if !c.Locals(localsKeyTestValid).(bool) { //nolint:forcetypeassert,errcheck // We store nothing else in the pool + return errors.New("handler called even though validation failed") + } - return nil - }) + return nil + } + + app.Get("/", handler) + app.Post("/", handler) + } + } return app } @@ -89,36 +93,36 @@ func Test_EarlyData(t *testing.T) { req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) resp, err := app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) req.Header.Set(headerName, headerValOff) resp, err = app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) req.Header.Set(headerName, headerValOn) resp, err = app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) } { req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody) resp, err := app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) req.Header.Set(headerName, headerValOff) resp, err = app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) req.Header.Set(headerName, headerValOn) resp, err = app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusTooEarly, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) } } @@ -129,36 +133,36 @@ func Test_EarlyData(t *testing.T) { req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) resp, err := app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusTooEarly, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) req.Header.Set(headerName, headerValOff) resp, err = app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusTooEarly, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) req.Header.Set(headerName, headerValOn) resp, err = app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusTooEarly, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) } { req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody) resp, err := app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusTooEarly, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) req.Header.Set(headerName, headerValOff) resp, err = app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusTooEarly, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) req.Header.Set(headerName, headerValOn) resp, err = app.Test(req) - require.NoError(t, err) - require.Equal(t, fiber.StatusTooEarly, resp.StatusCode) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode) } }