Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🩹 Fix: CORS middleware should use the defined AllowedOriginsFunc config when AllowedOrigins is empty #2771

Merged
merged 1 commit into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,14 @@ func New(config ...Config) fiber.Handler {
if cfg.AllowMethods == "" {
cfg.AllowMethods = ConfigDefault.AllowMethods
}
if cfg.AllowOrigins == "" {
// When none of the AllowOrigins or AllowOriginsFunc config was defined, set the default AllowOrigins value with "*"
if cfg.AllowOrigins == "" && cfg.AllowOriginsFunc == nil {
cfg.AllowOrigins = ConfigDefault.AllowOrigins
}
}

// Warning logs if both AllowOrigins and AllowOriginsFunc are set
if cfg.AllowOrigins != ConfigDefault.AllowOrigins && cfg.AllowOriginsFunc != nil {
if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil {
log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.")
}

Expand Down Expand Up @@ -145,7 +146,7 @@ func New(config ...Config) fiber.Handler {
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if (allowOrigin == "" || allowOrigin == ConfigDefault.AllowOrigins) && cfg.AllowOriginsFunc != nil {
if allowOrigin == "" && cfg.AllowOriginsFunc != nil {
if cfg.AllowOriginsFunc(origin) {
allowOrigin = origin
}
Expand Down
201 changes: 198 additions & 3 deletions middleware/cors/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,9 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) {
// Perform request
handler(ctx)

// Allow-Origin header should be "*" because http://google.com does not satisfy 'strings.Contains(origin, "example-2")'
// and AllowOrigins has not been set so the default "*" is used
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
// Allow-Origin header should be empty because http://google.com does not satisfy 'strings.Contains(origin, "example-2")'
// and AllowOrigins has not been set
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))

ctx.Request.Reset()
ctx.Response.Reset()
Expand All @@ -348,3 +348,198 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) {
// Allow-Origin header should be "http://example-2.com"
utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}

func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
testCases := []struct {
Name string
Config Config
RequestOrigin string
ResponseOrigin string
}{
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginNotAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(origin string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginNotAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(origin string) bool {
return true
},
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "http://bbb.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(origin string) bool {
return false
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
Config: Config{
AllowOrigins: "http://aaa.com",
AllowOriginsFunc: func(origin string) bool {
return false
},
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncUndefined/OriginAllowed",
Config: Config{
AllowOrigins: "",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "*",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsTrue/OriginAllowed",
Config: Config{
AllowOrigins: "",
AllowOriginsFunc: func(origin string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
Config: Config{
AllowOrigins: "",
AllowOriginsFunc: func(origin string) bool {
return false
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "",
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
app := fiber.New()
app.Use("/", New(tc.Config))

handler := app.Handler()

ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)

handler(ctx)

utils.AssertEqual(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
})
}
}

// The fix for issue #2422
func Test_CORS_AllowCredetials(t *testing.T) {
testCases := []struct {
Name string
Config Config
RequestOrigin string
ResponseOrigin string
}{
{
Name: "AllowOriginsFuncDefined",
Config: Config{
AllowCredentials: true,
AllowOriginsFunc: func(origin string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
// The AllowOriginsFunc config was defined, should use the real origin of the function
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsFuncNotDefined",
Config: Config{
AllowCredentials: true,
},
RequestOrigin: "http://aaa.com",
// None of the AllowOrigins or AllowOriginsFunc config was defined, should use the default origin of "*"
// which will cause the CORS error in the client:
// The value of the 'Access-Control-Allow-Origin' header in the response must not be the wildcard '*'
// when the request's credentials mode is 'include'.
ResponseOrigin: "*",
},
{
Name: "AllowOriginsDefined",
Config: Config{
AllowCredentials: true,
AllowOrigins: "http://aaa.com",
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/UnallowedOrigin",
Config: Config{
AllowCredentials: true,
AllowOrigins: "http://aaa.com",
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
app := fiber.New()
app.Use("/", New(tc.Config))

handler := app.Handler()

ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)

handler(ctx)

if tc.Config.AllowCredentials {
utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
}
utils.AssertEqual(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
})
}
}
Loading