From cbbe1e94fd0c72d1870395a663c8053d7e8c6ace Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Wed, 4 Dec 2024 15:16:31 +0200 Subject: [PATCH 01/10] feat: allow to specify read-only replica for SELECTs --- db.go | 64 +++++++++++++++++++++++++++-- internal/dbtest/docker-compose.yaml | 2 - query_base.go | 52 ++++++++++++++++------- query_column_add.go | 3 +- query_column_drop.go | 3 +- query_delete.go | 7 ++-- query_index_create.go | 3 +- query_index_drop.go | 3 +- query_insert.go | 7 ++-- query_merge.go | 7 ++-- query_raw.go | 15 +------ query_select.go | 7 ++-- query_table_create.go | 3 +- query_table_drop.go | 3 +- query_table_truncate.go | 3 +- query_update.go | 7 ++-- query_values.go | 3 +- 17 files changed, 123 insertions(+), 69 deletions(-) diff --git a/db.go b/db.go index c283f56bd..a19f4cc3e 100644 --- a/db.go +++ b/db.go @@ -9,6 +9,7 @@ import ( "reflect" "strings" "sync/atomic" + "time" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" @@ -32,15 +33,25 @@ func WithDiscardUnknownColumns() DBOption { } } +func WithReadOnlyReplica(replica *sql.DB) DBOption { + return func(db *DB) { + db.replicas = append(db.replicas, replica) + } +} + type DB struct { *sql.DB - dialect schema.Dialect + replicas []*sql.DB + healthyReplicas atomic.Pointer[[]*sql.DB] + nextReplica atomic.Int64 + dialect schema.Dialect queryHooks []QueryHook - fmter schema.Formatter - flags internal.Flag + fmter schema.Formatter + flags internal.Flag + closed atomic.Bool stats DBStats } @@ -58,6 +69,10 @@ func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { opt(db) } + if len(db.replicas) > 0 { + go db.monitorReplicas() + } + return db } @@ -69,6 +84,11 @@ func (db *DB) String() string { return b.String() } +func (db *DB) Close() error { + db.closed.Store(true) + return db.DB.Close() +} + func (db *DB) DBStats() DBStats { return DBStats{ Queries: atomic.LoadUint32(&db.stats.Queries), @@ -232,6 +252,44 @@ func (db *DB) HasFeature(feat feature.Feature) bool { return db.dialect.Features().Has(feat) } +// healthyReplica returns a random healthy replica. +func (db *DB) healthyReplica() *sql.DB { + replicas := db.loadHealthyReplicas() + if len(replicas) == 0 { + return db.DB + } + if len(replicas) == 1 { + return replicas[0] + } + i := db.nextReplica.Add(1) + return replicas[int(i)%len(replicas)] +} + +func (db *DB) loadHealthyReplicas() []*sql.DB { + if ptr := db.healthyReplicas.Load(); ptr != nil { + return *ptr + } + return nil +} + +func (db *DB) monitorReplicas() { + for !db.closed.Load() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + healthy := make([]*sql.DB, 0, len(db.replicas)) + + for _, replica := range db.replicas { + if err := replica.PingContext(ctx); err == nil { + healthy = append(healthy, replica) + } + } + + db.healthyReplicas.Store(&healthy) + time.Sleep(5 * time.Second) + } +} + //------------------------------------------------------------------------------ func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { diff --git a/internal/dbtest/docker-compose.yaml b/internal/dbtest/docker-compose.yaml index d47ab5b31..0223bf704 100755 --- a/internal/dbtest/docker-compose.yaml +++ b/internal/dbtest/docker-compose.yaml @@ -1,5 +1,3 @@ -version: '3.9' - services: mysql8: image: mysql:8.0 diff --git a/query_base.go b/query_base.go index 08ff8e5d9..70bedb999 100644 --- a/query_base.go +++ b/query_base.go @@ -24,7 +24,7 @@ const ( type withQuery struct { name string - query schema.QueryAppender + query Query recursive bool } @@ -114,8 +114,27 @@ func (q *baseQuery) DB() *DB { return q.db } -func (q *baseQuery) GetConn() IConn { - return q.conn +func (q *baseQuery) resolveConn(query Query) IConn { + if q.conn != nil { + return q.conn + } + if len(q.db.replicas) == 0 || !isReadOnlyQuery(query) { + return q.db.DB + } + return q.db.healthyReplica() +} + +func isReadOnlyQuery(query Query) bool { + sel, ok := query.(*SelectQuery) + if !ok { + return false + } + for _, el := range sel.with { + if !isReadOnlyQuery(el.query) { + return false + } + } + return true } func (q *baseQuery) GetModel() Model { @@ -249,7 +268,7 @@ func (q *baseQuery) isSoftDelete() bool { //------------------------------------------------------------------------------ -func (q *baseQuery) addWith(name string, query schema.QueryAppender, recursive bool) { +func (q *baseQuery) addWith(name string, query Query, recursive bool) { q.with = append(q.with, withQuery{ name: name, query: query, @@ -565,28 +584,33 @@ func (q *baseQuery) scan( hasDest bool, ) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) + res, err := q._scan(ctx, iquery, query, model, hasDest) + q.db.afterQuery(ctx, event, res, err) + return res, err +} - rows, err := q.conn.QueryContext(ctx, query) +func (q *baseQuery) _scan( + ctx context.Context, + iquery Query, + query string, + model Model, + hasDest bool, +) (sql.Result, error) { + rows, err := q.resolveConn(iquery).QueryContext(ctx, query) if err != nil { - q.db.afterQuery(ctx, event, nil, err) return nil, err } defer rows.Close() numRow, err := model.ScanRows(ctx, rows) if err != nil { - q.db.afterQuery(ctx, event, nil, err) return nil, err } if numRow == 0 && hasDest && isSingleRowModel(model) { - err = sql.ErrNoRows + return nil, sql.ErrNoRows } - - res := driver.RowsAffected(numRow) - q.db.afterQuery(ctx, event, res, err) - - return res, err + return driver.RowsAffected(numRow), nil } func (q *baseQuery) exec( @@ -595,7 +619,7 @@ func (q *baseQuery) exec( query string, ) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) - res, err := q.conn.ExecContext(ctx, query) + res, err := q.resolveConn(iquery).ExecContext(ctx, query) q.db.afterQuery(ctx, event, res, err) return res, err } diff --git a/query_column_add.go b/query_column_add.go index 50576873c..de4ff15fe 100644 --- a/query_column_add.go +++ b/query_column_add.go @@ -20,8 +20,7 @@ var _ Query = (*AddColumnQuery)(nil) func NewAddColumnQuery(db *DB) *AddColumnQuery { q := &AddColumnQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_column_drop.go b/query_column_drop.go index 24fc93cfd..a67084a63 100644 --- a/query_column_drop.go +++ b/query_column_drop.go @@ -18,8 +18,7 @@ var _ Query = (*DropColumnQuery)(nil) func NewDropColumnQuery(db *DB) *DropColumnQuery { q := &DropColumnQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_delete.go b/query_delete.go index 1235ba718..3467fdb83 100644 --- a/query_delete.go +++ b/query_delete.go @@ -23,8 +23,7 @@ func NewDeleteQuery(db *DB) *DeleteQuery { q := &DeleteQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -56,12 +55,12 @@ func (q *DeleteQuery) Apply(fns ...func(*DeleteQuery) *DeleteQuery) *DeleteQuery return q } -func (q *DeleteQuery) With(name string, query schema.QueryAppender) *DeleteQuery { +func (q *DeleteQuery) With(name string, query Query) *DeleteQuery { q.addWith(name, query, false) return q } -func (q *DeleteQuery) WithRecursive(name string, query schema.QueryAppender) *DeleteQuery { +func (q *DeleteQuery) WithRecursive(name string, query Query) *DeleteQuery { q.addWith(name, query, true) return q } diff --git a/query_index_create.go b/query_index_create.go index 11824cfa4..f229bb5c7 100644 --- a/query_index_create.go +++ b/query_index_create.go @@ -28,8 +28,7 @@ func NewCreateIndexQuery(db *DB) *CreateIndexQuery { q := &CreateIndexQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } diff --git a/query_index_drop.go b/query_index_drop.go index ae28e7956..6300bb67f 100644 --- a/query_index_drop.go +++ b/query_index_drop.go @@ -23,8 +23,7 @@ var _ Query = (*DropIndexQuery)(nil) func NewDropIndexQuery(db *DB) *DropIndexQuery { q := &DropIndexQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_insert.go b/query_insert.go index 8bec4ce26..3013a51d0 100644 --- a/query_insert.go +++ b/query_insert.go @@ -30,8 +30,7 @@ func NewInsertQuery(db *DB) *InsertQuery { q := &InsertQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -63,12 +62,12 @@ func (q *InsertQuery) Apply(fns ...func(*InsertQuery) *InsertQuery) *InsertQuery return q } -func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery { +func (q *InsertQuery) With(name string, query Query) *InsertQuery { q.addWith(name, query, false) return q } -func (q *InsertQuery) WithRecursive(name string, query schema.QueryAppender) *InsertQuery { +func (q *InsertQuery) WithRecursive(name string, query Query) *InsertQuery { q.addWith(name, query, true) return q } diff --git a/query_merge.go b/query_merge.go index 3c3f4f7f8..aa30456a6 100644 --- a/query_merge.go +++ b/query_merge.go @@ -25,8 +25,7 @@ var _ Query = (*MergeQuery)(nil) func NewMergeQuery(db *DB) *MergeQuery { q := &MergeQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } if q.db.dialect.Name() != dialect.MSSQL && q.db.dialect.Name() != dialect.PG { @@ -60,12 +59,12 @@ func (q *MergeQuery) Apply(fns ...func(*MergeQuery) *MergeQuery) *MergeQuery { return q } -func (q *MergeQuery) With(name string, query schema.QueryAppender) *MergeQuery { +func (q *MergeQuery) With(name string, query Query) *MergeQuery { q.addWith(name, query, false) return q } -func (q *MergeQuery) WithRecursive(name string, query schema.QueryAppender) *MergeQuery { +func (q *MergeQuery) WithRecursive(name string, query Query) *MergeQuery { q.addWith(name, query, true) return q } diff --git a/query_raw.go b/query_raw.go index 1634d0e5b..b1f43af9d 100644 --- a/query_raw.go +++ b/query_raw.go @@ -14,23 +14,10 @@ type RawQuery struct { args []interface{} } -// Deprecated: Use NewRaw instead. When add it to IDB, it conflicts with the sql.Conn#Raw -func (db *DB) Raw(query string, args ...interface{}) *RawQuery { - return &RawQuery{ - baseQuery: baseQuery{ - db: db, - conn: db.DB, - }, - query: query, - args: args, - } -} - func NewRawQuery(db *DB, query string, args ...interface{}) *RawQuery { return &RawQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, query: query, args: args, diff --git a/query_select.go b/query_select.go index 2b0872ae0..70be52e27 100644 --- a/query_select.go +++ b/query_select.go @@ -40,8 +40,7 @@ func NewSelectQuery(db *DB) *SelectQuery { return &SelectQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -72,12 +71,12 @@ func (q *SelectQuery) Apply(fns ...func(*SelectQuery) *SelectQuery) *SelectQuery return q } -func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery { +func (q *SelectQuery) With(name string, query Query) *SelectQuery { q.addWith(name, query, false) return q } -func (q *SelectQuery) WithRecursive(name string, query schema.QueryAppender) *SelectQuery { +func (q *SelectQuery) WithRecursive(name string, query Query) *SelectQuery { q.addWith(name, query, true) return q } diff --git a/query_table_create.go b/query_table_create.go index aeb79cd37..ce14deb46 100644 --- a/query_table_create.go +++ b/query_table_create.go @@ -39,8 +39,7 @@ var _ Query = (*CreateTableQuery)(nil) func NewCreateTableQuery(db *DB) *CreateTableQuery { q := &CreateTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, varchar: db.Dialect().DefaultVarcharLen(), } diff --git a/query_table_drop.go b/query_table_drop.go index a92014515..e937723a5 100644 --- a/query_table_drop.go +++ b/query_table_drop.go @@ -20,8 +20,7 @@ var _ Query = (*DropTableQuery)(nil) func NewDropTableQuery(db *DB) *DropTableQuery { q := &DropTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_table_truncate.go b/query_table_truncate.go index 1db81fb53..9805630ea 100644 --- a/query_table_truncate.go +++ b/query_table_truncate.go @@ -21,8 +21,7 @@ var _ Query = (*TruncateTableQuery)(nil) func NewTruncateTableQuery(db *DB) *TruncateTableQuery { q := &TruncateTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_update.go b/query_update.go index bb9264084..1d6fd3bad 100644 --- a/query_update.go +++ b/query_update.go @@ -31,8 +31,7 @@ func NewUpdateQuery(db *DB) *UpdateQuery { q := &UpdateQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -64,12 +63,12 @@ func (q *UpdateQuery) Apply(fns ...func(*UpdateQuery) *UpdateQuery) *UpdateQuery return q } -func (q *UpdateQuery) With(name string, query schema.QueryAppender) *UpdateQuery { +func (q *UpdateQuery) With(name string, query Query) *UpdateQuery { q.addWith(name, query, false) return q } -func (q *UpdateQuery) WithRecursive(name string, query schema.QueryAppender) *UpdateQuery { +func (q *UpdateQuery) WithRecursive(name string, query Query) *UpdateQuery { q.addWith(name, query, true) return q } diff --git a/query_values.go b/query_values.go index 34deb1ee4..24b85aee6 100644 --- a/query_values.go +++ b/query_values.go @@ -24,8 +24,7 @@ var ( func NewValuesQuery(db *DB, model interface{}) *ValuesQuery { q := &ValuesQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } q.setModel(model) From 702e525e30ec93b6d4611359518e1008b67744af Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Wed, 4 Dec 2024 15:27:36 +0200 Subject: [PATCH 02/10] fix: build --- db.go | 27 ++++++++++++++++++--------- query_base.go | 6 ++---- query_select.go | 14 ++++++++------ 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/db.go b/db.go index a19f4cc3e..abeb7aa97 100644 --- a/db.go +++ b/db.go @@ -40,29 +40,38 @@ func WithReadOnlyReplica(replica *sql.DB) DBOption { } type DB struct { + // Must be a pointer so we copy the state, not the state fields. + *noCopyState + + queryHooks []QueryHook + + fmter schema.Formatter + stats DBStats +} + +// noCopyState contains DB fields that must not be copied on clone(), +// for example, it is forbidden to copy atomic.Pointer. +type noCopyState struct { *sql.DB + dialect schema.Dialect replicas []*sql.DB healthyReplicas atomic.Pointer[[]*sql.DB] nextReplica atomic.Int64 - dialect schema.Dialect - queryHooks []QueryHook - - fmter schema.Formatter flags internal.Flag closed atomic.Bool - - stats DBStats } func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { dialect.Init(sqldb) db := &DB{ - DB: sqldb, - dialect: dialect, - fmter: schema.NewFormatter(dialect), + noCopyState: &noCopyState{ + DB: sqldb, + dialect: dialect, + }, + fmter: schema.NewFormatter(dialect), } for _, opt := range opts { diff --git a/query_base.go b/query_base.go index 70bedb999..27f557426 100644 --- a/query_base.go +++ b/query_base.go @@ -147,10 +147,8 @@ func (q *baseQuery) GetTableName() string { } for _, wq := range q.with { - if v, ok := wq.query.(Query); ok { - if model := v.GetModel(); model != nil { - return v.GetTableName() - } + if model := wq.query.GetModel(); model != nil { + return wq.query.GetTableName() } } diff --git a/query_select.go b/query_select.go index 70be52e27..95e04f455 100644 --- a/query_select.go +++ b/query_select.go @@ -748,7 +748,7 @@ func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) { query := internal.String(queryBytes) ctx, event := q.db.beforeQuery(ctx, q, query, nil, query, q.model) - rows, err := q.conn.QueryContext(ctx, query) + rows, err := q.resolveConn(q).QueryContext(ctx, query) q.db.afterQuery(ctx, event, nil, err) return rows, err } @@ -876,7 +876,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var num int - err = q.conn.QueryRowContext(ctx, query).Scan(&num) + err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&num) q.db.afterQuery(ctx, event, nil, err) @@ -894,13 +894,15 @@ func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (in return int(n), nil } } - if _, ok := q.conn.(*DB); ok { - return q.scanAndCountConc(ctx, dest...) + if q.conn == nil { + return q.scanAndCountConcurrently(ctx, dest...) } return q.scanAndCountSeq(ctx, dest...) } -func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) { +func (q *SelectQuery) scanAndCountConcurrently( + ctx context.Context, dest ...interface{}, +) (int, error) { var count int var wg sync.WaitGroup var mu sync.Mutex @@ -978,7 +980,7 @@ func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) { ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var exists bool - err = q.conn.QueryRowContext(ctx, query).Scan(&exists) + err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&exists) q.db.afterQuery(ctx, event, nil, err) From 97a1fa5127afca4752247da72676155ee83ce69d Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sat, 18 Jan 2025 09:44:00 +0200 Subject: [PATCH 03/10] chore: extract resolver --- db.go | 169 +++++++++++++++++++++++++++++++++++--------------- query_base.go | 19 ++---- 2 files changed, 122 insertions(+), 66 deletions(-) diff --git a/db.go b/db.go index abeb7aa97..ddd922eee 100644 --- a/db.go +++ b/db.go @@ -33,14 +33,14 @@ func WithDiscardUnknownColumns() DBOption { } } -func WithReadOnlyReplica(replica *sql.DB) DBOption { +func WithConnResolver(resolver ConnResolver) DBOption { return func(db *DB) { - db.replicas = append(db.replicas, replica) + db.resolver = resolver } } type DB struct { - // Must be a pointer so we copy the state, not the state fields. + // Must be a pointer so we copy the whole state, not individual fields. *noCopyState queryHooks []QueryHook @@ -53,11 +53,8 @@ type DB struct { // for example, it is forbidden to copy atomic.Pointer. type noCopyState struct { *sql.DB - dialect schema.Dialect - - replicas []*sql.DB - healthyReplicas atomic.Pointer[[]*sql.DB] - nextReplica atomic.Int64 + dialect schema.Dialect + resolver ConnResolver flags internal.Flag closed atomic.Bool @@ -78,10 +75,6 @@ func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { opt(db) } - if len(db.replicas) > 0 { - go db.monitorReplicas() - } - return db } @@ -95,7 +88,16 @@ func (db *DB) String() string { func (db *DB) Close() error { db.closed.Store(true) - return db.DB.Close() + + firstErr := db.DB.Close() + + if db.resolver != nil { + if err := db.resolver.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + + return firstErr } func (db *DB) DBStats() DBStats { @@ -261,44 +263,6 @@ func (db *DB) HasFeature(feat feature.Feature) bool { return db.dialect.Features().Has(feat) } -// healthyReplica returns a random healthy replica. -func (db *DB) healthyReplica() *sql.DB { - replicas := db.loadHealthyReplicas() - if len(replicas) == 0 { - return db.DB - } - if len(replicas) == 1 { - return replicas[0] - } - i := db.nextReplica.Add(1) - return replicas[int(i)%len(replicas)] -} - -func (db *DB) loadHealthyReplicas() []*sql.DB { - if ptr := db.healthyReplicas.Load(); ptr != nil { - return *ptr - } - return nil -} - -func (db *DB) monitorReplicas() { - for !db.closed.Load() { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - healthy := make([]*sql.DB, 0, len(db.replicas)) - - for _, replica := range db.replicas { - if err := replica.PingContext(ctx); err == nil { - healthy = append(healthy, replica) - } - } - - db.healthyReplicas.Store(&healthy) - time.Sleep(5 * time.Second) - } -} - //------------------------------------------------------------------------------ func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { @@ -770,3 +734,106 @@ func (tx Tx) NewDropColumn() *DropColumnQuery { func (db *DB) makeQueryBytes() []byte { return internal.MakeQueryBytes() } + +//------------------------------------------------------------------------------ + +type ConnResolver interface { + ResolveConn(query Query) IConn + Close() error +} + +type ReadWriteConnResolver struct { + replicas []*sql.DB // read-only replicas + healthyReplicas atomic.Pointer[[]*sql.DB] + nextReplica atomic.Int64 + closed atomic.Bool +} + +func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteConnResolver { + r := new(ReadWriteConnResolver) + + for _, opt := range opts { + opt(r) + } + + if len(r.replicas) > 0 { + go r.monitor() + } + return r +} + +type ReadWriteConnResolverOption func(r *ReadWriteConnResolver) + +func WithReadOnlyReplica(db *sql.DB) ReadWriteConnResolverOption { + return func(r *ReadWriteConnResolver) { + r.replicas = append(r.replicas, db) + } +} + +func (r *ReadWriteConnResolver) Close() error { + r.closed.Store(true) + + var firstErr error + for _, db := range r.replicas { + if err := db.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +// healthyReplica returns a random healthy replica. +func (r *ReadWriteConnResolver) ResolveConn(query Query) IConn { + if len(r.replicas) == 0 || !isReadOnlyQuery(query) { + return nil + } + + replicas := r.loadHealthyReplicas() + if len(replicas) == 0 { + return nil + } + if len(replicas) == 1 { + return replicas[0] + } + i := r.nextReplica.Add(1) + return replicas[int(i)%len(replicas)] +} + +func isReadOnlyQuery(query Query) bool { + sel, ok := query.(*SelectQuery) + if !ok { + return false + } + for _, el := range sel.with { + if !isReadOnlyQuery(el.query) { + return false + } + } + return true +} + +func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB { + if ptr := r.healthyReplicas.Load(); ptr != nil { + return *ptr + } + return nil +} + +func (r *ReadWriteConnResolver) monitor() { + const interval = 5 * time.Second + for !r.closed.Load() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + healthy := make([]*sql.DB, 0, len(r.replicas)) + + for _, replica := range r.replicas { + if err := replica.PingContext(ctx); err == nil { + healthy = append(healthy, replica) + } + } + + r.healthyReplicas.Store(&healthy) + time.Sleep(interval) + } +} diff --git a/query_base.go b/query_base.go index 27f557426..b17498742 100644 --- a/query_base.go +++ b/query_base.go @@ -118,23 +118,12 @@ func (q *baseQuery) resolveConn(query Query) IConn { if q.conn != nil { return q.conn } - if len(q.db.replicas) == 0 || !isReadOnlyQuery(query) { - return q.db.DB - } - return q.db.healthyReplica() -} - -func isReadOnlyQuery(query Query) bool { - sel, ok := query.(*SelectQuery) - if !ok { - return false - } - for _, el := range sel.with { - if !isReadOnlyQuery(el.query) { - return false + if q.db.resolver != nil { + if conn := q.db.resolver.ResolveConn(query); conn != nil { + return conn } } - return true + return q.db.DB } func (q *baseQuery) GetModel() Model { From 5a7ac56ea0f16c869484691f4f5e69bdefc421f4 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sat, 18 Jan 2025 10:01:55 +0200 Subject: [PATCH 04/10] chore: add test --- db.go | 5 +++++ internal/dbtest/db_test.go | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/db.go b/db.go index ddd922eee..7cd2a7243 100644 --- a/db.go +++ b/db.go @@ -737,11 +737,16 @@ func (db *DB) makeQueryBytes() []byte { //------------------------------------------------------------------------------ +// ConnResolver enables routing queries to multiple databases. type ConnResolver interface { ResolveConn(query Query) IConn Close() error } +// TODO: +// - make monitoring interval configurable +// - make ping timeout configutable +// - allow adding read/write replicas for multi-master replication type ReadWriteConnResolver struct { replicas []*sql.DB // read-only replicas healthyReplicas atomic.Pointer[[]*sql.DB] diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 0d423d4c5..de930a995 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -1845,3 +1845,35 @@ func mustDropTableOnCleanup(tb testing.TB, ctx context.Context, db *bun.DB, mode } }) } + +func TestConnResolver(t *testing.T) { + dsn := os.Getenv("PG") + if dsn == "" { + dsn = "postgres://postgres:postgres@localhost:5432/test?sslmode=disable" + } + + rwdb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) + t.Cleanup(func() { + require.NoError(t, rwdb.Close()) + }) + + rodb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) + t.Cleanup(func() { + require.NoError(t, rodb.Close()) + }) + + resolver := bun.NewReadWriteConnResolver(bun.WithReadOnlyReplica(rodb)) + + db := bun.NewDB(rwdb, pgdialect.New(), bun.WithConnResolver(resolver)) + db.AddQueryHook(bundebug.NewQueryHook( + bundebug.WithEnabled(false), + bundebug.FromEnv(), + )) + + var num int + err := db.NewSelect().ColumnExpr("1").Scan(ctx, &num) + require.NoError(t, err) + require.Equal(t, 1, num) + require.Equal(t, 1, rodb.Stats().OpenConnections) + require.Equal(t, 0, rwdb.Stats().OpenConnections) +} From dfc405901907419d043bb6ced3ad20c131c1b972 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sat, 18 Jan 2025 10:11:23 +0200 Subject: [PATCH 05/10] fix: test --- db.go | 2 ++ internal/dbtest/db_test.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/db.go b/db.go index 7cd2a7243..cff0abb8d 100644 --- a/db.go +++ b/db.go @@ -762,8 +762,10 @@ func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteCon } if len(r.replicas) > 0 { + r.healthyReplicas.Store(&r.replicas) go r.monitor() } + return r } diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index de930a995..225c278ab 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -1874,6 +1874,6 @@ func TestConnResolver(t *testing.T) { err := db.NewSelect().ColumnExpr("1").Scan(ctx, &num) require.NoError(t, err) require.Equal(t, 1, num) - require.Equal(t, 1, rodb.Stats().OpenConnections) + require.GreaterOrEqual(t, rodb.Stats().OpenConnections, 1) require.Equal(t, 0, rwdb.Stats().OpenConnections) } From 4cbb15a53e566e03284253aa46be372338968954 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sat, 18 Jan 2025 10:13:12 +0200 Subject: [PATCH 06/10] feat: make WithReadOnlyReplica variadic --- db.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/db.go b/db.go index cff0abb8d..8f0c432e3 100644 --- a/db.go +++ b/db.go @@ -771,9 +771,9 @@ func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteCon type ReadWriteConnResolverOption func(r *ReadWriteConnResolver) -func WithReadOnlyReplica(db *sql.DB) ReadWriteConnResolverOption { +func WithReadOnlyReplica(dbs ...*sql.DB) ReadWriteConnResolverOption { return func(r *ReadWriteConnResolver) { - r.replicas = append(r.replicas, db) + r.replicas = append(r.replicas, dbs...) } } From 815e11a023d2babf65d528a20ddffc7628636e7e Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sat, 18 Jan 2025 10:40:27 +0200 Subject: [PATCH 07/10] feat: add Options --- db.go | 8 ++++ driver/pgdriver/config.go | 85 +++++++++++++++++++++------------------ 2 files changed, 53 insertions(+), 40 deletions(-) diff --git a/db.go b/db.go index 8f0c432e3..2797290b5 100644 --- a/db.go +++ b/db.go @@ -27,6 +27,14 @@ type DBStats struct { type DBOption func(db *DB) +func WithOptions(opts ...DBOption) DBOption { + return func(db *DB) { + for _, opt := range opts { + opt(db) + } + } +} + func WithDiscardUnknownColumns() DBOption { return func(db *DB) { db.flags = db.flags.Set(discardUnknownColumns) diff --git a/driver/pgdriver/config.go b/driver/pgdriver/config.go index ccc038b21..50424dd48 100644 --- a/driver/pgdriver/config.go +++ b/driver/pgdriver/config.go @@ -50,7 +50,7 @@ func newDefaultConfig() *Config { host := env("PGHOST", "localhost") port := env("PGPORT", "5432") - cfg := &Config{ + conf := &Config{ Network: "tcp", Addr: net.JoinHostPort(host, port), DialTimeout: 5 * time.Second, @@ -63,28 +63,33 @@ func newDefaultConfig() *Config { WriteTimeout: 5 * time.Second, } - cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { + conf.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { netDialer := &net.Dialer{ - Timeout: cfg.DialTimeout, + Timeout: conf.DialTimeout, KeepAlive: 5 * time.Minute, } return netDialer.DialContext(ctx, network, addr) } - return cfg + return conf } -type Option func(cfg *Config) +type Option func(conf *Config) -// Deprecated. Use Option instead. -type DriverOption = Option +func WithOptions(opts ...Option) Option { + return func(conf *Config) { + for _, opt := range opts { + opt(conf) + } + } +} func WithNetwork(network string) Option { if network == "" { panic("network is empty") } - return func(cfg *Config) { - cfg.Network = network + return func(conf *Config) { + conf.Network = network } } @@ -92,23 +97,23 @@ func WithAddr(addr string) Option { if addr == "" { panic("addr is empty") } - return func(cfg *Config) { - cfg.Addr = addr + return func(conf *Config) { + conf.Addr = addr } } func WithTLSConfig(tlsConfig *tls.Config) Option { - return func(cfg *Config) { - cfg.TLSConfig = tlsConfig + return func(conf *Config) { + conf.TLSConfig = tlsConfig } } func WithInsecure(on bool) Option { - return func(cfg *Config) { + return func(conf *Config) { if on { - cfg.TLSConfig = nil + conf.TLSConfig = nil } else { - cfg.TLSConfig = &tls.Config{InsecureSkipVerify: true} + conf.TLSConfig = &tls.Config{InsecureSkipVerify: true} } } } @@ -117,14 +122,14 @@ func WithUser(user string) Option { if user == "" { panic("user is empty") } - return func(cfg *Config) { - cfg.User = user + return func(conf *Config) { + conf.User = user } } func WithPassword(password string) Option { - return func(cfg *Config) { - cfg.Password = password + return func(conf *Config) { + conf.Password = password } } @@ -132,46 +137,46 @@ func WithDatabase(database string) Option { if database == "" { panic("database is empty") } - return func(cfg *Config) { - cfg.Database = database + return func(conf *Config) { + conf.Database = database } } func WithApplicationName(appName string) Option { - return func(cfg *Config) { - cfg.AppName = appName + return func(conf *Config) { + conf.AppName = appName } } func WithConnParams(params map[string]interface{}) Option { - return func(cfg *Config) { - cfg.ConnParams = params + return func(conf *Config) { + conf.ConnParams = params } } func WithTimeout(timeout time.Duration) Option { - return func(cfg *Config) { - cfg.DialTimeout = timeout - cfg.ReadTimeout = timeout - cfg.WriteTimeout = timeout + return func(conf *Config) { + conf.DialTimeout = timeout + conf.ReadTimeout = timeout + conf.WriteTimeout = timeout } } func WithDialTimeout(dialTimeout time.Duration) Option { - return func(cfg *Config) { - cfg.DialTimeout = dialTimeout + return func(conf *Config) { + conf.DialTimeout = dialTimeout } } func WithReadTimeout(readTimeout time.Duration) Option { - return func(cfg *Config) { - cfg.ReadTimeout = readTimeout + return func(conf *Config) { + conf.ReadTimeout = readTimeout } } func WithWriteTimeout(writeTimeout time.Duration) Option { - return func(cfg *Config) { - cfg.WriteTimeout = writeTimeout + return func(conf *Config) { + conf.WriteTimeout = writeTimeout } } @@ -179,19 +184,19 @@ func WithWriteTimeout(writeTimeout time.Duration) Option { // a query on a connection that has been used before. // If the func returns driver.ErrBadConn, the connection is discarded. func WithResetSessionFunc(fn func(context.Context, *Conn) error) Option { - return func(cfg *Config) { - cfg.ResetSessionFunc = fn + return func(conf *Config) { + conf.ResetSessionFunc = fn } } func WithDSN(dsn string) Option { - return func(cfg *Config) { + return func(conf *Config) { opts, err := parseDSN(dsn) if err != nil { panic(err) } for _, opt := range opts { - opt(cfg) + opt(conf) } } } From 8f0cb6043b70d9c41ea44ca711b462aa636f66a2 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Wed, 22 Jan 2025 10:25:02 +0200 Subject: [PATCH 08/10] Update db.go Co-authored-by: Aoang --- db.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/db.go b/db.go index 2797290b5..2fc0ffd3c 100644 --- a/db.go +++ b/db.go @@ -95,7 +95,9 @@ func (db *DB) String() string { } func (db *DB) Close() error { - db.closed.Store(true) + if db.closed.Swap(true) { + return nil + } firstErr := db.DB.Close() From 06dd7792aeb6817cfe62d3623a481b358a6be5ec Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Wed, 22 Jan 2025 10:25:09 +0200 Subject: [PATCH 09/10] Update db.go Co-authored-by: Aoang --- db.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/db.go b/db.go index 2fc0ffd3c..9ec70da4d 100644 --- a/db.go +++ b/db.go @@ -788,7 +788,9 @@ func WithReadOnlyReplica(dbs ...*sql.DB) ReadWriteConnResolverOption { } func (r *ReadWriteConnResolver) Close() error { - r.closed.Store(true) + if r.closed.Swap(true) { + return nil + } var firstErr error for _, db := range r.replicas { From 9f5e8b1c46673bd1779bd4309a28db33dcd695bf Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Wed, 22 Jan 2025 10:32:09 +0200 Subject: [PATCH 10/10] fix: individual replica timeout --- db.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/db.go b/db.go index 2797290b5..ae35e7dc1 100644 --- a/db.go +++ b/db.go @@ -837,13 +837,14 @@ func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB { func (r *ReadWriteConnResolver) monitor() { const interval = 5 * time.Second for !r.closed.Load() { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - healthy := make([]*sql.DB, 0, len(r.replicas)) for _, replica := range r.replicas { - if err := replica.PingContext(ctx); err == nil { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + err := replica.PingContext(ctx) + cancel() + + if err == nil { healthy = append(healthy, replica) } }