Skip to content

Commit

Permalink
Move more functions to db.Find (#28419)
Browse files Browse the repository at this point in the history
Following #28220

This PR move more functions to use `db.Find`.

---------

Co-authored-by: delvh <dev.lh@web.de>
  • Loading branch information
lunny and delvh authored Jan 15, 2024
1 parent e531324 commit 70c4aad
Show file tree
Hide file tree
Showing 36 changed files with 305 additions and 238 deletions.
4 changes: 2 additions & 2 deletions cmd/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ func runRepoSyncReleases(_ *cli.Context) error {
}

func getReleaseCount(ctx context.Context, id int64) (int64, error) {
return repo_model.GetReleaseCountByRepoID(
return db.Count[repo_model.Release](
ctx,
id,
repo_model.FindReleasesOptions{
RepoID: id,
IncludeTags: true,
},
)
Expand Down
55 changes: 23 additions & 32 deletions models/asymkey/gpg_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,13 @@ import (

"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"

"github.com/keybase/go-crypto/openpgp"
"github.com/keybase/go-crypto/openpgp/packet"
"xorm.io/xorm"
"xorm.io/builder"
)

// __________________ ________ ____ __.
// / _____/\______ \/ _____/ | |/ _|____ ___.__.
// / \ ___ | ___/ \ ___ | <_/ __ < | |
// \ \_\ \| | \ \_\ \ | | \ ___/\___ |
// \______ /|____| \______ / |____|__ \___ > ____|
// \/ \/ \/ \/\/

// GPGKey represents a GPG key.
type GPGKey struct {
ID int64 `xorm:"pk autoincr"`
Expand Down Expand Up @@ -54,12 +46,11 @@ func (key *GPGKey) BeforeInsert() {
key.AddedUnix = timeutil.TimeStampNow()
}

// AfterLoad is invoked from XORM after setting the values of all fields of this object.
func (key *GPGKey) AfterLoad(session *xorm.Session) {
err := session.Where("primary_key_id=?", key.KeyID).Find(&key.SubsKey)
if err != nil {
log.Error("Find Sub GPGkeys[%s]: %v", key.KeyID, err)
func (key *GPGKey) LoadSubKeys(ctx context.Context) error {
if err := db.GetEngine(ctx).Where("primary_key_id=?", key.KeyID).Find(&key.SubsKey); err != nil {
return fmt.Errorf("find Sub GPGkeys[%s]: %v", key.KeyID, err)
}
return nil
}

// PaddedKeyID show KeyID padded to 16 characters
Expand All @@ -76,20 +67,26 @@ func PaddedKeyID(keyID string) string {
return zeros[0:16-len(keyID)] + keyID
}

// ListGPGKeys returns a list of public keys belongs to given user.
func ListGPGKeys(ctx context.Context, uid int64, listOptions db.ListOptions) ([]*GPGKey, error) {
sess := db.GetEngine(ctx).Table(&GPGKey{}).Where("owner_id=? AND primary_key_id=''", uid)
if listOptions.Page != 0 {
sess = db.SetSessionPagination(sess, &listOptions)
}

keys := make([]*GPGKey, 0, 2)
return keys, sess.Find(&keys)
type FindGPGKeyOptions struct {
db.ListOptions
OwnerID int64
KeyID string
IncludeSubKeys bool
}

// CountUserGPGKeys return number of gpg keys a user own
func CountUserGPGKeys(ctx context.Context, userID int64) (int64, error) {
return db.GetEngine(ctx).Where("owner_id=? AND primary_key_id=''", userID).Count(&GPGKey{})
func (opts FindGPGKeyOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if !opts.IncludeSubKeys {
cond = cond.And(builder.Eq{"primary_key_id": ""})
}

if opts.OwnerID > 0 {
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
}
if opts.KeyID != "" {
cond = cond.And(builder.Eq{"key_id": opts.KeyID})
}
return cond
}

func GetGPGKeyForUserByID(ctx context.Context, ownerID, keyID int64) (*GPGKey, error) {
Expand All @@ -103,12 +100,6 @@ func GetGPGKeyForUserByID(ctx context.Context, ownerID, keyID int64) (*GPGKey, e
return key, nil
}

// GetGPGKeysByKeyID returns public key by given ID.
func GetGPGKeysByKeyID(ctx context.Context, keyID string) ([]*GPGKey, error) {
keys := make([]*GPGKey, 0, 1)
return keys, db.GetEngine(ctx).Where("key_id=?", keyID).Find(&keys)
}

// GPGKeyToEntity retrieve the imported key and the traducted entity
func GPGKeyToEntity(ctx context.Context, k *GPGKey) (*openpgp.Entity, error) {
impKey, err := GetGPGImportByKeyID(ctx, k.KeyID)
Expand Down
23 changes: 20 additions & 3 deletions models/asymkey/gpg_key_commit_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ func ParseCommitWithSignature(ctx context.Context, c *git.Commit) *CommitVerific

// Now try to associate the signature with the committer, if present
if committer.ID != 0 {
keys, err := ListGPGKeys(ctx, committer.ID, db.ListOptions{})
keys, err := db.Find[GPGKey](ctx, FindGPGKeyOptions{
OwnerID: committer.ID,
})
if err != nil { // Skipping failed to get gpg keys of user
log.Error("ListGPGKeys: %v", err)
return &CommitVerification{
Expand All @@ -176,6 +178,15 @@ func ParseCommitWithSignature(ctx context.Context, c *git.Commit) *CommitVerific
}
}

if err := GPGKeyList(keys).LoadSubKeys(ctx); err != nil {
log.Error("LoadSubKeys: %v", err)
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.failed_retrieval_gpg_keys",
}
}

committerEmailAddresses, _ := user_model.GetEmailAddresses(ctx, committer.ID)
activated := false
for _, e := range committerEmailAddresses {
Expand Down Expand Up @@ -392,7 +403,10 @@ func hashAndVerifyForKeyID(ctx context.Context, sig *packet.Signature, payload s
if keyID == "" {
return nil
}
keys, err := GetGPGKeysByKeyID(ctx, keyID)
keys, err := db.Find[GPGKey](ctx, FindGPGKeyOptions{
KeyID: keyID,
IncludeSubKeys: true,
})
if err != nil {
log.Error("GetGPGKeysByKeyID: %v", err)
return &CommitVerification{
Expand All @@ -407,7 +421,10 @@ func hashAndVerifyForKeyID(ctx context.Context, sig *packet.Signature, payload s
for _, key := range keys {
var primaryKeys []*GPGKey
if key.PrimaryKeyID != "" {
primaryKeys, err = GetGPGKeysByKeyID(ctx, key.PrimaryKeyID)
primaryKeys, err = db.Find[GPGKey](ctx, FindGPGKeyOptions{
KeyID: key.PrimaryKeyID,
IncludeSubKeys: true,
})
if err != nil {
log.Error("GetGPGKeysByKeyID: %v", err)
return &CommitVerification{
Expand Down
38 changes: 38 additions & 0 deletions models/asymkey/gpg_key_list.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package asymkey

import (
"context"

"code.gitea.io/gitea/models/db"
)

type GPGKeyList []*GPGKey

func (keys GPGKeyList) keyIDs() []string {
ids := make([]string, len(keys))
for i, key := range keys {
ids[i] = key.KeyID
}
return ids
}

func (keys GPGKeyList) LoadSubKeys(ctx context.Context) error {
subKeys := make([]*GPGKey, 0, len(keys))
if err := db.GetEngine(ctx).In("primary_key_id", keys.keyIDs()).Find(&subKeys); err != nil {
return err
}
subKeysMap := make(map[string][]*GPGKey, len(subKeys))
for _, key := range subKeys {
subKeysMap[key.PrimaryKeyID] = append(subKeysMap[key.PrimaryKeyID], key)
}

for _, key := range keys {
if subKeys, ok := subKeysMap[key.KeyID]; ok {
key.SubsKey = subKeys
}
}
return nil
}
4 changes: 2 additions & 2 deletions models/asymkey/ssh_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ func (opts FindPublicKeyOptions) ToConds() builder.Cond {
cond = cond.And(builder.Eq{"fingerprint": opts.Fingerprint})
}
if len(opts.KeyTypes) > 0 {
cond = cond.And(builder.In("type", opts.KeyTypes))
cond = cond.And(builder.In("`type`", opts.KeyTypes))
}
if opts.NotKeytype > 0 {
cond = cond.And(builder.Neq{"type": opts.NotKeytype})
cond = cond.And(builder.Neq{"`type`": opts.NotKeytype})
}
if opts.LoginSourceID > 0 {
cond = cond.And(builder.Eq{"login_source_id": opts.LoginSourceID})
Expand Down
23 changes: 0 additions & 23 deletions models/asymkey/ssh_key_principals.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@ import (
"code.gitea.io/gitea/modules/util"
)

// __________ .__ .__ .__
// \______ _______|__| ____ ____ |_____________ | | ______
// | ___\_ __ | |/ \_/ ___\| \____ \__ \ | | / ___/
// | | | | \| | | \ \___| | |_> / __ \| |__\___ \
// |____| |__| |__|___| /\___ |__| __(____ |____/____ >
// \/ \/ |__| \/ \/
//
// This file contains functions related to principals

// AddPrincipalKey adds new principal to database and authorized_principals file.
func AddPrincipalKey(ctx context.Context, ownerID int64, content string, authSourceID int64) (*PublicKey, error) {
dbCtx, committer, err := db.TxContext(ctx)
Expand Down Expand Up @@ -103,17 +94,3 @@ func CheckPrincipalKeyString(ctx context.Context, user *user_model.User, content

return "", fmt.Errorf("didn't match allowed principals: %s", setting.SSH.AuthorizedPrincipalsAllow)
}

// ListPrincipalKeys returns a list of principals belongs to given user.
func ListPrincipalKeys(ctx context.Context, uid int64, listOptions db.ListOptions) ([]*PublicKey, error) {
sess := db.GetEngine(ctx).Where("owner_id = ? AND type = ?", uid, KeyTypePrincipal)
if listOptions.Page != 0 {
sess = db.SetSessionPagination(sess, &listOptions)

keys := make([]*PublicKey, 0, listOptions.PageSize)
return keys, sess.Find(&keys)
}

keys := make([]*PublicKey, 0, 5)
return keys, sess.Find(&keys)
}
3 changes: 1 addition & 2 deletions models/auth/webauthn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"code.gitea.io/gitea/modules/util"

"github.com/go-webauthn/webauthn/webauthn"
"xorm.io/xorm"
)

// ErrWebAuthnCredentialNotExist represents a "ErrWebAuthnCRedentialNotExist" kind of error.
Expand Down Expand Up @@ -83,7 +82,7 @@ func (cred *WebAuthnCredential) BeforeUpdate() {
}

// AfterLoad is invoked from XORM after setting the values of all fields of this object.
func (cred *WebAuthnCredential) AfterLoad(session *xorm.Session) {
func (cred *WebAuthnCredential) AfterLoad() {
cred.LowerName = strings.ToLower(cred.Name)
}

Expand Down
61 changes: 31 additions & 30 deletions models/db/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,16 @@ const (
// Paginator is the base for different ListOptions types
type Paginator interface {
GetSkipTake() (skip, take int)
GetStartEnd() (start, end int)
IsListAll() bool
}

// GetPaginatedSession creates a paginated database session
func GetPaginatedSession(p Paginator) *xorm.Session {
skip, take := p.GetSkipTake()

return x.Limit(take, skip)
}

// SetSessionPagination sets pagination for a database session
func SetSessionPagination(sess Engine, p Paginator) *xorm.Session {
skip, take := p.GetSkipTake()

return sess.Limit(take, skip)
}

// SetEnginePagination sets pagination for a database engine
func SetEnginePagination(e Engine, p Paginator) Engine {
skip, take := p.GetSkipTake()

return e.Limit(take, skip)
}

// ListOptions options to paginate results
type ListOptions struct {
PageSize int
Expand All @@ -66,13 +51,6 @@ func (opts *ListOptions) GetSkipTake() (skip, take int) {
return (opts.Page - 1) * opts.PageSize, opts.PageSize
}

// GetStartEnd returns the start and end of the ListOptions
func (opts *ListOptions) GetStartEnd() (start, end int) {
start, take := opts.GetSkipTake()
end = start + take
return start, end
}

func (opts ListOptions) GetPage() int {
return opts.Page
}
Expand Down Expand Up @@ -135,11 +113,6 @@ func (opts *AbsoluteListOptions) GetSkipTake() (skip, take int) {
return opts.skip, opts.take
}

// GetStartEnd returns the start and end values
func (opts *AbsoluteListOptions) GetStartEnd() (start, end int) {
return opts.skip, opts.skip + opts.take
}

// FindOptions represents a find options
type FindOptions interface {
GetPage() int
Expand All @@ -148,15 +121,34 @@ type FindOptions interface {
ToConds() builder.Cond
}

type JoinFunc func(sess Engine) error

type FindOptionsJoin interface {
ToJoins() []JoinFunc
}

type FindOptionsOrder interface {
ToOrders() string
}

// Find represents a common find function which accept an options interface
func Find[T any](ctx context.Context, opts FindOptions) ([]*T, error) {
sess := GetEngine(ctx).Where(opts.ToConds())
sess := GetEngine(ctx)

if joinOpt, ok := opts.(FindOptionsJoin); ok && len(joinOpt.ToJoins()) > 0 {
for _, joinFunc := range joinOpt.ToJoins() {
if err := joinFunc(sess); err != nil {
return nil, err
}
}
}

sess = sess.Where(opts.ToConds())
page, pageSize := opts.GetPage(), opts.GetPageSize()
if !opts.IsListAll() && pageSize > 0 && page >= 1 {
if !opts.IsListAll() && pageSize > 0 {
if page == 0 {
page = 1
}
sess.Limit(pageSize, (page-1)*pageSize)
}
if newOpt, ok := opts.(FindOptionsOrder); ok && newOpt.ToOrders() != "" {
Expand All @@ -176,8 +168,17 @@ func Find[T any](ctx context.Context, opts FindOptions) ([]*T, error) {

// Count represents a common count function which accept an options interface
func Count[T any](ctx context.Context, opts FindOptions) (int64, error) {
sess := GetEngine(ctx)
if joinOpt, ok := opts.(FindOptionsJoin); ok && len(joinOpt.ToJoins()) > 0 {
for _, joinFunc := range joinOpt.ToJoins() {
if err := joinFunc(sess); err != nil {
return 0, err
}
}
}

var object T
return GetEngine(ctx).Where(opts.ToConds()).Count(&object)
return sess.Where(opts.ToConds()).Count(&object)
}

// FindAndCount represents a common findandcount function which accept an options interface
Expand Down
3 changes: 0 additions & 3 deletions models/db/paginator/paginator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,8 @@ func TestPaginator(t *testing.T) {

for _, c := range cases {
skip, take := c.Paginator.GetSkipTake()
start, end := c.Paginator.GetStartEnd()

assert.Equal(t, c.Skip, skip)
assert.Equal(t, c.Take, take)
assert.Equal(t, c.Start, start)
assert.Equal(t, c.End, end)
}
}
Loading

0 comments on commit 70c4aad

Please sign in to comment.