Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move almost all functions' parameter db.Engine to context.Context #19748

Merged
merged 18 commits into from
May 20, 2022
6 changes: 3 additions & 3 deletions cmd/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ func runChangePassword(c *cli.Context) error {
return errors.New("The password you chose is on a list of stolen passwords previously exposed in public data breaches. Please try again with a different password.\nFor more details, see https://haveibeenpwned.com/Passwords")
}
uname := c.String("username")
user, err := user_model.GetUserByName(uname)
user, err := user_model.GetUserByName(ctx, uname)
if err != nil {
return err
}
Expand Down Expand Up @@ -659,7 +659,7 @@ func runDeleteUser(c *cli.Context) error {
if c.IsSet("email") {
user, err = user_model.GetUserByEmail(c.String("email"))
} else if c.IsSet("username") {
user, err = user_model.GetUserByName(c.String("username"))
user, err = user_model.GetUserByName(ctx, c.String("username"))
} else {
user, err = user_model.GetUserByID(c.Int64("id"))
}
Expand Down Expand Up @@ -689,7 +689,7 @@ func runGenerateAccessToken(c *cli.Context) error {
return err
}

user, err := user_model.GetUserByName(c.String("username"))
user, err := user_model.GetUserByName(ctx, c.String("username"))
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions integrations/api_issue_tracked_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestAPIGetTrackedTimes(t *testing.T) {
resp := session.MakeRequest(t, req, http.StatusOK)
var apiTimes api.TrackedTimeList
DecodeJSON(t, resp, &apiTimes)
expect, err := models.GetTrackedTimes(&models.FindTrackedTimesOptions{IssueID: issue2.ID})
expect, err := models.GetTrackedTimes(db.DefaultContext, &models.FindTrackedTimesOptions{IssueID: issue2.ID})
assert.NoError(t, err)
assert.Len(t, apiTimes, 3)

Expand Down Expand Up @@ -83,15 +83,15 @@ func TestAPIDeleteTrackedTime(t *testing.T) {
session.MakeRequest(t, req, http.StatusNotFound)

// Reset time of user 2 on issue 2
trackedSeconds, err := models.GetTrackedSeconds(models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
trackedSeconds, err := models.GetTrackedSeconds(db.DefaultContext, models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
assert.NoError(t, err)
assert.Equal(t, int64(3661), trackedSeconds)

req = NewRequestf(t, "DELETE", "/api/v1/repos/%s/%s/issues/%d/times?token=%s", user2.Name, issue2.Repo.Name, issue2.Index, token)
session.MakeRequest(t, req, http.StatusNoContent)
session.MakeRequest(t, req, http.StatusNotFound)

trackedSeconds, err = models.GetTrackedSeconds(models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
trackedSeconds, err = models.GetTrackedSeconds(db.DefaultContext, models.FindTrackedTimesOptions{IssueID: 2, UserID: 2})
assert.NoError(t, err)
assert.Equal(t, int64(0), trackedSeconds)
}
Expand Down
2 changes: 1 addition & 1 deletion integrations/api_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func testAPIRepoMigrateConflict(t *testing.T, u *url.URL) {
defer util.RemoveAll(dstPath)
t.Run("CreateRepo", doAPICreateRepository(httpContext, false))

user, err := user_model.GetUserByName(httpContext.Username)
user, err := user_model.GetUserByName(db.DefaultContext, httpContext.Username)
assert.NoError(t, err)
userID := user.ID

Expand Down
4 changes: 2 additions & 2 deletions integrations/auth_ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func TestLDAPGroupTeamSyncAddMember(t *testing.T) {
addAuthSourceLDAP(t, "", "on", `{"cn=ship_crew,ou=people,dc=planetexpress,dc=com":{"org26": ["team11"]},"cn=admin_staff,ou=people,dc=planetexpress,dc=com": {"non-existent": ["non-existent"]}}`)
org, err := organization.GetOrgByName("org26")
assert.NoError(t, err)
team, err := organization.GetTeam(org.ID, "team11")
team, err := organization.GetTeam(db.DefaultContext, org.ID, "team11")
assert.NoError(t, err)
auth.SyncExternalUsers(context.Background(), true)
for _, gitLDAPUser := range gitLDAPUsers {
Expand Down Expand Up @@ -366,7 +366,7 @@ func TestLDAPGroupTeamSyncRemoveMember(t *testing.T) {
addAuthSourceLDAP(t, "", "on", `{"cn=dispatch,ou=people,dc=planetexpress,dc=com": {"org26": ["team11"]}}`)
org, err := organization.GetOrgByName("org26")
assert.NoError(t, err)
team, err := organization.GetTeam(org.ID, "team11")
team, err := organization.GetTeam(db.DefaultContext, org.ID, "team11")
assert.NoError(t, err)
loginUserWithPassword(t, gitLDAPUsers[0].UserName, gitLDAPUsers[0].Password)
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{
Expand Down
3 changes: 2 additions & 1 deletion integrations/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/perm"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unittest"
Expand Down Expand Up @@ -438,7 +439,7 @@ func doProtectBranch(ctx APITestContext, branch, userToWhitelist, unprotectedFil
})
ctx.Session.MakeRequest(t, req, http.StatusSeeOther)
} else {
user, err := user_model.GetUserByName(userToWhitelist)
user, err := user_model.GetUserByName(db.DefaultContext, userToWhitelist)
assert.NoError(t, err)
// Change branch to protected
req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/settings/branches/%s", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame), url.PathEscape(branch)), map[string]string{
Expand Down
2 changes: 1 addition & 1 deletion integrations/mirror_pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestMirrorPull(t *testing.T) {
IsTag: true,
}, nil, ""))

_, err = repo_model.GetMirrorByRepoID(mirror.ID)
_, err = repo_model.GetMirrorByRepoID(ctx, mirror.ID)
assert.NoError(t, err)

ok := mirror_service.SyncPullMirror(ctx, mirror.ID)
Expand Down
3 changes: 2 additions & 1 deletion integrations/pull_merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
Expand Down Expand Up @@ -407,7 +408,7 @@ func TestConflictChecking(t *testing.T) {
assert.NoError(t, err)

issue := unittest.AssertExistsAndLoadBean(t, &models.Issue{Title: "PR with conflict!"}).(*models.Issue)
conflictingPR, err := models.GetPullRequestByIssueID(issue.ID)
conflictingPR, err := models.GetPullRequestByIssueID(db.DefaultContext, issue.ID)
assert.NoError(t, err)

// Ensure conflictedFiles is populated.
Expand Down
3 changes: 2 additions & 1 deletion integrations/pull_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/git"
Expand Down Expand Up @@ -165,7 +166,7 @@ func createOutdatedPR(t *testing.T, actor, forkOrg *user_model.User) *models.Pul
assert.NoError(t, err)

issue := unittest.AssertExistsAndLoadBean(t, &models.Issue{Title: "Test Pull -to-update-"}).(*models.Issue)
pr, err := models.GetPullRequestByIssueID(issue.ID)
pr, err := models.GetPullRequestByIssueID(db.DefaultContext, issue.ID)
assert.NoError(t, err)

return pr
Expand Down
12 changes: 5 additions & 7 deletions models/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,8 @@ func (a *Action) getCommentLink(ctx context.Context) string {
if a == nil {
return "#"
}
e := db.GetEngine(ctx)
if a.Comment == nil && a.CommentID != 0 {
a.Comment, _ = getCommentByID(e, a.CommentID)
a.Comment, _ = GetCommentByID(ctx, a.CommentID)
}
if a.Comment != nil {
return a.Comment.HTMLURL()
Expand All @@ -239,7 +238,7 @@ func (a *Action) getCommentLink(ctx context.Context) string {
return "#"
}

issue, err := getIssueByID(e, issueID)
issue, err := getIssueByID(ctx, issueID)
if err != nil {
return "#"
}
Expand Down Expand Up @@ -340,8 +339,7 @@ func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, error) {
return nil, err
}

e := db.GetEngine(ctx)
sess := e.Where(cond).
sess := db.GetEngine(ctx).Where(cond).
Select("`action`.*"). // this line will avoid select other joined table's columns
Join("INNER", "repository", "`repository`.id = `action`.repo_id")

Expand All @@ -354,7 +352,7 @@ func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, error) {
return nil, fmt.Errorf("Find: %v", err)
}

if err := ActionList(actions).loadAttributes(e); err != nil {
if err := ActionList(actions).loadAttributes(ctx); err != nil {
return nil, fmt.Errorf("LoadAttributes: %v", err)
}

Expand Down Expand Up @@ -504,7 +502,7 @@ func notifyWatchers(ctx context.Context, actions ...*Action) error {
permIssue = make([]bool, len(watchers))
permPR = make([]bool, len(watchers))
for i, watcher := range watchers {
user, err := user_model.GetUserByIDEngine(e, watcher.UserID)
user, err := user_model.GetUserByIDCtx(ctx, watcher.UserID)
if err != nil {
permCode[i] = false
permIssue[i] = false
Expand Down
21 changes: 11 additions & 10 deletions models/action_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package models

import (
"context"
"fmt"

"code.gitea.io/gitea/models/db"
Expand All @@ -26,14 +27,14 @@ func (actions ActionList) getUserIDs() []int64 {
return container.KeysInt64(userIDs)
}

func (actions ActionList) loadUsers(e db.Engine) (map[int64]*user_model.User, error) {
func (actions ActionList) loadUsers(ctx context.Context) (map[int64]*user_model.User, error) {
if len(actions) == 0 {
return nil, nil
}

userIDs := actions.getUserIDs()
userMaps := make(map[int64]*user_model.User, len(userIDs))
err := e.
err := db.GetEngine(ctx).
In("id", userIDs).
Find(&userMaps)
if err != nil {
Expand All @@ -56,14 +57,14 @@ func (actions ActionList) getRepoIDs() []int64 {
return container.KeysInt64(repoIDs)
}

func (actions ActionList) loadRepositories(e db.Engine) error {
func (actions ActionList) loadRepositories(ctx context.Context) error {
if len(actions) == 0 {
return nil
}

repoIDs := actions.getRepoIDs()
repoMaps := make(map[int64]*repo_model.Repository, len(repoIDs))
err := e.In("id", repoIDs).Find(&repoMaps)
err := db.GetEngine(ctx).In("id", repoIDs).Find(&repoMaps)
if err != nil {
return fmt.Errorf("find repository: %v", err)
}
Expand All @@ -74,7 +75,7 @@ func (actions ActionList) loadRepositories(e db.Engine) error {
return nil
}

func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_model.User) (err error) {
func (actions ActionList) loadRepoOwner(ctx context.Context, userMap map[int64]*user_model.User) (err error) {
if userMap == nil {
userMap = make(map[int64]*user_model.User)
}
Expand All @@ -85,7 +86,7 @@ func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_mod
}
repoOwner, ok := userMap[action.Repo.OwnerID]
if !ok {
repoOwner, err = user_model.GetUserByID(action.Repo.OwnerID)
repoOwner, err = user_model.GetUserByIDCtx(ctx, action.Repo.OwnerID)
lunny marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
if user_model.IsErrUserNotExist(err) {
continue
Expand All @@ -101,15 +102,15 @@ func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_mod
}

// loadAttributes loads all attributes
func (actions ActionList) loadAttributes(e db.Engine) error {
userMap, err := actions.loadUsers(e)
func (actions ActionList) loadAttributes(ctx context.Context) error {
userMap, err := actions.loadUsers(ctx)
if err != nil {
return err
}

if err := actions.loadRepositories(e); err != nil {
if err := actions.loadRepositories(ctx); err != nil {
return err
}

return actions.loadRepoOwner(e, userMap)
return actions.loadRepoOwner(ctx, userMap)
}
8 changes: 4 additions & 4 deletions models/asymkey/gpg_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,16 @@ func parseGPGKey(ownerID int64, e *openpgp.Entity, verified bool) (*GPGKey, erro
}

// deleteGPGKey does the actual key deletion
func deleteGPGKey(e db.Engine, keyID string) (int64, error) {
func deleteGPGKey(ctx context.Context, keyID string) (int64, error) {
if keyID == "" {
return 0, fmt.Errorf("empty KeyId forbidden") // Should never happen but just to be sure
}
// Delete imported key
n, err := e.Where("key_id=?", keyID).Delete(new(GPGKeyImport))
n, err := db.GetEngine(ctx).Where("key_id=?", keyID).Delete(new(GPGKeyImport))
if err != nil {
return n, err
}
return e.Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey))
return db.GetEngine(ctx).Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey))
}

// DeleteGPGKey deletes GPG key information in database.
Expand All @@ -231,7 +231,7 @@ func DeleteGPGKey(doer *user_model.User, id int64) (err error) {
}
defer committer.Close()

if _, err = deleteGPGKey(db.GetEngine(ctx), key.KeyID); err != nil {
if _, err = deleteGPGKey(ctx, key.KeyID); err != nil {
return err
}

Expand Down
17 changes: 9 additions & 8 deletions models/asymkey/gpg_key_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package asymkey

import (
"context"
"strings"

"code.gitea.io/gitea/models/db"
Expand All @@ -29,36 +30,36 @@ import (
// This file contains functions relating to adding GPG Keys

// addGPGKey add key, import and subkeys to database
func addGPGKey(e db.Engine, key *GPGKey, content string) (err error) {
func addGPGKey(ctx context.Context, key *GPGKey, content string) (err error) {
// Add GPGKeyImport
if _, err = e.Insert(GPGKeyImport{
if err = db.Insert(ctx, &GPGKeyImport{
KeyID: key.KeyID,
Content: content,
}); err != nil {
return err
}
// Save GPG primary key.
if _, err = e.Insert(key); err != nil {
if err = db.Insert(ctx, key); err != nil {
return err
}
// Save GPG subs key.
for _, subkey := range key.SubsKey {
if err := addGPGSubKey(e, subkey); err != nil {
if err := addGPGSubKey(ctx, subkey); err != nil {
return err
}
}
return nil
}

// addGPGSubKey add subkeys to database
func addGPGSubKey(e db.Engine, key *GPGKey) (err error) {
func addGPGSubKey(ctx context.Context, key *GPGKey) (err error) {
// Save GPG primary key.
if _, err = e.Insert(key); err != nil {
if err = db.Insert(ctx, key); err != nil {
return err
}
// Save GPG subs key.
for _, subkey := range key.SubsKey {
if err := addGPGSubKey(e, subkey); err != nil {
if err := addGPGSubKey(ctx, subkey); err != nil {
return err
}
}
Expand Down Expand Up @@ -158,7 +159,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro
return nil, err
}

if err = addGPGKey(db.GetEngine(ctx), key, content); err != nil {
if err = addGPGKey(ctx, key, content); err != nil {
return nil, err
}
keys = append(keys, key)
Expand Down
Loading