From 6a6dc12f30e8b1761e290170995c1b340c1f5152 Mon Sep 17 00:00:00 2001 From: jackHay22 Date: Wed, 27 Sep 2023 13:38:55 -0400 Subject: [PATCH 1/4] WIP: add auth type to external account linking --- models/auth/oauth2.go | 16 ++-- models/auth/saml.go | 28 ------ models/auth/source.go | 8 ++ routers/web/auth/auth.go | 20 ++-- routers/web/auth/linkaccount.go | 34 +++---- routers/web/auth/oauth.go | 19 ++-- routers/web/auth/openid.go | 5 +- routers/web/auth/saml.go | 14 +-- services/auth/source/oauth2/init.go | 2 +- services/auth/source/oauth2/providers.go | 2 +- services/auth/source/saml/init.go | 2 +- services/auth/source/saml/providers.go | 85 ----------------- services/auth/source/saml/source.go | 97 +++++++++++++++++++- services/auth/source/saml/source_callout.go | 21 ++++- services/auth/source/saml/source_register.go | 11 +-- services/externalaccount/link.go | 11 ++- services/externalaccount/user.go | 12 +-- 17 files changed, 195 insertions(+), 192 deletions(-) delete mode 100644 models/auth/saml.go diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go index 9c419eff69af..605f775edaef 100644 --- a/models/auth/oauth2.go +++ b/models/auth/oauth2.go @@ -386,7 +386,7 @@ func ListOAuth2Applications(uid int64, listOptions db.ListOptions) ([]*OAuth2App return apps, total, err } -////////////////////////////////////////////////////// +// //////////////////////////////////////////////////// // OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime. type OAuth2AuthorizationCode struct { @@ -461,7 +461,7 @@ func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth return auth, nil } -////////////////////////////////////////////////////// +// //////////////////////////////////////////////////// // OAuth2Grant represents the permission of an user for a specific application to access resources type OAuth2Grant struct { @@ -626,19 +626,19 @@ func (err ErrOAuthApplicationNotFound) Unwrap() error { return util.ErrNotExist } -// GetActiveOAuth2ProviderSources returns all actived LoginOAuth2 sources -func GetActiveOAuth2ProviderSources() ([]*Source, error) { +// GetActiveAuthProviderSources returns all actived LoginOAuth2 sources +func GetActiveAuthProviderSources(authType Type) ([]*Source, error) { sources := make([]*Source, 0, 1) - if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil { + if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, authType).Find(&sources); err != nil { return nil, err } return sources, nil } -// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name -func GetActiveOAuth2SourceByName(name string) (*Source, error) { +// GetActiveAuthSourceByName returns a OAuth2 AuthSource based on the given name +func GetActiveAuthSourceByName(name string, authType Type) (*Source, error) { authSource := new(Source) - has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource) + has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, authType, true).Get(authSource) if err != nil { return nil, err } diff --git a/models/auth/saml.go b/models/auth/saml.go deleted file mode 100644 index 5d20baa82a3c..000000000000 --- a/models/auth/saml.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2023 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package auth - -import ( - "code.gitea.io/gitea/models/db" -) - -// GetActiveSAMLProviderLoginSources returns all actived LoginSAML sources -func GetActiveSAMLProviderLoginSources() ([]*Source, error) { - sources := make([]*Source, 0, 1) - if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, SAML).Find(&sources); err != nil { - return nil, err - } - return sources, nil -} - -// GetActiveSAMLLoginSourceByName returns a SAML LoginSource based on the given name -func GetActiveSAMLLoginSourceByName(name string) (*Source, error) { - loginSource := new(Source) - has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, SAML, true).Get(loginSource) - if !has || err != nil { - return nil, err - } - - return loginSource, nil -} diff --git a/models/auth/source.go b/models/auth/source.go index f1142c714ac9..8963c442681e 100644 --- a/models/auth/source.go +++ b/models/auth/source.go @@ -8,6 +8,8 @@ import ( "fmt" "reflect" + "github.com/markbates/goth" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/timeutil" @@ -121,6 +123,12 @@ type Source struct { UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` } +// LinkAccountUser is used to link an external user with a local user +type LinkAccountUser struct { + Type Type + GothUser goth.User +} + // TableName xorm will read the table name from this method func (Source) TableName() string { return "login_source" diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index 61e083e22845..36faabfa7a44 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -154,7 +154,7 @@ func SignIn(ctx *context.Context) { ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names ctx.Data["OAuth2Providers"] = oauth2Providers - samlProviders, err := auth.GetActiveSAMLProviderLoginSources() + samlProviders, err := auth.GetActiveAuthProviderSources(auth.SAML) if err != nil { ctx.ServerError("UserSignIn", err) return @@ -496,7 +496,7 @@ func SignUpPost(ctx *context.Context) { Passwd: form.Password, } - if !createAndHandleCreatedUser(ctx, tplSignUp, form, u, nil, nil, false) { + if !createAndHandleCreatedUser(ctx, tplSignUp, form, u, nil, nil, false, auth.NoType) { // error already handled return } @@ -507,16 +507,16 @@ func SignUpPost(ctx *context.Context) { // createAndHandleCreatedUser calls createUserInContext and // then handleUserCreated. -func createAndHandleCreatedUser(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool) bool { - if !createUserInContext(ctx, tpl, form, u, overwrites, gothUser, allowLink) { +func createAndHandleCreatedUser(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool, authType auth.Type) bool { + if !createUserInContext(ctx, tpl, form, u, overwrites, gothUser, allowLink, authType) { return false } - return handleUserCreated(ctx, u, gothUser) + return handleUserCreated(ctx, u, gothUser, authType) } // createUserInContext creates a user and handles errors within a given context. // Optionally a template can be specified. -func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool) (ok bool) { +func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool, authType auth.Type) (ok bool) { if err := user_model.CreateUser(ctx, u, overwrites); err != nil { if allowLink && (user_model.IsErrUserAlreadyExist(err) || user_model.IsErrEmailAlreadyUsed(err)) { if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingAuto { @@ -533,10 +533,10 @@ func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *us } // TODO: probably we should respect 'remember' user's choice... - linkAccount(ctx, user, *gothUser, true) + linkAccount(ctx, user, *gothUser, true, authType) return false // user is already created here, all redirects are handled } else if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingLogin { - showLinkingLogin(ctx, *gothUser) + showLinkingLogin(ctx, *gothUser, authType) return false // user will be created only after linking login } } @@ -582,7 +582,7 @@ func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *us // handleUserCreated does additional steps after a new user is created. // It auto-sets admin for the only user, updates the optional external user and // sends a confirmation email if required. -func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth.User) (ok bool) { +func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth.User, authType auth.Type) (ok bool) { // Auto-set admin for the only user. if user_model.CountUsers(ctx, nil) == 1 { u.IsAdmin = true @@ -596,7 +596,7 @@ func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth. // update external user information if gothUser != nil { - if err := externalaccount.UpdateExternalUser(u, *gothUser); err != nil { + if err := externalaccount.UpdateExternalUser(u, *gothUser, authType); err != nil { if !errors.Is(err, util.ErrNotExist) { log.Error("UpdateExternalUser failed: %v", err) } diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go index 42d846180d61..3820b33ce55a 100644 --- a/routers/web/auth/linkaccount.go +++ b/routers/web/auth/linkaccount.go @@ -48,13 +48,13 @@ func LinkAccount(ctx *context.Context) { ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin" ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup" - gothUser := ctx.Session.Get("linkAccountGothUser") - if gothUser == nil { + externalLinkUser := ctx.Session.Get("linkAccountUser") + if externalLinkUser == nil { ctx.ServerError("UserSignIn", errors.New("not in LinkAccount session")) return } - gu, _ := gothUser.(goth.User) + gu := externalLinkUser.(auth.LinkAccountUser).GothUser uname := getUserName(&gu) email := gu.Email ctx.Data["user_name"] = uname @@ -131,12 +131,14 @@ func LinkAccountPostSignIn(ctx *context.Context) { ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin" ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup" - gothUser := ctx.Session.Get("linkAccountGothUser") - if gothUser == nil { + externalLinkUserInterface := ctx.Session.Get("linkAccountUser") + if externalLinkUserInterface == nil { ctx.ServerError("UserSignIn", errors.New("not in LinkAccount session")) return } + externalLinkUser := externalLinkUserInterface.(auth.LinkAccountUser) + if ctx.HasError() { ctx.HTML(http.StatusOK, tplLinkAccount) return @@ -148,10 +150,10 @@ func LinkAccountPostSignIn(ctx *context.Context) { return } - linkAccount(ctx, u, gothUser.(goth.User), signInForm.Remember) + linkAccount(ctx, u, externalLinkUser.GothUser, signInForm.Remember, externalLinkUser.Type) } -func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, remember bool) { +func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, remember bool, authType auth.Type) { updateAvatarIfNeed(gothUser.AvatarURL, u) // If this user is enrolled in 2FA, we can't sign the user in just yet. @@ -164,7 +166,7 @@ func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, r return } - err = externalaccount.LinkAccountToUser(ctx, u, gothUser) + err = externalaccount.LinkAccountToUser(ctx, u, gothUser, authType) if err != nil { ctx.ServerError("UserLinkAccount", err) return @@ -218,14 +220,14 @@ func LinkAccountPostRegister(ctx *context.Context) { ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin" ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup" - gothUserInterface := ctx.Session.Get("linkAccountGothUser") - if gothUserInterface == nil { + externalLinkUser := ctx.Session.Get("linkAccountUser") + if externalLinkUser == nil { ctx.ServerError("UserSignUp", errors.New("not in LinkAccount session")) return } - gothUser, ok := gothUserInterface.(goth.User) + linkUser, ok := externalLinkUser.(auth.LinkAccountUser) if !ok { - ctx.ServerError("UserSignUp", fmt.Errorf("session linkAccountGothUser type is %t but not goth.User", gothUserInterface)) + ctx.ServerError("UserSignUp", fmt.Errorf("session linkAccountUser type is %t but not goth.User", externalLinkUser)) return } @@ -271,7 +273,7 @@ func LinkAccountPostRegister(ctx *context.Context) { } } - authSource, err := auth.GetActiveOAuth2SourceByName(gothUser.Provider) + authSource, err := auth.GetActiveAuthSourceByName(linkUser.GothUser.Name, linkUser.Type) if err != nil { ctx.ServerError("CreateUser", err) return @@ -283,16 +285,16 @@ func LinkAccountPostRegister(ctx *context.Context) { Passwd: form.Password, LoginType: auth.OAuth2, LoginSource: authSource.ID, - LoginName: gothUser.UserID, + LoginName: linkUser.GothUser.UserID, } - if !createAndHandleCreatedUser(ctx, tplLinkAccount, form, u, nil, &gothUser, false) { + if !createAndHandleCreatedUser(ctx, tplLinkAccount, form, u, nil, &linkUser.GothUser, false, linkUser.Type) { // error already handled return } source := authSource.Cfg.(*oauth2.Source) - if err := syncGroupsToTeams(ctx, source, &gothUser, u); err != nil { + if err := syncGroupsToTeams(ctx, source, &linkUser.GothUser, u); err != nil { ctx.ServerError("SyncGroupsToTeams", err) return } diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index 79f4711c26f8..b20dc64413cb 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -847,7 +847,7 @@ func handleAuthorizeError(ctx *context.Context, authErr AuthorizeError, redirect func SignInOAuth(ctx *context.Context) { provider := ctx.Params(":provider") - authSource, err := auth.GetActiveOAuth2SourceByName(provider) + authSource, err := auth.GetActiveAuthSourceByName(provider, auth.OAuth2) if err != nil { ctx.ServerError("SignIn", err) return @@ -898,7 +898,7 @@ func SignInOAuthCallback(ctx *context.Context) { } // first look if the provider is still active - authSource, err := auth.GetActiveOAuth2SourceByName(provider) + authSource, err := auth.GetActiveAuthSourceByName(provider, auth.OAuth2) if err != nil { ctx.ServerError("SignIn", err) return @@ -941,7 +941,7 @@ func SignInOAuthCallback(ctx *context.Context) { if u == nil { if ctx.Doer != nil { // attach user to already logged in user - err = externalaccount.LinkAccountToUser(ctx, ctx.Doer, gothUser) + err = externalaccount.LinkAccountToUser(ctx, ctx.Doer, gothUser, auth.OAuth2) if err != nil { ctx.ServerError("UserLinkAccount", err) return @@ -987,7 +987,7 @@ func SignInOAuthCallback(ctx *context.Context) { setUserAdminAndRestrictedFromGroupClaims(source, u, &gothUser) - if !createAndHandleCreatedUser(ctx, base.TplName(""), nil, u, overwriteDefault, &gothUser, setting.OAuth2Client.AccountLinking != setting.OAuth2AccountLinkingDisabled) { + if !createAndHandleCreatedUser(ctx, base.TplName(""), nil, u, overwriteDefault, &gothUser, setting.OAuth2Client.AccountLinking != setting.OAuth2AccountLinkingDisabled, auth.OAuth2) { // error already handled return } @@ -998,7 +998,7 @@ func SignInOAuthCallback(ctx *context.Context) { } } else { // no existing user is found, request attach or new account - showLinkingLogin(ctx, gothUser) + showLinkingLogin(ctx, gothUser, auth.OAuth2) return } } @@ -1064,9 +1064,12 @@ func setUserAdminAndRestrictedFromGroupClaims(source *oauth2.Source, u *user_mod return wasAdmin != u.IsAdmin || wasRestricted != u.IsRestricted } -func showLinkingLogin(ctx *context.Context, gothUser goth.User) { +func showLinkingLogin(ctx *context.Context, gothUser goth.User, authType auth.Type) { if err := updateSession(ctx, nil, map[string]any{ - "linkAccountGothUser": gothUser, + "linkAccountUser": auth.LinkAccountUser{ + Type: authType, + GothUser: gothUser, + }, }); err != nil { ctx.ServerError("updateSession", err) return @@ -1151,7 +1154,7 @@ func handleOAuth2SignIn(ctx *context.Context, source *auth.Source, u *user_model } // update external user information - if err := externalaccount.UpdateExternalUser(u, gothUser); err != nil { + if err := externalaccount.UpdateExternalUser(u, gothUser, auth.OAuth2); err != nil { if !errors.Is(err, util.ErrNotExist) { log.Error("UpdateExternalUser failed: %v", err) } diff --git a/routers/web/auth/openid.go b/routers/web/auth/openid.go index aa0712963259..4861d262baf5 100644 --- a/routers/web/auth/openid.go +++ b/routers/web/auth/openid.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" + auth_model "code.gitea.io/gitea/models/auth" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/auth/openid" "code.gitea.io/gitea/modules/base" @@ -380,7 +381,7 @@ func RegisterOpenIDPost(ctx *context.Context) { Email: form.Email, Passwd: password, } - if !createUserInContext(ctx, tplSignUpOID, form, u, nil, nil, false) { + if !createUserInContext(ctx, tplSignUpOID, form, u, nil, nil, false, auth_model.NoType) { // error already handled return } @@ -396,7 +397,7 @@ func RegisterOpenIDPost(ctx *context.Context) { return } - if !handleUserCreated(ctx, u, nil) { + if !handleUserCreated(ctx, u, nil, auth_model.NoType) { // error already handled return } diff --git a/routers/web/auth/saml.go b/routers/web/auth/saml.go index 954c5758fe3d..f5095c5770af 100644 --- a/routers/web/auth/saml.go +++ b/routers/web/auth/saml.go @@ -26,7 +26,7 @@ import ( func SignInSAML(ctx *context.Context) { provider := ctx.Params(":provider") - loginSource, err := auth.GetActiveSAMLLoginSourceByName(provider) + loginSource, err := auth.GetActiveAuthSourceByName(provider, auth.SAML) if err != nil || loginSource == nil { ctx.NotFound("SAMLMetadata", err) return @@ -44,7 +44,7 @@ func SignInSAML(ctx *context.Context) { // SignInSAMLCallback func SignInSAMLCallback(ctx *context.Context) { provider := ctx.Params(":provider") - loginSource, err := auth.GetActiveSAMLLoginSourceByName(provider) + loginSource, err := auth.GetActiveAuthSourceByName(provider, auth.SAML) if err != nil || loginSource == nil { ctx.NotFound("SignInSAMLCallback", err) return @@ -65,7 +65,7 @@ func SignInSAMLCallback(ctx *context.Context) { if u == nil { if ctx.Doer != nil { // attach user to already logged in user - err = externalaccount.LinkAccountToUser(ctx.Doer, gothUser) + err = externalaccount.LinkAccountToUser(ctx, ctx.Doer, gothUser, auth.SAML) if err != nil { ctx.ServerError("UserLinkAccount", err) return @@ -77,7 +77,7 @@ func SignInSAMLCallback(ctx *context.Context) { // TODO: allow auto registration from saml users (OAuth2 uses the following setting.OAuth2Client.EnableAutoRegistration) } else { // no existing user is found, request attach or new account - showLinkingLogin(ctx, gothUser) + showLinkingLogin(ctx, gothUser, auth.SAML) return } } @@ -101,7 +101,7 @@ func handleSamlSignIn(ctx *context.Context, source *auth.Source, u *user_model.U u.SetLastLogin() // update external user information - if err := externalaccount.UpdateExternalUser(u, gothUser); err != nil { + if err := externalaccount.UpdateExternalUser(u, gothUser, auth.SAML); err != nil { if !errors.Is(err, util.ErrNotExist) { log.Error("UpdateExternalUser failed: %v", err) } @@ -131,7 +131,7 @@ func samlUserLoginCallback(authSource *auth.Source, request *http.Request, respo user := &user_model.User{ LoginName: gothUser.UserID, - LoginType: auth.OAuth2, + LoginType: auth.SAML, LoginSource: authSource.ID, } @@ -165,7 +165,7 @@ func samlUserLoginCallback(authSource *auth.Source, request *http.Request, respo // SAMLMetadata func SAMLMetadata(ctx *context.Context) { provider := ctx.Params(":provider") - loginSource, err := auth.GetActiveSAMLLoginSourceByName(provider) + loginSource, err := auth.GetActiveAuthSourceByName(provider, auth.SAML) if err != nil || loginSource == nil { ctx.NotFound("SAMLMetadata", err) return diff --git a/services/auth/source/oauth2/init.go b/services/auth/source/oauth2/init.go index 32fe545c9065..4b1fd94c00c1 100644 --- a/services/auth/source/oauth2/init.go +++ b/services/auth/source/oauth2/init.go @@ -62,7 +62,7 @@ func ResetOAuth2() error { // initOAuth2Sources is used to load and register all active OAuth2 providers func initOAuth2Sources() error { - authSources, _ := auth.GetActiveOAuth2ProviderSources() + authSources, _ := auth.GetActiveAuthProviderSources(auth.OAuth2) for _, source := range authSources { oauth2Source, ok := source.Cfg.(*Source) if !ok { diff --git a/services/auth/source/oauth2/providers.go b/services/auth/source/oauth2/providers.go index e3a0cb0335db..a7c68b3c90d7 100644 --- a/services/auth/source/oauth2/providers.go +++ b/services/auth/source/oauth2/providers.go @@ -100,7 +100,7 @@ func GetOAuth2Providers() []Provider { func GetActiveOAuth2Providers() ([]string, map[string]Provider, error) { // Maybe also separate used and unused providers so we can force the registration of only 1 active provider for each type - authSources, err := auth.GetActiveOAuth2ProviderSources() + authSources, err := auth.GetActiveAuthProviderSources(auth.OAuth2) if err != nil { return nil, nil, err } diff --git a/services/auth/source/saml/init.go b/services/auth/source/saml/init.go index 52d2b92d7167..14553e9fe759 100644 --- a/services/auth/source/saml/init.go +++ b/services/auth/source/saml/init.go @@ -13,7 +13,7 @@ import ( var samlRWMutex = sync.RWMutex{} func Init() error { - loginSources, _ := auth.GetActiveSAMLProviderLoginSources() + loginSources, _ := auth.GetActiveAuthProviderSources(auth.SAML) for _, source := range loginSources { samlSource, ok := source.Cfg.(*Source) if !ok { diff --git a/services/auth/source/saml/providers.go b/services/auth/source/saml/providers.go index c41b6e09b2f3..ff6f205f5797 100644 --- a/services/auth/source/saml/providers.go +++ b/services/auth/source/saml/providers.go @@ -5,23 +5,12 @@ package saml import ( "context" - "crypto/tls" - "crypto/x509" - "encoding/base64" - "encoding/xml" "fmt" "io" "net/http" - "net/url" "time" "code.gitea.io/gitea/modules/httplib" - "code.gitea.io/gitea/modules/log" - "code.gitea.io/gitea/modules/setting" - - saml2 "github.com/russellhaering/gosaml2" - "github.com/russellhaering/gosaml2/types" - dsig "github.com/russellhaering/goxmldsig" ) // Providers is list of known/available providers. @@ -29,80 +18,6 @@ type Providers map[string]Source var providers = Providers{} -// used to create different types of goth providers -func createProvider(ctx context.Context, source *Source) (*Source, error) { - source.CallbackURL = setting.AppURL + "user/saml/" + url.PathEscape(source.authSource.Name) + "/acs" - - idpMetadata, err := readIdentityProviderMetadata(ctx, source) - if err != nil { - return source, err - } - { - if source.IdentityProviderMetadataURL != "" { - log.Trace(fmt.Sprintf("Identity Provider metadata: %s", source.IdentityProviderMetadataURL), string(idpMetadata)) - } - } - - metadata := &types.EntityDescriptor{} - err = xml.Unmarshal(idpMetadata, metadata) - if err != nil { - return source, err - } - - certStore := dsig.MemoryX509CertificateStore{ - Roots: []*x509.Certificate{}, - } - - for _, kd := range metadata.IDPSSODescriptor.KeyDescriptors { - for idx, xcert := range kd.KeyInfo.X509Data.X509Certificates { - if xcert.Data == "" { - return source, fmt.Errorf("metadata certificate(%d) must not be empty", idx) - } - certData, err := base64.StdEncoding.DecodeString(xcert.Data) - if err != nil { - return source, err - } - - idpCert, err := x509.ParseCertificate(certData) - if err != nil { - return source, err - } - - certStore.Roots = append(certStore.Roots, idpCert) - } - } - - var keyStore dsig.X509KeyStore - - if source.ServiceProviderCertificate != "" && source.ServiceProviderPrivateKey != "" { - keyPair, err := tls.X509KeyPair([]byte(source.ServiceProviderCertificate), []byte(source.ServiceProviderPrivateKey)) - if err != nil { - return nil, err - } - keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) - if err != nil { - return nil, err - } - keyStore = dsig.TLSCertKeyStore(keyPair) - } else { - keyStore = dsig.RandomKeyStoreForTest() - } - - source.samlSP = &saml2.SAMLServiceProvider{ - IdentityProviderSSOURL: metadata.IDPSSODescriptor.SingleSignOnServices[0].Location, - IdentityProviderIssuer: metadata.EntityID, - AudienceURI: setting.AppURL + "user/saml/" + url.PathEscape(source.authSource.Name) + "/metadata", - AssertionConsumerServiceURL: source.CallbackURL, - SkipSignatureValidation: source.InsecureSkipAssertionSignatureValidation, - NameIdFormat: source.NameIDFormat.String(), - IDPCertificateStore: &certStore, - SPKeyStore: keyStore, - ServiceProviderIssuer: setting.AppURL + "user/saml/" + url.PathEscape(source.authSource.Name) + "/metadata", - } - - return source, err -} - func readIdentityProviderMetadata(ctx context.Context, source *Source) ([]byte, error) { if source.IdentityProviderMetadata != "" { return []byte(source.IdentityProviderMetadata), nil diff --git a/services/auth/source/saml/source.go b/services/auth/source/saml/source.go index 871fca8cbd97..260b7367b9d8 100644 --- a/services/auth/source/saml/source.go +++ b/services/auth/source/saml/source.go @@ -4,17 +4,29 @@ package saml import ( - "code.gitea.io/gitea/models/auth" - "code.gitea.io/gitea/modules/json" + "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/xml" + "fmt" + "net/url" saml2 "github.com/russellhaering/gosaml2" + "github.com/russellhaering/gosaml2/types" + dsig "github.com/russellhaering/goxmldsig" + + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/modules/json" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" ) // _________ _____ _____ .____ // / _____/ / _ \ / \ | | // \_____ \ / /_\ \ / \ / \| | // / \/ | \/ Y \ |___ -///_______ /\____|__ /\____|__ /_______ \ +// /_______ /\____|__ /\____|__ /_______ \ // \/ \/ \/ \/ // Source holds configuration for the SAML login source. @@ -51,9 +63,86 @@ type Source struct { samlSP *saml2.SAMLServiceProvider } +func (source *Source) initSAMLSp() error { + source.CallbackURL = setting.AppURL + "user/saml/" + url.PathEscape(source.authSource.Name) + "/acs" + + idpMetadata, err := readIdentityProviderMetadata(context.Background(), source) + if err != nil { + return err + } + { + if source.IdentityProviderMetadataURL != "" { + log.Trace(fmt.Sprintf("Identity Provider metadata: %s", source.IdentityProviderMetadataURL), string(idpMetadata)) + } + } + + metadata := &types.EntityDescriptor{} + err = xml.Unmarshal(idpMetadata, metadata) + if err != nil { + return err + } + + certStore := dsig.MemoryX509CertificateStore{ + Roots: []*x509.Certificate{}, + } + + for _, kd := range metadata.IDPSSODescriptor.KeyDescriptors { + for idx, xcert := range kd.KeyInfo.X509Data.X509Certificates { + if xcert.Data == "" { + return fmt.Errorf("metadata certificate(%d) must not be empty", idx) + } + certData, err := base64.StdEncoding.DecodeString(xcert.Data) + if err != nil { + return err + } + + idpCert, err := x509.ParseCertificate(certData) + if err != nil { + return err + } + + certStore.Roots = append(certStore.Roots, idpCert) + } + } + + var keyStore dsig.X509KeyStore + + if source.ServiceProviderCertificate != "" && source.ServiceProviderPrivateKey != "" { + keyPair, err := tls.X509KeyPair([]byte(source.ServiceProviderCertificate), []byte(source.ServiceProviderPrivateKey)) + if err != nil { + return err + } + keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) + if err != nil { + return err + } + keyStore = dsig.TLSCertKeyStore(keyPair) + } else { + keyStore = dsig.RandomKeyStoreForTest() + } + + source.samlSP = &saml2.SAMLServiceProvider{ + IdentityProviderSSOURL: metadata.IDPSSODescriptor.SingleSignOnServices[0].Location, + IdentityProviderIssuer: metadata.EntityID, + AudienceURI: setting.AppURL + "user/saml/" + url.PathEscape(source.authSource.Name) + "/metadata", + AssertionConsumerServiceURL: source.CallbackURL, + SkipSignatureValidation: source.InsecureSkipAssertionSignatureValidation, + NameIdFormat: source.NameIDFormat.String(), + IDPCertificateStore: &certStore, + SPKeyStore: keyStore, + ServiceProviderIssuer: setting.AppURL + "user/saml/" + url.PathEscape(source.authSource.Name) + "/metadata", + } + + return nil +} + // FromDB fills up a SAML from serialized format. func (source *Source) FromDB(bs []byte) error { - return json.UnmarshalHandleDoubleEncode(bs, &source) + if err := json.UnmarshalHandleDoubleEncode(bs, &source); err != nil { + return err + } + + return source.initSAMLSp() } // ToDB exports a SAML to a serialized format. diff --git a/services/auth/source/saml/source_callout.go b/services/auth/source/saml/source_callout.go index 301cadd16e12..eb5d106be9c8 100644 --- a/services/auth/source/saml/source_callout.go +++ b/services/auth/source/saml/source_callout.go @@ -32,14 +32,29 @@ func (source *Source) Callback(request *http.Request, response http.ResponseWrit samlRWMutex.RLock() defer samlRWMutex.RUnlock() - user := goth.User{} + user := goth.User{ + Provider: source.authSource.Name, + } samlResponse := request.FormValue("SAMLResponse") assertions, err := source.samlSP.RetrieveAssertionInfo(samlResponse) if err != nil { return user, err } - if warningInfo := assertions.WarningInfo; warningInfo != nil { - return user, fmt.Errorf("SAML response contains warnings: %v", warningInfo) + + if assertions.WarningInfo.OneTimeUse { + return user, fmt.Errorf("SAML response contains one time use warning") + } + + if assertions.WarningInfo.ProxyRestriction != nil { + return user, fmt.Errorf("SAML response contains proxy restriction warning: %v", assertions.WarningInfo.ProxyRestriction) + } + + if assertions.WarningInfo.NotInAudience { + return user, fmt.Errorf("SAML response contains audience warning") + } + + if assertions.WarningInfo.InvalidTime { + return user, fmt.Errorf("SAML response contains invalid time warning") } samlMap := make(map[string]string) // Global. diff --git a/services/auth/source/saml/source_register.go b/services/auth/source/saml/source_register.go index deb300a5ba4c..93eaaa88b66a 100644 --- a/services/auth/source/saml/source_register.go +++ b/services/auth/source/saml/source_register.go @@ -3,18 +3,15 @@ package saml -import "context" - // RegisterSource causes an OAuth2 configuration to be registered func (source *Source) RegisterSource() error { samlRWMutex.Lock() defer samlRWMutex.Unlock() - var err error - source, err = createProvider(context.Background(), source) - if err == nil { - providers[source.authSource.Name] = *source + if err := source.initSAMLSp(); err != nil { + return err } - return err + providers[source.authSource.Name] = *source + return nil } // UnregisterSource causes an SAML configuration to be unregistered diff --git a/services/externalaccount/link.go b/services/externalaccount/link.go index d6e2ea7e9427..1f4c6728b86a 100644 --- a/services/externalaccount/link.go +++ b/services/externalaccount/link.go @@ -7,9 +7,8 @@ import ( "context" "fmt" + "code.gitea.io/gitea/models/auth" user_model "code.gitea.io/gitea/models/user" - - "github.com/markbates/goth" ) // Store represents a thing that stores things @@ -21,10 +20,12 @@ type Store interface { // LinkAccountFromStore links the provided user with a stored external user func LinkAccountFromStore(ctx context.Context, store Store, user *user_model.User) error { - gothUser := store.Get("linkAccountGothUser") - if gothUser == nil { + externalLinkUserInterface := store.Get("linkAccountUser") + if externalLinkUserInterface == nil { return fmt.Errorf("not in LinkAccount session") } - return LinkAccountToUser(ctx, user, gothUser.(goth.User)) + externalLinkUser := externalLinkUserInterface.(auth.LinkAccountUser) + + return LinkAccountToUser(ctx, user, externalLinkUser.GothUser, externalLinkUser.Type) } diff --git a/services/externalaccount/user.go b/services/externalaccount/user.go index 51a0f9a4ef21..e4634935d1d9 100644 --- a/services/externalaccount/user.go +++ b/services/externalaccount/user.go @@ -16,8 +16,8 @@ import ( "github.com/markbates/goth" ) -func toExternalLoginUser(user *user_model.User, gothUser goth.User) (*user_model.ExternalLoginUser, error) { - authSource, err := auth.GetActiveOAuth2SourceByName(gothUser.Provider) +func toExternalLoginUser(user *user_model.User, gothUser goth.User, authType auth.Type) (*user_model.ExternalLoginUser, error) { + authSource, err := auth.GetActiveAuthSourceByName(gothUser.Provider, authType) if err != nil { return nil, err } @@ -43,8 +43,8 @@ func toExternalLoginUser(user *user_model.User, gothUser goth.User) (*user_model } // LinkAccountToUser link the gothUser to the user -func LinkAccountToUser(ctx context.Context, user *user_model.User, gothUser goth.User) error { - externalLoginUser, err := toExternalLoginUser(user, gothUser) +func LinkAccountToUser(ctx context.Context, user *user_model.User, gothUser goth.User, authType auth.Type) error { + externalLoginUser, err := toExternalLoginUser(user, gothUser, authType) if err != nil { return err } @@ -71,8 +71,8 @@ func LinkAccountToUser(ctx context.Context, user *user_model.User, gothUser goth } // UpdateExternalUser updates external user's information -func UpdateExternalUser(user *user_model.User, gothUser goth.User) error { - externalLoginUser, err := toExternalLoginUser(user, gothUser) +func UpdateExternalUser(user *user_model.User, gothUser goth.User, authType auth.Type) error { + externalLoginUser, err := toExternalLoginUser(user, gothUser, authType) if err != nil { return err } From 3042a0848f2cd5a1d21b372deefc330bfcfe5ccc Mon Sep 17 00:00:00 2001 From: jackHay22 Date: Fri, 29 Sep 2023 15:01:59 -0400 Subject: [PATCH 2/4] fixes for account link --- models/auth/oauth2.go | 5 ++++- routers/web/auth/linkaccount.go | 13 ++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go index 605f775edaef..33baaa79ad04 100644 --- a/models/auth/oauth2.go +++ b/models/auth/oauth2.go @@ -7,6 +7,7 @@ import ( "context" "encoding/base32" "encoding/base64" + "encoding/gob" "fmt" "net" "net/url" @@ -76,6 +77,8 @@ func Init(ctx context.Context) error { builtinAllClientIDs = append(builtinAllClientIDs, clientID) } + gob.Register(LinkAccountUser{}) + var registeredApps []*OAuth2Application if err := db.GetEngine(ctx).In("client_id", builtinAllClientIDs).Find(®isteredApps); err != nil { return err @@ -644,7 +647,7 @@ func GetActiveAuthSourceByName(name string, authType Type) (*Source, error) { } if !has { - return nil, fmt.Errorf("oauth2 source not found, name: %q", name) + return nil, fmt.Errorf("auth source not found, name: %q", name) } return authSource, nil diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go index 3820b33ce55a..99229c35ccf6 100644 --- a/routers/web/auth/linkaccount.go +++ b/routers/web/auth/linkaccount.go @@ -273,7 +273,7 @@ func LinkAccountPostRegister(ctx *context.Context) { } } - authSource, err := auth.GetActiveAuthSourceByName(linkUser.GothUser.Name, linkUser.Type) + authSource, err := auth.GetActiveAuthSourceByName(linkUser.GothUser.Provider, linkUser.Type) if err != nil { ctx.ServerError("CreateUser", err) return @@ -293,11 +293,14 @@ func LinkAccountPostRegister(ctx *context.Context) { return } - source := authSource.Cfg.(*oauth2.Source) - if err := syncGroupsToTeams(ctx, source, &linkUser.GothUser, u); err != nil { - ctx.ServerError("SyncGroupsToTeams", err) - return + if linkUser.Type == auth.OAuth2 { + source := authSource.Cfg.(*oauth2.Source) + if err := syncGroupsToTeams(ctx, source, &linkUser.GothUser, u); err != nil { + ctx.ServerError("SyncGroupsToTeams", err) + return + } } + // TODO groups for SAML? handleSignIn(ctx, u, false) } From 1dae9fd8748aae5cc9cdbe3fd8b137f4160f3716 Mon Sep 17 00:00:00 2001 From: jackHay22 Date: Mon, 2 Oct 2023 15:45:33 -0400 Subject: [PATCH 3/4] rebase cleanup --- models/auth/source.go | 3 +-- routers/web/auth/saml.go | 6 +++--- services/auth/source/saml/source.go | 8 ++++---- services/auth/source/saml/source_authenticate.go | 6 ++++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/models/auth/source.go b/models/auth/source.go index 8963c442681e..e4fde0d8f07d 100644 --- a/models/auth/source.go +++ b/models/auth/source.go @@ -8,13 +8,12 @@ import ( "fmt" "reflect" - "github.com/markbates/goth" - "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/timeutil" "code.gitea.io/gitea/modules/util" + "github.com/markbates/goth" "xorm.io/xorm" "xorm.io/xorm/convert" ) diff --git a/routers/web/auth/saml.go b/routers/web/auth/saml.go index f5095c5770af..e2e66e79c237 100644 --- a/routers/web/auth/saml.go +++ b/routers/web/auth/saml.go @@ -55,7 +55,7 @@ func SignInSAMLCallback(ctx *context.Context) { return } - u, gothUser, err := samlUserLoginCallback(loginSource, ctx.Req, ctx.Resp) + u, gothUser, err := samlUserLoginCallback(*ctx, loginSource, ctx.Req, ctx.Resp) if err != nil { // TODO: improve error display ctx.ServerError("SignIn", err) @@ -121,7 +121,7 @@ func handleSamlSignIn(ctx *context.Context, source *auth.Source, u *user_model.U ctx.Redirect(setting.AppSubURL + "/") } -func samlUserLoginCallback(authSource *auth.Source, request *http.Request, response http.ResponseWriter) (*user_model.User, goth.User, error) { +func samlUserLoginCallback(ctx context.Context, authSource *auth.Source, request *http.Request, response http.ResponseWriter) (*user_model.User, goth.User, error) { samlSource := authSource.Cfg.(*saml.Source) gothUser, err := samlSource.Callback(request, response) @@ -135,7 +135,7 @@ func samlUserLoginCallback(authSource *auth.Source, request *http.Request, respo LoginSource: authSource.ID, } - hasUser, err := user_model.GetUser(user) + hasUser, err := user_model.GetUser(ctx, user) if err != nil { return nil, goth.User{}, err } diff --git a/services/auth/source/saml/source.go b/services/auth/source/saml/source.go index 260b7367b9d8..e2321bcdaf4d 100644 --- a/services/auth/source/saml/source.go +++ b/services/auth/source/saml/source.go @@ -12,14 +12,14 @@ import ( "fmt" "net/url" - saml2 "github.com/russellhaering/gosaml2" - "github.com/russellhaering/gosaml2/types" - dsig "github.com/russellhaering/goxmldsig" - "code.gitea.io/gitea/models/auth" "code.gitea.io/gitea/modules/json" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" + + saml2 "github.com/russellhaering/gosaml2" + "github.com/russellhaering/gosaml2/types" + dsig "github.com/russellhaering/goxmldsig" ) // _________ _____ _____ .____ diff --git a/services/auth/source/saml/source_authenticate.go b/services/auth/source/saml/source_authenticate.go index ad4a7d247ca3..d118917f8740 100644 --- a/services/auth/source/saml/source_authenticate.go +++ b/services/auth/source/saml/source_authenticate.go @@ -4,11 +4,13 @@ package saml import ( + "context" + user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/services/auth/source/db" ) // Authenticate falls back to the db authenticator -func (source *Source) Authenticate(user *user_model.User, login, password string) (*user_model.User, error) { - return db.Authenticate(user, login, password) +func (source *Source) Authenticate(ctx context.Context, user *user_model.User, login, password string) (*user_model.User, error) { + return db.Authenticate(ctx, user, login, password) } From d555cf929f563d429843b04ad161faeb32d000bb Mon Sep 17 00:00:00 2001 From: jackHay22 Date: Mon, 2 Oct 2023 15:48:21 -0400 Subject: [PATCH 4/4] whitespace fix --- models/auth/oauth2.go | 4 ++-- services/auth/source/saml/source.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go index 33baaa79ad04..0ab042a16964 100644 --- a/models/auth/oauth2.go +++ b/models/auth/oauth2.go @@ -389,7 +389,7 @@ func ListOAuth2Applications(uid int64, listOptions db.ListOptions) ([]*OAuth2App return apps, total, err } -// //////////////////////////////////////////////////// +////////////////////////////////////////////////////// // OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime. type OAuth2AuthorizationCode struct { @@ -464,7 +464,7 @@ func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth return auth, nil } -// //////////////////////////////////////////////////// +////////////////////////////////////////////////////// // OAuth2Grant represents the permission of an user for a specific application to access resources type OAuth2Grant struct { diff --git a/services/auth/source/saml/source.go b/services/auth/source/saml/source.go index e2321bcdaf4d..6106810bced4 100644 --- a/services/auth/source/saml/source.go +++ b/services/auth/source/saml/source.go @@ -26,7 +26,7 @@ import ( // / _____/ / _ \ / \ | | // \_____ \ / /_\ \ / \ / \| | // / \/ | \/ Y \ |___ -// /_______ /\____|__ /\____|__ /_______ \ +///_______ /\____|__ /\____|__ /_______ \ // \/ \/ \/ \/ // Source holds configuration for the SAML login source.