From 886d0a5b18aba272f1c86af2a2cf68ce4c8879f2 Mon Sep 17 00:00:00 2001 From: bevzzz Date: Fri, 17 Nov 2023 16:42:00 +0100 Subject: [PATCH] feat: detect renamed columns --- dialect/pgdialect/alter_table.go | 8 ++ internal/dbtest/migrate_test.go | 239 ++++++++++++++++++++++--------- migrate/auto.go | 88 +++++++++++- migrate/sqlschema/migrator.go | 1 + 4 files changed, 259 insertions(+), 77 deletions(-) diff --git a/dialect/pgdialect/alter_table.go b/dialect/pgdialect/alter_table.go index f0b1c948b..71b090e46 100644 --- a/dialect/pgdialect/alter_table.go +++ b/dialect/pgdialect/alter_table.go @@ -57,3 +57,11 @@ func (m *Migrator) RenameConstraint(ctx context.Context, schema, table, oldName, ) return m.exec(ctx, q) } + +func (m *Migrator) RenameColumn(ctx context.Context, schema, table, oldName, newName string) error { + q := m.db.NewRaw( + "ALTER TABLE ?.? RENAME COLUMN ? TO ?", + bun.Ident(schema), bun.Ident(table), bun.Ident(oldName), bun.Ident(newName), + ) + return m.exec(ctx, q) +} diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 79c53e1cc..0a2d60e15 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -164,6 +164,37 @@ func testMigrateUpError(t *testing.T, db *bun.DB) { require.Equal(t, []string{"down2", "down1"}, history) } +// newAutoMigrator creates an AutoMigrator configured to use test migratins/locks tables. +// If the dialect doesn't support schema inspections or migrations, the test will fail with the corresponding error. +func newAutoMigrator(tb testing.TB, db *bun.DB, opts ...migrate.AutoMigratorOption) *migrate.AutoMigrator { + tb.Helper() + + opts = append(opts, + migrate.WithTableNameAuto(migrationsTable), + migrate.WithLocksTableNameAuto(migrationLocksTable), + ) + + m, err := migrate.NewAutoMigrator(db, opts...) + require.NoError(tb, err) + return m +} + +// inspectDbOrSkip returns a function to inspect the current state of the database. +// It calls tb.Skip() if the current dialect doesn't support database inpection and +// fails the test if the inspector cannot successfully retrieve database state. +func inspectDbOrSkip(tb testing.TB, db *bun.DB) func(context.Context) sqlschema.State { + tb.Helper() + inspector, err := sqlschema.NewInspector(db) + if err != nil { + tb.Skip(err) + } + return func(ctx context.Context) sqlschema.State { + state, err := inspector.Inspect(ctx) + require.NoError(tb, err) + return state + } +} + func TestAutoMigrator_Run(t *testing.T) { tests := []struct { @@ -174,6 +205,8 @@ func TestAutoMigrator_Run(t *testing.T) { {testAlterForeignKeys}, {testCustomFKNameFunc}, {testForceRenameFK}, + {testRenamedColumns}, + {testRenameColumnRenamesFK}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -198,28 +231,19 @@ func testRenameTable(t *testing.T, db *bun.DB) { // Arrange ctx := context.Background() - dbInspector, err := sqlschema.NewInspector(db) - if err != nil { - t.Skip(err) - } + inspect := inspectDbOrSkip(t, db) mustResetModel(t, ctx, db, (*initial)(nil)) mustDropTableOnCleanup(t, ctx, db, (*changed)(nil)) - - m, err := migrate.NewAutoMigrator(db, - migrate.WithTableNameAuto(migrationsTable), - migrate.WithLocksTableNameAuto(migrationLocksTable), - migrate.WithModel((*changed)(nil))) - require.NoError(t, err) + m := newAutoMigrator(t, db, migrate.WithModel((*changed)(nil))) // Act - err = m.Run(ctx) + err := m.Run(ctx) require.NoError(t, err) // Assert - state, err := dbInspector.Inspect(ctx) - require.NoError(t, err) - + state := inspect(ctx) tables := state.Tables + require.Len(t, tables, 1) require.Equal(t, "changed", tables[0].Name) } @@ -238,28 +262,19 @@ func testCreateDropTable(t *testing.T, db *bun.DB) { // Arrange ctx := context.Background() - dbInspector, err := sqlschema.NewInspector(db) - if err != nil { - t.Skip(err) - } + inspect := inspectDbOrSkip(t, db) mustResetModel(t, ctx, db, (*DropMe)(nil)) mustDropTableOnCleanup(t, ctx, db, (*CreateMe)(nil)) - - m, err := migrate.NewAutoMigrator(db, - migrate.WithTableNameAuto(migrationsTable), - migrate.WithLocksTableNameAuto(migrationLocksTable), - migrate.WithModel((*CreateMe)(nil))) - require.NoError(t, err) + m := newAutoMigrator(t, db, migrate.WithModel((*CreateMe)(nil))) // Act - err = m.Run(ctx) + err := m.Run(ctx) require.NoError(t, err) // Assert - state, err := dbInspector.Inspect(ctx) - require.NoError(t, err) - + state := inspect(ctx) tables := state.Tables + require.Len(t, tables, 1) require.Equal(t, "createme", tables[0].Name) } @@ -301,10 +316,7 @@ func testAlterForeignKeys(t *testing.T, db *bun.DB) { // Arrange ctx := context.Background() - dbInspector, err := sqlschema.NewInspector(db) - if err != nil { - t.Skip(err) - } + inspect := inspectDbOrSkip(t, db) db.RegisterModel((*ThingsToOwner)(nil)) mustCreateTableWithFKs(t, ctx, db, @@ -313,23 +325,18 @@ func testAlterForeignKeys(t *testing.T, db *bun.DB) { ) mustDropTableOnCleanup(t, ctx, db, (*ThingsToOwner)(nil)) - m, err := migrate.NewAutoMigrator(db, - migrate.WithTableNameAuto(migrationsTable), - migrate.WithLocksTableNameAuto(migrationLocksTable), - migrate.WithModel((*ThingCommon)(nil)), - migrate.WithModel((*OwnerCommon)(nil)), - migrate.WithModel((*ThingsToOwner)(nil)), - ) - require.NoError(t, err) + m := newAutoMigrator(t, db, migrate.WithModel( + (*ThingCommon)(nil), + (*OwnerCommon)(nil), + (*ThingsToOwner)(nil), + )) // Act - err = m.Run(ctx) + err := m.Run(ctx) require.NoError(t, err) // Assert - state, err := dbInspector.Inspect(ctx) - require.NoError(t, err) - + state := inspect(ctx) defaultSchema := db.Dialect().DefaultSchema() // Crated 2 new constraints @@ -377,10 +384,7 @@ func testForceRenameFK(t *testing.T, db *bun.DB) { } ctx := context.Background() - dbInspector, err := sqlschema.NewInspector(db) - if err != nil { - t.Skip(err) - } + inspect := inspectDbOrSkip(t, db) mustCreateTableWithFKs(t, ctx, db, (*Owner)(nil), @@ -388,31 +392,27 @@ func testForceRenameFK(t *testing.T, db *bun.DB) { ) mustDropTableOnCleanup(t, ctx, db, (*Person)(nil)) - m, err := migrate.NewAutoMigrator(db, - migrate.WithTableNameAuto(migrationsTable), - migrate.WithLocksTableNameAuto(migrationLocksTable), + m := newAutoMigrator(t, db, migrate.WithModel( (*Person)(nil), (*PersonalThing)(nil), ), + migrate.WithRenameFK(true), migrate.WithFKNameFunc(func(fk sqlschema.FK) string { return strings.Join([]string{ fk.From.Table, fk.To.Table, "fkey", }, "_") }), - migrate.WithRenameFK(true), ) - require.NoError(t, err) // Act - err = m.Run(ctx) + err := m.Run(ctx) require.NoError(t, err) // Assert - state, err := dbInspector.Inspect(ctx) - require.NoError(t, err) - + state := inspect(ctx) schema := db.Dialect().DefaultSchema() + wantName, ok := state.FKs[sqlschema.FK{ From: sqlschema.C(schema, "things", "owner_id"), To: sqlschema.C(schema, "people", "id"), @@ -445,33 +445,27 @@ func testCustomFKNameFunc(t *testing.T, db *bun.DB) { } ctx := context.Background() - dbInspector, err := sqlschema.NewInspector(db) - if err != nil { - t.Skip(err) - } + inspect := inspectDbOrSkip(t, db) mustCreateTableWithFKs(t, ctx, db, (*Table)(nil), (*Column)(nil), ) - m, err := migrate.NewAutoMigrator(db, - migrate.WithTableNameAuto(migrationsTable), - migrate.WithLocksTableNameAuto(migrationLocksTable), + m := newAutoMigrator(t, db, migrate.WithFKNameFunc(func(sqlschema.FK) string { return "test_fkey" }), - migrate.WithModel((*TableM)(nil)), - migrate.WithModel((*ColumnM)(nil)), + migrate.WithModel( + (*TableM)(nil), + (*ColumnM)(nil), + ), ) - require.NoError(t, err) // Act - err = m.Run(ctx) + err := m.Run(ctx) require.NoError(t, err) // Assert - state, err := dbInspector.Inspect(ctx) - require.NoError(t, err) - + state := inspect(ctx) fkName := state.FKs[sqlschema.FK{ From: sqlschema.C(db.Dialect().DefaultSchema(), "columns", "attrelid"), To: sqlschema.C(db.Dialect().DefaultSchema(), "tables", "oid"), @@ -479,6 +473,109 @@ func testCustomFKNameFunc(t *testing.T, db *bun.DB) { require.Equal(t, fkName, "test_fkey") } +func testRenamedColumns(t *testing.T, db *bun.DB) { + // Database state + type Original struct { + ID int64 `bun:",pk"` + } + + type Model1 struct { + bun.BaseModel `bun:"models"` + ID string `bun:",pk"` + DoNotRename string `bun:",default:2"` + ColumnTwo int `bun:",default:2"` + } + + // Model state + type Renamed struct { + bun.BaseModel `bun:"renamed"` + Count int64 `bun:",pk"` // renamed column in renamed model + } + + type Model2 struct { + bun.BaseModel `bun:"models"` + ID string `bun:",pk"` + DoNotRename string `bun:",default:2"` + SecondColumn int `bun:",default:2"` // renamed column + } + + ctx := context.Background() + inspect := inspectDbOrSkip(t, db) + mustResetModel(t, ctx, db, + (*Original)(nil), + (*Model1)(nil), + ) + mustDropTableOnCleanup(t, ctx, db, (*Renamed)(nil)) + m := newAutoMigrator(t, db, migrate.WithModel( + (*Renamed)(nil), + (*Model2)(nil), + )) + + // Act + err := m.Run(ctx) + require.NoError(t, err) + + // Assert + state := inspect(ctx) + + require.Len(t, state.Tables, 2) + + var renamed, model2 sqlschema.Table + for _, tbl := range state.Tables { + switch tbl.Name { + case "renamed": + renamed = tbl + case "models": + model2 = tbl + } + } + + require.Contains(t, renamed.Columns, "count") + require.Contains(t, model2.Columns, "second_column") + require.Contains(t, model2.Columns, "do_not_rename") +} + +func testRenameColumnRenamesFK(t *testing.T, db *bun.DB) { + type TennantBefore struct { + bun.BaseModel `bun:"table:tennants"` + ID int64 `bun:",pk,identity"` + Apartment int8 + NeighbourID int64 + + Neighbour *TennantBefore `bun:"rel:has-one,join:neighbour_id=id"` + } + + type TennantAfter struct { + bun.BaseModel `bun:"table:tennants"` + TennantID int64 `bun:",pk,identity"` + Apartment int8 + NeighbourID int64 `bun:"my_neighbour"` + + Neighbour *TennantAfter `bun:"rel:has-one,join:my_neighbour=tennant_id"` + } + + ctx := context.Background() + inspect := inspectDbOrSkip(t, db) + mustCreateTableWithFKs(t, ctx, db, (*TennantBefore)(nil)) + m := newAutoMigrator(t, db, + migrate.WithRenameFK(true), + migrate.WithModel((*TennantAfter)(nil)), + ) + + // Act + err := m.Run(ctx) + require.NoError(t, err) + + // Assert + state := inspect(ctx) + + fkName := state.FKs[sqlschema.FK{ + From: sqlschema.C(db.Dialect().DefaultSchema(), "tennants", "my_neighbour"), + To: sqlschema.C(db.Dialect().DefaultSchema(), "tennants", "tennant_id"), + }] + require.Equal(t, "tennants_my_neighbour_fkey", fkName) +} + // TODO: rewrite these tests into AutoMigrator tests, Diff should be moved to migrate/internal package func TestDiff(t *testing.T) { type Journal struct { diff --git a/migrate/auto.go b/migrate/auto.go index 87be34f42..5750cab00 100644 --- a/migrate/auto.go +++ b/migrate/auto.go @@ -259,11 +259,14 @@ func newDetector(got, want sqlschema.State, opts ...DiffOption) *detector { } func (d *detector) DetectChanges() Changeset { - // Discover CREATE/RENAME/DROP TABLE targetTables := newTableSet(d.target.Tables...) currentTables := newTableSet(d.current.Tables...) // keeps state (which models still need to be checked) + // These table sets record "updates" to the targetTables set. + created := newTableSet() + renamed := newTableSet() + addedTables := targetTables.Sub(currentTables) AddedLoop: for _, added := range addedTables.Values() { @@ -276,13 +279,15 @@ AddedLoop: To: added.Name, }) - // TODO: check for altered columns. + d.detectRenamedColumns(removed, added) // Update referenced table in all related FKs if d.detectRenamedFKs { d.refMap.UpdateT(removed.T(), added.T()) } + renamed.Add(added) + // Do not check this model further, we know it was renamed. currentTables.Remove(removed.Name) continue AddedLoop @@ -294,18 +299,36 @@ AddedLoop: Name: added.Name, Model: added.Model, }) + created.Add(added) } // Tables that aren't present anymore and weren't renamed or left untouched were deleted. - for _, t := range currentTables.Sub(targetTables).Values() { + dropped := currentTables.Sub(targetTables) + for _, t := range dropped.Values() { d.changes.Add(&DropTable{ Schema: t.Schema, Name: t.Name, }) } - // Compare and update FKs + // Detect changes in existing tables that weren't renamed + // TODO: here having State.Tables be a map[string]Table would be much more convenient. + // Then we can alse retire tableSet, or at least simplify it to a certain extent. + curEx := currentTables.Sub(dropped) + tarEx := targetTables.Sub(created).Sub(renamed) + for _, target := range tarEx.Values() { + // This step is redundant if we have map[string]Table + var current sqlschema.Table + for _, cur := range curEx.Values() { + if cur.Name == target.Name { + current = cur + break + } + } + d.detectRenamedColumns(current, target) + } + // Compare and update FKs ---------------- currentFKs := make(map[sqlschema.FK]string) for k, v := range d.current.FKs { currentFKs[k] = v @@ -355,6 +378,28 @@ func (d detector) canRename(t1, t2 sqlschema.Table) bool { return t1.Schema == t2.Schema && sqlschema.EqualSignatures(t1, t2) } +func (d *detector) detectRenamedColumns(removed, added sqlschema.Table) { + for aName, aCol := range added.Columns { + // This column exists in the database, so it wasn't renamed + if _, ok := removed.Columns[aName]; ok { + continue + } + for rName, rCol := range removed.Columns { + if aCol != rCol { + continue + } + d.changes.Add(&RenameColumn{ + Schema: added.Schema, + Table: added.Name, + From: rName, + To: aName, + }) + delete(removed.Columns, rName) // no need to check this column again + d.refMap.UpdateC(sqlschema.C(added.Schema, added.Name, rName), aName) + } + } +} + // Changeset is a set of changes that alter database state. type Changeset struct { operations []Operation @@ -458,8 +503,9 @@ func (op *RenameTable) Func(m sqlschema.Migrator) MigrationFunc { func (op *RenameTable) GetReverse() Operation { return &RenameTable{ - From: op.To, - To: op.From, + Schema: op.Schema, + From: op.To, + To: op.From, } } @@ -583,6 +629,7 @@ func (op *DropFK) GetReverse() Operation { } } +// RenameFK type RenameFK struct { FK sqlschema.FK From string @@ -610,6 +657,35 @@ func (op *RenameFK) GetReverse() Operation { } } +// RenameColumn +type RenameColumn struct { + Schema string + Table string + From string + To string +} + +var _ Operation = (*RenameColumn)(nil) + +func (op RenameColumn) String() string { + return "" +} + +func (op *RenameColumn) Func(m sqlschema.Migrator) MigrationFunc { + return func(ctx context.Context, db *bun.DB) error { + return m.RenameColumn(ctx, op.Schema, op.Table, op.From, op.To) + } +} + +func (op *RenameColumn) GetReverse() Operation { + return &RenameColumn{ + Schema: op.Schema, + Table: op.Table, + From: op.To, + To: op.From, + } +} + // sqlschema utils ------------------------------------------------------------ // tableSet stores unique table definitions. diff --git a/migrate/sqlschema/migrator.go b/migrate/sqlschema/migrator.go index d8b555a35..befdb8ad5 100644 --- a/migrate/sqlschema/migrator.go +++ b/migrate/sqlschema/migrator.go @@ -20,6 +20,7 @@ type Migrator interface { AddContraint(ctx context.Context, fk FK, name string) error DropContraint(ctx context.Context, schema, table, name string) error RenameConstraint(ctx context.Context, schema, table, oldName, newName string) error + RenameColumn(ctx context.Context, schema, table, oldName, newName string) error } // Migrator is a dialect-agnostic wrapper for sqlschema.Dialect