diff --git a/db.go b/db.go index c283f56bd..067996d1c 100644 --- a/db.go +++ b/db.go @@ -9,6 +9,7 @@ import ( "reflect" "strings" "sync/atomic" + "time" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" @@ -26,32 +27,56 @@ type DBStats struct { type DBOption func(db *DB) +func WithOptions(opts ...DBOption) DBOption { + return func(db *DB) { + for _, opt := range opts { + opt(db) + } + } +} + func WithDiscardUnknownColumns() DBOption { return func(db *DB) { db.flags = db.flags.Set(discardUnknownColumns) } } -type DB struct { - *sql.DB +func WithConnResolver(resolver ConnResolver) DBOption { + return func(db *DB) { + db.resolver = resolver + } +} - dialect schema.Dialect +type DB struct { + // Must be a pointer so we copy the whole state, not individual fields. + *noCopyState queryHooks []QueryHook fmter schema.Formatter - flags internal.Flag - stats DBStats } +// noCopyState contains DB fields that must not be copied on clone(), +// for example, it is forbidden to copy atomic.Pointer. +type noCopyState struct { + *sql.DB + dialect schema.Dialect + resolver ConnResolver + + flags internal.Flag + closed atomic.Bool +} + func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { dialect.Init(sqldb) db := &DB{ - DB: sqldb, - dialect: dialect, - fmter: schema.NewFormatter(dialect), + noCopyState: &noCopyState{ + DB: sqldb, + dialect: dialect, + }, + fmter: schema.NewFormatter(dialect), } for _, opt := range opts { @@ -69,6 +94,22 @@ func (db *DB) String() string { return b.String() } +func (db *DB) Close() error { + if db.closed.Swap(true) { + return nil + } + + firstErr := db.DB.Close() + + if db.resolver != nil { + if err := db.resolver.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + + return firstErr +} + func (db *DB) DBStats() DBStats { return DBStats{ Queries: atomic.LoadUint32(&db.stats.Queries), @@ -703,3 +744,116 @@ func (tx Tx) NewDropColumn() *DropColumnQuery { func (db *DB) makeQueryBytes() []byte { return internal.MakeQueryBytes() } + +//------------------------------------------------------------------------------ + +// ConnResolver enables routing queries to multiple databases. +type ConnResolver interface { + ResolveConn(query Query) IConn + Close() error +} + +// TODO: +// - make monitoring interval configurable +// - make ping timeout configutable +// - allow adding read/write replicas for multi-master replication +type ReadWriteConnResolver struct { + replicas []*sql.DB // read-only replicas + healthyReplicas atomic.Pointer[[]*sql.DB] + nextReplica atomic.Int64 + closed atomic.Bool +} + +func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteConnResolver { + r := new(ReadWriteConnResolver) + + for _, opt := range opts { + opt(r) + } + + if len(r.replicas) > 0 { + r.healthyReplicas.Store(&r.replicas) + go r.monitor() + } + + return r +} + +type ReadWriteConnResolverOption func(r *ReadWriteConnResolver) + +func WithReadOnlyReplica(dbs ...*sql.DB) ReadWriteConnResolverOption { + return func(r *ReadWriteConnResolver) { + r.replicas = append(r.replicas, dbs...) + } +} + +func (r *ReadWriteConnResolver) Close() error { + if r.closed.Swap(true) { + return nil + } + + var firstErr error + for _, db := range r.replicas { + if err := db.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +// healthyReplica returns a random healthy replica. +func (r *ReadWriteConnResolver) ResolveConn(query Query) IConn { + if len(r.replicas) == 0 || !isReadOnlyQuery(query) { + return nil + } + + replicas := r.loadHealthyReplicas() + if len(replicas) == 0 { + return nil + } + if len(replicas) == 1 { + return replicas[0] + } + i := r.nextReplica.Add(1) + return replicas[int(i)%len(replicas)] +} + +func isReadOnlyQuery(query Query) bool { + sel, ok := query.(*SelectQuery) + if !ok { + return false + } + for _, el := range sel.with { + if !isReadOnlyQuery(el.query) { + return false + } + } + return true +} + +func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB { + if ptr := r.healthyReplicas.Load(); ptr != nil { + return *ptr + } + return nil +} + +func (r *ReadWriteConnResolver) monitor() { + const interval = 5 * time.Second + for !r.closed.Load() { + healthy := make([]*sql.DB, 0, len(r.replicas)) + + for _, replica := range r.replicas { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + err := replica.PingContext(ctx) + cancel() + + if err == nil { + healthy = append(healthy, replica) + } + } + + r.healthyReplicas.Store(&healthy) + time.Sleep(interval) + } +} diff --git a/driver/pgdriver/config.go b/driver/pgdriver/config.go index ccc038b21..50424dd48 100644 --- a/driver/pgdriver/config.go +++ b/driver/pgdriver/config.go @@ -50,7 +50,7 @@ func newDefaultConfig() *Config { host := env("PGHOST", "localhost") port := env("PGPORT", "5432") - cfg := &Config{ + conf := &Config{ Network: "tcp", Addr: net.JoinHostPort(host, port), DialTimeout: 5 * time.Second, @@ -63,28 +63,33 @@ func newDefaultConfig() *Config { WriteTimeout: 5 * time.Second, } - cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { + conf.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { netDialer := &net.Dialer{ - Timeout: cfg.DialTimeout, + Timeout: conf.DialTimeout, KeepAlive: 5 * time.Minute, } return netDialer.DialContext(ctx, network, addr) } - return cfg + return conf } -type Option func(cfg *Config) +type Option func(conf *Config) -// Deprecated. Use Option instead. -type DriverOption = Option +func WithOptions(opts ...Option) Option { + return func(conf *Config) { + for _, opt := range opts { + opt(conf) + } + } +} func WithNetwork(network string) Option { if network == "" { panic("network is empty") } - return func(cfg *Config) { - cfg.Network = network + return func(conf *Config) { + conf.Network = network } } @@ -92,23 +97,23 @@ func WithAddr(addr string) Option { if addr == "" { panic("addr is empty") } - return func(cfg *Config) { - cfg.Addr = addr + return func(conf *Config) { + conf.Addr = addr } } func WithTLSConfig(tlsConfig *tls.Config) Option { - return func(cfg *Config) { - cfg.TLSConfig = tlsConfig + return func(conf *Config) { + conf.TLSConfig = tlsConfig } } func WithInsecure(on bool) Option { - return func(cfg *Config) { + return func(conf *Config) { if on { - cfg.TLSConfig = nil + conf.TLSConfig = nil } else { - cfg.TLSConfig = &tls.Config{InsecureSkipVerify: true} + conf.TLSConfig = &tls.Config{InsecureSkipVerify: true} } } } @@ -117,14 +122,14 @@ func WithUser(user string) Option { if user == "" { panic("user is empty") } - return func(cfg *Config) { - cfg.User = user + return func(conf *Config) { + conf.User = user } } func WithPassword(password string) Option { - return func(cfg *Config) { - cfg.Password = password + return func(conf *Config) { + conf.Password = password } } @@ -132,46 +137,46 @@ func WithDatabase(database string) Option { if database == "" { panic("database is empty") } - return func(cfg *Config) { - cfg.Database = database + return func(conf *Config) { + conf.Database = database } } func WithApplicationName(appName string) Option { - return func(cfg *Config) { - cfg.AppName = appName + return func(conf *Config) { + conf.AppName = appName } } func WithConnParams(params map[string]interface{}) Option { - return func(cfg *Config) { - cfg.ConnParams = params + return func(conf *Config) { + conf.ConnParams = params } } func WithTimeout(timeout time.Duration) Option { - return func(cfg *Config) { - cfg.DialTimeout = timeout - cfg.ReadTimeout = timeout - cfg.WriteTimeout = timeout + return func(conf *Config) { + conf.DialTimeout = timeout + conf.ReadTimeout = timeout + conf.WriteTimeout = timeout } } func WithDialTimeout(dialTimeout time.Duration) Option { - return func(cfg *Config) { - cfg.DialTimeout = dialTimeout + return func(conf *Config) { + conf.DialTimeout = dialTimeout } } func WithReadTimeout(readTimeout time.Duration) Option { - return func(cfg *Config) { - cfg.ReadTimeout = readTimeout + return func(conf *Config) { + conf.ReadTimeout = readTimeout } } func WithWriteTimeout(writeTimeout time.Duration) Option { - return func(cfg *Config) { - cfg.WriteTimeout = writeTimeout + return func(conf *Config) { + conf.WriteTimeout = writeTimeout } } @@ -179,19 +184,19 @@ func WithWriteTimeout(writeTimeout time.Duration) Option { // a query on a connection that has been used before. // If the func returns driver.ErrBadConn, the connection is discarded. func WithResetSessionFunc(fn func(context.Context, *Conn) error) Option { - return func(cfg *Config) { - cfg.ResetSessionFunc = fn + return func(conf *Config) { + conf.ResetSessionFunc = fn } } func WithDSN(dsn string) Option { - return func(cfg *Config) { + return func(conf *Config) { opts, err := parseDSN(dsn) if err != nil { panic(err) } for _, opt := range opts { - opt(cfg) + opt(conf) } } } diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 0d423d4c5..225c278ab 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -1845,3 +1845,35 @@ func mustDropTableOnCleanup(tb testing.TB, ctx context.Context, db *bun.DB, mode } }) } + +func TestConnResolver(t *testing.T) { + dsn := os.Getenv("PG") + if dsn == "" { + dsn = "postgres://postgres:postgres@localhost:5432/test?sslmode=disable" + } + + rwdb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) + t.Cleanup(func() { + require.NoError(t, rwdb.Close()) + }) + + rodb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) + t.Cleanup(func() { + require.NoError(t, rodb.Close()) + }) + + resolver := bun.NewReadWriteConnResolver(bun.WithReadOnlyReplica(rodb)) + + db := bun.NewDB(rwdb, pgdialect.New(), bun.WithConnResolver(resolver)) + db.AddQueryHook(bundebug.NewQueryHook( + bundebug.WithEnabled(false), + bundebug.FromEnv(), + )) + + var num int + err := db.NewSelect().ColumnExpr("1").Scan(ctx, &num) + require.NoError(t, err) + require.Equal(t, 1, num) + require.GreaterOrEqual(t, rodb.Stats().OpenConnections, 1) + require.Equal(t, 0, rwdb.Stats().OpenConnections) +} diff --git a/internal/dbtest/docker-compose.yaml b/internal/dbtest/docker-compose.yaml index d47ab5b31..0223bf704 100755 --- a/internal/dbtest/docker-compose.yaml +++ b/internal/dbtest/docker-compose.yaml @@ -1,5 +1,3 @@ -version: '3.9' - services: mysql8: image: mysql:8.0 diff --git a/query_base.go b/query_base.go index 08ff8e5d9..b17498742 100644 --- a/query_base.go +++ b/query_base.go @@ -24,7 +24,7 @@ const ( type withQuery struct { name string - query schema.QueryAppender + query Query recursive bool } @@ -114,8 +114,16 @@ func (q *baseQuery) DB() *DB { return q.db } -func (q *baseQuery) GetConn() IConn { - return q.conn +func (q *baseQuery) resolveConn(query Query) IConn { + if q.conn != nil { + return q.conn + } + if q.db.resolver != nil { + if conn := q.db.resolver.ResolveConn(query); conn != nil { + return conn + } + } + return q.db.DB } func (q *baseQuery) GetModel() Model { @@ -128,10 +136,8 @@ func (q *baseQuery) GetTableName() string { } for _, wq := range q.with { - if v, ok := wq.query.(Query); ok { - if model := v.GetModel(); model != nil { - return v.GetTableName() - } + if model := wq.query.GetModel(); model != nil { + return wq.query.GetTableName() } } @@ -249,7 +255,7 @@ func (q *baseQuery) isSoftDelete() bool { //------------------------------------------------------------------------------ -func (q *baseQuery) addWith(name string, query schema.QueryAppender, recursive bool) { +func (q *baseQuery) addWith(name string, query Query, recursive bool) { q.with = append(q.with, withQuery{ name: name, query: query, @@ -565,28 +571,33 @@ func (q *baseQuery) scan( hasDest bool, ) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) + res, err := q._scan(ctx, iquery, query, model, hasDest) + q.db.afterQuery(ctx, event, res, err) + return res, err +} - rows, err := q.conn.QueryContext(ctx, query) +func (q *baseQuery) _scan( + ctx context.Context, + iquery Query, + query string, + model Model, + hasDest bool, +) (sql.Result, error) { + rows, err := q.resolveConn(iquery).QueryContext(ctx, query) if err != nil { - q.db.afterQuery(ctx, event, nil, err) return nil, err } defer rows.Close() numRow, err := model.ScanRows(ctx, rows) if err != nil { - q.db.afterQuery(ctx, event, nil, err) return nil, err } if numRow == 0 && hasDest && isSingleRowModel(model) { - err = sql.ErrNoRows + return nil, sql.ErrNoRows } - - res := driver.RowsAffected(numRow) - q.db.afterQuery(ctx, event, res, err) - - return res, err + return driver.RowsAffected(numRow), nil } func (q *baseQuery) exec( @@ -595,7 +606,7 @@ func (q *baseQuery) exec( query string, ) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) - res, err := q.conn.ExecContext(ctx, query) + res, err := q.resolveConn(iquery).ExecContext(ctx, query) q.db.afterQuery(ctx, event, res, err) return res, err } diff --git a/query_column_add.go b/query_column_add.go index aa9aacf35..c3c781a1d 100644 --- a/query_column_add.go +++ b/query_column_add.go @@ -22,8 +22,7 @@ var _ Query = (*AddColumnQuery)(nil) func NewAddColumnQuery(db *DB) *AddColumnQuery { q := &AddColumnQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_column_drop.go b/query_column_drop.go index 986bffed3..e66e35b9a 100644 --- a/query_column_drop.go +++ b/query_column_drop.go @@ -20,8 +20,7 @@ var _ Query = (*DropColumnQuery)(nil) func NewDropColumnQuery(db *DB) *DropColumnQuery { q := &DropColumnQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_delete.go b/query_delete.go index 8fffe448b..99ec37bb7 100644 --- a/query_delete.go +++ b/query_delete.go @@ -25,8 +25,7 @@ func NewDeleteQuery(db *DB) *DeleteQuery { q := &DeleteQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -58,12 +57,12 @@ func (q *DeleteQuery) Apply(fns ...func(*DeleteQuery) *DeleteQuery) *DeleteQuery return q } -func (q *DeleteQuery) With(name string, query schema.QueryAppender) *DeleteQuery { +func (q *DeleteQuery) With(name string, query Query) *DeleteQuery { q.addWith(name, query, false) return q } -func (q *DeleteQuery) WithRecursive(name string, query schema.QueryAppender) *DeleteQuery { +func (q *DeleteQuery) WithRecursive(name string, query Query) *DeleteQuery { q.addWith(name, query, true) return q } diff --git a/query_index_create.go b/query_index_create.go index ad1905cc3..4ac4ffd10 100644 --- a/query_index_create.go +++ b/query_index_create.go @@ -29,8 +29,7 @@ func NewCreateIndexQuery(db *DB) *CreateIndexQuery { q := &CreateIndexQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } diff --git a/query_index_drop.go b/query_index_drop.go index a2a23fb8a..27c6e7f67 100644 --- a/query_index_drop.go +++ b/query_index_drop.go @@ -24,8 +24,7 @@ var _ Query = (*DropIndexQuery)(nil) func NewDropIndexQuery(db *DB) *DropIndexQuery { q := &DropIndexQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_insert.go b/query_insert.go index 63e84545a..d2e158d77 100644 --- a/query_insert.go +++ b/query_insert.go @@ -31,8 +31,7 @@ func NewInsertQuery(db *DB) *InsertQuery { q := &InsertQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -64,12 +63,12 @@ func (q *InsertQuery) Apply(fns ...func(*InsertQuery) *InsertQuery) *InsertQuery return q } -func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery { +func (q *InsertQuery) With(name string, query Query) *InsertQuery { q.addWith(name, query, false) return q } -func (q *InsertQuery) WithRecursive(name string, query schema.QueryAppender) *InsertQuery { +func (q *InsertQuery) WithRecursive(name string, query Query) *InsertQuery { q.addWith(name, query, true) return q } diff --git a/query_merge.go b/query_merge.go index 7dee02002..0c172f180 100644 --- a/query_merge.go +++ b/query_merge.go @@ -26,8 +26,7 @@ var _ Query = (*MergeQuery)(nil) func NewMergeQuery(db *DB) *MergeQuery { q := &MergeQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } if q.db.dialect.Name() != dialect.MSSQL && q.db.dialect.Name() != dialect.PG { @@ -61,12 +60,12 @@ func (q *MergeQuery) Apply(fns ...func(*MergeQuery) *MergeQuery) *MergeQuery { return q } -func (q *MergeQuery) With(name string, query schema.QueryAppender) *MergeQuery { +func (q *MergeQuery) With(name string, query Query) *MergeQuery { q.addWith(name, query, false) return q } -func (q *MergeQuery) WithRecursive(name string, query schema.QueryAppender) *MergeQuery { +func (q *MergeQuery) WithRecursive(name string, query Query) *MergeQuery { q.addWith(name, query, true) return q } diff --git a/query_raw.go b/query_raw.go index 8c3a6a7f8..308329567 100644 --- a/query_raw.go +++ b/query_raw.go @@ -15,23 +15,10 @@ type RawQuery struct { comment string } -// Deprecated: Use NewRaw instead. When add it to IDB, it conflicts with the sql.Conn#Raw -func (db *DB) Raw(query string, args ...interface{}) *RawQuery { - return &RawQuery{ - baseQuery: baseQuery{ - db: db, - conn: db.DB, - }, - query: query, - args: args, - } -} - func NewRawQuery(db *DB, query string, args ...interface{}) *RawQuery { return &RawQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, query: query, args: args, diff --git a/query_select.go b/query_select.go index b079537c3..1ef7e3bb1 100644 --- a/query_select.go +++ b/query_select.go @@ -41,8 +41,7 @@ func NewSelectQuery(db *DB) *SelectQuery { return &SelectQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -73,12 +72,12 @@ func (q *SelectQuery) Apply(fns ...func(*SelectQuery) *SelectQuery) *SelectQuery return q } -func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery { +func (q *SelectQuery) With(name string, query Query) *SelectQuery { q.addWith(name, query, false) return q } -func (q *SelectQuery) WithRecursive(name string, query schema.QueryAppender) *SelectQuery { +func (q *SelectQuery) WithRecursive(name string, query Query) *SelectQuery { q.addWith(name, query, true) return q } @@ -800,7 +799,7 @@ func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) { query := internal.String(queryBytes) ctx, event := q.db.beforeQuery(ctx, q, query, nil, query, q.model) - rows, err := q.conn.QueryContext(ctx, query) + rows, err := q.resolveConn(q).QueryContext(ctx, query) q.db.afterQuery(ctx, event, nil, err) return rows, err } @@ -936,7 +935,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var num int - err = q.conn.QueryRowContext(ctx, query).Scan(&num) + err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&num) q.db.afterQuery(ctx, event, nil, err) @@ -954,13 +953,15 @@ func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (in return int(n), nil } } - if _, ok := q.conn.(*DB); ok { - return q.scanAndCountConc(ctx, dest...) + if q.conn == nil { + return q.scanAndCountConcurrently(ctx, dest...) } return q.scanAndCountSeq(ctx, dest...) } -func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) { +func (q *SelectQuery) scanAndCountConcurrently( + ctx context.Context, dest ...interface{}, +) (int, error) { var count int var wg sync.WaitGroup var mu sync.Mutex @@ -1038,7 +1039,7 @@ func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) { ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var exists bool - err = q.conn.QueryRowContext(ctx, query).Scan(&exists) + err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&exists) q.db.afterQuery(ctx, event, nil, err) diff --git a/query_table_create.go b/query_table_create.go index 2c7855e7a..d8c4566cb 100644 --- a/query_table_create.go +++ b/query_table_create.go @@ -40,8 +40,7 @@ var _ Query = (*CreateTableQuery)(nil) func NewCreateTableQuery(db *DB) *CreateTableQuery { q := &CreateTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, varchar: db.Dialect().DefaultVarcharLen(), } diff --git a/query_table_drop.go b/query_table_drop.go index 01f000293..4e7d305a9 100644 --- a/query_table_drop.go +++ b/query_table_drop.go @@ -21,8 +21,7 @@ var _ Query = (*DropTableQuery)(nil) func NewDropTableQuery(db *DB) *DropTableQuery { q := &DropTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_table_truncate.go b/query_table_truncate.go index 7ee5d2a8d..0f30a1d04 100644 --- a/query_table_truncate.go +++ b/query_table_truncate.go @@ -22,8 +22,7 @@ var _ Query = (*TruncateTableQuery)(nil) func NewTruncateTableQuery(db *DB) *TruncateTableQuery { q := &TruncateTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_update.go b/query_update.go index c16b751ae..b700f2180 100644 --- a/query_update.go +++ b/query_update.go @@ -32,8 +32,7 @@ func NewUpdateQuery(db *DB) *UpdateQuery { q := &UpdateQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -65,12 +64,12 @@ func (q *UpdateQuery) Apply(fns ...func(*UpdateQuery) *UpdateQuery) *UpdateQuery return q } -func (q *UpdateQuery) With(name string, query schema.QueryAppender) *UpdateQuery { +func (q *UpdateQuery) With(name string, query Query) *UpdateQuery { q.addWith(name, query, false) return q } -func (q *UpdateQuery) WithRecursive(name string, query schema.QueryAppender) *UpdateQuery { +func (q *UpdateQuery) WithRecursive(name string, query Query) *UpdateQuery { q.addWith(name, query, true) return q } diff --git a/query_values.go b/query_values.go index 97fbc65fa..db6c852c3 100644 --- a/query_values.go +++ b/query_values.go @@ -25,8 +25,7 @@ var ( func NewValuesQuery(db *DB, model interface{}) *ValuesQuery { q := &ValuesQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } q.setModel(model)