Skip to content

Commit

Permalink
feat: Improve OIDC config options [DEL-479] (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolerenee authored Aug 10, 2021
1 parent da1df3e commit c7a53d0
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 65 deletions.
31 changes: 21 additions & 10 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"go.metalkube.net/hollow/internal/db"
"go.metalkube.net/hollow/internal/hollowserver"
"go.metalkube.net/hollow/pkg/ginjwt"
)

// serveCmd represents the serve command
Expand All @@ -25,12 +26,18 @@ func init() {
serveCmd.Flags().String("db-uri", "postgresql://root@db:26257/hollow_dev?sslmode=disable", "URI for database connection")
viperBindFlag("db.uri", serveCmd.Flags().Lookup("db-uri"))

serveCmd.Flags().String("jwt-aud", "", "expected audience on JWT tokens")
viperBindFlag("jwt.audience", serveCmd.Flags().Lookup("jwt-aud"))
serveCmd.Flags().String("jwt-issuer", "https://equinixmetal.us.auth0.com/", "expected issuer of JWT tokens")
viperBindFlag("jwt.issuer", serveCmd.Flags().Lookup("jwt-issuer"))
serveCmd.Flags().String("jwt-jwksuri", "https://equinixmetal.us.auth0.com/.well-known/jwks.json", "URI for JWKS listing for JWTs")
viperBindFlag("jwt.jwksuri", serveCmd.Flags().Lookup("jwt-jwksuri"))
serveCmd.Flags().Bool("oidc", true, "use oidc auth")
viperBindFlag("oidc.enabled", serveCmd.Flags().Lookup("oidc"))
serveCmd.Flags().String("oidc-aud", "", "expected audience on OIDC JWT")
viperBindFlag("oidc.audience", serveCmd.Flags().Lookup("oidc-aud"))
serveCmd.Flags().String("oidc-issuer", "", "expected issuer of OIDC JWT")
viperBindFlag("oidc.issuer", serveCmd.Flags().Lookup("oidc-issuer"))
serveCmd.Flags().String("oidc-jwksuri", "", "URI for JWKS listing for JWTs")
viperBindFlag("oidc.jwksuri", serveCmd.Flags().Lookup("oidc-jwksuri"))
serveCmd.Flags().String("oidc-roles-claim", "claim", "field containing the permissions of an OIDC JWT")
viperBindFlag("oidc.claims.roles", serveCmd.Flags().Lookup("oidc-roles-claim"))
serveCmd.Flags().String("oidc-username-claim", "", "additional fields to output in logs from the JWT token, ex (email)")
viperBindFlag("oidc.claims.username", serveCmd.Flags().Lookup("oidc-username-claim"))
}

