diff --git a/db.go b/db.go index c283f56bd..a19f4cc3e 100644 --- a/db.go +++ b/db.go @@ -9,6 +9,7 @@ import ( "reflect" "strings" "sync/atomic" + "time" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" @@ -32,15 +33,25 @@ func WithDiscardUnknownColumns() DBOption { } } +func WithReadOnlyReplica(replica *sql.DB) DBOption { + return func(db *DB) { + db.replicas = append(db.replicas, replica) + } +} + type DB struct { *sql.DB - dialect schema.Dialect + replicas []*sql.DB + healthyReplicas atomic.Pointer[[]*sql.DB] + nextReplica atomic.Int64 + dialect schema.Dialect queryHooks []QueryHook - fmter schema.Formatter - flags internal.Flag + fmter schema.Formatter + flags internal.Flag + closed atomic.Bool stats DBStats } @@ -58,6 +69,10 @@ func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { opt(db) } + if len(db.replicas) > 0 { + go db.monitorReplicas() + } + return db } @@ -69,6 +84,11 @@ func (db *DB) String() string { return b.String() } +func (db *DB) Close() error { + db.closed.Store(true) + return db.DB.Close() +} + func (db *DB) DBStats() DBStats { return DBStats{ Queries: atomic.LoadUint32(&db.stats.Queries), @@ -232,6 +252,44 @@ func (db *DB) HasFeature(feat feature.Feature) bool { return db.dialect.Features().Has(feat) } +// healthyReplica returns a random healthy replica. +func (db *DB) healthyReplica() *sql.DB { + replicas := db.loadHealthyReplicas() + if len(replicas) == 0 { + return db.DB + } + if len(replicas) == 1 { + return replicas[0] + } + i := db.nextReplica.Add(1) + return replicas[int(i)%len(replicas)] +} + +func (db *DB) loadHealthyReplicas() []*sql.DB { + if ptr := db.healthyReplicas.Load(); ptr != nil { + return *ptr + } + return nil +} + +func (db *DB) monitorReplicas() { + for !db.closed.Load() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + healthy := make([]*sql.DB, 0, len(db.replicas)) + + for _, replica := range db.replicas { + if err := replica.PingContext(ctx); err == nil { + healthy = append(healthy, replica) + } + } + + db.healthyReplicas.Store(&healthy) + time.Sleep(5 * time.Second) + } +} + //------------------------------------------------------------------------------ func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) { diff --git a/internal/dbtest/docker-compose.yaml b/internal/dbtest/docker-compose.yaml index d47ab5b31..0223bf704 100755 --- a/internal/dbtest/docker-compose.yaml +++ b/internal/dbtest/docker-compose.yaml @@ -1,5 +1,3 @@ -version: '3.9' - services: mysql8: image: mysql:8.0 diff --git a/query_base.go b/query_base.go index 08ff8e5d9..70bedb999 100644 --- a/query_base.go +++ b/query_base.go @@ -24,7 +24,7 @@ const ( type withQuery struct { name string - query schema.QueryAppender + query Query recursive bool } @@ -114,8 +114,27 @@ func (q *baseQuery) DB() *DB { return q.db } -func (q *baseQuery) GetConn() IConn { - return q.conn +func (q *baseQuery) resolveConn(query Query) IConn { + if q.conn != nil { + return q.conn + } + if len(q.db.replicas) == 0 || !isReadOnlyQuery(query) { + return q.db.DB + } + return q.db.healthyReplica() +} + +func isReadOnlyQuery(query Query) bool { + sel, ok := query.(*SelectQuery) + if !ok { + return false + } + for _, el := range sel.with { + if !isReadOnlyQuery(el.query) { + return false + } + } + return true } func (q *baseQuery) GetModel() Model { @@ -249,7 +268,7 @@ func (q *baseQuery) isSoftDelete() bool { //------------------------------------------------------------------------------ -func (q *baseQuery) addWith(name string, query schema.QueryAppender, recursive bool) { +func (q *baseQuery) addWith(name string, query Query, recursive bool) { q.with = append(q.with, withQuery{ name: name, query: query, @@ -565,28 +584,33 @@ func (q *baseQuery) scan( hasDest bool, ) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) + res, err := q._scan(ctx, iquery, query, model, hasDest) + q.db.afterQuery(ctx, event, res, err) + return res, err +} - rows, err := q.conn.QueryContext(ctx, query) +func (q *baseQuery) _scan( + ctx context.Context, + iquery Query, + query string, + model Model, + hasDest bool, +) (sql.Result, error) { + rows, err := q.resolveConn(iquery).QueryContext(ctx, query) if err != nil { - q.db.afterQuery(ctx, event, nil, err) return nil, err } defer rows.Close() numRow, err := model.ScanRows(ctx, rows) if err != nil { - q.db.afterQuery(ctx, event, nil, err) return nil, err } if numRow == 0 && hasDest && isSingleRowModel(model) { - err = sql.ErrNoRows + return nil, sql.ErrNoRows } - - res := driver.RowsAffected(numRow) - q.db.afterQuery(ctx, event, res, err) - - return res, err + return driver.RowsAffected(numRow), nil } func (q *baseQuery) exec( @@ -595,7 +619,7 @@ func (q *baseQuery) exec( query string, ) (sql.Result, error) { ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) - res, err := q.conn.ExecContext(ctx, query) + res, err := q.resolveConn(iquery).ExecContext(ctx, query) q.db.afterQuery(ctx, event, res, err) return res, err } diff --git a/query_column_add.go b/query_column_add.go index 50576873c..de4ff15fe 100644 --- a/query_column_add.go +++ b/query_column_add.go @@ -20,8 +20,7 @@ var _ Query = (*AddColumnQuery)(nil) func NewAddColumnQuery(db *DB) *AddColumnQuery { q := &AddColumnQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_column_drop.go b/query_column_drop.go index 24fc93cfd..a67084a63 100644 --- a/query_column_drop.go +++ b/query_column_drop.go @@ -18,8 +18,7 @@ var _ Query = (*DropColumnQuery)(nil) func NewDropColumnQuery(db *DB) *DropColumnQuery { q := &DropColumnQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_delete.go b/query_delete.go index 1235ba718..3467fdb83 100644 --- a/query_delete.go +++ b/query_delete.go @@ -23,8 +23,7 @@ func NewDeleteQuery(db *DB) *DeleteQuery { q := &DeleteQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -56,12 +55,12 @@ func (q *DeleteQuery) Apply(fns ...func(*DeleteQuery) *DeleteQuery) *DeleteQuery return q } -func (q *DeleteQuery) With(name string, query schema.QueryAppender) *DeleteQuery { +func (q *DeleteQuery) With(name string, query Query) *DeleteQuery { q.addWith(name, query, false) return q } -func (q *DeleteQuery) WithRecursive(name string, query schema.QueryAppender) *DeleteQuery { +func (q *DeleteQuery) WithRecursive(name string, query Query) *DeleteQuery { q.addWith(name, query, true) return q } diff --git a/query_index_create.go b/query_index_create.go index 11824cfa4..f229bb5c7 100644 --- a/query_index_create.go +++ b/query_index_create.go @@ -28,8 +28,7 @@ func NewCreateIndexQuery(db *DB) *CreateIndexQuery { q := &CreateIndexQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } diff --git a/query_index_drop.go b/query_index_drop.go index ae28e7956..6300bb67f 100644 --- a/query_index_drop.go +++ b/query_index_drop.go @@ -23,8 +23,7 @@ var _ Query = (*DropIndexQuery)(nil) func NewDropIndexQuery(db *DB) *DropIndexQuery { q := &DropIndexQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_insert.go b/query_insert.go index 8bec4ce26..3013a51d0 100644 --- a/query_insert.go +++ b/query_insert.go @@ -30,8 +30,7 @@ func NewInsertQuery(db *DB) *InsertQuery { q := &InsertQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -63,12 +62,12 @@ func (q *InsertQuery) Apply(fns ...func(*InsertQuery) *InsertQuery) *InsertQuery return q } -func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery { +func (q *InsertQuery) With(name string, query Query) *InsertQuery { q.addWith(name, query, false) return q } -func (q *InsertQuery) WithRecursive(name string, query schema.QueryAppender) *InsertQuery { +func (q *InsertQuery) WithRecursive(name string, query Query) *InsertQuery { q.addWith(name, query, true) return q } diff --git a/query_merge.go b/query_merge.go index 3c3f4f7f8..aa30456a6 100644 --- a/query_merge.go +++ b/query_merge.go @@ -25,8 +25,7 @@ var _ Query = (*MergeQuery)(nil) func NewMergeQuery(db *DB) *MergeQuery { q := &MergeQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } if q.db.dialect.Name() != dialect.MSSQL && q.db.dialect.Name() != dialect.PG { @@ -60,12 +59,12 @@ func (q *MergeQuery) Apply(fns ...func(*MergeQuery) *MergeQuery) *MergeQuery { return q } -func (q *MergeQuery) With(name string, query schema.QueryAppender) *MergeQuery { +func (q *MergeQuery) With(name string, query Query) *MergeQuery { q.addWith(name, query, false) return q } -func (q *MergeQuery) WithRecursive(name string, query schema.QueryAppender) *MergeQuery { +func (q *MergeQuery) WithRecursive(name string, query Query) *MergeQuery { q.addWith(name, query, true) return q } diff --git a/query_raw.go b/query_raw.go index 1634d0e5b..b1f43af9d 100644 --- a/query_raw.go +++ b/query_raw.go @@ -14,23 +14,10 @@ type RawQuery struct { args []interface{} } -// Deprecated: Use NewRaw instead. When add it to IDB, it conflicts with the sql.Conn#Raw -func (db *DB) Raw(query string, args ...interface{}) *RawQuery { - return &RawQuery{ - baseQuery: baseQuery{ - db: db, - conn: db.DB, - }, - query: query, - args: args, - } -} - func NewRawQuery(db *DB, query string, args ...interface{}) *RawQuery { return &RawQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, query: query, args: args, diff --git a/query_select.go b/query_select.go index 2b0872ae0..70be52e27 100644 --- a/query_select.go +++ b/query_select.go @@ -40,8 +40,7 @@ func NewSelectQuery(db *DB) *SelectQuery { return &SelectQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -72,12 +71,12 @@ func (q *SelectQuery) Apply(fns ...func(*SelectQuery) *SelectQuery) *SelectQuery return q } -func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery { +func (q *SelectQuery) With(name string, query Query) *SelectQuery { q.addWith(name, query, false) return q } -func (q *SelectQuery) WithRecursive(name string, query schema.QueryAppender) *SelectQuery { +func (q *SelectQuery) WithRecursive(name string, query Query) *SelectQuery { q.addWith(name, query, true) return q } diff --git a/query_table_create.go b/query_table_create.go index aeb79cd37..ce14deb46 100644 --- a/query_table_create.go +++ b/query_table_create.go @@ -39,8 +39,7 @@ var _ Query = (*CreateTableQuery)(nil) func NewCreateTableQuery(db *DB) *CreateTableQuery { q := &CreateTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, varchar: db.Dialect().DefaultVarcharLen(), } diff --git a/query_table_drop.go b/query_table_drop.go index a92014515..e937723a5 100644 --- a/query_table_drop.go +++ b/query_table_drop.go @@ -20,8 +20,7 @@ var _ Query = (*DropTableQuery)(nil) func NewDropTableQuery(db *DB) *DropTableQuery { q := &DropTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_table_truncate.go b/query_table_truncate.go index 1db81fb53..9805630ea 100644 --- a/query_table_truncate.go +++ b/query_table_truncate.go @@ -21,8 +21,7 @@ var _ Query = (*TruncateTableQuery)(nil) func NewTruncateTableQuery(db *DB) *TruncateTableQuery { q := &TruncateTableQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } return q diff --git a/query_update.go b/query_update.go index bb9264084..1d6fd3bad 100644 --- a/query_update.go +++ b/query_update.go @@ -31,8 +31,7 @@ func NewUpdateQuery(db *DB) *UpdateQuery { q := &UpdateQuery{ whereBaseQuery: whereBaseQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, }, } @@ -64,12 +63,12 @@ func (q *UpdateQuery) Apply(fns ...func(*UpdateQuery) *UpdateQuery) *UpdateQuery return q } -func (q *UpdateQuery) With(name string, query schema.QueryAppender) *UpdateQuery { +func (q *UpdateQuery) With(name string, query Query) *UpdateQuery { q.addWith(name, query, false) return q } -func (q *UpdateQuery) WithRecursive(name string, query schema.QueryAppender) *UpdateQuery { +func (q *UpdateQuery) WithRecursive(name string, query Query) *UpdateQuery { q.addWith(name, query, true) return q } diff --git a/query_values.go b/query_values.go index 34deb1ee4..24b85aee6 100644 --- a/query_values.go +++ b/query_values.go @@ -24,8 +24,7 @@ var ( func NewValuesQuery(db *DB, model interface{}) *ValuesQuery { q := &ValuesQuery{ baseQuery: baseQuery{ - db: db, - conn: db.DB, + db: db, }, } q.setModel(model)