diff --git a/dialect/pgdialect/alter_table.go b/dialect/pgdialect/alter_table.go index 89bae8042..821191c53 100644 --- a/dialect/pgdialect/alter_table.go +++ b/dialect/pgdialect/alter_table.go @@ -62,10 +62,14 @@ func (m *migrator) Apply(ctx context.Context, changes ...interface{}) error { b, err = m.addColumn(fmter, b, change) case *migrate.DropColumn: b, err = m.dropColumn(fmter, b, change) - case *migrate.DropConstraint: - b, err = m.dropContraint(fmter, b, change) case *migrate.AddForeignKey: b, err = m.addForeignKey(fmter, b, change) + case *migrate.AddUniqueConstraint: + b, err = m.addUnique(fmter, b, change) + case *migrate.DropUniqueConstraint: + b, err = m.dropConstraint(fmter, b, change.FQN, change.Unique.Name) + case *migrate.DropConstraint: + b, err = m.dropConstraint(fmter, b, change.FQN(), change.ConstraintName) case *migrate.RenameConstraint: b, err = m.renameConstraint(fmter, b, change) case *migrate.ChangeColumnType: @@ -147,15 +151,34 @@ func (m *migrator) renameConstraint(fmter schema.Formatter, b []byte, rename *mi return b, nil } -func (m *migrator) dropContraint(fmter schema.Formatter, b []byte, drop *migrate.DropConstraint) (_ []byte, err error) { +func (m *migrator) addUnique(fmter schema.Formatter, b []byte, change *migrate.AddUniqueConstraint) (_ []byte, err error) { + b = append(b, "ALTER TABLE "...) + if b, err = change.FQN.AppendQuery(fmter, b); err != nil { + return b, err + } + + b = append(b, " ADD CONSTRAINT "...) + if change.Unique.Name != "" { + b = fmter.AppendName(b, change.Unique.Name) + } else { + // Default naming scheme for unique constraints in Postgres is __key + b = fmter.AppendName(b, fmt.Sprintf("%s_%s_key", change.FQN.Table, change.Unique.Columns)) + } + b = append(b, " UNIQUE ("...) + b, _ = change.Unique.Columns.Safe().AppendQuery(fmter, b) + b = append(b, ")"...) + + return b, nil +} + +func (m *migrator) dropConstraint(fmter schema.Formatter, b []byte, fqn schema.FQN, name string) (_ []byte, err error) { b = append(b, "ALTER TABLE "...) - fqn := drop.FQN() if b, err = fqn.AppendQuery(fmter, b); err != nil { return b, err } b = append(b, " DROP CONSTRAINT "...) - b = fmter.AppendName(b, drop.ConstraintName) + b = fmter.AppendName(b, name) return b, nil } diff --git a/dialect/pgdialect/inspector.go b/dialect/pgdialect/inspector.go index c95e95cfb..dc4ea2707 100644 --- a/dialect/pgdialect/inspector.go +++ b/dialect/pgdialect/inspector.go @@ -48,7 +48,10 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) { if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil { return state, err } + colDefs := make(map[string]sqlschema.Column) + uniqueGroups := make(map[string][]string) + for _, c := range columns { def := c.Default if c.IsSerial || c.IsIdentity { @@ -66,12 +69,25 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) { IsAutoIncrement: c.IsSerial, IsIdentity: c.IsIdentity, } + + for _, group := range c.UniqueGroups { + uniqueGroups[group] = append(uniqueGroups[group], c.Name) + } + } + + var unique []sqlschema.Unique + for name, columns := range uniqueGroups { + unique = append(unique, sqlschema.Unique{ + Name: name, + Columns: sqlschema.NewComposite(columns...), + }) } state.Tables = append(state.Tables, sqlschema.Table{ - Schema: table.Schema, - Name: table.Name, - Columns: colDefs, + Schema: table.Schema, + Name: table.Name, + Columns: colDefs, + UniqueContraints: unique, }) } @@ -106,8 +122,7 @@ type InformationSchemaColumn struct { IndentityType string `bun:"identity_type"` IsSerial bool `bun:"is_serial"` IsNullable bool `bun:"is_nullable"` - IsUnique bool `bun:"is_unique"` - UniqueGroup []string `bun:"unique_group,array"` + UniqueGroups []string `bun:"unique_groups,array"` } type ForeignKey struct { @@ -156,8 +171,7 @@ SELECT "c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial, COALESCE("c".identity_type, '') AS identity_type, "c".is_nullable = 'YES' AS is_nullable, - 'u' = ANY("c".constraint_type) AS is_unique, - "c"."constraint_name" AS unique_group + "c"."unique_groups" AS unique_groups FROM ( SELECT "table_schema", @@ -170,7 +184,7 @@ FROM ( "c".is_nullable, att.array_dims, att.identity_type, - att."constraint_name", + att."unique_groups", att."constraint_type" FROM information_schema.columns "c" LEFT JOIN ( @@ -180,7 +194,7 @@ FROM ( "c".attname AS "column_name", "c".attndims AS array_dims, "c".attidentity AS identity_type, - ARRAY_AGG(con.conname) AS "constraint_name", + ARRAY_AGG(con.conname) FILTER (WHERE con.contype = 'u') AS "unique_groups", ARRAY_AGG(con.contype) AS "constraint_type" FROM ( SELECT @@ -200,76 +214,6 @@ FROM ( ) "c" WHERE "table_schema" = ? AND "table_name" = ? ORDER BY "table_schema", "table_name", "column_name" -` - - // sqlInspectSchema retrieves column type definitions for all user-defined tables. - // Other relations, such as views and indices, as well as Posgres's internal relations are excluded. - // - // TODO: implement scanning ORM relations for RawQuery too, so that one could scan this query directly to InformationSchemaTable. - sqlInspectSchema = ` -SELECT - "t"."table_schema", - "t".table_name, - "c".column_name, - "c".data_type, - "c".character_maximum_length::integer AS varchar_len, - "c".data_type = 'ARRAY' AS is_array, - COALESCE("c".array_dims, 0) AS array_dims, - CASE - WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$') - ELSE "c".column_default - END AS "default", - "c".constraint_type = 'p' AS is_pk, - "c".is_identity = 'YES' AS is_identity, - "c".column_default = format('nextval(''%s_%s_seq''::regclass)', "t".table_name, "c".column_name) AS is_serial, - COALESCE("c".identity_type, '') AS identity_type, - "c".is_nullable = 'YES' AS is_nullable, - "c".constraint_type = 'u' AS is_unique, - "c"."constraint_name" AS unique_group -FROM information_schema.tables "t" - LEFT JOIN ( - SELECT - "table_schema", - "table_name", - "column_name", - "c".data_type, - "c".character_maximum_length, - "c".column_default, - "c".is_identity, - "c".is_nullable, - att.array_dims, - att.identity_type, - att."constraint_name", - att."constraint_type" - FROM information_schema.columns "c" - LEFT JOIN ( - SELECT - s.nspname AS table_schema, - "t".relname AS "table_name", - "c".attname AS "column_name", - "c".attndims AS array_dims, - "c".attidentity AS identity_type, - con.conname AS "constraint_name", - con.contype AS "constraint_type" - FROM ( - SELECT - conname, - contype, - connamespace, - conrelid, - conrelid AS attrelid, - UNNEST(conkey) AS attnum - FROM pg_constraint - ) con - LEFT JOIN pg_attribute "c" USING (attrelid, attnum) - LEFT JOIN pg_namespace s ON s.oid = con.connamespace - LEFT JOIN pg_class "t" ON "t".oid = con.conrelid - ) att USING (table_schema, "table_name", "column_name") - ) "c" USING (table_schema, "table_name") -WHERE table_type = 'BASE TABLE' - AND table_schema <> 'information_schema' - AND table_schema NOT LIKE 'pg_%' -ORDER BY table_schema, table_name ` // sqlInspectForeignKeys get FK definitions for user-defined tables. diff --git a/internal/dbtest/inspect_test.go b/internal/dbtest/inspect_test.go index bd758d1f9..6d3124261 100644 --- a/internal/dbtest/inspect_test.go +++ b/internal/dbtest/inspect_test.go @@ -42,7 +42,7 @@ type Office struct { type Publisher struct { ID string `bun:"publisher_id,pk,default:gen_random_uuid(),unique:office_fk"` - Name string `bun:"publisher_name,unique,notnull,unique:office_fk"` + Name string `bun:"publisher_name,notnull,unique:office_fk"` CreatedAt time.Time `bun:"created_at,default:current_timestamp"` // Writers write articles for this publisher. @@ -63,8 +63,9 @@ type PublisherToJournalist struct { type Journalist struct { bun.BaseModel `bun:"table:authors"` ID int `bun:"author_id,pk,identity"` - FirstName string `bun:",notnull"` - LastName string + FirstName string `bun:"first_name,notnull,unique:full_name"` + LastName string `bun:"last_name,notnull,unique:full_name"` + Email string `bun:"email,notnull,unique"` // Articles that this journalist has written. Articles []*Article `bun:"rel:has-many,join:author_id=author_id"` @@ -171,6 +172,9 @@ func TestDatabaseInspector_Inspect(t *testing.T) { SQLType: "bigint", }, }, + UniqueContraints: []sqlschema.Unique{ + {Columns: sqlschema.NewComposite("editor", "title")}, + }, }, { Schema: defaultSchema, @@ -185,10 +189,16 @@ func TestDatabaseInspector_Inspect(t *testing.T) { SQLType: sqltype.VarChar, }, "last_name": { - SQLType: sqltype.VarChar, - IsNullable: true, + SQLType: sqltype.VarChar, + }, + "email": { + SQLType: sqltype.VarChar, }, }, + UniqueContraints: []sqlschema.Unique{ + {Columns: sqlschema.NewComposite("first_name", "last_name")}, + {Columns: sqlschema.NewComposite("email")}, + }, }, { Schema: defaultSchema, @@ -222,6 +232,9 @@ func TestDatabaseInspector_Inspect(t *testing.T) { IsNullable: true, }, }, + UniqueContraints: []sqlschema.Unique{ + {Columns: sqlschema.NewComposite("publisher_id", "publisher_name")}, + }, }, } @@ -268,7 +281,7 @@ func mustCreateTableWithFKs(tb testing.TB, ctx context.Context, db *bun.DB, mode for _, model := range models { create := db.NewCreateTable().Model(model).WithForeignKeys() _, err := create.Exec(ctx) - require.NoError(tb, err, "must create table %q:", create.GetTableName()) + require.NoError(tb, err, "arrange: must create table %q:", create.GetTableName()) mustDropTableOnCleanup(tb, ctx, db, model) } } @@ -303,9 +316,11 @@ func cmpTables(tb testing.TB, d sqlschema.InspectorDialect, want, got []sqlschem } cmpColumns(tb, d, wt.Name, wt.Columns, gt.Columns) + cmpConstraints(tb, wt, gt) } } +// cmpColumns compares that column definitions on the tables are func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, want, got map[string]sqlschema.Column) { tb.Helper() var errs []string @@ -362,6 +377,20 @@ func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, w } } +// cmpConstraints compares constraints defined on the table with the expected ones. +func cmpConstraints(tb testing.TB, want, got sqlschema.Table) { + tb.Helper() + + // Only keep columns included in each unique constraint for comparison. + stripNames := func(uniques []sqlschema.Unique) (res []string) { + for _, u := range uniques { + res = append(res, u.Columns.String()) + } + return + } + require.ElementsMatch(tb, stripNames(want.UniqueContraints), stripNames(got.UniqueContraints), "table %q does not have expected unique constraints (listA=want, listB=got)", want.Name) +} + func tableNames(tables []sqlschema.Table) (names []string) { for i := range tables { names = append(names, tables[i].Name) @@ -441,5 +470,31 @@ func TestSchemaInspector_Inspect(t *testing.T) { require.Len(t, got.Tables, 1) cmpColumns(t, dialect.(sqlschema.InspectorDialect), "model", want, got.Tables[0].Columns) }) + + t.Run("inspect unique constraints", func(t *testing.T) { + type Model struct { + ID string `bun:",unique"` + FirstName string `bun:"first_name,unique:full_name"` + LastName string `bun:"last_name,unique:full_name"` + } + + tables := schema.NewTables(dialect) + tables.Register((*Model)(nil)) + inspector := sqlschema.NewSchemaInspector(tables) + + want := sqlschema.Table{ + Name: "models", + UniqueContraints: []sqlschema.Unique{ + {Columns: sqlschema.NewComposite("id")}, + {Name: "full_name", Columns: sqlschema.NewComposite("first_name", "last_name")}, + }, + } + + got, err := inspector.Inspect(context.Background()) + require.NoError(t, err) + + require.Len(t, got.Tables, 1) + cmpConstraints(t, want, got.Tables[0]) + }) }) } diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 0259a498c..47d28b0e8 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -209,7 +209,8 @@ func TestAutoMigrator_Run(t *testing.T) { {testChangeColumnType_AutoCast}, {testIdentity}, {testAddDropColumn}, - // {testUnique}, + {testUnique}, + {testUniqueRenamedTable}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -542,16 +543,16 @@ func testRenamedColumns(t *testing.T, db *bun.DB) { func testRenameColumnRenamesFK(t *testing.T, db *bun.DB) { type TennantBefore struct { bun.BaseModel `bun:"table:tennants"` - ID int64 `bun:",pk,identity"` + ID int64 `bun:"id,pk,identity"` Apartment int8 - NeighbourID int64 + NeighbourID int64 `bun:"neighbour_id"` Neighbour *TennantBefore `bun:"rel:has-one,join:neighbour_id=id"` } type TennantAfter struct { bun.BaseModel `bun:"table:tennants"` - TennantID int64 `bun:",pk,identity"` + TennantID int64 `bun:"tennant_id,pk,identity"` Apartment int8 NeighbourID int64 `bun:"my_neighbour"` @@ -760,6 +761,8 @@ func testUnique(t *testing.T, db *bun.DB) { FirstName string `bun:"first_name,unique:full_name"` LastName string `bun:"last_name,unique:full_name"` Birthday string `bun:"birthday,unique"` + PetName string `bun:"pet_name,unique:pet"` + PetBreed string `bun:"pet_breed,unique:pet"` } type TableAfter struct { @@ -767,8 +770,12 @@ func testUnique(t *testing.T, db *bun.DB) { FirstName string `bun:"first_name,unique:full_name"` MiddleName string `bun:"middle_name,unique:full_name"` // extend "full_name" unique group LastName string `bun:"last_name,unique:full_name"` - Birthday string `bun:"birthday"` // doesn't have to be unique any more - Email string `bun:"email,unique"` // new column, unique + + Birthday string `bun:"birthday"` // doesn't have to be unique any more + Email string `bun:"email,unique"` // new column, unique + + PetName string `bun:"pet_name,unique"` + PetBreed string `bun:"pet_breed"` // shrink "pet" unique group } wantTables := []sqlschema.Table{ @@ -796,6 +803,90 @@ func testUnique(t *testing.T, db *bun.DB) { SQLType: sqltype.VarChar, IsNullable: true, }, + "pet_name": { + SQLType: sqltype.VarChar, + IsNullable: true, + }, + "pet_breed": { + SQLType: sqltype.VarChar, + IsNullable: true, + }, + }, + UniqueContraints: []sqlschema.Unique{ + {Columns: sqlschema.NewComposite("email")}, + {Columns: sqlschema.NewComposite("pet_name")}, + // We can only be sure of the user-defined index name + {Name: "full_name", Columns: sqlschema.NewComposite("first_name", "middle_name", "last_name")}, + }, + }, + } + + ctx := context.Background() + inspect := inspectDbOrSkip(t, db) + mustResetModel(t, ctx, db, (*TableBefore)(nil)) + m := newAutoMigrator(t, db, migrate.WithModel((*TableAfter)(nil))) + + // Act + err := m.Run(ctx) + require.NoError(t, err) + + // Assert + state := inspect(ctx) + cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) +} + +func testUniqueRenamedTable(t *testing.T, db *bun.DB) { + type TableBefore struct { + bun.BaseModel `bun:"table:before"` + FirstName string `bun:"first_name,unique:full_name"` + LastName string `bun:"last_name,unique:full_name"` + Birthday string `bun:"birthday,unique"` + PetName string `bun:"pet_name,unique:pet"` + PetBreed string `bun:"pet_breed,unique:pet"` + } + + type TableAfter struct { + bun.BaseModel `bun:"table:after"` + // Expand full_name unique group and rename it. + FirstName string `bun:"first_name,unique:birth_certificate"` + LastName string `bun:"last_name,unique:birth_certificate"` + Birthday string `bun:"birthday,unique:birth_certificate"` + + // pet_name and pet_breed have their own unique indices now. + PetName string `bun:"pet_name,unique"` + PetBreed string `bun:"pet_breed,unique"` + } + + wantTables := []sqlschema.Table{ + { + Schema: db.Dialect().DefaultSchema(), + Name: "after", + Columns: map[string]sqlschema.Column{ + "first_name": { + SQLType: sqltype.VarChar, + IsNullable: true, + }, + "last_name": { + SQLType: sqltype.VarChar, + IsNullable: true, + }, + "birthday": { + SQLType: sqltype.VarChar, + IsNullable: true, + }, + "pet_name": { + SQLType: sqltype.VarChar, + IsNullable: true, + }, + "pet_breed": { + SQLType: sqltype.VarChar, + IsNullable: true, + }, + }, + UniqueContraints: []sqlschema.Unique{ + {Columns: sqlschema.NewComposite("pet_name")}, + {Columns: sqlschema.NewComposite("pet_breed")}, + {Name: "full_name", Columns: sqlschema.NewComposite("first_name", "last_name", "birthday")}, }, }, } diff --git a/migrate/diff.go b/migrate/diff.go index 341329737..99b0fbf13 100644 --- a/migrate/diff.go +++ b/migrate/diff.go @@ -37,6 +37,7 @@ AddedLoop: // Here we do not check for created / dropped columns, as well as column type changes, // because it is only possible to detect a renamed table if its signature (see state.go) did not change. d.detectColumnChanges(removed, added, false) + d.detectConstraintChanges(removed, added) // Update referenced table in all related FKs. if d.detectRenamedFKs { @@ -82,6 +83,7 @@ AddedLoop: } } d.detectColumnChanges(current, target, true) + d.detectConstraintChanges(current, target) } // Compare and update FKs ---------------- @@ -338,6 +340,8 @@ func (d *detector) makeTargetColDef(current, target sqlschema.Column) sqlschema. // detechColumnChanges finds renamed columns and, if checkType == true, columns with changed type. func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkType bool) { + fqn := schema.FQN{target.Schema, target.Name} + ChangedRenamed: for tName, tCol := range target.Columns { @@ -347,7 +351,7 @@ ChangedRenamed: if cCol, ok := current.Columns[tName]; ok { if checkType && !d.equalColumns(cCol, tCol) { d.changes.Add(&ChangeColumnType{ - FQN: schema.FQN{target.Schema, target.Name}, + FQN: fqn, Column: tName, From: cCol, To: d.makeTargetColDef(cCol, tCol), @@ -364,7 +368,7 @@ ChangedRenamed: continue } d.changes.Add(&RenameColumn{ - FQN: schema.FQN{target.Schema, target.Name}, + FQN: fqn, OldName: cName, NewName: tName, }) @@ -375,7 +379,7 @@ ChangedRenamed: } d.changes.Add(&AddColumn{ - FQN: schema.FQN{target.Schema, target.Name}, + FQN: fqn, Column: tName, ColDef: tCol, }) @@ -385,7 +389,7 @@ ChangedRenamed: for cName, cCol := range current.Columns { if _, keep := target.Columns[cName]; !keep { d.changes.Add(&DropColumn{ - FQN: schema.FQN{target.Schema, target.Name}, + FQN: fqn, Column: cName, ColDef: cCol, }) @@ -393,6 +397,37 @@ ChangedRenamed: } } +func (d *detector) detectConstraintChanges(current, target sqlschema.Table) { + fqn := schema.FQN{target.Schema, target.Name} + +Add: + for _, want := range target.UniqueContraints { + for _, got := range current.UniqueContraints { + if got.Equals(want) { + continue Add + } + } + d.changes.Add(&AddUniqueConstraint{ + FQN: fqn, + Unique: want, + }) + } + +Drop: + for _, got := range current.UniqueContraints { + for _, want := range target.UniqueContraints { + if got.Equals(want) { + continue Drop + } + } + + d.changes.Add(&DropUniqueConstraint{ + FQN: fqn, + Unique: got, + }) + } +} + // sqlschema utils ------------------------------------------------------------ // tableSet stores unique table definitions. diff --git a/migrate/operations.go b/migrate/operations.go index 4b3958b5d..c4bbc6b80 100644 --- a/migrate/operations.go +++ b/migrate/operations.go @@ -202,6 +202,7 @@ func (op *AddForeignKey) GetReverse() Operation { } } +// TODO: Rename to DropForeignKey // DropConstraint. type DropConstraint struct { FK sqlschema.FK @@ -224,6 +225,63 @@ func (op *DropConstraint) GetReverse() Operation { } } +type AddUniqueConstraint struct { + FQN schema.FQN + Unique sqlschema.Unique +} + +var _ Operation = (*AddUniqueConstraint)(nil) + +func (op *AddUniqueConstraint) GetReverse() Operation { + return &DropUniqueConstraint{ + FQN: op.FQN, + Unique: op.Unique, + } +} + +func (op *AddUniqueConstraint) DependsOn(another Operation) bool { + switch another := another.(type) { + case *AddColumn: + var sameColumn bool + for _, column := range op.Unique.Columns.Split() { + if column == another.Column { + sameColumn = true + break + } + } + return op.FQN == another.FQN && sameColumn + case *RenameTable: + return op.FQN.Schema == another.FQN.Schema && op.FQN.Table == another.NewName + case *DropUniqueConstraint: + // We want to drop the constraint with the same name before adding this one. + return op.FQN == another.FQN && op.Unique.Name == another.Unique.Name + default: + return false + } + +} + +type DropUniqueConstraint struct { + FQN schema.FQN + Unique sqlschema.Unique +} + +var _ Operation = (*DropUniqueConstraint)(nil) + +func (op *DropUniqueConstraint) DependsOn(another Operation) bool { + if rename, ok := another.(*RenameTable); ok { + return op.FQN.Schema == rename.FQN.Schema && op.FQN.Table == rename.NewName + } + return false +} + +func (op *DropUniqueConstraint) GetReverse() Operation { + return &AddUniqueConstraint{ + FQN: op.FQN, + Unique: op.Unique, + } +} + // Change column type. type ChangeColumnType struct { FQN schema.FQN diff --git a/migrate/sqlschema/inspector.go b/migrate/sqlschema/inspector.go index 4c62289a3..cf809a343 100644 --- a/migrate/sqlschema/inspector.go +++ b/migrate/sqlschema/inspector.go @@ -75,11 +75,31 @@ func (si *SchemaInspector) Inspect(ctx context.Context) (State, error) { } } + var unique []Unique + for name, group := range t.Unique { + // Create a separate unique index for single-column unique constraints + // let each dialect apply the default naming convention. + if name == "" { + for _, f := range group { + unique = append(unique, Unique{Columns: NewComposite(f.Name)}) + } + continue + } + + // Set the name if it is a "unique group", in which case the user has provided the name. + var columns []string + for _, f := range group { + columns = append(columns, f.Name) + } + unique = append(unique, Unique{Name: name, Columns: NewComposite(columns...)}) + } + state.Tables = append(state.Tables, Table{ - Schema: t.Schema, - Name: t.Name, - Model: t.ZeroIface, - Columns: columns, + Schema: t.Schema, + Name: t.Name, + Model: t.ZeroIface, + Columns: columns, + UniqueContraints: unique, }) for _, rel := range t.Relations { diff --git a/migrate/sqlschema/state.go b/migrate/sqlschema/state.go index 789145196..b6139d29d 100644 --- a/migrate/sqlschema/state.go +++ b/migrate/sqlschema/state.go @@ -2,6 +2,7 @@ package sqlschema import ( "fmt" + "slices" "strings" "github.com/uptrace/bun/schema" @@ -13,10 +14,20 @@ type State struct { } type Table struct { - Schema string - Name string - Model interface{} + // Schema containing the table. + Schema string + + // Table name. + Name string + + // Model stores a pointer to the bun's underlying Go struct for the table. + Model interface{} + + // Columns map each column name to the column type definition. Columns map[string]Column + + // UniqueConstraints defined on the table. + UniqueContraints []Unique } // T returns a fully-qualified name object for the table. @@ -131,7 +142,7 @@ type cFQN struct { // C creates a fully-qualified column name object. func C(schema, table string, columns ...string) cFQN { - return cFQN{tFQN: T(schema, table), Column: newComposite(columns...)} + return cFQN{tFQN: T(schema, table), Column: NewComposite(columns...)} } // T returns the FQN of the column's parent table. @@ -143,8 +154,9 @@ func (c cFQN) T() tFQN { // Although having duplicated column references in a FK is illegal, composite neither validates nor enforces this constraint on the caller. type composite string -// newComposite creates a composite column from a slice of column names. -func newComposite(columns ...string) composite { +// NewComposite creates a composite column from a slice of column names. +func NewComposite(columns ...string) composite { + slices.Sort(columns) return composite(strings.Join(columns, ",")) } @@ -162,9 +174,14 @@ func (c composite) Split() []string { } // Contains checks that a composite column contains every part of another composite. -func (c composite) Contains(other composite) bool { +func (c composite) contains(other composite) bool { + return c.Contains(string(other)) +} + +// Contains checks that a composite column contains the current column. +func (c composite) Contains(other string) bool { var count int - checkColumns := other.Split() + checkColumns := composite(other).Split() wantCount := len(checkColumns) for _, check := range checkColumns { @@ -187,7 +204,7 @@ func (c composite) Replace(oldColumn, newColumn string) composite { for i, column := range columns { if column == oldColumn { columns[i] = newColumn - return newComposite(columns...) + return NewComposite(columns...) } } return c @@ -242,9 +259,9 @@ func (fk *FK) dependsT(t tFQN) (ok bool, cols []*cFQN) { // depends on C("a", "b", "c_1"), C("a", "b", "c_2"), C("w", "x", "y_1"), and C("w", "x", "y_2") func (fk *FK) dependsC(c cFQN) (bool, *cFQN) { switch { - case fk.From.Column.Contains(c.Column): + case fk.From.Column.contains(c.Column): return true, &fk.From - case fk.To.Column.Contains(c.Column): + case fk.To.Column.contains(c.Column): return true, &fk.To } return false, nil @@ -347,3 +364,14 @@ func (r RefMap) Deleted() (fks []FK) { } return } + +// Unique represents a unique constraint defined on 1 or more columns. +type Unique struct { + Name string + Columns composite +} + +// Equals checks that two unique constraint are the same, assuming both are defined for the same table. +func (u Unique) Equals(other Unique) bool { + return u.Columns == other.Columns +}