Skip to content

Commit

Permalink
Add auth query params (#470)
Browse files Browse the repository at this point in the history
 Refactor, move functions to new packages, sort them
  • Loading branch information
p53 authored Jun 4, 2024
1 parent 583cc2f commit c86921f
Show file tree
Hide file tree
Showing 16 changed files with 571 additions and 572 deletions.
89 changes: 0 additions & 89 deletions pkg/keycloak/proxy/forwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,101 +18,12 @@ package proxy
import (
"fmt"
"net/http"
"net/url"

"github.com/gogatekeeper/gatekeeper/pkg/apperrors"
"github.com/gogatekeeper/gatekeeper/pkg/constant"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/models"
"github.com/gogatekeeper/gatekeeper/pkg/utils"
"go.uber.org/zap"
)

/*
proxyMiddleware is responsible for handles reverse proxy
request to the upstream endpoint
*/
//nolint:cyclop
func proxyMiddleware(
logger *zap.Logger,
corsOrigins []string,
headers map[string]string,
endpoint *url.URL,
preserveHost bool,
upstream reverseProxy,
) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) {
next.ServeHTTP(wrt, req)

// @step: retrieve the request scope
ctxVal := req.Context().Value(constant.ContextScopeName)
var scope *models.RequestScope
if ctxVal != nil {
var assertOk bool
scope, assertOk = ctxVal.(*models.RequestScope)
if !assertOk {
logger.Error(apperrors.ErrAssertionFailed.Error())
return
}
if scope.AccessDenied {
return
}
}

// @step: add the proxy forwarding headers
req.Header.Set("X-Real-IP", utils.RealIP(req))
if xff := req.Header.Get(constant.HeaderXForwardedFor); xff == "" {
req.Header.Set("X-Forwarded-For", utils.RealIP(req))
} else {
req.Header.Set("X-Forwarded-For", xff)
}
req.Header.Set("X-Forwarded-Host", req.Host)
req.Header.Set("X-Forwarded-Proto", req.Header.Get("X-Forwarded-Proto"))

if len(corsOrigins) > 0 {
// if CORS is enabled by Gatekeeper, do not propagate CORS requests upstream
req.Header.Del("Origin")
}
// @step: add any custom headers to the request
for k, v := range headers {
req.Header.Set(k, v)
}

// @note: by default goproxy only provides a forwarding proxy, thus all requests have to be absolute and we must update the host headers
req.URL.Host = endpoint.Host
req.URL.Scheme = endpoint.Scheme
// Restore the unprocessed original path, so that we pass upstream exactly what we received
// as the resource request.
if scope != nil {
req.URL.Path = scope.Path
req.URL.RawPath = scope.RawPath
}
if v := req.Header.Get("Host"); v != "" {
req.Host = v
req.Header.Del("Host")
} else if !preserveHost {
req.Host = endpoint.Host
}

if utils.IsUpgradedConnection(req) {
clientIP := utils.RealIP(req)
logger.Debug("upgrading the connnection",
zap.String("client_ip", clientIP),
zap.String("remote_addr", req.RemoteAddr),
)
if err := utils.TryUpdateConnection(req, wrt, endpoint); err != nil {
logger.Error("failed to upgrade connection", zap.Error(err))
wrt.WriteHeader(http.StatusInternalServerError)
return
}
return
}

upstream.ServeHTTP(wrt, req)
})
}
}

// forwardProxyHandler is responsible for signing outbound requests
func forwardProxyHandler(
logger *zap.Logger,
Expand Down
25 changes: 13 additions & 12 deletions pkg/keycloak/proxy/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/gogatekeeper/gatekeeper/pkg/constant"
"github.com/gogatekeeper/gatekeeper/pkg/encryption"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/cookie"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/core"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/handlers"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/metrics"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/models"
Expand Down Expand Up @@ -154,7 +155,7 @@ func oauthAuthorizationHandler(
}

scope.Logger.Debug("redirecting to auth_url", zap.String("auth_url", authURL))
redirectToURL(scope.Logger, authURL, wrt, req, http.StatusSeeOther)
core.RedirectToURL(scope.Logger, authURL, wrt, req, http.StatusSeeOther)
}
}

Expand Down Expand Up @@ -203,13 +204,13 @@ func oauthCallbackHandler(
}

scope.Logger.Debug("callback handler")
accessToken, identityToken, refreshToken, err := getCodeFlowTokens(
accessToken, identityToken, refreshToken, err := session.GetCodeFlowTokens(
scope,
writer,
req,
enablePKCE,
cookiePKCEName,
idpClient,
idpClient.RestyClient().GetClient(),
accessForbidden,
accessError,
newOAuth2Config,
Expand All @@ -220,7 +221,7 @@ func oauthCallbackHandler(
}

rawAccessToken := accessToken
oAccToken, _, err := verifyOIDCTokens(
oAccToken, _, err := utils.VerifyOIDCTokens(
req.Context(),
provider,
clientID,
Expand Down Expand Up @@ -272,7 +273,7 @@ func oauthCallbackHandler(
}

oidcTokensCookiesExp = time.Until(stdRefreshClaims.Expiry.Time())
encrypted, err = encryptToken(scope, refreshToken, encryptionKey, "refresh", writer)
encrypted, err = core.EncryptToken(scope, refreshToken, encryptionKey, "refresh", writer)
if err != nil {
return
}
Expand All @@ -297,7 +298,7 @@ func oauthCallbackHandler(
redirectURI := "/"
if req.URL.Query().Get("state") != "" {
if encodedRequestURI, _ := req.Cookie(cookieRequestURIName); encodedRequestURI != nil {
redirectURI = getRequestURIFromCookie(scope, encodedRequestURI)
redirectURI = session.GetRequestURIFromCookie(scope, encodedRequestURI)
}
}

Expand Down Expand Up @@ -338,18 +339,18 @@ func oauthCallbackHandler(

// step: are we encrypting the access token?
if enableEncryptedToken || forceEncryptedCookie {
accessToken, err = encryptToken(scope, accessToken, encryptionKey, "access", writer)
accessToken, err = core.EncryptToken(scope, accessToken, encryptionKey, "access", writer)
if err != nil {
return
}

identityToken, err = encryptToken(scope, identityToken, encryptionKey, "id", writer)
identityToken, err = core.EncryptToken(scope, identityToken, encryptionKey, "id", writer)
if err != nil {
return
}

if enableUma && umaError == nil {
umaToken, err = encryptToken(scope, umaToken, encryptionKey, "uma", writer)
umaToken, err = core.EncryptToken(scope, umaToken, encryptionKey, "uma", writer)
if err != nil {
return
}
Expand All @@ -373,7 +374,7 @@ func oauthCallbackHandler(
}

scope.Logger.Debug("redirecting to", zap.String("location", redirectURI))
redirectToURL(scope.Logger, redirectURI, writer, req, http.StatusSeeOther)
core.RedirectToURL(scope.Logger, redirectURI, writer, req, http.StatusSeeOther)
}
}

Expand Down Expand Up @@ -747,7 +748,7 @@ func logoutHandler(
postLogoutParams,
)

redirectToURL(
core.RedirectToURL(
scope.Logger,
sendTo,
writer,
Expand Down Expand Up @@ -829,7 +830,7 @@ func logoutHandler(

// step: should we redirect the user
if redirectURL != "" {
redirectToURL(scope.Logger, redirectURL, writer, req, http.StatusSeeOther)
core.RedirectToURL(scope.Logger, redirectURL, writer, req, http.StatusSeeOther)
}
}
}
Loading

0 comments on commit c86921f

Please sign in to comment.