diff --git a/middleware/healthcheck/healthcheck.go b/middleware/healthcheck/healthcheck.go index 14ff33430c..c9d6a6476b 100644 --- a/middleware/healthcheck/healthcheck.go +++ b/middleware/healthcheck/healthcheck.go @@ -2,6 +2,7 @@ package healthcheck import ( "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/utils" ) // HealthChecker defines a function to check liveness or readiness of the application @@ -40,11 +41,19 @@ func New(config ...Config) fiber.Handler { return c.Next() } - switch c.Path() { - case cfg.ReadinessEndpoint: - return isReadyHandler(c) - case cfg.LivenessEndpoint: - return isLiveHandler(c) + prefixCount := len(utils.TrimRight(c.Route().Path, '/')) + if len(c.Path()) >= prefixCount { + checkPath := c.Path()[prefixCount:] + checkPathTrimmed := checkPath + if !c.App().Config().StrictRouting { + checkPathTrimmed = utils.TrimRight(checkPath, '/') + } + switch { + case checkPath == cfg.ReadinessEndpoint || checkPathTrimmed == cfg.ReadinessEndpoint: + return isReadyHandler(c) + case checkPath == cfg.LivenessEndpoint || checkPathTrimmed == cfg.LivenessEndpoint: + return isLiveHandler(c) + } } return c.Next() diff --git a/middleware/healthcheck/healthcheck_test.go b/middleware/healthcheck/healthcheck_test.go index df0165f158..84fbb43da0 100644 --- a/middleware/healthcheck/healthcheck_test.go +++ b/middleware/healthcheck/healthcheck_test.go @@ -1,28 +1,120 @@ package healthcheck import ( + "fmt" "net/http/httptest" "testing" - "time" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/utils" "github.com/valyala/fasthttp" ) +func shouldGiveStatus(t *testing.T, app *fiber.App, path string, expectedStatus int) { + t.Helper() + req, err := app.Test(httptest.NewRequest(fiber.MethodGet, path, nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, expectedStatus, req.StatusCode, "path: "+path+" should match "+fmt.Sprint(expectedStatus)) +} + +func shouldGiveOK(t *testing.T, app *fiber.App, path string) { + t.Helper() + shouldGiveStatus(t, app, path, fiber.StatusOK) +} + +func shouldGiveNotFound(t *testing.T, app *fiber.App, path string) { + t.Helper() + shouldGiveStatus(t, app, path, fiber.StatusNotFound) +} + +func Test_HealthCheck_Strict_Routing_Default(t *testing.T) { + t.Parallel() + + app := fiber.New(fiber.Config{ + StrictRouting: true, + }) + + app.Use(New()) + + shouldGiveOK(t, app, "/readyz") + shouldGiveOK(t, app, "/livez") + shouldGiveNotFound(t, app, "/readyz/") + shouldGiveNotFound(t, app, "/livez/") + shouldGiveNotFound(t, app, "/notDefined/readyz") + shouldGiveNotFound(t, app, "/notDefined/livez") +} + +func Test_HealthCheck_Group_Default(t *testing.T) { + t.Parallel() + + app := fiber.New() + app.Group("/v1", New()) + v2Group := app.Group("/v2/") + customer := v2Group.Group("/customer/") + customer.Use(New()) + + v3Group := app.Group("/v3/") + v3Group.Group("/todos/", New(Config{ReadinessEndpoint: "/readyz/", LivenessEndpoint: "/livez/"})) + + shouldGiveOK(t, app, "/v1/readyz") + shouldGiveOK(t, app, "/v1/livez") + shouldGiveOK(t, app, "/v1/readyz/") + shouldGiveOK(t, app, "/v1/livez/") + shouldGiveOK(t, app, "/v2/customer/readyz") + shouldGiveOK(t, app, "/v2/customer/livez") + shouldGiveOK(t, app, "/v2/customer/readyz/") + shouldGiveOK(t, app, "/v2/customer/livez/") + shouldGiveNotFound(t, app, "/v3/todos/readyz") + shouldGiveNotFound(t, app, "/v3/todos/livez") + shouldGiveOK(t, app, "/v3/todos/readyz/") + shouldGiveOK(t, app, "/v3/todos/livez/") + shouldGiveNotFound(t, app, "/notDefined/readyz") + shouldGiveNotFound(t, app, "/notDefined/livez") + shouldGiveNotFound(t, app, "/notDefined/readyz/") + shouldGiveNotFound(t, app, "/notDefined/livez/") + + // strict routing + app = fiber.New(fiber.Config{ + StrictRouting: true, + }) + app.Group("/v1", New()) + v2Group = app.Group("/v2/") + customer = v2Group.Group("/customer/") + customer.Use(New()) + + v3Group = app.Group("/v3/") + v3Group.Group("/todos/", New(Config{ReadinessEndpoint: "/readyz/", LivenessEndpoint: "/livez/"})) + + shouldGiveOK(t, app, "/v1/readyz") + shouldGiveOK(t, app, "/v1/livez") + shouldGiveNotFound(t, app, "/v1/readyz/") + shouldGiveNotFound(t, app, "/v1/livez/") + shouldGiveOK(t, app, "/v2/customer/readyz") + shouldGiveOK(t, app, "/v2/customer/livez") + shouldGiveNotFound(t, app, "/v2/customer/readyz/") + shouldGiveNotFound(t, app, "/v2/customer/livez/") + shouldGiveNotFound(t, app, "/v3/todos/readyz") + shouldGiveNotFound(t, app, "/v3/todos/livez") + shouldGiveOK(t, app, "/v3/todos/readyz/") + shouldGiveOK(t, app, "/v3/todos/livez/") + shouldGiveNotFound(t, app, "/notDefined/readyz") + shouldGiveNotFound(t, app, "/notDefined/livez") + shouldGiveNotFound(t, app, "/notDefined/readyz/") + shouldGiveNotFound(t, app, "/notDefined/livez/") +} + func Test_HealthCheck_Default(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) - req, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/readyz", nil)) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, req.StatusCode) - - req, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/livez", nil)) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, req.StatusCode) + shouldGiveOK(t, app, "/readyz") + shouldGiveOK(t, app, "/livez") + shouldGiveOK(t, app, "/readyz/") + shouldGiveOK(t, app, "/livez/") + shouldGiveNotFound(t, app, "/notDefined/readyz") + shouldGiveNotFound(t, app, "/notDefined/livez") } func Test_HealthCheck_Custom(t *testing.T) { @@ -31,11 +123,6 @@ func Test_HealthCheck_Custom(t *testing.T) { app := fiber.New() c1 := make(chan struct{}, 1) - go func() { - time.Sleep(1 * time.Second) - c1 <- struct{}{} - }() - app.Use(New(Config{ LivenessProbe: func(c *fiber.Ctx) bool { return true @@ -53,12 +140,9 @@ func Test_HealthCheck_Custom(t *testing.T) { })) // Live should return 200 with GET request - req, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/live", nil)) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, req.StatusCode) - + shouldGiveOK(t, app, "/live") // Live should return 404 with POST request - req, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/live", nil)) + req, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/live", nil)) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, fiber.StatusNotFound, req.StatusCode) @@ -68,16 +152,53 @@ func Test_HealthCheck_Custom(t *testing.T) { utils.AssertEqual(t, fiber.StatusNotFound, req.StatusCode) // Ready should return 503 with GET request before the channel is closed - req, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/ready", nil)) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusServiceUnavailable, req.StatusCode) - - time.Sleep(1 * time.Second) + shouldGiveStatus(t, app, "/ready", fiber.StatusServiceUnavailable) // Ready should return 200 with GET request after the channel is closed - req, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/ready", nil)) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusOK, req.StatusCode) + c1 <- struct{}{} + shouldGiveOK(t, app, "/ready") +} + +func Test_HealthCheck_Custom_Nested(t *testing.T) { + t.Parallel() + + app := fiber.New() + + c1 := make(chan struct{}, 1) + + app.Use(New(Config{ + LivenessProbe: func(c *fiber.Ctx) bool { + return true + }, + LivenessEndpoint: "/probe/live", + ReadinessProbe: func(c *fiber.Ctx) bool { + select { + case <-c1: + return true + default: + return false + } + }, + ReadinessEndpoint: "/probe/ready", + })) + + shouldGiveOK(t, app, "/probe/live") + shouldGiveStatus(t, app, "/probe/ready", fiber.StatusServiceUnavailable) + shouldGiveOK(t, app, "/probe/live/") + shouldGiveStatus(t, app, "/probe/ready/", fiber.StatusServiceUnavailable) + shouldGiveNotFound(t, app, "/probe/livez") + shouldGiveNotFound(t, app, "/probe/readyz") + shouldGiveNotFound(t, app, "/probe/livez/") + shouldGiveNotFound(t, app, "/probe/readyz/") + shouldGiveNotFound(t, app, "/livez") + shouldGiveNotFound(t, app, "/readyz") + shouldGiveNotFound(t, app, "/readyz/") + shouldGiveNotFound(t, app, "/livez/") + + c1 <- struct{}{} + shouldGiveOK(t, app, "/probe/ready") + c1 <- struct{}{} + shouldGiveOK(t, app, "/probe/ready/") } func Test_HealthCheck_Next(t *testing.T) { @@ -91,9 +212,8 @@ func Test_HealthCheck_Next(t *testing.T) { }, })) - req, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/livez", nil)) - utils.AssertEqual(t, nil, err) - utils.AssertEqual(t, fiber.StatusNotFound, req.StatusCode) + shouldGiveNotFound(t, app, "/readyz") + shouldGiveNotFound(t, app, "/livez") } func Benchmark_HealthCheck(b *testing.B) {