Skip to content

Commit

Permalink
chore: additional migration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kian99 committed Dec 19, 2024
1 parent a498392 commit a09ff0c
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 15 deletions.
35 changes: 20 additions & 15 deletions internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package db
import (
"context"
"database/sql"
"embed"
stderr "errors"
"fmt"
"path"
Expand All @@ -26,7 +27,7 @@ import (
// Use a custom table name so that we don't run into collisions when OpenFGA or other tools
// are using the same DB as JIMM in our Docker Compose setup.
const (
migrationTableName = "jimm_schema_migrations"
MigrationTableName = "jimm_schema_migrations"
)

// A Database provides access to the database model. A Database instance
Expand Down Expand Up @@ -87,46 +88,50 @@ func (d *Database) Migrate(ctx context.Context) error {
if d == nil || d.DB == nil {
return errors.E(op, errors.CodeServerConfiguration, "database not configured")
}
db := d.DB.WithContext(ctx)

return d.migrateFromSource(ctx, dbmodel.SQL, path.Join("sql", d.DB.Name()))
}

func (d *Database) migrateFromSource(ctx context.Context, fs embed.FS, sqlPath string) error {
sqlDir, err := iofs.New(fs, sqlPath)
if err != nil {
return fmt.Errorf("unable to create new sql filesys: %w", err)
}

db := d.DB.WithContext(ctx)
sqlDB, err := db.DB()
if err != nil {
return errors.E(op, fmt.Errorf("unable to obtain raw DB: %w", err))
return fmt.Errorf("failed to obtain raw DB: %w", err)
}
conn, err := sqlDB.Conn(ctx)
if err != nil {
return fmt.Errorf("failed to obtain DB conn: %w", err)
}

sqlDir, err := iofs.New(dbmodel.SQL, path.Join("sql", db.Name()))
if err != nil {
return errors.E(op, fmt.Errorf("unable to create new sql filesys: %w", err))
}

driver, err := postgres.WithConnection(ctx, conn, &postgres.Config{MigrationsTable: MigrationTableName})
if err != nil {
return errors.E(op, fmt.Errorf("unable to create new driver instance: %w", err))
return fmt.Errorf("unable to create new driver instance: %w", err)
}

// DB name is left blank because it is contained in the driver/DB connection.
m, err := migrate.NewWithInstance("iofs", sqlDir, "", driver)
if err != nil {
return errors.E(op, fmt.Errorf("unable to create new migrator: %w", err))
return fmt.Errorf("unable to create new migrator: %w", err)
}
defer m.Close()

// Setup custom logger for consistent output.
logger := migrationLogger{logger: zapctx.Logger(ctx), verbose: false}
logger := migrationLogger{logger: zapctx.Logger(ctx), verbose: true}
m.Log = logger

if err := d.handleDeprecatedMigrations(ctx, m); err != nil {
return errors.E(op, fmt.Errorf("failed to handle deprecated migrations: %w", err))
return fmt.Errorf("failed to handle deprecated migrations: %w", err)
}

v, dirty, err := m.Version()
if err != nil {
if !stderr.Is(err, migrate.ErrNilVersion) {
return errors.E(op, fmt.Errorf("failed to get db version: %w", err))
return fmt.Errorf("failed to get db version: %w", err)
}
}

Expand All @@ -135,13 +140,13 @@ func (d *Database) Migrate(ctx context.Context) error {
workingVersion := int(v) - 1
zapctx.Info(ctx, "dirty database, reverting version", zap.Int("version", workingVersion))
if err := m.Force(workingVersion); err != nil {
return errors.E(op, fmt.Errorf("failed to fix dirty db version: %w", err))
return fmt.Errorf("failed to fix dirty db version: %w", err)
}
}

if err := m.Up(); err != nil {
if !stderr.Is(err, migrate.ErrNoChange) {
return errors.E(op, fmt.Errorf("failed to migrate db: %w", err))
return fmt.Errorf("failed to migrate db: %w", err)
}
}

Expand Down
52 changes: 52 additions & 0 deletions internal/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,42 @@ package db_test

import (
"context"
"embed"
"testing"

qt "github.com/frankban/quicktest"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source/iofs"

"github.com/canonical/jimm/v3/internal/db"
"github.com/canonical/jimm/v3/internal/dbmodel"
"github.com/canonical/jimm/v3/internal/errors"
)

//go:embed testdata/invalidmigrations/*.sql
var invalidSQL embed.FS

//go:embed testdata/validmigrations/*.sql
var validSQL embed.FS

func newTestMigrator(c *qt.C, d *db.Database, fs embed.FS, sqlPath string) *migrate.Migrate {
sqlDir, err := iofs.New(fs, sqlPath)
c.Assert(err, qt.IsNil)

sqlDB, err := d.DB.DB()
c.Assert(err, qt.IsNil)

driver, err := postgres.WithInstance(sqlDB, &postgres.Config{MigrationsTable: db.MigrationTableName})
c.Assert(err, qt.IsNil)

m, err := migrate.NewWithInstance("iofs", sqlDir, "", driver)
c.Assert(err, qt.IsNil)
c.Cleanup(func() { m.Close() })

return m
}

// dbSuite contains a suite of database tests that are run against
// different database engines.
type dbSuite struct {
Expand All @@ -30,6 +57,31 @@ func (s *dbSuite) TestMigrate(c *qt.C) {
c.Assert(err, qt.IsNil)
}

// TestFailedMigration verifies a failed migration will cause a dirty migration
// that when fixed should automatically work.
func (s *dbSuite) TestFailedMigration(c *qt.C) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

err := s.Database.MigrateFromSource(ctx, invalidSQL, "testdata/invalidmigrations")
c.Assert(err, qt.Not(qt.IsNil))

m := newTestMigrator(c, s.Database, invalidSQL, "testdata/invalidmigrations")
v, dirty, err := m.Version()
c.Assert(err, qt.IsNil)
c.Assert(dirty, qt.IsTrue)
c.Assert(v, qt.Equals, uint(2))

err = s.Database.MigrateFromSource(ctx, validSQL, "testdata/validmigrations")
c.Assert(err, qt.IsNil)

m = newTestMigrator(c, s.Database, validSQL, "testdata/validmigrations")
v, dirty, err = m.Version()
c.Assert(err, qt.IsNil)
c.Assert(dirty, qt.IsFalse)
c.Assert(v, qt.Equals, uint(3))
}

func TestMigrateUnconfiguredDatabase(t *testing.T) {
c := qt.New(t)

Expand Down
9 changes: 9 additions & 0 deletions internal/db/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

package db

import (
"context"
"embed"
)

var (
JwksKind = jwksKind
JwksPublicKeyTag = jwksPublicKeyTag
Expand All @@ -12,3 +17,7 @@ var (
OAuthSessionStoreSecretTag = oauthSessionStoreSecretTag
NewUUID = &newUUID
)

func (d *Database) MigrateFromSource(ctx context.Context, fs embed.FS, sqlPath string) error {
return d.migrateFromSource(ctx, fs, sqlPath)
}
6 changes: 6 additions & 0 deletions internal/db/testdata/invalidmigrations/001_initial.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- 1_1.sql initialises an empty database.

CREATE TABLE IF NOT EXISTS test (
id BIGSERIAL PRIMARY KEY,
time TIMESTAMP WITH TIME ZONE
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- 1_2.sql is a migration that adds an invalid table.

CREATE TABLE IF NOT EXISTS invalid (
id BIGSERIAL PRIMARY KEY,
time INVALIDTYPE
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- 1_3.sql is a migration that adds an controller table.

CREATE TABLE IF NOT EXISTS controller (
id BIGSERIAL PRIMARY KEY,
name TEXT
);
6 changes: 6 additions & 0 deletions internal/db/testdata/validmigrations/001_initial.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- 1_1.sql initialises an empty database.

CREATE TABLE IF NOT EXISTS test (
id BIGSERIAL PRIMARY KEY,
time TIMESTAMP WITH TIME ZONE
);
5 changes: 5 additions & 0 deletions internal/db/testdata/validmigrations/002_add_table.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- 1_2.sql is a migration that adds a valid table.

CREATE TABLE IF NOT EXISTS valid (
id BIGSERIAL PRIMARY KEY
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- 1_3.sql is a migration that adds an controller table.

CREATE TABLE IF NOT EXISTS controller (
id BIGSERIAL PRIMARY KEY,
name TEXT
);
18 changes: 18 additions & 0 deletions internal/dbmodel/sql/migrations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Notes

Migrations are applied using [golang-migrate](https://github.com/golang-migrate/migrate).
Previously migrations were applied using a home-grown solution and a `versions` table.
The switch to golang-migrate was done to simplify our code.

To cater for existing deployments, we handle the case that the `versions` table still
exists and "force" the new migration tool to align with the old.

No "down" migrations are used currently. We aim to work with the philosophy that application
changes should be done such that we deprecate the use of any tables/columns, deploy these changes
and then later create a migration to make permanent changes to the DB. Ideally always moving
migrations forwards and never backwards.

By default, golang-migrate does not run migrations in a transactions.
**But**, the [postgres](https://github.com/golang-migrate/migrate/blob/master/database/postgres/README.md#multi-statement-mode) driver has slightly unique behavior - "running multiple SQL statements in one Exec executes them inside a transaction".
So each migration file is in fact run in a transaction when using PostgreSQL. To be more explicit,
one can wrap the migration file with BEGIN/COMMIT instructions.

0 comments on commit a09ff0c

Please sign in to comment.