Skip to content

Commit

Permalink
github validate token
Browse files Browse the repository at this point in the history
  • Loading branch information
Jing-ze committed Dec 26, 2024
1 parent 99545cf commit 6fbbff4
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 34 deletions.
44 changes: 28 additions & 16 deletions pkg/middleware/stored_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ type StoredSessionLoaderOptions struct {
// How often should sessions be refreshed
RefreshPeriod time.Duration

// Provider based session refreshing
// Provider based session
// return value isAsync and error
RefreshSession func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, error)

// Provider based session validation.
// If the session is older than `RefreshPeriod` but the provider doesn't
// refresh it, we must re-validate using this validation.
ValidateSession func(context.Context, *sessionsapi.SessionState) bool
// return value valid and isAsync
ValidateSession func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, bool)

// Refresh request parameters
RefreshClient wrapper.HttpClient
Expand Down Expand Up @@ -63,7 +65,7 @@ type StoredSessionLoader struct {
store sessionsapi.SessionStore
refreshPeriod time.Duration
sessionRefresher func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, error)
sessionValidator func(context.Context, *sessionsapi.SessionState) bool
sessionValidator func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, bool)

// Refresh request parameters
refreshClient wrapper.HttpClient
Expand All @@ -90,15 +92,25 @@ func (s *StoredSessionLoader) loadSession(next http.Handler) http.Handler {
resumeFlag := args[1].(bool)
updateKeysCallback := func(args ...interface{}) {
resumeFlag := args[0].(bool)
if session != nil && s.validateSession(req.Context(), session) != nil {
session = nil
validateSessionCallback := func(args ...interface{}) {
scope.Session = session
next.ServeHTTP(rw, req)
if resumeFlag {
if rw.Header().Get(util.ResponseCode) == string(http.StatusOK) {
proxywasm.ResumeHttpRequest()
}
}
}
scope.Session = session
next.ServeHTTP(rw, req)
if resumeFlag {
if rw.Header().Get(util.ResponseCode) == string(http.StatusOK) {
proxywasm.ResumeHttpRequest()
if session != nil {
err, isAsync := s.validateSession(req.Context(), session, validateSessionCallback)
if err != nil {
session = nil
}
if !isAsync {
validateSessionCallback()
}
} else {
validateSessionCallback()
}
}
keysNeedsUpdate := (session != nil) && (s.NeedsVerifier)
Expand Down Expand Up @@ -206,14 +218,14 @@ func (s *StoredSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R
// validateSession checks whether the session has expired and performs
// provider validation on the session.
// An error implies the session is no longer valid.
func (s *StoredSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState) error {
func (s *StoredSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState, callback func(args ...interface{})) (error, bool) {
if session.IsExpired() {
return errors.New("session is expired")
return errors.New("session is expired"), false
}

if !s.sessionValidator(ctx, session) {
return errors.New("session is invalid")
valid, isAsync := s.sessionValidator(ctx, session, s.refreshClient, callback, s.refreshRequestTimeout)
if !valid {
return errors.New("session is invalid"), isAsync
}

return nil
return nil, isAsync
}
25 changes: 15 additions & 10 deletions providers/github.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package providers

import (
"context"
"net/http"
"net/url"

"github.com/Jing-ze/oauth2-proxy/pkg/apis/sessions"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)

// GitHubProvider represents an GitHub based Identity Provider
Expand Down Expand Up @@ -60,15 +65,15 @@ func NewGitHubProvider(p *ProviderData) *GitHubProvider {
return provider
}

// func makeGitHubHeader(accessToken string) http.Header {
// // extra headers required by the GitHub API when making authenticated requests
// extraHeaders := map[string]string{
// acceptHeader: "application/vnd.github.v3+json",
// }
// return makeAuthorizationHeader(tokenTypeToken, accessToken, extraHeaders)
// }
func makeGitHubHeader(accessToken string) http.Header {
// extra headers required by the GitHub API when making authenticated requests
extraHeaders := map[string]string{
acceptHeader: "application/vnd.github.v3+json",
}
return makeAuthorizationHeader(tokenTypeToken, accessToken, extraHeaders)
}

// ValidateSession validates the AccessToken
// func (p *GitHubProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
// return validateToken(ctx, p, s.AccessToken, makeGitHubHeader(s.AccessToken))
// }
func (p *GitHubProvider) ValidateSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, bool) {
return validateToken(ctx, p, s.AccessToken, makeGitHubHeader(s.AccessToken), client, callback, timeout)
}
83 changes: 83 additions & 0 deletions providers/internal_util
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package providers

import (
"context"
"fmt"
"net/http"
"net/url"

"github.com/Jing-ze/oauth2-proxy/pkg/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)

// stripToken is a helper function to obfuscate "access_token"
// query parameters
func stripToken(endpoint string) string {
return stripParam("access_token", endpoint)
}

// stripParam generalizes the obfuscation of a particular
// query parameter - typically 'access_token' or 'client_secret'
// The parameter's second half is replaced by '...' and returned
// as part of the encoded query parameters.
// If the target parameter isn't found, the endpoint is returned
// unmodified.
func stripParam(param, endpoint string) string {
u, err := url.Parse(endpoint)
if err != nil {
util.Logger.Errorf("error attempting to strip %s: %s", param, err)
return endpoint
}

if u.RawQuery != "" {
values, err := url.ParseQuery(u.RawQuery)
if err != nil {
util.Logger.Errorf("error attempting to strip %s: %s", param, err)
return u.String()
}

if val := values.Get(param); val != "" {
values.Set(param, val[:(len(val)/2)]+"...")
u.RawQuery = values.Encode()
return u.String()
}
}

return endpoint
}

// validateToken returns true if token is valid
func validateToken(ctx context.Context, p Provider, accessToken string, header http.Header, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, bool) {
if accessToken == "" || p.Data().ValidateURL == nil || p.Data().ValidateURL.String() == "" {
return false, false
}
endpoint := p.Data().ValidateURL.String()
if len(header) == 0 {
params := url.Values{"access_token": {accessToken}}
if hasQueryParams(endpoint) {
endpoint = endpoint + "&" + params.Encode()
} else {
endpoint = endpoint + "?" + params.Encode()
}
}

client.Get(endpoint, headers, []byte{}, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
util.Logger.Debugf("%d GET %s %s", statusCode, stripToken(endpoint), responseBody)
if statusCode == 200 {
callback()
} else {
util.SendError(fmt.Sprintf("token validation request failed: status %d - %s", result.StatusCode(), result.Body()), nil, http.StatusInternalServerError)
}
}, timeout)
return true, true
}

// hasQueryParams check if URL has query parameters
func hasQueryParams(endpoint string) bool {
endpointURL, err := url.Parse(endpoint)
if err != nil {
return false
}

return len(endpointURL.RawQuery) != 0
}
10 changes: 5 additions & 5 deletions providers/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,21 @@ func (p *OIDCProvider) EnrichSession(_ context.Context, s *sessions.SessionState
}

// ValidateSession checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, bool) {
_, err := p.Verifier.Verify(ctx, s.IDToken)
if err != nil {
util.Logger.Errorf("id_token verification failed: %v", err)
return false
return false, false
}
if p.SkipNonce {
return true
return true, false
}
err = p.checkNonce(s)
if err != nil {
util.Logger.Errorf("nonce verification failed: %v", err)
return false
return false, false
}
return true
return true, false
}

// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
Expand Down
4 changes: 2 additions & 2 deletions providers/provider_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (b
}

// ValidateSession validates the AccessToken
func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
return true
func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, bool) {
return true, false
}

// RefreshSession refreshes the user's session
Expand Down
2 changes: 1 addition & 1 deletion providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type Provider interface {
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
EnrichSession(ctx context.Context, s *sessions.SessionState) error
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
ValidateSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, bool)
RefreshSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, error)
}

Expand Down

0 comments on commit 6fbbff4

Please sign in to comment.