Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SAML account link #1963

Merged
merged 4 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions models/auth/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"encoding/base32"
"encoding/base64"
"encoding/gob"
"fmt"
"net"
"net/url"
Expand Down Expand Up @@ -76,6 +77,8 @@ func Init(ctx context.Context) error {
builtinAllClientIDs = append(builtinAllClientIDs, clientID)
}

gob.Register(LinkAccountUser{})

var registeredApps []*OAuth2Application
if err := db.GetEngine(ctx).In("client_id", builtinAllClientIDs).Find(&registeredApps); err != nil {
return err
Expand Down Expand Up @@ -626,25 +629,25 @@ func (err ErrOAuthApplicationNotFound) Unwrap() error {
return util.ErrNotExist
}

// GetActiveOAuth2ProviderSources returns all actived LoginOAuth2 sources
func GetActiveOAuth2ProviderSources() ([]*Source, error) {
// GetActiveAuthProviderSources returns all actived LoginOAuth2 sources
func GetActiveAuthProviderSources(authType Type) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil {
if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, authType).Find(&sources); err != nil {
return nil, err
}
return sources, nil
}

// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name
func GetActiveOAuth2SourceByName(name string) (*Source, error) {
// GetActiveAuthSourceByName returns a OAuth2 AuthSource based on the given name
func GetActiveAuthSourceByName(name string, authType Type) (*Source, error) {
authSource := new(Source)
has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource)
has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, authType, true).Get(authSource)
if err != nil {
return nil, err
}

if !has {
return nil, fmt.Errorf("oauth2 source not found, name: %q", name)
return nil, fmt.Errorf("auth source not found, name: %q", name)
}

return authSource, nil
Expand Down
28 changes: 0 additions & 28 deletions models/auth/saml.go

This file was deleted.

7 changes: 7 additions & 0 deletions models/auth/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"

"github.com/markbates/goth"
"xorm.io/xorm"
"xorm.io/xorm/convert"
)
Expand Down Expand Up @@ -121,6 +122,12 @@ type Source struct {
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}

// LinkAccountUser is used to link an external user with a local user
type LinkAccountUser struct {
Type Type
GothUser goth.User
}

