Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance token checking #3842

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/api/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func PostHook(c *gin.Context) {
//

// get the token and verify the hook is authorized
parsedToken, err := token.ParseRequest(c.Request, func(_ *token.Token) (string, error) {
parsedToken, err := token.ParseRequest([]token.Type{token.HookToken}, c.Request, func(_ *token.Token) (string, error) {
return repo.Hash, nil
})
if err != nil {
Expand Down
13 changes: 5 additions & 8 deletions server/router/middleware/session/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,14 @@ func AuthorizeAgent(c *gin.Context) {
return
}

parsed, err := token.ParseRequest(c.Request, func(_ *token.Token) (string, error) {
_, err := token.ParseRequest([]token.Type{token.AgentToken}, c.Request, func(_ *token.Token) (string, error) {
return secret, nil
})
switch {
case err != nil:
if err != nil {
c.String(http.StatusInternalServerError, "invalid or empty token. %s", err)
c.Abort()
case parsed.Kind != token.AgentToken:
c.String(http.StatusForbidden, "invalid token. please use an agent token")
c.Abort()
default:
c.Next()
return
}

c.Next()
}
4 changes: 2 additions & 2 deletions server/router/middleware/session/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func SetUser() gin.HandlerFunc {
return func(c *gin.Context) {
var user *model.User

t, err := token.ParseRequest(c.Request, func(t *token.Token) (string, error) {
t, err := token.ParseRequest([]token.Type{token.UserToken, token.SessToken}, c.Request, func(t *token.Token) (string, error) {
var err error
userID, err := strconv.ParseInt(t.Get("user-id"), 10, 64)
if err != nil {
Expand All @@ -58,7 +58,7 @@ func SetUser() gin.HandlerFunc {
// if this is a session token (ie not the API token)
// this means the user is accessing with a web browser,
// so we should implement CSRF protection measures.
if t.Kind == token.SessToken {
if t.Type == token.SessToken {
err = token.CheckCsrf(c.Request, func(_ *token.Token) (string, error) {
return user.Hash, nil
})
Expand Down
56 changes: 36 additions & 20 deletions shared/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,52 @@ import (

type SecretFunc func(*Token) (string, error)

type Type string

const (
UserToken = "user"
SessToken = "sess"
HookToken = "hook"
CsrfToken = "csrf"
AgentToken = "agent"
UserToken Type = "user" // user token (exp cli)
SessToken Type = "sess" // session token (ui token requires csrf check)
HookToken Type = "hook" // repo hook token
CsrfToken Type = "csrf"
AgentToken Type = "agent"
)

// SignerAlgo id default algorithm used to sign JWT tokens.
const SignerAlgo = "HS256"

type Token struct {
Kind string
Type Type
claims jwt.MapClaims
}

func parse(raw string, fn SecretFunc) (*Token, error) {
func Parse(allowedTypes []Type, raw string, fn SecretFunc) (*Token, error) {
token := &Token{
claims: jwt.MapClaims{},
}
parsed, err := jwt.Parse(raw, keyFunc(token, fn))
if err != nil {
return nil, err
} else if !parsed.Valid {
}
if !parsed.Valid {
return nil, jwt.ErrTokenUnverifiable
}

hasAllowedType := false
for _, k := range allowedTypes {
if k == token.Type {
hasAllowedType = true
break
}
}

if !hasAllowedType {
return nil, jwt.ErrInvalidType
}

return token, nil
}

func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) {
func ParseRequest(allowedTypes []Type, r *http.Request, fn SecretFunc) (*Token, error) {
// first we attempt to get the token from the
// authorization header.
token := r.Header.Get("Authorization")
Expand All @@ -63,19 +79,19 @@ func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) {
if _, err := fmt.Sscanf(token, "Bearer %s", &bearer); err != nil {
return nil, err
}
return parse(bearer, fn)
return Parse(allowedTypes, bearer, fn)
}

token = r.Header.Get("X-Gitlab-Token")
if len(token) != 0 {
return parse(token, fn)
return Parse(allowedTypes, token, fn)
}

// then we attempt to get the token from the
// access_token url query parameter
token = r.FormValue("access_token")
if len(token) != 0 {
return parse(token, fn)
return Parse(allowedTypes, token, fn)
}

// and finally we attempt to get the token from
Expand All @@ -84,7 +100,7 @@ func ParseRequest(r *http.Request, fn SecretFunc) (*Token, error) {
if err != nil {
return nil, err
}
return parse(cookie.Value, fn)
return Parse(allowedTypes, cookie.Value, fn)
}

func CheckCsrf(r *http.Request, fn SecretFunc) error {
Expand All @@ -97,12 +113,12 @@ func CheckCsrf(r *http.Request, fn SecretFunc) error {

// parse the raw CSRF token value and validate
raw := r.Header.Get("X-CSRF-TOKEN")
_, err := parse(raw, fn)
_, err := Parse([]Type{CsrfToken}, raw, fn)
return err
}

func New(kind string) *Token {
return &Token{Kind: kind, claims: jwt.MapClaims{}}
func New(tokenType Type) *Token {
return &Token{Type: tokenType, claims: jwt.MapClaims{}}
}

// Sign signs the token using the given secret hash
Expand All @@ -124,7 +140,7 @@ func (t *Token) SignExpires(secret string, exp int64) (string, error) {
claims[k] = v
}

claims["type"] = t.Kind
claims["type"] = t.Type
if exp > 0 {
claims["exp"] = float64(exp)
}
Expand Down Expand Up @@ -157,12 +173,12 @@ func keyFunc(token *Token, fn SecretFunc) jwt.Keyfunc {
return nil, jwt.ErrSignatureInvalid
}

// extract the token kind and cast to the expected type
kind, ok := claims["type"]
// extract the token type and cast to the expected type
tokenType, ok := claims["type"].(string)
if !ok {
return nil, jwt.ErrInvalidType
}
token.Kind, _ = kind.(string)
token.Type = Type(tokenType)

// copy custom claims
for k, v := range claims {
Expand Down
62 changes: 62 additions & 0 deletions shared/token/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package token_test

import (
"testing"

"github.com/franela/goblin"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"

"go.woodpecker-ci.org/woodpecker/v2/shared/token"
)

func TestToken(t *testing.T) {
gin.SetMode(gin.TestMode)

g := goblin.Goblin(t)
g.Describe("Token", func() {
jwtSecret := "secret-to-sign-the-token"

g.It("should parse a valid token", func() {
_token := token.New(token.UserToken)
_token.Set("user-id", "1")
signedToken, err := _token.Sign(jwtSecret)
assert.NoError(g, err)

parsed, err := token.Parse([]token.Type{token.UserToken}, signedToken, func(_ *token.Token) (string, error) {
return jwtSecret, nil
})

assert.NoError(g, err)
assert.NotNil(g, parsed)
assert.Equal(g, "1", parsed.Get("user-id"))
})

g.It("should fail to parse a token with a wrong type", func() {
_token := token.New(token.UserToken)
_token.Set("user-id", "1")
signedToken, err := _token.Sign(jwtSecret)
assert.NoError(g, err)

_, err = token.Parse([]token.Type{token.AgentToken}, signedToken, func(_ *token.Token) (string, error) {
return jwtSecret, nil
})

assert.ErrorIs(g, err, jwt.ErrInvalidType)
})

g.It("should fail to parse a token with a wrong secret", func() {
_token := token.New(token.UserToken)
_token.Set("user-id", "1")
signedToken, err := _token.Sign(jwtSecret)
assert.NoError(g, err)

_, err = token.Parse([]token.Type{token.UserToken}, signedToken, func(_ *token.Token) (string, error) {
return "this-is-a-wrong-secret", nil
})

assert.ErrorIs(g, err, jwt.ErrSignatureInvalid)
})
})
}