diff --git a/cmd/server/assets/realmadmin/_form_codes.html b/cmd/server/assets/realmadmin/_form_codes.html index 10082c338..e98649a53 100644 --- a/cmd/server/assets/realmadmin/_form_codes.html +++ b/cmd/server/assets/realmadmin/_form_codes.html @@ -199,7 +199,7 @@
- {{if $realm.ErrorsFor "SMSTextTemplate"}}{{end}} + {{if $realm.ErrorsFor "smsTextTemplate"}}{{end}} @@ -214,7 +214,7 @@ New SMS template
- {{if $realm.ErrorsFor "SMSTextTemplate"}} + {{if $realm.ErrorsFor "smsTextTemplate"}}
Errors found for one or more SMS templates
diff --git a/pkg/database/audit_entry.go b/pkg/database/audit_entry.go index a2985ae3e..37e6a6e2b 100644 --- a/pkg/database/audit_entry.go +++ b/pkg/database/audit_entry.go @@ -15,9 +15,12 @@ package database import ( + "fmt" + "strings" "time" "github.com/google/exposure-notifications-verification-server/pkg/pagination" + "github.com/jinzhu/gorm" ) // AuditEntry represents an event in the system. These records are purged after @@ -63,6 +66,32 @@ type AuditEntry struct { CreatedAt time.Time } +// BeforeSave runs validations. If there are errors, the save fails. +func (a *AuditEntry) BeforeSave(tx *gorm.DB) error { + if a.ActorID == "" { + a.AddError("actor_id", "cannot be blank") + } + if a.ActorDisplay == "" { + a.AddError("actor_display", "cannot be blank") + } + + if a.Action == "" { + a.AddError("action", "cannot be blank") + } + + if a.TargetID == "" { + a.AddError("target_id", "cannot be blank") + } + if a.TargetDisplay == "" { + a.AddError("target_display", "cannot be blank") + } + + if msgs := a.ErrorMessages(); len(msgs) > 0 { + return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", ")) + } + return nil +} + // SaveAuditEntry saves the audit entry. func (db *Database) SaveAuditEntry(a *AuditEntry) error { return db.db.Save(a).Error diff --git a/pkg/database/audit_entry_test.go b/pkg/database/audit_entry_test.go new file mode 100644 index 000000000..e0c49d886 --- /dev/null +++ b/pkg/database/audit_entry_test.go @@ -0,0 +1,130 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "testing" + "time" + + "github.com/google/exposure-notifications-verification-server/pkg/pagination" +) + +func TestAuditEntry_BeforeSave(t *testing.T) { + t.Parallel() + + cases := []struct { + structField string + field string + }{ + {"ActorID", "actor_id"}, + {"ActorDisplay", "actor_display"}, + {"Action", "action"}, + {"TargetID", "target_id"}, + {"TargetDisplay", "target_display"}, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.field, func(t *testing.T) { + t.Parallel() + exerciseValidation(t, &AuditEntry{}, tc.structField, tc.field) + }) + } +} + +func TestDatabase_PurgeAuditEntries(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + for i := 0; i < 5; i++ { + if err := db.SaveAuditEntry(&AuditEntry{ + RealmID: 1, + ActorID: "actor:1", + ActorDisplay: "Actor", + Action: "created", + TargetID: "target:1", + TargetDisplay: "Target", + }); err != nil { + t.Fatal(err) + } + } + + // Should not purge entries (too young). + { + n, err := db.PurgeAuditEntries(24 * time.Hour) + if err != nil { + t.Fatal(err) + } + if got, want := n, int64(0); got != want { + t.Errorf("expected %d to purge, got %d", want, got) + } + } + + // Purges entries. + { + n, err := db.PurgeAuditEntries(1 * time.Nanosecond) + if err != nil { + t.Fatal(err) + } + if got, want := n, int64(5); got != want { + t.Errorf("expected %d to purge, got %d", want, got) + } + } +} + +func TestDatabase_ListAudits(t *testing.T) { + t.Parallel() + + t.Run("empty", func(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + + audits, _, err := db.ListAudits(&pagination.PageParams{Limit: 1}) + if err != nil { + t.Fatal(err) + } + if got, want := len(audits), 0; got != want { + t.Errorf("expected %d audits, got %d: %v", want, got, audits) + } + }) + + t.Run("lists", func(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + for i := 0; i < 5; i++ { + if err := db.SaveAuditEntry(&AuditEntry{ + RealmID: 1, + ActorID: "actor:1", + ActorDisplay: "Actor", + Action: "created", + TargetID: "target:1", + TargetDisplay: "Target", + }); err != nil { + t.Fatal(err) + } + } + + audits, _, err := db.ListAudits(&pagination.PageParams{Limit: 10}) + if err != nil { + t.Fatal(err) + } + if got, want := len(audits), 5; got != want { + t.Errorf("expected %d audits, got %d: %v", want, got, audits) + } + }) +} diff --git a/pkg/database/authorized_app.go b/pkg/database/authorized_app.go index 9c2d0c765..1bb8c4e6f 100644 --- a/pkg/database/authorized_app.go +++ b/pkg/database/authorized_app.go @@ -95,8 +95,8 @@ func (a *AuthorizedApp) BeforeSave(tx *gorm.DB) error { a.AddError("type", "is invalid") } - if len(a.Errors()) > 0 { - return fmt.Errorf("validation failed") + if msgs := a.ErrorMessages(); len(msgs) > 0 { + return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", ")) } return nil } @@ -109,51 +109,20 @@ func (a *AuthorizedApp) IsDeviceType() bool { return a.APIKeyType == APIKeyTypeDevice } -// Realm returns the associated realm for this app. +// Realm returns the associated realm for this app. If you only need the ID, +// call .RealmID instead of a full database lookup. func (a *AuthorizedApp) Realm(db *Database) (*Realm, error) { var realm Realm - if err := db.db.Model(a).Related(&realm).Error; err != nil { + if err := db.db. + Model(&Realm{}). + Where("id = ?", a.RealmID). + First(&realm). + Error; err != nil { return nil, err } return &realm, nil } -// TableName definition for the authorized apps relation. -func (AuthorizedApp) TableName() string { - return "authorized_apps" -} - -// CreateAuthorizedApp generates a new API key and assigns it to the specified -// app. Note that the API key is NOT stored in the database, only a hash. The -// only time the API key is available is as the string return parameter from -// invoking this function. -func (r *Realm) CreateAuthorizedApp(db *Database, app *AuthorizedApp, actor Auditable) (string, error) { - fullAPIKey, err := db.GenerateAPIKey(r.ID) - if err != nil { - return "", fmt.Errorf("failed to generate API key: %w", err) - } - - parts := strings.SplitN(fullAPIKey, ".", 3) - if len(parts) != 3 { - return "", fmt.Errorf("internal error, key is invalid") - } - apiKey := parts[0] - - hmacedKey, err := db.GenerateAPIKeyHMAC(apiKey) - if err != nil { - return "", fmt.Errorf("failed to create hmac: %w", err) - } - - app.RealmID = r.ID - app.APIKey = hmacedKey - app.APIKeyPreview = apiKey[:6] - - if err := db.SaveAuthorizedApp(app, actor); err != nil { - return "", err - } - return fullAPIKey, nil -} - // FindAuthorizedApp finds the authorized app by the given id. func (db *Database) FindAuthorizedApp(id interface{}) (*AuthorizedApp, error) { var app AuthorizedApp diff --git a/pkg/database/authorized_app_stats_test.go b/pkg/database/authorized_app_stats_test.go new file mode 100644 index 000000000..1a8f3512f --- /dev/null +++ b/pkg/database/authorized_app_stats_test.go @@ -0,0 +1,96 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestAuthorizedAppStats_MarshalCSV(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + stats AuthorizedAppStats + exp string + }{ + { + name: "empty", + stats: nil, + exp: "", + }, + { + name: "single", + stats: []*AuthorizedAppStat{ + { + Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC), + AuthorizedAppID: 1, + CodesIssued: 10, + AuthorizedAppName: "Appy", + }, + }, + exp: `date,authorized_app_id,authorized_app_name,codes_issued +2020-02-03,1,Appy,10 +`, + }, + { + name: "multi", + stats: []*AuthorizedAppStat{ + { + Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC), + AuthorizedAppID: 1, + CodesIssued: 10, + AuthorizedAppName: "Appy", + }, + { + Date: time.Date(2020, 2, 4, 0, 0, 0, 0, time.UTC), + AuthorizedAppID: 1, + CodesIssued: 45, + AuthorizedAppName: "Mc", + }, + { + Date: time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC), + AuthorizedAppID: 1, + CodesIssued: 15, + AuthorizedAppName: "Apperson", + }, + }, + exp: `date,authorized_app_id,authorized_app_name,codes_issued +2020-02-03,1,Appy,10 +2020-02-04,1,Mc,45 +2020-02-05,1,Apperson,15 +`, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b, err := tc.stats.MarshalCSV() + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(string(b), tc.exp); diff != "" { + t.Errorf("bad csv (+got, -want): %s", diff) + } + }) + } +} diff --git a/pkg/database/authorized_app_test.go b/pkg/database/authorized_app_test.go index 82f8a6008..16a3688d8 100644 --- a/pkg/database/authorized_app_test.go +++ b/pkg/database/authorized_app_test.go @@ -19,8 +19,160 @@ import ( "fmt" "strings" "testing" + "time" + + "github.com/google/exposure-notifications-server/pkg/timeutils" + "github.com/jinzhu/gorm" ) +func TestAPIKeyType(t *testing.T) { + t.Parallel() + + // This test might seem like it's redundant, but it's designed to ensure that + // the exact values for existing types remain unchanged. + cases := []struct { + t APIKeyType + exp int + }{ + {APIKeyTypeInvalid, -1}, + {APIKeyTypeDevice, 0}, + {APIKeyTypeAdmin, 1}, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.t.Display(), func(t *testing.T) { + t.Parallel() + + if got, want := int(tc.t), tc.exp; got != want { + t.Errorf("expected %d to be %d", got, want) + } + }) + } +} + +func TestAPIKeyType_Display(t *testing.T) { + t.Parallel() + + cases := []struct { + t APIKeyType + exp string + }{ + {APIKeyTypeInvalid, "invalid"}, + {APIKeyTypeDevice, "device"}, + {APIKeyTypeAdmin, "admin"}, + {1991, "invalid"}, + } + + for _, tc := range cases { + tc := tc + + t.Run(fmt.Sprintf("%d", tc.t), func(t *testing.T) { + t.Parallel() + + if got, want := tc.t.Display(), tc.exp; got != want { + t.Errorf("expected %q to be %q", got, want) + } + }) + } +} + +func TestAuthorizedApp_BeforeSave(t *testing.T) { + t.Parallel() + + t.Run("name", func(t *testing.T) { + t.Parallel() + exerciseValidation(t, &AuthorizedApp{}, "Name", "name") + }) + + t.Run("type", func(t *testing.T) { + t.Parallel() + + { + var m AuthorizedApp + m.APIKeyType = -1 + _ = m.BeforeSave(&gorm.DB{}) + if errs := m.ErrorsFor("type"); len(errs) < 1 { + t.Errorf("expected errors for type") + } + } + + { + var m AuthorizedApp + m.APIKeyType = 55 + _ = m.BeforeSave(&gorm.DB{}) + if errs := m.ErrorsFor("type"); len(errs) < 1 { + t.Errorf("expected errors for type") + } + } + + { + var m AuthorizedApp + m.APIKeyType = 0 + _ = m.BeforeSave(&gorm.DB{}) + if errs := m.ErrorsFor("type"); len(errs) != 0 { + t.Errorf("expected no errors for type, got %v", errs) + } + } + }) +} + +func TestAuthorizedApp_Realm(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + realm, err := db.FindRealm(1) + if err != nil { + t.Fatal(err) + } + + authorizedApp := &AuthorizedApp{ + Name: "Appy", + } + if _, err := realm.CreateAuthorizedApp(db, authorizedApp, SystemTest); err != nil { + t.Fatal(err) + } + + gotRealm, err := authorizedApp.Realm(db) + if err != nil { + t.Fatal(err) + } + if got, want := gotRealm.ID, realm.ID; got != want { + t.Errorf("expected %d to be %d", got, want) + } +} + +func TestAuthorizedApp_Stats(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + realm, err := db.FindRealm(1) + if err != nil { + t.Fatal(err) + } + + authorizedApp := &AuthorizedApp{ + Name: "Appy", + } + if _, err := realm.CreateAuthorizedApp(db, authorizedApp, SystemTest); err != nil { + t.Fatal(err) + } + + // Ensure graph is contiguous. + { + stop := timeutils.Midnight(time.Now().UTC()) + start := stop.Add(6 * -24 * time.Hour) + stats, err := authorizedApp.Stats(db, start, stop) + if err != nil { + t.Fatal(err) + } + if got, want := len(stats), 7; got != want { + t.Errorf("expected stats for %d days, got %d", want, got) + } + } +} + func TestDatabase_CreateFindAPIKey(t *testing.T) { t.Parallel() @@ -126,3 +278,45 @@ func TestDatabase_GenerateVerifyAPIKeySignature(t *testing.T) { t.Errorf("expected %v to be %v", got, want) } } + +func TestDatabase_PurgeAuthorizedApps(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + + now := time.Now().UTC() + for i := 0; i < 5; i++ { + if err := db.SaveAuthorizedApp(&AuthorizedApp{ + RealmID: 1, + Name: fmt.Sprintf("appy%d", i), + APIKey: fmt.Sprintf("%d", i), + Model: gorm.Model{ + DeletedAt: &now, + }, + }, SystemTest); err != nil { + t.Fatal(err) + } + } + + // Should not purge entries (too young). + { + n, err := db.PurgeAuthorizedApps(24 * time.Hour) + if err != nil { + t.Fatal(err) + } + if got, want := n, int64(0); got != want { + t.Errorf("expected %d to purge, got %d", want, got) + } + } + + // Purges entries. + { + n, err := db.PurgeAuthorizedApps(1 * time.Nanosecond) + if err != nil { + t.Fatal(err) + } + if got, want := n, int64(5); got != want { + t.Errorf("expected %d to purge, got %d", want, got) + } + } +} diff --git a/pkg/database/cleanup.go b/pkg/database/cleanup.go index a34e4db7c..2dc4c3dff 100644 --- a/pkg/database/cleanup.go +++ b/pkg/database/cleanup.go @@ -34,25 +34,22 @@ type CleanupStatus struct { NotBefore time.Time } -// TableName sets the CleanupStatus table name -func (CleanupStatus) TableName() string { - return "cleanup_statuses" -} - // CreateCleanup is used to create a new 'cleanup' type/row in the database. func (db *Database) CreateCleanup(cType string) (*CleanupStatus, error) { - cstat := &CleanupStatus{ - Type: cType, - Generation: 1, - NotBefore: time.Now().UTC(), - } + var cstat CleanupStatus + + sql := `INSERT INTO cleanup_statuses (type, generation, not_before) + VALUES ($1, $2, $3) + ON CONFLICT (type) DO UPDATE SET type = EXCLUDED.type + RETURNING *` + + now := time.Now().UTC() if err := db.db. - Set("gorm:insert_option", "ON CONFLICT (type) DO NOTHING"). - FirstOrCreate(cstat). - Error; err != nil { + Raw(sql, cType, 1, now). + Scan(&cstat).Error; err != nil { return nil, err } - return cstat, nil + return &cstat, nil } // FindCleanupStatus looks up the current cleanup state in the database by cleanup type. @@ -64,13 +61,14 @@ func (db *Database) FindCleanupStatus(cType string) (*CleanupStatus, error) { return &cstat, nil } -// ClaimCleanup attempts to obtain a lock for the specified `lockTime` so that that type of -// cleanup can be perofmed exclusively by the owner. +// ClaimCleanup attempts to obtain a lock for the specified `lockTime` so that +// that type of cleanup can be performed exclusively by the owner. func (db *Database) ClaimCleanup(current *CleanupStatus, lockTime time.Duration) (*CleanupStatus, error) { var cstat CleanupStatus - err := db.db.Transaction(func(tx *gorm.DB) error { + if err := db.db.Transaction(func(tx *gorm.DB) error { if err := tx. Set("gorm:query_option", "FOR UPDATE"). + Model(&CleanupStatus{}). Where("type = ?", current.Type). First(&cstat). Error; err != nil { @@ -83,8 +81,7 @@ func (db *Database) ClaimCleanup(current *CleanupStatus, lockTime time.Duration) cstat.Generation++ cstat.NotBefore = time.Now().UTC().Add(lockTime) return tx.Save(&cstat).Error - }) - if err != nil { + }); err != nil { return nil, err } return &cstat, nil diff --git a/pkg/database/cleanup_test.go b/pkg/database/cleanup_test.go new file mode 100644 index 000000000..45d4ca561 --- /dev/null +++ b/pkg/database/cleanup_test.go @@ -0,0 +1,120 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "errors" + "testing" + "time" +) + +func TestDatabase_CreateCleanup(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + + cleanup1, err := db.CreateCleanup("x") + if err != nil { + t.Fatal(err) + } + + // If the cleanup already exists, it's a noop + cleanup2, err := db.CreateCleanup("x") + if err != nil { + t.Fatal(err) + } + + if got, want := cleanup1.ID, cleanup2.ID; got != want { + t.Errorf("expected %d to be %d", got, want) + } +} + +func TestDatabase_FindCleanupStatus(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + + want, err := db.CreateCleanup("x") + if err != nil { + t.Fatal(err) + } + + got, err := db.FindCleanupStatus("x") + if err != nil { + t.Fatal(err) + } + + if got, want := got.ID, want.ID; got != want { + t.Errorf("expected %d to be %d", got, want) + } +} + +func TestDatabase_ClaimCleanup(t *testing.T) { + t.Parallel() + + t.Run("no_exist", func(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + cleanup, err := db.ClaimCleanup(&CleanupStatus{Type: "nope"}, 5*time.Second) + if !IsNotFound(err) { + t.Errorf("expected error, got: %v: %v", err, cleanup) + } + }) + + t.Run("wrong_generation", func(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + + if _, err := db.CreateCleanup("dirty"); err != nil { + t.Fatal(err) + } + + _, err := db.ClaimCleanup(&CleanupStatus{ + Type: "dirty", + Generation: 2, + }, 1*time.Second) + if got, want := err, ErrCleanupWrongGeneration; !errors.Is(err, want) { + t.Errorf("expected %v to be %v", got, want) + } + }) + + t.Run("exists", func(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + + if _, err := db.CreateCleanup("dirty"); err != nil { + t.Fatal(err) + } + + cleanup, err := db.ClaimCleanup(&CleanupStatus{ + Type: "dirty", + Generation: 1, + }, 1*time.Second) + if err != nil { + t.Fatal(err) + } + + if got, want := cleanup.Generation, uint(2); got != want { + t.Errorf("expected generation %d to be %d", got, want) + } + + if got, now := cleanup.NotBefore, time.Now().UTC(); !got.After(now) { + t.Errorf("expected %q to be after %q", got, now) + } + }) +} diff --git a/pkg/database/database_test.go b/pkg/database/database_test.go index 18acc770b..c75548c2d 100644 --- a/pkg/database/database_test.go +++ b/pkg/database/database_test.go @@ -15,7 +15,12 @@ package database import ( + "io/ioutil" + "log" + "reflect" "testing" + + "github.com/jinzhu/gorm" ) var testDatabaseInstance *TestInstance @@ -25,3 +30,53 @@ func TestMain(m *testing.M) { defer testDatabaseInstance.MustClose() m.Run() } + +type validateable interface { + ErrorsFor(s string) []string + BeforeSave(tx *gorm.DB) error +} + +// exerciseValidation exercises zero value validation (not empty) for the given +// model and struct fields. +func exerciseValidation(t *testing.T, i validateable, structField, field string) { + // Get interface underlying value. + v := reflect.ValueOf(&i) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + if !v.CanInterface() { + t.Fatalf("%v cannot interface", v) + } + + // Convert interface to struct. + sv := reflect.ValueOf(v.Interface()) + for sv.Kind() == reflect.Ptr { + sv = sv.Elem() + } + if sv.Kind() != reflect.Struct { + t.Fatalf("%T is not a struct: %v", i, sv.Kind()) + } + + // Get struct field name. + f := sv.FieldByName(structField) + if !f.IsValid() { + t.Fatalf("%s is not valid", structField) + } + if !f.CanSet() { + t.Fatalf("%s is not settable", structField) + } + + // Set to the zero value. + valueV := reflect.Zero(f.Type()) + f.Set(valueV) + + // Create db + var db gorm.DB + db.SetLogger(gorm.Logger{LogWriter: log.New(ioutil.Discard, "", 0)}) + + // Run the validation. + _ = i.BeforeSave(&gorm.DB{}) + if errs := i.ErrorsFor(field); len(errs) < 1 { + t.Errorf("expected errors for %s", field) + } +} diff --git a/pkg/database/errors.go b/pkg/database/errors.go index 83469b176..56bbca06b 100644 --- a/pkg/database/errors.go +++ b/pkg/database/errors.go @@ -17,6 +17,7 @@ package database import ( "errors" "fmt" + "sort" ) var ( @@ -47,8 +48,16 @@ func (e *Errorable) Errors() map[string][]string { func (e *Errorable) ErrorMessages() []string { e.init() + // Sort keys so the response is in predictable ordering. + keys := make([]string, 0, len(e.errors)) + for k := range e.errors { + keys = append(keys, k) + } + sort.Strings(keys) + l := make([]string, 0, len(e.errors)) - for k, v := range e.errors { + for _, k := range keys { + v := e.errors[k] for _, msg := range v { l = append(l, fmt.Sprintf("%s %s", k, msg)) } diff --git a/pkg/database/external_issuer_stats_test.go b/pkg/database/external_issuer_stats_test.go new file mode 100644 index 000000000..314b6c22f --- /dev/null +++ b/pkg/database/external_issuer_stats_test.go @@ -0,0 +1,96 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestExternalIssuerStats_MarshalCSV(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + stats ExternalIssuerStats + exp string + }{ + { + name: "empty", + stats: nil, + exp: "", + }, + { + name: "single", + stats: []*ExternalIssuerStat{ + { + Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC), + RealmID: 1, + IssuerID: "user:2", + CodesIssued: 10, + }, + }, + exp: `date,realm_id,issuer_id,codes_issued +2020-02-03,1,user:2,10 +`, + }, + { + name: "multi", + stats: []*ExternalIssuerStat{ + { + Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC), + RealmID: 1, + IssuerID: "user:2", + CodesIssued: 10, + }, + { + Date: time.Date(2020, 2, 4, 0, 0, 0, 0, time.UTC), + RealmID: 1, + IssuerID: "user:2", + CodesIssued: 45, + }, + { + Date: time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC), + RealmID: 1, + IssuerID: "user:2", + CodesIssued: 15, + }, + }, + exp: `date,realm_id,issuer_id,codes_issued +2020-02-03,1,user:2,10 +2020-02-04,1,user:2,45 +2020-02-05,1,user:2,15 +`, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b, err := tc.stats.MarshalCSV() + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(string(b), tc.exp); diff != "" { + t.Errorf("bad csv (+got, -want): %s", diff) + } + }) + } +} diff --git a/pkg/database/membership.go b/pkg/database/membership.go index f4a6c6d36..64ab2f10a 100644 --- a/pkg/database/membership.go +++ b/pkg/database/membership.go @@ -16,12 +16,15 @@ package database import ( "fmt" + "strings" "github.com/google/exposure-notifications-verification-server/pkg/rbac" ) // Membership represents a user's membership in a realm. type Membership struct { + Errorable + UserID uint User *User @@ -35,13 +38,16 @@ type Membership struct { // preloaded and the referenced values exist. func (m *Membership) AfterFind() error { if m.User == nil { - return fmt.Errorf("membership user does not exist") + m.AddError("user", "does not exist") } if m.Realm == nil { - return fmt.Errorf("membership realm does not exist") + m.AddError("realm", "does not exist") } + if msgs := m.ErrorMessages(); len(msgs) > 0 { + return fmt.Errorf("lookup failed: %s", strings.Join(msgs, ", ")) + } return nil } diff --git a/pkg/database/membership_test.go b/pkg/database/membership_test.go new file mode 100644 index 000000000..858847a18 --- /dev/null +++ b/pkg/database/membership_test.go @@ -0,0 +1,43 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "testing" +) + +func TestMembership_AfterFind(t *testing.T) { + t.Parallel() + + t.Run("user", func(t *testing.T) { + t.Parallel() + + var m Membership + _ = m.AfterFind() + if errs := m.ErrorsFor("user"); len(errs) < 1 { + t.Errorf("expected errors for %s", "user") + } + }) + + t.Run("realm", func(t *testing.T) { + t.Parallel() + + var m Membership + _ = m.AfterFind() + if errs := m.ErrorsFor("realm"); len(errs) < 1 { + t.Errorf("expected errors for %s", "realm") + } + }) +} diff --git a/pkg/database/mobile_app.go b/pkg/database/mobile_app.go index 5e92880e0..d4abfa7c6 100644 --- a/pkg/database/mobile_app.go +++ b/pkg/database/mobile_app.go @@ -130,10 +130,14 @@ func (a *MobileApp) BeforeSave(tx *gorm.DB) error { } a.SHA = strings.Join(shas, "\n") - if len(a.Errors()) > 0 { - return fmt.Errorf("validation failed: %v", a.Errors()) + if msgs := a.ErrorMessages(); len(msgs) > 0 { + return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", ")) } + return nil +} +func (a *MobileApp) AfterFind(tx *gorm.DB) error { + a.URL = stringValue(a.URLPtr) return nil } @@ -290,11 +294,6 @@ func (a *MobileApp) AuditDisplay() string { return fmt.Sprintf("%s (%s)", a.Name, a.OS.Display()) } -func (a *MobileApp) AfterFind(tx *gorm.DB) error { - a.URL = stringValue(a.URLPtr) - return nil -} - // PurgeMobileApps will delete mobile apps that have been deleted for more than // the specified time. func (db *Database) PurgeMobileApps(maxAge time.Duration) (int64, error) { diff --git a/pkg/database/mobile_app_test.go b/pkg/database/mobile_app_test.go index 5cadd7866..53235521a 100644 --- a/pkg/database/mobile_app_test.go +++ b/pkg/database/mobile_app_test.go @@ -15,82 +15,94 @@ package database import ( + "fmt" "testing" + "time" "github.com/google/exposure-notifications-verification-server/pkg/pagination" + "github.com/jinzhu/gorm" ) -func TestMobileApp_Validation(t *testing.T) { +func TestOSType(t *testing.T) { t.Parallel() - db, _ := testDatabaseInstance.NewDatabase(t, nil) - - t.Run("name", func(t *testing.T) { - t.Parallel() - - var m MobileApp - m.Name = "" - _ = m.BeforeSave(db.RawDB()) - - nameErrs := m.ErrorsFor("name") - if len(nameErrs) < 1 { - t.Fatal("expected error") - } - }) - - t.Run("app_id", func(t *testing.T) { - t.Parallel() - - var m MobileApp - m.AppID = "" - _ = m.BeforeSave(db.RawDB()) - - appIDErrs := m.ErrorsFor("app_id") - if len(appIDErrs) < 1 { - t.Fatal("expected error") - } - }) - - t.Run("os", func(t *testing.T) { - t.Parallel() - - var m MobileApp - m.OS = 0 - _ = m.BeforeSave(db.RawDB()) + // This test might seem like it's redundant, but it's designed to ensure that + // the exact values for existing types remain unchanged. + cases := []struct { + t OSType + exp int + }{ + {OSTypeInvalid, 0}, + {OSTypeIOS, 1}, + {OSTypeAndroid, 2}, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.t.Display(), func(t *testing.T) { + t.Parallel() + + if got, want := int(tc.t), tc.exp; got != want { + t.Errorf("expected %d to be %d", got, want) + } + }) + } +} - osErrs := m.ErrorsFor("os") - if len(osErrs) < 1 { - t.Fatal("expected error") - } +func TestOSType_Display(t *testing.T) { + t.Parallel() - m.OS = 4 - _ = m.BeforeSave(db.RawDB()) + cases := []struct { + t OSType + exp string + }{ + {OSTypeInvalid, "Unknown"}, + {OSTypeIOS, "iOS"}, + {OSTypeAndroid, "Android"}, + {1991, "Unknown"}, + } + + for _, tc := range cases { + tc := tc + + t.Run(fmt.Sprintf("%d", tc.t), func(t *testing.T) { + t.Parallel() + + if got, want := tc.t.Display(), tc.exp; got != want { + t.Errorf("expected %q to be %q", got, want) + } + }) + } +} - osErrs = m.ErrorsFor("os") - if len(osErrs) < 1 { - t.Fatal("expected error") - } - }) +func TestMobileApp_Validation(t *testing.T) { + t.Parallel() - t.Run("url", func(t *testing.T) { - t.Parallel() + cases := []struct { + structField string + field string + }{ + {"Name", "name"}, + {"AppID", "app_id"}, + {"OS", "os"}, + } - var m MobileApp - m.URL = "" - _ = m.BeforeSave(db.RawDB()) + for _, tc := range cases { + tc := tc - urlErrors := m.ErrorsFor("url") - if len(urlErrors) != 0 { - t.Fatal("no xpected error") - } - }) + t.Run(tc.field, func(t *testing.T) { + t.Parallel() + exerciseValidation(t, &MobileApp{}, tc.structField, tc.field) + }) + } t.Run("sha", func(t *testing.T) { t.Parallel() var m MobileApp m.OS = OSTypeIOS - _ = m.BeforeSave(db.RawDB()) + _ = m.BeforeSave(&gorm.DB{}) shaErrs := m.ErrorsFor("sha") if len(shaErrs) > 0 { @@ -98,7 +110,7 @@ func TestMobileApp_Validation(t *testing.T) { } m.OS = OSTypeAndroid - _ = m.BeforeSave(db.RawDB()) + _ = m.BeforeSave(&gorm.DB{}) shaErrs = m.ErrorsFor("sha") if len(shaErrs) < 1 { @@ -136,7 +148,7 @@ func TestMobileApp_Validation(t *testing.T) { var m MobileApp m.SHA = tc.sha - _ = m.BeforeSave(db.RawDB()) + _ = m.BeforeSave(&gorm.DB{}) shaErrs := m.ErrorsFor("sha") if !tc.err && len(shaErrs) > 0 { @@ -147,7 +159,7 @@ func TestMobileApp_Validation(t *testing.T) { }) } -func TestMobileApp_List(t *testing.T) { +func TestDatabase_ListActiveApps(t *testing.T) { t.Parallel() t.Run("access_mobileapps_and_realms", func(t *testing.T) { @@ -211,3 +223,47 @@ func TestMobileApp_List(t *testing.T) { } }) } + +func TestDatabase_PurgeMobileApps(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + + now := time.Now().UTC() + for i := 0; i < 5; i++ { + if err := db.SaveMobileApp(&MobileApp{ + RealmID: 1, + Name: fmt.Sprintf("Appy%d", i), + OS: OSTypeIOS, + URL: fmt.Sprintf("https://%d.example.com", i), + AppID: fmt.Sprintf("app.%d.com", i), + Model: gorm.Model{ + DeletedAt: &now, + }, + }, SystemTest); err != nil { + t.Fatal(err) + } + } + + // Should not purge entries (too young). + { + n, err := db.PurgeMobileApps(24 * time.Hour) + if err != nil { + t.Fatal(err) + } + if got, want := n, int64(0); got != want { + t.Errorf("expected %d to purge, got %d", want, got) + } + } + + // Purges entries. + { + n, err := db.PurgeMobileApps(1 * time.Nanosecond) + if err != nil { + t.Fatal(err) + } + if got, want := n, int64(5); got != want { + t.Errorf("expected %d to purge, got %d", want, got) + } + } +} diff --git a/pkg/database/modeler_status_test.go b/pkg/database/modeler_status_test.go new file mode 100644 index 000000000..447d75463 --- /dev/null +++ b/pkg/database/modeler_status_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "testing" + "time" +) + +func TestDatabase_ClaimModelerStatus(t *testing.T) { + t.Parallel() + + // Create this now so we don't get clock skew + later := time.Now().UTC().Add(modelerLockTime) + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + + if err := db.ClaimModelerStatus(); err != nil { + t.Fatal(err) + } + + var status ModelerStatus + if err := db.db.Model(&ModelerStatus{}).First(&status).Error; err != nil { + t.Fatal(err) + } + + if got, now := status.NotBefore, later; !got.After(now) { + t.Errorf("expected %q to be after %q", got, now) + } +} diff --git a/pkg/database/realm.go b/pkg/database/realm.go index e8a4b7567..7c39a4b15 100644 --- a/pkg/database/realm.go +++ b/pkg/database/realm.go @@ -49,7 +49,6 @@ const ( TestTypeConfirmed TestTypeLikely TestTypeNegative - DefaultTemplateLabel = "Default SMS template" ) func (t TestType) Display() string { @@ -70,6 +69,30 @@ func (t TestType) Display() string { return strings.Join(types, ", ") } +// AuthRequirement represents authentication requirements for the realm +type AuthRequirement int16 + +const ( + // MFAOptionalPrompt will prompt users for MFA on login. + MFAOptionalPrompt AuthRequirement = iota + // MFARequired will not allow users to proceed without MFA on their account. + MFARequired + // MFAOptional will not prompt users to enable MFA. + MFAOptional +) + +func (r AuthRequirement) String() string { + switch r { + case MFAOptionalPrompt: + return "prompt" + case MFARequired: + return "required" + case MFAOptional: + return "optional" + } + return "" +} + var ( ErrNoSigningKeyManagement = errors.New("no signing key management") ErrBadDateRange = errors.New("bad date range") @@ -89,22 +112,12 @@ const ( SMSTemplateMaxLength = 800 SMSTemplateExpansionMax = 918 + DefaultTemplateLabel = "Default SMS template" + EmailInviteLink = "[invitelink]" EmailPasswordResetLink = "[passwordresetlink]" EmailVerifyLink = "[verifylink]" RealmName = "[realmname]" -) - -// AuthRequirement represents authentication requirements for the realm -type AuthRequirement int16 - -const ( - // MFAOptionalPrompt will prompt users for MFA on login. - MFAOptionalPrompt = iota - // MFARequired will not allow users to proceed without MFA on their account. - MFARequired - // MFAOptional will not prompt users to enable MFA. - MFAOptional // MaxPageSize is the maximum allowed page size for a list query. MaxPageSize = 1000 @@ -250,31 +263,6 @@ type Realm struct { Tokens []*Token `gorm:"PRELOAD:false; SAVE_ASSOCIATIONS:false; ASSOCIATION_AUTOUPDATE:false, ASSOCIATION_SAVE_REFERENCE:false"` } -// EffectiveMFAMode returns the realm's default MFAMode but first -// checks if the user is in the grace-period (if so, required becomes prompt). -func (r *Realm) EffectiveMFAMode(user *User) AuthRequirement { - if r == nil { - return MFARequired - } - - if time.Since(user.CreatedAt) <= r.MFARequiredGracePeriod.Duration { - return MFAOptionalPrompt - } - return r.MFAMode -} - -func (mode *AuthRequirement) String() string { - switch *mode { - case MFAOptionalPrompt: - return "prompt" - case MFARequired: - return "required" - case MFAOptional: - return "optional" - } - return "" -} - // NewRealmWithDefaults initializes a new Realm with the default settings populated, // and the provided name. It does NOT save the Realm to the database. func NewRealmWithDefaults(name string) *Realm { @@ -291,14 +279,6 @@ func NewRealmWithDefaults(name string) *Realm { } } -func (r *Realm) CanUpgradeToRealmSigningKeys() bool { - return r.CertificateIssuer != "" && r.CertificateAudience != "" -} - -func (r *Realm) SigningKeyID() string { - return fmt.Sprintf("realm-%d", r.ID) -} - // AfterFind runs after a realm is found. func (r *Realm) AfterFind(tx *gorm.DB) error { r.RegionCode = stringValue(r.RegionCodePtr) @@ -365,12 +345,12 @@ func (r *Realm) BeforeSave(tx *gorm.DB) error { if r.SMSTextAlternateTemplates != nil { for l, t := range r.SMSTextAlternateTemplates { if t == nil || *t == "" { - r.AddError("SMSTextTemplate", fmt.Sprintf("no template for label %s", l)) + r.AddError("smsTextTemplate", fmt.Sprintf("no template for label %s", l)) r.AddError(l, fmt.Sprintf("no template for label %s", l)) continue } if l == "" { - r.AddError("SMSTextTemplate", fmt.Sprintf("no label for template %s", *t)) + r.AddError("smsTextTemplate", fmt.Sprintf("no label for template %s", *t)) continue } r.validateSMSTemplate(l, *t) @@ -383,19 +363,19 @@ func (r *Realm) BeforeSave(tx *gorm.DB) error { if r.EmailInviteTemplate != "" { if !strings.Contains(r.EmailInviteTemplate, EmailInviteLink) { - r.AddError("EmailInviteLink", fmt.Sprintf("must contain %q", EmailInviteLink)) + r.AddError("emailInviteLink", fmt.Sprintf("must contain %q", EmailInviteLink)) } } if r.EmailPasswordResetTemplate != "" { if !strings.Contains(r.EmailPasswordResetTemplate, EmailPasswordResetLink) { - r.AddError("EmailPasswordResetTemplate", fmt.Sprintf("must contain %q", EmailPasswordResetLink)) + r.AddError("emailPasswordResetTemplate", fmt.Sprintf("must contain %q", EmailPasswordResetLink)) } } if r.EmailVerifyTemplate != "" { if !strings.Contains(r.EmailVerifyTemplate, EmailVerifyLink) { - r.AddError("EmailVerifyTemplate", fmt.Sprintf("must contain %q", EmailVerifyLink)) + r.AddError("emailVerifyTemplate", fmt.Sprintf("must contain %q", EmailVerifyLink)) } } @@ -416,8 +396,8 @@ func (r *Realm) BeforeSave(tx *gorm.DB) error { } } - if len(r.Errors()) > 0 { - return fmt.Errorf("realm validation failed: %s", strings.Join(r.ErrorMessages(), ", ")) + if msgs := r.ErrorMessages(); len(msgs) > 0 { + return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", ")) } return nil } @@ -427,29 +407,29 @@ func (r *Realm) BeforeSave(tx *gorm.DB) error { func (r *Realm) validateSMSTemplate(label, t string) { if r.EnableENExpress { if !strings.Contains(t, SMSENExpressLink) { - r.AddError("SMSTextTemplate", fmt.Sprintf("must contain %q", SMSENExpressLink)) + r.AddError("smsTextTemplate", fmt.Sprintf("must contain %q", SMSENExpressLink)) r.AddError(label, fmt.Sprintf("must contain %q", SMSENExpressLink)) } if strings.Contains(t, SMSRegion) { - r.AddError("SMSTextTemplate", fmt.Sprintf("cannot contain %q - this is automatically included in %q", SMSRegion, SMSENExpressLink)) + r.AddError("smsTextTemplate", fmt.Sprintf("cannot contain %q - this is automatically included in %q", SMSRegion, SMSENExpressLink)) r.AddError(label, fmt.Sprintf("must contain %q", SMSENExpressLink)) } if strings.Contains(t, SMSLongCode) { - r.AddError("SMSTextTemplate", fmt.Sprintf("cannot contain %q - the long code is automatically included in %q", SMSLongCode, SMSENExpressLink)) + r.AddError("smsTextTemplate", fmt.Sprintf("cannot contain %q - the long code is automatically included in %q", SMSLongCode, SMSENExpressLink)) r.AddError(label, fmt.Sprintf("must contain %q", SMSENExpressLink)) } } else { // Check that we have exactly one of [code] or [longcode] as template substitutions. if c, lc := strings.Contains(t, SMSCode), strings.Contains(t, SMSLongCode); !(c || lc) || (c && lc) { - r.AddError("SMSTextTemplate", fmt.Sprintf("must contain exactly one of %q or %q", SMSCode, SMSLongCode)) + r.AddError("smsTextTemplate", fmt.Sprintf("must contain exactly one of %q or %q", SMSCode, SMSLongCode)) r.AddError(label, fmt.Sprintf("must contain %q", SMSENExpressLink)) } } // Check template length. if l := len(t); l > SMSTemplateMaxLength { - r.AddError("SMSTextTemplate", fmt.Sprintf("must be %d characters or less, current message is %v characters long", SMSTemplateMaxLength, l)) + r.AddError("smsTextTemplate", fmt.Sprintf("must be %d characters or less, current message is %v characters long", SMSTemplateMaxLength, l)) r.AddError(label, fmt.Sprintf("must contain %q", SMSENExpressLink)) } @@ -459,11 +439,11 @@ func (r *Realm) validateSMSTemplate(label, t string) { enxDomain := os.Getenv("ENX_REDIRECT_DOMAIN") expandedSMSText, err := r.BuildSMSText(fakeCode, fakeLongCode, enxDomain, label) if err != nil { - r.AddError("SMSTextTemplate", fmt.Sprintf("SMS template expansion failed: %s", err)) + r.AddError("smsTextTemplate", fmt.Sprintf("SMS template expansion failed: %s", err)) r.AddError(label, fmt.Sprintf("SMS template expansion failed: %s", err)) } if l := len(expandedSMSText); l > SMSTemplateExpansionMax { - r.AddError("SMSTextTemplate", fmt.Sprintf("when expanded, the result message is too long (%v characters). The max expanded message is %v characters", l, SMSTemplateExpansionMax)) + r.AddError("smsTextTemplate", fmt.Sprintf("when expanded, the result message is too long (%v characters). The max expanded message is %v characters", l, SMSTemplateExpansionMax)) r.AddError(label, fmt.Sprintf("when expanded, the result message is too long (%v characters). The max expanded message is %v characters", l, SMSTemplateExpansionMax)) } @@ -481,6 +461,19 @@ func (r *Realm) GetLongCodeDurationHours() int { return int(r.LongCodeDuration.Duration.Hours()) } +// EffectiveMFAMode returns the realm's default MFAMode but first +// checks if the user is in the grace-period (if so, required becomes prompt). +func (r *Realm) EffectiveMFAMode(user *User) AuthRequirement { + if r == nil { + return MFARequired + } + + if time.Since(user.CreatedAt) <= r.MFARequiredGracePeriod.Duration { + return MFAOptionalPrompt + } + return r.MFAMode +} + // FindVerificationCodeByUUID find a verification codes by UUID. func (r *Realm) FindVerificationCodeByUUID(db *Database, uuid string) (*VerificationCode, error) { var vc VerificationCode @@ -1242,6 +1235,45 @@ func (db *Database) SaveRealm(r *Realm, actor Auditable) error { }) } +// CreateAuthorizedApp generates a new API key and assigns it to the specified +// app. Note that the API key is NOT stored in the database, only a hash. The +// only time the API key is available is as the string return parameter from +// invoking this function. +func (r *Realm) CreateAuthorizedApp(db *Database, app *AuthorizedApp, actor Auditable) (string, error) { + fullAPIKey, err := db.GenerateAPIKey(r.ID) + if err != nil { + return "", fmt.Errorf("failed to generate API key: %w", err) + } + + parts := strings.SplitN(fullAPIKey, ".", 3) + if len(parts) != 3 { + return "", fmt.Errorf("internal error, key is invalid") + } + apiKey := parts[0] + + hmacedKey, err := db.GenerateAPIKeyHMAC(apiKey) + if err != nil { + return "", fmt.Errorf("failed to create hmac: %w", err) + } + + app.RealmID = r.ID + app.APIKey = hmacedKey + app.APIKeyPreview = apiKey[:6] + + if err := db.SaveAuthorizedApp(app, actor); err != nil { + return "", err + } + return fullAPIKey, nil +} + +func (r *Realm) CanUpgradeToRealmSigningKeys() bool { + return r.CertificateIssuer != "" && r.CertificateAudience != "" +} + +func (r *Realm) SigningKeyID() string { + return fmt.Sprintf("realm-%d", r.ID) +} + // CreateSigningKeyVersion creates a new signing key version on the key manager // and saves a reference to the new key version in the database. If creating the // key in the key manager fails, the database is not updated. However, if diff --git a/pkg/database/realm_stats_test.go b/pkg/database/realm_stats_test.go new file mode 100644 index 000000000..edc54bece --- /dev/null +++ b/pkg/database/realm_stats_test.go @@ -0,0 +1,100 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestRealmStats_MarshalCSV(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + stats RealmStats + exp string + }{ + { + name: "empty", + stats: nil, + exp: "", + }, + { + name: "single", + stats: []*RealmStat{ + { + Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC), + RealmID: 1, + CodesIssued: 10, + CodesClaimed: 9, + DailyActiveUsers: 2, + }, + }, + exp: `date,codes_issued,codes_claimed,daily_active_users +2020-02-03,10,9,2 +`, + }, + { + name: "multi", + stats: []*RealmStat{ + { + Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC), + RealmID: 1, + CodesIssued: 10, + CodesClaimed: 9, + DailyActiveUsers: 12, + }, + { + Date: time.Date(2020, 2, 4, 0, 0, 0, 0, time.UTC), + RealmID: 1, + CodesIssued: 45, + CodesClaimed: 30, + DailyActiveUsers: 24, + }, + { + Date: time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC), + RealmID: 1, + CodesIssued: 15, + CodesClaimed: 2, + DailyActiveUsers: 18, + }, + }, + exp: `date,codes_issued,codes_claimed,daily_active_users +2020-02-03,10,9,12 +2020-02-04,45,30,24 +2020-02-05,15,2,18 +`, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b, err := tc.stats.MarshalCSV() + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(string(b), tc.exp); diff != "" { + t.Errorf("bad csv (+got, -want): %s", diff) + } + }) + } +} diff --git a/pkg/database/realm_test.go b/pkg/database/realm_test.go index 559b125bf..4be805000 100644 --- a/pkg/database/realm_test.go +++ b/pkg/database/realm_test.go @@ -16,6 +16,7 @@ package database import ( "context" + "fmt" "os" "path/filepath" "strings" @@ -24,38 +25,115 @@ import ( "github.com/google/exposure-notifications-server/pkg/timeutils" "github.com/google/exposure-notifications-verification-server/internal/project" - "github.com/google/exposure-notifications-verification-server/pkg/rbac" + "github.com/jinzhu/gorm" ) -func TestSMS(t *testing.T) { +func TestTestType(t *testing.T) { t.Parallel() - realm := NewRealmWithDefaults("test") - realm.SMSTextTemplate = "This is your Exposure Notifications Verification code: [enslink] Expires in [longexpires] hours" - realm.RegionCode = "US-WA" + // This test might seem like it's redundant, but it's designed to ensure that + // the exact values for existing types remain unchanged. + cases := []struct { + t TestType + exp int + }{ + {TestTypeConfirmed, 2}, + {TestTypeLikely, 4}, + {TestTypeNegative, 8}, + } - got, err := realm.BuildSMSText("12345678", "abcdefgh12345678", "en.express", "") - if err != nil { - t.Fatalf("failed to buildSMS, %v", err) + for _, tc := range cases { + tc := tc + + t.Run(tc.t.Display(), func(t *testing.T) { + t.Parallel() + + if got, want := int(tc.t), tc.exp; got != want { + t.Errorf("expected %d to be %d", got, want) + } + }) } - want := "This is your Exposure Notifications Verification code: https://us-wa.en.express/v?c=abcdefgh12345678 Expires in 24 hours" - if got != want { - t.Errorf("SMS text wrong, want: %q got %q", want, got) +} + +func TestTestType_Display(t *testing.T) { + t.Parallel() + + cases := []struct { + t TestType + exp string + }{ + {TestTypeConfirmed, "confirmed"}, + {TestTypeConfirmed | TestTypeLikely, "confirmed, likely"}, + {TestTypeLikely, "likely"}, + {TestTypeNegative, "negative"}, } - realm.SMSTextTemplate = "State of Wonder, COVID-19 Exposure Verification code [code]. Expires in [expires] minutes. Act now!" - got, err = realm.BuildSMSText("654321", "asdflkjasdlkfjl", "", "") - if err != nil { - t.Fatalf("failed to buildSMS, %v", err) + for _, tc := range cases { + tc := tc + + t.Run(fmt.Sprintf("%d", tc.t), func(t *testing.T) { + t.Parallel() + + if got, want := tc.t.Display(), tc.exp; got != want { + t.Errorf("expected %q to be %q", got, want) + } + }) } - want = "State of Wonder, COVID-19 Exposure Verification code 654321. Expires in 15 minutes. Act now!" - if got != want { - t.Errorf("SMS text wrong, want: %q got %q", want, got) +} + +func TestAuthRequirement(t *testing.T) { + t.Parallel() + + // This test might seem like it's redundant, but it's designed to ensure that + // the exact values for existing types remain unchanged. + cases := []struct { + t AuthRequirement + exp int + }{ + {MFAOptionalPrompt, 0}, + {MFARequired, 1}, + {MFAOptional, 2}, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.t.String(), func(t *testing.T) { + t.Parallel() + + if got, want := int(tc.t), tc.exp; got != want { + t.Errorf("expected %d to be %d", got, want) + } + }) } } -func TestValidation(t *testing.T) { - db, _ := testDatabaseInstance.NewDatabase(t, nil) +func TestAuthRequirement_String(t *testing.T) { + t.Parallel() + + cases := []struct { + t AuthRequirement + exp string + }{ + {MFAOptionalPrompt, "prompt"}, + {MFARequired, "required"}, + {MFAOptional, "optional"}, + } + + for _, tc := range cases { + tc := tc + + t.Run(fmt.Sprintf("%d", tc.t), func(t *testing.T) { + t.Parallel() + + if got, want := tc.t.String(), tc.exp; got != want { + t.Errorf("expected %q to be %q", got, want) + } + }) + } +} + +func TestRealm_BeforeSave(t *testing.T) { os.Setenv("ENX_REDIRECT_DOMAIN", "https://en.express") valid := "State of Wonder, COVID-19 Exposure Verification code [code]. Expires in [expires] minutes. Act now!" @@ -75,7 +153,6 @@ func TestValidation(t *testing.T) { { Name: "region_code_too_long", Input: &Realm{ - Name: "foo", RegionCode: "USA-IS-A-OK", }, Error: "regionCode cannot be more than 10 characters", @@ -83,12 +160,28 @@ func TestValidation(t *testing.T) { { Name: "enx_region_code_mismatch", Input: &Realm{ - Name: "foo", RegionCode: " ", EnableENExpress: true, }, Error: "regionCode cannot be blank when using EN Express", }, + { + Name: "system_sms_forbidden", + Input: &Realm{ + UseSystemSMSConfig: true, + CanUseSystemSMSConfig: false, + }, + Error: "useSystemSMSConfig is not allowed on this realm", + }, + { + Name: "system_sms_missing_from", + Input: &Realm{ + UseSystemSMSConfig: true, + CanUseSystemSMSConfig: true, + SMSFromNumberID: 0, + }, + Error: "smsFromNumber is required to use the system config", + }, { Name: "rotation_warning_too_big", Input: &Realm{ @@ -141,7 +234,7 @@ func TestValidation(t *testing.T) { EnableENExpress: true, SMSTextTemplate: "call 1-800-555-1234", }, - Error: "SMSTextTemplate must contain \"[enslink]\"", + Error: "smsTextTemplate must contain \"[enslink]\"", }, { Name: "enx_link_contains_region", @@ -150,7 +243,7 @@ func TestValidation(t *testing.T) { EnableENExpress: true, SMSTextTemplate: "[enslink] [region]", }, - Error: "SMSTextTemplate cannot contain \"[region]\" - this is automatically included in \"[enslink]\"", + Error: "smsTextTemplate cannot contain \"[region]\" - this is automatically included in \"[enslink]\"", }, { Name: "enx_link_contains_long_code", @@ -159,7 +252,7 @@ func TestValidation(t *testing.T) { EnableENExpress: true, SMSTextTemplate: "[enslink] [longcode]", }, - Error: "SMSTextTemplate cannot contain \"[longcode]\" - the long code is automatically included in \"[enslink]\"", + Error: "smsTextTemplate cannot contain \"[longcode]\" - the long code is automatically included in \"[enslink]\"", }, { Name: "link_missing_code", @@ -168,7 +261,7 @@ func TestValidation(t *testing.T) { EnableENExpress: false, SMSTextTemplate: "call me", }, - Error: "SMSTextTemplate must contain exactly one of \"[code]\" or \"[longcode]\"", + Error: "smsTextTemplate must contain exactly one of \"[code]\" or \"[longcode]\"", }, { Name: "link_both_codess", @@ -177,7 +270,7 @@ func TestValidation(t *testing.T) { EnableENExpress: false, SMSTextTemplate: "[code][longcode]", }, - Error: "SMSTextTemplate must contain exactly one of \"[code]\" or \"[longcode]\"", + Error: "smsTextTemplate must contain exactly one of \"[code]\" or \"[longcode]\"", }, { Name: "text_too_long", @@ -188,7 +281,7 @@ func TestValidation(t *testing.T) { Curabitur non massa urna. Phasellus sit amet nisi ut quam dapibus pretium eget in turpis. Phasellus et justo odio. In auctor, felis a tincidunt maximus, nunc erat vehicula ligula, ac posuere felis odio eget mauris. Nulla gravida.`, }, - Error: "SMSTextTemplate must be 800 characters or less, current message is 807 characters long", + Error: "smsTextTemplate must be 800 characters or less, current message is 807 characters long", }, { Name: "text_too_long", @@ -197,7 +290,7 @@ func TestValidation(t *testing.T) { EnableENExpress: false, SMSTextTemplate: strings.Repeat("[enslink]", 88), }, - Error: "SMSTextTemplate when expanded, the result message is too long (3168 characters). The max expanded message is 918 characters", + Error: "smsTextTemplate when expanded, the result message is too long (3168 characters). The max expanded message is 918 characters", }, { Name: "valid", @@ -220,7 +313,7 @@ func TestValidation(t *testing.T) { Error: "no template for label alternate1", }, { - Name: "alternate_sms_template valid", + Name: "alternate_sms_template_valid", Input: &Realm{ Name: "b", CodeLength: 6, @@ -230,54 +323,154 @@ func TestValidation(t *testing.T) { SMSTextAlternateTemplates: map[string]*string{"alternate1": &valid}, }, }, + { + Name: "system_email_forbidden", + Input: &Realm{ + UseSystemEmailConfig: true, + CanUseSystemEmailConfig: false, + }, + Error: "useSystemEmailConfig is not allowed on this realm", + }, + { + Name: "email_invite_template_missing_link", + Input: &Realm{ + EmailInviteTemplate: "banana", + }, + Error: "emailInviteLink must contain \"[invitelink]\"", + }, + { + Name: "email_password_reset_template_missing_link", + Input: &Realm{ + EmailPasswordResetTemplate: "banana", + }, + Error: "emailPasswordResetTemplate must contain \"[passwordresetlink]\"", + }, + { + Name: "email_verify_template_missing_link", + Input: &Realm{ + EmailVerifyTemplate: "banana", + }, + Error: "emailVerifyTemplate must contain \"[verifylink]\"", + }, + { + Name: "certificate_issuer_blank", + Input: &Realm{ + UseRealmCertificateKey: true, + CertificateIssuer: "", + }, + Error: "certificateIssuer cannot be blank", + }, + { + Name: "certificate_audience_blank", + Input: &Realm{ + UseRealmCertificateKey: true, + CertificateAudience: "", + }, + Error: "certificateAudience cannot be blank", + }, } for _, tc := range cases { + tc := tc + t.Run(tc.Name, func(t *testing.T) { - if err := db.SaveRealm(tc.Input, SystemTest); err == nil { + t.Parallel() + + if err := tc.Input.BeforeSave(&gorm.DB{}); err != nil { if tc.Error != "" { - t.Fatalf("expected error: %q got: nil", tc.Error) + if got, want := err.Error(), tc.Error; !strings.Contains(got, want) { + t.Errorf("expected %q to be %q", got, want) + } + } else { + t.Errorf("bad error: %s", err) } - } else if tc.Error == "" { - t.Fatalf("expected no error, got %q", err.Error()) - } else if !strings.Contains(err.Error(), tc.Error) { - t.Fatalf("wrong error, want: %q got: %q", tc.Error, err.Error()) } }) } } -func TestPerUserRealmStats(t *testing.T) { +func TestRealm_BuildSMSText(t *testing.T) { + t.Parallel() + + realm := NewRealmWithDefaults("test") + realm.SMSTextTemplate = "This is your Exposure Notifications Verification code: [enslink] Expires in [longexpires] hours" + realm.RegionCode = "US-WA" + + got, err := realm.BuildSMSText("12345678", "abcdefgh12345678", "en.express", "") + if err != nil { + t.Fatal(err) + } + want := "This is your Exposure Notifications Verification code: https://us-wa.en.express/v?c=abcdefgh12345678 Expires in 24 hours" + if got != want { + t.Errorf("SMS text wrong, want: %q got %q", want, got) + } + + realm.SMSTextTemplate = "State of Wonder, COVID-19 Exposure Verification code [code]. Expires in [expires] minutes. Act now!" + got, err = realm.BuildSMSText("654321", "asdflkjasdlkfjl", "", "") + if err != nil { + t.Fatal(err) + } + want = "State of Wonder, COVID-19 Exposure Verification code 654321. Expires in 15 minutes. Act now!" + if got != want { + t.Errorf("SMS text wrong, want: %q got %q", want, got) + } +} + +func TestRealm_BuildInviteEmail(t *testing.T) { + t.Parallel() + + realm := NewRealmWithDefaults("test") + realm.EmailInviteTemplate = "Welcome to [realmname] [invitelink]." + + if got, want := realm.BuildInviteEmail("https://join.now"), "Welcome to test https://join.now."; got != want { + t.Errorf("expected %q to be %q", got, want) + } +} + +func TestRealm_BuildPasswordResetEmail(t *testing.T) { + t.Parallel() + + realm := NewRealmWithDefaults("test") + realm.EmailPasswordResetTemplate = "Hey [realmname] reset [passwordresetlink]." + + if got, want := realm.BuildPasswordResetEmail("https://reset.now"), "Hey test reset https://reset.now."; got != want { + t.Errorf("expected %q to be %q", got, want) + } +} + +func TestRealm_BuildVerifyEmail(t *testing.T) { + t.Parallel() + + realm := NewRealmWithDefaults("test") + realm.EmailVerifyTemplate = "Hey [realmname] verify [verifylink]." + + if got, want := realm.BuildVerifyEmail("https://verify.now"), "Hey test verify https://verify.now."; got != want { + t.Errorf("expected %q to be %q", got, want) + } +} + +func TestRealm_UserStats(t *testing.T) { t.Parallel() db, _ := testDatabaseInstance.NewDatabase(t, nil) + realm, err := db.FindRealm(1) + if err != nil { + t.Fatal(err) + } + numDays := 7 endDate := timeutils.Midnight(time.Now()) startDate := timeutils.Midnight(endDate.Add(time.Duration(numDays) * -24 * time.Hour)) - // Create a new realm - realm := NewRealmWithDefaults("test") - if err := db.SaveRealm(realm, SystemTest); err != nil { - t.Fatalf("error saving realm: %v", err) - } - - // Create the users. - users := []*User{} + // Create users. for userIdx, name := range []string{"Rocky", "Bullwinkle", "Boris", "Natasha"} { user := &User{ - Name: name, - Email: name + "@gmail.com", - SystemAdmin: false, + Name: name, + Email: name + "@example.com", } if err := db.SaveUser(user, SystemTest); err != nil { - t.Fatalf("[%v] error creating user: %v", name, err) - } - users = append(users, user) - - // Add to realm - if err := user.AddToRealm(db, realm, rbac.CodeIssue, SystemTest); err != nil { - t.Fatal(err) + t.Fatalf("failed to create user %q: %s", name, err) } // Add some stats per user. @@ -294,10 +487,6 @@ func TestPerUserRealmStats(t *testing.T) { } } - if len(users) == 0 { // sanity check - t.Error("len(users) = 0, expected ≠ 0") - } - stats, err := realm.UserStats(db, startDate, endDate) if err != nil { t.Fatalf("error getting stats: %v", err) @@ -509,3 +698,82 @@ func TestRealm_SMSConfig(t *testing.T) { } } } + +func TestRealm_EmailConfig(t *testing.T) { + t.Parallel() + + db, _ := testDatabaseInstance.NewDatabase(t, nil) + + realm, err := db.FindRealm(1) + if err != nil { + t.Fatal(err) + } + + // Initial realm should have no config + if _, err := realm.EmailConfig(db); err == nil { + t.Fatalf("expected error") + } + + // Create config + if err := db.SaveEmailConfig(&EmailConfig{ + RealmID: realm.ID, + SMTPAccount: "account", + SMTPHost: "host", + SMTPPort: "port", + SMTPPassword: "password", + }); err != nil { + t.Fatal(err) + } + + { + // Now the realm should have a config + EmailConfig, err := realm.EmailConfig(db) + if err != nil { + t.Fatal(err) + } + if got, want := EmailConfig.SMTPAccount, "account"; got != want { + t.Errorf("expected %v to be %v", got, want) + } + if got, want := EmailConfig.SMTPHost, "host"; got != want { + t.Errorf("expected %v to be %v", got, want) + } + if got, want := EmailConfig.SMTPPort, "port"; got != want { + t.Errorf("expected %v to be %v", got, want) + } + } + + // Create system config + if err := db.SaveEmailConfig(&EmailConfig{ + SMTPAccount: "system-account", + SMTPHost: "system-host", + SMTPPort: "system-port", + SMTPPassword: "system-password", + IsSystem: true, + }); err != nil { + t.Fatal(err) + } + + // Update to use system config + realm.CanUseSystemEmailConfig = true + realm.UseSystemEmailConfig = true + if err := db.SaveRealm(realm, SystemTest); err != nil { + t.Fatal(err) + } + + // The realm should use the system config. + { + emailConfig, err := realm.EmailConfig(db) + if err != nil { + t.Fatal(err) + } + if got, want := emailConfig.SMTPAccount, "system-account"; got != want { + t.Errorf("expected %v to be %v", got, want) + } + if got, want := emailConfig.SMTPHost, "system-host"; got != want { + t.Errorf("expected %v to be %v", got, want) + } + if got, want := emailConfig.SMTPPort, "system-port"; got != want { + t.Errorf("expected %v to be %v", got, want) + } + } +} diff --git a/pkg/database/realm_user_stats_test.go b/pkg/database/realm_user_stats_test.go new file mode 100644 index 000000000..42919e9a7 --- /dev/null +++ b/pkg/database/realm_user_stats_test.go @@ -0,0 +1,104 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestRealmUserStats_MarshalCSV(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + stats RealmUserStats + exp string + }{ + { + name: "empty", + stats: nil, + exp: "", + }, + { + name: "single", + stats: []*RealmUserStat{ + { + Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC), + RealmID: 1, + UserID: 1, + Name: "You", + Email: "you@example.com", + CodesIssued: 10, + }, + }, + exp: `date,realm_id,user_id,name,email,codes_issued +2020-02-03,1,1,You,you@example.com,10 +`, + }, + { + name: "multi", + stats: []*RealmUserStat{ + { + Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC), + RealmID: 1, + UserID: 1, + Name: "You", + Email: "you@example.com", + CodesIssued: 10, + }, + { + Date: time.Date(2020, 2, 4, 0, 0, 0, 0, time.UTC), + RealmID: 1, + UserID: 2, + Name: "Them", + Email: "them@example.com", + CodesIssued: 45, + }, + { + Date: time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC), + RealmID: 1, + UserID: 3, + Name: "Us", + Email: "us@example.com", + CodesIssued: 15, + }, + }, + exp: `date,realm_id,user_id,name,email,codes_issued +2020-02-03,1,1,You,you@example.com,10 +2020-02-04,1,2,Them,them@example.com,45 +2020-02-05,1,3,Us,us@example.com,15 +`, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b, err := tc.stats.MarshalCSV() + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(string(b), tc.exp); diff != "" { + t.Errorf("bad csv (+got, -want): %s", diff) + } + }) + } +} diff --git a/pkg/database/signing_key.go b/pkg/database/signing_key.go index c31057ea0..415ec709f 100644 --- a/pkg/database/signing_key.go +++ b/pkg/database/signing_key.go @@ -38,11 +38,3 @@ type SigningKey struct { func (s *SigningKey) GetKID() string { return fmt.Sprintf("r%dv%d", s.RealmID, s.ID) } - -func (s *SigningKey) Delete(db *Database) error { - return db.db.Delete(s).Error -} - -func (db *Database) SaveSigningKey(s *SigningKey) error { - return db.db.Save(s).Error -} diff --git a/pkg/database/sms_config_test.go b/pkg/database/sms_config_test.go index 34dcde84f..0feadb894 100644 --- a/pkg/database/sms_config_test.go +++ b/pkg/database/sms_config_test.go @@ -25,9 +25,8 @@ func TestSMSConfig_Lifecycle(t *testing.T) { db, _ := testDatabaseInstance.NewDatabase(t, nil) - // Create realm - realm := NewRealmWithDefaults(t.Name()) - if err := db.SaveRealm(realm, SystemTest); err != nil { + realm, err := db.FindRealm(1) + if err != nil { t.Fatal(err) } @@ -55,7 +54,7 @@ func TestSMSConfig_Lifecycle(t *testing.T) { } // Get the realm to verify SMS configs are NOT preloaded - realm, err := db.FindRealm(realm.ID) + realm, err = db.FindRealm(realm.ID) if err != nil { t.Fatal(err) } @@ -117,8 +116,8 @@ func TestSMSProvider(t *testing.T) { db, _ := testDatabaseInstance.NewDatabase(t, nil) - realm := NewRealmWithDefaults("test-sms-realm-1") - if err := db.SaveRealm(realm, SystemTest); err != nil { + realm, err := db.FindRealm(1) + if err != nil { t.Fatal(err) } diff --git a/pkg/database/sms_from_number_test.go b/pkg/database/sms_from_number_test.go index 9f072b663..85b0bdff2 100644 --- a/pkg/database/sms_from_number_test.go +++ b/pkg/database/sms_from_number_test.go @@ -22,33 +22,22 @@ import ( func TestSMSFromNumber_BeforeSave(t *testing.T) { t.Parallel() - db, _ := testDatabaseInstance.NewDatabase(t, nil) - - t.Run("label", func(t *testing.T) { - t.Parallel() - - var n SMSFromNumber - n.Label = "" - _ = n.BeforeSave(db.RawDB()) - - errs := n.ErrorsFor("label") - if len(errs) < 1 { - t.Fatal("expected error") - } - }) - - t.Run("value", func(t *testing.T) { - t.Parallel() + cases := []struct { + structField string + field string + }{ + {"Label", "label"}, + {"Value", "value"}, + } - var n SMSFromNumber - n.Value = "" - _ = n.BeforeSave(db.RawDB()) + for _, tc := range cases { + tc := tc - errs := n.ErrorsFor("value") - if len(errs) < 1 { - t.Fatal("expected error") - } - }) + t.Run(tc.field, func(t *testing.T) { + t.Parallel() + exerciseValidation(t, &SMSFromNumber{}, tc.structField, tc.field) + }) + } } func TestSMSFromNumbers(t *testing.T) { diff --git a/pkg/database/user.go b/pkg/database/user.go index cf61a77b5..3e217f35c 100644 --- a/pkg/database/user.go +++ b/pkg/database/user.go @@ -47,6 +47,27 @@ type User struct { LastPasswordChange time.Time } +// BeforeSave runs validations. If there are errors, the save fails. +func (u *User) BeforeSave(tx *gorm.DB) error { + u.Email = project.TrimSpace(u.Email) + if u.Email == "" { + u.AddError("email", "cannot be blank") + } + if !strings.Contains(u.Email, "@") { + u.AddError("email", "appears to be invalid") + } + + u.Name = project.TrimSpace(u.Name) + if u.Name == "" { + u.AddError("name", "cannot be blank") + } + + if msgs := u.ErrorMessages(); len(msgs) > 0 { + return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", ")) + } + return nil +} + // PasswordChanged returns password change time or account creation time if unset. func (u *User) PasswordChanged() time.Time { if u.LastPasswordChange.Before(launched) { @@ -75,29 +96,6 @@ func (u *User) PasswordAgeString() string { return fmt.Sprintf("%d seconds", int(ago.Seconds())) } -// BeforeSave runs validations. If there are errors, the save fails. -func (u *User) BeforeSave(tx *gorm.DB) error { - // Validation - u.Email = project.TrimSpace(u.Email) - if u.Email == "" { - u.AddError("email", "cannot be blank") - } - if !strings.Contains(u.Email, "@") { - u.AddError("email", "appears to be invalid") - } - - u.Name = project.TrimSpace(u.Name) - if u.Name == "" { - u.AddError("name", "cannot be blank") - } - - if len(u.Errors()) > 0 { - return fmt.Errorf("validation failed: %s", strings.Join(u.ErrorMessages(), ", ")) - } - - return nil -} - // FindUser finds a user by the given id, if one exists. The id can be a string // or integer value. It returns an error if the record is not found. func (db *Database) FindUser(id interface{}) (*User, error) { diff --git a/pkg/database/user_stats_test.go b/pkg/database/user_stats_test.go new file mode 100644 index 000000000..f73d88f02 --- /dev/null +++ b/pkg/database/user_stats_test.go @@ -0,0 +1,104 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestUserStats_MarshalCSV(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + stats UserStats + exp string + }{ + { + name: "empty", + stats: nil, + exp: "", + }, + { + name: "single", + stats: []*UserStat{ + { + Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC), + UserID: 1, + RealmID: 1, + CodesIssued: 10, + UserName: "You", + UserEmail: "you@example.com", + }, + }, + exp: `date,realm_id,user_id,user_name,user_email,codes_issued +2020-02-03,1,1,You,you@example.com,10 +`, + }, + { + name: "multi", + stats: []*UserStat{ + { + Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC), + UserID: 1, + RealmID: 1, + CodesIssued: 10, + UserName: "You", + UserEmail: "you@example.com", + }, + { + Date: time.Date(2020, 2, 4, 0, 0, 0, 0, time.UTC), + UserID: 2, + RealmID: 1, + CodesIssued: 45, + UserName: "Them", + UserEmail: "them@example.com", + }, + { + Date: time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC), + UserID: 1, + RealmID: 3, + CodesIssued: 15, + UserName: "Us", + UserEmail: "us@example.com", + }, + }, + exp: `date,realm_id,user_id,user_name,user_email,codes_issued +2020-02-03,1,1,You,you@example.com,10 +2020-02-04,1,2,Them,them@example.com,45 +2020-02-05,3,1,Us,us@example.com,15 +`, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b, err := tc.stats.MarshalCSV() + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(string(b), tc.exp); diff != "" { + t.Errorf("bad csv (+got, -want): %s", diff) + } + }) + } +} diff --git a/pkg/database/user_test.go b/pkg/database/user_test.go index 825a2c276..7561c5eb9 100644 --- a/pkg/database/user_test.go +++ b/pkg/database/user_test.go @@ -21,7 +21,28 @@ import ( "github.com/google/exposure-notifications-verification-server/pkg/rbac" ) -func TestUserLifecycle(t *testing.T) { +func TestUser_BeforeSave(t *testing.T) { + t.Parallel() + + cases := []struct { + structField string + field string + }{ + {"Email", "email"}, + {"Name", "name"}, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.field, func(t *testing.T) { + t.Parallel() + exerciseValidation(t, &User{}, tc.structField, tc.field) + }) + } +} + +func TestUser_Lifecycle(t *testing.T) { t.Parallel() db, _ := testDatabaseInstance.NewDatabase(t, nil) @@ -99,13 +120,13 @@ func TestUserLifecycle(t *testing.T) { } } -func TestPurgeUsers(t *testing.T) { +func TestDatabase_PurgeUsers(t *testing.T) { t.Parallel() db, _ := testDatabaseInstance.NewDatabase(t, nil) - realm := NewRealmWithDefaults("test") - if err := db.SaveRealm(realm, SystemTest); err != nil { + realm, err := db.FindRealm(1) + if err != nil { t.Fatal(err) } @@ -168,58 +189,60 @@ func TestPurgeUsers(t *testing.T) { } } -func TestRemoveRealmUpdatesTime(t *testing.T) { +func TestUser_DeleteFromRealm(t *testing.T) { t.Parallel() - db, _ := testDatabaseInstance.NewDatabase(t, nil) + t.Run("updates_time", func(t *testing.T) { + db, _ := testDatabaseInstance.NewDatabase(t, nil) - realm := NewRealmWithDefaults("test") - if err := db.SaveRealm(realm, SystemTest); err != nil { - t.Fatal(err) - } + realm := NewRealmWithDefaults("test") + if err := db.SaveRealm(realm, SystemTest); err != nil { + t.Fatal(err) + } - email := "purge@example.com" - user := &User{ - Email: email, - Name: "Dr Delete", - } - if err := db.SaveUser(user, SystemTest); err != nil { - t.Fatal(err) - } + email := "purge@example.com" + user := &User{ + Email: email, + Name: "Dr Delete", + } + if err := db.SaveUser(user, SystemTest); err != nil { + t.Fatal(err) + } - // Add to realm - if err := user.AddToRealm(db, realm, rbac.LegacyRealmAdmin, SystemTest); err != nil { - t.Fatal(err) - } + // Add to realm + if err := user.AddToRealm(db, realm, rbac.LegacyRealmAdmin, SystemTest); err != nil { + t.Fatal(err) + } - got, err := db.FindUser(user.ID) - if err != nil { - t.Fatal(err) - } + got, err := db.FindUser(user.ID) + if err != nil { + t.Fatal(err) + } - if got, want := got.ID, user.ID; got != want { - t.Errorf("expected %#v to be %#v", got, want) - } + if got, want := got.ID, user.ID; got != want { + t.Errorf("expected %#v to be %#v", got, want) + } - time.Sleep(time.Second) // in case this executes in under a nanosecond. + time.Sleep(time.Second) // in case this executes in under a nanosecond. - originalTime := got.Model.UpdatedAt - if err := user.DeleteFromRealm(db, realm, SystemTest); err != nil { - t.Fatal(err) - } + originalTime := got.Model.UpdatedAt + if err := user.DeleteFromRealm(db, realm, SystemTest); err != nil { + t.Fatal(err) + } - got, err = db.FindUser(user.ID) - if err != nil { - t.Fatal(err) - } + got, err = db.FindUser(user.ID) + if err != nil { + t.Fatal(err) + } - if got, want := got.ID, user.ID; got != want { - t.Errorf("expected %#v to be %#v", got, want) - } - // Assert that the user time was updated. - if originalTime == got.Model.UpdatedAt { - t.Errorf("expected user time to be updated. Got %#v", originalTime.Format(time.RFC3339)) - } + if got, want := got.ID, user.ID; got != want { + t.Errorf("expected %#v to be %#v", got, want) + } + // Assert that the user time was updated. + if originalTime == got.Model.UpdatedAt { + t.Errorf("expected user time to be updated. Got %#v", originalTime.Format(time.RFC3339)) + } + }) } func expectExists(t *testing.T, db *Database, id uint) { diff --git a/pkg/database/vercode.go b/pkg/database/vercode.go index ee02f9626..f439cfe05 100644 --- a/pkg/database/vercode.go +++ b/pkg/database/vercode.go @@ -30,6 +30,7 @@ import ( const ( oneDay = 24 * time.Hour + // MinCodeLength defines the minimum number of digits in a code. MinCodeLength = 6 ) @@ -37,9 +38,9 @@ const ( type CodeType int const ( - InvalidCode CodeType = iota - ShortCode - LongCode + _ CodeType = iota + CodeTypeShort + CodeTypeLong ) var ( @@ -90,19 +91,14 @@ type VerificationCode struct { IssuingExternalID string `gorm:"column:issuing_external_id; type:varchar(255);"` } -// TableName sets the VerificationCode table name -func (VerificationCode) TableName() string { - return "verification_codes" -} - // BeforeSave is used by callbacks. -func (v *VerificationCode) BeforeSave(scope *gorm.Scope) error { +func (v *VerificationCode) BeforeSave(tx *gorm.DB) error { if len(v.IssuingExternalID) > 255 { v.AddError("issuingExternalID", "cannot exceed 255 characters") } - if len(v.Errors()) > 0 { - return fmt.Errorf("email config validation failed: %s", strings.Join(v.ErrorMessages(), ", ")) + if msgs := v.ErrorMessages(); len(msgs) > 0 { + return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", ")) } return nil } @@ -170,8 +166,6 @@ func (v *VerificationCode) AfterCreate(scope *gorm.Scope) { } } -// TODO(mikehelmick) - Add method to soft delete expired codes - // FormatSymptomDate returns YYYY-MM-DD formatted test date, or "" if nil. func (v *VerificationCode) FormatSymptomDate() string { if v.SymptomDate == nil { @@ -184,13 +178,13 @@ func (v *VerificationCode) FormatSymptomDate() string { // code, and determines if it is expired based on that. func (db *Database) IsCodeExpired(v *VerificationCode, code string) (bool, CodeType, error) { if v == nil { - return false, InvalidCode, fmt.Errorf("provided code is nil") + return false, 0, fmt.Errorf("provided code is nil") } // It's possible that this could be called with the already HMACd version. possibles, err := db.generateVerificationCodeHMACs(code) if err != nil { - return false, InvalidCode, fmt.Errorf("failed to create hmac: %w", err) + return false, 0, fmt.Errorf("failed to create hmac: %w", err) } possibles = append(possibles, code) @@ -206,11 +200,11 @@ func (db *Database) IsCodeExpired(v *VerificationCode, code string) (bool, CodeT now := time.Now().UTC() switch { case inList(v.Code, possibles): - return !v.ExpiresAt.After(now), ShortCode, nil + return !v.ExpiresAt.After(now), CodeTypeShort, nil case inList(v.LongCode, possibles): - return !v.LongExpiresAt.After(now), LongCode, nil + return !v.LongExpiresAt.After(now), CodeTypeLong, nil default: - return true, InvalidCode, fmt.Errorf("not found") + return true, 0, fmt.Errorf("not found") } } diff --git a/pkg/database/vercode_test.go b/pkg/database/vercode_test.go index 4ce9c7dd9..c7c189862 100644 --- a/pkg/database/vercode_test.go +++ b/pkg/database/vercode_test.go @@ -16,6 +16,8 @@ package database import ( "errors" + "fmt" + "strings" "testing" "time" @@ -26,6 +28,47 @@ import ( "github.com/jinzhu/gorm" ) +func TestCodeType(t *testing.T) { + t.Parallel() + + // This test might seem like it's redundant, but it's designed to ensure that + // the exact values for existing types remain unchanged. + cases := []struct { + t CodeType + exp int + }{ + {CodeTypeShort, 1}, + {CodeTypeLong, 2}, + } + + for _, tc := range cases { + tc := tc + + t.Run(fmt.Sprintf("%d", tc.t), func(t *testing.T) { + t.Parallel() + + if got, want := int(tc.t), tc.exp; got != want { + t.Errorf("expected %d to be %d", got, want) + } + }) + } +} + +func TestVerificationCode_BeforeSave(t *testing.T) { + t.Parallel() + + t.Run("issuingExternalID", func(t *testing.T) { + t.Parallel() + + var v VerificationCode + v.IssuingExternalID = strings.Repeat("*", 256) + _ = v.BeforeSave(&gorm.DB{}) + if errs := v.ErrorsFor("issuingExternalID"); len(errs) < 1 { + t.Errorf("expected errors for %s", "issuingExternalID") + } + }) +} + func TestVerificationCode_FindVerificationCode(t *testing.T) { t.Parallel() diff --git a/pkg/otp/code_test.go b/pkg/otp/code_test.go index b7066eae0..fe7f50a94 100644 --- a/pkg/otp/code_test.go +++ b/pkg/otp/code_test.go @@ -114,8 +114,8 @@ func TestIssue(t *testing.T) { } if exp, codeType, err := db.IsCodeExpired(verCode, code); exp || err != nil { t.Fatalf("loaded code doesn't match requested code, %v %v", exp, err) - } else if codeType != database.ShortCode { - t.Errorf("wrong code type, want: %v got: %v", database.ShortCode, codeType) + } else if codeType != database.CodeTypeShort { + t.Errorf("wrong code type, want: %v got: %v", database.CodeTypeShort, codeType) } } @@ -126,8 +126,8 @@ func TestIssue(t *testing.T) { } if exp, codeType, err := db.IsCodeExpired(verCode, code); exp || err != nil { t.Fatalf("loaded code doesn't match requested code") - } else if codeType != database.LongCode { - t.Errorf("wrong code type, want: %v got: %v", database.LongCode, codeType) + } else if codeType != database.CodeTypeLong { + t.Errorf("wrong code type, want: %v got: %v", database.CodeTypeLong, codeType) } } }