// TableName xorm will read the table name from this method
func (Source) TableName() string {
return "login_source"
Expand Down
20 changes: 10 additions & 10 deletions routers/web/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func SignIn(ctx *context.Context) {
ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names
ctx.Data["OAuth2Providers"] = oauth2Providers

samlProviders, err := auth.GetActiveSAMLProviderLoginSources()
samlProviders, err := auth.GetActiveAuthProviderSources(auth.SAML)
if err != nil {
ctx.ServerError("UserSignIn", err)
return
Expand Down Expand Up @@ -496,7 +496,7 @@ func SignUpPost(ctx *context.Context) {
Passwd: form.Password,
}

if !createAndHandleCreatedUser(ctx, tplSignUp, form, u, nil, nil, false) {
if !createAndHandleCreatedUser(ctx, tplSignUp, form, u, nil, nil, false, auth.NoType) {
// error already handled
return
}
Expand All @@ -507,16 +507,16 @@ func SignUpPost(ctx *context.Context) {

// createAndHandleCreatedUser calls createUserInContext and
// then handleUserCreated.
func createAndHandleCreatedUser(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool) bool {
if !createUserInContext(ctx, tpl, form, u, overwrites, gothUser, allowLink) {
func createAndHandleCreatedUser(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool, authType auth.Type) bool {
if !createUserInContext(ctx, tpl, form, u, overwrites, gothUser, allowLink, authType) {
return false
}
return handleUserCreated(ctx, u, gothUser)
return handleUserCreated(ctx, u, gothUser, authType)
}

// createUserInContext creates a user and handles errors within a given context.
// Optionally a template can be specified.
func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool) (ok bool) {
func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *user_model.User, overwrites *user_model.CreateUserOverwriteOptions, gothUser *goth.User, allowLink bool, authType auth.Type) (ok bool) {
if err := user_model.CreateUser(ctx, u, overwrites); err != nil {
if allowLink && (user_model.IsErrUserAlreadyExist(err) || user_model.IsErrEmailAlreadyUsed(err)) {
if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingAuto {
Expand All @@ -533,10 +533,10 @@ func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *us
}

// TODO: probably we should respect 'remember' user's choice...
linkAccount(ctx, user, *gothUser, true)
linkAccount(ctx, user, *gothUser, true, authType)
return false // user is already created here, all redirects are handled
} else if setting.OAuth2Client.AccountLinking == setting.OAuth2AccountLinkingLogin {
showLinkingLogin(ctx, *gothUser)
showLinkingLogin(ctx, *gothUser, authType)
return false // user will be created only after linking login
}
}
Expand Down Expand Up @@ -582,7 +582,7 @@ func createUserInContext(ctx *context.Context, tpl base.TplName, form any, u *us
// handleUserCreated does additional steps after a new user is created.
// It auto-sets admin for the only user, updates the optional external user and
// sends a confirmation email if required.
func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth.User) (ok bool) {
func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth.User, authType auth.Type) (ok bool) {
// Auto-set admin for the only user.
if user_model.CountUsers(ctx, nil) == 1 {
u.IsAdmin = true
Expand All @@ -596,7 +596,7 @@ func handleUserCreated(ctx *context.Context, u *user_model.User, gothUser *goth.

// update external user information
if gothUser != nil {
if err := externalaccount.UpdateExternalUser(u, *gothUser); err != nil {
if err := externalaccount.UpdateExternalUser(u, *gothUser, authType); err != nil {
if !errors.Is(err, util.ErrNotExist) {
log.Error("UpdateExternalUser failed: %v", err)
}
Expand Down
43 changes: 24 additions & 19 deletions routers/web/auth/linkaccount.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ func LinkAccount(ctx *context.Context) {
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin"
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup"

gothUser := ctx.Session.Get("linkAccountGothUser")
if gothUser == nil {
externalLinkUser := ctx.Session.Get("linkAccountUser")
if externalLinkUser == nil {
ctx.ServerError("UserSignIn", errors.New("not in LinkAccount session"))
return
}

gu, _ := gothUser.(goth.User)
gu := externalLinkUser.(auth.LinkAccountUser).GothUser
uname := getUserName(&gu)
email := gu.Email
ctx.Data["user_name"] = uname
Expand Down Expand Up @@ -131,12 +131,14 @@ func LinkAccountPostSignIn(ctx *context.Context) {
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin"
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup"

gothUser := ctx.Session.Get("linkAccountGothUser")
if gothUser == nil {
externalLinkUserInterface := ctx.Session.Get("linkAccountUser")
if externalLinkUserInterface == nil {
ctx.ServerError("UserSignIn", errors.New("not in LinkAccount session"))
return
}

externalLinkUser := externalLinkUserInterface.(auth.LinkAccountUser)

if ctx.HasError() {
ctx.HTML(http.StatusOK, tplLinkAccount)
return
Expand All @@ -148,10 +150,10 @@ func LinkAccountPostSignIn(ctx *context.Context) {
return
}

linkAccount(ctx, u, gothUser.(goth.User), signInForm.Remember)
linkAccount(ctx, u, externalLinkUser.GothUser, signInForm.Remember, externalLinkUser.Type)
}

func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, remember bool) {
func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, remember bool, authType auth.Type) {
updateAvatarIfNeed(gothUser.AvatarURL, u)

// If this user is enrolled in 2FA, we can't sign the user in just yet.
Expand All @@ -164,7 +166,7 @@ func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, r
return
}

err = externalaccount.LinkAccountToUser(ctx, u, gothUser)
err = externalaccount.LinkAccountToUser(ctx, u, gothUser, authType)
if err != nil {
ctx.ServerError("UserLinkAccount", err)
return
Expand Down Expand Up @@ -218,14 +220,14 @@ func LinkAccountPostRegister(ctx *context.Context) {
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin"
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup"

gothUserInterface := ctx.Session.Get("linkAccountGothUser")
if gothUserInterface == nil {
externalLinkUser := ctx.Session.Get("linkAccountUser")
if externalLinkUser == nil {
ctx.ServerError("UserSignUp", errors.New("not in LinkAccount session"))
return
}
gothUser, ok := gothUserInterface.(goth.User)
linkUser, ok := externalLinkUser.(auth.LinkAccountUser)
if !ok {
ctx.ServerError("UserSignUp", fmt.Errorf("session linkAccountGothUser type is %t but not goth.User", gothUserInterface))
ctx.ServerError("UserSignUp", fmt.Errorf("session linkAccountUser type is %t but not goth.User", externalLinkUser))
return
}

Expand Down Expand Up @@ -271,7 +273,7 @@ func LinkAccountPostRegister(ctx *context.Context) {
}
}

authSource, err := auth.GetActiveOAuth2SourceByName(gothUser.Provider)
authSource, err := auth.GetActiveAuthSourceByName(linkUser.GothUser.Provider, linkUser.Type)
if err != nil {
ctx.ServerError("CreateUser", err)
return
Expand All @@ -283,19 +285,22 @@ func LinkAccountPostRegister(ctx *context.Context) {
Passwd: form.Password,
LoginType: auth.OAuth2,
LoginSource: authSource.ID,
LoginName: gothUser.UserID,
LoginName: linkUser.GothUser.UserID,
}

if !createAndHandleCreatedUser(ctx, tplLinkAccount, form, u, nil, &gothUser, false) {
if !createAndHandleCreatedUser(ctx, tplLinkAccount, form, u, nil, &linkUser.GothUser, false, linkUser.Type) {
// error already handled
return
}

source := authSource.Cfg.(*oauth2.Source)
if err := syncGroupsToTeams(ctx, source, &gothUser, u); err != nil {
ctx.ServerError("SyncGroupsToTeams", err)
return
if linkUser.Type == auth.OAuth2 {
source := authSource.Cfg.(*oauth2.Source)
if err := syncGroupsToTeams(ctx, source, &linkUser.GothUser, u); err != nil {
ctx.ServerError("SyncGroupsToTeams", err)
return
}
}
// TODO groups for SAML?

handleSignIn(ctx, u, false)
}
19 changes: 11 additions & 8 deletions routers/web/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ func handleAuthorizeError(ctx *context.Context, authErr AuthorizeError, redirect
func SignInOAuth(ctx *context.Context) {
provider := ctx.Params(":provider")

authSource, err := auth.GetActiveOAuth2SourceByName(provider)
authSource, err := auth.GetActiveAuthSourceByName(provider, auth.OAuth2)
if err != nil {
ctx.ServerError("SignIn", err)
return
Expand Down Expand Up @@ -898,7 +898,7 @@ func SignInOAuthCallback(ctx *context.Context) {
}

// first look if the provider is still active
authSource, err := auth.GetActiveOAuth2SourceByName(provider)
authSource, err := auth.GetActiveAuthSourceByName(provider, auth.OAuth2)
if err != nil {
ctx.ServerError("SignIn", err)
return
Expand Down Expand Up @@ -941,7 +941,7 @@ func SignInOAuthCallback(ctx *context.Context) {
if u == nil {
if ctx.Doer != nil {
// attach user to already logged in user
err = externalaccount.LinkAccountToUser(ctx, ctx.Doer, gothUser)
err = externalaccount.LinkAccountToUser(ctx, ctx.Doer, gothUser, auth.OAuth2)
if err != nil {
ctx.ServerError("UserLinkAccount", err)
return
Expand Down Expand Up @@ -987,7 +987,7 @@ func SignInOAuthCallback(ctx *context.Context) {

setUserAdminAndRestrictedFromGroupClaims(source, u, &gothUser)

if !createAndHandleCreatedUser(ctx, base.TplName(""), nil, u, overwriteDefault, &gothUser, setting.OAuth2Client.AccountLinking != setting.OAuth2AccountLinkingDisabled) {
if !createAndHandleCreatedUser(ctx, base.TplName(""), nil, u, overwriteDefault, &gothUser, setting.OAuth2Client.AccountLinking != setting.OAuth2AccountLinkingDisabled, auth.OAuth2) {
// error already handled
return
}
Expand All @@ -998,7 +998,7 @@ func SignInOAuthCallback(ctx *context.Context) {
}
} else {
// no existing user is found, request attach or new account
showLinkingLogin(ctx, gothUser)
showLinkingLogin(ctx, gothUser, auth.OAuth2)
return
}
}
Expand Down Expand Up @@ -1064,9 +1064,12 @@ func setUserAdminAndRestrictedFromGroupClaims(source *oauth2.Source, u *user_mod
return wasAdmin != u.IsAdmin || wasRestricted != u.IsRestricted
}

func showLinkingLogin(ctx *context.Context, gothUser goth.User) {
func showLinkingLogin(ctx *context.Context, gothUser goth.User, authType auth.Type) {
if err := updateSession(ctx, nil, map[string]any{
"linkAccountGothUser": gothUser,
"linkAccountUser": auth.LinkAccountUser{
Type: authType,
GothUser: gothUser,
},
}); err != nil {
ctx.ServerError("updateSession", err)
return
Expand Down Expand Up @@ -1151,7 +1154,7 @@ func handleOAuth2SignIn(ctx *context.Context, source *auth.Source, u *user_model
}

// update external user information
if err := externalaccount.UpdateExternalUser(u, gothUser); err != nil {
if err := externalaccount.UpdateExternalUser(u, gothUser, auth.OAuth2); err != nil {
if !errors.Is(err, util.ErrNotExist) {
log.Error("UpdateExternalUser failed: %v", err)
}
Expand Down
5 changes: 3 additions & 2 deletions routers/web/auth/openid.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/url"

auth_model "code.gitea.io/gitea/models/auth"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/auth/openid"
"code.gitea.io/gitea/modules/base"
Expand Down Expand Up @@ -380,7 +381,7 @@ func RegisterOpenIDPost(ctx *context.Context) {
Email: form.Email,
Passwd: password,
}
if !createUserInContext(ctx, tplSignUpOID, form, u, nil, nil, false) {
if !createUserInContext(ctx, tplSignUpOID, form, u, nil, nil, false, auth_model.NoType) {
// error already handled
return
}
Expand All @@ -396,7 +397,7 @@ func RegisterOpenIDPost(ctx *context.Context) {
return
}

if !handleUserCreated(ctx, u, nil) {
if !handleUserCreated(ctx, u, nil, auth_model.NoType) {
// error already handled
return
}
Expand Down
Loading
Loading