Skip to content

Commit

Permalink
Add AllowOriginVaryRequestFunc to correctly handle Vary header
Browse files Browse the repository at this point in the history
Fixes #157
  • Loading branch information
rs committed Sep 5, 2023
1 parent 6c6189c commit 6599721
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 31 deletions.
86 changes: 55 additions & 31 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,23 @@ type Options struct {
// Only one wildcard can be used per origin.
// Default value is ["*"]
AllowedOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It take the origin
// as argument and returns true if allowed or false otherwise. If this option is
// set, the content of AllowedOrigins is ignored.
// AllowOriginFunc is a custom function to validate the origin. It take the
// origin as argument and returns true if allowed or false otherwise. If
// this option is set, the content of `AllowedOrigins` is ignored.
AllowOriginFunc func(origin string) bool
// AllowOriginRequestFunc is a custom function to validate the origin. It takes the HTTP Request object and the origin as
// argument and returns true if allowed or false otherwise. If this option is set, the content of `AllowedOrigins`
// and `AllowOriginFunc` is ignored.
// AllowOriginRequestFunc is a custom function to validate the origin. It
// takes the HTTP Request object and the origin as argument and returns true
// if allowed or false otherwise. If headers are used take the decision,
// consider using AllowOriginVaryRequestFunc instead. If this option is set,
// the content of `AllowedOrigins`, `AllowOriginFunc` are ignored.
AllowOriginRequestFunc func(r *http.Request, origin string) bool
// AllowOriginVaryRequestFunc is a custom function to validate the origin.
// It takes the HTTP Request object and the origin as argument and returns
// true if allowed or false otherwise with a list of headers used to take
// that decision if any so they can be added to the Vary header. If this
// option is set, the content of `AllowedOrigins`, `AllowOriginFunc` and
// `AllowOriginRequestFunc` are ignored.
AllowOriginVaryRequestFunc func(r *http.Request, origin string) (bool, []string)
// AllowedMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (HEAD, GET and POST).
AllowedMethods []string
Expand Down Expand Up @@ -91,9 +100,7 @@ type Cors struct {
// List of allowed origins containing wildcards
allowedWOrigins []wildcard
// Optional origin validator function
allowOriginFunc func(origin string) bool
// Optional origin validator (with request) function
allowOriginRequestFunc func(r *http.Request, origin string) bool
allowOriginFunc func(r *http.Request, origin string) (bool, []string)
// Normalized list of allowed headers
allowedHeaders []string
// Normalized list of allowed methods
Expand All @@ -115,26 +122,36 @@ type Cors struct {
// New creates a new Cors handler with the provided options.
func New(options Options) *Cors {
c := &Cors{
exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey),
allowOriginFunc: options.AllowOriginFunc,
allowOriginRequestFunc: options.AllowOriginRequestFunc,
allowCredentials: options.AllowCredentials,
allowPrivateNetwork: options.AllowPrivateNetwork,
maxAge: options.MaxAge,
optionPassthrough: options.OptionsPassthrough,
Log: options.Logger,
exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey),
allowCredentials: options.AllowCredentials,
allowPrivateNetwork: options.AllowPrivateNetwork,
maxAge: options.MaxAge,
optionPassthrough: options.OptionsPassthrough,
Log: options.Logger,
}
if options.Debug && c.Log == nil {
c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags)
}

if options.AllowOriginVaryRequestFunc != nil {
c.allowOriginFunc = options.AllowOriginVaryRequestFunc
} else if options.AllowOriginRequestFunc != nil {
c.allowOriginFunc = func(r *http.Request, origin string) (bool, []string) {
return options.AllowOriginRequestFunc(r, origin), nil
}
} else if options.AllowOriginFunc != nil {
c.allowOriginFunc = func(r *http.Request, origin string) (bool, []string) {
return options.AllowOriginFunc(origin), nil
}
}

// Normalize options
// Note: for origins matching, the spec requires a case-sensitive matching.
// As it may error prone, we chose to ignore the spec here.

// Allowed Origins
if len(options.AllowedOrigins) == 0 {
if options.AllowOriginFunc == nil && options.AllowOriginRequestFunc == nil {
if c.allowOriginFunc == nil {
// Default is all origins
c.allowedOriginsAll = true
}
Expand Down Expand Up @@ -294,11 +311,16 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
headers.Add("Vary", "Access-Control-Request-Private-Network")
}

allowed, additionalVaryHeaders := c.isOriginAllowed(r, origin)
if len(additionalVaryHeaders) > 0 {
headers.Add("Vary", strings.Join(convert(additionalVaryHeaders, http.CanonicalHeaderKey), ", "))
}

if origin == "" {
c.logf(" Preflight aborted: empty origin")
return
}
if !c.isOriginAllowed(r, origin) {
if !allowed {
c.logf(" Preflight aborted: origin '%s' not allowed", origin)
return
}
Expand Down Expand Up @@ -349,13 +371,18 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
headers := w.Header()
origin := r.Header.Get("Origin")

allowed, additionalVaryHeaders := c.isOriginAllowed(r, origin)

// Always set Vary, see https://github.com/rs/cors/issues/10
headers.Add("Vary", "Origin")
if len(additionalVaryHeaders) > 0 {
headers.Add("Vary", strings.Join(convert(additionalVaryHeaders, http.CanonicalHeaderKey), ", "))
}
if origin == "" {
c.logf(" Actual request no headers added: missing origin")
return
}
if !c.isOriginAllowed(r, origin) {
if !allowed {
c.logf(" Actual request no headers added: origin '%s' not allowed", origin)
return
}
Expand All @@ -366,7 +393,6 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
// We think it's a nice feature to be able to have control on those methods though.
if !c.isMethodAllowed(r.Method) {
c.logf(" Actual request no headers added: method '%s' not allowed", r.Method)

return
}
if c.allowedOriginsAll {
Expand All @@ -393,33 +419,31 @@ func (c *Cors) logf(format string, a ...interface{}) {
// check the Origin of a request. No origin at all is also allowed.
func (c *Cors) OriginAllowed(r *http.Request) bool {
origin := r.Header.Get("Origin")
return c.isOriginAllowed(r, origin)
allowed, _ := c.isOriginAllowed(r, origin)
return allowed
}

// isOriginAllowed checks if a given origin is allowed to perform cross-domain requests
// on the endpoint
func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool {
if c.allowOriginRequestFunc != nil {
return c.allowOriginRequestFunc(r, origin)
}
func (c *Cors) isOriginAllowed(r *http.Request, origin string) (allowed bool, varyHeaders []string) {
if c.allowOriginFunc != nil {
return c.allowOriginFunc(origin)
return c.allowOriginFunc(r, origin)
}
if c.allowedOriginsAll {
return true
return true, nil
}
origin = strings.ToLower(origin)
for _, o := range c.allowedOrigins {
if o == origin {
return true
return true, nil
}
}
for _, w := range c.allowedWOrigins {
if w.match(origin) {
return true
return true, nil
}
}
return false
return false, nil
}

// isMethodAllowed checks if a given method can be used as part of a cross-domain request
Expand Down
18 changes: 18 additions & 0 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,24 @@ func TestSpec(t *testing.T) {
},
true,
},
{
"AllowOriginVaryRequestFuncMatch",
Options{
AllowOriginVaryRequestFunc: func(r *http.Request, o string) (bool, []string) {
return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret", []string{"Authorization"}
},
},
"GET",
map[string]string{
"Origin": "http://foobar.com",
"Authorization": "secret",
},
map[string]string{
"Vary": "Origin, Authorization",
"Access-Control-Allow-Origin": "http://foobar.com",
},
true,
},
{
"AllowOriginRequestFuncNotMatch",
Options{
Expand Down

0 comments on commit 6599721

Please sign in to comment.