diff --git a/cmd/server/assets/realmadmin/_form_codes.html b/cmd/server/assets/realmadmin/_form_codes.html
index 10082c338..e98649a53 100644
--- a/cmd/server/assets/realmadmin/_form_codes.html
+++ b/cmd/server/assets/realmadmin/_form_codes.html
@@ -199,7 +199,7 @@
- {{if $realm.ErrorsFor "SMSTextTemplate"}}
{{end}}
+ {{if $realm.ErrorsFor "smsTextTemplate"}}
{{end}}
@@ -214,7 +214,7 @@
New SMS template
- {{if $realm.ErrorsFor "SMSTextTemplate"}}
+ {{if $realm.ErrorsFor "smsTextTemplate"}}
Errors found for one or more SMS templates
diff --git a/pkg/database/audit_entry.go b/pkg/database/audit_entry.go
index a2985ae3e..37e6a6e2b 100644
--- a/pkg/database/audit_entry.go
+++ b/pkg/database/audit_entry.go
@@ -15,9 +15,12 @@
package database
import (
+ "fmt"
+ "strings"
"time"
"github.com/google/exposure-notifications-verification-server/pkg/pagination"
+ "github.com/jinzhu/gorm"
)
// AuditEntry represents an event in the system. These records are purged after
@@ -63,6 +66,32 @@ type AuditEntry struct {
CreatedAt time.Time
}
+// BeforeSave runs validations. If there are errors, the save fails.
+func (a *AuditEntry) BeforeSave(tx *gorm.DB) error {
+ if a.ActorID == "" {
+ a.AddError("actor_id", "cannot be blank")
+ }
+ if a.ActorDisplay == "" {
+ a.AddError("actor_display", "cannot be blank")
+ }
+
+ if a.Action == "" {
+ a.AddError("action", "cannot be blank")
+ }
+
+ if a.TargetID == "" {
+ a.AddError("target_id", "cannot be blank")
+ }
+ if a.TargetDisplay == "" {
+ a.AddError("target_display", "cannot be blank")
+ }
+
+ if msgs := a.ErrorMessages(); len(msgs) > 0 {
+ return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", "))
+ }
+ return nil
+}
+
// SaveAuditEntry saves the audit entry.
func (db *Database) SaveAuditEntry(a *AuditEntry) error {
return db.db.Save(a).Error
diff --git a/pkg/database/audit_entry_test.go b/pkg/database/audit_entry_test.go
new file mode 100644
index 000000000..e0c49d886
--- /dev/null
+++ b/pkg/database/audit_entry_test.go
@@ -0,0 +1,130 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package database
+
+import (
+ "testing"
+ "time"
+
+ "github.com/google/exposure-notifications-verification-server/pkg/pagination"
+)
+
+func TestAuditEntry_BeforeSave(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ structField string
+ field string
+ }{
+ {"ActorID", "actor_id"},
+ {"ActorDisplay", "actor_display"},
+ {"Action", "action"},
+ {"TargetID", "target_id"},
+ {"TargetDisplay", "target_display"},
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(tc.field, func(t *testing.T) {
+ t.Parallel()
+ exerciseValidation(t, &AuditEntry{}, tc.structField, tc.field)
+ })
+ }
+}
+
+func TestDatabase_PurgeAuditEntries(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+ for i := 0; i < 5; i++ {
+ if err := db.SaveAuditEntry(&AuditEntry{
+ RealmID: 1,
+ ActorID: "actor:1",
+ ActorDisplay: "Actor",
+ Action: "created",
+ TargetID: "target:1",
+ TargetDisplay: "Target",
+ }); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ // Should not purge entries (too young).
+ {
+ n, err := db.PurgeAuditEntries(24 * time.Hour)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := n, int64(0); got != want {
+ t.Errorf("expected %d to purge, got %d", want, got)
+ }
+ }
+
+ // Purges entries.
+ {
+ n, err := db.PurgeAuditEntries(1 * time.Nanosecond)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := n, int64(5); got != want {
+ t.Errorf("expected %d to purge, got %d", want, got)
+ }
+ }
+}
+
+func TestDatabase_ListAudits(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty", func(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+
+ audits, _, err := db.ListAudits(&pagination.PageParams{Limit: 1})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := len(audits), 0; got != want {
+ t.Errorf("expected %d audits, got %d: %v", want, got, audits)
+ }
+ })
+
+ t.Run("lists", func(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+ for i := 0; i < 5; i++ {
+ if err := db.SaveAuditEntry(&AuditEntry{
+ RealmID: 1,
+ ActorID: "actor:1",
+ ActorDisplay: "Actor",
+ Action: "created",
+ TargetID: "target:1",
+ TargetDisplay: "Target",
+ }); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ audits, _, err := db.ListAudits(&pagination.PageParams{Limit: 10})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := len(audits), 5; got != want {
+ t.Errorf("expected %d audits, got %d: %v", want, got, audits)
+ }
+ })
+}
diff --git a/pkg/database/authorized_app.go b/pkg/database/authorized_app.go
index 9c2d0c765..1bb8c4e6f 100644
--- a/pkg/database/authorized_app.go
+++ b/pkg/database/authorized_app.go
@@ -95,8 +95,8 @@ func (a *AuthorizedApp) BeforeSave(tx *gorm.DB) error {
a.AddError("type", "is invalid")
}
- if len(a.Errors()) > 0 {
- return fmt.Errorf("validation failed")
+ if msgs := a.ErrorMessages(); len(msgs) > 0 {
+ return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", "))
}
return nil
}
@@ -109,51 +109,20 @@ func (a *AuthorizedApp) IsDeviceType() bool {
return a.APIKeyType == APIKeyTypeDevice
}
-// Realm returns the associated realm for this app.
+// Realm returns the associated realm for this app. If you only need the ID,
+// call .RealmID instead of a full database lookup.
func (a *AuthorizedApp) Realm(db *Database) (*Realm, error) {
var realm Realm
- if err := db.db.Model(a).Related(&realm).Error; err != nil {
+ if err := db.db.
+ Model(&Realm{}).
+ Where("id = ?", a.RealmID).
+ First(&realm).
+ Error; err != nil {
return nil, err
}
return &realm, nil
}
-// TableName definition for the authorized apps relation.
-func (AuthorizedApp) TableName() string {
- return "authorized_apps"
-}
-
-// CreateAuthorizedApp generates a new API key and assigns it to the specified
-// app. Note that the API key is NOT stored in the database, only a hash. The
-// only time the API key is available is as the string return parameter from
-// invoking this function.
-func (r *Realm) CreateAuthorizedApp(db *Database, app *AuthorizedApp, actor Auditable) (string, error) {
- fullAPIKey, err := db.GenerateAPIKey(r.ID)
- if err != nil {
- return "", fmt.Errorf("failed to generate API key: %w", err)
- }
-
- parts := strings.SplitN(fullAPIKey, ".", 3)
- if len(parts) != 3 {
- return "", fmt.Errorf("internal error, key is invalid")
- }
- apiKey := parts[0]
-
- hmacedKey, err := db.GenerateAPIKeyHMAC(apiKey)
- if err != nil {
- return "", fmt.Errorf("failed to create hmac: %w", err)
- }
-
- app.RealmID = r.ID
- app.APIKey = hmacedKey
- app.APIKeyPreview = apiKey[:6]
-
- if err := db.SaveAuthorizedApp(app, actor); err != nil {
- return "", err
- }
- return fullAPIKey, nil
-}
-
// FindAuthorizedApp finds the authorized app by the given id.
func (db *Database) FindAuthorizedApp(id interface{}) (*AuthorizedApp, error) {
var app AuthorizedApp
diff --git a/pkg/database/authorized_app_stats_test.go b/pkg/database/authorized_app_stats_test.go
new file mode 100644
index 000000000..1a8f3512f
--- /dev/null
+++ b/pkg/database/authorized_app_stats_test.go
@@ -0,0 +1,96 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package database
+
+import (
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestAuthorizedAppStats_MarshalCSV(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ name string
+ stats AuthorizedAppStats
+ exp string
+ }{
+ {
+ name: "empty",
+ stats: nil,
+ exp: "",
+ },
+ {
+ name: "single",
+ stats: []*AuthorizedAppStat{
+ {
+ Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC),
+ AuthorizedAppID: 1,
+ CodesIssued: 10,
+ AuthorizedAppName: "Appy",
+ },
+ },
+ exp: `date,authorized_app_id,authorized_app_name,codes_issued
+2020-02-03,1,Appy,10
+`,
+ },
+ {
+ name: "multi",
+ stats: []*AuthorizedAppStat{
+ {
+ Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC),
+ AuthorizedAppID: 1,
+ CodesIssued: 10,
+ AuthorizedAppName: "Appy",
+ },
+ {
+ Date: time.Date(2020, 2, 4, 0, 0, 0, 0, time.UTC),
+ AuthorizedAppID: 1,
+ CodesIssued: 45,
+ AuthorizedAppName: "Mc",
+ },
+ {
+ Date: time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC),
+ AuthorizedAppID: 1,
+ CodesIssued: 15,
+ AuthorizedAppName: "Apperson",
+ },
+ },
+ exp: `date,authorized_app_id,authorized_app_name,codes_issued
+2020-02-03,1,Appy,10
+2020-02-04,1,Mc,45
+2020-02-05,1,Apperson,15
+`,
+ },
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ b, err := tc.stats.MarshalCSV()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if diff := cmp.Diff(string(b), tc.exp); diff != "" {
+ t.Errorf("bad csv (+got, -want): %s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/database/authorized_app_test.go b/pkg/database/authorized_app_test.go
index 82f8a6008..16a3688d8 100644
--- a/pkg/database/authorized_app_test.go
+++ b/pkg/database/authorized_app_test.go
@@ -19,8 +19,160 @@ import (
"fmt"
"strings"
"testing"
+ "time"
+
+ "github.com/google/exposure-notifications-server/pkg/timeutils"
+ "github.com/jinzhu/gorm"
)
+func TestAPIKeyType(t *testing.T) {
+ t.Parallel()
+
+ // This test might seem like it's redundant, but it's designed to ensure that
+ // the exact values for existing types remain unchanged.
+ cases := []struct {
+ t APIKeyType
+ exp int
+ }{
+ {APIKeyTypeInvalid, -1},
+ {APIKeyTypeDevice, 0},
+ {APIKeyTypeAdmin, 1},
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(tc.t.Display(), func(t *testing.T) {
+ t.Parallel()
+
+ if got, want := int(tc.t), tc.exp; got != want {
+ t.Errorf("expected %d to be %d", got, want)
+ }
+ })
+ }
+}
+
+func TestAPIKeyType_Display(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ t APIKeyType
+ exp string
+ }{
+ {APIKeyTypeInvalid, "invalid"},
+ {APIKeyTypeDevice, "device"},
+ {APIKeyTypeAdmin, "admin"},
+ {1991, "invalid"},
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(fmt.Sprintf("%d", tc.t), func(t *testing.T) {
+ t.Parallel()
+
+ if got, want := tc.t.Display(), tc.exp; got != want {
+ t.Errorf("expected %q to be %q", got, want)
+ }
+ })
+ }
+}
+
+func TestAuthorizedApp_BeforeSave(t *testing.T) {
+ t.Parallel()
+
+ t.Run("name", func(t *testing.T) {
+ t.Parallel()
+ exerciseValidation(t, &AuthorizedApp{}, "Name", "name")
+ })
+
+ t.Run("type", func(t *testing.T) {
+ t.Parallel()
+
+ {
+ var m AuthorizedApp
+ m.APIKeyType = -1
+ _ = m.BeforeSave(&gorm.DB{})
+ if errs := m.ErrorsFor("type"); len(errs) < 1 {
+ t.Errorf("expected errors for type")
+ }
+ }
+
+ {
+ var m AuthorizedApp
+ m.APIKeyType = 55
+ _ = m.BeforeSave(&gorm.DB{})
+ if errs := m.ErrorsFor("type"); len(errs) < 1 {
+ t.Errorf("expected errors for type")
+ }
+ }
+
+ {
+ var m AuthorizedApp
+ m.APIKeyType = 0
+ _ = m.BeforeSave(&gorm.DB{})
+ if errs := m.ErrorsFor("type"); len(errs) != 0 {
+ t.Errorf("expected no errors for type, got %v", errs)
+ }
+ }
+ })
+}
+
+func TestAuthorizedApp_Realm(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+ realm, err := db.FindRealm(1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ authorizedApp := &AuthorizedApp{
+ Name: "Appy",
+ }
+ if _, err := realm.CreateAuthorizedApp(db, authorizedApp, SystemTest); err != nil {
+ t.Fatal(err)
+ }
+
+ gotRealm, err := authorizedApp.Realm(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := gotRealm.ID, realm.ID; got != want {
+ t.Errorf("expected %d to be %d", got, want)
+ }
+}
+
+func TestAuthorizedApp_Stats(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+ realm, err := db.FindRealm(1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ authorizedApp := &AuthorizedApp{
+ Name: "Appy",
+ }
+ if _, err := realm.CreateAuthorizedApp(db, authorizedApp, SystemTest); err != nil {
+ t.Fatal(err)
+ }
+
+ // Ensure graph is contiguous.
+ {
+ stop := timeutils.Midnight(time.Now().UTC())
+ start := stop.Add(6 * -24 * time.Hour)
+ stats, err := authorizedApp.Stats(db, start, stop)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := len(stats), 7; got != want {
+ t.Errorf("expected stats for %d days, got %d", want, got)
+ }
+ }
+}
+
func TestDatabase_CreateFindAPIKey(t *testing.T) {
t.Parallel()
@@ -126,3 +278,45 @@ func TestDatabase_GenerateVerifyAPIKeySignature(t *testing.T) {
t.Errorf("expected %v to be %v", got, want)
}
}
+
+func TestDatabase_PurgeAuthorizedApps(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+
+ now := time.Now().UTC()
+ for i := 0; i < 5; i++ {
+ if err := db.SaveAuthorizedApp(&AuthorizedApp{
+ RealmID: 1,
+ Name: fmt.Sprintf("appy%d", i),
+ APIKey: fmt.Sprintf("%d", i),
+ Model: gorm.Model{
+ DeletedAt: &now,
+ },
+ }, SystemTest); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ // Should not purge entries (too young).
+ {
+ n, err := db.PurgeAuthorizedApps(24 * time.Hour)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := n, int64(0); got != want {
+ t.Errorf("expected %d to purge, got %d", want, got)
+ }
+ }
+
+ // Purges entries.
+ {
+ n, err := db.PurgeAuthorizedApps(1 * time.Nanosecond)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := n, int64(5); got != want {
+ t.Errorf("expected %d to purge, got %d", want, got)
+ }
+ }
+}
diff --git a/pkg/database/cleanup.go b/pkg/database/cleanup.go
index a34e4db7c..2dc4c3dff 100644
--- a/pkg/database/cleanup.go
+++ b/pkg/database/cleanup.go
@@ -34,25 +34,22 @@ type CleanupStatus struct {
NotBefore time.Time
}
-// TableName sets the CleanupStatus table name
-func (CleanupStatus) TableName() string {
- return "cleanup_statuses"
-}
-
// CreateCleanup is used to create a new 'cleanup' type/row in the database.
func (db *Database) CreateCleanup(cType string) (*CleanupStatus, error) {
- cstat := &CleanupStatus{
- Type: cType,
- Generation: 1,
- NotBefore: time.Now().UTC(),
- }
+ var cstat CleanupStatus
+
+ sql := `INSERT INTO cleanup_statuses (type, generation, not_before)
+ VALUES ($1, $2, $3)
+ ON CONFLICT (type) DO UPDATE SET type = EXCLUDED.type
+ RETURNING *`
+
+ now := time.Now().UTC()
if err := db.db.
- Set("gorm:insert_option", "ON CONFLICT (type) DO NOTHING").
- FirstOrCreate(cstat).
- Error; err != nil {
+ Raw(sql, cType, 1, now).
+ Scan(&cstat).Error; err != nil {
return nil, err
}
- return cstat, nil
+ return &cstat, nil
}
// FindCleanupStatus looks up the current cleanup state in the database by cleanup type.
@@ -64,13 +61,14 @@ func (db *Database) FindCleanupStatus(cType string) (*CleanupStatus, error) {
return &cstat, nil
}
-// ClaimCleanup attempts to obtain a lock for the specified `lockTime` so that that type of
-// cleanup can be perofmed exclusively by the owner.
+// ClaimCleanup attempts to obtain a lock for the specified `lockTime` so that
+// that type of cleanup can be performed exclusively by the owner.
func (db *Database) ClaimCleanup(current *CleanupStatus, lockTime time.Duration) (*CleanupStatus, error) {
var cstat CleanupStatus
- err := db.db.Transaction(func(tx *gorm.DB) error {
+ if err := db.db.Transaction(func(tx *gorm.DB) error {
if err := tx.
Set("gorm:query_option", "FOR UPDATE").
+ Model(&CleanupStatus{}).
Where("type = ?", current.Type).
First(&cstat).
Error; err != nil {
@@ -83,8 +81,7 @@ func (db *Database) ClaimCleanup(current *CleanupStatus, lockTime time.Duration)
cstat.Generation++
cstat.NotBefore = time.Now().UTC().Add(lockTime)
return tx.Save(&cstat).Error
- })
- if err != nil {
+ }); err != nil {
return nil, err
}
return &cstat, nil
diff --git a/pkg/database/cleanup_test.go b/pkg/database/cleanup_test.go
new file mode 100644
index 000000000..45d4ca561
--- /dev/null
+++ b/pkg/database/cleanup_test.go
@@ -0,0 +1,120 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package database
+
+import (
+ "errors"
+ "testing"
+ "time"
+)
+
+func TestDatabase_CreateCleanup(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+
+ cleanup1, err := db.CreateCleanup("x")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // If the cleanup already exists, it's a noop
+ cleanup2, err := db.CreateCleanup("x")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if got, want := cleanup1.ID, cleanup2.ID; got != want {
+ t.Errorf("expected %d to be %d", got, want)
+ }
+}
+
+func TestDatabase_FindCleanupStatus(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+
+ want, err := db.CreateCleanup("x")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ got, err := db.FindCleanupStatus("x")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if got, want := got.ID, want.ID; got != want {
+ t.Errorf("expected %d to be %d", got, want)
+ }
+}
+
+func TestDatabase_ClaimCleanup(t *testing.T) {
+ t.Parallel()
+
+ t.Run("no_exist", func(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+ cleanup, err := db.ClaimCleanup(&CleanupStatus{Type: "nope"}, 5*time.Second)
+ if !IsNotFound(err) {
+ t.Errorf("expected error, got: %v: %v", err, cleanup)
+ }
+ })
+
+ t.Run("wrong_generation", func(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+
+ if _, err := db.CreateCleanup("dirty"); err != nil {
+ t.Fatal(err)
+ }
+
+ _, err := db.ClaimCleanup(&CleanupStatus{
+ Type: "dirty",
+ Generation: 2,
+ }, 1*time.Second)
+ if got, want := err, ErrCleanupWrongGeneration; !errors.Is(err, want) {
+ t.Errorf("expected %v to be %v", got, want)
+ }
+ })
+
+ t.Run("exists", func(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+
+ if _, err := db.CreateCleanup("dirty"); err != nil {
+ t.Fatal(err)
+ }
+
+ cleanup, err := db.ClaimCleanup(&CleanupStatus{
+ Type: "dirty",
+ Generation: 1,
+ }, 1*time.Second)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if got, want := cleanup.Generation, uint(2); got != want {
+ t.Errorf("expected generation %d to be %d", got, want)
+ }
+
+ if got, now := cleanup.NotBefore, time.Now().UTC(); !got.After(now) {
+ t.Errorf("expected %q to be after %q", got, now)
+ }
+ })
+}
diff --git a/pkg/database/database_test.go b/pkg/database/database_test.go
index 18acc770b..c75548c2d 100644
--- a/pkg/database/database_test.go
+++ b/pkg/database/database_test.go
@@ -15,7 +15,12 @@
package database
import (
+ "io/ioutil"
+ "log"
+ "reflect"
"testing"
+
+ "github.com/jinzhu/gorm"
)
var testDatabaseInstance *TestInstance
@@ -25,3 +30,53 @@ func TestMain(m *testing.M) {
defer testDatabaseInstance.MustClose()
m.Run()
}
+
+type validateable interface {
+ ErrorsFor(s string) []string
+ BeforeSave(tx *gorm.DB) error
+}
+
+// exerciseValidation exercises zero value validation (not empty) for the given
+// model and struct fields.
+func exerciseValidation(t *testing.T, i validateable, structField, field string) {
+ // Get interface underlying value.
+ v := reflect.ValueOf(&i)
+ for v.Kind() == reflect.Ptr {
+ v = v.Elem()
+ }
+ if !v.CanInterface() {
+ t.Fatalf("%v cannot interface", v)
+ }
+
+ // Convert interface to struct.
+ sv := reflect.ValueOf(v.Interface())
+ for sv.Kind() == reflect.Ptr {
+ sv = sv.Elem()
+ }
+ if sv.Kind() != reflect.Struct {
+ t.Fatalf("%T is not a struct: %v", i, sv.Kind())
+ }
+
+ // Get struct field name.
+ f := sv.FieldByName(structField)
+ if !f.IsValid() {
+ t.Fatalf("%s is not valid", structField)
+ }
+ if !f.CanSet() {
+ t.Fatalf("%s is not settable", structField)
+ }
+
+ // Set to the zero value.
+ valueV := reflect.Zero(f.Type())
+ f.Set(valueV)
+
+ // Create db
+ var db gorm.DB
+ db.SetLogger(gorm.Logger{LogWriter: log.New(ioutil.Discard, "", 0)})
+
+ // Run the validation.
+ _ = i.BeforeSave(&gorm.DB{})
+ if errs := i.ErrorsFor(field); len(errs) < 1 {
+ t.Errorf("expected errors for %s", field)
+ }
+}
diff --git a/pkg/database/errors.go b/pkg/database/errors.go
index 83469b176..56bbca06b 100644
--- a/pkg/database/errors.go
+++ b/pkg/database/errors.go
@@ -17,6 +17,7 @@ package database
import (
"errors"
"fmt"
+ "sort"
)
var (
@@ -47,8 +48,16 @@ func (e *Errorable) Errors() map[string][]string {
func (e *Errorable) ErrorMessages() []string {
e.init()
+ // Sort keys so the response is in predictable ordering.
+ keys := make([]string, 0, len(e.errors))
+ for k := range e.errors {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+
l := make([]string, 0, len(e.errors))
- for k, v := range e.errors {
+ for _, k := range keys {
+ v := e.errors[k]
for _, msg := range v {
l = append(l, fmt.Sprintf("%s %s", k, msg))
}
diff --git a/pkg/database/external_issuer_stats_test.go b/pkg/database/external_issuer_stats_test.go
new file mode 100644
index 000000000..314b6c22f
--- /dev/null
+++ b/pkg/database/external_issuer_stats_test.go
@@ -0,0 +1,96 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package database
+
+import (
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestExternalIssuerStats_MarshalCSV(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ name string
+ stats ExternalIssuerStats
+ exp string
+ }{
+ {
+ name: "empty",
+ stats: nil,
+ exp: "",
+ },
+ {
+ name: "single",
+ stats: []*ExternalIssuerStat{
+ {
+ Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ IssuerID: "user:2",
+ CodesIssued: 10,
+ },
+ },
+ exp: `date,realm_id,issuer_id,codes_issued
+2020-02-03,1,user:2,10
+`,
+ },
+ {
+ name: "multi",
+ stats: []*ExternalIssuerStat{
+ {
+ Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ IssuerID: "user:2",
+ CodesIssued: 10,
+ },
+ {
+ Date: time.Date(2020, 2, 4, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ IssuerID: "user:2",
+ CodesIssued: 45,
+ },
+ {
+ Date: time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ IssuerID: "user:2",
+ CodesIssued: 15,
+ },
+ },
+ exp: `date,realm_id,issuer_id,codes_issued
+2020-02-03,1,user:2,10
+2020-02-04,1,user:2,45
+2020-02-05,1,user:2,15
+`,
+ },
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ b, err := tc.stats.MarshalCSV()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if diff := cmp.Diff(string(b), tc.exp); diff != "" {
+ t.Errorf("bad csv (+got, -want): %s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/database/membership.go b/pkg/database/membership.go
index f4a6c6d36..64ab2f10a 100644
--- a/pkg/database/membership.go
+++ b/pkg/database/membership.go
@@ -16,12 +16,15 @@ package database
import (
"fmt"
+ "strings"
"github.com/google/exposure-notifications-verification-server/pkg/rbac"
)
// Membership represents a user's membership in a realm.
type Membership struct {
+ Errorable
+
UserID uint
User *User
@@ -35,13 +38,16 @@ type Membership struct {
// preloaded and the referenced values exist.
func (m *Membership) AfterFind() error {
if m.User == nil {
- return fmt.Errorf("membership user does not exist")
+ m.AddError("user", "does not exist")
}
if m.Realm == nil {
- return fmt.Errorf("membership realm does not exist")
+ m.AddError("realm", "does not exist")
}
+ if msgs := m.ErrorMessages(); len(msgs) > 0 {
+ return fmt.Errorf("lookup failed: %s", strings.Join(msgs, ", "))
+ }
return nil
}
diff --git a/pkg/database/membership_test.go b/pkg/database/membership_test.go
new file mode 100644
index 000000000..858847a18
--- /dev/null
+++ b/pkg/database/membership_test.go
@@ -0,0 +1,43 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package database
+
+import (
+ "testing"
+)
+
+func TestMembership_AfterFind(t *testing.T) {
+ t.Parallel()
+
+ t.Run("user", func(t *testing.T) {
+ t.Parallel()
+
+ var m Membership
+ _ = m.AfterFind()
+ if errs := m.ErrorsFor("user"); len(errs) < 1 {
+ t.Errorf("expected errors for %s", "user")
+ }
+ })
+
+ t.Run("realm", func(t *testing.T) {
+ t.Parallel()
+
+ var m Membership
+ _ = m.AfterFind()
+ if errs := m.ErrorsFor("realm"); len(errs) < 1 {
+ t.Errorf("expected errors for %s", "realm")
+ }
+ })
+}
diff --git a/pkg/database/mobile_app.go b/pkg/database/mobile_app.go
index 5e92880e0..d4abfa7c6 100644
--- a/pkg/database/mobile_app.go
+++ b/pkg/database/mobile_app.go
@@ -130,10 +130,14 @@ func (a *MobileApp) BeforeSave(tx *gorm.DB) error {
}
a.SHA = strings.Join(shas, "\n")
- if len(a.Errors()) > 0 {
- return fmt.Errorf("validation failed: %v", a.Errors())
+ if msgs := a.ErrorMessages(); len(msgs) > 0 {
+ return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", "))
}
+ return nil
+}
+func (a *MobileApp) AfterFind(tx *gorm.DB) error {
+ a.URL = stringValue(a.URLPtr)
return nil
}
@@ -290,11 +294,6 @@ func (a *MobileApp) AuditDisplay() string {
return fmt.Sprintf("%s (%s)", a.Name, a.OS.Display())
}
-func (a *MobileApp) AfterFind(tx *gorm.DB) error {
- a.URL = stringValue(a.URLPtr)
- return nil
-}
-
// PurgeMobileApps will delete mobile apps that have been deleted for more than
// the specified time.
func (db *Database) PurgeMobileApps(maxAge time.Duration) (int64, error) {
diff --git a/pkg/database/mobile_app_test.go b/pkg/database/mobile_app_test.go
index 5cadd7866..53235521a 100644
--- a/pkg/database/mobile_app_test.go
+++ b/pkg/database/mobile_app_test.go
@@ -15,82 +15,94 @@
package database
import (
+ "fmt"
"testing"
+ "time"
"github.com/google/exposure-notifications-verification-server/pkg/pagination"
+ "github.com/jinzhu/gorm"
)
-func TestMobileApp_Validation(t *testing.T) {
+func TestOSType(t *testing.T) {
t.Parallel()
- db, _ := testDatabaseInstance.NewDatabase(t, nil)
-
- t.Run("name", func(t *testing.T) {
- t.Parallel()
-
- var m MobileApp
- m.Name = ""
- _ = m.BeforeSave(db.RawDB())
-
- nameErrs := m.ErrorsFor("name")
- if len(nameErrs) < 1 {
- t.Fatal("expected error")
- }
- })
-
- t.Run("app_id", func(t *testing.T) {
- t.Parallel()
-
- var m MobileApp
- m.AppID = ""
- _ = m.BeforeSave(db.RawDB())
-
- appIDErrs := m.ErrorsFor("app_id")
- if len(appIDErrs) < 1 {
- t.Fatal("expected error")
- }
- })
-
- t.Run("os", func(t *testing.T) {
- t.Parallel()
-
- var m MobileApp
- m.OS = 0
- _ = m.BeforeSave(db.RawDB())
+ // This test might seem like it's redundant, but it's designed to ensure that
+ // the exact values for existing types remain unchanged.
+ cases := []struct {
+ t OSType
+ exp int
+ }{
+ {OSTypeInvalid, 0},
+ {OSTypeIOS, 1},
+ {OSTypeAndroid, 2},
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(tc.t.Display(), func(t *testing.T) {
+ t.Parallel()
+
+ if got, want := int(tc.t), tc.exp; got != want {
+ t.Errorf("expected %d to be %d", got, want)
+ }
+ })
+ }
+}
- osErrs := m.ErrorsFor("os")
- if len(osErrs) < 1 {
- t.Fatal("expected error")
- }
+func TestOSType_Display(t *testing.T) {
+ t.Parallel()
- m.OS = 4
- _ = m.BeforeSave(db.RawDB())
+ cases := []struct {
+ t OSType
+ exp string
+ }{
+ {OSTypeInvalid, "Unknown"},
+ {OSTypeIOS, "iOS"},
+ {OSTypeAndroid, "Android"},
+ {1991, "Unknown"},
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(fmt.Sprintf("%d", tc.t), func(t *testing.T) {
+ t.Parallel()
+
+ if got, want := tc.t.Display(), tc.exp; got != want {
+ t.Errorf("expected %q to be %q", got, want)
+ }
+ })
+ }
+}
- osErrs = m.ErrorsFor("os")
- if len(osErrs) < 1 {
- t.Fatal("expected error")
- }
- })
+func TestMobileApp_Validation(t *testing.T) {
+ t.Parallel()
- t.Run("url", func(t *testing.T) {
- t.Parallel()
+ cases := []struct {
+ structField string
+ field string
+ }{
+ {"Name", "name"},
+ {"AppID", "app_id"},
+ {"OS", "os"},
+ }
- var m MobileApp
- m.URL = ""
- _ = m.BeforeSave(db.RawDB())
+ for _, tc := range cases {
+ tc := tc
- urlErrors := m.ErrorsFor("url")
- if len(urlErrors) != 0 {
- t.Fatal("no xpected error")
- }
- })
+ t.Run(tc.field, func(t *testing.T) {
+ t.Parallel()
+ exerciseValidation(t, &MobileApp{}, tc.structField, tc.field)
+ })
+ }
t.Run("sha", func(t *testing.T) {
t.Parallel()
var m MobileApp
m.OS = OSTypeIOS
- _ = m.BeforeSave(db.RawDB())
+ _ = m.BeforeSave(&gorm.DB{})
shaErrs := m.ErrorsFor("sha")
if len(shaErrs) > 0 {
@@ -98,7 +110,7 @@ func TestMobileApp_Validation(t *testing.T) {
}
m.OS = OSTypeAndroid
- _ = m.BeforeSave(db.RawDB())
+ _ = m.BeforeSave(&gorm.DB{})
shaErrs = m.ErrorsFor("sha")
if len(shaErrs) < 1 {
@@ -136,7 +148,7 @@ func TestMobileApp_Validation(t *testing.T) {
var m MobileApp
m.SHA = tc.sha
- _ = m.BeforeSave(db.RawDB())
+ _ = m.BeforeSave(&gorm.DB{})
shaErrs := m.ErrorsFor("sha")
if !tc.err && len(shaErrs) > 0 {
@@ -147,7 +159,7 @@ func TestMobileApp_Validation(t *testing.T) {
})
}
-func TestMobileApp_List(t *testing.T) {
+func TestDatabase_ListActiveApps(t *testing.T) {
t.Parallel()
t.Run("access_mobileapps_and_realms", func(t *testing.T) {
@@ -211,3 +223,47 @@ func TestMobileApp_List(t *testing.T) {
}
})
}
+
+func TestDatabase_PurgeMobileApps(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+
+ now := time.Now().UTC()
+ for i := 0; i < 5; i++ {
+ if err := db.SaveMobileApp(&MobileApp{
+ RealmID: 1,
+ Name: fmt.Sprintf("Appy%d", i),
+ OS: OSTypeIOS,
+ URL: fmt.Sprintf("https://%d.example.com", i),
+ AppID: fmt.Sprintf("app.%d.com", i),
+ Model: gorm.Model{
+ DeletedAt: &now,
+ },
+ }, SystemTest); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ // Should not purge entries (too young).
+ {
+ n, err := db.PurgeMobileApps(24 * time.Hour)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := n, int64(0); got != want {
+ t.Errorf("expected %d to purge, got %d", want, got)
+ }
+ }
+
+ // Purges entries.
+ {
+ n, err := db.PurgeMobileApps(1 * time.Nanosecond)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := n, int64(5); got != want {
+ t.Errorf("expected %d to purge, got %d", want, got)
+ }
+ }
+}
diff --git a/pkg/database/modeler_status_test.go b/pkg/database/modeler_status_test.go
new file mode 100644
index 000000000..447d75463
--- /dev/null
+++ b/pkg/database/modeler_status_test.go
@@ -0,0 +1,42 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package database
+
+import (
+ "testing"
+ "time"
+)
+
+func TestDatabase_ClaimModelerStatus(t *testing.T) {
+ t.Parallel()
+
+ // Create this now so we don't get clock skew
+ later := time.Now().UTC().Add(modelerLockTime)
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+
+ if err := db.ClaimModelerStatus(); err != nil {
+ t.Fatal(err)
+ }
+
+ var status ModelerStatus
+ if err := db.db.Model(&ModelerStatus{}).First(&status).Error; err != nil {
+ t.Fatal(err)
+ }
+
+ if got, now := status.NotBefore, later; !got.After(now) {
+ t.Errorf("expected %q to be after %q", got, now)
+ }
+}
diff --git a/pkg/database/realm.go b/pkg/database/realm.go
index e8a4b7567..7c39a4b15 100644
--- a/pkg/database/realm.go
+++ b/pkg/database/realm.go
@@ -49,7 +49,6 @@ const (
TestTypeConfirmed
TestTypeLikely
TestTypeNegative
- DefaultTemplateLabel = "Default SMS template"
)
func (t TestType) Display() string {
@@ -70,6 +69,30 @@ func (t TestType) Display() string {
return strings.Join(types, ", ")
}
+// AuthRequirement represents authentication requirements for the realm
+type AuthRequirement int16
+
+const (
+ // MFAOptionalPrompt will prompt users for MFA on login.
+ MFAOptionalPrompt AuthRequirement = iota
+ // MFARequired will not allow users to proceed without MFA on their account.
+ MFARequired
+ // MFAOptional will not prompt users to enable MFA.
+ MFAOptional
+)
+
+func (r AuthRequirement) String() string {
+ switch r {
+ case MFAOptionalPrompt:
+ return "prompt"
+ case MFARequired:
+ return "required"
+ case MFAOptional:
+ return "optional"
+ }
+ return ""
+}
+
var (
ErrNoSigningKeyManagement = errors.New("no signing key management")
ErrBadDateRange = errors.New("bad date range")
@@ -89,22 +112,12 @@ const (
SMSTemplateMaxLength = 800
SMSTemplateExpansionMax = 918
+ DefaultTemplateLabel = "Default SMS template"
+
EmailInviteLink = "[invitelink]"
EmailPasswordResetLink = "[passwordresetlink]"
EmailVerifyLink = "[verifylink]"
RealmName = "[realmname]"
-)
-
-// AuthRequirement represents authentication requirements for the realm
-type AuthRequirement int16
-
-const (
- // MFAOptionalPrompt will prompt users for MFA on login.
- MFAOptionalPrompt = iota
- // MFARequired will not allow users to proceed without MFA on their account.
- MFARequired
- // MFAOptional will not prompt users to enable MFA.
- MFAOptional
// MaxPageSize is the maximum allowed page size for a list query.
MaxPageSize = 1000
@@ -250,31 +263,6 @@ type Realm struct {
Tokens []*Token `gorm:"PRELOAD:false; SAVE_ASSOCIATIONS:false; ASSOCIATION_AUTOUPDATE:false, ASSOCIATION_SAVE_REFERENCE:false"`
}
-// EffectiveMFAMode returns the realm's default MFAMode but first
-// checks if the user is in the grace-period (if so, required becomes prompt).
-func (r *Realm) EffectiveMFAMode(user *User) AuthRequirement {
- if r == nil {
- return MFARequired
- }
-
- if time.Since(user.CreatedAt) <= r.MFARequiredGracePeriod.Duration {
- return MFAOptionalPrompt
- }
- return r.MFAMode
-}
-
-func (mode *AuthRequirement) String() string {
- switch *mode {
- case MFAOptionalPrompt:
- return "prompt"
- case MFARequired:
- return "required"
- case MFAOptional:
- return "optional"
- }
- return ""
-}
-
// NewRealmWithDefaults initializes a new Realm with the default settings populated,
// and the provided name. It does NOT save the Realm to the database.
func NewRealmWithDefaults(name string) *Realm {
@@ -291,14 +279,6 @@ func NewRealmWithDefaults(name string) *Realm {
}
}
-func (r *Realm) CanUpgradeToRealmSigningKeys() bool {
- return r.CertificateIssuer != "" && r.CertificateAudience != ""
-}
-
-func (r *Realm) SigningKeyID() string {
- return fmt.Sprintf("realm-%d", r.ID)
-}
-
// AfterFind runs after a realm is found.
func (r *Realm) AfterFind(tx *gorm.DB) error {
r.RegionCode = stringValue(r.RegionCodePtr)
@@ -365,12 +345,12 @@ func (r *Realm) BeforeSave(tx *gorm.DB) error {
if r.SMSTextAlternateTemplates != nil {
for l, t := range r.SMSTextAlternateTemplates {
if t == nil || *t == "" {
- r.AddError("SMSTextTemplate", fmt.Sprintf("no template for label %s", l))
+ r.AddError("smsTextTemplate", fmt.Sprintf("no template for label %s", l))
r.AddError(l, fmt.Sprintf("no template for label %s", l))
continue
}
if l == "" {
- r.AddError("SMSTextTemplate", fmt.Sprintf("no label for template %s", *t))
+ r.AddError("smsTextTemplate", fmt.Sprintf("no label for template %s", *t))
continue
}
r.validateSMSTemplate(l, *t)
@@ -383,19 +363,19 @@ func (r *Realm) BeforeSave(tx *gorm.DB) error {
if r.EmailInviteTemplate != "" {
if !strings.Contains(r.EmailInviteTemplate, EmailInviteLink) {
- r.AddError("EmailInviteLink", fmt.Sprintf("must contain %q", EmailInviteLink))
+ r.AddError("emailInviteLink", fmt.Sprintf("must contain %q", EmailInviteLink))
}
}
if r.EmailPasswordResetTemplate != "" {
if !strings.Contains(r.EmailPasswordResetTemplate, EmailPasswordResetLink) {
- r.AddError("EmailPasswordResetTemplate", fmt.Sprintf("must contain %q", EmailPasswordResetLink))
+ r.AddError("emailPasswordResetTemplate", fmt.Sprintf("must contain %q", EmailPasswordResetLink))
}
}
if r.EmailVerifyTemplate != "" {
if !strings.Contains(r.EmailVerifyTemplate, EmailVerifyLink) {
- r.AddError("EmailVerifyTemplate", fmt.Sprintf("must contain %q", EmailVerifyLink))
+ r.AddError("emailVerifyTemplate", fmt.Sprintf("must contain %q", EmailVerifyLink))
}
}
@@ -416,8 +396,8 @@ func (r *Realm) BeforeSave(tx *gorm.DB) error {
}
}
- if len(r.Errors()) > 0 {
- return fmt.Errorf("realm validation failed: %s", strings.Join(r.ErrorMessages(), ", "))
+ if msgs := r.ErrorMessages(); len(msgs) > 0 {
+ return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", "))
}
return nil
}
@@ -427,29 +407,29 @@ func (r *Realm) BeforeSave(tx *gorm.DB) error {
func (r *Realm) validateSMSTemplate(label, t string) {
if r.EnableENExpress {
if !strings.Contains(t, SMSENExpressLink) {
- r.AddError("SMSTextTemplate", fmt.Sprintf("must contain %q", SMSENExpressLink))
+ r.AddError("smsTextTemplate", fmt.Sprintf("must contain %q", SMSENExpressLink))
r.AddError(label, fmt.Sprintf("must contain %q", SMSENExpressLink))
}
if strings.Contains(t, SMSRegion) {
- r.AddError("SMSTextTemplate", fmt.Sprintf("cannot contain %q - this is automatically included in %q", SMSRegion, SMSENExpressLink))
+ r.AddError("smsTextTemplate", fmt.Sprintf("cannot contain %q - this is automatically included in %q", SMSRegion, SMSENExpressLink))
r.AddError(label, fmt.Sprintf("must contain %q", SMSENExpressLink))
}
if strings.Contains(t, SMSLongCode) {
- r.AddError("SMSTextTemplate", fmt.Sprintf("cannot contain %q - the long code is automatically included in %q", SMSLongCode, SMSENExpressLink))
+ r.AddError("smsTextTemplate", fmt.Sprintf("cannot contain %q - the long code is automatically included in %q", SMSLongCode, SMSENExpressLink))
r.AddError(label, fmt.Sprintf("must contain %q", SMSENExpressLink))
}
} else {
// Check that we have exactly one of [code] or [longcode] as template substitutions.
if c, lc := strings.Contains(t, SMSCode), strings.Contains(t, SMSLongCode); !(c || lc) || (c && lc) {
- r.AddError("SMSTextTemplate", fmt.Sprintf("must contain exactly one of %q or %q", SMSCode, SMSLongCode))
+ r.AddError("smsTextTemplate", fmt.Sprintf("must contain exactly one of %q or %q", SMSCode, SMSLongCode))
r.AddError(label, fmt.Sprintf("must contain %q", SMSENExpressLink))
}
}
// Check template length.
if l := len(t); l > SMSTemplateMaxLength {
- r.AddError("SMSTextTemplate", fmt.Sprintf("must be %d characters or less, current message is %v characters long", SMSTemplateMaxLength, l))
+ r.AddError("smsTextTemplate", fmt.Sprintf("must be %d characters or less, current message is %v characters long", SMSTemplateMaxLength, l))
r.AddError(label, fmt.Sprintf("must contain %q", SMSENExpressLink))
}
@@ -459,11 +439,11 @@ func (r *Realm) validateSMSTemplate(label, t string) {
enxDomain := os.Getenv("ENX_REDIRECT_DOMAIN")
expandedSMSText, err := r.BuildSMSText(fakeCode, fakeLongCode, enxDomain, label)
if err != nil {
- r.AddError("SMSTextTemplate", fmt.Sprintf("SMS template expansion failed: %s", err))
+ r.AddError("smsTextTemplate", fmt.Sprintf("SMS template expansion failed: %s", err))
r.AddError(label, fmt.Sprintf("SMS template expansion failed: %s", err))
}
if l := len(expandedSMSText); l > SMSTemplateExpansionMax {
- r.AddError("SMSTextTemplate", fmt.Sprintf("when expanded, the result message is too long (%v characters). The max expanded message is %v characters", l, SMSTemplateExpansionMax))
+ r.AddError("smsTextTemplate", fmt.Sprintf("when expanded, the result message is too long (%v characters). The max expanded message is %v characters", l, SMSTemplateExpansionMax))
r.AddError(label, fmt.Sprintf("when expanded, the result message is too long (%v characters). The max expanded message is %v characters", l, SMSTemplateExpansionMax))
}
@@ -481,6 +461,19 @@ func (r *Realm) GetLongCodeDurationHours() int {
return int(r.LongCodeDuration.Duration.Hours())
}
+// EffectiveMFAMode returns the realm's default MFAMode but first
+// checks if the user is in the grace-period (if so, required becomes prompt).
+func (r *Realm) EffectiveMFAMode(user *User) AuthRequirement {
+ if r == nil {
+ return MFARequired
+ }
+
+ if time.Since(user.CreatedAt) <= r.MFARequiredGracePeriod.Duration {
+ return MFAOptionalPrompt
+ }
+ return r.MFAMode
+}
+
// FindVerificationCodeByUUID find a verification codes by UUID.
func (r *Realm) FindVerificationCodeByUUID(db *Database, uuid string) (*VerificationCode, error) {
var vc VerificationCode
@@ -1242,6 +1235,45 @@ func (db *Database) SaveRealm(r *Realm, actor Auditable) error {
})
}
+// CreateAuthorizedApp generates a new API key and assigns it to the specified
+// app. Note that the API key is NOT stored in the database, only a hash. The
+// only time the API key is available is as the string return parameter from
+// invoking this function.
+func (r *Realm) CreateAuthorizedApp(db *Database, app *AuthorizedApp, actor Auditable) (string, error) {
+ fullAPIKey, err := db.GenerateAPIKey(r.ID)
+ if err != nil {
+ return "", fmt.Errorf("failed to generate API key: %w", err)
+ }
+
+ parts := strings.SplitN(fullAPIKey, ".", 3)
+ if len(parts) != 3 {
+ return "", fmt.Errorf("internal error, key is invalid")
+ }
+ apiKey := parts[0]
+
+ hmacedKey, err := db.GenerateAPIKeyHMAC(apiKey)
+ if err != nil {
+ return "", fmt.Errorf("failed to create hmac: %w", err)
+ }
+
+ app.RealmID = r.ID
+ app.APIKey = hmacedKey
+ app.APIKeyPreview = apiKey[:6]
+
+ if err := db.SaveAuthorizedApp(app, actor); err != nil {
+ return "", err
+ }
+ return fullAPIKey, nil
+}
+
+func (r *Realm) CanUpgradeToRealmSigningKeys() bool {
+ return r.CertificateIssuer != "" && r.CertificateAudience != ""
+}
+
+func (r *Realm) SigningKeyID() string {
+ return fmt.Sprintf("realm-%d", r.ID)
+}
+
// CreateSigningKeyVersion creates a new signing key version on the key manager
// and saves a reference to the new key version in the database. If creating the
// key in the key manager fails, the database is not updated. However, if
diff --git a/pkg/database/realm_stats_test.go b/pkg/database/realm_stats_test.go
new file mode 100644
index 000000000..edc54bece
--- /dev/null
+++ b/pkg/database/realm_stats_test.go
@@ -0,0 +1,100 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package database
+
+import (
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestRealmStats_MarshalCSV(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ name string
+ stats RealmStats
+ exp string
+ }{
+ {
+ name: "empty",
+ stats: nil,
+ exp: "",
+ },
+ {
+ name: "single",
+ stats: []*RealmStat{
+ {
+ Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ CodesIssued: 10,
+ CodesClaimed: 9,
+ DailyActiveUsers: 2,
+ },
+ },
+ exp: `date,codes_issued,codes_claimed,daily_active_users
+2020-02-03,10,9,2
+`,
+ },
+ {
+ name: "multi",
+ stats: []*RealmStat{
+ {
+ Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ CodesIssued: 10,
+ CodesClaimed: 9,
+ DailyActiveUsers: 12,
+ },
+ {
+ Date: time.Date(2020, 2, 4, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ CodesIssued: 45,
+ CodesClaimed: 30,
+ DailyActiveUsers: 24,
+ },
+ {
+ Date: time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ CodesIssued: 15,
+ CodesClaimed: 2,
+ DailyActiveUsers: 18,
+ },
+ },
+ exp: `date,codes_issued,codes_claimed,daily_active_users
+2020-02-03,10,9,12
+2020-02-04,45,30,24
+2020-02-05,15,2,18
+`,
+ },
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ b, err := tc.stats.MarshalCSV()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if diff := cmp.Diff(string(b), tc.exp); diff != "" {
+ t.Errorf("bad csv (+got, -want): %s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/database/realm_test.go b/pkg/database/realm_test.go
index 559b125bf..4be805000 100644
--- a/pkg/database/realm_test.go
+++ b/pkg/database/realm_test.go
@@ -16,6 +16,7 @@ package database
import (
"context"
+ "fmt"
"os"
"path/filepath"
"strings"
@@ -24,38 +25,115 @@ import (
"github.com/google/exposure-notifications-server/pkg/timeutils"
"github.com/google/exposure-notifications-verification-server/internal/project"
- "github.com/google/exposure-notifications-verification-server/pkg/rbac"
+ "github.com/jinzhu/gorm"
)
-func TestSMS(t *testing.T) {
+func TestTestType(t *testing.T) {
t.Parallel()
- realm := NewRealmWithDefaults("test")
- realm.SMSTextTemplate = "This is your Exposure Notifications Verification code: [enslink] Expires in [longexpires] hours"
- realm.RegionCode = "US-WA"
+ // This test might seem like it's redundant, but it's designed to ensure that
+ // the exact values for existing types remain unchanged.
+ cases := []struct {
+ t TestType
+ exp int
+ }{
+ {TestTypeConfirmed, 2},
+ {TestTypeLikely, 4},
+ {TestTypeNegative, 8},
+ }
- got, err := realm.BuildSMSText("12345678", "abcdefgh12345678", "en.express", "")
- if err != nil {
- t.Fatalf("failed to buildSMS, %v", err)
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(tc.t.Display(), func(t *testing.T) {
+ t.Parallel()
+
+ if got, want := int(tc.t), tc.exp; got != want {
+ t.Errorf("expected %d to be %d", got, want)
+ }
+ })
}
- want := "This is your Exposure Notifications Verification code: https://us-wa.en.express/v?c=abcdefgh12345678 Expires in 24 hours"
- if got != want {
- t.Errorf("SMS text wrong, want: %q got %q", want, got)
+}
+
+func TestTestType_Display(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ t TestType
+ exp string
+ }{
+ {TestTypeConfirmed, "confirmed"},
+ {TestTypeConfirmed | TestTypeLikely, "confirmed, likely"},
+ {TestTypeLikely, "likely"},
+ {TestTypeNegative, "negative"},
}
- realm.SMSTextTemplate = "State of Wonder, COVID-19 Exposure Verification code [code]. Expires in [expires] minutes. Act now!"
- got, err = realm.BuildSMSText("654321", "asdflkjasdlkfjl", "", "")
- if err != nil {
- t.Fatalf("failed to buildSMS, %v", err)
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(fmt.Sprintf("%d", tc.t), func(t *testing.T) {
+ t.Parallel()
+
+ if got, want := tc.t.Display(), tc.exp; got != want {
+ t.Errorf("expected %q to be %q", got, want)
+ }
+ })
}
- want = "State of Wonder, COVID-19 Exposure Verification code 654321. Expires in 15 minutes. Act now!"
- if got != want {
- t.Errorf("SMS text wrong, want: %q got %q", want, got)
+}
+
+func TestAuthRequirement(t *testing.T) {
+ t.Parallel()
+
+ // This test might seem like it's redundant, but it's designed to ensure that
+ // the exact values for existing types remain unchanged.
+ cases := []struct {
+ t AuthRequirement
+ exp int
+ }{
+ {MFAOptionalPrompt, 0},
+ {MFARequired, 1},
+ {MFAOptional, 2},
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(tc.t.String(), func(t *testing.T) {
+ t.Parallel()
+
+ if got, want := int(tc.t), tc.exp; got != want {
+ t.Errorf("expected %d to be %d", got, want)
+ }
+ })
}
}
-func TestValidation(t *testing.T) {
- db, _ := testDatabaseInstance.NewDatabase(t, nil)
+func TestAuthRequirement_String(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ t AuthRequirement
+ exp string
+ }{
+ {MFAOptionalPrompt, "prompt"},
+ {MFARequired, "required"},
+ {MFAOptional, "optional"},
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(fmt.Sprintf("%d", tc.t), func(t *testing.T) {
+ t.Parallel()
+
+ if got, want := tc.t.String(), tc.exp; got != want {
+ t.Errorf("expected %q to be %q", got, want)
+ }
+ })
+ }
+}
+
+func TestRealm_BeforeSave(t *testing.T) {
os.Setenv("ENX_REDIRECT_DOMAIN", "https://en.express")
valid := "State of Wonder, COVID-19 Exposure Verification code [code]. Expires in [expires] minutes. Act now!"
@@ -75,7 +153,6 @@ func TestValidation(t *testing.T) {
{
Name: "region_code_too_long",
Input: &Realm{
- Name: "foo",
RegionCode: "USA-IS-A-OK",
},
Error: "regionCode cannot be more than 10 characters",
@@ -83,12 +160,28 @@ func TestValidation(t *testing.T) {
{
Name: "enx_region_code_mismatch",
Input: &Realm{
- Name: "foo",
RegionCode: " ",
EnableENExpress: true,
},
Error: "regionCode cannot be blank when using EN Express",
},
+ {
+ Name: "system_sms_forbidden",
+ Input: &Realm{
+ UseSystemSMSConfig: true,
+ CanUseSystemSMSConfig: false,
+ },
+ Error: "useSystemSMSConfig is not allowed on this realm",
+ },
+ {
+ Name: "system_sms_missing_from",
+ Input: &Realm{
+ UseSystemSMSConfig: true,
+ CanUseSystemSMSConfig: true,
+ SMSFromNumberID: 0,
+ },
+ Error: "smsFromNumber is required to use the system config",
+ },
{
Name: "rotation_warning_too_big",
Input: &Realm{
@@ -141,7 +234,7 @@ func TestValidation(t *testing.T) {
EnableENExpress: true,
SMSTextTemplate: "call 1-800-555-1234",
},
- Error: "SMSTextTemplate must contain \"[enslink]\"",
+ Error: "smsTextTemplate must contain \"[enslink]\"",
},
{
Name: "enx_link_contains_region",
@@ -150,7 +243,7 @@ func TestValidation(t *testing.T) {
EnableENExpress: true,
SMSTextTemplate: "[enslink] [region]",
},
- Error: "SMSTextTemplate cannot contain \"[region]\" - this is automatically included in \"[enslink]\"",
+ Error: "smsTextTemplate cannot contain \"[region]\" - this is automatically included in \"[enslink]\"",
},
{
Name: "enx_link_contains_long_code",
@@ -159,7 +252,7 @@ func TestValidation(t *testing.T) {
EnableENExpress: true,
SMSTextTemplate: "[enslink] [longcode]",
},
- Error: "SMSTextTemplate cannot contain \"[longcode]\" - the long code is automatically included in \"[enslink]\"",
+ Error: "smsTextTemplate cannot contain \"[longcode]\" - the long code is automatically included in \"[enslink]\"",
},
{
Name: "link_missing_code",
@@ -168,7 +261,7 @@ func TestValidation(t *testing.T) {
EnableENExpress: false,
SMSTextTemplate: "call me",
},
- Error: "SMSTextTemplate must contain exactly one of \"[code]\" or \"[longcode]\"",
+ Error: "smsTextTemplate must contain exactly one of \"[code]\" or \"[longcode]\"",
},
{
Name: "link_both_codess",
@@ -177,7 +270,7 @@ func TestValidation(t *testing.T) {
EnableENExpress: false,
SMSTextTemplate: "[code][longcode]",
},
- Error: "SMSTextTemplate must contain exactly one of \"[code]\" or \"[longcode]\"",
+ Error: "smsTextTemplate must contain exactly one of \"[code]\" or \"[longcode]\"",
},
{
Name: "text_too_long",
@@ -188,7 +281,7 @@ func TestValidation(t *testing.T) {
Curabitur non massa urna. Phasellus sit amet nisi ut quam dapibus pretium eget in turpis. Phasellus et justo odio. In auctor, felis a tincidunt maximus, nunc erat vehicula ligula, ac posuere felis odio eget mauris. Nulla gravida.`,
},
- Error: "SMSTextTemplate must be 800 characters or less, current message is 807 characters long",
+ Error: "smsTextTemplate must be 800 characters or less, current message is 807 characters long",
},
{
Name: "text_too_long",
@@ -197,7 +290,7 @@ func TestValidation(t *testing.T) {
EnableENExpress: false,
SMSTextTemplate: strings.Repeat("[enslink]", 88),
},
- Error: "SMSTextTemplate when expanded, the result message is too long (3168 characters). The max expanded message is 918 characters",
+ Error: "smsTextTemplate when expanded, the result message is too long (3168 characters). The max expanded message is 918 characters",
},
{
Name: "valid",
@@ -220,7 +313,7 @@ func TestValidation(t *testing.T) {
Error: "no template for label alternate1",
},
{
- Name: "alternate_sms_template valid",
+ Name: "alternate_sms_template_valid",
Input: &Realm{
Name: "b",
CodeLength: 6,
@@ -230,54 +323,154 @@ func TestValidation(t *testing.T) {
SMSTextAlternateTemplates: map[string]*string{"alternate1": &valid},
},
},
+ {
+ Name: "system_email_forbidden",
+ Input: &Realm{
+ UseSystemEmailConfig: true,
+ CanUseSystemEmailConfig: false,
+ },
+ Error: "useSystemEmailConfig is not allowed on this realm",
+ },
+ {
+ Name: "email_invite_template_missing_link",
+ Input: &Realm{
+ EmailInviteTemplate: "banana",
+ },
+ Error: "emailInviteLink must contain \"[invitelink]\"",
+ },
+ {
+ Name: "email_password_reset_template_missing_link",
+ Input: &Realm{
+ EmailPasswordResetTemplate: "banana",
+ },
+ Error: "emailPasswordResetTemplate must contain \"[passwordresetlink]\"",
+ },
+ {
+ Name: "email_verify_template_missing_link",
+ Input: &Realm{
+ EmailVerifyTemplate: "banana",
+ },
+ Error: "emailVerifyTemplate must contain \"[verifylink]\"",
+ },
+ {
+ Name: "certificate_issuer_blank",
+ Input: &Realm{
+ UseRealmCertificateKey: true,
+ CertificateIssuer: "",
+ },
+ Error: "certificateIssuer cannot be blank",
+ },
+ {
+ Name: "certificate_audience_blank",
+ Input: &Realm{
+ UseRealmCertificateKey: true,
+ CertificateAudience: "",
+ },
+ Error: "certificateAudience cannot be blank",
+ },
}
for _, tc := range cases {
+ tc := tc
+
t.Run(tc.Name, func(t *testing.T) {
- if err := db.SaveRealm(tc.Input, SystemTest); err == nil {
+ t.Parallel()
+
+ if err := tc.Input.BeforeSave(&gorm.DB{}); err != nil {
if tc.Error != "" {
- t.Fatalf("expected error: %q got: nil", tc.Error)
+ if got, want := err.Error(), tc.Error; !strings.Contains(got, want) {
+ t.Errorf("expected %q to be %q", got, want)
+ }
+ } else {
+ t.Errorf("bad error: %s", err)
}
- } else if tc.Error == "" {
- t.Fatalf("expected no error, got %q", err.Error())
- } else if !strings.Contains(err.Error(), tc.Error) {
- t.Fatalf("wrong error, want: %q got: %q", tc.Error, err.Error())
}
})
}
}
-func TestPerUserRealmStats(t *testing.T) {
+func TestRealm_BuildSMSText(t *testing.T) {
+ t.Parallel()
+
+ realm := NewRealmWithDefaults("test")
+ realm.SMSTextTemplate = "This is your Exposure Notifications Verification code: [enslink] Expires in [longexpires] hours"
+ realm.RegionCode = "US-WA"
+
+ got, err := realm.BuildSMSText("12345678", "abcdefgh12345678", "en.express", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := "This is your Exposure Notifications Verification code: https://us-wa.en.express/v?c=abcdefgh12345678 Expires in 24 hours"
+ if got != want {
+ t.Errorf("SMS text wrong, want: %q got %q", want, got)
+ }
+
+ realm.SMSTextTemplate = "State of Wonder, COVID-19 Exposure Verification code [code]. Expires in [expires] minutes. Act now!"
+ got, err = realm.BuildSMSText("654321", "asdflkjasdlkfjl", "", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ want = "State of Wonder, COVID-19 Exposure Verification code 654321. Expires in 15 minutes. Act now!"
+ if got != want {
+ t.Errorf("SMS text wrong, want: %q got %q", want, got)
+ }
+}
+
+func TestRealm_BuildInviteEmail(t *testing.T) {
+ t.Parallel()
+
+ realm := NewRealmWithDefaults("test")
+ realm.EmailInviteTemplate = "Welcome to [realmname] [invitelink]."
+
+ if got, want := realm.BuildInviteEmail("https://join.now"), "Welcome to test https://join.now."; got != want {
+ t.Errorf("expected %q to be %q", got, want)
+ }
+}
+
+func TestRealm_BuildPasswordResetEmail(t *testing.T) {
+ t.Parallel()
+
+ realm := NewRealmWithDefaults("test")
+ realm.EmailPasswordResetTemplate = "Hey [realmname] reset [passwordresetlink]."
+
+ if got, want := realm.BuildPasswordResetEmail("https://reset.now"), "Hey test reset https://reset.now."; got != want {
+ t.Errorf("expected %q to be %q", got, want)
+ }
+}
+
+func TestRealm_BuildVerifyEmail(t *testing.T) {
+ t.Parallel()
+
+ realm := NewRealmWithDefaults("test")
+ realm.EmailVerifyTemplate = "Hey [realmname] verify [verifylink]."
+
+ if got, want := realm.BuildVerifyEmail("https://verify.now"), "Hey test verify https://verify.now."; got != want {
+ t.Errorf("expected %q to be %q", got, want)
+ }
+}
+
+func TestRealm_UserStats(t *testing.T) {
t.Parallel()
db, _ := testDatabaseInstance.NewDatabase(t, nil)
+ realm, err := db.FindRealm(1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
numDays := 7
endDate := timeutils.Midnight(time.Now())
startDate := timeutils.Midnight(endDate.Add(time.Duration(numDays) * -24 * time.Hour))
- // Create a new realm
- realm := NewRealmWithDefaults("test")
- if err := db.SaveRealm(realm, SystemTest); err != nil {
- t.Fatalf("error saving realm: %v", err)
- }
-
- // Create the users.
- users := []*User{}
+ // Create users.
for userIdx, name := range []string{"Rocky", "Bullwinkle", "Boris", "Natasha"} {
user := &User{
- Name: name,
- Email: name + "@gmail.com",
- SystemAdmin: false,
+ Name: name,
+ Email: name + "@example.com",
}
if err := db.SaveUser(user, SystemTest); err != nil {
- t.Fatalf("[%v] error creating user: %v", name, err)
- }
- users = append(users, user)
-
- // Add to realm
- if err := user.AddToRealm(db, realm, rbac.CodeIssue, SystemTest); err != nil {
- t.Fatal(err)
+ t.Fatalf("failed to create user %q: %s", name, err)
}
// Add some stats per user.
@@ -294,10 +487,6 @@ func TestPerUserRealmStats(t *testing.T) {
}
}
- if len(users) == 0 { // sanity check
- t.Error("len(users) = 0, expected ≠ 0")
- }
-
stats, err := realm.UserStats(db, startDate, endDate)
if err != nil {
t.Fatalf("error getting stats: %v", err)
@@ -509,3 +698,82 @@ func TestRealm_SMSConfig(t *testing.T) {
}
}
}
+
+func TestRealm_EmailConfig(t *testing.T) {
+ t.Parallel()
+
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
+
+ realm, err := db.FindRealm(1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Initial realm should have no config
+ if _, err := realm.EmailConfig(db); err == nil {
+ t.Fatalf("expected error")
+ }
+
+ // Create config
+ if err := db.SaveEmailConfig(&EmailConfig{
+ RealmID: realm.ID,
+ SMTPAccount: "account",
+ SMTPHost: "host",
+ SMTPPort: "port",
+ SMTPPassword: "password",
+ }); err != nil {
+ t.Fatal(err)
+ }
+
+ {
+ // Now the realm should have a config
+ EmailConfig, err := realm.EmailConfig(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := EmailConfig.SMTPAccount, "account"; got != want {
+ t.Errorf("expected %v to be %v", got, want)
+ }
+ if got, want := EmailConfig.SMTPHost, "host"; got != want {
+ t.Errorf("expected %v to be %v", got, want)
+ }
+ if got, want := EmailConfig.SMTPPort, "port"; got != want {
+ t.Errorf("expected %v to be %v", got, want)
+ }
+ }
+
+ // Create system config
+ if err := db.SaveEmailConfig(&EmailConfig{
+ SMTPAccount: "system-account",
+ SMTPHost: "system-host",
+ SMTPPort: "system-port",
+ SMTPPassword: "system-password",
+ IsSystem: true,
+ }); err != nil {
+ t.Fatal(err)
+ }
+
+ // Update to use system config
+ realm.CanUseSystemEmailConfig = true
+ realm.UseSystemEmailConfig = true
+ if err := db.SaveRealm(realm, SystemTest); err != nil {
+ t.Fatal(err)
+ }
+
+ // The realm should use the system config.
+ {
+ emailConfig, err := realm.EmailConfig(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := emailConfig.SMTPAccount, "system-account"; got != want {
+ t.Errorf("expected %v to be %v", got, want)
+ }
+ if got, want := emailConfig.SMTPHost, "system-host"; got != want {
+ t.Errorf("expected %v to be %v", got, want)
+ }
+ if got, want := emailConfig.SMTPPort, "system-port"; got != want {
+ t.Errorf("expected %v to be %v", got, want)
+ }
+ }
+}
diff --git a/pkg/database/realm_user_stats_test.go b/pkg/database/realm_user_stats_test.go
new file mode 100644
index 000000000..42919e9a7
--- /dev/null
+++ b/pkg/database/realm_user_stats_test.go
@@ -0,0 +1,104 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package database
+
+import (
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestRealmUserStats_MarshalCSV(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ name string
+ stats RealmUserStats
+ exp string
+ }{
+ {
+ name: "empty",
+ stats: nil,
+ exp: "",
+ },
+ {
+ name: "single",
+ stats: []*RealmUserStat{
+ {
+ Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ UserID: 1,
+ Name: "You",
+ Email: "you@example.com",
+ CodesIssued: 10,
+ },
+ },
+ exp: `date,realm_id,user_id,name,email,codes_issued
+2020-02-03,1,1,You,you@example.com,10
+`,
+ },
+ {
+ name: "multi",
+ stats: []*RealmUserStat{
+ {
+ Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ UserID: 1,
+ Name: "You",
+ Email: "you@example.com",
+ CodesIssued: 10,
+ },
+ {
+ Date: time.Date(2020, 2, 4, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ UserID: 2,
+ Name: "Them",
+ Email: "them@example.com",
+ CodesIssued: 45,
+ },
+ {
+ Date: time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC),
+ RealmID: 1,
+ UserID: 3,
+ Name: "Us",
+ Email: "us@example.com",
+ CodesIssued: 15,
+ },
+ },
+ exp: `date,realm_id,user_id,name,email,codes_issued
+2020-02-03,1,1,You,you@example.com,10
+2020-02-04,1,2,Them,them@example.com,45
+2020-02-05,1,3,Us,us@example.com,15
+`,
+ },
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ b, err := tc.stats.MarshalCSV()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if diff := cmp.Diff(string(b), tc.exp); diff != "" {
+ t.Errorf("bad csv (+got, -want): %s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/database/signing_key.go b/pkg/database/signing_key.go
index c31057ea0..415ec709f 100644
--- a/pkg/database/signing_key.go
+++ b/pkg/database/signing_key.go
@@ -38,11 +38,3 @@ type SigningKey struct {
func (s *SigningKey) GetKID() string {
return fmt.Sprintf("r%dv%d", s.RealmID, s.ID)
}
-
-func (s *SigningKey) Delete(db *Database) error {
- return db.db.Delete(s).Error
-}
-
-func (db *Database) SaveSigningKey(s *SigningKey) error {
- return db.db.Save(s).Error
-}
diff --git a/pkg/database/sms_config_test.go b/pkg/database/sms_config_test.go
index 34dcde84f..0feadb894 100644
--- a/pkg/database/sms_config_test.go
+++ b/pkg/database/sms_config_test.go
@@ -25,9 +25,8 @@ func TestSMSConfig_Lifecycle(t *testing.T) {
db, _ := testDatabaseInstance.NewDatabase(t, nil)
- // Create realm
- realm := NewRealmWithDefaults(t.Name())
- if err := db.SaveRealm(realm, SystemTest); err != nil {
+ realm, err := db.FindRealm(1)
+ if err != nil {
t.Fatal(err)
}
@@ -55,7 +54,7 @@ func TestSMSConfig_Lifecycle(t *testing.T) {
}
// Get the realm to verify SMS configs are NOT preloaded
- realm, err := db.FindRealm(realm.ID)
+ realm, err = db.FindRealm(realm.ID)
if err != nil {
t.Fatal(err)
}
@@ -117,8 +116,8 @@ func TestSMSProvider(t *testing.T) {
db, _ := testDatabaseInstance.NewDatabase(t, nil)
- realm := NewRealmWithDefaults("test-sms-realm-1")
- if err := db.SaveRealm(realm, SystemTest); err != nil {
+ realm, err := db.FindRealm(1)
+ if err != nil {
t.Fatal(err)
}
diff --git a/pkg/database/sms_from_number_test.go b/pkg/database/sms_from_number_test.go
index 9f072b663..85b0bdff2 100644
--- a/pkg/database/sms_from_number_test.go
+++ b/pkg/database/sms_from_number_test.go
@@ -22,33 +22,22 @@ import (
func TestSMSFromNumber_BeforeSave(t *testing.T) {
t.Parallel()
- db, _ := testDatabaseInstance.NewDatabase(t, nil)
-
- t.Run("label", func(t *testing.T) {
- t.Parallel()
-
- var n SMSFromNumber
- n.Label = ""
- _ = n.BeforeSave(db.RawDB())
-
- errs := n.ErrorsFor("label")
- if len(errs) < 1 {
- t.Fatal("expected error")
- }
- })
-
- t.Run("value", func(t *testing.T) {
- t.Parallel()
+ cases := []struct {
+ structField string
+ field string
+ }{
+ {"Label", "label"},
+ {"Value", "value"},
+ }
- var n SMSFromNumber
- n.Value = ""
- _ = n.BeforeSave(db.RawDB())
+ for _, tc := range cases {
+ tc := tc
- errs := n.ErrorsFor("value")
- if len(errs) < 1 {
- t.Fatal("expected error")
- }
- })
+ t.Run(tc.field, func(t *testing.T) {
+ t.Parallel()
+ exerciseValidation(t, &SMSFromNumber{}, tc.structField, tc.field)
+ })
+ }
}
func TestSMSFromNumbers(t *testing.T) {
diff --git a/pkg/database/user.go b/pkg/database/user.go
index cf61a77b5..3e217f35c 100644
--- a/pkg/database/user.go
+++ b/pkg/database/user.go
@@ -47,6 +47,27 @@ type User struct {
LastPasswordChange time.Time
}
+// BeforeSave runs validations. If there are errors, the save fails.
+func (u *User) BeforeSave(tx *gorm.DB) error {
+ u.Email = project.TrimSpace(u.Email)
+ if u.Email == "" {
+ u.AddError("email", "cannot be blank")
+ }
+ if !strings.Contains(u.Email, "@") {
+ u.AddError("email", "appears to be invalid")
+ }
+
+ u.Name = project.TrimSpace(u.Name)
+ if u.Name == "" {
+ u.AddError("name", "cannot be blank")
+ }
+
+ if msgs := u.ErrorMessages(); len(msgs) > 0 {
+ return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", "))
+ }
+ return nil
+}
+
// PasswordChanged returns password change time or account creation time if unset.
func (u *User) PasswordChanged() time.Time {
if u.LastPasswordChange.Before(launched) {
@@ -75,29 +96,6 @@ func (u *User) PasswordAgeString() string {
return fmt.Sprintf("%d seconds", int(ago.Seconds()))
}
-// BeforeSave runs validations. If there are errors, the save fails.
-func (u *User) BeforeSave(tx *gorm.DB) error {
- // Validation
- u.Email = project.TrimSpace(u.Email)
- if u.Email == "" {
- u.AddError("email", "cannot be blank")
- }
- if !strings.Contains(u.Email, "@") {
- u.AddError("email", "appears to be invalid")
- }
-
- u.Name = project.TrimSpace(u.Name)
- if u.Name == "" {
- u.AddError("name", "cannot be blank")
- }
-
- if len(u.Errors()) > 0 {
- return fmt.Errorf("validation failed: %s", strings.Join(u.ErrorMessages(), ", "))
- }
-
- return nil
-}
-
// FindUser finds a user by the given id, if one exists. The id can be a string
// or integer value. It returns an error if the record is not found.
func (db *Database) FindUser(id interface{}) (*User, error) {
diff --git a/pkg/database/user_stats_test.go b/pkg/database/user_stats_test.go
new file mode 100644
index 000000000..f73d88f02
--- /dev/null
+++ b/pkg/database/user_stats_test.go
@@ -0,0 +1,104 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package database
+
+import (
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestUserStats_MarshalCSV(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ name string
+ stats UserStats
+ exp string
+ }{
+ {
+ name: "empty",
+ stats: nil,
+ exp: "",
+ },
+ {
+ name: "single",
+ stats: []*UserStat{
+ {
+ Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC),
+ UserID: 1,
+ RealmID: 1,
+ CodesIssued: 10,
+ UserName: "You",
+ UserEmail: "you@example.com",
+ },
+ },
+ exp: `date,realm_id,user_id,user_name,user_email,codes_issued
+2020-02-03,1,1,You,you@example.com,10
+`,
+ },
+ {
+ name: "multi",
+ stats: []*UserStat{
+ {
+ Date: time.Date(2020, 2, 3, 0, 0, 0, 0, time.UTC),
+ UserID: 1,
+ RealmID: 1,
+ CodesIssued: 10,
+ UserName: "You",
+ UserEmail: "you@example.com",
+ },
+ {
+ Date: time.Date(2020, 2, 4, 0, 0, 0, 0, time.UTC),
+ UserID: 2,
+ RealmID: 1,
+ CodesIssued: 45,
+ UserName: "Them",
+ UserEmail: "them@example.com",
+ },
+ {
+ Date: time.Date(2020, 2, 5, 0, 0, 0, 0, time.UTC),
+ UserID: 1,
+ RealmID: 3,
+ CodesIssued: 15,
+ UserName: "Us",
+ UserEmail: "us@example.com",
+ },
+ },
+ exp: `date,realm_id,user_id,user_name,user_email,codes_issued
+2020-02-03,1,1,You,you@example.com,10
+2020-02-04,1,2,Them,them@example.com,45
+2020-02-05,3,1,Us,us@example.com,15
+`,
+ },
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ b, err := tc.stats.MarshalCSV()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if diff := cmp.Diff(string(b), tc.exp); diff != "" {
+ t.Errorf("bad csv (+got, -want): %s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/database/user_test.go b/pkg/database/user_test.go
index 825a2c276..7561c5eb9 100644
--- a/pkg/database/user_test.go
+++ b/pkg/database/user_test.go
@@ -21,7 +21,28 @@ import (
"github.com/google/exposure-notifications-verification-server/pkg/rbac"
)
-func TestUserLifecycle(t *testing.T) {
+func TestUser_BeforeSave(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ structField string
+ field string
+ }{
+ {"Email", "email"},
+ {"Name", "name"},
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(tc.field, func(t *testing.T) {
+ t.Parallel()
+ exerciseValidation(t, &User{}, tc.structField, tc.field)
+ })
+ }
+}
+
+func TestUser_Lifecycle(t *testing.T) {
t.Parallel()
db, _ := testDatabaseInstance.NewDatabase(t, nil)
@@ -99,13 +120,13 @@ func TestUserLifecycle(t *testing.T) {
}
}
-func TestPurgeUsers(t *testing.T) {
+func TestDatabase_PurgeUsers(t *testing.T) {
t.Parallel()
db, _ := testDatabaseInstance.NewDatabase(t, nil)
- realm := NewRealmWithDefaults("test")
- if err := db.SaveRealm(realm, SystemTest); err != nil {
+ realm, err := db.FindRealm(1)
+ if err != nil {
t.Fatal(err)
}
@@ -168,58 +189,60 @@ func TestPurgeUsers(t *testing.T) {
}
}
-func TestRemoveRealmUpdatesTime(t *testing.T) {
+func TestUser_DeleteFromRealm(t *testing.T) {
t.Parallel()
- db, _ := testDatabaseInstance.NewDatabase(t, nil)
+ t.Run("updates_time", func(t *testing.T) {
+ db, _ := testDatabaseInstance.NewDatabase(t, nil)
- realm := NewRealmWithDefaults("test")
- if err := db.SaveRealm(realm, SystemTest); err != nil {
- t.Fatal(err)
- }
+ realm := NewRealmWithDefaults("test")
+ if err := db.SaveRealm(realm, SystemTest); err != nil {
+ t.Fatal(err)
+ }
- email := "purge@example.com"
- user := &User{
- Email: email,
- Name: "Dr Delete",
- }
- if err := db.SaveUser(user, SystemTest); err != nil {
- t.Fatal(err)
- }
+ email := "purge@example.com"
+ user := &User{
+ Email: email,
+ Name: "Dr Delete",
+ }
+ if err := db.SaveUser(user, SystemTest); err != nil {
+ t.Fatal(err)
+ }
- // Add to realm
- if err := user.AddToRealm(db, realm, rbac.LegacyRealmAdmin, SystemTest); err != nil {
- t.Fatal(err)
- }
+ // Add to realm
+ if err := user.AddToRealm(db, realm, rbac.LegacyRealmAdmin, SystemTest); err != nil {
+ t.Fatal(err)
+ }
- got, err := db.FindUser(user.ID)
- if err != nil {
- t.Fatal(err)
- }
+ got, err := db.FindUser(user.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
- if got, want := got.ID, user.ID; got != want {
- t.Errorf("expected %#v to be %#v", got, want)
- }
+ if got, want := got.ID, user.ID; got != want {
+ t.Errorf("expected %#v to be %#v", got, want)
+ }
- time.Sleep(time.Second) // in case this executes in under a nanosecond.
+ time.Sleep(time.Second) // in case this executes in under a nanosecond.
- originalTime := got.Model.UpdatedAt
- if err := user.DeleteFromRealm(db, realm, SystemTest); err != nil {
- t.Fatal(err)
- }
+ originalTime := got.Model.UpdatedAt
+ if err := user.DeleteFromRealm(db, realm, SystemTest); err != nil {
+ t.Fatal(err)
+ }
- got, err = db.FindUser(user.ID)
- if err != nil {
- t.Fatal(err)
- }
+ got, err = db.FindUser(user.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
- if got, want := got.ID, user.ID; got != want {
- t.Errorf("expected %#v to be %#v", got, want)
- }
- // Assert that the user time was updated.
- if originalTime == got.Model.UpdatedAt {
- t.Errorf("expected user time to be updated. Got %#v", originalTime.Format(time.RFC3339))
- }
+ if got, want := got.ID, user.ID; got != want {
+ t.Errorf("expected %#v to be %#v", got, want)
+ }
+ // Assert that the user time was updated.
+ if originalTime == got.Model.UpdatedAt {
+ t.Errorf("expected user time to be updated. Got %#v", originalTime.Format(time.RFC3339))
+ }
+ })
}
func expectExists(t *testing.T, db *Database, id uint) {
diff --git a/pkg/database/vercode.go b/pkg/database/vercode.go
index ee02f9626..f439cfe05 100644
--- a/pkg/database/vercode.go
+++ b/pkg/database/vercode.go
@@ -30,6 +30,7 @@ import (
const (
oneDay = 24 * time.Hour
+
// MinCodeLength defines the minimum number of digits in a code.
MinCodeLength = 6
)
@@ -37,9 +38,9 @@ const (
type CodeType int
const (
- InvalidCode CodeType = iota
- ShortCode
- LongCode
+ _ CodeType = iota
+ CodeTypeShort
+ CodeTypeLong
)
var (
@@ -90,19 +91,14 @@ type VerificationCode struct {
IssuingExternalID string `gorm:"column:issuing_external_id; type:varchar(255);"`
}
-// TableName sets the VerificationCode table name
-func (VerificationCode) TableName() string {
- return "verification_codes"
-}
-
// BeforeSave is used by callbacks.
-func (v *VerificationCode) BeforeSave(scope *gorm.Scope) error {
+func (v *VerificationCode) BeforeSave(tx *gorm.DB) error {
if len(v.IssuingExternalID) > 255 {
v.AddError("issuingExternalID", "cannot exceed 255 characters")
}
- if len(v.Errors()) > 0 {
- return fmt.Errorf("email config validation failed: %s", strings.Join(v.ErrorMessages(), ", "))
+ if msgs := v.ErrorMessages(); len(msgs) > 0 {
+ return fmt.Errorf("validation failed: %s", strings.Join(msgs, ", "))
}
return nil
}
@@ -170,8 +166,6 @@ func (v *VerificationCode) AfterCreate(scope *gorm.Scope) {
}
}
-// TODO(mikehelmick) - Add method to soft delete expired codes
-
// FormatSymptomDate returns YYYY-MM-DD formatted test date, or "" if nil.
func (v *VerificationCode) FormatSymptomDate() string {
if v.SymptomDate == nil {
@@ -184,13 +178,13 @@ func (v *VerificationCode) FormatSymptomDate() string {
// code, and determines if it is expired based on that.
func (db *Database) IsCodeExpired(v *VerificationCode, code string) (bool, CodeType, error) {
if v == nil {
- return false, InvalidCode, fmt.Errorf("provided code is nil")
+ return false, 0, fmt.Errorf("provided code is nil")
}
// It's possible that this could be called with the already HMACd version.
possibles, err := db.generateVerificationCodeHMACs(code)
if err != nil {
- return false, InvalidCode, fmt.Errorf("failed to create hmac: %w", err)
+ return false, 0, fmt.Errorf("failed to create hmac: %w", err)
}
possibles = append(possibles, code)
@@ -206,11 +200,11 @@ func (db *Database) IsCodeExpired(v *VerificationCode, code string) (bool, CodeT
now := time.Now().UTC()
switch {
case inList(v.Code, possibles):
- return !v.ExpiresAt.After(now), ShortCode, nil
+ return !v.ExpiresAt.After(now), CodeTypeShort, nil
case inList(v.LongCode, possibles):
- return !v.LongExpiresAt.After(now), LongCode, nil
+ return !v.LongExpiresAt.After(now), CodeTypeLong, nil
default:
- return true, InvalidCode, fmt.Errorf("not found")
+ return true, 0, fmt.Errorf("not found")
}
}
diff --git a/pkg/database/vercode_test.go b/pkg/database/vercode_test.go
index 4ce9c7dd9..c7c189862 100644
--- a/pkg/database/vercode_test.go
+++ b/pkg/database/vercode_test.go
@@ -16,6 +16,8 @@ package database
import (
"errors"
+ "fmt"
+ "strings"
"testing"
"time"
@@ -26,6 +28,47 @@ import (
"github.com/jinzhu/gorm"
)
+func TestCodeType(t *testing.T) {
+ t.Parallel()
+
+ // This test might seem like it's redundant, but it's designed to ensure that
+ // the exact values for existing types remain unchanged.
+ cases := []struct {
+ t CodeType
+ exp int
+ }{
+ {CodeTypeShort, 1},
+ {CodeTypeLong, 2},
+ }
+
+ for _, tc := range cases {
+ tc := tc
+
+ t.Run(fmt.Sprintf("%d", tc.t), func(t *testing.T) {
+ t.Parallel()
+
+ if got, want := int(tc.t), tc.exp; got != want {
+ t.Errorf("expected %d to be %d", got, want)
+ }
+ })
+ }
+}
+
+func TestVerificationCode_BeforeSave(t *testing.T) {
+ t.Parallel()
+
+ t.Run("issuingExternalID", func(t *testing.T) {
+ t.Parallel()
+
+ var v VerificationCode
+ v.IssuingExternalID = strings.Repeat("*", 256)
+ _ = v.BeforeSave(&gorm.DB{})
+ if errs := v.ErrorsFor("issuingExternalID"); len(errs) < 1 {
+ t.Errorf("expected errors for %s", "issuingExternalID")
+ }
+ })
+}
+
func TestVerificationCode_FindVerificationCode(t *testing.T) {
t.Parallel()
diff --git a/pkg/otp/code_test.go b/pkg/otp/code_test.go
index b7066eae0..fe7f50a94 100644
--- a/pkg/otp/code_test.go
+++ b/pkg/otp/code_test.go
@@ -114,8 +114,8 @@ func TestIssue(t *testing.T) {
}
if exp, codeType, err := db.IsCodeExpired(verCode, code); exp || err != nil {
t.Fatalf("loaded code doesn't match requested code, %v %v", exp, err)
- } else if codeType != database.ShortCode {
- t.Errorf("wrong code type, want: %v got: %v", database.ShortCode, codeType)
+ } else if codeType != database.CodeTypeShort {
+ t.Errorf("wrong code type, want: %v got: %v", database.CodeTypeShort, codeType)
}
}
@@ -126,8 +126,8 @@ func TestIssue(t *testing.T) {
}
if exp, codeType, err := db.IsCodeExpired(verCode, code); exp || err != nil {
t.Fatalf("loaded code doesn't match requested code")
- } else if codeType != database.LongCode {
- t.Errorf("wrong code type, want: %v got: %v", database.LongCode, codeType)
+ } else if codeType != database.CodeTypeLong {
+ t.Errorf("wrong code type, want: %v got: %v", database.CodeTypeLong, codeType)
}
}
}