diff --git a/pkg/database/pagination.go b/pkg/database/pagination.go index 50acff830..6b825643f 100644 --- a/pkg/database/pagination.go +++ b/pkg/database/pagination.go @@ -34,6 +34,13 @@ const ( // If page is 0, it defaults to 1. If limit is 0, it defaults to the global // pagination default limit. func Paginate(query *gorm.DB, result interface{}, page, limit uint64) (*pagination.Paginator, error) { + if page < 1 { + page = 1 + } + if limit < 1 { + limit = pagination.DefaultLimit + } + return PaginateFn(query, page, limit, func(query *gorm.DB, offset uint64) error { return query. Limit(limit). @@ -44,13 +51,6 @@ func Paginate(query *gorm.DB, result interface{}, page, limit uint64) (*paginati // PaginateFn paginates with a custom function for returning results. func PaginateFn(query *gorm.DB, page, limit uint64, populateFn func(query *gorm.DB, offset uint64) error) (*pagination.Paginator, error) { - if page < 1 { - page = 1 - } - if limit < 1 { - limit = pagination.DefaultLimit - } - var total uint64 if err := query. Count(&total). diff --git a/pkg/database/realm.go b/pkg/database/realm.go index 04fded10c..a08bf6daa 100644 --- a/pkg/database/realm.go +++ b/pkg/database/realm.go @@ -860,6 +860,9 @@ func (r *Realm) ListMemberships(db *Database, p *pagination.PageParams, scopes . Model(&Membership{}). Scopes(scopes...). Where("realm_id = ?", r.ID). + Where("realms.deleted_at IS NULL"). + Where("users.deleted_at IS NULL"). + Joins("JOIN realms ON realms.id = memberships.realm_id"). Joins("JOIN users ON users.id = memberships.user_id"). Order("LOWER(users.name)") diff --git a/pkg/database/realm_test.go b/pkg/database/realm_test.go index 051c31213..9575d5703 100644 --- a/pkg/database/realm_test.go +++ b/pkg/database/realm_test.go @@ -26,6 +26,7 @@ import ( "github.com/google/exposure-notifications-server/pkg/timeutils" "github.com/google/exposure-notifications-verification-server/internal/project" "github.com/google/exposure-notifications-verification-server/pkg/pagination" + "github.com/google/exposure-notifications-verification-server/pkg/rbac" "github.com/jinzhu/gorm" "github.com/lib/pq" ) @@ -737,6 +738,50 @@ func TestRealm_CreateSigningKeyVersion(t *testing.T) { } } +func TestRealm_ListMemberships(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + realm, err := db.FindRealm(1) + if err != nil { + t.Fatal(err) + } + user, err := db.FindUser(1) + if err != nil { + t.Fatal(err) + } + if err := user.AddToRealm(db, realm, rbac.CodeIssue, SystemTest); err != nil { + t.Fatal(err) + } + + deletedUser := &User{ + Name: "User", + Email: "foo@bar.com", + } + if err := db.SaveUser(deletedUser, SystemTest); err != nil { + t.Fatal(err) + } + if err := deletedUser.AddToRealm(db, realm, rbac.CodeIssue, SystemTest); err != nil { + t.Fatal(err) + } + now := time.Now().UTC().Add(-10 * time.Minute) + deletedUser.DeletedAt = &now + if err := db.SaveUser(deletedUser, SystemTest); err != nil { + t.Fatal(err) + } + + memberships, _, err := realm.ListMemberships(db, nil) + if err != nil { + t.Fatal(err) + } + if len(memberships) != 1 { + t.Fatalf("expected %#v to have 1 element", memberships) + } + if got, want := memberships[0].UserID, user.ID; got != want { + t.Errorf("expected %d to be %d", got, want) + } +} + func TestRealm_SMSConfig(t *testing.T) { t.Parallel() diff --git a/pkg/database/scopes.go b/pkg/database/scopes.go index bb11165aa..1bdc54de3 100644 --- a/pkg/database/scopes.go +++ b/pkg/database/scopes.go @@ -46,18 +46,18 @@ func WithUserSearch(q string) Scope { } if search.name != "" { - db = db.Where("name ~* ?", fmt.Sprintf("(%s)", search.name)) + db = db.Where("users.name ~* ?", fmt.Sprintf("(%s)", search.name)) } if search.email != "" { - db = db.Where("email ~* ?", fmt.Sprintf("(%s)", search.email)) + db = db.Where("users.email ~* ?", fmt.Sprintf("(%s)", search.email)) } // For backwards-compatibility with previous versions of search, other could // have been a name or email. if search.other != nil { s := strings.Join(search.other, "|") - db = db.Where("name ~* ? OR email ~* ?", fmt.Sprintf("(%s)", s), fmt.Sprintf("(%s)", s)) + db = db.Where("users.name ~* ? OR users.email ~* ?", fmt.Sprintf("(%s)", s), fmt.Sprintf("(%s)", s)) } if p := search.withPerms; p != 0 {