From 669904cd060186d0ea760057699aa3f4076ac13b Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:49:00 +0100 Subject: [PATCH] [management] Remove context from database calls (#2863) --- management/server/sql_store.go | 274 ++++++--------------------------- 1 file changed, 45 insertions(+), 229 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 646184578eb..5bc521437d7 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -300,13 +300,11 @@ func (s *SqlStore) GetInstallationID() string { } func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { - startTime := time.Now() - // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID - err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + err := s.db.Transaction(func(tx *gorm.DB) error { // check if peer exists before saving var peerID string result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID) @@ -327,9 +325,6 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. }) if err != nil { - if errors.Is(err, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return err } @@ -337,8 +332,6 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. } func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { - startTime := time.Now() - accountCopy := Account{ Domain: domain, DomainCategory: category, @@ -346,14 +339,11 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID } fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} - result := s.db.WithContext(ctx).Model(&Account{}). + result := s.db.Model(&Account{}). Select(fieldsToUpdate). Where(idQueryCondition, accountID). Updates(&accountCopy) if result.Error != nil { - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return result.Error } @@ -365,8 +355,6 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID } func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { - startTime := time.Now() - var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -379,9 +367,6 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) if result.Error != nil { - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return result.Error } @@ -393,8 +378,6 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe } func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { - startTime := time.Now() - // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. var peerCopy nbpeer.Peer // Since the location field has been migrated to JSON serialization, @@ -406,9 +389,6 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P Updates(peerCopy) if result.Error != nil { - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return result.Error } @@ -422,8 +402,6 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P // SaveUsers saves the given list of users to the database. // It updates existing users if a conflict occurs. func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { - startTime := time.Now() - usersToSave := make([]User, 0, len(users)) for _, user := range users { user.AccountID = accountID @@ -437,9 +415,6 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { Clauses(clause.OnConflict{UpdateAll: true}). Create(&usersToSave).Error if err != nil { - if errors.Is(err, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return status.Errorf(status.Internal, "failed to save users to store: %v", err) } @@ -448,13 +423,8 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { // SaveUser saves the given user to the database. func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { - startTime := time.Now() - - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) if result.Error != nil { - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) } return nil @@ -462,17 +432,12 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u // SaveGroups saves the given list of groups to the database. func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { - startTime := time.Now() - if len(groups) == 0 { return nil } - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) if result.Error != nil { - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) } return nil @@ -499,10 +464,8 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) } func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { - startTime := time.Now() - var accountID string - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", strings.ToLower(domain), true, PrivateCategory, ).First(&accountID) @@ -510,9 +473,6 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } - if errors.Is(result.Error, context.Canceled) { - return "", status.NewStoreContextCanceledError(time.Since(startTime)) - } log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) return "", status.NewGetAccountFromStoreError(result.Error) } @@ -521,17 +481,12 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength } func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { - startTime := time.Now() - var key SetupKey - result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey) + result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, status.NewSetupKeyNotFoundError(result.Error) } @@ -543,17 +498,12 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* } func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { - startTime := time.Now() - var token PersonalAccessToken result := s.db.First(&token, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - if errors.Is(result.Error, context.Canceled) { - return "", status.NewStoreContextCanceledError(time.Since(startTime)) - } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) return "", status.NewGetAccountFromStoreError(result.Error) } @@ -562,17 +512,12 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri } func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) { - startTime := time.Now() - var token PersonalAccessToken result := s.db.First(&token, idQueryCondition, tokenID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -596,18 +541,13 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, } func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { - startTime := time.Now() - var user User - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Preload(clause.Associations).First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, status.NewGetUserFromStoreError() } @@ -615,17 +555,12 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { - startTime := time.Now() - var users []*User result := s.db.Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting users from store") } @@ -634,17 +569,12 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us } func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { - startTime := time.Now() - var groups []*nbgroup.Group result := s.db.Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting groups from store") } @@ -744,17 +674,12 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { - startTime := time.Now() - var user User - result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID) + result := s.db.Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -766,17 +691,12 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun } func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { - startTime := time.Now() - var peer nbpeer.Peer - result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID) + result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -788,17 +708,12 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco } func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { - startTime := time.Now() - var peer nbpeer.Peer - result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey) + result := s.db.Select("account_id").First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -810,18 +725,13 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( } func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { - startTime := time.Now() - var peer nbpeer.Peer var accountID string - result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) + result := s.db.Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - if errors.Is(result.Error, context.Canceled) { - return "", status.NewStoreContextCanceledError(time.Since(startTime)) - } return "", status.NewGetAccountFromStoreError(result.Error) } @@ -829,17 +739,12 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) } func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { - startTime := time.Now() - var accountID string result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - if errors.Is(result.Error, context.Canceled) { - return "", status.NewStoreContextCanceledError(time.Since(startTime)) - } return "", status.NewGetAccountFromStoreError(result.Error) } @@ -847,17 +752,12 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { } func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { - startTime := time.Now() - var accountID string - result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) + result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - if errors.Is(result.Error, context.Canceled) { - return "", status.NewStoreContextCanceledError(time.Since(startTime)) - } return "", status.NewSetupKeyNotFoundError(result.Error) } @@ -869,21 +769,16 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) } func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { - startTime := time.Now() - var ipJSONStrings []string // Fetch the IP addresses as JSON strings - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Where("account_id = ?", accountID). Pluck("ip", &ipJSONStrings) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "no peers found for the account") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) } @@ -901,10 +796,8 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength } func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { - startTime := time.Now() - var labels []string - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Where("account_id = ?", accountID). Pluck("dns_label", &labels) @@ -912,9 +805,6 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "no peers found for the account") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error) } @@ -923,33 +813,23 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { - startTime := time.Now() - var accountNetwork AccountNetwork - if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { + if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } - if errors.Is(err, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) } return accountNetwork.Network, nil } func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { - startTime := time.Now() - var peer nbpeer.Peer - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) } @@ -957,16 +837,11 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { - startTime := time.Now() - var accountSettings AccountSettings - if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } - if errors.Is(err, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err) } return accountSettings.Settings, nil @@ -974,17 +849,13 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { - startTime := time.Now() - var user User - result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID) + result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewUserNotFoundError(userID) } - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } + return status.NewGetUserFromStoreError() } user.LastLogin = lastLogin @@ -993,8 +864,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri } func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { - startTime := time.Now() - definitionJSON, err := json.Marshal(checks) if err != nil { return nil, err @@ -1003,9 +872,6 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p var postureCheck posture.Checks err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error if err != nil { - if errors.Is(err, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, err } @@ -1115,27 +981,20 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, } func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { - startTime := time.Now() - var setupKey SetupKey - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, keyQueryCondition, key) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "setup key not found") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, status.NewSetupKeyNotFoundError(result.Error) } return &setupKey, nil } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - startTime := time.Now() - - result := s.db.WithContext(ctx).Model(&SetupKey{}). + result := s.db.Model(&SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ "used_times": gorm.Expr("used_times + 1"), @@ -1143,9 +1002,6 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string }) if result.Error != nil { - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error) } @@ -1157,17 +1013,12 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string } func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - startTime := time.Now() - var group nbgroup.Group - result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group) + result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") } - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) } @@ -1180,9 +1031,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) if err := s.db.Save(&group).Error; err != nil { - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } @@ -1190,17 +1038,13 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer } func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - startTime := time.Now() - var group nbgroup.Group - result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group not found for account") } - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) } @@ -1213,9 +1057,6 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) if err := s.db.Save(&group).Error; err != nil { - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return status.Errorf(status.Internal, "issue updating group: %s", err) } @@ -1224,16 +1065,11 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { - return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) + return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID) } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - startTime := time.Now() - - if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { - if errors.Is(err, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } + if err := s.db.Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -1241,20 +1077,15 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - startTime := time.Now() - - result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { - if errors.Is(result.Error, context.Canceled) { - return status.NewStoreContextCanceledError(time.Since(startTime)) - } return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) } return nil } func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { - tx := s.db.WithContext(ctx).Begin() + tx := s.db.Begin() if tx.Error != nil { return tx.Error } @@ -1278,18 +1109,13 @@ func (s *SqlStore) GetDB() *gorm.DB { } func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { - startTime := time.Now() - var accountDNSSettings AccountDNSSettings - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "dns settings not found") } - if errors.Is(result.Error, context.Canceled) { - return nil, status.NewStoreContextCanceledError(time.Since(startTime)) - } return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) } return &accountDNSSettings.DNSSettings, nil @@ -1297,18 +1123,13 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki // AccountExists checks whether an account exists by the given ID. func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { - startTime := time.Now() - var accountID string - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). Select("id").First(&accountID, idQueryCondition, id) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return false, nil } - if errors.Is(result.Error, context.Canceled) { - return false, status.NewStoreContextCanceledError(time.Since(startTime)) - } return false, result.Error } @@ -1317,18 +1138,13 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { - startTime := time.Now() - var account Account - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). Where(idQueryCondition, accountID).First(&account) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", "", status.Errorf(status.NotFound, "account not found") } - if errors.Is(result.Error, context.Canceled) { - return "", "", status.NewStoreContextCanceledError(time.Since(startTime)) - } return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error) } @@ -1337,7 +1153,7 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength // GetGroupByID retrieves a group by ID and account ID. func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { - return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID) + return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID) } // GetGroupByName retrieves a group by name and account ID. @@ -1346,7 +1162,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // TODO: This fix is accepted for now, but if we need to handle this more frequently // we may need to reconsider changing the types. - query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) + query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) if s.storeEngine == PostgresStoreEngine { query = query.Order("json_array_length(peers::json) DESC") } else { @@ -1365,7 +1181,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // SaveGroup saves a group to the store. func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) } @@ -1374,56 +1190,56 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) + return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID) } // GetPolicyByID retrieves a policy by its ID and account ID. func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { - return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID) + return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID) } // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { - return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[*posture.Checks](s.db, lockStrength, accountID) } // GetPostureChecksByID retrieves posture checks by their ID and account ID. func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { - return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID) + return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID) } // GetAccountRoutes retrieves network routes for an account. func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { - return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[*route.Route](s.db, lockStrength, accountID) } // GetRouteByID retrieves a route by its ID and account ID. func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { - return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID) + return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID) } // GetAccountSetupKeys retrieves setup keys for an account. func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { - return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[*SetupKey](s.db, lockStrength, accountID) } // GetSetupKeyByID retrieves a setup key by its ID and account ID. func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) { - return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID) + return getRecordByID[SetupKey](s.db, lockStrength, setupKeyID, accountID) } // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { - return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID) } // GetNameServerGroupByID retrieves a name server group by its ID and account ID. func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { - return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) + return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID) } func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error { - return deleteRecordByID[SetupKey](s.db.WithContext(ctx), LockingStrengthUpdate, keyID, accountID) + return deleteRecordByID[SetupKey](s.db, LockingStrengthUpdate, keyID, accountID) } // getRecords retrieves records from the database based on the account ID.