Skip to content

Commit

Permalink
gormschema: fixed circular foreign keys on postgresql example
Browse files Browse the repository at this point in the history
  • Loading branch information
dorav committed Dec 13, 2023
1 parent 577fd9c commit a7a8234
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 4 deletions.
74 changes: 71 additions & 3 deletions gormschema/gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/migrator"
)

// New returns a new Loader.
Expand Down Expand Up @@ -73,12 +74,79 @@ func (l *Loader) Load(models ...any) (string, error) {
if err != nil {
return "", err
}
if err := db.AutoMigrate(models...); err != nil {
return "", err
db.Config.DisableForeignKeyConstraintWhenMigrating = true
err = db.AutoMigrate(models...)

Check failure on line 78 in gormschema/gorm.go

View workflow job for this annotation

GitHub Actions / golangci-lint

ineffectual assignment to err (ineffassign)
if !l.config.DisableForeignKeyConstraintWhenMigrating {
db, err = gorm.Open(customDialector{
Dialector: di,
}, l.config)
if err != nil {
return "", err
}
cm, ok := db.Migrator().(*customMigrator)
if !ok {
return "", err
}
if err = cm.CreateConstraints(models); err != nil {
return "", err
}
}
s, ok := recordriver.Session("gorm")
if !ok {
return "", err
return "", fmt.Errorf("gorm db session not found")
}
return s.Stmts(), nil
}

type customMigrator struct {
migrator.Migrator
dialectMigrator gorm.Migrator
}

type customDialector struct {
gorm.Dialector
}

func (d customDialector) newCustomMigrator(db *gorm.DB) *customMigrator {
return &customMigrator{
Migrator: migrator.Migrator{
Config: migrator.Config{
DB: db,
Dialector: d,
CreateIndexAfterCreateTable: true,
},
},
dialectMigrator: d.Dialector.Migrator(db),
}
}

func (d customDialector) Migrator(db *gorm.DB) gorm.Migrator {
return d.newCustomMigrator(db)
}

func (m *customMigrator) HasTable(dst interface{}) bool {
return true
}

func (m *customMigrator) CreateConstraints(models []interface{}) error {
for _, model := range m.ReorderModels(models, true) {
err := m.Migrator.RunWithValue(model, func(stmt *gorm.Statement) error {
for _, rel := range stmt.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil &&
constraint.Schema == stmt.Schema {
if err := m.dialectMigrator.CreateConstraint(model, constraint.Name); err != nil {
return err
}
}
}
return nil
})
if err != nil {
return err
}
}
return nil
}
29 changes: 28 additions & 1 deletion gormschema/gorm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package gormschema
import (
"testing"

"ariga.io/atlas-go-sdk/recordriver"
"ariga.io/atlas-provider-gorm/internal/testdata/models"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)

func TestConfig(t *testing.T) {
func TestSQLiteConfig(t *testing.T) {
l := New("sqlite", WithConfig(
&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Expand All @@ -20,3 +21,29 @@ func TestConfig(t *testing.T) {
require.Contains(t, sql, "CREATE TABLE `users`")
require.NotContains(t, sql, "FOREIGN KEY")
}

func TestPostgreSQLConfig(t *testing.T) {
l := New("postgres", WithConfig(&gorm.Config{}))
sql, err := l.Load(models.Location{}, models.Event{})
require.NoError(t, err)
require.Contains(t, sql, `CREATE TABLE "events"`)
require.Contains(t, sql, `CREATE UNIQUE INDEX IF NOT EXISTS "idx_events_location_id"`)
require.Contains(t, sql, `CREATE TABLE "locations"`)
require.Contains(t, sql, `CREATE UNIQUE INDEX IF NOT EXISTS "idx_locations_event_id"`)
require.Contains(t, sql, `ALTER TABLE "events" ADD CONSTRAINT "fk_locations_event" FOREIGN KEY ("locationId")`)
require.Contains(t, sql, `ALTER TABLE "locations" ADD CONSTRAINT "fk_events_location"`)
sess, ok := recordriver.Session("gorm")
require.True(t, ok)
sess.Statements = []string{}
l = New("postgres", WithConfig(
&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
},
))
sql, err = l.Load(models.Location{}, models.Event{})
require.NoError(t, err)
require.Contains(t, sql, `CREATE TABLE "events"`)
require.Contains(t, sql, `CREATE TABLE "locations"`)
require.Contains(t, sql, `CREATE UNIQUE INDEX IF NOT EXISTS "idx_locations_event_id"`)
require.NotContains(t, sql, "FOREIGN KEY")
}
12 changes: 12 additions & 0 deletions internal/testdata/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,15 @@ type Toy struct {
ID uint
Name string
}

type Location struct {
LocationID string `gorm:"primaryKey;column:locationId;"`
EventID string `gorm:"uniqueIndex;column:eventId;"`
Event *Event `gorm:"foreignKey:locationId;references:locationId;OnUpdate:CASCADE,OnDelete:CASCADE"`
}

type Event struct {
EventID string `gorm:"primaryKey;column:eventId;"`
LocationID string `gorm:"column:locationId;"`
Location *Location `gorm:"foreignKey:eventId;references:eventId;OnUpdate:CASCADE,OnDelete:CASCADE"`
}

0 comments on commit a7a8234

Please sign in to comment.