From 6fbbff456082ff87c1f8da818fcdb9cf557af29e Mon Sep 17 00:00:00 2001 From: daijingze_mac <18373118@buaa.edu.cn> Date: Thu, 26 Dec 2024 11:25:25 +0800 Subject: [PATCH] github validate token --- pkg/middleware/stored_session.go | 44 +++++++++++------ providers/github.go | 25 ++++++---- providers/internal_util | 83 ++++++++++++++++++++++++++++++++ providers/oidc.go | 10 ++-- providers/provider_default.go | 4 +- providers/providers.go | 2 +- 6 files changed, 134 insertions(+), 34 deletions(-) create mode 100644 providers/internal_util diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 1bbf57c785..de2fd17f57 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -28,13 +28,15 @@ type StoredSessionLoaderOptions struct { // How often should sessions be refreshed RefreshPeriod time.Duration - // Provider based session refreshing + // Provider based session + // return value isAsync and error RefreshSession func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, error) // Provider based session validation. // If the session is older than `RefreshPeriod` but the provider doesn't // refresh it, we must re-validate using this validation. - ValidateSession func(context.Context, *sessionsapi.SessionState) bool + // return value valid and isAsync + ValidateSession func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, bool) // Refresh request parameters RefreshClient wrapper.HttpClient @@ -63,7 +65,7 @@ type StoredSessionLoader struct { store sessionsapi.SessionStore refreshPeriod time.Duration sessionRefresher func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, error) - sessionValidator func(context.Context, *sessionsapi.SessionState) bool + sessionValidator func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, bool) // Refresh request parameters refreshClient wrapper.HttpClient @@ -90,15 +92,25 @@ func (s *StoredSessionLoader) loadSession(next http.Handler) http.Handler { resumeFlag := args[1].(bool) updateKeysCallback := func(args ...interface{}) { resumeFlag := args[0].(bool) - if session != nil && s.validateSession(req.Context(), session) != nil { - session = nil + validateSessionCallback := func(args ...interface{}) { + scope.Session = session + next.ServeHTTP(rw, req) + if resumeFlag { + if rw.Header().Get(util.ResponseCode) == string(http.StatusOK) { + proxywasm.ResumeHttpRequest() + } + } } - scope.Session = session - next.ServeHTTP(rw, req) - if resumeFlag { - if rw.Header().Get(util.ResponseCode) == string(http.StatusOK) { - proxywasm.ResumeHttpRequest() + if session != nil { + err, isAsync := s.validateSession(req.Context(), session, validateSessionCallback) + if err != nil { + session = nil + } + if !isAsync { + validateSessionCallback() } + } else { + validateSessionCallback() } } keysNeedsUpdate := (session != nil) && (s.NeedsVerifier) @@ -206,14 +218,14 @@ func (s *StoredSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R // validateSession checks whether the session has expired and performs // provider validation on the session. // An error implies the session is no longer valid. -func (s *StoredSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState) error { +func (s *StoredSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState, callback func(args ...interface{})) (error, bool) { if session.IsExpired() { - return errors.New("session is expired") + return errors.New("session is expired"), false } - - if !s.sessionValidator(ctx, session) { - return errors.New("session is invalid") + valid, isAsync := s.sessionValidator(ctx, session, s.refreshClient, callback, s.refreshRequestTimeout) + if !valid { + return errors.New("session is invalid"), isAsync } - return nil + return nil, isAsync } diff --git a/providers/github.go b/providers/github.go index 750beba0e7..45ffae9882 100644 --- a/providers/github.go +++ b/providers/github.go @@ -1,7 +1,12 @@ package providers import ( + "context" + "net/http" "net/url" + + "github.com/Jing-ze/oauth2-proxy/pkg/apis/sessions" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) // GitHubProvider represents an GitHub based Identity Provider @@ -60,15 +65,15 @@ func NewGitHubProvider(p *ProviderData) *GitHubProvider { return provider } -// func makeGitHubHeader(accessToken string) http.Header { -// // extra headers required by the GitHub API when making authenticated requests -// extraHeaders := map[string]string{ -// acceptHeader: "application/vnd.github.v3+json", -// } -// return makeAuthorizationHeader(tokenTypeToken, accessToken, extraHeaders) -// } +func makeGitHubHeader(accessToken string) http.Header { + // extra headers required by the GitHub API when making authenticated requests + extraHeaders := map[string]string{ + acceptHeader: "application/vnd.github.v3+json", + } + return makeAuthorizationHeader(tokenTypeToken, accessToken, extraHeaders) +} // ValidateSession validates the AccessToken -// func (p *GitHubProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { -// return validateToken(ctx, p, s.AccessToken, makeGitHubHeader(s.AccessToken)) -// } +func (p *GitHubProvider) ValidateSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, bool) { + return validateToken(ctx, p, s.AccessToken, makeGitHubHeader(s.AccessToken), client, callback, timeout) +} diff --git a/providers/internal_util b/providers/internal_util new file mode 100644 index 0000000000..f42e4e15c9 --- /dev/null +++ b/providers/internal_util @@ -0,0 +1,83 @@ +package providers + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/Jing-ze/oauth2-proxy/pkg/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +// stripToken is a helper function to obfuscate "access_token" +// query parameters +func stripToken(endpoint string) string { + return stripParam("access_token", endpoint) +} + +// stripParam generalizes the obfuscation of a particular +// query parameter - typically 'access_token' or 'client_secret' +// The parameter's second half is replaced by '...' and returned +// as part of the encoded query parameters. +// If the target parameter isn't found, the endpoint is returned +// unmodified. +func stripParam(param, endpoint string) string { + u, err := url.Parse(endpoint) + if err != nil { + util.Logger.Errorf("error attempting to strip %s: %s", param, err) + return endpoint + } + + if u.RawQuery != "" { + values, err := url.ParseQuery(u.RawQuery) + if err != nil { + util.Logger.Errorf("error attempting to strip %s: %s", param, err) + return u.String() + } + + if val := values.Get(param); val != "" { + values.Set(param, val[:(len(val)/2)]+"...") + u.RawQuery = values.Encode() + return u.String() + } + } + + return endpoint +} + +// validateToken returns true if token is valid +func validateToken(ctx context.Context, p Provider, accessToken string, header http.Header, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, bool) { + if accessToken == "" || p.Data().ValidateURL == nil || p.Data().ValidateURL.String() == "" { + return false, false + } + endpoint := p.Data().ValidateURL.String() + if len(header) == 0 { + params := url.Values{"access_token": {accessToken}} + if hasQueryParams(endpoint) { + endpoint = endpoint + "&" + params.Encode() + } else { + endpoint = endpoint + "?" + params.Encode() + } + } + + client.Get(endpoint, headers, []byte{}, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + util.Logger.Debugf("%d GET %s %s", statusCode, stripToken(endpoint), responseBody) + if statusCode == 200 { + callback() + } else { + util.SendError(fmt.Sprintf("token validation request failed: status %d - %s", result.StatusCode(), result.Body()), nil, http.StatusInternalServerError) + } + }, timeout) + return true, true +} + +// hasQueryParams check if URL has query parameters +func hasQueryParams(endpoint string) bool { + endpointURL, err := url.Parse(endpoint) + if err != nil { + return false + } + + return len(endpointURL.RawQuery) != 0 +} diff --git a/providers/oidc.go b/providers/oidc.go index 5354827981..609feaf08d 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -117,21 +117,21 @@ func (p *OIDCProvider) EnrichSession(_ context.Context, s *sessions.SessionState } // ValidateSession checks that the session's IDToken is still valid -func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { +func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, bool) { _, err := p.Verifier.Verify(ctx, s.IDToken) if err != nil { util.Logger.Errorf("id_token verification failed: %v", err) - return false + return false, false } if p.SkipNonce { - return true + return true, false } err = p.checkNonce(s) if err != nil { util.Logger.Errorf("nonce verification failed: %v", err) - return false + return false, false } - return true + return true, false } // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens diff --git a/providers/provider_default.go b/providers/provider_default.go index 9311593fd8..ff2cb0927c 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -141,8 +141,8 @@ func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (b } // ValidateSession validates the AccessToken -func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { - return true +func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, bool) { + return true, false } // RefreshSession refreshes the user's session diff --git a/providers/providers.go b/providers/providers.go index ec1de38fdc..236f0f244f 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -31,7 +31,7 @@ type Provider interface { GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) EnrichSession(ctx context.Context, s *sessions.SessionState) error Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) - ValidateSession(ctx context.Context, s *sessions.SessionState) bool + ValidateSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, bool) RefreshSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, error) }