Skip to content

Commit

Permalink
Merge pull request higress-group#5 from Jing-ze/provider-github
Browse files Browse the repository at this point in the history
feat: support provider type github
  • Loading branch information
johnlanni authored Dec 27, 2024
2 parents 2e4ee25 + 55566a0 commit c1a05d7
Show file tree
Hide file tree
Showing 12 changed files with 346 additions and 78 deletions.
89 changes: 56 additions & 33 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ type OAuthProxy struct {
appDirector redirect.AppDirector

passAuthorization bool
passAccessToken bool
encodeState bool

client wrapper.HttpClient
client wrapper.HttpClient
validateClient wrapper.HttpClient
}

// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
Expand Down Expand Up @@ -103,11 +105,19 @@ func NewOAuthProxy(opts *options.Options) (*OAuthProxy, error) {
return nil, err
}

var validateServiceClient wrapper.HttpClient
if opts.ValidateService.ServiceName != "" {
validateServiceClient, err = opts.ValidateService.NewService()
if err != nil {
return nil, err
}
}

preAuthChain, err := buildPreAuthChain(opts)
if err != nil {
return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
}
sessionChain := buildSessionChain(opts, provider, sessionStore, serviceClient)
sessionChain := buildSessionChain(opts, provider, sessionStore, serviceClient, validateServiceClient)

