diff --git a/services/horizon/internal/httpx/server.go b/services/horizon/internal/httpx/server.go index c3f6983c2c..7d9dc5419f 100644 --- a/services/horizon/internal/httpx/server.go +++ b/services/horizon/internal/httpx/server.go @@ -58,6 +58,7 @@ func init() { problem.RegisterError(context.Canceled, hProblem.ClientDisconnected) problem.RegisterError(db.ErrCancelled, hProblem.ClientDisconnected) problem.RegisterError(db.ErrTimeout, hProblem.ServiceUnavailable) + problem.RegisterError(db.ErrStatementTimeout, hProblem.ServiceUnavailable) problem.RegisterError(db.ErrConflictWithRecovery, hProblem.ServiceUnavailable) problem.RegisterError(db.ErrBadConnection, hProblem.ServiceUnavailable) } diff --git a/services/horizon/internal/init.go b/services/horizon/internal/init.go index c2b438d839..26f2d1baea 100644 --- a/services/horizon/internal/init.go +++ b/services/horizon/internal/init.go @@ -18,9 +18,9 @@ import ( "github.com/stellar/go/support/log" ) -func mustNewDBSession(subservice db.Subservice, databaseURL string, maxIdle, maxOpen int, registry *prometheus.Registry) db.SessionInterface { +func mustNewDBSession(subservice db.Subservice, databaseURL string, maxIdle, maxOpen int, registry *prometheus.Registry, clientConfigs ...db.ClientConfig) db.SessionInterface { log.Infof("Establishing database session for %v", subservice) - session, err := db.Open("postgres", databaseURL) + session, err := db.Open("postgres", databaseURL, clientConfigs...) if err != nil { log.Fatalf("cannot open %v DB: %v", subservice, err) } @@ -47,21 +47,36 @@ func mustInitHorizonDB(app *App) { } if app.config.RoDatabaseURL == "" { + var clientConfigs []db.ClientConfig + if !app.config.Ingest { + // if we are not ingesting then we don't expect to have long db queries / transactions + clientConfigs = append( + clientConfigs, + db.StatementTimeout(app.config.ConnectionTimeout), + db.IdleTransactionTimeout(app.config.ConnectionTimeout), + ) + } app.historyQ = &history.Q{mustNewDBSession( db.HistorySubservice, app.config.DatabaseURL, maxIdle, maxOpen, app.prometheusRegistry, + clientConfigs..., )} } else { // If RO set, use it for all DB queries + roClientConfigs := []db.ClientConfig{ + db.StatementTimeout(app.config.ConnectionTimeout), + db.IdleTransactionTimeout(app.config.ConnectionTimeout), + } app.historyQ = &history.Q{mustNewDBSession( db.HistorySubservice, app.config.RoDatabaseURL, maxIdle, maxOpen, app.prometheusRegistry, + roClientConfigs..., )} app.primaryHistoryQ = &history.Q{mustNewDBSession( diff --git a/support/db/main.go b/support/db/main.go index 5a316899c5..56d496382f 100644 --- a/support/db/main.go +++ b/support/db/main.go @@ -14,6 +14,9 @@ package db import ( "context" "database/sql" + "net/url" + "strconv" + "strings" "time" "github.com/Masterminds/squirrel" @@ -44,6 +47,9 @@ var ( // ErrBadConnection is an error returned when driver returns `bad connection` // error. ErrBadConnection = errors.New("bad connection") + // ErrStatementTimeout is an error returned by Session methods when request has + // been cancelled due to a statement timeout. + ErrStatementTimeout = errors.New("canceling statement due to statement timeout") ) // Conn represents a connection to a single database. @@ -163,8 +169,59 @@ func pingDB(db *sqlx.DB) error { return errors.Wrapf(err, "failed to connect to DB after %v attempts", maxDBPingAttempts) } +type ClientConfig struct { + Key string + Value string +} + +func StatementTimeout(timeout time.Duration) ClientConfig { + return ClientConfig{ + Key: "statement_timeout", + Value: strconv.FormatInt(timeout.Milliseconds(), 10), + } +} + +func IdleTransactionTimeout(timeout time.Duration) ClientConfig { + return ClientConfig{ + Key: "idle_in_transaction_session_timeout", + Value: strconv.FormatInt(timeout.Milliseconds(), 10), + } +} + +func augmentDSN(dsn string, clientConfigs []ClientConfig) string { + parsed, err := url.Parse(dsn) + // dsn can either be a postgres url like "postgres://postgres:123456@127.0.0.1:5432" + // or, it can be a white space separated string of key value pairs like + // "host=localhost port=5432 user=bob password=secret" + if err != nil || parsed.Scheme == "" { + // if dsn does not parse as a postgres url, we assume it must be take + // the form of a white space separated string + parts := []string{dsn} + for _, config := range clientConfigs { + // do not override if the key is already present in dsn + if strings.Contains(dsn, config.Key+"=") { + continue + } + parts = append(parts, config.Key+"="+config.Value) + } + return strings.Join(parts, " ") + } + + q := parsed.Query() + for _, config := range clientConfigs { + // do not override if the key is already present in dsn + if len(q.Get(config.Key)) > 0 { + continue + } + q.Set(config.Key, config.Value) + } + parsed.RawQuery = q.Encode() + return parsed.String() +} + // Open the database at `dsn` and returns a new *Session using it. -func Open(dialect, dsn string) (*Session, error) { +func Open(dialect, dsn string, clientConfigs ...ClientConfig) (*Session, error) { + dsn = augmentDSN(dsn, clientConfigs) db, err := sqlx.Open(dialect, dsn) if err != nil { return nil, errors.Wrap(err, "open failed") diff --git a/support/db/main_test.go b/support/db/main_test.go index 8ca94f1e3b..68724d197d 100644 --- a/support/db/main_test.go +++ b/support/db/main_test.go @@ -2,6 +2,7 @@ package db import ( "testing" + "time" "github.com/stellar/go/support/db/dbtest" "github.com/stretchr/testify/assert" @@ -27,3 +28,27 @@ func TestGetTable(t *testing.T) { } } + +func TestAugmentDSN(t *testing.T) { + configs := []ClientConfig{ + IdleTransactionTimeout(2 * time.Second), + StatementTimeout(4 * time.Millisecond), + } + for _, testCase := range []struct { + input string + expected string + }{ + {"postgresql://localhost", "postgresql://localhost?idle_in_transaction_session_timeout=2000&statement_timeout=4"}, + {"postgresql://localhost/mydb?user=other&password=secret", "postgresql://localhost/mydb?idle_in_transaction_session_timeout=2000&password=secret&statement_timeout=4&user=other"}, + {"postgresql://localhost/mydb?user=other&idle_in_transaction_session_timeout=500", "postgresql://localhost/mydb?idle_in_transaction_session_timeout=500&statement_timeout=4&user=other"}, + {"host=localhost user=bob password=secret", "host=localhost user=bob password=secret idle_in_transaction_session_timeout=2000 statement_timeout=4"}, + {"host=localhost user=bob password=secret statement_timeout=32", "host=localhost user=bob password=secret statement_timeout=32 idle_in_transaction_session_timeout=2000"}, + } { + t.Run(testCase.input, func(t *testing.T) { + output := augmentDSN(testCase.input, configs) + if output != testCase.expected { + t.Fatalf("got %v but expected %v", output, testCase.expected) + } + }) + } +} diff --git a/support/db/session.go b/support/db/session.go index 4bc0218f90..33514e2a09 100644 --- a/support/db/session.go +++ b/support/db/session.go @@ -91,6 +91,10 @@ func (s *Session) Commit() error { log.Debug("sql: commit") s.tx = nil s.txOptions = nil + + if knownErr := s.replaceWithKnownError(err, context.Background()); knownErr != nil { + return knownErr + } return err } @@ -231,6 +235,10 @@ func (s *Session) NoRows(err error) bool { // replaceWithKnownError tries to replace Postgres error with package error. // Returns a new error if the err is known. func (s *Session) replaceWithKnownError(err error, ctx context.Context) error { + if err == nil { + return nil + } + switch { case ctx.Err() == context.Canceled: return ErrCancelled @@ -243,6 +251,8 @@ func (s *Session) replaceWithKnownError(err error, ctx context.Context) error { return ErrConflictWithRecovery case strings.Contains(err.Error(), "driver: bad connection"): return ErrBadConnection + case strings.Contains(err.Error(), "pq: canceling statement due to statement timeout"): + return ErrStatementTimeout default: return nil } @@ -305,6 +315,10 @@ func (s *Session) Rollback() error { log.Debug("sql: rollback") s.tx = nil s.txOptions = nil + + if knownErr := s.replaceWithKnownError(err, context.Background()); knownErr != nil { + return knownErr + } return err } diff --git a/support/db/session_test.go b/support/db/session_test.go index 742167bee0..8b13ba2736 100644 --- a/support/db/session_test.go +++ b/support/db/session_test.go @@ -129,3 +129,34 @@ func TestSession(t *testing.T) { assert.Equal("$1 = $2 = $3 = ?", out) } } + +func TestStatementTimeout(t *testing.T) { + assert := assert.New(t) + db := dbtest.Postgres(t).Load(testSchema) + defer db.Close() + + sess, err := Open(db.Dialect, db.DSN, StatementTimeout(50*time.Millisecond)) + assert.NoError(err) + defer sess.Close() + + var count int + err = sess.GetRaw(context.Background(), &count, "SELECT pg_sleep(2), COUNT(*) FROM people") + assert.ErrorIs(err, ErrStatementTimeout) +} + +func TestIdleTransactionTimeout(t *testing.T) { + assert := assert.New(t) + db := dbtest.Postgres(t).Load(testSchema) + defer db.Close() + + sess, err := Open(db.Dialect, db.DSN, IdleTransactionTimeout(50*time.Millisecond)) + assert.NoError(err) + defer sess.Close() + + assert.NoError(sess.Begin()) + <-time.After(100 * time.Millisecond) + + var count int + err = sess.GetRaw(context.Background(), &count, "SELECT COUNT(*) FROM people") + assert.ErrorIs(err, ErrBadConnection) +}