diff --git a/middleware/jwt.go b/middleware/jwt.go index da00ea56b..fab4d6fdf 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -29,15 +29,19 @@ type ( // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. ErrorHandlerWithContext JWTErrorHandlerWithContext - // Signing key to validate token. Used as fallback if SigningKeys has length 0. - // Required. This or SigningKeys. + // Signing key to validate token. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither user-defined KeyFunc nor SigningKeys is provided. SigningKey interface{} // Map of signing keys to validate token with kid field usage. - // Required. This or SigningKey. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither user-defined KeyFunc nor SigningKey is provided. SigningKeys map[string]interface{} - // Signing method, used to check token signing method. + // Signing method used to check the token's signing algorithm. // Optional. Default value HS256. SigningMethod string @@ -64,7 +68,16 @@ type ( // Optional. Default value "Bearer". AuthScheme string - keyFunc jwt.Keyfunc + // KeyFunc defines a user-defined function that supplies the public key for a token validation. + // The function shall take care of verifying the signing algorithm and selecting the proper key. + // A user-defined KeyFunc can be useful if tokens are issued by an external party. + // + // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither SigningKeys nor SigningKey is provided. + // Default to an internal implementation verifying the signing algorithm and selecting the proper key. + KeyFunc jwt.Keyfunc } // JWTSuccessHandler defines a function which is executed for a valid token. @@ -99,6 +112,7 @@ var ( TokenLookup: "header:" + echo.HeaderAuthorization, AuthScheme: "Bearer", Claims: jwt.MapClaims{}, + KeyFunc: nil, } ) @@ -123,7 +137,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultJWTConfig.Skipper } - if config.SigningKey == nil && len(config.SigningKeys) == 0 { + if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil { panic("echo: jwt middleware requires signing key") } if config.SigningMethod == "" { @@ -141,21 +155,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.AuthScheme == "" { config.AuthScheme = DefaultJWTConfig.AuthScheme } - config.keyFunc = func(t *jwt.Token) (interface{}, error) { - // Check the signing method - if t.Method.Alg() != config.SigningMethod { - return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) - } - if len(config.SigningKeys) > 0 { - if kid, ok := t.Header["kid"].(string); ok { - if key, ok := config.SigningKeys[kid]; ok { - return key, nil - } - } - return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) - } - - return config.SigningKey, nil + if config.KeyFunc == nil { + config.KeyFunc = config.defaultKeyFunc } // Initialize @@ -196,11 +197,11 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { token := new(jwt.Token) // Issue #647, #656 if _, ok := config.Claims.(jwt.MapClaims); ok { - token, err = jwt.Parse(auth, config.keyFunc) + token, err = jwt.Parse(auth, config.KeyFunc) } else { t := reflect.ValueOf(config.Claims).Type().Elem() claims := reflect.New(t).Interface().(jwt.Claims) - token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc) + token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc) } if err == nil && token.Valid { // Store user information from token into context. @@ -225,6 +226,24 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { } } +// defaultKeyFunc returns a signing key of the given token. +func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { + // Check the signing method + if t.Method.Alg() != config.SigningMethod { + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) + } + if len(config.SigningKeys) > 0 { + if kid, ok := t.Header["kid"].(string); ok { + if key, ok := config.SigningKeys[kid]; ok { + return key, nil + } + } + return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) + } + + return config.SigningKey, nil +} + // jwtFromHeader returns a `jwtExtractor` that extracts token from the request header. func jwtFromHeader(header string, authScheme string) jwtExtractor { return func(c echo.Context) (string, error) { diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 205721aec..37de31fcc 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "net/http" "net/http/httptest" "net/url" @@ -220,6 +221,35 @@ func TestJWT(t *testing.T) { expErrCode: http.StatusBadRequest, info: "Empty form field", }, + { + hdrAuth: validAuth, + config: JWTConfig{ + KeyFunc: func(*jwt.Token) (interface{}, error) { + return validKey, nil + }, + }, + info: "Valid JWT with a valid key using a user-defined KeyFunc", + }, + { + hdrAuth: validAuth, + config: JWTConfig{ + KeyFunc: func(*jwt.Token) (interface{}, error) { + return invalidKey, nil + }, + }, + expErrCode: http.StatusUnauthorized, + info: "Valid JWT with an invalid key using a user-defined KeyFunc", + }, + { + hdrAuth: validAuth, + config: JWTConfig{ + KeyFunc: func(*jwt.Token) (interface{}, error) { + return nil, errors.New("faulty KeyFunc") + }, + }, + expErrCode: http.StatusUnauthorized, + info: "Token verification does not pass using a user-defined KeyFunc", + }, } { if tc.reqURL == "" { tc.reqURL = "/"