redirectValidator := redirect.NewValidator(opts.WhitelistDomains)
appDirector := redirect.NewAppDirector(redirect.AppDirectorOpts{
Expand Down Expand Up @@ -137,16 +147,18 @@ func NewOAuthProxy(opts *options.Options) (*OAuthProxy, error) {
appDirector: appDirector,
encodeState: opts.EncodeState,
passAuthorization: opts.PassAuthorization,
passAccessToken: opts.PassAccessToken,

client: serviceClient,
client: serviceClient,
validateClient: validateServiceClient,
}
p.buildServeMux(opts.ProxyPrefix)

return p, nil
}

func SetLogger(log wrapper.Log) {
util.Logger = &log
util.Logger = log
}

func (p *OAuthProxy) buildServeMux(proxyPrefix string) {
Expand Down Expand Up @@ -183,16 +195,18 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
return chain, nil
}

func buildSessionChain(opts *options.Options, provider providers.Provider, sessionStore sessionsapi.SessionStore, serviceClient wrapper.HttpClient) alice.Chain {
func buildSessionChain(opts *options.Options, provider providers.Provider, sessionStore sessionsapi.SessionStore, serviceClient wrapper.HttpClient, validateClient wrapper.HttpClient) alice.Chain {
chain := alice.New()

ss, loadSession := middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh,
RefreshSession: provider.RefreshSession,
ValidateSession: provider.ValidateSession,
RefreshClient: serviceClient,
RefreshRequestTimeout: provider.Data().RedeemTimeout,
SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh,
RefreshSession: provider.RefreshSession,
ValidateSession: provider.ValidateSession,
RefreshClient: serviceClient,
ValidateClient: validateClient,
RefreshRequestTimeout: provider.Data().RedeemTimeout,
ValidateRequestTimeout: provider.Data().RedeemTimeout,
})
chain = chain.Append(loadSession)
provider.Data().StoredSession = ss
Expand Down Expand Up @@ -381,30 +395,35 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
csrf.SetSessionNonce(session)

updateKeysCallback := func(args ...interface{}) {
if !p.provider.ValidateSession(req.Context(), session) {
util.SendError(fmt.Sprintf("Session validation failed: %s", session), rw, http.StatusForbidden)
return
}
util.Logger.Debug("Session validated successfully")
if !p.redirectValidator.IsValidRedirect(appRedirect) {
appRedirect = "/"
util.Logger.Debugf("Invalid redirect, defaulting to root: %s", appRedirect)
}
// set cookie, or deny
authorized, err := p.provider.Authorize(req.Context(), session)
if err != nil {
util.Logger.Errorf("Error with authorization: %v", err)
}
if p.validator(session.Email) && authorized {
util.Logger.Infof("Authenticated successfully via OAuth2: %s", session)
err := p.SaveSession(rw, req, session)
validateSessionCallback := func(args ...interface{}) {
util.Logger.Debug("Session validated successfully")
if !p.redirectValidator.IsValidRedirect(appRedirect) {
appRedirect = "/"
util.Logger.Debugf("Invalid redirect, defaulting to root: %s", appRedirect)
}
// set cookie, or deny
authorized, err := p.provider.Authorize(req.Context(), session)
if err != nil {
util.SendError(fmt.Sprintf("Error saving session state: %v", err), rw, http.StatusInternalServerError)
return
util.Logger.Errorf("Error with authorization: %v", err)
}
if p.validator(session.Email) && authorized {
util.Logger.Infof("Authenticated successfully via OAuth2: %s", session)
err := p.SaveSession(rw, req, session)
if err != nil {
util.SendError(fmt.Sprintf("Error saving session state: %v", err), rw, http.StatusInternalServerError)
return
}
redirectToLocation(rw, appRedirect)
} else {
util.SendError("Invalid authentication via OAuth2: unauthorized", rw, http.StatusForbidden)
}
redirectToLocation(rw, appRedirect)
} else {
util.SendError("Invalid authentication via OAuth2: unauthorized", rw, http.StatusForbidden)
}
valid, isAsync := p.provider.ValidateSession(req.Context(), session, p.validateClient, validateSessionCallback, p.provider.Data().RedeemTimeout)
if !valid {
util.SendError(fmt.Sprintf("Session validation failed: %s", session), rw, http.StatusForbidden)
return
} else if !isAsync {
validateSessionCallback()
}
}
if p.provider.Data().NeedsVerifier {
Expand Down Expand Up @@ -441,6 +460,10 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
proxywasm.AddHttpRequestHeader("Authorization", fmt.Sprintf("%s %s", providers.TokenTypeBearer, session.IDToken))
util.Logger.Debug("Authorization header add id token")
}
if p.passAccessToken {
proxywasm.AddHttpRequestHeader("X-Forwarded-Access-Token", session.AccessToken)
util.Logger.Debug("X-Forwarded-Access-Token header add access token")
}
if cookies, ok := rw.Header()[SetCookieHeader]; ok && len(cookies) > 0 {
newCookieValue := strings.Join(cookies, ",")
if p.ctx != nil {
Expand Down
11 changes: 7 additions & 4 deletions pkg/apis/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ type Options struct {

WhitelistDomains []string `mapstructure:"whitelist_domains"`

Cookie Cookie `mapstructure:",squash"`
Session SessionOptions `mapstructure:",squash"`
Service Service `mapstructure:",squash"`
MatchRules MatchRules `mapstructure:",squash"`
Cookie Cookie `mapstructure:",squash"`
Session SessionOptions `mapstructure:",squash"`
Service Service `mapstructure:",squash"`
ValidateService ValidateService `mapstructure:",squash"`
MatchRules MatchRules `mapstructure:",squash"`

Providers Providers

SkipAuthPreflight bool `mapstructure:"skip_auth_preflight"`
EncodeState bool `mapstructure:"encode_state"`
PassAuthorization bool `mapstructure:"pass_authorization_header"`
PassAccessToken bool `mapstructure:"pass_access_token"`

VerifierInterval time.Duration `mapstructure:"verifier_interval"`
UpdateKeysInterval time.Duration `mapstructure:"update_keys_interval"`
Expand All @@ -57,6 +59,7 @@ func NewOptions() *Options {
Session: sessionOptionsDefaults(),
SkipAuthPreflight: false,
PassAuthorization: true,
PassAccessToken: false,
VerifierInterval: 2 * time.Second, // 5 seconds
UpdateKeysInterval: 24 * time.Hour, // 24 hours
MatchRules: matchRulesDefaults(),
Expand Down
2 changes: 2 additions & 0 deletions pkg/apis/options/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ const (
OIDCProvider ProviderType = "oidc"

AliyunProvider ProviderType = "aliyun"

GitHubProvider ProviderType = "github"
)

type OIDCOptions struct {
Expand Down
19 changes: 19 additions & 0 deletions pkg/apis/options/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,22 @@ func (s *Service) NewService() (wrapper.HttpClient, error) {
})
return client, nil
}

type ValidateService struct {
// 带服务类型的完整 FQDN 名称,例如 keycloak.static, auth.dns
ServiceName string `mapstructure:"validate_service_name"`
ServicePort int64 `mapstructure:"validate_service_port"`
ServiceHost string `mapstructure:"validate_service_host"`
}

func (s *ValidateService) NewService() (wrapper.HttpClient, error) {
if s.ServiceName == "" || s.ServicePort == 0 {
return nil, errors.New("invalid service config")
}
client := wrapper.NewClusterClient(&wrapper.FQDNCluster{
FQDN: s.ServiceName,
Host: s.ServiceHost,
Port: s.ServicePort,
})
return client, nil
}
73 changes: 49 additions & 24 deletions pkg/middleware/stored_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,23 @@ type StoredSessionLoaderOptions struct {
// How often should sessions be refreshed
RefreshPeriod time.Duration

// Provider based session refreshing
// Provider based session
// return value isAsync and error
RefreshSession func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, error)

// Provider based session validation.
// If the session is older than `RefreshPeriod` but the provider doesn't
// refresh it, we must re-validate using this validation.
ValidateSession func(context.Context, *sessionsapi.SessionState) bool
// return value valid and isAsync
ValidateSession func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, bool)

// Refresh request parameters
RefreshClient wrapper.HttpClient
RefreshRequestTimeout uint32

// Validate request parameters
ValidateClient wrapper.HttpClient
ValidateRequestTimeout uint32
}

// NewStoredSessionLoader creates a new StoredSessionLoader which loads
Expand All @@ -47,12 +53,14 @@ type StoredSessionLoaderOptions struct {
// If a session was loader by a previous handler, it will not be replaced.
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) (*StoredSessionLoader, alice.Constructor) {
ss := &StoredSessionLoader{
store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod,
sessionRefresher: opts.RefreshSession,
sessionValidator: opts.ValidateSession,
refreshClient: opts.RefreshClient,
refreshRequestTimeout: opts.RefreshRequestTimeout,
store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod,
sessionRefresher: opts.RefreshSession,
sessionValidator: opts.ValidateSession,
refreshClient: opts.RefreshClient,
refreshRequestTimeout: opts.RefreshRequestTimeout,
validateClient: opts.ValidateClient,
validateRequestTimeout: opts.ValidateRequestTimeout,
}
return ss, ss.loadSession
}
Expand All @@ -63,13 +71,18 @@ type StoredSessionLoader struct {
store sessionsapi.SessionStore
refreshPeriod time.Duration
sessionRefresher func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, error)
sessionValidator func(context.Context, *sessionsapi.SessionState) bool
sessionValidator func(context.Context, *sessionsapi.SessionState, wrapper.HttpClient, func(args ...interface{}), uint32) (bool, bool)

// Refresh request parameters
refreshClient wrapper.HttpClient
refreshRequestTimeout uint32
RemoteKeySet *oidc.KeySet
NeedsVerifier bool

// Validate request parameters
validateClient wrapper.HttpClient
validateRequestTimeout uint32

RemoteKeySet *oidc.KeySet
NeedsVerifier bool
}

// loadSession attempts to load a session as identified by the request cookies.
Expand All @@ -90,15 +103,27 @@ func (s *StoredSessionLoader) loadSession(next http.Handler) http.Handler {
resumeFlag := args[1].(bool)
updateKeysCallback := func(args ...interface{}) {
resumeFlag := args[0].(bool)
if session != nil && s.validateSession(req.Context(), session) != nil {
session = nil
validateSessionCallback := func(args ...interface{}) {
resumeFlag := args[0].(bool)
sessionValid := args[1].(bool)
if !sessionValid {
session = nil
}
scope.Session = session
next.ServeHTTP(rw, req)
if resumeFlag {
if rw.Header().Get(util.ResponseCode) == string(http.StatusOK) {
proxywasm.ResumeHttpRequest()
}
}
}
scope.Session = session
next.ServeHTTP(rw, req)
if resumeFlag {
if rw.Header().Get(util.ResponseCode) == string(http.StatusOK) {
proxywasm.ResumeHttpRequest()
if session != nil {
err, isAsync := s.validateSession(req.Context(), session, validateSessionCallback)
if !isAsync {
validateSessionCallback(resumeFlag, err == nil)
}
} else {
validateSessionCallback(resumeFlag, true)
}
}
keysNeedsUpdate := (session != nil) && (s.NeedsVerifier)
Expand Down Expand Up @@ -206,14 +231,14 @@ func (s *StoredSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R
// validateSession checks whether the session has expired and performs
// provider validation on the session.
// An error implies the session is no longer valid.
func (s *StoredSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState) error {
func (s *StoredSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState, callback func(args ...interface{})) (error, bool) {
if session.IsExpired() {
return errors.New("session is expired")
return errors.New("session is expired"), false
}

if !s.sessionValidator(ctx, session) {
return errors.New("session is invalid")
valid, isAsync := s.sessionValidator(ctx, session, s.validateClient, callback, s.validateRequestTimeout)
if !valid {
return errors.New("session is invalid"), isAsync
}

return nil
return nil, isAsync
}
2 changes: 1 addition & 1 deletion pkg/util/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)

var Logger *wrapper.Log
var Logger wrapper.Log

func SendError(errMsg string, rw http.ResponseWriter, status int) {
Logger.Errorf(errMsg)
Expand Down
Loading

0 comments on commit c1a05d7

Please sign in to comment.