Skip to content

Commit

Permalink
Add standardized errors where possible (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
p53 authored Oct 26, 2023
1 parent 3fe32ed commit a082bfa
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 118 deletions.
28 changes: 26 additions & 2 deletions pkg/apperrors/apperrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,37 @@ var (
ErrSessionExpiredRefreshOff = errors.New("session expired and access token refreshing is disabled")
ErrRefreshTokenNotFound = errors.New("unable to find refresh token for user")
ErrAccTokenRefreshFailure = errors.New("failed to refresh the access token")
ErrEncryptAccToken = errors.New("unable to encode access token")
ErrEncryptAccToken = errors.New("unable to encrypt access token")
ErrEncryptRefreshToken = errors.New("failed to encrypt refresh token")
ErrEncryptIDToken = errors.New("unable to encode idToken token")
ErrEncryptIDToken = errors.New("unable to encrypt idToken token")

ErrDelTokFromStore = errors.New("failed to remove old token")
ErrSaveTokToStore = errors.New("failed to store refresh token")

ErrLoginWithLoginHandleDisabled = errors.New("attempt to login when login handler is disabled")
ErrMissingLoginCreds = errors.New("request does not have both username and password")
ErrInvalidUserCreds = errors.New("invalid user credentials")
ErrAcquireTokenViaPassCredsGrant = errors.New("unable to request the access token via grant_type 'password'")
ErrExtractIdentityFromAccessToken = errors.New("unable to extract identity from access token")
ErrResponseMissingIDToken = errors.New("token response does not contain an id_token")
ErrResponseMissingExpires = errors.New("token response does not contain expires_in")
ErrParseRefreshToken = errors.New("failed to parse refresh token")
ErrParseIDToken = errors.New("failed to parse id token")
ErrParseAccessToken = errors.New("failed to parse access token")
ErrParseIDTokenClaims = errors.New("faled to parse id token claims")
ErrParseAccessTokenClaims = errors.New("faled to parse access token claims")
ErrParseRefreshTokenClaims = errors.New("faled to parse refresh token claims")

ErrVerifyIDToken = errors.New("unable to verify the ID token")
ErrVerifyAccessToken = errors.New("unable to verify the access token")

ErrCreateRevocationReq = errors.New("unable to construct the revocation request")
ErrRevocationReqFailure = errors.New("request to revocation endpoint failed")
ErrInvalidRevocationResp = errors.New("invalid response from revocation endpoint")

ErrMarshallDiscoveryResp = errors.New("problem marshalling discovery response")
ErrDiscoveryResponseWrite = errors.New("problem during discovery response write")

// config errors

ErrInvalidPostLoginRedirectPath = errors.New("post login redirect path invalid, should be only path not absolute url (no hostname, scheme)")
Expand Down
104 changes: 39 additions & 65 deletions pkg/keycloak/proxy/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (r *OauthProxy) oauthCallbackHandler(writer http.ResponseWriter, req *http.
case r.useStore():
if err = r.StoreRefreshToken(req.Context(), rawAccessToken, encrypted, oidcTokensCookiesExp); err != nil {
scope.Logger.Error(
"failed to save the refresh token in the store",
apperrors.ErrSaveTokToStore.Error(),
zap.Error(err),
zap.String("sub", stdClaims.Subject),
zap.String("email", customClaims.Email),
Expand Down Expand Up @@ -311,7 +311,7 @@ func (r *OauthProxy) oauthCallbackHandler(writer http.ResponseWriter, req *http.
/*
loginHandler provide's a generic endpoint for clients to perform a user_credentials login to the provider
*/
//nolint:funlen,cyclop // refactor
//nolint:cyclop // refactor
func (r *OauthProxy) loginHandler(writer http.ResponseWriter, req *http.Request) {
scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)

Expand All @@ -321,7 +321,7 @@ func (r *OauthProxy) loginHandler(writer http.ResponseWriter, req *http.Request)
return
}

errorMsg, code, err := func() (string, int, error) {
code, err := func() (int, error) {
ctx, cancel := context.WithTimeout(
context.Background(),
r.Config.OpenIDProviderTimeout,
Expand All @@ -338,33 +338,30 @@ func (r *OauthProxy) loginHandler(writer http.ResponseWriter, req *http.Request)
defer cancel()

if !r.Config.EnableLoginHandler {
return "attempt to login when login handler is disabled",
http.StatusNotImplemented,
errors.New("login handler disabled")
return http.StatusNotImplemented,
apperrors.ErrLoginWithLoginHandleDisabled
}

username := req.PostFormValue("username")
password := req.PostFormValue("password")

if username == "" || password == "" {
return "request does not have both username and password",
http.StatusBadRequest,
errors.New("no credentials")
return http.StatusBadRequest,
apperrors.ErrMissingLoginCreds
}

conf := r.newOAuth2Config(r.getRedirectionURL(writer, req))

start := time.Now()
token, err := conf.PasswordCredentialsToken(ctx, username, password)

if err != nil {
if !token.Valid() {
return "invalid user credentials provided", http.StatusUnauthorized, err
return http.StatusUnauthorized,
errors.Join(apperrors.ErrInvalidUserCreds, err)
}

return "unable to request the access token via grant_type 'password'",
http.StatusInternalServerError,
err
return http.StatusInternalServerError,
errors.Join(apperrors.ErrAcquireTokenViaPassCredsGrant, err)
}

// @metric observe the time taken for a login request
Expand All @@ -373,34 +370,28 @@ func (r *OauthProxy) loginHandler(writer http.ResponseWriter, req *http.Request)
accessToken := token.AccessToken
refreshToken := ""
accessTokenObj, err := jwt.ParseSigned(token.AccessToken)

if err != nil {
return "unable to decode the access token", http.StatusNotImplemented, err
return http.StatusNotImplemented,
errors.Join(apperrors.ErrParseAccessToken, err)
}

identity, err := ExtractIdentity(accessTokenObj)

if err != nil {
return "unable to extract identity from access token",
http.StatusNotImplemented,
err
return http.StatusNotImplemented,
errors.Join(apperrors.ErrExtractIdentityFromAccessToken, err)
}

writer.Header().Set("Content-Type", "application/json")
idToken, assertOk := token.Extra("id_token").(string)

if !assertOk {
return "",
http.StatusInternalServerError,
fmt.Errorf("token response does not contain an id_token")
return http.StatusInternalServerError,
apperrors.ErrResponseMissingIDToken
}

expiresIn, assertOk := token.Extra("expires_in").(float64)

if !assertOk {
return "",
http.StatusInternalServerError,
fmt.Errorf("token response does not contain expires_in")
return http.StatusInternalServerError,
apperrors.ErrResponseMissingExpires
}

// step: are we encrypting the access token?
Expand All @@ -409,28 +400,24 @@ func (r *OauthProxy) loginHandler(writer http.ResponseWriter, req *http.Request)
if r.Config.EnableEncryptedToken || r.Config.ForceEncryptedCookie {
if accessToken, err = encryption.EncodeText(accessToken, r.Config.EncryptionKey); err != nil {
scope.Logger.Error(apperrors.ErrEncryptAccToken.Error(), zap.Error(err))
return apperrors.ErrEncryptAccToken.Error(),
http.StatusInternalServerError,
err
return http.StatusInternalServerError,
errors.Join(apperrors.ErrEncryptAccToken, err)
}

if idToken, err = encryption.EncodeText(idToken, r.Config.EncryptionKey); err != nil {
scope.Logger.Error(apperrors.ErrEncryptIDToken.Error(), zap.Error(err))
return apperrors.ErrEncryptIDToken.Error(),
http.StatusInternalServerError,
err
return http.StatusInternalServerError,
errors.Join(apperrors.ErrEncryptIDToken, err)
}
}

// step: does the response have a refresh token and we do NOT ignore refresh tokens?
if r.Config.EnableRefreshTokens && token.RefreshToken != "" {
refreshToken, err = encryption.EncodeText(token.RefreshToken, r.Config.EncryptionKey)

if err != nil {
scope.Logger.Error(apperrors.ErrEncryptRefreshToken.Error(), zap.Error(err))
return apperrors.ErrEncryptRefreshToken.Error(),
http.StatusInternalServerError,
err
return http.StatusInternalServerError,
errors.Join(apperrors.ErrEncryptRefreshToken, err)
}

// drop in the access token - cookie expiration = access token
Expand All @@ -452,18 +439,14 @@ func (r *OauthProxy) loginHandler(writer http.ResponseWriter, req *http.Request)
// notes: not all idp refresh tokens are readable, google for example, so we attempt to decode into
// a jwt and if possible extract the expiration, else we default to 10 days
refreshTokenObj, errRef := jwt.ParseSigned(token.RefreshToken)

if errRef != nil {
scope.Logger.Error("failed to parse refresh token", zap.Error(errRef))
return "failed to parse refresh token",
http.StatusInternalServerError,
errRef
return http.StatusInternalServerError,
errors.Join(apperrors.ErrParseRefreshToken, err)
}

stdRefreshClaims := &jwt.Claims{}

err = refreshTokenObj.UnsafeClaimsWithoutVerification(stdRefreshClaims)

if err != nil {
expiration = 0
} else {
Expand All @@ -474,7 +457,7 @@ func (r *OauthProxy) loginHandler(writer http.ResponseWriter, req *http.Request)
case true:
if err = r.StoreRefreshToken(req.Context(), token.AccessToken, refreshToken, expiration); err != nil {
scope.Logger.Warn(
"failed to save the refresh token in the store",
apperrors.ErrSaveTokToStore.Error(),
zap.Error(err),
)
}
Expand Down Expand Up @@ -503,10 +486,8 @@ func (r *OauthProxy) loginHandler(writer http.ResponseWriter, req *http.Request)

if tokenScope != nil {
tScope, assertOk = tokenScope.(string)

if !assertOk {
return "",
http.StatusInternalServerError,
return http.StatusInternalServerError,
apperrors.ErrAssertionFailed
}
}
Expand Down Expand Up @@ -534,21 +515,19 @@ func (r *OauthProxy) loginHandler(writer http.ResponseWriter, req *http.Request)
}

err = json.NewEncoder(writer).Encode(resp)

if err != nil {
return "", http.StatusInternalServerError, err
return http.StatusInternalServerError, err
}

return "", http.StatusOK, nil
return http.StatusOK, nil
}()

if err != nil {
clientIP := utils.RealIP(req)
scope.Logger.Error(errorMsg,
scope.Logger.Error(err.Error(),
zap.String("client_ip", clientIP),
zap.String("remote_addr", req.RemoteAddr),
zap.Error(err))

)
writer.WriteHeader(code)
}
}
Expand Down Expand Up @@ -583,7 +562,6 @@ func (r *OauthProxy) logoutHandler(writer http.ResponseWriter, req *http.Request
}

scope, assertOk := req.Context().Value(constant.ContextScopeName).(*RequestScope)

if !assertOk {
r.Log.Error(apperrors.ErrAssertionFailed.Error())
writer.WriteHeader(http.StatusInternalServerError)
Expand All @@ -592,7 +570,6 @@ func (r *OauthProxy) logoutHandler(writer http.ResponseWriter, req *http.Request

// @step: drop the access token
user, err := r.GetIdentity(req, r.Config.CookieAccessName, "")

if err != nil {
r.accessError(writer, req)
return
Expand All @@ -607,7 +584,6 @@ func (r *OauthProxy) logoutHandler(writer http.ResponseWriter, req *http.Request
}

idToken, _, err := r.retrieveIDToken(req)

// we are doing it so that in case with no-redirects=true, we can pass
// id token in authorization header
if err != nil {
Expand All @@ -624,7 +600,7 @@ func (r *OauthProxy) logoutHandler(writer http.ResponseWriter, req *http.Request
go func() {
if err = r.DeleteRefreshToken(req.Context(), user.RawToken); err != nil {
scope.Logger.Error(
"unable to remove the refresh token from store",
apperrors.ErrDelTokFromStore.Error(),
zap.Error(err),
)
}
Expand Down Expand Up @@ -695,9 +671,8 @@ func (r *OauthProxy) logoutHandler(writer http.ResponseWriter, req *http.Request
fmt.Sprintf("token=%s", identityToken),
),
)

if err != nil {
scope.Logger.Error("unable to construct the revocation request", zap.Error(err))
scope.Logger.Error(apperrors.ErrCreateRevocationReq.Error(), zap.Error(err))
writer.WriteHeader(http.StatusInternalServerError)
return
}
Expand All @@ -708,9 +683,8 @@ func (r *OauthProxy) logoutHandler(writer http.ResponseWriter, req *http.Request

start := time.Now()
response, err := client.Do(request)

if err != nil {
scope.Logger.Error("unable to post to revocation endpoint", zap.Error(err))
scope.Logger.Error(apperrors.ErrRevocationReqFailure.Error(), zap.Error(err))
writer.WriteHeader(http.StatusInternalServerError)
return
}
Expand All @@ -731,7 +705,7 @@ func (r *OauthProxy) logoutHandler(writer http.ResponseWriter, req *http.Request
content, _ := io.ReadAll(response.Body)

scope.Logger.Error(
"invalid response from revocation endpoint",
apperrors.ErrInvalidRevocationResp.Error(),
zap.Int("status", response.StatusCode),
zap.String("response", string(content)),
)
Expand Down Expand Up @@ -865,7 +839,7 @@ func (r *OauthProxy) discoveryHandler(wrt http.ResponseWriter, _ *http.Request)

if err != nil {
r.Log.Error(
"problem marshalling response",
apperrors.ErrMarshallDiscoveryResp.Error(),
zap.String("error", err.Error()),
)

Expand All @@ -879,7 +853,7 @@ func (r *OauthProxy) discoveryHandler(wrt http.ResponseWriter, _ *http.Request)

if err != nil {
r.Log.Error(
"problem during response write",
apperrors.ErrDiscoveryResponseWrite.Error(),
zap.String("error", err.Error()),
)
}
Expand Down
Loading

0 comments on commit a082bfa

Please sign in to comment.