Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for ON UPDATE and ON DELETE rules on belongs-to relationships from struct tags #533

Merged
merged 10 commits into from
Jun 8, 2022
89 changes: 89 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ func TestDB(t *testing.T) {
{testJSONValuer},
{testSelectBool},
{testFKViolation},
{testWithForeignKeysAndRules},
{testWithForeignKeys},
{testInterfaceAny},
{testInterfaceJSON},
Expand Down Expand Up @@ -869,6 +870,94 @@ func testFKViolation(t *testing.T, db *bun.DB) {
require.Equal(t, 0, n)
}

func testWithForeignKeysAndRules(t *testing.T, db *bun.DB) {
type User struct {
ID int `bun:",pk"`
Type string `bun:",pk"`
Name string
}
type Deck struct {
ID int `bun:",pk"`
UserID int
UserType string
User *User `bun:"rel:belongs-to,join:user_id=id,join:user_type=type,on_update:cascade,on_delete:set null"`
}

if db.Dialect().Name() == dialect.SQLite {
_, err := db.Exec("PRAGMA foreign_keys = ON;")
require.NoError(t, err)
}

for _, model := range []interface{}{(*Deck)(nil), (*User)(nil)} {
_, err := db.NewDropTable().Model(model).IfExists().Exec(ctx)
require.NoError(t, err)
}

_, err := db.NewCreateTable().
Model((*User)(nil)).
IfNotExists().
Exec(ctx)
require.NoError(t, err)

_, err = db.NewCreateTable().
Model((*Deck)(nil)).
IfNotExists().
WithForeignKeys().
Exec(ctx)
require.NoError(t, err)

// Empty deck should violate FK constraint.
_, err = db.NewInsert().Model(new(Deck)).Exec(ctx)
require.Error(t, err)

// Create a deck that violates the user_id FK contraint
deck := &Deck{UserID: 42}

_, err = db.NewInsert().Model(deck).Exec(ctx)
require.Error(t, err)

decks := []*Deck{deck}
_, err = db.NewInsert().Model(&decks).Exec(ctx)
require.Error(t, err)

n, err := db.NewSelect().Model((*Deck)(nil)).Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, n)

_, err = db.NewInsert().Model(&User{ID: 1, Type: "admin", Name: "root"}).Exec(ctx)
require.NoError(t, err)
res, err := db.NewInsert().Model(&Deck{UserID: 1, UserType: "admin"}).Exec(ctx)
require.NoError(t, err)

affected, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), affected)

// Update User ID and check for FK update
res, err = db.NewUpdate().Model(&User{}).Where("id = ?", 1).Where("type = ?", "admin").Set("id = ?", 2).Exec(ctx)
require.NoError(t, err)

affected, err = res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), affected)

n, err = db.NewSelect().Model(&Deck{}).Where("user_id = 1").Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, n)

n, err = db.NewSelect().Model(&Deck{}).Where("user_id = 2").Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, n)

// Delete user and check for FK delete
_, err = db.NewDelete().Model(&User{}).Where("id = ?", 2).Exec(ctx)
require.NoError(t, err)

n, err = db.NewSelect().Model(&Deck{}).Where("user_id = 2").Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, n)
}

func testWithForeignKeys(t *testing.T, db *bun.DB) {
type User struct {
ID int `bun:",pk,autoincrement"`
Expand Down
9 changes: 6 additions & 3 deletions query_table_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,16 @@ func (q *CreateTableQuery) TableSpace(tablespace string) *CreateTableQuery {
func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery {
for _, relation := range q.tableModel.Table().Relations {
if relation.Type == schema.ManyToManyRelation ||
relation.Type == schema.HasManyRelation {
relation.Type == schema.HasManyRelation {
continue
}
q = q.ForeignKey("(?) REFERENCES ? (?)",
}

q = q.ForeignKey("(?) REFERENCES ? (?) ? ?",
Safe(appendColumns(nil, "", relation.BaseFields)),
relation.JoinTable.SQLName,
Safe(appendColumns(nil, "", relation.JoinFields)),
Safe(relation.OnUpdate),
Safe(relation.OnDelete),
)
}
return q
Expand Down
2 changes: 2 additions & 0 deletions schema/relation.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ type Relation struct {
JoinTable *Table
BaseFields []*Field
JoinFields []*Field
OnUpdate string
OnDelete string

PolymorphicField *Field
PolymorphicValue string
Expand Down
42 changes: 42 additions & 0 deletions schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,35 @@ func (t *Table) belongsToRelation(field *Field) *Relation {
JoinTable: joinTable,
}

rel.OnUpdate = "ON UPDATE NO ACTION"
if onUpdate, ok := field.Tag.Options["on_update"]; ok {
if len(onUpdate) > 1 {
panic(fmt.Errorf("bun: %s belongs-to %s: on_update option must be a single field", t.TypeName, field.GoName))
}

rule := strings.ToUpper(onUpdate[0])
if !isKnownFKRule(rule) {
internal.Warn.Printf("bun: %s belongs-to %s: unknown on_update rule %s", t.TypeName, field.GoName, rule)
}

s := fmt.Sprintf("ON UPDATE %s", rule)
rel.OnUpdate = s
}

rel.OnDelete = "ON DELETE NO ACTION"
if onDelete, ok := field.Tag.Options["on_delete"]; ok {
if len(onDelete) > 1 {
panic(fmt.Errorf("bun: %s belongs-to %s: on_delete option must be a single field", t.TypeName, field.GoName))
}

rule := strings.ToUpper(onDelete[0])
if !isKnownFKRule(rule) {
internal.Warn.Printf("bun: %s belongs-to %s: unknown on_delete rule %s", t.TypeName, field.GoName, rule)
}
s := fmt.Sprintf("ON DELETE %s", rule)
rel.OnDelete = s
}

if join, ok := field.Tag.Options["join"]; ok {
baseColumns, joinColumns := parseRelationJoin(join)
for i, baseColumn := range baseColumns {
Expand Down Expand Up @@ -859,13 +888,26 @@ func isKnownFieldOption(name string) bool {
"autoincrement",
"rel",
"join",
"on_update",
"on_delete",
"m2m",
"polymorphic":
return true
}
return false
}

func isKnownFKRule(name string) bool {
switch name {
case "CASCADE",
"RESTRICT",
"SET NULL",
"SET DEFAULT":
return true
}
return false
}

func removeField(fields []*Field, field *Field) []*Field {
for i, f := range fields {
if f == field {
Expand Down