diff --git a/dialect/pgdialect/alter_table.go b/dialect/pgdialect/alter_table.go
index 89bae8042..821191c53 100644
--- a/dialect/pgdialect/alter_table.go
+++ b/dialect/pgdialect/alter_table.go
@@ -62,10 +62,14 @@ func (m *migrator) Apply(ctx context.Context, changes ...interface{}) error {
b, err = m.addColumn(fmter, b, change)
case *migrate.DropColumn:
b, err = m.dropColumn(fmter, b, change)
- case *migrate.DropConstraint:
- b, err = m.dropContraint(fmter, b, change)
case *migrate.AddForeignKey:
b, err = m.addForeignKey(fmter, b, change)
+ case *migrate.AddUniqueConstraint:
+ b, err = m.addUnique(fmter, b, change)
+ case *migrate.DropUniqueConstraint:
+ b, err = m.dropConstraint(fmter, b, change.FQN, change.Unique.Name)
+ case *migrate.DropConstraint:
+ b, err = m.dropConstraint(fmter, b, change.FQN(), change.ConstraintName)
case *migrate.RenameConstraint:
b, err = m.renameConstraint(fmter, b, change)
case *migrate.ChangeColumnType:
@@ -147,15 +151,34 @@ func (m *migrator) renameConstraint(fmter schema.Formatter, b []byte, rename *mi
return b, nil
}
-func (m *migrator) dropContraint(fmter schema.Formatter, b []byte, drop *migrate.DropConstraint) (_ []byte, err error) {
+func (m *migrator) addUnique(fmter schema.Formatter, b []byte, change *migrate.AddUniqueConstraint) (_ []byte, err error) {
+ b = append(b, "ALTER TABLE "...)
+ if b, err = change.FQN.AppendQuery(fmter, b); err != nil {
+ return b, err
+ }
+
+ b = append(b, " ADD CONSTRAINT "...)
+ if change.Unique.Name != "" {
+ b = fmter.AppendName(b, change.Unique.Name)
+ } else {
+ // Default naming scheme for unique constraints in Postgres is
__key
+ b = fmter.AppendName(b, fmt.Sprintf("%s_%s_key", change.FQN.Table, change.Unique.Columns))
+ }
+ b = append(b, " UNIQUE ("...)
+ b, _ = change.Unique.Columns.Safe().AppendQuery(fmter, b)
+ b = append(b, ")"...)
+
+ return b, nil
+}
+
+func (m *migrator) dropConstraint(fmter schema.Formatter, b []byte, fqn schema.FQN, name string) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
- fqn := drop.FQN()
if b, err = fqn.AppendQuery(fmter, b); err != nil {
return b, err
}
b = append(b, " DROP CONSTRAINT "...)
- b = fmter.AppendName(b, drop.ConstraintName)
+ b = fmter.AppendName(b, name)
return b, nil
}
diff --git a/dialect/pgdialect/inspector.go b/dialect/pgdialect/inspector.go
index c95e95cfb..dc4ea2707 100644
--- a/dialect/pgdialect/inspector.go
+++ b/dialect/pgdialect/inspector.go
@@ -48,7 +48,10 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) {
if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil {
return state, err
}
+
colDefs := make(map[string]sqlschema.Column)
+ uniqueGroups := make(map[string][]string)
+
for _, c := range columns {
def := c.Default
if c.IsSerial || c.IsIdentity {
@@ -66,12 +69,25 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) {
IsAutoIncrement: c.IsSerial,
IsIdentity: c.IsIdentity,
}
+
+ for _, group := range c.UniqueGroups {
+ uniqueGroups[group] = append(uniqueGroups[group], c.Name)
+ }
+ }
+
+ var unique []sqlschema.Unique
+ for name, columns := range uniqueGroups {
+ unique = append(unique, sqlschema.Unique{
+ Name: name,
+ Columns: sqlschema.NewComposite(columns...),
+ })
}
state.Tables = append(state.Tables, sqlschema.Table{
- Schema: table.Schema,
- Name: table.Name,
- Columns: colDefs,
+ Schema: table.Schema,
+ Name: table.Name,
+ Columns: colDefs,
+ UniqueContraints: unique,
})
}
@@ -106,8 +122,7 @@ type InformationSchemaColumn struct {
IndentityType string `bun:"identity_type"`
IsSerial bool `bun:"is_serial"`
IsNullable bool `bun:"is_nullable"`
- IsUnique bool `bun:"is_unique"`
- UniqueGroup []string `bun:"unique_group,array"`
+ UniqueGroups []string `bun:"unique_groups,array"`
}
type ForeignKey struct {
@@ -156,8 +171,7 @@ SELECT
"c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial,
COALESCE("c".identity_type, '') AS identity_type,
"c".is_nullable = 'YES' AS is_nullable,
- 'u' = ANY("c".constraint_type) AS is_unique,
- "c"."constraint_name" AS unique_group
+ "c"."unique_groups" AS unique_groups
FROM (
SELECT
"table_schema",
@@ -170,7 +184,7 @@ FROM (
"c".is_nullable,
att.array_dims,
att.identity_type,
- att."constraint_name",
+ att."unique_groups",
att."constraint_type"
FROM information_schema.columns "c"
LEFT JOIN (
@@ -180,7 +194,7 @@ FROM (
"c".attname AS "column_name",
"c".attndims AS array_dims,
"c".attidentity AS identity_type,
- ARRAY_AGG(con.conname) AS "constraint_name",
+ ARRAY_AGG(con.conname) FILTER (WHERE con.contype = 'u') AS "unique_groups",
ARRAY_AGG(con.contype) AS "constraint_type"
FROM (
SELECT
@@ -200,76 +214,6 @@ FROM (
) "c"
WHERE "table_schema" = ? AND "table_name" = ?
ORDER BY "table_schema", "table_name", "column_name"
-`
-
- // sqlInspectSchema retrieves column type definitions for all user-defined tables.
- // Other relations, such as views and indices, as well as Posgres's internal relations are excluded.
- //
- // TODO: implement scanning ORM relations for RawQuery too, so that one could scan this query directly to InformationSchemaTable.
- sqlInspectSchema = `
-SELECT
- "t"."table_schema",
- "t".table_name,
- "c".column_name,
- "c".data_type,
- "c".character_maximum_length::integer AS varchar_len,
- "c".data_type = 'ARRAY' AS is_array,
- COALESCE("c".array_dims, 0) AS array_dims,
- CASE
- WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$')
- ELSE "c".column_default
- END AS "default",
- "c".constraint_type = 'p' AS is_pk,
- "c".is_identity = 'YES' AS is_identity,
- "c".column_default = format('nextval(''%s_%s_seq''::regclass)', "t".table_name, "c".column_name) AS is_serial,
- COALESCE("c".identity_type, '') AS identity_type,
- "c".is_nullable = 'YES' AS is_nullable,
- "c".constraint_type = 'u' AS is_unique,
- "c"."constraint_name" AS unique_group
-FROM information_schema.tables "t"
- LEFT JOIN (
- SELECT
- "table_schema",
- "table_name",
- "column_name",
- "c".data_type,
- "c".character_maximum_length,
- "c".column_default,
- "c".is_identity,
- "c".is_nullable,
- att.array_dims,
- att.identity_type,
- att."constraint_name",
- att."constraint_type"
- FROM information_schema.columns "c"
- LEFT JOIN (
- SELECT
- s.nspname AS table_schema,
- "t".relname AS "table_name",
- "c".attname AS "column_name",
- "c".attndims AS array_dims,
- "c".attidentity AS identity_type,
- con.conname AS "constraint_name",
- con.contype AS "constraint_type"
- FROM (
- SELECT
- conname,
- contype,
- connamespace,
- conrelid,
- conrelid AS attrelid,
- UNNEST(conkey) AS attnum
- FROM pg_constraint
- ) con
- LEFT JOIN pg_attribute "c" USING (attrelid, attnum)
- LEFT JOIN pg_namespace s ON s.oid = con.connamespace
- LEFT JOIN pg_class "t" ON "t".oid = con.conrelid
- ) att USING (table_schema, "table_name", "column_name")
- ) "c" USING (table_schema, "table_name")
-WHERE table_type = 'BASE TABLE'
- AND table_schema <> 'information_schema'
- AND table_schema NOT LIKE 'pg_%'
-ORDER BY table_schema, table_name
`
// sqlInspectForeignKeys get FK definitions for user-defined tables.
diff --git a/internal/dbtest/inspect_test.go b/internal/dbtest/inspect_test.go
index bd758d1f9..6d3124261 100644
--- a/internal/dbtest/inspect_test.go
+++ b/internal/dbtest/inspect_test.go
@@ -42,7 +42,7 @@ type Office struct {
type Publisher struct {
ID string `bun:"publisher_id,pk,default:gen_random_uuid(),unique:office_fk"`
- Name string `bun:"publisher_name,unique,notnull,unique:office_fk"`
+ Name string `bun:"publisher_name,notnull,unique:office_fk"`
CreatedAt time.Time `bun:"created_at,default:current_timestamp"`
// Writers write articles for this publisher.
@@ -63,8 +63,9 @@ type PublisherToJournalist struct {
type Journalist struct {
bun.BaseModel `bun:"table:authors"`
ID int `bun:"author_id,pk,identity"`
- FirstName string `bun:",notnull"`
- LastName string
+ FirstName string `bun:"first_name,notnull,unique:full_name"`
+ LastName string `bun:"last_name,notnull,unique:full_name"`
+ Email string `bun:"email,notnull,unique"`
// Articles that this journalist has written.
Articles []*Article `bun:"rel:has-many,join:author_id=author_id"`
@@ -171,6 +172,9 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
SQLType: "bigint",
},
},
+ UniqueContraints: []sqlschema.Unique{
+ {Columns: sqlschema.NewComposite("editor", "title")},
+ },
},
{
Schema: defaultSchema,
@@ -185,10 +189,16 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
SQLType: sqltype.VarChar,
},
"last_name": {
- SQLType: sqltype.VarChar,
- IsNullable: true,
+ SQLType: sqltype.VarChar,
+ },
+ "email": {
+ SQLType: sqltype.VarChar,
},
},
+ UniqueContraints: []sqlschema.Unique{
+ {Columns: sqlschema.NewComposite("first_name", "last_name")},
+ {Columns: sqlschema.NewComposite("email")},
+ },
},
{
Schema: defaultSchema,
@@ -222,6 +232,9 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
IsNullable: true,
},
},
+ UniqueContraints: []sqlschema.Unique{
+ {Columns: sqlschema.NewComposite("publisher_id", "publisher_name")},
+ },
},
}
@@ -268,7 +281,7 @@ func mustCreateTableWithFKs(tb testing.TB, ctx context.Context, db *bun.DB, mode
for _, model := range models {
create := db.NewCreateTable().Model(model).WithForeignKeys()
_, err := create.Exec(ctx)
- require.NoError(tb, err, "must create table %q:", create.GetTableName())
+ require.NoError(tb, err, "arrange: must create table %q:", create.GetTableName())
mustDropTableOnCleanup(tb, ctx, db, model)
}
}
@@ -303,9 +316,11 @@ func cmpTables(tb testing.TB, d sqlschema.InspectorDialect, want, got []sqlschem
}
cmpColumns(tb, d, wt.Name, wt.Columns, gt.Columns)
+ cmpConstraints(tb, wt, gt)
}
}
+// cmpColumns compares that column definitions on the tables are
func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, want, got map[string]sqlschema.Column) {
tb.Helper()
var errs []string
@@ -362,6 +377,20 @@ func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, w
}
}
+// cmpConstraints compares constraints defined on the table with the expected ones.
+func cmpConstraints(tb testing.TB, want, got sqlschema.Table) {
+ tb.Helper()
+
+ // Only keep columns included in each unique constraint for comparison.
+ stripNames := func(uniques []sqlschema.Unique) (res []string) {
+ for _, u := range uniques {
+ res = append(res, u.Columns.String())
+ }
+ return
+ }
+ require.ElementsMatch(tb, stripNames(want.UniqueContraints), stripNames(got.UniqueContraints), "table %q does not have expected unique constraints (listA=want, listB=got)", want.Name)
+}
+
func tableNames(tables []sqlschema.Table) (names []string) {
for i := range tables {
names = append(names, tables[i].Name)
@@ -441,5 +470,31 @@ func TestSchemaInspector_Inspect(t *testing.T) {
require.Len(t, got.Tables, 1)
cmpColumns(t, dialect.(sqlschema.InspectorDialect), "model", want, got.Tables[0].Columns)
})
+
+ t.Run("inspect unique constraints", func(t *testing.T) {
+ type Model struct {
+ ID string `bun:",unique"`
+ FirstName string `bun:"first_name,unique:full_name"`
+ LastName string `bun:"last_name,unique:full_name"`
+ }
+
+ tables := schema.NewTables(dialect)
+ tables.Register((*Model)(nil))
+ inspector := sqlschema.NewSchemaInspector(tables)
+
+ want := sqlschema.Table{
+ Name: "models",
+ UniqueContraints: []sqlschema.Unique{
+ {Columns: sqlschema.NewComposite("id")},
+ {Name: "full_name", Columns: sqlschema.NewComposite("first_name", "last_name")},
+ },
+ }
+
+ got, err := inspector.Inspect(context.Background())
+ require.NoError(t, err)
+
+ require.Len(t, got.Tables, 1)
+ cmpConstraints(t, want, got.Tables[0])
+ })
})
}
diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go
index 0259a498c..47d28b0e8 100644
--- a/internal/dbtest/migrate_test.go
+++ b/internal/dbtest/migrate_test.go
@@ -209,7 +209,8 @@ func TestAutoMigrator_Run(t *testing.T) {
{testChangeColumnType_AutoCast},
{testIdentity},
{testAddDropColumn},
- // {testUnique},
+ {testUnique},
+ {testUniqueRenamedTable},
}
testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
@@ -542,16 +543,16 @@ func testRenamedColumns(t *testing.T, db *bun.DB) {
func testRenameColumnRenamesFK(t *testing.T, db *bun.DB) {
type TennantBefore struct {
bun.BaseModel `bun:"table:tennants"`
- ID int64 `bun:",pk,identity"`
+ ID int64 `bun:"id,pk,identity"`
Apartment int8
- NeighbourID int64
+ NeighbourID int64 `bun:"neighbour_id"`
Neighbour *TennantBefore `bun:"rel:has-one,join:neighbour_id=id"`
}
type TennantAfter struct {
bun.BaseModel `bun:"table:tennants"`
- TennantID int64 `bun:",pk,identity"`
+ TennantID int64 `bun:"tennant_id,pk,identity"`
Apartment int8
NeighbourID int64 `bun:"my_neighbour"`
@@ -760,6 +761,8 @@ func testUnique(t *testing.T, db *bun.DB) {
FirstName string `bun:"first_name,unique:full_name"`
LastName string `bun:"last_name,unique:full_name"`
Birthday string `bun:"birthday,unique"`
+ PetName string `bun:"pet_name,unique:pet"`
+ PetBreed string `bun:"pet_breed,unique:pet"`
}
type TableAfter struct {
@@ -767,8 +770,12 @@ func testUnique(t *testing.T, db *bun.DB) {
FirstName string `bun:"first_name,unique:full_name"`
MiddleName string `bun:"middle_name,unique:full_name"` // extend "full_name" unique group
LastName string `bun:"last_name,unique:full_name"`
- Birthday string `bun:"birthday"` // doesn't have to be unique any more
- Email string `bun:"email,unique"` // new column, unique
+
+ Birthday string `bun:"birthday"` // doesn't have to be unique any more
+ Email string `bun:"email,unique"` // new column, unique
+
+ PetName string `bun:"pet_name,unique"`
+ PetBreed string `bun:"pet_breed"` // shrink "pet" unique group
}
wantTables := []sqlschema.Table{
@@ -796,6 +803,90 @@ func testUnique(t *testing.T, db *bun.DB) {
SQLType: sqltype.VarChar,
IsNullable: true,
},
+ "pet_name": {
+ SQLType: sqltype.VarChar,
+ IsNullable: true,
+ },
+ "pet_breed": {
+ SQLType: sqltype.VarChar,
+ IsNullable: true,
+ },
+ },
+ UniqueContraints: []sqlschema.Unique{
+ {Columns: sqlschema.NewComposite("email")},
+ {Columns: sqlschema.NewComposite("pet_name")},
+ // We can only be sure of the user-defined index name
+ {Name: "full_name", Columns: sqlschema.NewComposite("first_name", "middle_name", "last_name")},
+ },
+ },
+ }
+
+ ctx := context.Background()
+ inspect := inspectDbOrSkip(t, db)
+ mustResetModel(t, ctx, db, (*TableBefore)(nil))
+ m := newAutoMigrator(t, db, migrate.WithModel((*TableAfter)(nil)))
+
+ // Act
+ err := m.Run(ctx)
+ require.NoError(t, err)
+
+ // Assert
+ state := inspect(ctx)
+ cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables)
+}
+
+func testUniqueRenamedTable(t *testing.T, db *bun.DB) {
+ type TableBefore struct {
+ bun.BaseModel `bun:"table:before"`
+ FirstName string `bun:"first_name,unique:full_name"`
+ LastName string `bun:"last_name,unique:full_name"`
+ Birthday string `bun:"birthday,unique"`
+ PetName string `bun:"pet_name,unique:pet"`
+ PetBreed string `bun:"pet_breed,unique:pet"`
+ }
+
+ type TableAfter struct {
+ bun.BaseModel `bun:"table:after"`
+ // Expand full_name unique group and rename it.
+ FirstName string `bun:"first_name,unique:birth_certificate"`
+ LastName string `bun:"last_name,unique:birth_certificate"`
+ Birthday string `bun:"birthday,unique:birth_certificate"`
+
+ // pet_name and pet_breed have their own unique indices now.
+ PetName string `bun:"pet_name,unique"`
+ PetBreed string `bun:"pet_breed,unique"`
+ }
+
+ wantTables := []sqlschema.Table{
+ {
+ Schema: db.Dialect().DefaultSchema(),
+ Name: "after",
+ Columns: map[string]sqlschema.Column{
+ "first_name": {
+ SQLType: sqltype.VarChar,
+ IsNullable: true,
+ },
+ "last_name": {
+ SQLType: sqltype.VarChar,
+ IsNullable: true,
+ },
+ "birthday": {
+ SQLType: sqltype.VarChar,
+ IsNullable: true,
+ },
+ "pet_name": {
+ SQLType: sqltype.VarChar,
+ IsNullable: true,
+ },
+ "pet_breed": {
+ SQLType: sqltype.VarChar,
+ IsNullable: true,
+ },
+ },
+ UniqueContraints: []sqlschema.Unique{
+ {Columns: sqlschema.NewComposite("pet_name")},
+ {Columns: sqlschema.NewComposite("pet_breed")},
+ {Name: "full_name", Columns: sqlschema.NewComposite("first_name", "last_name", "birthday")},
},
},
}
diff --git a/migrate/diff.go b/migrate/diff.go
index 341329737..99b0fbf13 100644
--- a/migrate/diff.go
+++ b/migrate/diff.go
@@ -37,6 +37,7 @@ AddedLoop:
// Here we do not check for created / dropped columns, as well as column type changes,
// because it is only possible to detect a renamed table if its signature (see state.go) did not change.
d.detectColumnChanges(removed, added, false)
+ d.detectConstraintChanges(removed, added)
// Update referenced table in all related FKs.
if d.detectRenamedFKs {
@@ -82,6 +83,7 @@ AddedLoop:
}
}
d.detectColumnChanges(current, target, true)
+ d.detectConstraintChanges(current, target)
}
// Compare and update FKs ----------------
@@ -338,6 +340,8 @@ func (d *detector) makeTargetColDef(current, target sqlschema.Column) sqlschema.
// detechColumnChanges finds renamed columns and, if checkType == true, columns with changed type.
func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkType bool) {
+ fqn := schema.FQN{target.Schema, target.Name}
+
ChangedRenamed:
for tName, tCol := range target.Columns {
@@ -347,7 +351,7 @@ ChangedRenamed:
if cCol, ok := current.Columns[tName]; ok {
if checkType && !d.equalColumns(cCol, tCol) {
d.changes.Add(&ChangeColumnType{
- FQN: schema.FQN{target.Schema, target.Name},
+ FQN: fqn,
Column: tName,
From: cCol,
To: d.makeTargetColDef(cCol, tCol),
@@ -364,7 +368,7 @@ ChangedRenamed:
continue
}
d.changes.Add(&RenameColumn{
- FQN: schema.FQN{target.Schema, target.Name},
+ FQN: fqn,
OldName: cName,
NewName: tName,
})
@@ -375,7 +379,7 @@ ChangedRenamed:
}
d.changes.Add(&AddColumn{
- FQN: schema.FQN{target.Schema, target.Name},
+ FQN: fqn,
Column: tName,
ColDef: tCol,
})
@@ -385,7 +389,7 @@ ChangedRenamed:
for cName, cCol := range current.Columns {
if _, keep := target.Columns[cName]; !keep {
d.changes.Add(&DropColumn{
- FQN: schema.FQN{target.Schema, target.Name},
+ FQN: fqn,
Column: cName,
ColDef: cCol,
})
@@ -393,6 +397,37 @@ ChangedRenamed:
}
}
+func (d *detector) detectConstraintChanges(current, target sqlschema.Table) {
+ fqn := schema.FQN{target.Schema, target.Name}
+
+Add:
+ for _, want := range target.UniqueContraints {
+ for _, got := range current.UniqueContraints {
+ if got.Equals(want) {
+ continue Add
+ }
+ }
+ d.changes.Add(&AddUniqueConstraint{
+ FQN: fqn,
+ Unique: want,
+ })
+ }
+
+Drop:
+ for _, got := range current.UniqueContraints {
+ for _, want := range target.UniqueContraints {
+ if got.Equals(want) {
+ continue Drop
+ }
+ }
+
+ d.changes.Add(&DropUniqueConstraint{
+ FQN: fqn,
+ Unique: got,
+ })
+ }
+}
+
// sqlschema utils ------------------------------------------------------------
// tableSet stores unique table definitions.
diff --git a/migrate/operations.go b/migrate/operations.go
index 4b3958b5d..c4bbc6b80 100644
--- a/migrate/operations.go
+++ b/migrate/operations.go
@@ -202,6 +202,7 @@ func (op *AddForeignKey) GetReverse() Operation {
}
}
+// TODO: Rename to DropForeignKey
// DropConstraint.
type DropConstraint struct {
FK sqlschema.FK
@@ -224,6 +225,63 @@ func (op *DropConstraint) GetReverse() Operation {
}
}
+type AddUniqueConstraint struct {
+ FQN schema.FQN
+ Unique sqlschema.Unique
+}
+
+var _ Operation = (*AddUniqueConstraint)(nil)
+
+func (op *AddUniqueConstraint) GetReverse() Operation {
+ return &DropUniqueConstraint{
+ FQN: op.FQN,
+ Unique: op.Unique,
+ }
+}
+
+func (op *AddUniqueConstraint) DependsOn(another Operation) bool {
+ switch another := another.(type) {
+ case *AddColumn:
+ var sameColumn bool
+ for _, column := range op.Unique.Columns.Split() {
+ if column == another.Column {
+ sameColumn = true
+ break
+ }
+ }
+ return op.FQN == another.FQN && sameColumn
+ case *RenameTable:
+ return op.FQN.Schema == another.FQN.Schema && op.FQN.Table == another.NewName
+ case *DropUniqueConstraint:
+ // We want to drop the constraint with the same name before adding this one.
+ return op.FQN == another.FQN && op.Unique.Name == another.Unique.Name
+ default:
+ return false
+ }
+
+}
+
+type DropUniqueConstraint struct {
+ FQN schema.FQN
+ Unique sqlschema.Unique
+}
+
+var _ Operation = (*DropUniqueConstraint)(nil)
+
+func (op *DropUniqueConstraint) DependsOn(another Operation) bool {
+ if rename, ok := another.(*RenameTable); ok {
+ return op.FQN.Schema == rename.FQN.Schema && op.FQN.Table == rename.NewName
+ }
+ return false
+}
+
+func (op *DropUniqueConstraint) GetReverse() Operation {
+ return &AddUniqueConstraint{
+ FQN: op.FQN,
+ Unique: op.Unique,
+ }
+}
+
// Change column type.
type ChangeColumnType struct {
FQN schema.FQN
diff --git a/migrate/sqlschema/inspector.go b/migrate/sqlschema/inspector.go
index 4c62289a3..cf809a343 100644
--- a/migrate/sqlschema/inspector.go
+++ b/migrate/sqlschema/inspector.go
@@ -75,11 +75,31 @@ func (si *SchemaInspector) Inspect(ctx context.Context) (State, error) {
}
}
+ var unique []Unique
+ for name, group := range t.Unique {
+ // Create a separate unique index for single-column unique constraints
+ // let each dialect apply the default naming convention.
+ if name == "" {
+ for _, f := range group {
+ unique = append(unique, Unique{Columns: NewComposite(f.Name)})
+ }
+ continue
+ }
+
+ // Set the name if it is a "unique group", in which case the user has provided the name.
+ var columns []string
+ for _, f := range group {
+ columns = append(columns, f.Name)
+ }
+ unique = append(unique, Unique{Name: name, Columns: NewComposite(columns...)})
+ }
+
state.Tables = append(state.Tables, Table{
- Schema: t.Schema,
- Name: t.Name,
- Model: t.ZeroIface,
- Columns: columns,
+ Schema: t.Schema,
+ Name: t.Name,
+ Model: t.ZeroIface,
+ Columns: columns,
+ UniqueContraints: unique,
})
for _, rel := range t.Relations {
diff --git a/migrate/sqlschema/state.go b/migrate/sqlschema/state.go
index 789145196..b6139d29d 100644
--- a/migrate/sqlschema/state.go
+++ b/migrate/sqlschema/state.go
@@ -2,6 +2,7 @@ package sqlschema
import (
"fmt"
+ "slices"
"strings"
"github.com/uptrace/bun/schema"
@@ -13,10 +14,20 @@ type State struct {
}
type Table struct {
- Schema string
- Name string
- Model interface{}
+ // Schema containing the table.
+ Schema string
+
+ // Table name.
+ Name string
+
+ // Model stores a pointer to the bun's underlying Go struct for the table.
+ Model interface{}
+
+ // Columns map each column name to the column type definition.
Columns map[string]Column
+
+ // UniqueConstraints defined on the table.
+ UniqueContraints []Unique
}
// T returns a fully-qualified name object for the table.
@@ -131,7 +142,7 @@ type cFQN struct {
// C creates a fully-qualified column name object.
func C(schema, table string, columns ...string) cFQN {
- return cFQN{tFQN: T(schema, table), Column: newComposite(columns...)}
+ return cFQN{tFQN: T(schema, table), Column: NewComposite(columns...)}
}
// T returns the FQN of the column's parent table.
@@ -143,8 +154,9 @@ func (c cFQN) T() tFQN {
// Although having duplicated column references in a FK is illegal, composite neither validates nor enforces this constraint on the caller.
type composite string
-// newComposite creates a composite column from a slice of column names.
-func newComposite(columns ...string) composite {
+// NewComposite creates a composite column from a slice of column names.
+func NewComposite(columns ...string) composite {
+ slices.Sort(columns)
return composite(strings.Join(columns, ","))
}
@@ -162,9 +174,14 @@ func (c composite) Split() []string {
}
// Contains checks that a composite column contains every part of another composite.
-func (c composite) Contains(other composite) bool {
+func (c composite) contains(other composite) bool {
+ return c.Contains(string(other))
+}
+
+// Contains checks that a composite column contains the current column.
+func (c composite) Contains(other string) bool {
var count int
- checkColumns := other.Split()
+ checkColumns := composite(other).Split()
wantCount := len(checkColumns)
for _, check := range checkColumns {
@@ -187,7 +204,7 @@ func (c composite) Replace(oldColumn, newColumn string) composite {
for i, column := range columns {
if column == oldColumn {
columns[i] = newColumn
- return newComposite(columns...)
+ return NewComposite(columns...)
}
}
return c
@@ -242,9 +259,9 @@ func (fk *FK) dependsT(t tFQN) (ok bool, cols []*cFQN) {
// depends on C("a", "b", "c_1"), C("a", "b", "c_2"), C("w", "x", "y_1"), and C("w", "x", "y_2")
func (fk *FK) dependsC(c cFQN) (bool, *cFQN) {
switch {
- case fk.From.Column.Contains(c.Column):
+ case fk.From.Column.contains(c.Column):
return true, &fk.From
- case fk.To.Column.Contains(c.Column):
+ case fk.To.Column.contains(c.Column):
return true, &fk.To
}
return false, nil
@@ -347,3 +364,14 @@ func (r RefMap) Deleted() (fks []FK) {
}
return
}
+
+// Unique represents a unique constraint defined on 1 or more columns.
+type Unique struct {
+ Name string
+ Columns composite
+}
+
+// Equals checks that two unique constraint are the same, assuming both are defined for the same table.
+func (u Unique) Equals(other Unique) bool {
+ return u.Columns == other.Columns
+}