Skip to content

Commit

Permalink
feat: allow check trusted origins
Browse files Browse the repository at this point in the history
Signed-off-by: rogerogers <rogers@rogerogers.com>
  • Loading branch information
rogerogers committed Aug 31, 2024
1 parent cb5eab0 commit 580d548
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 7 deletions.
23 changes: 22 additions & 1 deletion csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"io"
"math/rand"
"net/textproto"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -92,12 +93,19 @@ func New(opts ...Option) app.HandlerFunc {
return
}

if cfg.checkTrustedOrigins {
if !matchTrustedOrigins(c, cfg.TrustedOrigins) {
c.Error(errBadReferer)
cfg.ErrorFunc(ctx, c)
return
}
}

if tokenize(cfg.Secret, salt) != token {
c.Error(errInvalidToken)
cfg.ErrorFunc(ctx, c)
return
}

c.Next(ctx)
}
}
Expand Down Expand Up @@ -166,3 +174,16 @@ func randStr(n int) string {
}
return sb.String()
}

func matchTrustedOrigins(c *app.RequestContext, trustedOrigins []string) bool {
match := false
if referer, err := url.Parse(string(c.GetHeader("Referer"))); err == nil && referer.String() != "" {
for _, trustedOrigin := range trustedOrigins {
if referer.Scheme+"://"+referer.Host == trustedOrigin {
match = true
break
}
}
}
return match
}
40 changes: 34 additions & 6 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ const (
)

var (
ErrMissingCookie = errors.New("[CSRF] missing csrf token in cookie")
errMissingHeader = errors.New("[CSRF] missing csrf token in header")
errMissingQuery = errors.New("[CSRF] missing csrf token in query")
errMissingParam = errors.New("[CSRF] missing csrf token in param")
errMissingForm = errors.New("[CSRF] missing csrf token in form")
errMissingSalt = errors.New("[CSRF] missing salt")
errInvalidToken = errors.New("[CSRF] invalid token")
errBadReferer = errors.New("[CSRF] invalid referer")
)

type CsrfNextHandler func(ctx context.Context, c *app.RequestContext) bool
Expand Down Expand Up @@ -103,6 +105,12 @@ type Options struct {
//
// Optional. Default will create an Extractor based on KeyLookup.
Extractor CsrfExtractorHandler

// Optional. Default: []
TrustedOrigins []string

// Optional. Default: false
checkTrustedOrigins bool
}

func (o *Options) Apply(opts []Option) {
Expand All @@ -123,11 +131,13 @@ var OptionsDefault = Options{

func NewOptions(opts ...Option) *Options {
options := &Options{
Secret: OptionsDefault.Secret,
IgnoreMethods: OptionsDefault.IgnoreMethods,
Next: OptionsDefault.Next,
KeyLookup: OptionsDefault.KeyLookup,
ErrorFunc: OptionsDefault.ErrorFunc,
Secret: OptionsDefault.Secret,
IgnoreMethods: OptionsDefault.IgnoreMethods,
Next: OptionsDefault.Next,
KeyLookup: OptionsDefault.KeyLookup,
ErrorFunc: OptionsDefault.ErrorFunc,
TrustedOrigins: []string{},
checkTrustedOrigins: false,
}
options.Apply(opts)
return options
Expand Down Expand Up @@ -179,11 +189,29 @@ func WithErrorFunc(f app.HandlerFunc) Option {
}
}

// WithExtractor sets extractor.
// WithExtractor sets Extractor.
func WithExtractor(f CsrfExtractorHandler) Option {
return Option{
F: func(o *Options) {
o.Extractor = f
},
}
}

// WithTrustedOrigins sets TrustedOrigins
func WithTrustedOrigins(t []string) Option {
return Option{
F: func(o *Options) {
o.TrustedOrigins = t
},
}
}

// WithCheckTrustedOrigins sets checkTrustedOrigins
func WithCheckTrustedOrigins() Option {
return Option{
F: func(o *Options) {
o.checkTrustedOrigins = true
},
}
}

0 comments on commit 580d548

Please sign in to comment.