From c7a53d0d82357c776244bbca54b0b8fb845e896e Mon Sep 17 00:00:00 2001 From: Nicole Hubbard Date: Tue, 10 Aug 2021 09:14:05 -0500 Subject: [PATCH] feat: Improve OIDC config options [DEL-479] (#20) --- cmd/serve.go | 31 +++++++---- internal/hollowserver/server.go | 13 ++--- internal/hollowserver/server_test.go | 8 ++- pkg/api/v1/client_test.go | 4 +- pkg/api/v1/router_int_test.go | 12 +++-- pkg/ginjwt/jwt.go | 79 +++++++++++++++++++--------- pkg/ginjwt/jwt_test.go | 12 +++-- pkg/ginjwt/testtools.go | 7 +-- 8 files changed, 101 insertions(+), 65 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index eba5df28..d492b893 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -6,6 +6,7 @@ import ( "go.metalkube.net/hollow/internal/db" "go.metalkube.net/hollow/internal/hollowserver" + "go.metalkube.net/hollow/pkg/ginjwt" ) // serveCmd represents the serve command @@ -25,12 +26,18 @@ func init() { serveCmd.Flags().String("db-uri", "postgresql://root@db:26257/hollow_dev?sslmode=disable", "URI for database connection") viperBindFlag("db.uri", serveCmd.Flags().Lookup("db-uri")) - serveCmd.Flags().String("jwt-aud", "", "expected audience on JWT tokens") - viperBindFlag("jwt.audience", serveCmd.Flags().Lookup("jwt-aud")) - serveCmd.Flags().String("jwt-issuer", "https://equinixmetal.us.auth0.com/", "expected issuer of JWT tokens") - viperBindFlag("jwt.issuer", serveCmd.Flags().Lookup("jwt-issuer")) - serveCmd.Flags().String("jwt-jwksuri", "https://equinixmetal.us.auth0.com/.well-known/jwks.json", "URI for JWKS listing for JWTs") - viperBindFlag("jwt.jwksuri", serveCmd.Flags().Lookup("jwt-jwksuri")) + serveCmd.Flags().Bool("oidc", true, "use oidc auth") + viperBindFlag("oidc.enabled", serveCmd.Flags().Lookup("oidc")) + serveCmd.Flags().String("oidc-aud", "", "expected audience on OIDC JWT") + viperBindFlag("oidc.audience", serveCmd.Flags().Lookup("oidc-aud")) + serveCmd.Flags().String("oidc-issuer", "", "expected issuer of OIDC JWT") + viperBindFlag("oidc.issuer", serveCmd.Flags().Lookup("oidc-issuer")) + serveCmd.Flags().String("oidc-jwksuri", "", "URI for JWKS listing for JWTs") + viperBindFlag("oidc.jwksuri", serveCmd.Flags().Lookup("oidc-jwksuri")) + serveCmd.Flags().String("oidc-roles-claim", "claim", "field containing the permissions of an OIDC JWT") + viperBindFlag("oidc.claims.roles", serveCmd.Flags().Lookup("oidc-roles-claim")) + serveCmd.Flags().String("oidc-username-claim", "", "additional fields to output in logs from the JWT token, ex (email)") + viperBindFlag("oidc.claims.username", serveCmd.Flags().Lookup("oidc-username-claim")) } func serve() { @@ -48,10 +55,14 @@ func serve() { Listen: viper.GetString("listen"), Debug: viper.GetBool("logging.debug"), Store: store, - AuthConfig: hollowserver.AuthConfig{ - Audience: viper.GetString("jwt.audience"), - Issuer: viper.GetString("jwt.issuer"), - JWKSURI: viper.GetString("jwt.jwksuri"), + AuthConfig: ginjwt.AuthConfig{ + Enabled: viper.GetBool("oidc.enabled"), + Audience: viper.GetString("oidc.audience"), + Issuer: viper.GetString("oidc.issuer"), + JWKSURI: viper.GetString("oidc.jwksuri"), + LogFields: viper.GetStringSlice("oidc.log"), + RolesClaim: viper.GetString("oidc.claims.roles"), + UsernameClaim: viper.GetString("oidc.claims.username"), }, } diff --git a/internal/hollowserver/server.go b/internal/hollowserver/server.go index faea389d..c67ce7cc 100644 --- a/internal/hollowserver/server.go +++ b/internal/hollowserver/server.go @@ -20,14 +20,7 @@ type Server struct { Listen string Debug bool Store *db.Store - AuthConfig AuthConfig -} - -// AuthConfig provides the configuration for the authentication service -type AuthConfig struct { - Audience string - Issuer string - JWKSURI string + AuthConfig ginjwt.AuthConfig } var ( @@ -41,7 +34,7 @@ func (s *Server) setup() *gin.Engine { err error ) - authMW, err = ginjwt.NewAuthMiddleware(s.AuthConfig.Audience, s.AuthConfig.Issuer, s.AuthConfig.JWKSURI) + authMW, err = ginjwt.NewAuthMiddleware(s.AuthConfig) if err != nil { s.Logger.Sugar().Fatal("failed to initialize auth middleware", "error", err) } @@ -70,7 +63,7 @@ func (s *Server) setup() *gin.Engine { ginzap.WithUTC(true), ginzap.WithCustomFields( func(c *gin.Context) zap.Field { return zap.String("jwt_subject", ginjwt.GetSubject(c)) }, - func(c *gin.Context) zap.Field { return zap.String("jwt_email", ginjwt.GetEmail(c)) }, + func(c *gin.Context) zap.Field { return zap.String("jwt_user", ginjwt.GetUser(c)) }, ), )) r.Use(ginzap.RecoveryWithZap(s.Logger, true)) diff --git a/internal/hollowserver/server_test.go b/internal/hollowserver/server_test.go index 1df8db94..47d61bd1 100644 --- a/internal/hollowserver/server_test.go +++ b/internal/hollowserver/server_test.go @@ -11,13 +11,11 @@ import ( "go.metalkube.net/hollow/internal/db" "go.metalkube.net/hollow/internal/hollowserver" + "go.metalkube.net/hollow/pkg/ginjwt" ) -// use a public keys listing so tests pass -var serverAuthConfig = hollowserver.AuthConfig{ - Audience: "", - Issuer: "", - JWKSURI: "https://www.googleapis.com/oauth2/v3/certs", +var serverAuthConfig = ginjwt.AuthConfig{ + Enabled: false, } func TestUnknownRoute(t *testing.T) { diff --git a/pkg/api/v1/client_test.go b/pkg/api/v1/client_test.go index f3a6c9e8..2f2b9c49 100644 --- a/pkg/api/v1/client_test.go +++ b/pkg/api/v1/client_test.go @@ -141,9 +141,9 @@ func realClientTests(t *testing.T, f func(ctx context.Context, token string, res "auth token with no scopes", ctx, validToken([]string{}), - http.StatusUnauthorized, + http.StatusForbidden, true, - "server error - response code: 401, message:", + "server error - response code: 403, message:", }, { "fake timeout", diff --git a/pkg/api/v1/router_int_test.go b/pkg/api/v1/router_int_test.go index f14de6e6..359ca4ab 100644 --- a/pkg/api/v1/router_int_test.go +++ b/pkg/api/v1/router_int_test.go @@ -32,10 +32,12 @@ func serverTest(t *testing.T) *integrationServer { hs := hollowserver.Server{ Logger: l, Store: store, - AuthConfig: hollowserver.AuthConfig{ - Audience: "hollow.test", - Issuer: "hollow.test.issuer", - JWKSURI: jwksURI, + AuthConfig: ginjwt.AuthConfig{ + Enabled: true, + Audience: "hollow.test", + Issuer: "hollow.test.issuer", + JWKSURI: jwksURI, + RolesClaim: "userPerms", }, } s := hs.NewServer() @@ -75,5 +77,5 @@ func validToken(scopes []string) string { } signer := ginjwt.TestHelperMustMakeSigner(jose.RS256, ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey1) - return ginjwt.TestHelperGetToken(signer, claims, scopes) + return ginjwt.TestHelperGetToken(signer, claims, "userPerms", scopes) } diff --git a/pkg/ginjwt/jwt.go b/pkg/ginjwt/jwt.go index 69be9d7b..363bd0b1 100644 --- a/pkg/ginjwt/jwt.go +++ b/pkg/ginjwt/jwt.go @@ -13,33 +13,43 @@ import ( const ( contextKeySubject = "jwt.subject" - contextKeyEmail = "jwt.email" + contextKeyUser = "jwt.user" expectedAuthHeaderParts = 2 ) // Middleware provides a gin compatible middleware that will authenticate JWT requests type Middleware struct { - audience string - issuer string - jwksURI string + config AuthConfig cachedJWKS jose.JSONWebKeySet } -type customClaims struct { - Scope string `json:"scope"` - Email string `json:"https://equinixmetal.com/email"` -} - -func (c *customClaims) Scopes() []string { - return strings.Split(c.Scope, " ") +// AuthConfig provides the configuration for the authentication service +type AuthConfig struct { + Enabled bool + Audience string + Issuer string + JWKSURI string + LogFields []string + RolesClaim string + UsernameClaim string } // NewAuthMiddleware will return an auth middleware configured with the jwt parameters passed in -func NewAuthMiddleware(aud, iss, jwksURI string) (*Middleware, error) { +func NewAuthMiddleware(cfg AuthConfig) (*Middleware, error) { + if cfg.RolesClaim == "" { + cfg.RolesClaim = "scope" + } + + if cfg.UsernameClaim == "" { + cfg.UsernameClaim = "sub" + } + mw := &Middleware{ - audience: aud, - issuer: iss, - jwksURI: jwksURI, + config: cfg, + } + + if !cfg.Enabled { + return mw, nil } if err := mw.refreshJWKS(); err != nil { @@ -52,6 +62,10 @@ func NewAuthMiddleware(aud, iss, jwksURI string) (*Middleware, error) { // AuthRequired provides a middleware that ensures a request has authentication func (m *Middleware) AuthRequired(scopes []string) gin.HandlerFunc { return func(c *gin.Context) { + if !m.config.Enabled { + return + } + authHeader := c.Request.Header.Get("Authorization") if authHeader == "" { @@ -86,7 +100,7 @@ func (m *Middleware) AuthRequired(scopes []string) gin.HandlerFunc { } cl := jwt.Claims{} - sc := customClaims{} + sc := map[string]interface{}{} if err := tok.Claims(key, &cl, &sc); err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"message": "unable to validate auth token"}) @@ -94,8 +108,8 @@ func (m *Middleware) AuthRequired(scopes []string) gin.HandlerFunc { } err = cl.Validate(jwt.Expected{ - Issuer: m.issuer, - Audience: jwt.Audience{m.audience}, + Issuer: m.config.Issuer, + Audience: jwt.Audience{m.config.Audience}, Time: time.Now(), }) if err != nil { @@ -103,18 +117,31 @@ func (m *Middleware) AuthRequired(scopes []string) gin.HandlerFunc { return } - if !hasScope(sc.Scopes(), scopes) { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"message": "not authorized, missing required scope"}) + var roles []string + switch r := sc[m.config.RolesClaim].(type) { + case string: + roles = strings.Split(r, " ") + case []interface{}: + for _, i := range r { + roles = append(roles, i.(string)) + } + } + + if !hasScope(roles, scopes) { + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"message": "not authorized, missing required scope"}) return } + u := sc[m.config.UsernameClaim] + user := u.(string) + c.Set(contextKeySubject, cl.Subject) - c.Set(contextKeyEmail, sc.Email) + c.Set(contextKeyUser, user) } } func (m *Middleware) refreshJWKS() error { - resp, err := http.Get(m.jwksURI) //nolint:noctx + resp, err := http.Get(m.config.JWKSURI) //nolint:noctx if err != nil { return err } @@ -163,8 +190,8 @@ func GetSubject(c *gin.Context) string { return c.GetString(contextKeySubject) } -// GetEmail will return the JWT email that is saved in the request. This requires that authentication of the request -// has already occurred. If authentication failed or there isn't an email an empty string is returned. -func GetEmail(c *gin.Context) string { - return c.GetString(contextKeyEmail) +// GetUser will return the JWT user that is saved in the request. This requires that authentication of the request +// has already occurred. If authentication failed or there isn't a user an empty string is returned. +func GetUser(c *gin.Context) string { + return c.GetString(contextKeyUser) } diff --git a/pkg/ginjwt/jwt_test.go b/pkg/ginjwt/jwt_test.go index ecd56ecc..5a9ba659 100644 --- a/pkg/ginjwt/jwt_test.go +++ b/pkg/ginjwt/jwt_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -112,7 +113,7 @@ func TestMiddlewareValidatesTokens(t *testing.T) { Audience: jwt.Audience{"ginjwt.test", "another.test.service"}, }, []string{"testScope", "anotherScope", "more-scopes"}, - http.StatusUnauthorized, + http.StatusForbidden, "missing required scope", }, { @@ -172,7 +173,9 @@ func TestMiddlewareValidatesTokens(t *testing.T) { for _, tt := range testCases { t.Run(tt.testName, func(t *testing.T) { jwksURI := ginjwt.TestHelperJWKSProvider() - authMW, err := ginjwt.NewAuthMiddleware(tt.middlewareAud, tt.middlewareIss, jwksURI) + + cfg := ginjwt.AuthConfig{Enabled: true, Audience: tt.middlewareAud, Issuer: tt.middlewareIss, JWKSURI: jwksURI} + authMW, err := ginjwt.NewAuthMiddleware(cfg) require.NoError(t, err) r := gin.New() @@ -185,7 +188,7 @@ func TestMiddlewareValidatesTokens(t *testing.T) { req := httptest.NewRequest("GET", "http://test/", nil) signer := ginjwt.TestHelperMustMakeSigner(jose.RS256, tt.signingKeyID, tt.signingKey) - rawToken := ginjwt.TestHelperGetToken(signer, tt.claims, tt.claimScopes) + rawToken := ginjwt.TestHelperGetToken(signer, tt.claims, "scope", strings.Join(tt.claimScopes, " ")) req.Header.Set("Authorization", fmt.Sprintf("bearer %s", rawToken)) r.ServeHTTP(w, req) @@ -232,7 +235,8 @@ func TestInvalidAuthHeader(t *testing.T) { for _, tt := range testCases { t.Run(tt.testName, func(t *testing.T) { jwksURI := ginjwt.TestHelperJWKSProvider() - authMW, err := ginjwt.NewAuthMiddleware("aud", "iss", jwksURI) + cfg := ginjwt.AuthConfig{Enabled: true, Audience: "aud", Issuer: "iss", JWKSURI: jwksURI} + authMW, err := ginjwt.NewAuthMiddleware(cfg) require.NoError(t, err) r := gin.New() diff --git a/pkg/ginjwt/testtools.go b/pkg/ginjwt/testtools.go index 12fb6570..c9c00b4a 100644 --- a/pkg/ginjwt/testtools.go +++ b/pkg/ginjwt/testtools.go @@ -8,7 +8,6 @@ import ( "fmt" "net" "net/http" - "strings" "github.com/gin-gonic/gin" "gopkg.in/square/go-jose.v2" @@ -76,8 +75,10 @@ func TestHelperJWKSProvider() string { } // TestHelperGetToken will return a signed token -func TestHelperGetToken(signer jose.Signer, cl jwt.Claims, scopes []string) string { - sc := customClaims{Scope: strings.Join(scopes, " ")} +func TestHelperGetToken(signer jose.Signer, cl jwt.Claims, key string, value interface{}) string { + sc := map[string]interface{}{} + + sc[key] = value raw, err := jwt.Signed(signer).Claims(cl).Claims(sc).CompactSerialize() if err != nil {