func serve() {
Expand All @@ -48,10 +55,14 @@ func serve() {
Listen: viper.GetString("listen"),
Debug: viper.GetBool("logging.debug"),
Store: store,
AuthConfig: hollowserver.AuthConfig{
Audience: viper.GetString("jwt.audience"),
Issuer: viper.GetString("jwt.issuer"),
JWKSURI: viper.GetString("jwt.jwksuri"),
AuthConfig: ginjwt.AuthConfig{
Enabled: viper.GetBool("oidc.enabled"),
Audience: viper.GetString("oidc.audience"),
Issuer: viper.GetString("oidc.issuer"),
JWKSURI: viper.GetString("oidc.jwksuri"),
LogFields: viper.GetStringSlice("oidc.log"),
RolesClaim: viper.GetString("oidc.claims.roles"),
UsernameClaim: viper.GetString("oidc.claims.username"),
},
}

Expand Down
13 changes: 3 additions & 10 deletions internal/hollowserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,7 @@ type Server struct {
Listen string
Debug bool
Store *db.Store
AuthConfig AuthConfig
}

// AuthConfig provides the configuration for the authentication service
type AuthConfig struct {
Audience string
Issuer string
JWKSURI string
AuthConfig ginjwt.AuthConfig
}

var (
Expand All @@ -41,7 +34,7 @@ func (s *Server) setup() *gin.Engine {
err error
)

authMW, err = ginjwt.NewAuthMiddleware(s.AuthConfig.Audience, s.AuthConfig.Issuer, s.AuthConfig.JWKSURI)
authMW, err = ginjwt.NewAuthMiddleware(s.AuthConfig)
if err != nil {
s.Logger.Sugar().Fatal("failed to initialize auth middleware", "error", err)
}
Expand Down Expand Up @@ -70,7 +63,7 @@ func (s *Server) setup() *gin.Engine {
ginzap.WithUTC(true),
ginzap.WithCustomFields(
func(c *gin.Context) zap.Field { return zap.String("jwt_subject", ginjwt.GetSubject(c)) },
func(c *gin.Context) zap.Field { return zap.String("jwt_email", ginjwt.GetEmail(c)) },
func(c *gin.Context) zap.Field { return zap.String("jwt_user", ginjwt.GetUser(c)) },
),
))
r.Use(ginzap.RecoveryWithZap(s.Logger, true))
Expand Down
8 changes: 3 additions & 5 deletions internal/hollowserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@ import (

"go.metalkube.net/hollow/internal/db"
"go.metalkube.net/hollow/internal/hollowserver"
"go.metalkube.net/hollow/pkg/ginjwt"
)

// use a public keys listing so tests pass
var serverAuthConfig = hollowserver.AuthConfig{
Audience: "",
Issuer: "",
JWKSURI: "https://www.googleapis.com/oauth2/v3/certs",
var serverAuthConfig = ginjwt.AuthConfig{
Enabled: false,
}

func TestUnknownRoute(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions pkg/api/v1/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ func realClientTests(t *testing.T, f func(ctx context.Context, token string, res
"auth token with no scopes",
ctx,
validToken([]string{}),
http.StatusUnauthorized,
http.StatusForbidden,
true,
"server error - response code: 401, message:",
"server error - response code: 403, message:",
},
{
"fake timeout",
Expand Down
12 changes: 7 additions & 5 deletions pkg/api/v1/router_int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ func serverTest(t *testing.T) *integrationServer {
hs := hollowserver.Server{
Logger: l,
Store: store,
AuthConfig: hollowserver.AuthConfig{
Audience: "hollow.test",
Issuer: "hollow.test.issuer",
JWKSURI: jwksURI,
AuthConfig: ginjwt.AuthConfig{
Enabled: true,
Audience: "hollow.test",
Issuer: "hollow.test.issuer",
JWKSURI: jwksURI,
RolesClaim: "userPerms",
},
}
s := hs.NewServer()
Expand Down Expand Up @@ -75,5 +77,5 @@ func validToken(scopes []string) string {
}
signer := ginjwt.TestHelperMustMakeSigner(jose.RS256, ginjwt.TestPrivRSAKey1ID, ginjwt.TestPrivRSAKey1)

return ginjwt.TestHelperGetToken(signer, claims, scopes)
return ginjwt.TestHelperGetToken(signer, claims, "userPerms", scopes)
}
79 changes: 53 additions & 26 deletions pkg/ginjwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,43 @@ import (

const (
contextKeySubject = "jwt.subject"
contextKeyEmail = "jwt.email"
contextKeyUser = "jwt.user"
expectedAuthHeaderParts = 2
)

// Middleware provides a gin compatible middleware that will authenticate JWT requests
type Middleware struct {
audience string
issuer string
jwksURI string
config AuthConfig
cachedJWKS jose.JSONWebKeySet
}

type customClaims struct {
Scope string `json:"scope"`
Email string `json:"https://equinixmetal.com/email"`
}

func (c *customClaims) Scopes() []string {
return strings.Split(c.Scope, " ")
// AuthConfig provides the configuration for the authentication service
type AuthConfig struct {
Enabled bool
Audience string
Issuer string
JWKSURI string
LogFields []string
RolesClaim string
UsernameClaim string
}

// NewAuthMiddleware will return an auth middleware configured with the jwt parameters passed in
func NewAuthMiddleware(aud, iss, jwksURI string) (*Middleware, error) {
func NewAuthMiddleware(cfg AuthConfig) (*Middleware, error) {
if cfg.RolesClaim == "" {
cfg.RolesClaim = "scope"
}

if cfg.UsernameClaim == "" {
cfg.UsernameClaim = "sub"
}

mw := &Middleware{
audience: aud,
issuer: iss,
jwksURI: jwksURI,
config: cfg,
}

if !cfg.Enabled {
return mw, nil
}

if err := mw.refreshJWKS(); err != nil {
Expand All @@ -52,6 +62,10 @@ func NewAuthMiddleware(aud, iss, jwksURI string) (*Middleware, error) {
// AuthRequired provides a middleware that ensures a request has authentication
func (m *Middleware) AuthRequired(scopes []string) gin.HandlerFunc {
return func(c *gin.Context) {
if !m.config.Enabled {
return
}

authHeader := c.Request.Header.Get("Authorization")

if authHeader == "" {
Expand Down Expand Up @@ -86,35 +100,48 @@ func (m *Middleware) AuthRequired(scopes []string) gin.HandlerFunc {
}

cl := jwt.Claims{}
sc := customClaims{}
sc := map[string]interface{}{}

if err := tok.Claims(key, &cl, &sc); err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"message": "unable to validate auth token"})
return
}

err = cl.Validate(jwt.Expected{
Issuer: m.issuer,
Audience: jwt.Audience{m.audience},
Issuer: m.config.Issuer,
Audience: jwt.Audience{m.config.Audience},
Time: time.Now(),
})
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"message": "invalid auth token", "error": err.Error()})
return
}

if !hasScope(sc.Scopes(), scopes) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"message": "not authorized, missing required scope"})
var roles []string
switch r := sc[m.config.RolesClaim].(type) {
case string:
roles = strings.Split(r, " ")
case []interface{}:
for _, i := range r {
roles = append(roles, i.(string))
}
}

if !hasScope(roles, scopes) {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"message": "not authorized, missing required scope"})
return
}

u := sc[m.config.UsernameClaim]
user := u.(string)

c.Set(contextKeySubject, cl.Subject)
c.Set(contextKeyEmail, sc.Email)
c.Set(contextKeyUser, user)
}
}

func (m *Middleware) refreshJWKS() error {
resp, err := http.Get(m.jwksURI) //nolint:noctx
resp, err := http.Get(m.config.JWKSURI) //nolint:noctx
if err != nil {
return err
}
Expand Down Expand Up @@ -163,8 +190,8 @@ func GetSubject(c *gin.Context) string {
return c.GetString(contextKeySubject)
}

// GetEmail will return the JWT email that is saved in the request. This requires that authentication of the request
// has already occurred. If authentication failed or there isn't an email an empty string is returned.
func GetEmail(c *gin.Context) string {
return c.GetString(contextKeyEmail)
// GetUser will return the JWT user that is saved in the request. This requires that authentication of the request
// has already occurred. If authentication failed or there isn't a user an empty string is returned.
func GetUser(c *gin.Context) string {
return c.GetString(contextKeyUser)
}
12 changes: 8 additions & 4 deletions pkg/ginjwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -112,7 +113,7 @@ func TestMiddlewareValidatesTokens(t *testing.T) {
Audience: jwt.Audience{"ginjwt.test", "another.test.service"},
},
[]string{"testScope", "anotherScope", "more-scopes"},
http.StatusUnauthorized,
http.StatusForbidden,
"missing required scope",
},
{
Expand Down Expand Up @@ -172,7 +173,9 @@ func TestMiddlewareValidatesTokens(t *testing.T) {
for _, tt := range testCases {
t.Run(tt.testName, func(t *testing.T) {
jwksURI := ginjwt.TestHelperJWKSProvider()
authMW, err := ginjwt.NewAuthMiddleware(tt.middlewareAud, tt.middlewareIss, jwksURI)

cfg := ginjwt.AuthConfig{Enabled: true, Audience: tt.middlewareAud, Issuer: tt.middlewareIss, JWKSURI: jwksURI}
authMW, err := ginjwt.NewAuthMiddleware(cfg)
require.NoError(t, err)

r := gin.New()
Expand All @@ -185,7 +188,7 @@ func TestMiddlewareValidatesTokens(t *testing.T) {
req := httptest.NewRequest("GET", "http://test/", nil)

signer := ginjwt.TestHelperMustMakeSigner(jose.RS256, tt.signingKeyID, tt.signingKey)
rawToken := ginjwt.TestHelperGetToken(signer, tt.claims, tt.claimScopes)
rawToken := ginjwt.TestHelperGetToken(signer, tt.claims, "scope", strings.Join(tt.claimScopes, " "))
req.Header.Set("Authorization", fmt.Sprintf("bearer %s", rawToken))

r.ServeHTTP(w, req)
Expand Down Expand Up @@ -232,7 +235,8 @@ func TestInvalidAuthHeader(t *testing.T) {
for _, tt := range testCases {
t.Run(tt.testName, func(t *testing.T) {
jwksURI := ginjwt.TestHelperJWKSProvider()
authMW, err := ginjwt.NewAuthMiddleware("aud", "iss", jwksURI)
cfg := ginjwt.AuthConfig{Enabled: true, Audience: "aud", Issuer: "iss", JWKSURI: jwksURI}
authMW, err := ginjwt.NewAuthMiddleware(cfg)
require.NoError(t, err)

r := gin.New()
Expand Down
7 changes: 4 additions & 3 deletions pkg/ginjwt/testtools.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"net"
"net/http"
"strings"

"github.com/gin-gonic/gin"
"gopkg.in/square/go-jose.v2"
Expand Down Expand Up @@ -76,8 +75,10 @@ func TestHelperJWKSProvider() string {
}

// TestHelperGetToken will return a signed token
func TestHelperGetToken(signer jose.Signer, cl jwt.Claims, scopes []string) string {
sc := customClaims{Scope: strings.Join(scopes, " ")}
func TestHelperGetToken(signer jose.Signer, cl jwt.Claims, key string, value interface{}) string {
sc := map[string]interface{}{}

sc[key] = value

raw, err := jwt.Signed(signer).Claims(cl).Claims(sc).CompactSerialize()
if err != nil {
Expand Down

0 comments on commit c7a53d0

Please sign in to comment.