Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance the auth hook func to support external JWT #811

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions bootstrap/container/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,15 @@ func ScheduleActionRecordClientFrom(get di.Get) interfaces.ScheduleActionRecordC

return get(ScheduleActionRecordClientName).(interfaces.ScheduleActionRecordClient)
}

// SecurityProxyAuthClientName contains the name of the AuthClient's implementation in the DIC.
var SecurityProxyAuthClientName = di.TypeInstanceToName((*interfaces.AuthClient)(nil))

// SecurityProxyAuthClientFrom helper function queries the DIC and returns the AuthClient's implementation.
func SecurityProxyAuthClientFrom(get di.Get) interfaces.AuthClient {
if get(SecurityProxyAuthClientName) == nil {
return nil
}

return get(SecurityProxyAuthClientName).(interfaces.AuthClient)
}
3 changes: 1 addition & 2 deletions bootstrap/controller/commonapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ type config struct {

func NewCommonController(dic *di.Container, r *echo.Echo, serviceName string, serviceVersion string) *CommonController {
lc := container.LoggingClientFrom(dic.Get)
secretProvider := container.SecretProviderExtFrom(dic.Get)
authenticationHook := handlers.AutoConfigAuthenticationFunc(secretProvider, lc)
authenticationHook := handlers.AutoConfigAuthenticationFunc(dic)
configuration := container.ConfigurationFrom(dic.Get)
c := CommonController{
dic: dic,
Expand Down
8 changes: 3 additions & 5 deletions bootstrap/handlers/auth_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ import (
"os"
"strconv"

"github.com/edgexfoundry/go-mod-core-contracts/v4/clients/logger"

"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/interfaces"
"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/secret"
"github.com/edgexfoundry/go-mod-bootstrap/v4/di"

"github.com/labstack/echo/v4"
)
Expand All @@ -44,12 +42,12 @@ func NilAuthenticationHandlerFunc() echo.MiddlewareFunc {
// to disable JWT validation. This might be wanted for an EdgeX
// adopter that wanted to only validate JWT's at the proxy layer,
// or as an escape hatch for a caller that cannot authenticate.
func AutoConfigAuthenticationFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient) echo.MiddlewareFunc {
func AutoConfigAuthenticationFunc(dic *di.Container) echo.MiddlewareFunc {
// Golang standard library treats an error as false
disableJWTValidation, _ := strconv.ParseBool(os.Getenv("EDGEX_DISABLE_JWT_VALIDATION"))
authenticationHook := NilAuthenticationHandlerFunc()
if secret.IsSecurityEnabled() && !disableJWTValidation {
authenticationHook = SecretStoreAuthenticationHandlerFunc(secretProvider, lc)
authenticationHook = AuthenticationHandlerFunc(dic)
}
return authenticationHook
}
85 changes: 62 additions & 23 deletions bootstrap/handlers/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,43 @@ import (
"net/http"
"strings"

"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/container"
"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/handlers/headers"
"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/interfaces"
"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/zerotrust"
"github.com/edgexfoundry/go-mod-bootstrap/v4/di"
"github.com/edgexfoundry/go-mod-core-contracts/v4/clients/logger"
dtoCommon "github.com/edgexfoundry/go-mod-core-contracts/v4/dtos/common"

"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"github.com/openziti/sdk-golang/ziti/edge"
)

// SecretStoreAuthenticationHandlerFunc prefixes an existing HandlerFunc
// with a OpenBao-based JWT authentication check. Usage:
// openBaoIssuer defines the issuer if JWT was issued from OpenBao
const openBaoIssuer = "/v1/identity/oidc"
cloudxxx8 marked this conversation as resolved.
Show resolved Hide resolved

// AuthenticationHandlerFunc prefixes an existing HandlerFunc,
// performing authentication checks based on OpenBao-issued JWTs or external JWTs by checking the Authorization header. Usage:
//
// authenticationHook := handlers.NilAuthenticationHandlerFunc()
//
// authenticationHook := handlers.NilAuthenticationHandlerFunc()
// if secret.IsSecurityEnabled() {
// lc := container.LoggingClientFrom(dic.Get)
// secretProvider := container.SecretProviderFrom(dic.Get)
// authenticationHook = handlers.SecretStoreAuthenticationHandlerFunc(secretProvider, lc)
// }
// For optionally-authenticated requests
// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet)
// if secret.IsSecurityEnabled() {
// authenticationHook = handlers.AuthenticationHandlerFunc(dic)
// }
// For optionally-authenticated requests
// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet)
//
// For unauthenticated requests
// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet)
// For unauthenticated requests
// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet)
//
// For typical usage, it is preferred to use AutoConfigAuthenticationFunc which
// will automatically select between a real and a fake JWT validation handler.
func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient) echo.MiddlewareFunc {
func AuthenticationHandlerFunc(dic *di.Container) echo.MiddlewareFunc {
return func(inner echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
lc := container.LoggingClientFrom(dic.Get)
secretProvider := container.SecretProviderExtFrom(dic.Get)
r := c.Request()
w := c.Response()
authHeader := r.Header.Get("Authorization")
Expand All @@ -70,20 +79,29 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid
authParts := strings.Split(authHeader, " ")
if len(authParts) >= 2 && strings.EqualFold(authParts[0], "Bearer") {
token := authParts[1]
validToken, err := secretProvider.IsJWTValid(token)
if err != nil {
lc.Errorf("Error checking JWT validity: %v", err)
// set Response.Committed to true in order to rewrite the status code

parser := jwt.NewParser()
parsedToken, _, jwtErr := parser.ParseUnverified(token, &jwt.MapClaims{})
if jwtErr != nil {
w.Committed = false
return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
} else if !validToken {
lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path)
// set Response.Committed to true in order to rewrite the status code
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
issuer, jwtErr := parsedToken.Claims.GetIssuer()
if jwtErr != nil {
w.Committed = false
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
lc.Debugf("Request to '%s' authorized", r.URL.Path)
return inner(c)

if issuer == openBaoIssuer {
return SecretStoreAuthenticationHandlerFunc(secretProvider, lc, token, c)
} else {
// Verify the JWT by invoking security-proxy-auth http client
err := headers.VerifyJWT(token, issuer, parsedToken.Method.Alg(), dic, r.Context())
if err != nil {
errResp := dtoCommon.NewBaseResponse("", err.Error(), err.Code())
return c.JSON(err.Code(), errResp)
}
}
}
err := fmt.Errorf("unable to parse JWT for call to '%s'; unauthorized", r.URL.Path)
lc.Errorf("%v", err)
Expand All @@ -93,3 +111,24 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid
}
}
}

// SecretStoreAuthenticationHandlerFunc verifies the JWT with a OpenBao-based JWT authentication check
func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient, token string, c echo.Context) error {
cloudxxx8 marked this conversation as resolved.
Show resolved Hide resolved
r := c.Request()
w := c.Response()

validToken, err := secretProvider.IsJWTValid(token)
if err != nil {
lc.Errorf("Error checking JWT validity by the secret provider: %v ", err)
// set Response.Committed to true in order to rewrite the status code
w.Committed = false
return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
} else if !validToken {
lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path)
// set Response.Committed to true in order to rewrite the status code
w.Committed = false
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
lc.Debugf("Request to '%s' authorized", r.URL.Path)
return nil
}
84 changes: 60 additions & 24 deletions bootstrap/handlers/auth_middleware_no_ziti.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,38 @@ import (
"net/http"
"strings"

"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/container"
"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/handlers/headers"
"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/interfaces"
"github.com/edgexfoundry/go-mod-bootstrap/v4/di"
"github.com/edgexfoundry/go-mod-core-contracts/v4/clients/logger"
dtoCommon "github.com/edgexfoundry/go-mod-core-contracts/v4/dtos/common"

"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/interfaces"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
)

// SecretStoreAuthenticationHandlerFunc prefixes an existing HandlerFunc
// with a OpenBao-based JWT authentication check. Usage:
// AuthenticationHandlerFunc prefixes an existing HandlerFunc,
// performing authentication checks based on OpenBao-issued JWTs or external JWTs by checking the Authorization header. Usage:
//
// authenticationHook := handlers.NilAuthenticationHandlerFunc()
//
// authenticationHook := handlers.NilAuthenticationHandlerFunc()
// if secret.IsSecurityEnabled() {
// lc := container.LoggingClientFrom(dic.Get)
// secretProvider := container.SecretProviderFrom(dic.Get)
// authenticationHook = handlers.SecretStoreAuthenticationHandlerFunc(secretProvider, lc)
// }
// For optionally-authenticated requests
// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet)
// if secret.IsSecurityEnabled() {
// authenticationHook = handlers.AuthenticationHandlerFunc(dic)
// }
// For optionally-authenticated requests
// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet)
//
// For unauthenticated requests
// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet)
// For unauthenticated requests
// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet)
//
// For typical usage, it is preferred to use AutoConfigAuthenticationFunc which
// will automatically select between a real and a fake JWT validation handler.
func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient) echo.MiddlewareFunc {
func AuthenticationHandlerFunc(dic *di.Container) echo.MiddlewareFunc {
return func(inner echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
lc := container.LoggingClientFrom(dic.Get)
secretProvider := container.SecretProviderExtFrom(dic.Get)
r := c.Request()
w := c.Response()
authHeader := r.Header.Get("Authorization")
Expand All @@ -61,20 +67,29 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid
authParts := strings.Split(authHeader, " ")
if len(authParts) >= 2 && strings.EqualFold(authParts[0], "Bearer") {
token := authParts[1]
validToken, err := secretProvider.IsJWTValid(token)
if err != nil {
lc.Errorf("Error checking JWT validity: %v", err)
// set Response.Committed to true in order to rewrite the status code

parser := jwt.NewParser()
parsedToken, _, jwtErr := parser.ParseUnverified(token, &jwt.MapClaims{})
if jwtErr != nil {
w.Committed = false
return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
} else if !validToken {
lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path)
// set Response.Committed to true in order to rewrite the status code
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
issuer, jwtErr := parsedToken.Claims.GetIssuer()
if jwtErr != nil {
w.Committed = false
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
lc.Debugf("Request to '%s' authorized", r.URL.Path)
return inner(c)

if issuer == openBaoIssuer {
return SecretStoreAuthenticationHandlerFunc(secretProvider, lc, token, c)
} else {
// Verify the JWT by invoking security-proxy-auth http client
err := headers.VerifyJWT(token, issuer, parsedToken.Method.Alg(), dic, r.Context())
if err != nil {
errResp := dtoCommon.NewBaseResponse("", err.Error(), err.Code())
return c.JSON(err.Code(), errResp)
}
}
}
err := fmt.Errorf("unable to parse JWT for call to '%s'; unauthorized", r.URL.Path)
lc.Errorf("%v", err)
Expand All @@ -84,3 +99,24 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid
}
}
}

// SecretStoreAuthenticationHandlerFunc verifies the JWT with a OpenBao-based JWT authentication check
func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient, token string, c echo.Context) error {
r := c.Request()
w := c.Response()

validToken, err := secretProvider.IsJWTValid(token)
if err != nil {
lc.Errorf("Error checking JWT validity by the secret provider: %v ", err)
// set Response.Committed to true in order to rewrite the status code
w.Committed = false
return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
} else if !validToken {
lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path)
// set Response.Committed to true in order to rewrite the status code
w.Committed = false
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
lc.Debugf("Request to '%s' authorized", r.URL.Path)
return nil
}
70 changes: 70 additions & 0 deletions bootstrap/handlers/headers/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//
// Copyright (C) 2024 IOTech Ltd
//
// SPDX-License-Identifier: Apache-2.0

package headers

import (
"context"
stdErrs "errors"

"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/container"
"github.com/edgexfoundry/go-mod-bootstrap/v4/di"
"github.com/edgexfoundry/go-mod-core-contracts/v4/errors"

"github.com/golang-jwt/jwt/v5"
)

// VerifyJWT validates the JWT issued by security-proxy-auth by using the verification key provided by the security-proxy-auth service
func VerifyJWT(token string,
issuer string,
alg string,
dic *di.Container,
ctx context.Context) errors.EdgeX {
lc := container.LoggingClientFrom(dic.Get)

verifyKey, edgexErr := GetVerificationKey(dic, issuer, alg, ctx)
if edgexErr != nil {
return errors.NewCommonEdgeXWrapper(edgexErr)
}

err := ParseJWT(token, verifyKey, &jwt.MapClaims{}, jwt.WithExpirationRequired())
if err != nil {
if stdErrs.Is(err, jwt.ErrTokenExpired) {
// Skip the JWT expired error
lc.Debug("JWT is valid but expired")
return nil
} else {
if stdErrs.Is(err, jwt.ErrTokenMalformed) ||
stdErrs.Is(err, jwt.ErrTokenUnverifiable) ||
stdErrs.Is(err, jwt.ErrTokenSignatureInvalid) ||
stdErrs.Is(err, jwt.ErrTokenRequiredClaimMissing) {
lc.Errorf("Invalid jwt : %v\n", err)
return errors.NewCommonEdgeX(errors.KindUnauthorized, "invalid jwt", err)
}
lc.Errorf("Error occurred while validating JWT: %v", err)
return errors.NewCommonEdgeX(errors.Kind(err), "failed to parse jwt", err)
}
}
return nil
}

// ParseJWT parses and validates the JWT with the passed ParserOptions and returns the token which implements the Claim interface
func ParseJWT(token string, verifyKey any, claims jwt.Claims, parserOption ...jwt.ParserOption) error {
_, err := jwt.ParseWithClaims(token, claims, func(_ *jwt.Token) (any, error) {
return verifyKey, nil
}, parserOption...)
if err != nil {
return err
}

issuer, err := claims.GetIssuer()
if err != nil {
return errors.NewCommonEdgeX(errors.KindServerError, "failed to retrieve the issuer", err)
}
if len(issuer) == 0 {
return errors.NewCommonEdgeX(errors.KindUnauthorized, "issuer is empty", err)
}
return nil
}
Loading