diff --git a/internal/dbtest/orm_test.go b/internal/dbtest/orm_test.go index 5d3be7a5d..117a2cb47 100644 --- a/internal/dbtest/orm_test.go +++ b/internal/dbtest/orm_test.go @@ -15,6 +15,7 @@ import ( "github.com/uptrace/bun/dbfixture" "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/schema" ) func TestORM(t *testing.T) { @@ -34,6 +35,8 @@ func TestORM(t *testing.T) { {testRelationBelongsToSelf}, {testCompositeHasMany}, {testCompositeM2M}, + {testHasOneRelationWithOpts}, + {testHasManyRelationWithOpts}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -519,6 +522,151 @@ func testCompositeM2M(t *testing.T, db *bun.DB) { require.Equal(t, 1, len(ordersOut2[0].Items)) } +func testHasOneRelationWithOpts(t *testing.T, db *bun.DB) { + type Profile struct { + ID int64 `bun:",pk"` + Lang string + UserID int64 + } + + type User struct { + bun.BaseModel `bun:"alias:u"` + ID int64 `bun:",pk"` + Name string + Profile *Profile `bun:"rel:has-one,join:id=user_id"` + } + + mustResetModel(t, ctx, db, (*User)(nil), (*Profile)(nil)) + + users := []*User{ + {ID: 1, Name: "user 1"}, + {ID: 2, Name: "user 2"}, + {ID: 3, Name: "user 3"}, + } + _, err := db.NewInsert().Model(&users).Exec(ctx) + require.NoError(t, err) + + profiles := []*Profile{ + {ID: 1, Lang: "en", UserID: 1}, + {ID: 2, Lang: "ru", UserID: 2}, + {ID: 3, Lang: "md", UserID: 3}, + } + _, err = db.NewInsert().Model(&profiles).Exec(ctx) + require.NoError(t, err) + + var outUsers1 []*User + err = db. + NewSelect(). + Model(&outUsers1). + RelationWithOpts("Profile", bun.RelationOpts{ + AdditionalJoinOnConditions: []schema.QueryWithArgs{ + { + Query: "profile.lang = ?", + Args: []any{"ru"}, + }, + }, + }). + Where("u.id IN (?)", bun.In([]int64{1, 2})). + Scan(ctx) + require.NoError(t, err) + require.Len(t, outUsers1, 2) + require.ElementsMatch(t, []*User{ + {ID: 1, Name: "user 1", Profile: nil}, + {ID: 2, Name: "user 2", Profile: &Profile{ID: 2, Lang: "ru", UserID: 2}}, + }, outUsers1) + + var outUsers2 []*User + err = db. + NewSelect(). + Model(&outUsers2). + RelationWithOpts("Profile", bun.RelationOpts{ + Apply: func(q *bun.SelectQuery) *bun.SelectQuery { + return q.Where("profile.lang = ?", "ru") + }, + }). + Where("u.id IN (?)", bun.In([]int64{1, 2})). + Scan(ctx) + require.NoError(t, err) + require.Len(t, outUsers2, 1) + require.ElementsMatch(t, []*User{ + {ID: 2, Name: "user 2", Profile: &Profile{ID: 2, Lang: "ru", UserID: 2}}, + }, outUsers2) +} + +func testHasManyRelationWithOpts(t *testing.T, db *bun.DB) { + type Profile struct { + ID int64 `bun:",pk"` + Name string + Lang string + Active bool + UserID int64 + } + + type User struct { + bun.BaseModel `bun:"alias:u"` + ID int64 `bun:",pk"` + Name string + Profiles []*Profile `bun:"rel:has-many,join:id=user_id"` + } + + mustResetModel(t, ctx, db, (*User)(nil), (*Profile)(nil)) + + users := []*User{ + {ID: 1, Name: "user 1"}, + {ID: 2, Name: "user 2"}, + {ID: 3, Name: "user 3"}, + } + _, err := db.NewInsert().Model(&users).Exec(ctx) + require.NoError(t, err) + + profiles := []*Profile{ + {ID: 1, Name: "name1-en", Lang: "en", UserID: 1}, + {ID: 2, Name: "name2-ru", Lang: "ru", UserID: 2}, + {ID: 3, Name: "name2-ja", Lang: "ja", UserID: 2}, + {ID: 4, Name: "name3-md", Lang: "md", UserID: 3}, + {ID: 5, Name: "name3-en", Lang: "en", UserID: 3}, + } + _, err = db.NewInsert().Model(&profiles).Exec(ctx) + require.NoError(t, err) + + var outUsers1 []*User + err = db. + NewSelect(). + Model(&outUsers1). + RelationWithOpts("Profiles", bun.RelationOpts{ + AdditionalJoinOnConditions: []schema.QueryWithArgs{ + { + Query: "profile.lang = ?", + Args: []any{"ru"}, + }, + }, + }). + Where("u.id IN (?)", bun.In([]int64{1, 2})). + Scan(ctx) + require.NoError(t, err) + require.Equal(t, []*User{ + {ID: 1, Name: "user 1", Profiles: nil}, + {ID: 2, Name: "user 2", Profiles: []*Profile{{ID: 2, Name: "name2-ru", Lang: "ru", UserID: 2}}}, + }, outUsers1) + + var outUsers2 []*User + err = db. + NewSelect(). + Model(&outUsers2). + RelationWithOpts("Profiles", bun.RelationOpts{ + Apply: func(q *bun.SelectQuery) *bun.SelectQuery { + return q.Where("profile.lang = ?", "ru") + }, + }). + Where("u.id IN (?)", bun.In([]int64{1, 2})). + Scan(ctx) + require.NoError(t, err) + require.Equal(t, []*User{ + {ID: 1, Name: "user 1", Profiles: nil}, + {ID: 2, Name: "user 2", Profiles: []*Profile{{ID: 2, Name: "name2-ru", Lang: "ru", UserID: 2}}}, + }, outUsers2) +} + type Genre struct { ID int `bun:",pk"` Name string diff --git a/query_select.go b/query_select.go index 2b0872ae0..a0f680ea6 100644 --- a/query_select.go +++ b/query_select.go @@ -381,6 +381,43 @@ func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQ return q } + q.applyToRelation(join, apply...) + + return q +} + +type RelationOpts struct { + // Apply applies additional options to the relation. + Apply func(*SelectQuery) *SelectQuery + // AdditionalJoinOnConditions adds additional conditions to the JOIN ON clause. + AdditionalJoinOnConditions []schema.QueryWithArgs +} + +// RelationWithOpts adds a relation to the query with additional options. +func (q *SelectQuery) RelationWithOpts(name string, opts RelationOpts) *SelectQuery { + if q.tableModel == nil { + q.setErr(errNilModel) + return q + } + + join := q.tableModel.join(name) + if join == nil { + q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name)) + return q + } + + if opts.Apply != nil { + q.applyToRelation(join, opts.Apply) + } + + if len(opts.AdditionalJoinOnConditions) > 0 { + join.additionalJoinOnConditions = opts.AdditionalJoinOnConditions + } + + return q +} + +func (q *SelectQuery) applyToRelation(join *relationJoin, apply ...func(*SelectQuery) *SelectQuery) { var apply1, apply2 func(*SelectQuery) *SelectQuery if len(join.Relation.Condition) > 0 { @@ -407,8 +444,6 @@ func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQ return q } - - return q } func (q *SelectQuery) forEachInlineRelJoin(fn func(*relationJoin) error) error { diff --git a/relation_join.go b/relation_join.go index 19dede4f9..47f27afd5 100644 --- a/relation_join.go +++ b/relation_join.go @@ -16,6 +16,8 @@ type relationJoin struct { JoinModel TableModel Relation *schema.Relation + additionalJoinOnConditions []schema.QueryWithArgs + apply func(*SelectQuery) *SelectQuery columns []schema.QueryWithArgs } @@ -86,6 +88,11 @@ func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *Selec j.Relation.BasePKs, ) where = append(where, ")"...) + if len(j.additionalJoinOnConditions) > 0 { + where = append(where, " AND "...) + where = appendAdditionalJoinOnConditions(q.db.Formatter(), where, j.additionalJoinOnConditions) + } + q = q.Where(internal.String(where)) if j.Relation.PolymorphicField != nil { @@ -111,6 +118,10 @@ func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery q = q.Where(internal.String(where)) + if len(j.additionalJoinOnConditions) > 0 { + q = q.Where(internal.String(appendAdditionalJoinOnConditions(q.db.Formatter(), []byte{}, j.additionalJoinOnConditions))) + } + if j.Relation.PolymorphicField != nil { q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) } @@ -204,6 +215,12 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { join = append(join, ") IN ("...) join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, j.Relation.BasePKs) join = append(join, ")"...) + + if len(j.additionalJoinOnConditions) > 0 { + join = append(join, " AND "...) + join = appendAdditionalJoinOnConditions(fmter, join, j.additionalJoinOnConditions) + } + q = q.Join(internal.String(join)) joinTable := j.JoinModel.Table() @@ -330,6 +347,11 @@ func (j *relationJoin) appendHasOneJoin( b = j.appendSoftDelete(fmter, b, q.flags) } + if len(j.additionalJoinOnConditions) > 0 { + b = append(b, " AND "...) + b = appendAdditionalJoinOnConditions(fmter, b, j.additionalJoinOnConditions) + } + return b, nil } @@ -417,3 +439,15 @@ func appendMultiValues( b = append(b, ')') return b } + +func appendAdditionalJoinOnConditions( + fmter schema.Formatter, b []byte, conditions []schema.QueryWithArgs, +) []byte { + for i, cond := range conditions { + if i > 0 { + b = append(b, " AND "...) + } + b = fmter.AppendQuery(b, cond.Query, cond.Args...) + } + return b +}