Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow to specify read-only replica for SELECTs #1085

Merged
merged 12 commits into from
Jan 22, 2025
165 changes: 157 additions & 8 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"reflect"
"strings"
"sync/atomic"
"time"

"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
Expand All @@ -26,32 +27,56 @@ type DBStats struct {

type DBOption func(db *DB)

func WithOptions(opts ...DBOption) DBOption {
return func(db *DB) {
for _, opt := range opts {
opt(db)
}
}
}

func WithDiscardUnknownColumns() DBOption {
return func(db *DB) {
db.flags = db.flags.Set(discardUnknownColumns)
}
}

type DB struct {
*sql.DB
func WithConnResolver(resolver ConnResolver) DBOption {
return func(db *DB) {
db.resolver = resolver
}
}

dialect schema.Dialect
type DB struct {
// Must be a pointer so we copy the whole state, not individual fields.
*noCopyState

queryHooks []QueryHook

fmter schema.Formatter
flags internal.Flag

stats DBStats
}

// noCopyState contains DB fields that must not be copied on clone(),
// for example, it is forbidden to copy atomic.Pointer.
type noCopyState struct {
*sql.DB
dialect schema.Dialect
resolver ConnResolver

flags internal.Flag
closed atomic.Bool
}

func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
dialect.Init(sqldb)

db := &DB{
DB: sqldb,
dialect: dialect,
fmter: schema.NewFormatter(dialect),
noCopyState: &noCopyState{
DB: sqldb,
dialect: dialect,
},
fmter: schema.NewFormatter(dialect),
}

for _, opt := range opts {
Expand All @@ -69,6 +94,20 @@ func (db *DB) String() string {
return b.String()
}

func (db *DB) Close() error {
db.closed.Store(true)
vmihailenco marked this conversation as resolved.
Show resolved Hide resolved

firstErr := db.DB.Close()

if db.resolver != nil {
if err := db.resolver.Close(); err != nil && firstErr == nil {
firstErr = err
}
}

return firstErr
}

func (db *DB) DBStats() DBStats {
return DBStats{
Queries: atomic.LoadUint32(&db.stats.Queries),
Expand Down Expand Up @@ -703,3 +742,113 @@ func (tx Tx) NewDropColumn() *DropColumnQuery {
func (db *DB) makeQueryBytes() []byte {
return internal.MakeQueryBytes()
}

//------------------------------------------------------------------------------

// ConnResolver enables routing queries to multiple databases.
type ConnResolver interface {
ResolveConn(query Query) IConn
Close() error
}

// TODO:
// - make monitoring interval configurable
// - make ping timeout configutable
// - allow adding read/write replicas for multi-master replication
type ReadWriteConnResolver struct {
replicas []*sql.DB // read-only replicas
healthyReplicas atomic.Pointer[[]*sql.DB]
nextReplica atomic.Int64
closed atomic.Bool
}

func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteConnResolver {
r := new(ReadWriteConnResolver)

for _, opt := range opts {
opt(r)
}

if len(r.replicas) > 0 {
r.healthyReplicas.Store(&r.replicas)
go r.monitor()
}

return r
}

type ReadWriteConnResolverOption func(r *ReadWriteConnResolver)

func WithReadOnlyReplica(dbs ...*sql.DB) ReadWriteConnResolverOption {
return func(r *ReadWriteConnResolver) {
r.replicas = append(r.replicas, dbs...)
}
}

func (r *ReadWriteConnResolver) Close() error {
r.closed.Store(true)
vmihailenco marked this conversation as resolved.
Show resolved Hide resolved

var firstErr error
for _, db := range r.replicas {
if err := db.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}

// healthyReplica returns a random healthy replica.
func (r *ReadWriteConnResolver) ResolveConn(query Query) IConn {
if len(r.replicas) == 0 || !isReadOnlyQuery(query) {
return nil
}

replicas := r.loadHealthyReplicas()
if len(replicas) == 0 {
return nil
}
if len(replicas) == 1 {
return replicas[0]
}
i := r.nextReplica.Add(1)
return replicas[int(i)%len(replicas)]
}

func isReadOnlyQuery(query Query) bool {
sel, ok := query.(*SelectQuery)
if !ok {
return false
}
for _, el := range sel.with {
if !isReadOnlyQuery(el.query) {
return false
}
}
return true
}

func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB {
if ptr := r.healthyReplicas.Load(); ptr != nil {
return *ptr
}
return nil
}

func (r *ReadWriteConnResolver) monitor() {
const interval = 5 * time.Second
for !r.closed.Load() {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
vmihailenco marked this conversation as resolved.
Show resolved Hide resolved
defer cancel()
vmihailenco marked this conversation as resolved.
Show resolved Hide resolved

healthy := make([]*sql.DB, 0, len(r.replicas))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to declare two slices and always use the one not utilized by r.healthyReplicas to avoid memory allocation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like you suggest to declare 2 slices and re-cycle them in the monitor.

It is not totally safe since we don't exactly know when the slice becomes free, but more importantly I believe it is not worth it since a slice allocation every 3-5 seconds is not going to change much.


for _, replica := range r.replicas {
if err := replica.PingContext(ctx); err == nil {
healthy = append(healthy, replica)
}
}

r.healthyReplicas.Store(&healthy)
time.Sleep(interval)
}
}
85 changes: 45 additions & 40 deletions driver/pgdriver/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func newDefaultConfig() *Config {
host := env("PGHOST", "localhost")
port := env("PGPORT", "5432")

cfg := &Config{
conf := &Config{
Network: "tcp",
Addr: net.JoinHostPort(host, port),
DialTimeout: 5 * time.Second,
Expand All @@ -63,52 +63,57 @@ func newDefaultConfig() *Config {
WriteTimeout: 5 * time.Second,
}

cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
conf.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialer := &net.Dialer{
Timeout: cfg.DialTimeout,
Timeout: conf.DialTimeout,
KeepAlive: 5 * time.Minute,
}
return netDialer.DialContext(ctx, network, addr)
}

return cfg
return conf
}

type Option func(cfg *Config)
type Option func(conf *Config)

// Deprecated. Use Option instead.
type DriverOption = Option
func WithOptions(opts ...Option) Option {
return func(conf *Config) {
for _, opt := range opts {
opt(conf)
}
}
}

func WithNetwork(network string) Option {
if network == "" {
panic("network is empty")
}
return func(cfg *Config) {
cfg.Network = network
return func(conf *Config) {
conf.Network = network
}
}

func WithAddr(addr string) Option {
if addr == "" {
panic("addr is empty")
}
return func(cfg *Config) {
cfg.Addr = addr
return func(conf *Config) {
conf.Addr = addr
}
}

func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *Config) {
cfg.TLSConfig = tlsConfig
return func(conf *Config) {
conf.TLSConfig = tlsConfig
}
}

func WithInsecure(on bool) Option {
return func(cfg *Config) {
return func(conf *Config) {
if on {
cfg.TLSConfig = nil
conf.TLSConfig = nil
} else {
cfg.TLSConfig = &tls.Config{InsecureSkipVerify: true}
conf.TLSConfig = &tls.Config{InsecureSkipVerify: true}
}
}
}
Expand All @@ -117,81 +122,81 @@ func WithUser(user string) Option {
if user == "" {
panic("user is empty")
}
return func(cfg *Config) {
cfg.User = user
return func(conf *Config) {
conf.User = user
}
}

func WithPassword(password string) Option {
return func(cfg *Config) {
cfg.Password = password
return func(conf *Config) {
conf.Password = password
}
}

func WithDatabase(database string) Option {
if database == "" {
panic("database is empty")
}
return func(cfg *Config) {
cfg.Database = database
return func(conf *Config) {
conf.Database = database
}
}

func WithApplicationName(appName string) Option {
return func(cfg *Config) {
cfg.AppName = appName
return func(conf *Config) {
conf.AppName = appName
}
}

func WithConnParams(params map[string]interface{}) Option {
return func(cfg *Config) {
cfg.ConnParams = params
return func(conf *Config) {
conf.ConnParams = params
}
}

func WithTimeout(timeout time.Duration) Option {
return func(cfg *Config) {
cfg.DialTimeout = timeout
cfg.ReadTimeout = timeout
cfg.WriteTimeout = timeout
return func(conf *Config) {
conf.DialTimeout = timeout
conf.ReadTimeout = timeout
conf.WriteTimeout = timeout
}
}

func WithDialTimeout(dialTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.DialTimeout = dialTimeout
return func(conf *Config) {
conf.DialTimeout = dialTimeout
}
}

func WithReadTimeout(readTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.ReadTimeout = readTimeout
return func(conf *Config) {
conf.ReadTimeout = readTimeout
}
}

func WithWriteTimeout(writeTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.WriteTimeout = writeTimeout
return func(conf *Config) {
conf.WriteTimeout = writeTimeout
}
}

// WithResetSessionFunc configures a function that is called prior to executing
// a query on a connection that has been used before.
// If the func returns driver.ErrBadConn, the connection is discarded.
func WithResetSessionFunc(fn func(context.Context, *Conn) error) Option {
return func(cfg *Config) {
cfg.ResetSessionFunc = fn
return func(conf *Config) {
conf.ResetSessionFunc = fn
}
}

func WithDSN(dsn string) Option {
return func(cfg *Config) {
return func(conf *Config) {
opts, err := parseDSN(dsn)
if err != nil {
panic(err)
}
for _, opt := range opts {
opt(cfg)
opt(conf)
}
}
}
Expand Down
Loading
Loading