Skip to content

Commit

Permalink
Merge pull request #1963 from jackHay22/jh/saml
Browse files Browse the repository at this point in the history
Fix SAML account link
  • Loading branch information
techknowlogick authored Oct 8, 2023
2 parents 7a92296 + d555cf9 commit f8b3a2a
Show file tree
Hide file tree
Showing 18 changed files with 206 additions and 196 deletions.
17 changes: 10 additions & 7 deletions models/auth/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"encoding/base32"
"encoding/base64"
"encoding/gob"
"fmt"
"net"
"net/url"
Expand Down Expand Up @@ -76,6 +77,8 @@ func Init(ctx context.Context) error {
builtinAllClientIDs = append(builtinAllClientIDs, clientID)
}

gob.Register(LinkAccountUser{})

var registeredApps []*OAuth2Application
if err := db.GetEngine(ctx).In("client_id", builtinAllClientIDs).Find(&registeredApps); err != nil {
return err
Expand Down Expand Up @@ -626,25 +629,25 @@ func (err ErrOAuthApplicationNotFound) Unwrap() error {
return util.ErrNotExist
}

// GetActiveOAuth2ProviderSources returns all actived LoginOAuth2 sources
func GetActiveOAuth2ProviderSources() ([]*Source, error) {
// GetActiveAuthProviderSources returns all actived LoginOAuth2 sources
func GetActiveAuthProviderSources(authType Type) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil {
if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, authType).Find(&sources); err != nil {
return nil, err
}
return sources, nil
}

// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name
func GetActiveOAuth2SourceByName(name string) (*Source, error) {
// GetActiveAuthSourceByName returns a OAuth2 AuthSource based on the given name
func GetActiveAuthSourceByName(name string, authType Type) (*Source, error) {
authSource := new(Source)
has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource)
has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, authType, true).Get(authSource)
if err != nil {
return nil, err
}

if !has {
return nil, fmt.Errorf("oauth2 source not found, name: %q", name)
return nil, fmt.Errorf("auth source not found, name: %q", name)
}

return authSource, nil
Expand Down
28 changes: 0 additions & 28 deletions models/auth/saml.go

This file was deleted.

7 changes: 7 additions & 0 deletions models/auth/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"

"github.com/markbates/goth"
"xorm.io/xorm"
"xorm.io/xorm/convert"
)
Expand Down Expand Up @@ -121,6 +122,12 @@ type Source struct {
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}

// LinkAccountUser is used to link an external user with a local user
type LinkAccountUser struct {
Type Type
GothUser goth.User
}

