Skip to content

Commit

Permalink
sql: add context to db methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisSchinnerl committed May 21, 2024
1 parent 84becc5 commit 528d943
Show file tree
Hide file tree
Showing 13 changed files with 170 additions and 174 deletions.
34 changes: 11 additions & 23 deletions internal/sql/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,7 @@ func (lr *loggedRow) Scan(dest ...any) error {
return err
}

func (ls *loggedStmt) Exec(args ...any) (sql.Result, error) {
return ls.ExecContext(context.Background(), args...)
}

func (ls *loggedStmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) {
func (ls *loggedStmt) Exec(ctx context.Context, args ...any) (sql.Result, error) {
start := time.Now()
result, err := ls.Stmt.ExecContext(ctx, args...)
if dur := time.Since(start); dur > ls.longQueryDuration {
Expand All @@ -77,11 +73,7 @@ func (ls *loggedStmt) ExecContext(ctx context.Context, args ...any) (sql.Result,
return result, err
}

func (ls *loggedStmt) Query(args ...any) (*sql.Rows, error) {
return ls.QueryContext(context.Background(), args...)
}

func (ls *loggedStmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) {
func (ls *loggedStmt) Query(ctx context.Context, args ...any) (*sql.Rows, error) {
start := time.Now()
rows, err := ls.Stmt.QueryContext(ctx, args...)
if dur := time.Since(start); dur > ls.longQueryDuration {
Expand All @@ -90,11 +82,7 @@ func (ls *loggedStmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows,
return rows, err
}

func (ls *loggedStmt) QueryRow(args ...any) *loggedRow {
return ls.QueryRowContext(context.Background(), args...)
}

func (ls *loggedStmt) QueryRowContext(ctx context.Context, args ...any) *loggedRow {
func (ls *loggedStmt) QueryRow(ctx context.Context, args ...any) *loggedRow {
start := time.Now()
row := ls.Stmt.QueryRowContext(ctx, args...)
if dur := time.Since(start); dur > ls.longQueryDuration {
Expand All @@ -105,9 +93,9 @@ func (ls *loggedStmt) QueryRowContext(ctx context.Context, args ...any) *loggedR

// Exec executes a query without returning any rows. The args are for
// any placeholder parameters in the query.
func (lt *loggedTxn) Exec(query string, args ...any) (sql.Result, error) {
func (lt *loggedTxn) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
start := time.Now()
result, err := lt.Tx.Exec(query, args...)
result, err := lt.Tx.ExecContext(ctx, query, args...)
if dur := time.Since(start); dur > lt.longQueryDuration {
lt.log.Warn("slow exec", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack"))
}
Expand All @@ -118,9 +106,9 @@ func (lt *loggedTxn) Exec(query string, args ...any) (sql.Result, error) {
// Multiple queries or executions may be run concurrently from the
// returned statement. The caller must call the statement's Close method
// when the statement is no longer needed.
func (lt *loggedTxn) Prepare(query string) (*loggedStmt, error) {
func (lt *loggedTxn) Prepare(ctx context.Context, query string) (*loggedStmt, error) {
start := time.Now()
stmt, err := lt.Tx.Prepare(query)
stmt, err := lt.Tx.PrepareContext(ctx, query)
if dur := time.Since(start); dur > lt.longQueryDuration {
lt.log.Warn("slow prepare", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack"))
} else if err != nil {
Expand All @@ -136,9 +124,9 @@ func (lt *loggedTxn) Prepare(query string) (*loggedStmt, error) {

// Query executes a query that returns rows, typically a SELECT. The
// args are for any placeholder parameters in the query.
func (lt *loggedTxn) Query(query string, args ...any) (*loggedRows, error) {
func (lt *loggedTxn) Query(ctx context.Context, query string, args ...any) (*loggedRows, error) {
start := time.Now()
rows, err := lt.Tx.Query(query, args...)
rows, err := lt.Tx.QueryContext(ctx, query, args...)
if dur := time.Since(start); dur > lt.longQueryDuration {
lt.log.Warn("slow query", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack"))
}
Expand All @@ -150,9 +138,9 @@ func (lt *loggedTxn) Query(query string, args ...any) (*loggedRows, error) {
// Row's Scan method is called. If the query selects no rows, the *Row's
// Scan will return ErrNoRows. Otherwise, the *Row's Scan scans the
// first selected row and discards the rest.
func (lt *loggedTxn) QueryRow(query string, args ...any) *loggedRow {
func (lt *loggedTxn) QueryRow(ctx context.Context, query string, args ...any) *loggedRow {
start := time.Now()
row := lt.Tx.QueryRow(query, args...)
row := lt.Tx.QueryRowContext(ctx, query, args...)
if dur := time.Since(start); dur > lt.longQueryDuration {
lt.log.Warn("slow query row", zap.String("query", query), zap.Duration("elapsed", dur), zap.Stack("stack"))
}
Expand Down
67 changes: 34 additions & 33 deletions internal/sql/migrations.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sql

import (
"context"
"embed"
"fmt"
"strings"
Expand All @@ -19,19 +20,19 @@ type (
// Migrator is an interface for defining database-specific helper methods
// required during migrations
Migrator interface {
ApplyMigration(func(tx Tx) (bool, error)) error
CreateMigrationTable() error
ApplyMigration(ctx context.Context, fn func(tx Tx) (bool, error)) error
CreateMigrationTable(ctx context.Context) error
DB() *DB
}

MainMigrator interface {
Migrator
MakeDirsForPath(tx Tx, path string) (uint, error)
MakeDirsForPath(ctx context.Context, tx Tx, path string) (uint, error)
}
)

var (
MainMigrations = func(m MainMigrator, migrationsFs embed.FS, log *zap.SugaredLogger) []Migration {
MainMigrations = func(ctx context.Context, m MainMigrator, migrationsFs embed.FS, log *zap.SugaredLogger) []Migration {
dbIdentifier := "main"
return []Migration{
{
Expand All @@ -41,13 +42,13 @@ var (
{
ID: "00001_object_metadata",
Migrate: func(tx Tx) error {
return performMigration(tx, migrationsFs, dbIdentifier, "00001_object_metadata", log)
return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00001_object_metadata", log)
},
},
{
ID: "00002_prune_slabs_trigger",
Migrate: func(tx Tx) error {
err := performMigration(tx, migrationsFs, dbIdentifier, "00002_prune_slabs_trigger", log)
err := performMigration(ctx, tx, migrationsFs, dbIdentifier, "00002_prune_slabs_trigger", log)
if utils.IsErr(err, ErrMySQLNoSuperPrivilege) {
log.Warn("migration 00002_prune_slabs_trigger requires the user to have the SUPER privilege to register triggers")
}
Expand All @@ -57,37 +58,37 @@ var (
{
ID: "00003_idx_objects_size",
Migrate: func(tx Tx) error {
return performMigration(tx, migrationsFs, dbIdentifier, "00003_idx_objects_size", log)
return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00003_idx_objects_size", log)
},
},
{
ID: "00004_prune_slabs_cascade",
Migrate: func(tx Tx) error {
return performMigration(tx, migrationsFs, dbIdentifier, "00004_prune_slabs_cascade", log)
return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00004_prune_slabs_cascade", log)
},
},
{
ID: "00005_zero_size_object_health",
Migrate: func(tx Tx) error {
return performMigration(tx, migrationsFs, dbIdentifier, "00005_zero_size_object_health", log)
return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00005_zero_size_object_health", log)
},
},
{
ID: "00006_idx_objects_created_at",
Migrate: func(tx Tx) error {
return performMigration(tx, migrationsFs, dbIdentifier, "00006_idx_objects_created_at", log)
return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00006_idx_objects_created_at", log)
},
},
{
ID: "00007_host_checks",
Migrate: func(tx Tx) error {
return performMigration(tx, migrationsFs, dbIdentifier, "00007_host_checks", log)
return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00007_host_checks", log)
},
},
{
ID: "00008_directories",
Migrate: func(tx Tx) error {
if err := performMigration(tx, migrationsFs, dbIdentifier, "00008_directories_1", log); err != nil {
if err := performMigration(ctx, tx, migrationsFs, dbIdentifier, "00008_directories_1", log); err != nil {
return fmt.Errorf("failed to migrate: %v", err)
}
// helper type
Expand All @@ -104,7 +105,7 @@ var (
log.Infof("processed %v objects", offset)
}
var objBatch []obj
rows, err := tx.Query("SELECT id, object_id FROM objects ORDER BY id LIMIT ? OFFSET ?", batchSize, offset)
rows, err := tx.Query(ctx, "SELECT id, object_id FROM objects ORDER BY id LIMIT ? OFFSET ?", batchSize, offset)
if err != nil {
return fmt.Errorf("failed to fetch objects: %v", err)
}
Expand Down Expand Up @@ -135,12 +136,12 @@ var (
processedDirs[dir] = struct{}{}

// process
dirID, err := m.MakeDirsForPath(tx, obj.ObjectID)
dirID, err := m.MakeDirsForPath(ctx, tx, obj.ObjectID)
if err != nil {
return fmt.Errorf("failed to create directory %s: %w", obj.ObjectID, err)
}

if _, err := tx.Exec(`
if _, err := tx.Exec(context.Background(), `
UPDATE objects
SET db_directory_id = ?
WHERE object_id LIKE ? AND
Expand All @@ -156,15 +157,15 @@ var (
}
}
log.Info("post-migration directory creation complete")
if err := performMigration(tx, migrationsFs, dbIdentifier, "00008_directories_2", log); err != nil {
if err := performMigration(ctx, tx, migrationsFs, dbIdentifier, "00008_directories_2", log); err != nil {
return fmt.Errorf("failed to migrate: %v", err)
}
return nil
},
},
}
}
MetricsMigrations = func(migrationsFs embed.FS, log *zap.SugaredLogger) []Migration {
MetricsMigrations = func(ctx context.Context, migrationsFs embed.FS, log *zap.SugaredLogger) []Migration {
dbIdentifier := "metrics"
return []Migration{
{
Expand All @@ -174,35 +175,35 @@ var (
{
ID: "00001_idx_contracts_fcid_timestamp",
Migrate: func(tx Tx) error {
return performMigration(tx, migrationsFs, dbIdentifier, "00001_idx_contracts_fcid_timestamp", log)
return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00001_idx_contracts_fcid_timestamp", log)
},
},
}
}
)

func PerformMigrations(m Migrator, fs embed.FS, identifier string, migrations []Migration) error {
func PerformMigrations(ctx context.Context, m Migrator, fs embed.FS, identifier string, migrations []Migration) error {
// try to create migrations table
err := m.CreateMigrationTable()
err := m.CreateMigrationTable(ctx)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}

// check if the migrations table is empty
var isEmpty bool
if err := m.DB().QueryRow("SELECT COUNT(*) = 0 FROM migrations").Scan(&isEmpty); err != nil {
if err := m.DB().QueryRow(ctx, "SELECT COUNT(*) = 0 FROM migrations").Scan(&isEmpty); err != nil {
return fmt.Errorf("failed to count rows in migrations table: %w", err)
} else if isEmpty {
// table is empty, init schema
return initSchema(m.DB(), fs, identifier, migrations)
return initSchema(ctx, m.DB(), fs, identifier, migrations)
}

// apply missing migrations
for _, migration := range migrations {
if err := m.ApplyMigration(func(tx Tx) (bool, error) {
if err := m.ApplyMigration(ctx, func(tx Tx) (bool, error) {
// check if migration was already applied
var applied bool
if err := tx.QueryRow("SELECT EXISTS (SELECT 1 FROM migrations WHERE id = ?)", migration.ID).Scan(&applied); err != nil {
if err := tx.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM migrations WHERE id = ?)", migration.ID).Scan(&applied); err != nil {
return false, fmt.Errorf("failed to check if migration '%s' was already applied: %w", migration.ID, err)
} else if applied {
return false, nil
Expand All @@ -212,7 +213,7 @@ func PerformMigrations(m Migrator, fs embed.FS, identifier string, migrations []
return false, fmt.Errorf("migration '%s' failed: %w", migration.ID, err)
}
// insert migration
if _, err := tx.Exec("INSERT INTO migrations (id) VALUES (?)", migration.ID); err != nil {
if _, err := tx.Exec(ctx, "INSERT INTO migrations (id) VALUES (?)", migration.ID); err != nil {
return false, fmt.Errorf("failed to insert migration '%s': %w", migration.ID, err)
}
return true, nil
Expand All @@ -223,7 +224,7 @@ func PerformMigrations(m Migrator, fs embed.FS, identifier string, migrations []
return nil
}

func execSQLFile(tx Tx, fs embed.FS, folder, filename string) error {
func execSQLFile(ctx context.Context, tx Tx, fs embed.FS, folder, filename string) error {
path := fmt.Sprintf("migrations/%s/%s.sql", folder, filename)

// read file
Expand All @@ -233,31 +234,31 @@ func execSQLFile(tx Tx, fs embed.FS, folder, filename string) error {
}

// execute it
if _, err := tx.Exec(string(file)); err != nil {
if _, err := tx.Exec(ctx, string(file)); err != nil {
return fmt.Errorf("failed to execute %s: %w", path, err)
}
return nil
}

func initSchema(db *DB, fs embed.FS, identifier string, migrations []Migration) error {
return db.Transaction(func(tx Tx) error {
func initSchema(ctx context.Context, db *DB, fs embed.FS, identifier string, migrations []Migration) error {
return db.Transaction(ctx, func(tx Tx) error {
// init schema
if err := execSQLFile(tx, fs, identifier, "schema"); err != nil {
if err := execSQLFile(ctx, tx, fs, identifier, "schema"); err != nil {
return fmt.Errorf("failed to execute schema: %w", err)
}
// insert migration ids
for _, migration := range migrations {
if _, err := tx.Exec("INSERT INTO migrations (id) VALUES (?)", migration.ID); err != nil {
if _, err := tx.Exec(ctx, "INSERT INTO migrations (id) VALUES (?)", migration.ID); err != nil {
return fmt.Errorf("failed to insert migration '%s': %w", migration.ID, err)
}
}
return nil
})
}

func performMigration(tx Tx, fs embed.FS, kind, migration string, logger *zap.SugaredLogger) error {
func performMigration(ctx context.Context, tx Tx, fs embed.FS, kind, migration string, logger *zap.SugaredLogger) error {
logger.Infof("performing %s migration '%s'", kind, migration)
if err := execSQLFile(tx, fs, kind, fmt.Sprintf("migration_%s", migration)); err != nil {
if err := execSQLFile(ctx, tx, fs, kind, fmt.Sprintf("migration_%s", migration)); err != nil {
return err
}
logger.Infof("migration '%s' complete", migration)
Expand Down
Loading

0 comments on commit 528d943

Please sign in to comment.