Skip to content

Commit

Permalink
Merge pull request #1097 from takaaa220/add_conditions_to_relation
Browse files Browse the repository at this point in the history
Add RelationWithOpts to SelectQuery for Customizing JOIN ON Clause
  • Loading branch information
j2gg0s authored Jan 3, 2025
2 parents 254c441 + dd3ef52 commit cbb687d
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 2 deletions.
148 changes: 148 additions & 0 deletions internal/dbtest/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/uptrace/bun/dbfixture"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/schema"
)

func TestORM(t *testing.T) {
Expand All @@ -34,6 +35,8 @@ func TestORM(t *testing.T) {
{testRelationBelongsToSelf},
{testCompositeHasMany},
{testCompositeM2M},
{testHasOneRelationWithOpts},
{testHasManyRelationWithOpts},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -519,6 +522,151 @@ func testCompositeM2M(t *testing.T, db *bun.DB) {
require.Equal(t, 1, len(ordersOut2[0].Items))
}

func testHasOneRelationWithOpts(t *testing.T, db *bun.DB) {
type Profile struct {
ID int64 `bun:",pk"`
Lang string
UserID int64
}

type User struct {
bun.BaseModel `bun:"alias:u"`
ID int64 `bun:",pk"`
Name string
Profile *Profile `bun:"rel:has-one,join:id=user_id"`
}

mustResetModel(t, ctx, db, (*User)(nil), (*Profile)(nil))

users := []*User{
{ID: 1, Name: "user 1"},
{ID: 2, Name: "user 2"},
{ID: 3, Name: "user 3"},
}
_, err := db.NewInsert().Model(&users).Exec(ctx)
require.NoError(t, err)

profiles := []*Profile{
{ID: 1, Lang: "en", UserID: 1},
{ID: 2, Lang: "ru", UserID: 2},
{ID: 3, Lang: "md", UserID: 3},
}
_, err = db.NewInsert().Model(&profiles).Exec(ctx)
require.NoError(t, err)

var outUsers1 []*User
err = db.
NewSelect().
Model(&outUsers1).
RelationWithOpts("Profile", bun.RelationOpts{
AdditionalJoinOnConditions: []schema.QueryWithArgs{
{
Query: "profile.lang = ?",
Args: []any{"ru"},
},
},
}).
Where("u.id IN (?)", bun.In([]int64{1, 2})).
Scan(ctx)
require.NoError(t, err)
require.Len(t, outUsers1, 2)
require.ElementsMatch(t, []*User{
{ID: 1, Name: "user 1", Profile: nil},
{ID: 2, Name: "user 2", Profile: &Profile{ID: 2, Lang: "ru", UserID: 2}},
}, outUsers1)

var outUsers2 []*User
err = db.
NewSelect().
Model(&outUsers2).
RelationWithOpts("Profile", bun.RelationOpts{
Apply: func(q *bun.SelectQuery) *bun.SelectQuery {
return q.Where("profile.lang = ?", "ru")
},
}).
Where("u.id IN (?)", bun.In([]int64{1, 2})).
Scan(ctx)
require.NoError(t, err)
require.Len(t, outUsers2, 1)
require.ElementsMatch(t, []*User{
{ID: 2, Name: "user 2", Profile: &Profile{ID: 2, Lang: "ru", UserID: 2}},
}, outUsers2)
}

func testHasManyRelationWithOpts(t *testing.T, db *bun.DB) {
type Profile struct {
ID int64 `bun:",pk"`
Name string
Lang string
Active bool
UserID int64
}

type User struct {
bun.BaseModel `bun:"alias:u"`
ID int64 `bun:",pk"`
Name string
Profiles []*Profile `bun:"rel:has-many,join:id=user_id"`
}

mustResetModel(t, ctx, db, (*User)(nil), (*Profile)(nil))

users := []*User{
{ID: 1, Name: "user 1"},
{ID: 2, Name: "user 2"},
{ID: 3, Name: "user 3"},
}
_, err := db.NewInsert().Model(&users).Exec(ctx)
require.NoError(t, err)

profiles := []*Profile{
{ID: 1, Name: "name1-en", Lang: "en", UserID: 1},
{ID: 2, Name: "name2-ru", Lang: "ru", UserID: 2},
{ID: 3, Name: "name2-ja", Lang: "ja", UserID: 2},
{ID: 4, Name: "name3-md", Lang: "md", UserID: 3},
{ID: 5, Name: "name3-en", Lang: "en", UserID: 3},
}
_, err = db.NewInsert().Model(&profiles).Exec(ctx)
require.NoError(t, err)

var outUsers1 []*User
err = db.
NewSelect().
Model(&outUsers1).
RelationWithOpts("Profiles", bun.RelationOpts{
AdditionalJoinOnConditions: []schema.QueryWithArgs{
{
Query: "profile.lang = ?",
Args: []any{"ru"},
},
},
}).
Where("u.id IN (?)", bun.In([]int64{1, 2})).
Scan(ctx)
require.NoError(t, err)
require.Equal(t, []*User{
{ID: 1, Name: "user 1", Profiles: nil},
{ID: 2, Name: "user 2", Profiles: []*Profile{{ID: 2, Name: "name2-ru", Lang: "ru", UserID: 2}}},
}, outUsers1)

var outUsers2 []*User
err = db.
NewSelect().
Model(&outUsers2).
RelationWithOpts("Profiles", bun.RelationOpts{
Apply: func(q *bun.SelectQuery) *bun.SelectQuery {
return q.Where("profile.lang = ?", "ru")
},
}).
Where("u.id IN (?)", bun.In([]int64{1, 2})).
Scan(ctx)
require.NoError(t, err)
require.Equal(t, []*User{
{ID: 1, Name: "user 1", Profiles: nil},
{ID: 2, Name: "user 2", Profiles: []*Profile{{ID: 2, Name: "name2-ru", Lang: "ru", UserID: 2}}},
}, outUsers2)
}

type Genre struct {
ID int `bun:",pk"`
Name string
Expand Down
39 changes: 37 additions & 2 deletions query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,43 @@ func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQ
return q
}

q.applyToRelation(join, apply...)

return q
}

type RelationOpts struct {
// Apply applies additional options to the relation.
Apply func(*SelectQuery) *SelectQuery
// AdditionalJoinOnConditions adds additional conditions to the JOIN ON clause.
AdditionalJoinOnConditions []schema.QueryWithArgs
}

// RelationWithOpts adds a relation to the query with additional options.
func (q *SelectQuery) RelationWithOpts(name string, opts RelationOpts) *SelectQuery {
if q.tableModel == nil {
q.setErr(errNilModel)
return q
}

join := q.tableModel.join(name)
if join == nil {
q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name))
return q
}

