Skip to content

Commit

Permalink
feat: create transactional migration files
Browse files Browse the repository at this point in the history
  • Loading branch information
bevzzz committed Nov 8, 2024
1 parent da0d8e3 commit c3320f6
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 15 deletions.
31 changes: 23 additions & 8 deletions internal/dbtest/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,18 +240,31 @@ func TestAutoMigrator_CreateSQLMigrations(t *testing.T) {
ctx := context.Background()
m := newAutoMigratorOrSkip(t, db, migrate.WithModel((*NewTable)(nil)))

migrations, err := m.CreateSQLMigrations(ctx)
require.NoError(t, err, "should create migrations successfully")
t.Run("basic", func(t *testing.T) {
migrations, err := m.CreateSQLMigrations(ctx)
require.NoError(t, err, "should create migrations successfully")

require.Len(t, migrations, 2, "expected up/down migration pair")
require.DirExists(t, migrationsDir)
checkMigrationFileContains(t, ".up.sql", "CREATE TABLE")
checkMigrationFileContains(t, ".down.sql", "DROP TABLE")
})

t.Run("transactional", func(t *testing.T) {
migrations, err := m.CreateTxSQLMigrations(ctx)
require.NoError(t, err, "should create migrations successfully")

require.Len(t, migrations, 2, "expected up/down migration pair")
require.DirExists(t, migrationsDir)
checkMigrationFileContains(t, "tx.up.sql", "CREATE TABLE", "SET statement_timeout = 0")
checkMigrationFileContains(t, "tx.down.sql", "DROP TABLE", "SET statement_timeout = 0")
})

require.Len(t, migrations, 2, "expected up/down migration pair")
require.DirExists(t, migrationsDir)
checkMigrationFileContains(t, ".up.sql", "CREATE TABLE")
checkMigrationFileContains(t, ".down.sql", "DROP TABLE")
})
}

// checkMigrationFileContains expected SQL snippet.
func checkMigrationFileContains(t *testing.T, fileSuffix string, content string) {
func checkMigrationFileContains(t *testing.T, fileSuffix string, snippets ...string) {
t.Helper()

files, err := os.ReadDir(migrationsDir)
Expand All @@ -261,7 +274,9 @@ func checkMigrationFileContains(t *testing.T, fileSuffix string, content string)
if strings.HasSuffix(f.Name(), fileSuffix) {
b, err := os.ReadFile(filepath.Join(migrationsDir, f.Name()))
require.NoError(t, err)
require.Containsf(t, string(b), content, "expected %s file to contain string", f.Name())
for _, content := range snippets {
require.Containsf(t, string(b), content, "expected %s file to contain string", f.Name())
}
return
}
}
Expand Down
36 changes: 29 additions & 7 deletions migrate/auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ func WithModel(models ...interface{}) AutoMigratorOption {
}

// WithExcludeTable tells the AutoMigrator to ignore a table in the database.
// This prevents AutoMigrator from dropping tables which may exist in the schema
// but which are not used by the application.
func WithExcludeTable(tables ...string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.excludeTables = append(m.excludeTables, tables...)
Expand Down Expand Up @@ -55,6 +57,7 @@ func WithMarkAppliedOnSuccessAuto(enabled bool) AutoMigratorOption {
}
}

// WithMigrationsDirectoryAuto overrides the default directory for migration files.
func WithMigrationsDirectoryAuto(directory string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.migrationsOpts = append(m.migrationsOpts, WithMigrationsDirectory(directory))
Expand Down Expand Up @@ -146,9 +149,9 @@ func (am *AutoMigrator) plan(ctx context.Context) (*changeset, error) {

// Migrate writes required changes to a new migration file and runs the migration.
// This will create and entry in the migrations table, making it possible to revert
// the changes with Migrator.Rollback().
// the changes with Migrator.Rollback(). MigrationOptions are passed on to Migrator.Migrate().
func (am *AutoMigrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
migrations, _, err := am.createSQLMigrations(ctx)
migrations, _, err := am.createSQLMigrations(ctx, false)
if err != nil {
return nil, fmt.Errorf("auto migrate: %w", err)
}
Expand All @@ -165,12 +168,21 @@ func (am *AutoMigrator) Migrate(ctx context.Context, opts ...MigrationOption) (*
return group, nil
}

// CreateSQLMigration writes required changes to a new migration file.
// Use migrate.Migrator to apply the generated migrations.
func (am *AutoMigrator) CreateSQLMigrations(ctx context.Context) ([]*MigrationFile, error) {
_, files, err := am.createSQLMigrations(ctx)
_, files, err := am.createSQLMigrations(ctx, true)
return files, err
}

func (am *AutoMigrator) createSQLMigrations(ctx context.Context) (*Migrations, []*MigrationFile, error) {
// CreateTxSQLMigration writes required changes to a new migration file making sure they will be executed
// in a transaction when applied. Use migrate.Migrator to apply the generated migrations.
func (am *AutoMigrator) CreateTxSQLMigrations(ctx context.Context) ([]*MigrationFile, error) {
_, files, err := am.createSQLMigrations(ctx, false)
return files, err
}

func (am *AutoMigrator) createSQLMigrations(ctx context.Context, transactional bool) (*Migrations, []*MigrationFile, error) {
changes, err := am.plan(ctx)
if err != nil {
return nil, nil, fmt.Errorf("create sql migrations: %w", err)
Expand All @@ -185,20 +197,30 @@ func (am *AutoMigrator) createSQLMigrations(ctx context.Context) (*Migrations, [
Comment: "Changes detected by bun.migrate.AutoMigrator",
})

up, err := am.createSQL(ctx, migrations, name+".up.sql", changes)
// Append .tx.up.sql or .up.sql to migration name, dependin if it should be transactional.
fname := func(direction string) string {
return name + map[bool]string{true: ".tx.", false: "."}[transactional] + direction + ".sql"
}

up, err := am.createSQL(ctx, migrations, fname("up"), changes, transactional)
if err != nil {
return nil, nil, fmt.Errorf("create sql migration up: %w", err)
}

down, err := am.createSQL(ctx, migrations, name+".down.sql", changes.GetReverse())
down, err := am.createSQL(ctx, migrations, fname("down"), changes.GetReverse(), transactional)
if err != nil {
return nil, nil, fmt.Errorf("create sql migration down: %w", err)
}
return migrations, []*MigrationFile{up, down}, nil
}

func (am *AutoMigrator) createSQL(_ context.Context, migrations *Migrations, fname string, changes *changeset) (*MigrationFile, error) {
func (am *AutoMigrator) createSQL(_ context.Context, migrations *Migrations, fname string, changes *changeset, transactional bool) (*MigrationFile, error) {
var buf bytes.Buffer

if transactional {
buf.WriteString("SET statement_timeout = 0;")
}

if err := changes.WriteTo(&buf, am.dbMigrator); err != nil {
return nil, err
}
Expand Down

0 comments on commit c3320f6

Please sign in to comment.