Skip to content

Commit

Permalink
Adapt interface to allow returning the additional labels
Browse files Browse the repository at this point in the history
  • Loading branch information
angelbarrera92 committed Jun 19, 2023
1 parent bd975b8 commit 5f9d622
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 19 deletions.
5 changes: 3 additions & 2 deletions internal/app/prometheus-multi-tenant-proxy/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ type key int
// Auth implements an authentication middleware
type Auth interface {
// IsAuthorized authenticates a request and returns the list of namespaces the user has access to
IsAuthorized(r *http.Request) (bool, []string)
IsAuthorized(r *http.Request) (bool, []string, map[string]string)
// WriteUnauthorisedResponse writes an HTTP response in case the user is forbidden
WriteUnauthorisedResponse(w http.ResponseWriter)
// Load loads or reloads the configuration
Expand All @@ -20,12 +20,13 @@ type Auth interface {
// AuthHandler returns au authentication middleware handler
func AuthHandler(auth Auth, handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authorized, namespaces := auth.IsAuthorized(r)
authorized, namespaces, labels := auth.IsAuthorized(r)
if !authorized {
auth.WriteUnauthorisedResponse(w)
return
}
ctx := context.WithValue(r.Context(), Namespaces, namespaces)
ctx = context.WithValue(ctx, Labels, labels)
handler(w, r.WithContext(ctx))
}
}
14 changes: 8 additions & 6 deletions internal/app/prometheus-multi-tenant-proxy/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import (
const (
//Namespaces Key used to pass prometheus tenant id though the middleware context
Namespaces key = iota
realm = "Prometheus multi-tenant proxy"
//Labels Key used to pass prometheus additional labels though the middleware context
Labels key = iota
realm = "Prometheus multi-tenant proxy"
)

// BasicAuth can be used as a middleware chain to authenticate users
Expand Down Expand Up @@ -60,22 +62,22 @@ func (auth *BasicAuth) Load() bool {

// IsAuthorized uses the basic authentication and the Authn file to authenticate a user
// and return the namespace he has access to
func (auth *BasicAuth) IsAuthorized(r *http.Request) (bool, []string) {
func (auth *BasicAuth) IsAuthorized(r *http.Request) (bool, []string, map[string]string) {
user, pass, ok := r.BasicAuth()
if !ok {
return false, nil
return false, nil, nil
}
return auth.isAuthorized(user, pass)
}

func (auth *BasicAuth) isAuthorized(user, pass string) (bool, []string) {
func (auth *BasicAuth) isAuthorized(user, pass string) (bool, []string, map[string]string) {
authConfig := auth.getConfig()
for _, v := range authConfig.Users {
if subtle.ConstantTimeCompare([]byte(user), []byte(v.Username)) == 1 && subtle.ConstantTimeCompare([]byte(pass), []byte(v.Password)) == 1 {
return true, append(v.Namespaces, v.Namespace)
return true, append(v.Namespaces, v.Namespace), nil
}
}
return false, nil
return false, nil, nil
}

// WriteUnauthorisedResponse writes a 401 Unauthorized HTTP response with
Expand Down
8 changes: 7 additions & 1 deletion internal/app/prometheus-multi-tenant-proxy/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func TestBasic_isAuthorized(t *testing.T) {
args args
want bool
want1 []string
want2 map[string]string
}{
{
"Valid User",
Expand All @@ -48,6 +49,7 @@ func TestBasic_isAuthorized(t *testing.T) {
},
true,
[]string{"tenant-a"},
nil,
}, {
"Invalid User",
args{
Expand All @@ -56,17 +58,21 @@ func TestBasic_isAuthorized(t *testing.T) {
},
false,
nil,
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := auth.isAuthorized(tt.args.user, tt.args.pass)
got, got1, got2 := auth.isAuthorized(tt.args.user, tt.args.pass)
if got != tt.want {
t.Errorf("isAuthorized() got = %v, want %v", got, tt.want)
}
if !reflect.DeepEqual(got1, tt.want1) {
t.Errorf("isAuthorized() got1 = %v, want1 %v", got1, tt.want1)
}
if !reflect.DeepEqual(got2, tt.want2) {
t.Errorf("isAuthorized() got2 = %v, want2 %v", got2, tt.want2)
}
})
}
}
12 changes: 6 additions & 6 deletions internal/app/prometheus-multi-tenant-proxy/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ func (auth *JwtAuth) loadFromFile(location *string) bool {

// IsAuthorized validates the user by verifying the JWT token in
// the request and returning the namespaces claim found in token the payload.
func (auth *JwtAuth) IsAuthorized(r *http.Request) (bool, []string) {
func (auth *JwtAuth) IsAuthorized(r *http.Request) (bool, []string, map[string]string) {
tokenString := extractTokens(&r.Header)
if tokenString == "" {
log.Printf("Token is missing from header request")
return false, nil
return false, nil, nil
}
return auth.isAuthorized(tokenString)
}
Expand All @@ -144,19 +144,19 @@ func (auth *JwtAuth) WriteUnauthorisedResponse(w http.ResponseWriter) {
w.Write([]byte("Unauthorised\n"))
}

func (auth *JwtAuth) isAuthorized(tokenString string) (bool, []string) {
func (auth *JwtAuth) isAuthorized(tokenString string) (bool, []string, map[string]string) {
token, err := jwt.ParseWithClaims(tokenString, &NamespaceClaim{}, auth.jwks.Keyfunc)
if err != nil || !token.Valid {
log.Printf("%s\n", err)
return false, nil
return false, nil, nil
}

claims := token.Claims.(*NamespaceClaim)
if len(claims.Namespaces) == 0 {
log.Printf("token claim is invalid: namespaces is missing or empty")
return false, nil
return false, nil, nil
}
return true, claims.Namespaces
return true, claims.Namespaces, nil
}

func isValidSigningMethod(signingMethod string) bool {
Expand Down
8 changes: 4 additions & 4 deletions internal/app/prometheus-multi-tenant-proxy/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ const (
)

func (auth *JwtAuth) assertHmac(t *testing.T, expectAuthorized bool) {
authorized, _ := auth.isAuthorized(validHmacToken)
authorized, _, _ := auth.isAuthorized(validHmacToken)
if authorized != expectAuthorized {
t.Errorf("HMAC authorized=%v, expected=%v", authorized, expectAuthorized)
}
}
func (auth *JwtAuth) assertRSA(t *testing.T, expectAuthorized bool) {
authorized, _ := auth.isAuthorized(validRsaToken)
authorized, _, _ := auth.isAuthorized(validRsaToken)
if authorized != expectAuthorized {
t.Errorf("RSA authorized=%v, expected=%v", authorized, expectAuthorized)
}
Expand Down Expand Up @@ -132,7 +132,7 @@ func TestJWT_IsAuthorized(t *testing.T) {

for _, tc := range validTestCases {
t.Run(tc.desc, func(t *testing.T) {
authorized, namespaces := auth.isAuthorized(tc.token)
authorized, namespaces, _ := auth.isAuthorized(tc.token)
if !authorized {
t.Fatal("Should be authorized")
}
Expand All @@ -156,7 +156,7 @@ func TestJWT_IsAuthorized(t *testing.T) {

for _, tc := range invalidTestCases {
t.Run(tc.reason, func(t *testing.T) {
if authorized, _ := auth.isAuthorized(tc.token); authorized {
if authorized, _, _ := auth.isAuthorized(tc.token); authorized {
t.Error("Signature should be invalid - invalid secret signature")
}
})
Expand Down

0 comments on commit 5f9d622

Please sign in to comment.