diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index bcfa6d4c9..52dd2c276 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -240,18 +240,31 @@ func TestAutoMigrator_CreateSQLMigrations(t *testing.T) { ctx := context.Background() m := newAutoMigratorOrSkip(t, db, migrate.WithModel((*NewTable)(nil))) - migrations, err := m.CreateSQLMigrations(ctx) - require.NoError(t, err, "should create migrations successfully") + t.Run("basic", func(t *testing.T) { + migrations, err := m.CreateSQLMigrations(ctx) + require.NoError(t, err, "should create migrations successfully") + + require.Len(t, migrations, 2, "expected up/down migration pair") + require.DirExists(t, migrationsDir) + checkMigrationFileContains(t, ".up.sql", "CREATE TABLE") + checkMigrationFileContains(t, ".down.sql", "DROP TABLE") + }) + + t.Run("transactional", func(t *testing.T) { + migrations, err := m.CreateTxSQLMigrations(ctx) + require.NoError(t, err, "should create migrations successfully") + + require.Len(t, migrations, 2, "expected up/down migration pair") + require.DirExists(t, migrationsDir) + checkMigrationFileContains(t, "tx.up.sql", "CREATE TABLE", "SET statement_timeout = 0") + checkMigrationFileContains(t, "tx.down.sql", "DROP TABLE", "SET statement_timeout = 0") + }) - require.Len(t, migrations, 2, "expected up/down migration pair") - require.DirExists(t, migrationsDir) - checkMigrationFileContains(t, ".up.sql", "CREATE TABLE") - checkMigrationFileContains(t, ".down.sql", "DROP TABLE") }) } // checkMigrationFileContains expected SQL snippet. -func checkMigrationFileContains(t *testing.T, fileSuffix string, content string) { +func checkMigrationFileContains(t *testing.T, fileSuffix string, snippets ...string) { t.Helper() files, err := os.ReadDir(migrationsDir) @@ -261,7 +274,9 @@ func checkMigrationFileContains(t *testing.T, fileSuffix string, content string) if strings.HasSuffix(f.Name(), fileSuffix) { b, err := os.ReadFile(filepath.Join(migrationsDir, f.Name())) require.NoError(t, err) - require.Containsf(t, string(b), content, "expected %s file to contain string", f.Name()) + for _, content := range snippets { + require.Containsf(t, string(b), content, "expected %s file to contain string", f.Name()) + } return } } diff --git a/migrate/auto.go b/migrate/auto.go index 44f595ba5..10ab3f2a7 100644 --- a/migrate/auto.go +++ b/migrate/auto.go @@ -25,6 +25,8 @@ func WithModel(models ...interface{}) AutoMigratorOption { } // WithExcludeTable tells the AutoMigrator to ignore a table in the database. +// This prevents AutoMigrator from dropping tables which may exist in the schema +// but which are not used by the application. func WithExcludeTable(tables ...string) AutoMigratorOption { return func(m *AutoMigrator) { m.excludeTables = append(m.excludeTables, tables...) @@ -55,6 +57,7 @@ func WithMarkAppliedOnSuccessAuto(enabled bool) AutoMigratorOption { } } +// WithMigrationsDirectoryAuto overrides the default directory for migration files. func WithMigrationsDirectoryAuto(directory string) AutoMigratorOption { return func(m *AutoMigrator) { m.migrationsOpts = append(m.migrationsOpts, WithMigrationsDirectory(directory)) @@ -146,9 +149,9 @@ func (am *AutoMigrator) plan(ctx context.Context) (*changeset, error) { // Migrate writes required changes to a new migration file and runs the migration. // This will create and entry in the migrations table, making it possible to revert -// the changes with Migrator.Rollback(). +// the changes with Migrator.Rollback(). MigrationOptions are passed on to Migrator.Migrate(). func (am *AutoMigrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) { - migrations, _, err := am.createSQLMigrations(ctx) + migrations, _, err := am.createSQLMigrations(ctx, false) if err != nil { return nil, fmt.Errorf("auto migrate: %w", err) } @@ -165,12 +168,21 @@ func (am *AutoMigrator) Migrate(ctx context.Context, opts ...MigrationOption) (* return group, nil } +// CreateSQLMigration writes required changes to a new migration file. +// Use migrate.Migrator to apply the generated migrations. func (am *AutoMigrator) CreateSQLMigrations(ctx context.Context) ([]*MigrationFile, error) { - _, files, err := am.createSQLMigrations(ctx) + _, files, err := am.createSQLMigrations(ctx, true) return files, err } -func (am *AutoMigrator) createSQLMigrations(ctx context.Context) (*Migrations, []*MigrationFile, error) { +// CreateTxSQLMigration writes required changes to a new migration file making sure they will be executed +// in a transaction when applied. Use migrate.Migrator to apply the generated migrations. +func (am *AutoMigrator) CreateTxSQLMigrations(ctx context.Context) ([]*MigrationFile, error) { + _, files, err := am.createSQLMigrations(ctx, false) + return files, err +} + +func (am *AutoMigrator) createSQLMigrations(ctx context.Context, transactional bool) (*Migrations, []*MigrationFile, error) { changes, err := am.plan(ctx) if err != nil { return nil, nil, fmt.Errorf("create sql migrations: %w", err) @@ -185,20 +197,30 @@ func (am *AutoMigrator) createSQLMigrations(ctx context.Context) (*Migrations, [ Comment: "Changes detected by bun.migrate.AutoMigrator", }) - up, err := am.createSQL(ctx, migrations, name+".up.sql", changes) + // Append .tx.up.sql or .up.sql to migration name, dependin if it should be transactional. + fname := func(direction string) string { + return name + map[bool]string{true: ".tx.", false: "."}[transactional] + direction + ".sql" + } + + up, err := am.createSQL(ctx, migrations, fname("up"), changes, transactional) if err != nil { return nil, nil, fmt.Errorf("create sql migration up: %w", err) } - down, err := am.createSQL(ctx, migrations, name+".down.sql", changes.GetReverse()) + down, err := am.createSQL(ctx, migrations, fname("down"), changes.GetReverse(), transactional) if err != nil { return nil, nil, fmt.Errorf("create sql migration down: %w", err) } return migrations, []*MigrationFile{up, down}, nil } -func (am *AutoMigrator) createSQL(_ context.Context, migrations *Migrations, fname string, changes *changeset) (*MigrationFile, error) { +func (am *AutoMigrator) createSQL(_ context.Context, migrations *Migrations, fname string, changes *changeset, transactional bool) (*MigrationFile, error) { var buf bytes.Buffer + + if transactional { + buf.WriteString("SET statement_timeout = 0;") + } + if err := changes.WriteTo(&buf, am.dbMigrator); err != nil { return nil, err }