Skip to content

Commit

Permalink
support/db, services/horizon/internal: Configure postgres client conn…
Browse files Browse the repository at this point in the history
…ection timeouts for read only db (#4390)

Configure postgres client connection timeouts for read only database sessions
  • Loading branch information
tamirms authored May 20, 2022
1 parent 156db13 commit e96fc25
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 3 deletions.
1 change: 1 addition & 0 deletions services/horizon/internal/httpx/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func init() {
problem.RegisterError(context.Canceled, hProblem.ClientDisconnected)
problem.RegisterError(db.ErrCancelled, hProblem.ClientDisconnected)
problem.RegisterError(db.ErrTimeout, hProblem.ServiceUnavailable)
problem.RegisterError(db.ErrStatementTimeout, hProblem.ServiceUnavailable)
problem.RegisterError(db.ErrConflictWithRecovery, hProblem.ServiceUnavailable)
problem.RegisterError(db.ErrBadConnection, hProblem.ServiceUnavailable)
}
Expand Down
19 changes: 17 additions & 2 deletions services/horizon/internal/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import (
"github.com/stellar/go/support/log"
)

func mustNewDBSession(subservice db.Subservice, databaseURL string, maxIdle, maxOpen int, registry *prometheus.Registry) db.SessionInterface {
func mustNewDBSession(subservice db.Subservice, databaseURL string, maxIdle, maxOpen int, registry *prometheus.Registry, clientConfigs ...db.ClientConfig) db.SessionInterface {
log.Infof("Establishing database session for %v", subservice)
session, err := db.Open("postgres", databaseURL)
session, err := db.Open("postgres", databaseURL, clientConfigs...)
if err != nil {
log.Fatalf("cannot open %v DB: %v", subservice, err)
}
Expand All @@ -47,21 +47,36 @@ func mustInitHorizonDB(app *App) {
}

if app.config.RoDatabaseURL == "" {
var clientConfigs []db.ClientConfig
if !app.config.Ingest {
// if we are not ingesting then we don't expect to have long db queries / transactions
clientConfigs = append(
clientConfigs,
db.StatementTimeout(app.config.ConnectionTimeout),
db.IdleTransactionTimeout(app.config.ConnectionTimeout),
)
}
app.historyQ = &history.Q{mustNewDBSession(
db.HistorySubservice,
app.config.DatabaseURL,
maxIdle,
maxOpen,
app.prometheusRegistry,
clientConfigs...,
)}
} else {
// If RO set, use it for all DB queries
roClientConfigs := []db.ClientConfig{
db.StatementTimeout(app.config.ConnectionTimeout),
db.IdleTransactionTimeout(app.config.ConnectionTimeout),
}
app.historyQ = &history.Q{mustNewDBSession(
db.HistorySubservice,
app.config.RoDatabaseURL,
maxIdle,
maxOpen,
app.prometheusRegistry,
roClientConfigs...,
)}

app.primaryHistoryQ = &history.Q{mustNewDBSession(
Expand Down
59 changes: 58 additions & 1 deletion support/db/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ package db
import (
"context"
"database/sql"
"net/url"
"strconv"
"strings"
"time"

"github.com/Masterminds/squirrel"
Expand Down Expand Up @@ -44,6 +47,9 @@ var (
// ErrBadConnection is an error returned when driver returns `bad connection`
// error.
ErrBadConnection = errors.New("bad connection")
// ErrStatementTimeout is an error returned by Session methods when request has
// been cancelled due to a statement timeout.
ErrStatementTimeout = errors.New("canceling statement due to statement timeout")
)

// Conn represents a connection to a single database.
Expand Down Expand Up @@ -163,8 +169,59 @@ func pingDB(db *sqlx.DB) error {
return errors.Wrapf(err, "failed to connect to DB after %v attempts", maxDBPingAttempts)
}

type ClientConfig struct {
Key string
Value string
}

func StatementTimeout(timeout time.Duration) ClientConfig {
return ClientConfig{
Key: "statement_timeout",
Value: strconv.FormatInt(timeout.Milliseconds(), 10),
}
}

func IdleTransactionTimeout(timeout time.Duration) ClientConfig {
return ClientConfig{
Key: "idle_in_transaction_session_timeout",
Value: strconv.FormatInt(timeout.Milliseconds(), 10),
}
}

func augmentDSN(dsn string, clientConfigs []ClientConfig) string {
parsed, err := url.Parse(dsn)
// dsn can either be a postgres url like "postgres://postgres:123456@127.0.0.1:5432"
// or, it can be a white space separated string of key value pairs like
// "host=localhost port=5432 user=bob password=secret"
if err != nil || parsed.Scheme == "" {
// if dsn does not parse as a postgres url, we assume it must be take
// the form of a white space separated string
parts := []string{dsn}
for _, config := range clientConfigs {
// do not override if the key is already present in dsn
if strings.Contains(dsn, config.Key+"=") {
continue
}
parts = append(parts, config.Key+"="+config.Value)
}
return strings.Join(parts, " ")
}

q := parsed.Query()
for _, config := range clientConfigs {
// do not override if the key is already present in dsn
if len(q.Get(config.Key)) > 0 {
continue
}
q.Set(config.Key, config.Value)
}
parsed.RawQuery = q.Encode()
return parsed.String()
}

// Open the database at `dsn` and returns a new *Session using it.
func Open(dialect, dsn string) (*Session, error) {
func Open(dialect, dsn string, clientConfigs ...ClientConfig) (*Session, error) {
dsn = augmentDSN(dsn, clientConfigs)
db, err := sqlx.Open(dialect, dsn)
if err != nil {
return nil, errors.Wrap(err, "open failed")
Expand Down
25 changes: 25 additions & 0 deletions support/db/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package db

import (
"testing"
"time"

"github.com/stellar/go/support/db/dbtest"
"github.com/stretchr/testify/assert"
Expand All @@ -27,3 +28,27 @@ func TestGetTable(t *testing.T) {
}

}

func TestAugmentDSN(t *testing.T) {
configs := []ClientConfig{
IdleTransactionTimeout(2 * time.Second),
StatementTimeout(4 * time.Millisecond),
}
for _, testCase := range []struct {
input string
expected string
}{
{"postgresql://localhost", "postgresql://localhost?idle_in_transaction_session_timeout=2000&statement_timeout=4"},
{"postgresql://localhost/mydb?user=other&password=secret", "postgresql://localhost/mydb?idle_in_transaction_session_timeout=2000&password=secret&statement_timeout=4&user=other"},
{"postgresql://localhost/mydb?user=other&idle_in_transaction_session_timeout=500", "postgresql://localhost/mydb?idle_in_transaction_session_timeout=500&statement_timeout=4&user=other"},
{"host=localhost user=bob password=secret", "host=localhost user=bob password=secret idle_in_transaction_session_timeout=2000 statement_timeout=4"},
{"host=localhost user=bob password=secret statement_timeout=32", "host=localhost user=bob password=secret statement_timeout=32 idle_in_transaction_session_timeout=2000"},
} {
t.Run(testCase.input, func(t *testing.T) {
output := augmentDSN(testCase.input, configs)
if output != testCase.expected {
t.Fatalf("got %v but expected %v", output, testCase.expected)
}
})
}
}
14 changes: 14 additions & 0 deletions support/db/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ func (s *Session) Commit() error {
log.Debug("sql: commit")
s.tx = nil
s.txOptions = nil

if knownErr := s.replaceWithKnownError(err, context.Background()); knownErr != nil {
return knownErr
}
return err
}

Expand Down Expand Up @@ -231,6 +235,10 @@ func (s *Session) NoRows(err error) bool {
// replaceWithKnownError tries to replace Postgres error with package error.
// Returns a new error if the err is known.
func (s *Session) replaceWithKnownError(err error, ctx context.Context) error {
if err == nil {
return nil
}

switch {
case ctx.Err() == context.Canceled:
return ErrCancelled
Expand All @@ -243,6 +251,8 @@ func (s *Session) replaceWithKnownError(err error, ctx context.Context) error {
return ErrConflictWithRecovery
case strings.Contains(err.Error(), "driver: bad connection"):
return ErrBadConnection
case strings.Contains(err.Error(), "pq: canceling statement due to statement timeout"):
return ErrStatementTimeout
default:
return nil
}
Expand Down Expand Up @@ -305,6 +315,10 @@ func (s *Session) Rollback() error {
log.Debug("sql: rollback")
s.tx = nil
s.txOptions = nil

if knownErr := s.replaceWithKnownError(err, context.Background()); knownErr != nil {
return knownErr
}
return err
}

Expand Down
31 changes: 31 additions & 0 deletions support/db/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,34 @@ func TestSession(t *testing.T) {
assert.Equal("$1 = $2 = $3 = ?", out)
}
}

func TestStatementTimeout(t *testing.T) {
assert := assert.New(t)
db := dbtest.Postgres(t).Load(testSchema)
defer db.Close()

sess, err := Open(db.Dialect, db.DSN, StatementTimeout(50*time.Millisecond))
assert.NoError(err)
defer sess.Close()

var count int
err = sess.GetRaw(context.Background(), &count, "SELECT pg_sleep(2), COUNT(*) FROM people")
assert.ErrorIs(err, ErrStatementTimeout)
}

func TestIdleTransactionTimeout(t *testing.T) {
assert := assert.New(t)
db := dbtest.Postgres(t).Load(testSchema)
defer db.Close()

sess, err := Open(db.Dialect, db.DSN, IdleTransactionTimeout(50*time.Millisecond))
assert.NoError(err)
defer sess.Close()

assert.NoError(sess.Begin())
<-time.After(100 * time.Millisecond)

var count int
err = sess.GetRaw(context.Background(), &count, "SELECT COUNT(*) FROM people")
assert.ErrorIs(err, ErrBadConnection)
}

0 comments on commit e96fc25

Please sign in to comment.