Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[performance] add user cache and database #879

Merged
merged 6 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions cmd/gotosocial/action/admin/account/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ import (

"github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/validate"
"golang.org/x/crypto/bcrypt"
)
Expand Down Expand Up @@ -92,8 +90,8 @@ var Confirm action.GTSAction = func(ctx context.Context) error {
return err
}

u := &gtsmodel.User{}
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err != nil {
return err
}

Expand Down Expand Up @@ -130,16 +128,16 @@ var Promote action.GTSAction = func(ctx context.Context) error {
return err
}

u := &gtsmodel.User{}
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err != nil {
return err
}

updatingColumns := []string{"admin", "updated_at"}
admin := true
u.Admin = &admin
u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}

Expand All @@ -166,16 +164,16 @@ var Demote action.GTSAction = func(ctx context.Context) error {
return err
}

u := &gtsmodel.User{}
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err != nil {
return err
}

updatingColumns := []string{"admin", "updated_at"}
admin := false
u.Admin = &admin
u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}

Expand All @@ -202,16 +200,16 @@ var Disable action.GTSAction = func(ctx context.Context) error {
return err
}

u := &gtsmodel.User{}
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err != nil {
return err
}

updatingColumns := []string{"disabled", "updated_at"}
disabled := true
u.Disabled = &disabled
u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}

Expand Down Expand Up @@ -252,8 +250,8 @@ var Password action.GTSAction = func(ctx context.Context) error {
return err
}

u := &gtsmodel.User{}
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err != nil {
return err
}

Expand All @@ -265,7 +263,7 @@ var Password action.GTSAction = func(ctx context.Context) error {
updatingColumns := []string{"encrypted_password", "updated_at"}
u.EncryptedPassword = string(pw)
u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}

Expand Down
8 changes: 4 additions & 4 deletions internal/api/client/auth/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
return
}

user := &gtsmodel.User{}
if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
user, err := m.db.GetUserByID(c.Request.Context(), userID)
if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode
Expand Down Expand Up @@ -213,8 +213,8 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
return
}

user := &gtsmodel.User{}
if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
user, err := m.db.GetUserByID(c.Request.Context(), userID)
if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode
Expand Down
10 changes: 6 additions & 4 deletions internal/api/client/auth/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
doTest := func(testCase authorizeHandlerTestCase) {
ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "")

user := suite.testUsers["unconfirmed_account"]
account := suite.testAccounts["unconfirmed_account"]
user := &gtsmodel.User{}
account := &gtsmodel.Account{}

*user = *suite.testUsers["unconfirmed_account"]
*account = *suite.testAccounts["unconfirmed_account"]

testSession := sessions.Default(ctx)
testSession.Set(sessionUserID, user.ID)
Expand All @@ -91,8 +94,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt)

updatingColumns = append(updatingColumns, "updated_at")
user.UpdatedAt = time.Now()
err := suite.db.UpdateByPrimaryKey(context.Background(), user, updatingColumns...)
_, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...)
suite.NoError(err)
_, err = suite.db.UpdateAccount(context.Background(), account)
suite.NoError(err)
Expand Down
3 changes: 1 addition & 2 deletions internal/api/client/auth/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i

// see if we already have a user for this email address
// if so, we don't need to continue + create one
user := &gtsmodel.User{}
err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
user, err := m.db.GetUserByEmailAddress(ctx, claims.Email)
if err == nil {
return user, nil
}
Expand Down
6 changes: 2 additions & 4 deletions internal/api/client/auth/signin.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"golang.org/x/crypto/bcrypt"
)

Expand Down Expand Up @@ -119,8 +117,8 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st
return incorrectPassword(err)
}

