diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 404e8cdf2..4e620f119 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -254,6 +254,7 @@ func TestDB(t *testing.T) { {testJSONValuer}, {testSelectBool}, {testFKViolation}, + {testWithForeignKeysAndRules}, {testWithForeignKeys}, {testInterfaceAny}, {testInterfaceJSON}, @@ -869,6 +870,94 @@ func testFKViolation(t *testing.T, db *bun.DB) { require.Equal(t, 0, n) } +func testWithForeignKeysAndRules(t *testing.T, db *bun.DB) { + type User struct { + ID int `bun:",pk"` + Type string `bun:",pk"` + Name string + } + type Deck struct { + ID int `bun:",pk"` + UserID int + UserType string + User *User `bun:"rel:belongs-to,join:user_id=id,join:user_type=type,on_update:cascade,on_delete:set null"` + } + + if db.Dialect().Name() == dialect.SQLite { + _, err := db.Exec("PRAGMA foreign_keys = ON;") + require.NoError(t, err) + } + + for _, model := range []interface{}{(*Deck)(nil), (*User)(nil)} { + _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx) + require.NoError(t, err) + } + + _, err := db.NewCreateTable(). + Model((*User)(nil)). + IfNotExists(). + Exec(ctx) + require.NoError(t, err) + + _, err = db.NewCreateTable(). + Model((*Deck)(nil)). + IfNotExists(). + WithForeignKeys(). + Exec(ctx) + require.NoError(t, err) + + // Empty deck should violate FK constraint. + _, err = db.NewInsert().Model(new(Deck)).Exec(ctx) + require.Error(t, err) + + // Create a deck that violates the user_id FK contraint + deck := &Deck{UserID: 42} + + _, err = db.NewInsert().Model(deck).Exec(ctx) + require.Error(t, err) + + decks := []*Deck{deck} + _, err = db.NewInsert().Model(&decks).Exec(ctx) + require.Error(t, err) + + n, err := db.NewSelect().Model((*Deck)(nil)).Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, n) + + _, err = db.NewInsert().Model(&User{ID: 1, Type: "admin", Name: "root"}).Exec(ctx) + require.NoError(t, err) + res, err := db.NewInsert().Model(&Deck{UserID: 1, UserType: "admin"}).Exec(ctx) + require.NoError(t, err) + + affected, err := res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), affected) + + // Update User ID and check for FK update + res, err = db.NewUpdate().Model(&User{}).Where("id = ?", 1).Where("type = ?", "admin").Set("id = ?", 2).Exec(ctx) + require.NoError(t, err) + + affected, err = res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), affected) + + n, err = db.NewSelect().Model(&Deck{}).Where("user_id = 1").Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, n) + + n, err = db.NewSelect().Model(&Deck{}).Where("user_id = 2").Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, n) + + // Delete user and check for FK delete + _, err = db.NewDelete().Model(&User{}).Where("id = ?", 2).Exec(ctx) + require.NoError(t, err) + + n, err = db.NewSelect().Model(&Deck{}).Where("user_id = 2").Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, n) +} + func testWithForeignKeys(t *testing.T, db *bun.DB) { type User struct { ID int `bun:",pk,autoincrement"` diff --git a/query_table_create.go b/query_table_create.go index 4aad10070..daa1ccca9 100644 --- a/query_table_create.go +++ b/query_table_create.go @@ -105,13 +105,16 @@ func (q *CreateTableQuery) TableSpace(tablespace string) *CreateTableQuery { func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery { for _, relation := range q.tableModel.Table().Relations { if relation.Type == schema.ManyToManyRelation || - relation.Type == schema.HasManyRelation { + relation.Type == schema.HasManyRelation { continue - } - q = q.ForeignKey("(?) REFERENCES ? (?)", + } + + q = q.ForeignKey("(?) REFERENCES ? (?) ? ?", Safe(appendColumns(nil, "", relation.BaseFields)), relation.JoinTable.SQLName, Safe(appendColumns(nil, "", relation.JoinFields)), + Safe(relation.OnUpdate), + Safe(relation.OnDelete), ) } return q diff --git a/schema/relation.go b/schema/relation.go index 8d1baeb3f..06ef8c05c 100644 --- a/schema/relation.go +++ b/schema/relation.go @@ -18,6 +18,8 @@ type Relation struct { JoinTable *Table BaseFields []*Field JoinFields []*Field + OnUpdate string + OnDelete string PolymorphicField *Field PolymorphicValue string diff --git a/schema/table.go b/schema/table.go index e39cefb30..3abb354be 100644 --- a/schema/table.go +++ b/schema/table.go @@ -479,6 +479,35 @@ func (t *Table) belongsToRelation(field *Field) *Relation { JoinTable: joinTable, } + rel.OnUpdate = "ON UPDATE NO ACTION" + if onUpdate, ok := field.Tag.Options["on_update"]; ok { + if len(onUpdate) > 1 { + panic(fmt.Errorf("bun: %s belongs-to %s: on_update option must be a single field", t.TypeName, field.GoName)) + } + + rule := strings.ToUpper(onUpdate[0]) + if !isKnownFKRule(rule) { + internal.Warn.Printf("bun: %s belongs-to %s: unknown on_update rule %s", t.TypeName, field.GoName, rule) + } + + s := fmt.Sprintf("ON UPDATE %s", rule) + rel.OnUpdate = s + } + + rel.OnDelete = "ON DELETE NO ACTION" + if onDelete, ok := field.Tag.Options["on_delete"]; ok { + if len(onDelete) > 1 { + panic(fmt.Errorf("bun: %s belongs-to %s: on_delete option must be a single field", t.TypeName, field.GoName)) + } + + rule := strings.ToUpper(onDelete[0]) + if !isKnownFKRule(rule) { + internal.Warn.Printf("bun: %s belongs-to %s: unknown on_delete rule %s", t.TypeName, field.GoName, rule) + } + s := fmt.Sprintf("ON DELETE %s", rule) + rel.OnDelete = s + } + if join, ok := field.Tag.Options["join"]; ok { baseColumns, joinColumns := parseRelationJoin(join) for i, baseColumn := range baseColumns { @@ -859,6 +888,8 @@ func isKnownFieldOption(name string) bool { "autoincrement", "rel", "join", + "on_update", + "on_delete", "m2m", "polymorphic": return true @@ -866,6 +897,17 @@ func isKnownFieldOption(name string) bool { return false } +func isKnownFKRule(name string) bool { + switch name { + case "CASCADE", + "RESTRICT", + "SET NULL", + "SET DEFAULT": + return true + } + return false +} + func removeField(fields []*Field, field *Field) []*Field { for i, f := range fields { if f == field {