Skip to content
This repository has been archived by the owner on Jul 12, 2023. It is now read-only.

Fix data races in tests #484

Merged
merged 3 commits into from
Sep 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions pkg/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,10 @@ func (db *Database) Open(ctx context.Context) error {
func (db *Database) OpenWithCacher(ctx context.Context, cacher cache.Cacher) error {
c := db.config

// Establish a connection to the database.
b, err := retry.NewFibonacci(250 * time.Millisecond)
if err != nil {
return fmt.Errorf("failed to configure database backoff: %w", err)
}
b = retry.WithMaxRetries(10, b)
b = retry.WithCappedDuration(2*time.Second, b)

// Establish a connection to the database. We use this later to register
// opencenusus stats.
var rawDB *gorm.DB
if err := retry.Do(ctx, b, func(ctx context.Context) error {
if err := withRetries(ctx, func(ctx context.Context) error {
var err error
rawDB, err = gorm.Open("postgres", c.ConnectionString())
if err != nil {
Expand Down Expand Up @@ -449,3 +443,16 @@ func getFieldString(scope *gorm.Scope, name string) (*gorm.Field, string, bool)

return field, typ, true
}

// withRetries is a helper for creating a fibonacci backoff with capped retries,
// useful for retrying database queries.
func withRetries(ctx context.Context, f retry.RetryFunc) error {
b, err := retry.NewFibonacci(50 * time.Millisecond)
if err != nil {
return fmt.Errorf("failed to configure backoff: %w", err)
}
b = retry.WithMaxRetries(10, b)
b = retry.WithCappedDuration(1*time.Second, b)

return retry.Do(ctx, b, f)
}
5 changes: 0 additions & 5 deletions pkg/database/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ func (db *Database) getMigrations(ctx context.Context) *gormigrate.Gormigrate {
logger := logging.FromContext(ctx)
options := gormigrate.DefaultOptions

// Each migration runs in its own transacton already. Setting to true forces
// all unrun migrations to run in a _single_ transaction which is probably
// undesirable.
options.UseTransaction = false

return gormigrate.New(db.db, options, []*gormigrate.Migration{
{
ID: initState,
Expand Down
157 changes: 89 additions & 68 deletions pkg/database/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func TestIssueToken(t *testing.T) {

cases := []struct {
Name string
Verification VerificationCode
Verification func() *VerificationCode
Accept api.AcceptTypes
UseLongCode bool
Error string
Expand All @@ -100,27 +100,31 @@ func TestIssueToken(t *testing.T) {
}{
{
Name: "normal_token_issue",
Verification: VerificationCode{
Code: "12345678",
LongCode: "12345678ABC",
TestType: "confirmed",
SymptomDate: &symptomDate,
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
Verification: func() *VerificationCode {
return &VerificationCode{
Code: "12345678",
LongCode: "12345678ABC",
TestType: "confirmed",
SymptomDate: &symptomDate,
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
}
},
Accept: acceptConfirmed,
Error: "",
TokenAge: time.Hour,
},
{
Name: "long_code_token_issue",
Verification: VerificationCode{
Code: "22332244",
LongCode: "abcd1234efgh5678",
TestType: "confirmed",
SymptomDate: &symptomDate,
ExpiresAt: time.Now().Add(5 * time.Second),
LongExpiresAt: time.Now().Add(time.Hour),
Verification: func() *VerificationCode {
return &VerificationCode{
Code: "22332244",
LongCode: "abcd1234efgh5678",
TestType: "confirmed",
SymptomDate: &symptomDate,
ExpiresAt: time.Now().Add(5 * time.Second),
LongExpiresAt: time.Now().Add(time.Hour),
}
},
Accept: acceptConfirmed,
UseLongCode: true,
Expand All @@ -129,27 +133,31 @@ func TestIssueToken(t *testing.T) {
},
{
Name: "already_claimed",
Verification: VerificationCode{
Code: "00000001",
LongCode: "00000001ABC",
Claimed: true,
TestType: "confirmed",
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
Verification: func() *VerificationCode {
return &VerificationCode{
Code: "00000001",
LongCode: "00000001ABC",
Claimed: true,
TestType: "confirmed",
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
}
},
Accept: acceptConfirmed,
Error: ErrVerificationCodeUsed.Error(),
TokenAge: time.Hour,
},
{
Name: "code_expired",
Verification: VerificationCode{
Code: "00000002",
LongCode: "00000002ABC",
Claimed: false,
TestType: "confirmed",
ExpiresAt: time.Now().Add(2 * time.Second),
LongExpiresAt: time.Now().Add(2 * time.Second),
Verification: func() *VerificationCode {
return &VerificationCode{
Code: "00000002",
LongCode: "00000002ABC",
Claimed: false,
TestType: "confirmed",
ExpiresAt: time.Now().Add(1 * time.Second),
LongExpiresAt: time.Now().Add(1 * time.Second),
}
},
Accept: acceptConfirmed,
Delay: 2 * time.Second,
Expand All @@ -158,13 +166,15 @@ func TestIssueToken(t *testing.T) {
},
{
Name: "token_expired",
Verification: VerificationCode{
Code: "00000003",
LongCode: "00000003ABC",
Claimed: false,
TestType: "confirmed",
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
Verification: func() *VerificationCode {
return &VerificationCode{
Code: "00000003",
LongCode: "00000003ABC",
Claimed: false,
TestType: "confirmed",
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
}
},
Accept: acceptConfirmed,
Delay: time.Second,
Expand All @@ -173,13 +183,15 @@ func TestIssueToken(t *testing.T) {
},
{
Name: "wrong_test_type",
Verification: VerificationCode{
Code: "00000005",
LongCode: "00000005ABC",
Claimed: false,
TestType: "confirmed",
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
Verification: func() *VerificationCode {
return &VerificationCode{
Code: "00000005",
LongCode: "00000005ABC",
Claimed: false,
TestType: "confirmed",
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
}
},
Accept: acceptConfirmed,
ClaimError: ErrTokenMetadataMismatch.Error(),
Expand All @@ -188,14 +200,16 @@ func TestIssueToken(t *testing.T) {
},
{
Name: "wrong_test_date",
Verification: VerificationCode{
Code: "00000007",
LongCode: "00000007ABC",
Claimed: false,
TestType: "confirmed",
SymptomDate: &symptomDate,
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
Verification: func() *VerificationCode {
return &VerificationCode{
Code: "00000007",
LongCode: "00000007ABC",
Claimed: false,
TestType: "confirmed",
SymptomDate: &symptomDate,
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
}
},
Accept: acceptConfirmed,
ClaimError: ErrTokenMetadataMismatch.Error(),
Expand All @@ -204,14 +218,16 @@ func TestIssueToken(t *testing.T) {
},
{
Name: "unsupported_test_type",
Verification: VerificationCode{
Code: "00000008",
LongCode: "00000008ABC",
Claimed: false,
TestType: "likely",
SymptomDate: &symptomDate,
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
Verification: func() *VerificationCode {
return &VerificationCode{
Code: "00000008",
LongCode: "00000008ABC",
Claimed: false,
TestType: "likely",
SymptomDate: &symptomDate,
ExpiresAt: time.Now().Add(time.Hour),
LongExpiresAt: time.Now().Add(time.Hour),
}
},
Accept: acceptConfirmed,
Error: ErrUnsupportedTestType.Error(),
Expand All @@ -229,15 +245,20 @@ func TestIssueToken(t *testing.T) {
t.Fatal(err)
}

// Create the verification. We do this here instead of inside the test
// struct to mitigate as much time drift as possible. It also ensures we
// get a new VerificationCode on each invocation.
verification := tc.Verification()
verification.RealmID = realm.ID

// Extract the code before saving. After saving, the code on the struct
// will be the HMAC.
code := tc.Verification.Code
code := verification.Code
if tc.UseLongCode {
code = tc.Verification.LongCode
code = verification.LongCode
}

tc.Verification.RealmID = realm.ID
if err := db.SaveVerificationCode(&tc.Verification, codeAge); err != nil {
if err := db.SaveVerificationCode(verification, codeAge); err != nil {
t.Fatalf("error creating verification code: %v", err)
}

Expand All @@ -257,11 +278,11 @@ func TestIssueToken(t *testing.T) {
}

if tc.Error == "" {
if tok.TestType != tc.Verification.TestType {
t.Errorf("test type missmatch want: %v, got %v", tc.Verification.TestType, tok.TestType)
if tok.TestType != verification.TestType {
t.Errorf("test type missmatch want: %v, got %v", verification.TestType, tok.TestType)
}
if tok.FormatSymptomDate() != tc.Verification.FormatSymptomDate() {
t.Errorf("test date missmatch want: %v, got %v", tc.Verification.FormatSymptomDate(), tok.FormatSymptomDate())
if tok.FormatSymptomDate() != verification.FormatSymptomDate() {
t.Errorf("test date missmatch want: %v, got %v", verification.FormatSymptomDate(), tok.FormatSymptomDate())
}

got, err := db.FindTokenByID(tok.TokenID)
Expand All @@ -277,7 +298,7 @@ func TestIssueToken(t *testing.T) {
time.Sleep(tc.Delay)
}

subject := &Subject{TestType: tc.Verification.TestType, SymptomDate: tc.Verification.SymptomDate}
subject := &Subject{TestType: verification.TestType, SymptomDate: verification.SymptomDate}
if tc.Subject != nil {
subject = tc.Subject
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/database/vercode.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (v *VerificationCode) Validate(maxAge time.Duration) error {
return ErrTestTooOld
}
}
if !v.ExpiresAt.After(now) || !v.LongExpiresAt.After(now) {
if now.After(v.ExpiresAt) || now.After(v.LongExpiresAt) {
return ErrCodeAlreadyExpired
}
return nil
Expand Down