Skip to content

Commit

Permalink
Move code to packages, regroup (#467)
Browse files Browse the repository at this point in the history
Refactor, regrouping code to packages, separating what is reusable
  • Loading branch information
p53 authored May 23, 2024
1 parent 28b055b commit 9de8d1d
Show file tree
Hide file tree
Showing 23 changed files with 584 additions and 572 deletions.
15 changes: 3 additions & 12 deletions pkg/authorization/external_keycloak.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,13 @@ import (

"github.com/Nerzal/gocloak/v12"
"github.com/gogatekeeper/gatekeeper/pkg/apperrors"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/models"
)

type Permission struct {
Scopes []string `json:"scopes"`
ResourceID string `json:"rsid"`
ResourceName string `json:"rsname"`
}

type Permissions struct {
Permissions []Permission `json:"permissions"`
}

var _ Provider = (*KeycloakAuthorizationProvider)(nil)

type KeycloakAuthorizationProvider struct {
perms Permissions
perms models.Permissions
targetPath string
idpClient *gocloak.GoCloak
idpTimeout time.Duration
Expand All @@ -31,7 +22,7 @@ type KeycloakAuthorizationProvider struct {
}

func NewKeycloakAuthorizationProvider(
perms Permissions,
perms models.Permissions,
targetPath string,
idpClient *gocloak.GoCloak,
idpTimeout time.Duration,
Expand Down
5 changes: 3 additions & 2 deletions pkg/keycloak/proxy/forwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/gogatekeeper/gatekeeper/pkg/apperrors"
"github.com/gogatekeeper/gatekeeper/pkg/constant"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/models"
"github.com/gogatekeeper/gatekeeper/pkg/utils"
"go.uber.org/zap"
)
Expand All @@ -45,10 +46,10 @@ func proxyMiddleware(

// @step: retrieve the request scope
ctxVal := req.Context().Value(constant.ContextScopeName)
var scope *RequestScope
var scope *models.RequestScope
if ctxVal != nil {
var assertOk bool
scope, assertOk = ctxVal.(*RequestScope)
scope, assertOk = ctxVal.(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
Expand Down
28 changes: 15 additions & 13 deletions pkg/keycloak/proxy/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import (
"github.com/gogatekeeper/gatekeeper/pkg/proxy/cookie"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/handlers"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/metrics"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/models"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/session"
"github.com/gogatekeeper/gatekeeper/pkg/storage"
"github.com/gogatekeeper/gatekeeper/pkg/utils"
"github.com/grokify/go-pkce"
Expand Down Expand Up @@ -66,7 +68,7 @@ func oauthAuthorizationHandler(
return
}

scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
Expand Down Expand Up @@ -194,7 +196,7 @@ func oauthCallbackHandler(
return
}

scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
Expand Down Expand Up @@ -396,7 +398,7 @@ func loginHandler(
store storage.Storage,
) func(wrt http.ResponseWriter, req *http.Request) {
return func(writer http.ResponseWriter, req *http.Request) {
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope)

if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
Expand Down Expand Up @@ -455,7 +457,7 @@ func loginHandler(
errors.Join(apperrors.ErrParseAccessToken, err)
}

identity, err := ExtractIdentity(accessTokenObj)
identity, err := session.ExtractIdentity(accessTokenObj)
if err != nil {
return http.StatusNotImplemented,
errors.Join(apperrors.ErrExtractIdentityFromAccessToken, err)
Expand Down Expand Up @@ -576,10 +578,10 @@ func loginHandler(
}
}

var resp TokenResponse
var resp models.TokenResponse

if enableEncryptedToken {
resp = TokenResponse{
resp = models.TokenResponse{
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
Expand All @@ -588,7 +590,7 @@ func loginHandler(
TokenType: token.TokenType,
}
} else {
resp = TokenResponse{
resp = models.TokenResponse{
IDToken: plainIDToken,
AccessToken: token.AccessToken,
RefreshToken: refreshToken,
Expand Down Expand Up @@ -643,7 +645,7 @@ func logoutHandler(
cookManager *cookie.Manager,
idpClient *gocloak.GoCloak,
accessError func(wrt http.ResponseWriter, req *http.Request) context.Context,
GetIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*UserContext, error),
GetIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error),
) func(wrt http.ResponseWriter, req *http.Request) {
return func(writer http.ResponseWriter, req *http.Request) {
// @check if the redirection is there
Expand All @@ -667,7 +669,7 @@ func logoutHandler(
}
}

scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
writer.WriteHeader(http.StatusInternalServerError)
Expand Down Expand Up @@ -835,7 +837,7 @@ func logoutHandler(

// expirationHandler checks if the token has expired
func expirationHandler(
getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*UserContext, error),
getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error),
cookieAccessName string,
) func(wrt http.ResponseWriter, req *http.Request) {
return func(wrt http.ResponseWriter, req *http.Request) {
Expand All @@ -856,7 +858,7 @@ func expirationHandler(

// tokenHandler display access token to screen
func tokenHandler(
getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*UserContext, error),
getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error),
cookieAccessName string,
accessError func(wrt http.ResponseWriter, req *http.Request) context.Context,
) func(wrt http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -897,7 +899,7 @@ func retrieveRefreshToken(
cookieRefreshName string,
encryptionKey string,
req *http.Request,
user *UserContext,
user *models.UserContext,
) (string, string, error) {
var token string
var err error
Expand All @@ -906,7 +908,7 @@ func retrieveRefreshToken(
case true:
token, err = GetRefreshTokenFromStore(req.Context(), store, user.RawToken)
default:
token, err = utils.GetRefreshTokenFromCookie(req, cookieRefreshName)
token, err = session.GetRefreshTokenFromCookie(req, cookieRefreshName)
}

if err != nil {
Expand Down
38 changes: 20 additions & 18 deletions pkg/keycloak/proxy/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import (
"github.com/gogatekeeper/gatekeeper/pkg/encryption"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/cookie"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/metrics"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/models"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/session"
"github.com/gogatekeeper/gatekeeper/pkg/storage"
"github.com/gogatekeeper/gatekeeper/pkg/utils"
"golang.org/x/oauth2"
Expand All @@ -59,7 +61,7 @@ func entrypointMiddleware(logger *zap.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) {
// @step: create a context for the request
scope := &RequestScope{}
scope := &models.RequestScope{}
// Save the exact formatting of the incoming request so we can use it later
scope.Path = req.URL.Path
scope.RawPath = req.URL.RawPath
Expand Down Expand Up @@ -123,7 +125,7 @@ func loggingMiddleware(
return
}

scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
Expand Down Expand Up @@ -174,7 +176,7 @@ func authenticationMiddleware(
logger *zap.Logger,
cookieAccessName string,
cookieRefreshName string,
getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*UserContext, error),
getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error),
idpClient *gocloak.GoCloak,
enableIDPSessionCheck bool,
provider *oidc3.Provider,
Expand All @@ -196,7 +198,7 @@ func authenticationMiddleware(
) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) {
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
Expand Down Expand Up @@ -404,7 +406,7 @@ func authenticationMiddleware(
cookMgr.DropAccessTokenCookie(req.WithContext(ctx), wrt, accessToken, accessExpiresIn)

// update the with the new access token and inject into the context
newUser, err := ExtractIdentity(&newAccToken)
newUser, err := session.ExtractIdentity(&newAccToken)
if err != nil {
lLog.Error(err.Error())
accessForbidden(wrt, req)
Expand Down Expand Up @@ -492,12 +494,12 @@ func authorizationMiddleware(
clientID string,
skipClientIDCheck bool,
skipIssuerCheck bool,
getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*UserContext, error),
getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error),
accessForbidden func(wrt http.ResponseWriter, req *http.Request) context.Context,
) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) {
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
Expand Down Expand Up @@ -545,7 +547,7 @@ func authorizationMiddleware(

authzFunc := func(
targetPath string,
userPerms authorization.Permissions,
userPerms models.Permissions,
) (authorization.AuthzDecision, error) {
pat.m.RLock()
token := pat.Token.AccessToken
Expand Down Expand Up @@ -575,7 +577,7 @@ func authorizationMiddleware(
authzFunc,
)
if err != nil {
var umaUser *UserContext
var umaUser *models.UserContext
scope.Logger.Error(err.Error())
scope.Logger.Info("trying to get new uma token")

Expand Down Expand Up @@ -686,7 +688,7 @@ func authorizationMiddleware(
//nolint:cyclop
func checkClaim(
logger *zap.Logger,
user *UserContext,
user *models.UserContext,
claimName string,
match *regexp.Regexp,
resourceURL string,
Expand Down Expand Up @@ -783,7 +785,7 @@ func admissionMiddleware(
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) {
// we don't need to continue is a decision has been made
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
Expand Down Expand Up @@ -912,7 +914,7 @@ func identityHeadersMiddleware(

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) {
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
Expand Down Expand Up @@ -946,7 +948,7 @@ func identityHeadersMiddleware(
}
// are we filtering out the cookies
if !enableAuthzCookies {
_ = filterCookies(req, cookieFilter)
_ = cookie.FilterCookies(req, cookieFilter)
}
// inject any custom claims
for claim, header := range customClaims {
Expand Down Expand Up @@ -986,7 +988,7 @@ func securityMiddleware(
})

return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) {
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
Expand Down Expand Up @@ -1026,12 +1028,12 @@ func proxyDenyMiddleware(logger *zap.Logger) func(http.Handler) http.Handler {
return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) {
ctxVal := req.Context().Value(constant.ContextScopeName)

var scope *RequestScope
var scope *models.RequestScope
if ctxVal == nil {
scope = &RequestScope{}
scope = &models.RequestScope{}
} else {
var assertOk bool
scope, assertOk = ctxVal.(*RequestScope)
scope, assertOk = ctxVal.(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
Expand Down Expand Up @@ -1064,7 +1066,7 @@ func denyMiddleware(
func hmacMiddleware(logger *zap.Logger, encKey string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) {
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
Expand Down
Loading

0 comments on commit 9de8d1d

Please sign in to comment.