user := &gtsmodel.User{}
if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, user); err != nil {
user, err := m.db.GetUserByEmailAddress(ctx, email)
if err != nil {
err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
return incorrectPassword(err)
}
Expand Down
23 changes: 13 additions & 10 deletions internal/api/security/tokencheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func (m *Module) TokenCheck(c *gin.Context) {
log.Tracef("authenticated user %s with bearer token, scope is %s", userID, ti.GetScope())

// fetch user for this token
user := &gtsmodel.User{}
if err := m.db.GetByID(ctx, userID, user); err != nil {
user, err := m.db.GetUserByID(ctx, userID)
if err != nil {
if err != db.ErrNoEntries {
log.Errorf("database error looking for user with id %s: %s", userID, err)
return
Expand All @@ -80,22 +80,25 @@ func (m *Module) TokenCheck(c *gin.Context) {
c.Set(oauth.SessionAuthorizedUser, user)

// fetch account for this token
acct, err := m.db.GetAccountByID(ctx, user.AccountID)
if err != nil {
if err != db.ErrNoEntries {
log.Errorf("database error looking for account with id %s: %s", user.AccountID, err)
if user.Account == nil {
acct, err := m.db.GetAccountByID(ctx, user.AccountID)
if err != nil {
if err != db.ErrNoEntries {
log.Errorf("database error looking for account with id %s: %s", user.AccountID, err)
return
}
log.Warnf("no account found for userID %s", userID)
return
}
log.Warnf("no account found for userID %s", userID)
return
user.Account = acct
}

if !acct.SuspendedAt.IsZero() {
if !user.Account.SuspendedAt.IsZero() {
log.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID)
return
}

c.Set(oauth.SessionAuthorizedAccount, acct)
c.Set(oauth.SessionAuthorizedAccount, user.Account)
}

// check for application token
Expand Down
141 changes: 141 additions & 0 deletions internal/cache/user.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package cache

import (
"time"

"codeberg.org/gruf/go-cache/v2"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)

// UserCache is a cache wrapper to provide lookups for gtsmodel.User
type UserCache struct {
cache cache.LookupCache[string, string, *gtsmodel.User]
}

// NewUserCache returns a new instantiated UserCache object
func NewUserCache() *UserCache {
c := &UserCache{}
c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.User]{
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
lm.RegisterLookup("accountid")
lm.RegisterLookup("email")
lm.RegisterLookup("unconfirmedemail")
lm.RegisterLookup("confirmationtoken")
},

AddLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
lm.Set("accountid", user.AccountID, user.ID)
if email := user.Email; email != "" {
lm.Set("email", email, user.ID)
}
if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
lm.Set("unconfirmedemail", unconfirmedEmail, user.ID)
}
if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
lm.Set("confirmationtoken", confirmationToken, user.ID)
}
},

DeleteLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
lm.Delete("accountid", user.AccountID)
if email := user.Email; email != "" {
lm.Delete("email", email)
}
if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
lm.Delete("unconfirmedemail", unconfirmedEmail)
}
if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
lm.Delete("confirmationtoken", confirmationToken)
}
},
})
c.cache.SetTTL(time.Minute*5, false)
c.cache.Start(time.Second * 10)
return c
}

// GetByID attempts to fetch a user from the cache by its ID, you will receive a copy for thread-safety
func (c *UserCache) GetByID(id string) (*gtsmodel.User, bool) {
return c.cache.Get(id)
}

// GetByAccountID attempts to fetch a user from the cache by its account ID, you will receive a copy for thread-safety
func (c *UserCache) GetByAccountID(accountID string) (*gtsmodel.User, bool) {
return c.cache.GetBy("accountid", accountID)
}

// GetByEmail attempts to fetch a user from the cache by its email address, you will receive a copy for thread-safety
func (c *UserCache) GetByEmail(email string) (*gtsmodel.User, bool) {
return c.cache.GetBy("email", email)
}

// GetByUnconfirmedEmail attempts to fetch a user from the cache by its confirmation token, you will receive a copy for thread-safety
func (c *UserCache) GetByConfirmationToken(token string) (*gtsmodel.User, bool) {
return c.cache.GetBy("confirmationtoken", token)
}

// Put places a user in the cache, ensuring that the object place is a copy for thread-safety
func (c *UserCache) Put(user *gtsmodel.User) {
if user == nil || user.ID == "" {
panic("invalid user")
}
c.cache.Set(user.ID, copyUser(user))
}

// Invalidate invalidates one user from the cache using the ID of the user as key.
func (c *UserCache) Invalidate(userID string) {
c.cache.Invalidate(userID)
}

func copyUser(user *gtsmodel.User) *gtsmodel.User {
return &gtsmodel.User{
ID: user.ID,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
Email: user.Email,
AccountID: user.AccountID,
Account: nil,
EncryptedPassword: user.EncryptedPassword,
SignUpIP: user.SignUpIP,
CurrentSignInAt: user.CurrentSignInAt,
CurrentSignInIP: user.CurrentSignInIP,
LastSignInAt: user.LastSignInAt,
LastSignInIP: user.LastSignInIP,
SignInCount: user.SignInCount,
InviteID: user.InviteID,
ChosenLanguages: user.ChosenLanguages,
FilteredLanguages: user.FilteredLanguages,
Locale: user.Locale,
CreatedByApplicationID: user.CreatedByApplicationID,
CreatedByApplication: nil,
LastEmailedAt: user.LastEmailedAt,
ConfirmationToken: user.ConfirmationToken,
ConfirmationSentAt: user.ConfirmationSentAt,
ConfirmedAt: user.ConfirmedAt,
UnconfirmedEmail: user.UnconfirmedEmail,
Moderator: copyBoolPtr(user.Moderator),
Admin: copyBoolPtr(user.Admin),
Disabled: copyBoolPtr(user.Disabled),
Approved: copyBoolPtr(user.Approved),
ResetPasswordToken: user.ResetPasswordToken,
ResetPasswordSentAt: user.ResetPasswordSentAt,
}
}
5 changes: 4 additions & 1 deletion internal/db/bundb/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"time"

"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
Expand All @@ -40,7 +41,8 @@ import (
)

type adminDB struct {
conn *DBConn
conn *DBConn
userCache *cache.UserCache
}

func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
Expand Down Expand Up @@ -175,6 +177,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
Exec(ctx); err != nil {
return nil, a.conn.ProcessError(err)
}
a.userCache.Put(u)

return u, nil
}
Expand Down
Loading