if opts.Apply != nil {
q.applyToRelation(join, opts.Apply)
}

if len(opts.AdditionalJoinOnConditions) > 0 {
join.additionalJoinOnConditions = opts.AdditionalJoinOnConditions
}

return q
}

func (q *SelectQuery) applyToRelation(join *relationJoin, apply ...func(*SelectQuery) *SelectQuery) {
var apply1, apply2 func(*SelectQuery) *SelectQuery

if len(join.Relation.Condition) > 0 {
Expand All @@ -407,8 +444,6 @@ func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQ

return q
}

return q
}

func (q *SelectQuery) forEachInlineRelJoin(fn func(*relationJoin) error) error {
Expand Down
34 changes: 34 additions & 0 deletions relation_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ type relationJoin struct {
JoinModel TableModel
Relation *schema.Relation

additionalJoinOnConditions []schema.QueryWithArgs

apply func(*SelectQuery) *SelectQuery
columns []schema.QueryWithArgs
}
Expand Down Expand Up @@ -86,6 +88,11 @@ func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *Selec
j.Relation.BasePKs,
)
where = append(where, ")"...)
if len(j.additionalJoinOnConditions) > 0 {
where = append(where, " AND "...)
where = appendAdditionalJoinOnConditions(q.db.Formatter(), where, j.additionalJoinOnConditions)
}

q = q.Where(internal.String(where))

if j.Relation.PolymorphicField != nil {
Expand All @@ -111,6 +118,10 @@ func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery

q = q.Where(internal.String(where))

if len(j.additionalJoinOnConditions) > 0 {
q = q.Where(internal.String(appendAdditionalJoinOnConditions(q.db.Formatter(), []byte{}, j.additionalJoinOnConditions)))
}

if j.Relation.PolymorphicField != nil {
q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue)
}
Expand Down Expand Up @@ -204,6 +215,12 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery {
join = append(join, ") IN ("...)
join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, j.Relation.BasePKs)
join = append(join, ")"...)

if len(j.additionalJoinOnConditions) > 0 {
join = append(join, " AND "...)
join = appendAdditionalJoinOnConditions(fmter, join, j.additionalJoinOnConditions)
}

q = q.Join(internal.String(join))

joinTable := j.JoinModel.Table()
Expand Down Expand Up @@ -330,6 +347,11 @@ func (j *relationJoin) appendHasOneJoin(
b = j.appendSoftDelete(fmter, b, q.flags)
}

if len(j.additionalJoinOnConditions) > 0 {
b = append(b, " AND "...)
b = appendAdditionalJoinOnConditions(fmter, b, j.additionalJoinOnConditions)
}

return b, nil
}

Expand Down Expand Up @@ -417,3 +439,15 @@ func appendMultiValues(
b = append(b, ')')
return b
}

func appendAdditionalJoinOnConditions(
fmter schema.Formatter, b []byte, conditions []schema.QueryWithArgs,
) []byte {
for i, cond := range conditions {
if i > 0 {
b = append(b, " AND "...)
}
b = fmter.AppendQuery(b, cond.Query, cond.Args...)
}
return b
}

0 comments on commit cbb687d

Please sign in to comment.