// TableName xorm will read the table name from this method
func (Source) TableName() string {
return "login_source"
Expand Down
20 changes: 10 additions & 10 deletions routers/web/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func SignIn(ctx *context.Context) {
ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names
ctx.Data["OAuth2Providers"] = oauth2Providers

samlProviders, err := auth.GetActiveSAMLProviderLoginSources()
samlProviders, err := auth.GetActiveAuthProviderSources(auth.SAML)
if err != nil {
ctx.ServerError("UserSignIn", err)
return
Expand Down Expand Up @@ -496,7 +496,7 @@ func SignUpPost(ctx *context.Context) {
Passwd: form.Password,
}

if !createAndHandleCreatedUser(ctx, tplSignUp, form, u, nil, nil, false) {
if !createAndHandleCreatedUser(ctx, tplSignUp, form, u, nil, nil, false, auth.NoType) {
// error already handled
return
}
Expand All @@ -507,16 +507,16 @@ func SignUpPost(ctx *context.Context) {

// createAndHandleCreatedUser calls createUserInContext and
// then handleUserCreated.
func createAndHandleCreatedUser(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool) bool {
if !createUserInContext(ctx, tpl, form, u, overwrites, gothUser, allowLink) {
func createAndHandleCreatedUser(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool, authType auth.Type) bool {
if !createUserInContext(ctx, tpl, form, u, overwrites, gothUser, allowLink, authType) {
return false
}
return handleUserCreated(ctx, u, gothUser)
return handleUserCreated(ctx, u, gothUser, authType)
}

// createUserInContext creates a user and handles errors within a given context.
// Optionally a template can be specified.
func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool) (ok bool) {
func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool, authType auth.Type) (ok bool) {
if err := user_model.CreateUser(ctx, u, overwrites); err != nil {
if allowLink && (user_model.IsErrUserAlreadyExist(err) || user_model.IsErrEmailAlreadyUsed(err)) {
if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingAuto {
Expand All @@ -533,10 +533,10 @@ func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *us
}

// TODO: probably we should respect 'remember' user's choice...
linkAccount(ctx, user, *gothUser, true)
linkAccount(ctx, user, *gothUser, true, authType)
return false // user is already created here, all redirects are handled
} else if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingLogin {
showLinkingLogin(ctx, *gothUser)
showLinkingLogin(ctx, *gothUser, authType)
return false // user will be created only after linking login
}
}
Expand Down Expand Up @@ -582,7 +582,7 @@ func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *us
// handleUserCreated does additional steps after a new user is created.
// It auto-sets admin for the only user, updates the optional external user and
// sends a confirmation email if required.
func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth.User) (ok bool) {
func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth.User, authType auth.Type) (ok bool) {
// Auto-set admin for the only user.
if user_model.CountUsers(ctx, nil) == 1 {
u.IsAdmin = true
Expand All @@ -596,7 +596,7 @@ func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth.

// update external user information
if gothUser != nil {
if err := externalaccount.UpdateExternalUser(u, *gothUser); err != nil {
if err := externalaccount.UpdateExternalUser(u, *gothUser, authType); err != nil {
if !errors.Is(err, util.ErrNotExist) {
log.Error("UpdateExternalUser failed: %v", err)
}
Expand Down
43 changes: 24 additions & 19 deletions routers/web/auth/linkaccount.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ func LinkAccount(ctx *context.Context) {
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin"
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup"

gothUser := ctx.Session.Get("linkAccountGothUser")
if gothUser == nil {
externalLinkUser := ctx.Session.Get("linkAccountUser")
if externalLinkUser == nil {
ctx.ServerError("UserSignIn", errors.New("not in LinkAccount session"))
return
}

gu, _ := gothUser.(goth.User)
gu := externalLinkUser.(auth.LinkAccountUser).GothUser
uname := getUserName(&gu)
email := gu.Email
ctx.Data["user_name"] = uname
Expand Down Expand Up @@ -131,12 +131,14 @@ func LinkAccountPostSignIn(ctx *context.Context) {
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin"
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup"

gothUser := ctx.Session.Get("linkAccountGothUser")
if gothUser == nil {
externalLinkUserInterface := ctx.Session.Get("linkAccountUser")
if externalLinkUserInterface == nil {
ctx.ServerError("UserSignIn", errors.New("not in LinkAccount session"))
return
}

externalLinkUser := externalLinkUserInterface.(auth.LinkAccountUser)

if ctx.HasError() {
ctx.HTML(http.StatusOK, tplLinkAccount)
return
Expand All @@ -148,10 +150,10 @@ func LinkAccountPostSignIn(ctx *context.Context) {
return
}

linkAccount(ctx, u, gothUser.(goth.User), signInForm.Remember)
linkAccount(ctx, u, externalLinkUser.GothUser, signInForm.Remember, externalLinkUser.Type)
}

func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, remember bool) {
func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, remember bool, authType auth.Type) {
updateAvatarIfNeed(gothUser.AvatarURL, u)

// If this user is enrolled in 2FA, we can't sign the user in just yet.
Expand All @@ -164,7 +166,7 @@ func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, r
return
}

err = externalaccount.LinkAccountToUser(ctx, u, gothUser)
err = externalaccount.LinkAccountToUser(ctx, u, gothUser, authType)
if err != nil {
ctx.ServerError("UserLinkAccount", err)
return
Expand Down Expand Up @@ -218,14 +220,14 @@ func LinkAccountPostRegister(ctx *context.Context) {
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin"
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup"

gothUserInterface := ctx.Session.Get("linkAccountGothUser")
if gothUserInterface == nil {
externalLinkUser := ctx.Session.Get("linkAccountUser")
if externalLinkUser == nil {
ctx.ServerError("UserSignUp", errors.New("not in LinkAccount session"))
return
}
gothUser, ok := gothUserInterface.(goth.User)
linkUser, ok := externalLinkUser.(auth.LinkAccountUser)
if !ok {
ctx.ServerError("UserSignUp", fmt.Errorf("session linkAccountGothUser type is %t but not goth.User", gothUserInterface))
ctx.ServerError("UserSignUp", fmt.Errorf("session linkAccountUser type is %t but not goth.User", externalLinkUser))
return
}

Expand Down Expand Up @@ -271,7 +273,7 @@ func LinkAccountPostRegister(ctx *context.Context) {
}
}

authSource, err := auth.GetActiveOAuth2SourceByName(gothUser.Provider)
authSource, err := auth.GetActiveAuthSourceByName(linkUser.GothUser.Provider, linkUser.Type)
if err != nil {
ctx.ServerError("CreateUser", err)
return
Expand All @@ -283,19 +285,22 @@ func LinkAccountPostRegister(ctx *context.Context) {
Passwd: form.Password,
LoginType: auth.OAuth2,
LoginSource: authSource.ID,
LoginName: gothUser.UserID,
LoginName: linkUser.GothUser.UserID,
}

if !createAndHandleCreatedUser(ctx, tplLinkAccount, form, u, nil, &gothUser, false) {
if !createAndHandleCreatedUser(ctx, tplLinkAccount, form, u, nil, &linkUser.GothUser, false, linkUser.Type) {
// error already handled
return
}

source := authSource.Cfg.(*oauth2.Source)
if err := syncGroupsToTeams(ctx, source, &gothUser, u); err != nil {
ctx.ServerError("SyncGroupsToTeams", err)
return
if linkUser.Type == auth.OAuth2 {
source := authSource.Cfg.(*oauth2.Source)
if err := syncGroupsToTeams(ctx, source, &linkUser.GothUser, u); err != nil {
ctx.ServerError("SyncGroupsToTeams", err)
return
}
}
// TODO groups for SAML?

handleSignIn(ctx, u, false)
}
19 changes: 11 additions & 8 deletions routers/web/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ func handleAuthorizeError(ctx *context.Context, authErr AuthorizeError, redirect
func SignInOAuth(ctx *context.Context) {
provider := ctx.Params(":provider")

authSource, err := auth.GetActiveOAuth2SourceByName(provider)
authSource, err := auth.GetActiveAuthSourceByName(provider, auth.OAuth2)
if err != nil {
ctx.ServerError("SignIn", err)
return
Expand Down Expand Up @@ -898,7 +898,7 @@ func SignInOAuthCallback(ctx *context.Context) {
}

// first look if the provider is still active
authSource, err := auth.GetActiveOAuth2SourceByName(provider)
authSource, err := auth.GetActiveAuthSourceByName(provider, auth.OAuth2)
if err != nil {
ctx.ServerError("SignIn", err)
return
Expand Down Expand Up @@ -941,7 +941,7 @@ func SignInOAuthCallback(ctx *context.Context) {
if u == nil {
if ctx.Doer != nil {
// attach user to already logged in user
err = externalaccount.LinkAccountToUser(ctx, ctx.Doer, gothUser)
err = externalaccount.LinkAccountToUser(ctx, ctx.Doer, gothUser, auth.OAuth2)
if err != nil {
ctx.ServerError("UserLinkAccount", err)
return
Expand Down Expand Up @@ -987,7 +987,7 @@ func SignInOAuthCallback(ctx *context.Context) {

setUserAdminAndRestrictedFromGroupClaims(source, u, &gothUser)

if !createAndHandleCreatedUser(ctx, base.TplName(""), nil, u, overwriteDefault, &gothUser, setting.OAuth2Client.AccountLinking != setting.OAuth2AccountLinkingDisabled) {
if !createAndHandleCreatedUser(ctx, base.TplName(""), nil, u, overwriteDefault, &gothUser, setting.OAuth2Client.AccountLinking != setting.OAuth2AccountLinkingDisabled, auth.OAuth2) {
// error already handled
return
}
Expand All @@ -998,7 +998,7 @@ func SignInOAuthCallback(ctx *context.Context) {
}
} else {
// no existing user is found, request attach or new account
showLinkingLogin(ctx, gothUser)
showLinkingLogin(ctx, gothUser, auth.OAuth2)
return
}
}
Expand Down Expand Up @@ -1064,9 +1064,12 @@ func setUserAdminAndRestrictedFromGroupClaims(source *oauth2.Source, u *user_mod
return wasAdmin != u.IsAdmin || wasRestricted != u.IsRestricted
}

func showLinkingLogin(ctx *context.Context, gothUser goth.User) {
func showLinkingLogin(ctx *context.Context, gothUser goth.User, authType auth.Type) {
if err := updateSession(ctx, nil, map[string]any{
"linkAccountGothUser": gothUser,
"linkAccountUser": auth.LinkAccountUser{
Type: authType,
GothUser: gothUser,
},
}); err != nil {
ctx.ServerError("updateSession", err)
return
Expand Down Expand Up @@ -1151,7 +1154,7 @@ func handleOAuth2SignIn(ctx *context.Context, source *auth.Source, u *user_model
}

// update external user information
if err := externalaccount.UpdateExternalUser(u, gothUser); err != nil {
if err := externalaccount.UpdateExternalUser(u, gothUser, auth.OAuth2); err != nil {
if !errors.Is(err, util.ErrNotExist) {
log.Error("UpdateExternalUser failed: %v", err)
}
Expand Down
5 changes: 3 additions & 2 deletions routers/web/auth/openid.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/url"

auth_model "code.gitea.io/gitea/models/auth"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/auth/openid"
"code.gitea.io/gitea/modules/base"
Expand Down Expand Up @@ -380,7 +381,7 @@ func RegisterOpenIDPost(ctx *context.Context) {
Email: form.Email,
Passwd: password,
}
if !createUserInContext(ctx, tplSignUpOID, form, u, nil, nil, false) {
if !createUserInContext(ctx, tplSignUpOID, form, u, nil, nil, false, auth_model.NoType) {
// error already handled
return
}
Expand All @@ -396,7 +397,7 @@ func RegisterOpenIDPost(ctx *context.Context) {
return
}

if !handleUserCreated(ctx, u, nil) {
if !handleUserCreated(ctx, u, nil, auth_model.NoType) {
// error already handled
return
}
Expand Down
Loading

0 comments on commit f8b3a2a

Please sign in to comment.