Skip to content
This repository has been archived by the owner on Jul 12, 2023. It is now read-only.

Defensively check deleted_at fields in realm/user join tables for memberships #1565

Merged
merged 2 commits into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pkg/database/pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ const (
// If page is 0, it defaults to 1. If limit is 0, it defaults to the global
// pagination default limit.
func Paginate(query *gorm.DB, result interface{}, page, limit uint64) (*pagination.Paginator, error) {
if page < 1 {
page = 1
}
if limit < 1 {
limit = pagination.DefaultLimit
}

return PaginateFn(query, page, limit, func(query *gorm.DB, offset uint64) error {
return query.
Limit(limit).
Expand All @@ -44,13 +51,6 @@ func Paginate(query *gorm.DB, result interface{}, page, limit uint64) (*paginati

// PaginateFn paginates with a custom function for returning results.
func PaginateFn(query *gorm.DB, page, limit uint64, populateFn func(query *gorm.DB, offset uint64) error) (*pagination.Paginator, error) {
if page < 1 {
page = 1
}
if limit < 1 {
limit = pagination.DefaultLimit
}

var total uint64
if err := query.
Count(&total).
Expand Down
3 changes: 3 additions & 0 deletions pkg/database/realm.go
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,9 @@ func (r *Realm) ListMemberships(db *Database, p *pagination.PageParams, scopes .
Model(&Membership{}).
Scopes(scopes...).
Where("realm_id = ?", r.ID).
Where("realms.deleted_at IS NULL").
Where("users.deleted_at IS NULL").
Joins("JOIN realms ON realms.id = memberships.realm_id").
Joins("JOIN users ON users.id = memberships.user_id").
Order("LOWER(users.name)")

Expand Down
45 changes: 45 additions & 0 deletions pkg/database/realm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/google/exposure-notifications-server/pkg/timeutils"
"github.com/google/exposure-notifications-verification-server/internal/project"
"github.com/google/exposure-notifications-verification-server/pkg/pagination"
"github.com/google/exposure-notifications-verification-server/pkg/rbac"
"github.com/jinzhu/gorm"
"github.com/lib/pq"
)
Expand Down Expand Up @@ -737,6 +738,50 @@ func TestRealm_CreateSigningKeyVersion(t *testing.T) {
}
}

func TestRealm_ListMemberships(t *testing.T) {
t.Parallel()

db, _ := testDatabaseInstance.NewDatabase(t, nil)
realm, err := db.FindRealm(1)
if err != nil {
t.Fatal(err)
}
user, err := db.FindUser(1)
if err != nil {
t.Fatal(err)
}
if err := user.AddToRealm(db, realm, rbac.CodeIssue, SystemTest); err != nil {
t.Fatal(err)
}

deletedUser := &User{
Name: "User",
Email: "foo@bar.com",
}
if err := db.SaveUser(deletedUser, SystemTest); err != nil {
t.Fatal(err)
}
if err := deletedUser.AddToRealm(db, realm, rbac.CodeIssue, SystemTest); err != nil {
t.Fatal(err)
}
now := time.Now().UTC().Add(-10 * time.Minute)
deletedUser.DeletedAt = &now
if err := db.SaveUser(deletedUser, SystemTest); err != nil {
t.Fatal(err)
}

memberships, _, err := realm.ListMemberships(db, nil)
if err != nil {
t.Fatal(err)
}
if len(memberships) != 1 {
t.Fatalf("expected %#v to have 1 element", memberships)
}
if got, want := memberships[0].UserID, user.ID; got != want {
t.Errorf("expected %d to be %d", got, want)
}
}

func TestRealm_SMSConfig(t *testing.T) {
t.Parallel()

Expand Down
6 changes: 3 additions & 3 deletions pkg/database/scopes.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ func WithUserSearch(q string) Scope {
}

if search.name != "" {
db = db.Where("name ~* ?", fmt.Sprintf("(%s)", search.name))
db = db.Where("users.name ~* ?", fmt.Sprintf("(%s)", search.name))
}

if search.email != "" {
db = db.Where("email ~* ?", fmt.Sprintf("(%s)", search.email))
db = db.Where("users.email ~* ?", fmt.Sprintf("(%s)", search.email))
}

// For backwards-compatibility with previous versions of search, other could
// have been a name or email.
if search.other != nil {
s := strings.Join(search.other, "|")
db = db.Where("name ~* ? OR email ~* ?", fmt.Sprintf("(%s)", s), fmt.Sprintf("(%s)", s))
db = db.Where("users.name ~* ? OR users.email ~* ?", fmt.Sprintf("(%s)", s), fmt.Sprintf("(%s)", s))
}

if p := search.withPerms; p != 0 {
Expand Down